[Algorithm] - 1. 최단거리 알고리즘

0. 최단 거리 알고리즘에 대한 설명

: 특정한 하나의 정점에서, 다른 정점들으로 가는 최소 거리를 계산하는 방법이다. 같은 방법으로 플로이드- 와샬 방법이 있지만, 이 경우는 모든 정점에서 모든 정점으로 가는 최소 거리를 계산하는 방법이라 조금 차이가 있다.
: 간단히 정리하자면, 정점간 최단경로를 모두 구해야 하는 경우에는 보통 플로이드 와샬이 우세하다.
: 시작점으로부터 나머지 정점까지 최단거리만 구해도 되는 경우면 다익스트라가 낫다.

https://codedoc.tistory.com/95

하나의 정점일 경우는 다익스트라, 전체의 정점의 거리일 경우에는 플로이드와샬을 쓴다고 생각하면 쉽다.

: 코드는 노드의 개수를 N, 간선의 개수를 M으로 입력받는다고 가정한다.

1. 다익스트라

: 유의할 점으로는, 음의 가중치를 지닌 경우 사용할 수 없다는 점이 있다.

1.1 Python

: heapq를 사용해야 시간 안에 통과할 수 있는 경우가 많다.

코드는 나동빈님의 '이것이 취업을 위한 코딩테스트다'를 참고했다.

A. 간선과 노드를 이용한 기본적인 다익스트라 문제

가장 대표적인 문제의 예시는 다음과 같다.
https://www.acmicpc.net/problem/1753

: 아래와 같은 코드를 쓰면 된다.

import heapq
import sys
input = sys.stdin.readline

N = int(input())
M = int(input())
graph = [[]for i in range(N + 1)]
distance = [int(1E9)] * (N + 1)

for i in range(M):
    start, end, time = map(int, input().split())
    graph[start].append((end, time))

def dijkstra(start):
    queue = []
    heapq.heappush(queue, (0, start))
    distance[start] = 0
    while len(queue) > 0:
        dist, now = heapq.heappop(queue)
        if distance[now] < dist:
            continue
        for i in graph[now]:
            cost = dist + i[1]
            if cost < distance[i[0]]:
                went[i[0]] = []
                for j in went[now]:
                    went[i[0]].append(j)
                went[i[0]].append(i[0])
                distance[i[0]] = cost
                heapq.heappush(queue, (cost, i[0]))

dijkstra(start)

: 노드의 방문 순서까지 출력해야 되는 아래와 같은 문제의 경우에는, 다익스트라 갱신 과정에서 배열에 입출력하면 된다.

https://www.acmicpc.net/problem/11779

def dijkstra(start):
    queue = []
    heapq.heappush(queue, (0, start))
    distance[start] = 0
    while len(queue) > 0:
        dist, now = heapq.heappop(queue)
        if distance[now] < dist:
            continue
        for i in graph[now]:
            cost = dist + i[1]
            # 이 밑에 부분을 보면 된다. 
            if cost < distance[i[0]]:
                went[i[0]] = []
                for j in went[now]:
                    went[i[0]].append(j)
                went[i[0]].append(i[0])
                distance[i[0]] = cost
                heapq.heappush(queue, (cost, i[0]))

B. 배열의 최소 이동거리를 구하는 다익스트라 문제

: BFS를 사용해도 되지만, 아래 사진과 같이 heapq를 사용했을 때의 속도가 훨씬 빨랐다.
: 위의 다익스트라와 동일하게, 전체 배열에 INF 값을 배정하고, queue를 이용하여 돌리면서 heapq로 값을 정렬해 배열을 갱신한다.

https://www.acmicpc.net/problem/4485

import heapq
import sys

input = sys.stdin.readline
dx = [0,0,1,-1]
dy = [1,-1,0,0]
def dijkstra(graph, N):
    queue = []
    cost_sum = [[int(1e9)] * N for _ in range(N)]
    heapq.heappush(queue, [graph[0][0], 0,0])
    cost_sum[0][0] = 0
    answer = 0
    while queue:
        cost, x, y = heapq.heappop(queue)
        if x == N-1 and y == N-1:
            answer = cost
            break
        for i in range(4):
            nx = x + dx[i]
            ny = y + dy[i]
            if 0<=nx< N and 0<=ny < N:
                nCost = cost + graph[nx][ny]
                if nCost < cost_sum[nx][ny]:
                    cost_sum[nx][ny] = nCost
                    heapq.heappush(queue, [nCost,nx,ny])
    return answer

