[알고리즘] 백준 - 도로포장

백준 - 도로포장

실패한 내 풀이

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;

public class baekjoon_1162 {

    static int N, M, K;
    static List<int[]>[] graph;
    static List<int[]> roads;
    static boolean[] backTrackingVisited;
    static int min_ans = Integer.MAX_VALUE;

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        String[] inputs = br.readLine().split(" ");
        N = Integer.parseInt(inputs[0]); //도시 수
        M = Integer.parseInt(inputs[1]); //도로 수
        K = Integer.parseInt(inputs[2]); //포장할 도로 수

        graph = new ArrayList[N+1];
        backTrackingVisited = new boolean[N + 1];
        roads = new ArrayList<>();
        for (int i = 0; i < N + 1; i++) {
            graph[i] = new ArrayList<>();
        }

        for (int i = 0; i < M; i++) {
            inputs = br.readLine().split(" ");
            int start = Integer.parseInt(inputs[0]);
            int end = Integer.parseInt(inputs[1]);
            int cost = Integer.parseInt(inputs[2]); //포장할 도로 수
            graph[start].add(new int[]{cost, end});
            graph[end].add(new int[]{cost, start});
            roads.add(new int[]{start, end, cost});
        }
        solve();
        System.out.println(min_ans);

    }

    private static void solve() {
        backTracking(0, 0, K);
    }

    private static void backTracking(int curPos, int curCount, int maxCount) {

        if (curPos == roads.size() && curCount < maxCount) {
            return;
        }

        if (curCount == K) {
            //도로 설치
            construct_road();
            min_ans = Math.min(min_ans, dijkstra());
            //도로 제거
            remove_road();
            return;
        }

        backTrackingVisited[curPos] = true;
        backTracking(curPos + 1, curCount + 1, maxCount);
        backTrackingVisited[curPos] = false;
        backTracking(curPos+1, curCount, maxCount);
    }

    private static void remove_road() {
        ArrayList<int[]> constructedRoads = new ArrayList<>();

        for (int i = 1; i < backTrackingVisited.length; i++) {
            if (backTrackingVisited[i] == true) {
                constructedRoads.add(roads.get(i));
            }
        }

        for (int[] constructedRoad : constructedRoads) {
            int start = constructedRoad[0];
            int end = constructedRoad[1];
            int cost = constructedRoad[2];

            for (int i = 0; i < graph[start].size(); i++) {
                if (graph[start].get(i)[1] == end) { //
                    graph[start].get(i)[0] = cost; //원상복구
                }
            }
        }
    }

    private static void construct_road() {
        ArrayList<int[]> constructedRoads = new ArrayList<>();

        for (int i = 1; i < backTrackingVisited.length; i++) {
            if (backTrackingVisited[i] == true) {
                constructedRoads.add(roads.get(i));
            }
        }

        for (int[] constructedRoad : constructedRoads) {
            int start = constructedRoad[0];
            int end = constructedRoad[1];

            for (int i = 0; i < graph[start].size(); i++) {
                if (graph[start].get(i)[1] == end) { //포장하려는 도로와 같은 도로가 있다면
                    graph[start].get(i)[0] = 0; //포장해서 비용을 0으로 만든다
                }
            }
        }
    }

    private static int dijkstra() {
        int[] distances = new int[N+1];
        Arrays.fill(distances, Integer.MAX_VALUE);
        distances[1] = 0;
        PriorityQueue<int[]> pq = new PriorityQueue<>(new Comparator<int[]>() {
            @Override
            public int compare(int[] o1, int[] o2) {
                return o1[0] < o2[0] ? -1 : o1[0] == o2[0] ? 0 : 1;
            }
        });
        pq.add(new int[]{0, 1}); //cost, node
        while (!pq.isEmpty()) {
            int[] polls = pq.poll();
            int cost = polls[0];
            int node = polls[1];

            if (node == N) {
                return cost;
            }

            for (int[] adjNodeCost : graph[node]) {
                int adjNode = adjNodeCost[1];
                int adjCost = adjNodeCost[0];

                if (distances[adjNode] > cost + adjCost) {
                    distances[adjNode] = cost + adjCost;
                    pq.add(new int[]{cost + adjCost, adjNode});
                }
            }
        }

        return distances[N];
    }
}

처음에는 K개의 범위가 그리 크지 않아서 백트랙킹으로 완탐을 이용하여 각 경우에 대해 다익스트라를 쓰려고 했다. 하지만 메모리 초과가 났다.

다른 사람 풀이

import java.util.*;

public class Main {
    public static List<int[]>[] edges = new ArrayList[10001];
    public static int K;

    public static void main(String[] args) {
        Scanner sca = new Scanner(System.in);
        N = sca.nextInt();
        int M = sca.nextInt();
        K = sca.nextInt();
        for (int i = 0; i < 10001; i++) {
            edges[i] = new ArrayList<>();
        }
        for (int i = 0; i < M; i++) {
            int a = sca.nextInt();
            int b = sca.nextInt();
            int dis = sca.nextInt();
            int[] atob = new int[]{b, dis};
            int[] btoa = new int[]{a, dis};
            edges[a].add(atob);
            edges[b].add(btoa);
        }
        dijkstra();
        System.out.println(curMin);
    }

    public static long curMin = Long.MAX_VALUE;
    public static int N;
    public static long distance[][] = new long[10001][21];

    public static void dijkstra() {
        for (int i = 0; i < 10001; i++) {
            Arrays.fill(distance[i], Long.MAX_VALUE);
        }

        // node, dist, cut
        PriorityQueue<long[]> pq = new PriorityQueue<>(new Comparator<long[]>() {
            @Override
            public int compare(long[] o1, long[] o2) {
                return o1[1] < o2[1] ? -1 : o1[1] == o2[1] ? 0 : 1;
            }
        });
        pq.add(new long[]{1, 0, 0});
        distance[1][0] = 0;
        while (!pq.isEmpty()) {
            long[] list = pq.poll();
            int curNode = (int)list[0];
            long curDist = list[1];
            int cutCount = (int)list[2];

            if(curDist > distance[curNode][cutCount])
                continue;

            for (int i = 0; i < edges[curNode].size(); i++) {
                int next = edges[curNode].get(i)[0];
                int edgeCost = edges[curNode].get(i)[1];

                // cut 할 수 있다면, cut 해본다
                if(cutCount < K && distance[next][cutCount+1] > curDist){
                    distance[next][cutCount+1] = curDist;
                    pq.add(new long[]{next, curDist, cutCount+1});
                }

                // cut을 안해본다
                if(distance[next][cutCount] > curDist + edgeCost){
                    distance[next][cutCount] = curDist + edgeCost;
                    pq.add(new long[]{next, curDist + edgeCost, cutCount});
                }
            }
        }
        for (int i = 0; i <= K; i++) {
            curMin = Math.min(curMin, distance[N][i]);
        }
    }
}

dp를 활용해서 풀었다.

좋은 웹페이지 즐겨찾기