트리 dp

5219 단어 dpcodeforces트리 DP

제목 링크:


codeforces 161D

제목 대의:


나무 한 그루를 주십시오. 각 변의 변권은 1입니다. 두 점 사이의 경로 길이가 k인 점 쌍은 몇 개입니까?

제목 분석:

  • 정의 상태 dp[i][k]는 i를 뿌리로 하는 서브트리의 점 도착점 i의 길이가 k인 점의 개수를 대표한다.V를 u와 인접한 점의 집합으로 정의하고 p는 u의 아버지
  • 그리고 전이 방정식은 매우 간단하다.
    dp[u][j]=∑v∈Vdp[v][j−1]
  • 그리고 우리는 처리된 dp수조를 이용하여 점 i에서 모든 점까지의 지름 길이가 k의 개수와 같은 조작을 할 수 있다.
  • 전이 방정식은 다음과 같다.
    dp[u][j]+=dp[p][j−1]−dp[u][j−2]
  • 아버지가 이미 유지를 받았기 때문에 현재 dp[p][j-1]는 p점에서 모든 점 중 길이가 k인 점의 개수를 표시하고 현재 자수에 존재하는 점을 뺀다. 그 다음에 u가 아닌 자수에 있는 점에서 u까지의 거리가 j인 점의 개수를 나타낸다. 마지막으로 각 점을 들어 그들의 ∑ni=1dp[u][k]를 통계한다. 요구에 부합되는 경로마다 두 번 계산했기 때문에 마지막 결과는 2로 나누어야 한다.

  • AC 코드:

    #include <iostream>
    #include <cstring>
    #include <cstdio>
    #include <vector>
    #include <algorithm>
    #define MAX 50007
    
    using namespace std;
    
    typedef long long LL;
    
    int n,k,a,b;
    LL dp[MAX][507],ans;
    vector<int> e[MAX];
    
    void add ( int u , int v )
    {
        e[u].push_back ( v );
        e[v].push_back ( u );
    }
    
    void Clear ( )
    {
        for ( int i = 0 ; i < MAX ; i++ )
            e[i].clear();
    }
    
    void dfs ( int u , int p )
    {
        dp[u][0] = 1;
        for ( int i = 1 ; i <= k ; i++ )
            dp[u][i] = 0;
        for ( int i = 0 ; i < e[u].size() ; i++ )
        {
            int v = e[u][i];
            if ( v == p ) continue;
            dfs ( v , u );
            for ( int j = 1 ; j <= k ; j++ )
                dp[u][j] += dp[v][j-1];
        }
    }
    
    void solve ( int u , int p )
    {
        for ( int i = 0 ; i < e[u].size() ; i++ )
        {
            int v = e[u][i];
            if ( v == p ) continue;
            for ( int j = k; j >= 1 ; j-- )
            {
                dp[v][j] += dp[u][j-1];
                if ( j > 1 ) dp[v][j] -= dp[v][j-2];
            }
            solve ( v , u );
        }
    }
    
    int main ( )
    {
        while ( ~scanf ( "%d%d" , &n , &k ) )
        {
            Clear();
            for ( int i = 1 ; i < n ; i++ )
            {
                scanf ( "%d%d" , &a , &b );
                add ( a , b );
            }
            ans = 0;
            dfs ( 1 , -1 );
            solve ( 1 , -1 );
            /*for ( int i = 1; i <= n ; i++ ) for ( int j = 0 ; j <= k ; j++ ) cout << i << " " << j << " " << dp[i][j] << endl;*/
            for ( int i = 1 ; i <= n ; i++ )
                ans += dp[i][k];
            printf ( "%I64d
    "
    , ans/2LL ); } }

    좋은 웹페이지 즐겨찾기