백준 트리 디자이너 호석(22253)

트리 디자이너 호석

1. 힌트

1) 트리는 재귀적이기 때문에 경우의 수는 자식 트리에 대한 경우의 수를 가지고 만들어낼 수 있습니다.

2) 경우의 수를 적절하게 나눠서 중복되거나 빠지는 경우가 없도록 해야합니다. 우선, 단 하나의 정점도 고르지 않는 경우는 제외해야합니다. 만약 리프노드라면 경우의 수는 얼마일까요? 선택 하느냐 안하느냐로 2개라고 생각할 수 있지만, 이러면 경우의 수가 겹치게 됩니다.

3) 최종적인 점화식 구성은 dp(here,prev,flag)dp(here, prev, flag)

2. 접근

1) 트리의 특성

문제의 초반에 주어진 조건인 정점의 개수가 NN개이고 간선의 개수가 N1N-1

일반적인 그래프에서는 사이클이 존재할 수 있기 때문에 문제를 재귀적으로 구성해서 풀 수 없습니다. 하지만 트리에서는 이야기가 다릅니다. 트리는 간선에 방향을 준다면 DAG이기때문에 상태공간을 잘 정의해서 다이나믹 프로그래밍을 적용시킬 수 있습니다.

2) DAG 만들기

문제에 입력으로 주어지는 트리는 부모 자식 관계로 주어지는 것이 아니라 단순히 연결 관계로 주어지기 때문에 간선에 방향성을 주어서 DAG로 만들어야 합니다. 문제에서 11번 정점이 루트라고 주었기 때문에 11번 정점부터 그래프를 탐색하면서 새로운 정점을 만날때 마다 (here,there)(here, there)간에 간선을 추가해주면서 그래프를 만들 수 있습니다. 똑같은 정점을 여러 번 방문하지 않기 때문에 시간 복잡도는 O(V+E)O(V + E)

3) 점화식 구성

문제는 점화식 구성입니다. 경우의 수를 잘못 나눴다가는 중복되거나 빼먹을 수 있기 때문입니다. 점화식을 쉽게 구성하는 방법은 작은 크기의 데이터를 만들어보고 손으로 직접 해보는 것입니다. 다음과 같은 트리에서는 55가지 방법이 가능합니다.

대충 점화식의 인자부터 넣으면서 식을 세워봅시다. 당연히 현재 몇 번째 정점인지는 넣어야겠죠 또, 우리는 오름차순으로 정점을 골라야 하므로 맨 마지막에 고른 정점의 번호도 저장을 해둡시다. 그러면 다음과 같습니다.

dp(here,prev)=dp(here, prev)=

이 점화식의 Base case부터 생각해봅시다. 트리에서 Base case라고 하면 당연히 리프노드인 경우겠죠. 경우의 수는 그 리프노드를 안 고르는 경우 11가지, 고를 수 있다면 고르는 경우 11가지입니다. 그런데 이러면 문제의 정의와 모순이 생깁니다. 우리는 오름차순으로 정점을 선택하는 경우의 수라고 정의했기 때문이죠. 예제 입력 1에서도 알 수 있듯이 한 번도 고르지 않은 경우는 경우의 수에 포함이 되지 않습니다.
어쩔수 없이 새로운 인자를 추가해줍시다. flagflagherehere번째 정점을 선택했는지 안했는지 여부입니다.
Base Case : flag=1flag = 1

dp(here,prev,flag)=dp(here, prev, flag)=

점화식은 간단하게 모든 자식들에 대해서 고를 수 있다면 고르는 경우, 고르지 않는 경우 두 가지로 나눌 수 있습니다.
dp(here,prev,flag)=dp(there,prev,0)+dp(there,S[there],1)dp(here, prev, flag)= \sum dp(there,prev,0) +\sum dp(there, S[there], 1)

3. 구현

dfs(here, prev)

adj는 단순히 연결 관계만 나타내는 간선입니다. 이를 부모-자식 관계의 간선 children으로 바꿔주기 위해 그래프를 한 번 탐색해줍니다.

dp(here, prev, flag)

모듈러 연산자는 비싼 연산자입니다. 이를 MOD보다 커지면 MOD를 빼 주는 식으로 구현할 수 있습니다.

public class Main {
    static int N;
    static int[] S;
    static ArrayList<ArrayList<Integer>> adj;
    static ArrayList<ArrayList<Integer>> children;
    
    static void dfs(int here, int prev) {
        for (int there : adj.get(here)) if (there != prev) {
            children.get(here).add(there);
            dfs(there, here);
        }
    }

    static int[][][] cache;

    static final int MOD = 1000000007;

    // here번째 정점을 루트로 하는 서브 트리에서
    // 가장 마지막으로 고른 정점에 적힌 정수가 prev이고 (안 골랐어도 0)
    // here번째 정점을 골랐는지 여부 flag가 주어질 때,
    // 오름차순으로 정점을 고르는 경우의 수
    static int dp(int here, int prev, int flag) {
        // 리프 노드인 경우 here 정점을 고르는 경우만 경우를 하나 찾은 것
        if (children.get(here).isEmpty()) return flag;
        if (cache[here][prev][flag] != -1) return cache[here][prev][flag];
        // here 정점을 선택하고 here의 자손으로부터는 안 고르는 경우 1
        int sum = flag;
        for (int there : children.get(here)) {
            // there을 고를 수 있으면 고르는 경우
            if (prev <= S[there]) sum += dp(there, S[there], 1);
            if (sum >= MOD) sum -= MOD;
            // 고르지 않는 경우
            sum += dp(there, prev, 0);
            if (sum >= MOD) sum -= MOD;
        }
        return cache[here][prev][flag] = sum;
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        N = Integer.parseInt(br.readLine());
        S = new int[N + 1];
        StringTokenizer st = new StringTokenizer(br.readLine(), " ");
        for (int i = 1; i <= N; i++) S[i] = Integer.parseInt(st.nextToken());
        adj = new ArrayList<>();
        for (int i = 0; i <= N; i++) adj.add(new ArrayList<>());
        for (int i = 0; i < N - 1; i++) {
            st = new StringTokenizer(br.readLine(), " ");
            int u = Integer.parseInt(st.nextToken());
            int v = Integer.parseInt(st.nextToken());
            adj.get(u).add(v); adj.get(v).add(u);
        }
        children = new ArrayList<>();
        for (int i = 0; i <= N; i++) children.add(new ArrayList<>());
        dfs(1, 1);
        cache = new int[N + 1][10][2];
        for (int i = 0; i < cache.length; i++)
            for (int j = 0; j < cache[i].length; j++)
                Arrays.fill(cache[i][j], -1);
        System.out.println((dp(1, 0, 0) + dp(1, S[1], 1)) % MOD);
    }

}

좋은 웹페이지 즐겨찾기