cnt = 1
while 1:
    N = int(input())
    if N == 0:
        break
    cost_graph = [[0] * N for _ in range(N)]
    for i in range(N):
        cost_graph[i] = list(map(int,input().split()))
    print("Problem "+ str(cnt) + ": " + str(dijkstra(cost_graph, N)))
    cnt += 1

1.2 Java

: priority queue를 사용해야 빠르다.

A. 간선과 노드를 이용한 기본적인 다익스트라 문제

import java.io.*;
import java.util.*;

class Node implements Comparable<Node>{
    int end, weight;

    public Node(int end, int weight){
        this.end = end;
        this.weight = weight;
    }

    @Override
    public int compareTo(Node o) {
        return weight - o.weight;
    }
}

public class Main {
    private static final BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    private static final BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
    private static final int INF = 100_000_000;
    static int v,e,k;
    static List<Node>[] graph;
    static int[] dist;


    public static void main(String[] args) throws IOException {
        StringTokenizer st = new StringTokenizer(br.readLine());
        v = Integer.parseInt(st.nextToken());
        e = Integer.parseInt(st.nextToken());
        k = Integer.parseInt(br.readLine());
        graph = new ArrayList[v + 1];
        dist = new int[v + 1];

        Arrays.fill(dist, INF);

        for(int i = 1; i <= v; i++){
        	graph[i] = new ArrayList<>();
        }
        // 리스트에 그래프 정보를 초기화
        for(int i = 0 ; i < e; i++){
            st = new StringTokenizer(br.readLine());
            int start = Integer.parseInt(st.nextToken());
            int end = Integer.parseInt(st.nextToken());
            int weight = Integer.parseInt(st.nextToken());
            // start에서 end로 가는 weight 가중치
            graph[start].add(new Node(end, weight));
        }

        StringBuilder sb = new StringBuilder();
        // 다익스트라 알고리즘
        dijkstra(k);
        // 출력 부분
        for(int i = 1; i <= v; i++){
            if(dist[i] == INF) sb.append("INF\n");
            else sb.append(dist[i] + "\n");
        }

        bw.write(sb.toString());
        bw.close();
        br.close();
    }

    private static void dijkstra(int start){
       PriorityQueue<Node> queue = new PriorityQueue<>();
       boolean[] check = new boolean[v + 1];
       queue.add(new Node(start, 0));
       dist[start] = 0;

       while(!queue.isEmpty()){
           Node curNode = queue.poll();
           int cur = curNode.end;

           if(check[cur] == true) continue;
           check[cur] = true;

           for(Node node : graph	[cur]){
               if(dist[node.end] > dist[cur] + node.weight){
                   dist[node.end] = dist[cur] + node.weight;
                   queue.add(new Node(node.end, dist[node.end]));
               }
           }
       }
    }
}

B. 배열의 최소 이동거리를 구하는 다익스트라 문제

: 파이썬과 코드의 흐름은 같다. 따로 포인트 구조체를 만들어서 sum_cnt 배열에 그 배열까지 갈수 있는 최소값을 갱신하여 N-1에 닿을때까지 반복문을 돌린다.

import java.io.*;
import java.io.InputStreamReader;
import java.util.*;

