Codeforces 367E Sereja and Intervals DP

제목 대의:
바로 현재 길이가 m인 구간에서 하나의 구간을 찾아내는 것이다. 이 구간은 n개가 있어야 하고 그 중 임의의 두 구간은 관계를 포함하지 않으며 최소한 한 구간이 존재해야 한다. 왼쪽 경계의 값은 x이다.
(1<=n*m<=100000, 1<=x<=m), 이런 구간의 배열 종수를 구하고 마지막 결과는 10^9+7 모드 출력에 대한
대략적인 사고방식:
그냥 dp...상태 이동 방정식은 코드 주석을 보십시오
코드는 다음과 같습니다.
Result  :  Accepted     Memory  :  1408 KB     Time  :  140 ms
/*
 * Author: Gatevin
 * Created Time:  2015/3/23 21:21:17
 * File Name: Chitoge_Kirisaki.cpp
 */
#include<iostream>
#include<sstream>
#include<fstream>
#include<vector>
#include<list>
#include<deque>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<bitset>
#include<algorithm>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cctype>
#include<cmath>
#include<ctime>
#include<iomanip>
using namespace std;
const double eps(1e-8);
typedef long long lint;

const lint mod = 1e9 + 7;
lint dp[2][320][320];
int n, m, x;

/*
 *  dp[i][l][r]    i    ,  l     r         
 *    i        ,    ,           ,          
 *    
 * dp[i][l][r] = dp[i - 1][l][r] + dp[i - 1][l - 1][r] + dp[i - 1][l][r - 1] + dp[i - 1][l - 1][r - 1]
 *   4      
 *  i == x ,       ,           
 *   dp[x][l][r] = dp[i - 1][l - 1][r] + dp[i - 1][l - 1][r - 1]
 *    dp[0][0][0] = 1,   i              
 */
void Add(lint& a, lint b)
{
    a += b;
    if(a >= mod) a -= mod;
    return;
}

lint solve()
{
    dp[0][0][0] = 1;
    int now = 1;
    register int i, l, r;
    for(i = 1; i <= m; i++, now ^= 1)
        for(l = 0; l <= n; l++)
            for(r = 0, dp[now][l][0] = 0; r <= l; r++, dp[now][l][r] = 0)
                if(i == x)
                {
                    if(l) Add(dp[now][l][r], dp[now ^ 1][l - 1][r]);
                    if(l && r) Add(dp[now][l][r], dp[now ^ 1][l - 1][r - 1]);
                }
                else
                {
                    Add(dp[now][l][r], dp[now ^ 1][l][r]);
                    if(l - 1 >= r) Add(dp[now][l][r], dp[now ^ 1][l - 1][r]);
                    if(r) Add(dp[now][l][r], dp[now ^ 1][l][r - 1]);
                    if(l && r) Add(dp[now][l][r], dp[now ^ 1][l - 1][r - 1]);
                }
    lint ret = dp[now ^ 1][n][n];
    for(int i = 1; i <= n; i++)
        ret = (ret*i) % mod;
    return ret;
}

int main()
{
    scanf("%d %d %d", &n, &m, &x);
    lint ans;
    if(n > m) ans = 0;
    else ans = solve();
    printf("%I64d
", ans); return 0; }

좋은 웹페이지 즐겨찾기