BAEKJOON #1922 네트워크 연결 (Graph, 최소 스패닝 트리) - python

네트워크 연결

출처 : 백준 #1922

시간 제한메모리 제한
2초256MB

문제

도현이는 컴퓨터와 컴퓨터를 모두 연결하는 네트워크를 구축하려 한다. 하지만 아쉽게도 허브가 있지 않아 컴퓨터와 컴퓨터를 직접 연결하여야 한다. 그런데 모두가 자료를 공유하기 위해서는 모든 컴퓨터가 연결이 되어 있어야 한다. (a와 b가 연결이 되어 있다는 말은 a에서 b로의 경로가 존재한다는 것을 의미한다. a에서 b를 연결하는 선이 있고, b와 c를 연결하는 선이 있으면 a와 c는 연결이 되어 있다.)

그런데 이왕이면 컴퓨터를 연결하는 비용을 최소로 하여야 컴퓨터를 연결하는 비용 외에 다른 곳에 돈을 더 쓸 수 있을 것이다. 이제 각 컴퓨터를 연결하는데 필요한 비용이 주어졌을 때 모든 컴퓨터를 연결하는데 필요한 최소비용을 출력하라. 모든 컴퓨터를 연결할 수 없는 경우는 없다.


입력

첫째 줄에 컴퓨터의 수 N (1 ≤ N ≤ 1000)가 주어진다.

둘째 줄에는 연결할 수 있는 선의 수 M (1 ≤ M ≤ 100,000)가 주어진다.

셋째 줄부터 M+2번째 줄까지 총 M개의 줄에 각 컴퓨터를 연결하는데 드는 비용이 주어진다. 이 비용의 정보는 세 개의 정수로 주어지는데, 만약에 a b c 가 주어져 있다고 하면 a컴퓨터와 b컴퓨터를 연결하는데 비용이 c (1 ≤ c ≤ 10,000) 만큼 든다는 것을 의미한다. a와 b는 같을 수도 있다.


출력

모든 컴퓨터를 연결하는데 필요한 최소비용을 첫째 줄에 출력한다.


입출력 예시

예제 입력 1

6
9
1 2 5
1 3 4
2 3 2
2 4 7
3 4 6
3 5 11
4 5 3
4 6 8
5 6 8

예제 출력 1

23


풀이

생각

  • 모든 컴퓨터들이 각각 갈 수 있는 경로가 존재해야 한다.
  • 경로 유지비가 최소가 되어야 한다.
  • 따라서 최소 스패닝 트리라고 생각할 수 있다.
  • 최소 스패닝 트리를 푸는 방법에는 크루스칼 알고리즘, 프림 알고리즘이 있다.
  • 크루스칼 알고리즘은 일종의 Greedy 알고리즘이다. 최소 비용대로 고려하여 최적해를 찾기 때문이다.

풀이 설명

  • 각 컴퓨터를 연결하여 주는 union_parent과 각 컴퓨터가 같은 집합 안에 속해있는지를 알려주는 find_parent 함수를 이용하였다.
  • 핵심은 비용과 연결 노드들을 받는 arr 배열의 정렬이다.
    • 정렬이 되어야 최솟값부터 찾아내므로 최소 비용으로 스패닝 트리를 완성할 수 있기 때문이다.

python code (Kruskal Algorithm)

# 백준 1922번 네트워크 연결
from sys import stdin

def find_parent(parent, x):
    if parent[x] != x:
        parent[x] = find_parent(parent, parent[x])
    return parent[x]

def union_parent(parent, a, b):
    a = find_parent(parent, a)
    b = find_parent(parent, b)
    if a < b:
        parent[b] = a
    else:
        parent[a] = b


input = stdin.readline

n = int(input())
m = int(input())
parent = [0] * (n+1)
for i in range(1, n+1):
    parent[i] = i

arr = []
result = 0
for i in range(m):
    a, b, c = map(int, input().split())
    arr.append((c, a, b))
arr.sort() # sort를 꼭 해줘야 한다 - 그래야 최소 값부터 find & union을 하니까
for x in arr:
    c, a, b = x
    if find_parent(parent, a) != find_parent(parent, b):
        union_parent(parent, a, b)
        result += c
print(result)

python code (Prim Algorithm)

# 백준 1922번 네트워크 연결 (Prim)
from sys import stdin
import heapq

input = stdin.readline
n = int(input())
m = int(input())

graph = [[] for _ in range(n+1)]
for i in range(m):
    a, b, c = map(int, input().split())
    graph[a].append((c, b))
    graph[b].append((c, a)) # 양방향 연결이므로 양쪽 노드 다 해줘야 함!

def prim(graph, start):
    edge_cnt = -1
    visited = [False] * (n+1)
    queue = []
    heapq.heappush(queue, (0, start))
    total_weight = 0

    while edge_cnt < n-1:
        if not queue:
            return False
        weight, node = heapq.heappop(queue)
        if visited[node] != True:
            edge_cnt += 1
            visited[node] = True
            total_weight += weight
            for x in graph[node]:
                heapq.heappush(queue, (x[0], x[1]))
            # print(weight, queue)
    return total_weight

print(prim(graph, 1))       
# min_weight = int(1e9)     # greedy 알고리즘... 모든 노드에서 동일하게 값(최적해)이 나온다.
# for i in range(1, n+1):
#     w = prim(graph, i)
#     print(w)
#     if min_weight > w:
#         min_weight = w
# print(min_weight)

좋은 웹페이지 즐겨찾기