[백준] 1613번 역사

  • 출처 : https://www.acmicpc.net/problem/1613

  • 문제 : 백준 1613번 역사

  • 틀린 풀이 : 위상 정렬, DFS

    1. 각 노드를 정렬 (위상 정렬)
    2. 연결되지 않은 노드 분리 (DFS)
    3. 연결된 노드일 경우 정렬된 리스트에서 인덱스 비교하여 1 또는 -1 출력
    4. 연결되지 않았을 경우 0 출력
  • 틀린 이유 : 연결되어 있지만 관계를 모르는 경우가 존재

  • 맞은 풀이 : Floyd-Warshall, DP

    1. 2차원 거리 리스트 초기화
    2. 리스트[k][i][j] : 1부터 k까지의 노드만을 지나 i에서 j까지 도달하는 거리
    3. 모든 노드 간 거리 계산(N^2)을 N번 반복(3중 for문)
      • 계산법 : 리스트[i][j] = min(리스트[i][j], 리스트[i][k]+리스트[k][j])
    4. 평가하려는 출발지로부터 도착지까지 도달할 수 없으면 거리 무한대
  • 정리

    • 굳이 거리를 계산할 필요 없을 것 같아 bool 타입으로 바꿔보았다

    • Floyd-Warshall 알고리즘

      • O(N^3)
      • 3차원 배열을 사용하지 않아도 되는 이유는 이해하기 조금 어렵다
      • all-to-all 최단 경로 알고리즘
        • one-to-one 또는 one-to-all : 벨만-포드, 다익스트라 알고리즘
    • 다익스트라 알고리즘(N^2)을 모든 노드에 적용한 것(N)과 시간 복잡도 면에서 유사

틀린 코드

import sys, bisect


def init():
    n, k = map(int, sys.stdin.readline().rstrip().split(' '))
    order_list = []
    directed_list = [[] for _ in range(n+1)]
    undirected_list = [[] for _ in range(n+1)]
    for i in range(k):
        start, end = map(int, sys.stdin.readline().rstrip().split(' '))
        order_list.append((start, end))
        directed_list[start].append(end)
        undirected_list[start].append(end)
        undirected_list[end].append(start)
    s = int(sys.stdin.readline().rstrip())
    pair_list = []
    for i in range(s):
        start, end = map(int, sys.stdin.readline().rstrip().split(' '))
        pair_list.append((start, end))
    part_list = [0 for _ in range(n+1)]
    index_list = [-1 for _ in range(n+1)]
    sorted_list = []
    return n, k, order_list, s, pair_list, part_list, undirected_list, directed_list, sorted_list, index_list


def dfs_ts(visited, curr_num):
    visited[curr_num] = True
    for next_num in directed_list[curr_num]:
        if not visited[next_num]:
            dfs_ts(visited, next_num)
    sorted_list.append(curr_num)
    index_list[curr_num] = len(sorted_list) - 1
    

def set_sorted_list():
    visited = [False for i in range(n+1)]
    for i in range(1, n+1):
        if not visited[i]:
            dfs_ts(visited, i)


def dfs(curr_num, part):
    part_list[curr_num] = part
    for next_num in undirected_list[curr_num]:
        if part_list[next_num] == 0:
            dfs(next_num, part)


def set_part_list():
    part = 1
    for i in range(1, n+1):
        if part_list[i] == 0:
            dfs(i, part)
            part += 1


def check(pair):
    start, end = pair
    if part_list[start] == part_list[end]:
        if index_list[start] < index_list[end]:
            return 1
        else:
            return -1
    else:
        return 0


n, k, order_list, s, pair_list, part_list, undirected_list, directed_list, sorted_list, index_list = init()
set_sorted_list()
set_part_list()
for pair in pair_list:
    sys.stdout.write(f'{check(pair)}\n')

맞은 코드

import sys


def init():
    n, k = map(int, sys.stdin.readline().rstrip().split(' '))
    order_list = [tuple(map(int, sys.stdin.readline().rstrip().split(' '))) for _ in range(k)]
    s = int(sys.stdin.readline().rstrip())
    pair_list = [tuple(map(int, sys.stdin.readline().rstrip().split(' '))) for _ in range(s)]
    # dist_list = [[float('inf') for _ in range(n+1)] for _ in range(n+1)]
    dist_list = [[False for _ in range(n+1)] for _ in range(n+1)]
    return n, k, order_list, s, pair_list, dist_list


n, k, order_list, s, pair_list, dist_list = init()
for start, end in order_list:
    # dist_list[start][end] = 1
    dist_list[start][end] = True
for k in range(1, n+1):
    for i in range(1, n+1):
        for j in range(1, n+1):
            # dist_list[i][j] = min(dist_list[i][j], dist_list[i][k]+dist_list[k][j])
            dist_list[i][j] = dist_list[i][j] or (dist_list[i][k] and dist_list[k][j])
for start, end in pair_list:
    # if dist_list[start][end] == float('inf') and dist_list[end][start] == float('inf'):
    #     sys.stdout.write('0\n')
    # elif dist_list[start][end] == float('inf'):
    #     sys.stdout.write('1\n')
    # elif dist_list[end][start] == float('inf'):
    #     sys.stdout.write('-1\n')
    if not dist_list[start][end] and not dist_list[end][start]:
        sys.stdout.write('0\n')
    elif not dist_list[start][end]:
        sys.stdout.write('1\n')
    elif not dist_list[end][start]:
        sys.stdout.write('-1\n')

좋은 웹페이지 즐겨찾기