public class Main {
	static class point implements Comparable<point>{
		int x, y, cost;
		public point(int x, int y, int cost) {
			super();
			this.x = x;
			this.y = y;
			this.cost = cost;
		}
		@Override
		public int compareTo(point O) {
			return this.cost- O.cost;
		}
	}
	public static int dijkstra() {
		PriorityQueue<point> pq = new PriorityQueue<point>();
		sum_cnt[0][0] = graph[0][0];
		pq.offer(new point(0,0,graph[0][0]));
		while(!pq.isEmpty()) {
			point temp = pq.poll();
			for(int i = 0; i <4; i ++) {
				int nx = temp.x + dx[i];
				int ny = temp.y + dy[i];
				if(0<= nx && nx <N && 0<= ny && ny<N) { 
					if (sum_cnt[nx][ny] > sum_cnt[temp.x][temp.y] + graph[nx][ny]) {
						sum_cnt[nx][ny] = sum_cnt[temp.x][temp.y] + graph[nx][ny];
						pq.offer(new point(nx,ny,graph[nx][ny]));
					}
				}
				
			}
		}
		return sum_cnt[N-1][N-1];
	}
	static int N;
	static int[][] graph;
	static int [][] sum_cnt;
	static int dx[] = {-1,1,0,0};
	static int dy[] = {0,0,1,-1};
	
	
    public static void main(String[] args) throws Exception{
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));  
        StringTokenizer st = null; 
        StringBuilder sb = new StringBuilder();
        int cnt = 1;
        while(true) {
        	N = Integer.parseInt(br.readLine());
        	if(N==0) {
        		break;
        	}
        	graph = new int [N+1][N+1];
        	sum_cnt = new int[N+1][N+1];
        	for(int i = 0; i <N; i ++) {
        		st = new StringTokenizer(br.readLine());
        		for(int j = 0; j <N; j ++) {
        			graph[i][j] = Integer.parseInt(st.nextToken());
        			sum_cnt[i][j] = Integer.MAX_VALUE;
        		}
        	}
        	sb.append("Problem " + cnt + ": " + dijkstra() + "\n"); 
        	cnt +=1;
        }
        System.out.println(sb); 
        br.close();
    }
}

2. 플로이드-와샬

: 3중 반복문을 돌리면 된다. 대표적인 문제는 아래 문제이다.
: 아래 문제는 입력 받을때를 제외하고 가장 기본적인 풀이이다.

https://www.acmicpc.net/problem/11404

2.1 Python

import sys
input = sys.stdin.readline
N = int(input())
M = int(input())
INF = int(1e9)
graph = [[INF] * (N) for i in range(N)]

while(M):
    fromv, tov, time = map(int, input().split())
    fromv -=1
    tov-=1
    graph[fromv][tov] = min(graph[fromv][tov],time)
    M-=1
    
for k in range(N):
    for i in range(N):
        for j in range(N):
            graph[i][j] = min(graph[i][j], graph[i][k] + graph[k][j])
            if i ==j:
                graph[i][j] =0

for i in range(N):
    for j in range(N):
        if graph[i][j] == INF:
            print(0, end=' ')
        else:
            print(graph[i][j],end =' ')
    print()

2.2 Java

import java.io.*;
import java.util.*;

public class Main {
	static final int INF =  987654321;
	
    public static void main(String[] args) throws NumberFormatException,IOException {
    	BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
        StringTokenizer st;
        int N = Integer.parseInt(br.readLine());
        int M = Integer.parseInt(br.readLine());
        int[][] arr = new int[N + 1][N + 1];
        
        for(int i = 0; i <N; i ++) {
        	for(int j = 0; j <N; j ++) {
        		arr[i][j] = INF;
        		if(i==j) {
        			arr[i][j] = 0;
        		}
        	}
        }
        for(int i = 0; i <M; i ++) {
        	st = new StringTokenizer(br.readLine());
        	int a = Integer.parseInt(st.nextToken());
        	int b = Integer.parseInt(st.nextToken());
        	int c = Integer.parseInt(st.nextToken());
        	arr[a-1][b-1] = Math.min(arr[a-1][b-1], c);
        }
        for(int k = 0; k <N; k++) {
        	for(int i = 0; i <N; i++) {
        		for(int j = 0; j <N; j++) {
        			if(arr[i][j] > arr[i][k] + arr[k][j]) {
        				arr[i][j] = arr[i][k] + arr[k][j];
        			}
        		}
        	}
        }
        StringBuilder sb = new StringBuilder();
        for(int i = 0; i <N; i++) {
    		for(int j = 0; j <N; j++) {
    			if(arr[i][j] == INF) {
    				arr[i][j] = 0;
    			}
    			sb.append(arr[i][j] + " ");
    		}
    		sb.append("\n");
    	}
        bw.write(sb.toString());
        bw.flush();
        bw.close();
        br.close();
    }

}

좋은 웹페이지 즐겨찾기