January 02, 2021
크루스칼(kruskal) 알고리즘, 최소 스패닝 트리(MST)
10만 개의 정점을 최소 비용으로 연결하는 문제이다.
연결 비용이 이므로 N*(N-1)개의 모든 연결 비용을 구할 필요가 없다. 따라서 각 축을 기준으로 정렬한 뒤, 연속된 정점을 이어주면 3(N-1)개의 간선이 생긴다. 이를 비용 기준으로 오름차 정렬하여 최소 스패닝 트리를 만들어주면 문제를 해결할 수 있다.
#include <algorithm>
#include <iostream>
#include <vector>
using namespace std;
struct Point {
int x, y, z, idx;
};
struct Edge {
int dist, u, v;
bool operator<(const Edge& e) { return dist < e.dist; }
};
vector<Point> v;
vector<Edge> edges;
vector<int> parent;
int n;
void createEdges() {
sort(v.begin(), v.end(),
[](const Point& a, const Point& b) { return a.x < b.x; });
for (int i = 0; i < v.size() - 1; i++) {
edges.push_back({abs(v[i].x - v[i + 1].x), v[i].idx, v[i + 1].idx});
}
sort(v.begin(), v.end(),
[](const Point& a, const Point& b) { return a.y < b.y; });
for (int i = 0; i < v.size() - 1; i++) {
edges.push_back({abs(v[i].y - v[i + 1].y), v[i].idx, v[i + 1].idx});
}
sort(v.begin(), v.end(),
[](const Point& a, const Point& b) { return a.z < b.z; });
for (int i = 0; i < v.size() - 1; i++) {
edges.push_back({abs(v[i].z - v[i + 1].z), v[i].idx, v[i + 1].idx});
}
}
int find(int u) {
if (u == parent[u]) return u;
return parent[u] = find(parent[u]);
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cin >> n;
parent.resize(n);
for (int i = 0, x, y, z; i < n; i++) {
cin >> x >> y >> z;
v.push_back({x, y, z, i});
parent[i] = i;
}
createEdges();
sort(edges.begin(), edges.end());
int cnt = 0, ans = 0;
for (auto& e : edges) {
int u = find(e.u);
int v = find(e.v);
if (u == v) continue;
parent[u] = v;
ans += e.dist;
if (++cnt == n - 1) break;
}
cout << ans << endl;
return 0;
}
import java.io.*;
import java.util.*;
class Point {
int x, y, z, idx;
public Point(int x, int y, int z, int idx) {
this.x = x;
this.y = y;
this.z = z;
this.idx = idx;
}
}
class Edge {
int u, v, dist;
public Edge(int u, int v, int dist) {
this.u = u;
this.v = v;
this.dist = dist;
}
}
public class Main {
static ArrayList<Point> points = new ArrayList<>();
static ArrayList<Edge> edges = new ArrayList<>();
static int[] parent;
static int n;
static void input() throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
n = Integer.parseInt(br.readLine());
parent = new int[n];
StringTokenizer st;
for (int i = 0, x, y, z; i < n; i++) {
st = new StringTokenizer(br.readLine());
x = Integer.parseInt(st.nextToken());
y = Integer.parseInt(st.nextToken());
z = Integer.parseInt(st.nextToken());
points.add(new Point(x, y, z, i));
parent[i] = i;
}
}
static void createEdges() {
points.sort(Comparator.comparingInt(o -> o.x));
for (int i = 0; i < n - 1; i++) {
int dist = Math.abs(points.get(i).x - points.get(i + 1).x);
edges.add(new Edge(points.get(i).idx, points.get(i + 1).idx, dist));
}
points.sort(Comparator.comparingInt(o -> o.y));
for (int i = 0; i < n - 1; i++) {
int dist = Math.abs(points.get(i).y - points.get(i + 1).y);
edges.add(new Edge(points.get(i).idx, points.get(i + 1).idx, dist));
}
points.sort(Comparator.comparingInt(o -> o.z));
for (int i = 0; i < n - 1; i++) {
int dist = Math.abs(points.get(i).z - points.get(i + 1).z);
edges.add(new Edge(points.get(i).idx, points.get(i + 1).idx, dist));
}
}
static int find(int u) {
if (u == parent[u]) {
return u;
}
return parent[u] = find(parent[u]);
}
static void solve() {
createEdges();
edges.sort(Comparator.comparingInt(o -> o.dist));
int cnt = 0, ans = 0;
for (Edge e : edges) {
int u = find(e.u);
int v = find(e.v);
if (u == v) {
continue;
}
parent[u] = v;
ans += e.dist;
if (++cnt == n - 1) {
break;
}
}
System.out.println(ans);
}
public static void main(String[] args) throws IOException {
input();
solve();
}
}