[백준] 2887 행성 터널

문제 링크

2887번: 행성 터널

알고리즘

크루스칼(kruskal) 알고리즘, 최소 스패닝 트리(MST)

풀이

10만 개의 정점을 최소 비용으로 연결하는 문제이다.

연결 비용이 min(xAxB,yAyB,zAzB)min(|x_A-x_B|, |y_A-y_B|, |z_A-z_B|)이므로 N*(N-1)개의 모든 연결 비용을 구할 필요가 없다. 따라서 각 축을 기준으로 정렬한 뒤, 연속된 정점을 이어주면 3(N-1)개의 간선이 생긴다. 이를 비용 기준으로 오름차 정렬하여 최소 스패닝 트리를 만들어주면 문제를 해결할 수 있다.

  • 시간 복잡도는 정렬 알고리즘에 의하여 3Nlog3N3Nlog3N이다.
  • 최소 비용의 최댓값은 20억이므로 정수형으로 충분하다.
  • 완성된 최소 스패닝 트리는 N-1개의 간선을 갖는다.

코드

  • C++
#include <algorithm>
#include <iostream>
#include <vector>
using namespace std;

struct Point {
    int x, y, z, idx;
};

struct Edge {
    int dist, u, v;
    bool operator<(const Edge& e) { return dist < e.dist; }
};

vector<Point> v;
vector<Edge> edges;
vector<int> parent;
int n;

void createEdges() {
    sort(v.begin(), v.end(),
         [](const Point& a, const Point& b) { return a.x < b.x; });
    for (int i = 0; i < v.size() - 1; i++) {
        edges.push_back({abs(v[i].x - v[i + 1].x), v[i].idx, v[i + 1].idx});
    }

    sort(v.begin(), v.end(),
         [](const Point& a, const Point& b) { return a.y < b.y; });
    for (int i = 0; i < v.size() - 1; i++) {
        edges.push_back({abs(v[i].y - v[i + 1].y), v[i].idx, v[i + 1].idx});
    }

    sort(v.begin(), v.end(),
         [](const Point& a, const Point& b) { return a.z < b.z; });
    for (int i = 0; i < v.size() - 1; i++) {
        edges.push_back({abs(v[i].z - v[i + 1].z), v[i].idx, v[i + 1].idx});
    }
}

int find(int u) {
    if (u == parent[u]) return u;
    return parent[u] = find(parent[u]);
}

int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);

    cin >> n;
    parent.resize(n);

    for (int i = 0, x, y, z; i < n; i++) {
        cin >> x >> y >> z;
        v.push_back({x, y, z, i});
        parent[i] = i;
    }

    createEdges();

    sort(edges.begin(), edges.end());

    int cnt = 0, ans = 0;
    for (auto& e : edges) {
        int u = find(e.u);
        int v = find(e.v);
        if (u == v) continue;
        parent[u] = v;
        ans += e.dist;
        if (++cnt == n - 1) break;
    }

    cout << ans << endl;

    return 0;
}
  • Java
import java.io.*;
import java.util.*;

class Point {

    int x, y, z, idx;

    public Point(int x, int y, int z, int idx) {
        this.x = x;
        this.y = y;
        this.z = z;
        this.idx = idx;
    }
}

class Edge {

    int u, v, dist;

    public Edge(int u, int v, int dist) {
        this.u = u;
        this.v = v;
        this.dist = dist;
    }
}

public class Main {

    static ArrayList<Point> points = new ArrayList<>();
    static ArrayList<Edge> edges = new ArrayList<>();
    static int[] parent;
    static int n;

    static void input() throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        n = Integer.parseInt(br.readLine());
        parent = new int[n];
        StringTokenizer st;

        for (int i = 0, x, y, z; i < n; i++) {
            st = new StringTokenizer(br.readLine());
            x = Integer.parseInt(st.nextToken());
            y = Integer.parseInt(st.nextToken());
            z = Integer.parseInt(st.nextToken());
            points.add(new Point(x, y, z, i));
            parent[i] = i;
        }
    }

    static void createEdges() {
        points.sort(Comparator.comparingInt(o -> o.x));
        for (int i = 0; i < n - 1; i++) {
            int dist = Math.abs(points.get(i).x - points.get(i + 1).x);
            edges.add(new Edge(points.get(i).idx, points.get(i + 1).idx, dist));
        }

        points.sort(Comparator.comparingInt(o -> o.y));
        for (int i = 0; i < n - 1; i++) {
            int dist = Math.abs(points.get(i).y - points.get(i + 1).y);
            edges.add(new Edge(points.get(i).idx, points.get(i + 1).idx, dist));
        }

        points.sort(Comparator.comparingInt(o -> o.z));
        for (int i = 0; i < n - 1; i++) {
            int dist = Math.abs(points.get(i).z - points.get(i + 1).z);
            edges.add(new Edge(points.get(i).idx, points.get(i + 1).idx, dist));
        }
    }

    static int find(int u) {
        if (u == parent[u]) {
            return u;
        }
        return parent[u] = find(parent[u]);
    }

    static void solve() {
        createEdges();
        edges.sort(Comparator.comparingInt(o -> o.dist));
        int cnt = 0, ans = 0;
        for (Edge e : edges) {
            int u = find(e.u);
            int v = find(e.v);
            if (u == v) {
                continue;
            }
            parent[u] = v;
            ans += e.dist;
            if (++cnt == n - 1) {
                break;
            }
        }
        System.out.println(ans);
    }

    public static void main(String[] args) throws IOException {
        input();
        solve();
    }
}

Written by@WOOJIN
自强不息,厚德载物

GitHub