Toolofv 님의 블로그

[Python] 백준 - 1197 최소 스패닝 트리 본문

Algorithm

[Python] 백준 - 1197 최소 스패닝 트리

Toolofv 2024. 7. 8. 22:26

 

문제

그래프가 주어졌을 때, 그 그래프의 최소 스패닝 트리를 구하는 프로그램을 작성하시오.

최소 스패닝 트리는, 주어진 그래프의 모든 정점들을 연결하는 부분 그래프 중에서 그 가중치의 합이 최소인 트리를 말한다.

입력

첫째 줄에 정점의 개수 V(1 ≤ V ≤ 10,000)와 간선의 개수 E(1 ≤ E ≤ 100,000)가 주어진다. 다음 E개의 줄에는 각 간선에 대한 정보를 나타내는 세 정수 A, B, C가 주어진다. 이는 A번 정점과 B번 정점이 가중치 C인 간선으로 연결되어 있다는 의미이다. C는 음수일 수도 있으며, 절댓값이 1,000,000을 넘지 않는다.

그래프의 정점은 1번부터 V번까지 번호가 매겨져 있고, 임의의 두 정점 사이에 경로가 있다. 최소 스패닝 트리의 가중치가 -2,147,483,648보다 크거나 같고, 2,147,483,647보다 작거나 같은 데이터만 입력으로 주어진다.

출력

첫째 줄에 최소 스패닝 트리의 가중치를 출력한다.

 

 

문제해결방법 - 

 

1. 최소 스패닝 트리는 싸이클이 없으며, 가중치가 최소값으로 연결된 트리 구조이다.

2. 유니온 파인드(Union Find)프림(Prim) 알고리즘의 방법으로 구현할 수 있다.

 

1) 유니온 파인드(Union Find)

 

자료구조 - parent리스트(자기자신을 최상위 부모 노드로 갖는 리스트), edge리스트(시작점, 도착점, 간선비용을 저장한 리스트)

FIND 함수는 parent리스트를 이용하여 최상위 부모 노드를 찾아 출력하는 함수이다.

UNION 함수는 인자 (x, y)의 최상위 부모 노드를 연결시켜주는 함수이다.

 

-> edge리스트의 간선비용을 최소값부터 돌 수 있도록 정렬한 후, 시작점과 끝점(최상위 부모 노드)이 연결되지 않았을 때, 간선 비용을 더해준다. (연결이 되어 있다면 싸이클이기 때문에 더해주지 않고 그 전까지의 간선 비용의 합을 출력한다.)

 

2) 프림(Prim)

 

자료구조 - 다익스트라 알고리즘과 비슷하지만, dfs, bfs에서 사용하는 방문리스트와 시작점, 끝점, 간선 비용 리스트만 있으면 된다.

 

Prim 함수는 간선 비용이 최소값(Heap자료구조로 최소의 간선 비용으로만 정렬된다.)인 경우부터 진행되며, 직전 방문한 노드이거나, 싸이클이 완성되었을 시 continue로 인해 다음 코드로 진행되지 않는다.

 

 

- 유니온 파인드(Union Find)

import sys
import heapq
sys.setrecursionlimit(10**6)
input = sys.stdin.readline

n, m = map(int, input().split())
edge, ans = [], 0
parent = [i for i in range(n+1)]
for i in range(m):
	a, b, c = map(int, input().split())
	edge.append((c, a, b))
edge.sort(key = lambda x : x[0])

def FIND(x):
	if parent[x] != x:
		parent[x] = FIND(parent[x])
	return parent[x]

def UNION(x, y):
	x = FIND(x)
	y = FIND(y)
	if x < y:
		parent[y] = x
	else:
		parent[x] = y

for i in range(m):
	c, a, b = edge[i]
	if FIND(a) != FIND(b):
		ans += c
	UNION(a, b)
print(ans)

 

- 프림(Prim)

import sys
import heapq
sys.setrecursionlimit(10**6)
input = sys.stdin.readline

n, m = map(int, input().split())
graph = [[] for _ in range(n+1)]
INF = sys.maxsize
v = [INF for _ in range(n+1)]
for _ in range(m):
	a, b, c = map(int, input().split())
	graph[a].append((b, c))
	graph[b].append((a, c))

def prim(start):
	q = []
	heapq.heappush(q, (0, start))
	ans, cnt = 0, 0
	while q:
		cost, idx = heapq.heappop(q)
		if v[idx] == 1:
			continue
		v[idx] = 1
		ans += cost
		cnt += 1
		for next, next_c in graph[idx]:
			if v[next] == INF:
				heapq.heappush(q, (next_c, next))
	return ans
print(prim(1))
반응형