省选模拟赛 LYK loves string(string)

题目描述

LYK喜欢字符串,它认为一个长度为n的字符串一定会有n*(n+1)/2个子串,但是这些子串是不一定全部都不同的,也就是说,不相同的子串可能没有那么多个。LYK认为,两个字符串不同当且仅当它们的长度不同或者某一位上的字符不同。LYK想知道,在字符集大小为k的情况下,有多少种长度为n的字符串,且该字符串共有m个不相同的子串。

由于答案可能很大,你只需输出答案对1e9+7取模后的结果即可。

输入格式(string.in)

一行3个数n,m,k

输出格式(string.out)

 一行,表示方案总数。

输入样例

2 3 3

输出样例

6

样例解释

共有6种可能,分别是ab,ac,ba,bc,ca,cb

数据范围

对于20%的数据:1<=n,k<=5

对于40%的数据:1<=n<=5,1<=k<=1000000000

对于60%的数据:1<=n<=8,1<=k<=1000000000

对于100%的数据:1<=n<=10,1<=m<=100,1<=k<=1000000000

分析:很容易想歪的一道题.

   一开始想到dp,这要怎么dp呢?状压dp吗?状态不好用0/1表示......

   考虑到n很小,尝试搜索. 搜第i位的字符是哪一个?显然不行,字符集太大了. 其实不同子串的个数只与每个字符的相对大小有关.所以搜每一位的字符的相对大小即可.  如果最后搜出来的有k个不同相对大小的字符,答案乘上A(k,i)即可(排列数).

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

typedef long long ll;

const ll mod = 1e9+7;

ll n,m,k,f[30][300][30],vis[1500],tot,cnt,num[1500],a[30000],ans;

void solve()
{
    cnt = 0;
    tot = 0;
    memset(vis,0,sizeof(vis));
    for (int i = 1; i <= n; i++)
    {
        if (!vis[num[i]])
        {
            tot++;
            vis[num[i]] = 1;
        }
    }
    for (int i = 1; i <= n; i++)
        for (int j = 1; j <= n; j++)
    {
        int r = i + j - 1;
        if (r > n)
            break;
        ll temp = 0;
        for (int k = j; k <= r; k++)
            temp = temp * 10 + num[k];
        a[++cnt] = temp;
    }
    sort(a + 1,a + 1 + cnt);
    cnt = unique(a + 1,a + 1 + cnt) - a - 1;
    f[n][cnt][tot]++;
}

void dfs(int dep)
{
    if (dep == n + 1)
    {
        solve();
        return;
    }
    memset(vis,0,sizeof(vis));
    tot = 0;
    for (int i = 1; i < dep; i++)
    {
        if (!vis[num[i]])
        {
            tot++;
            vis[num[i]] = 1;
        }
    }
    num[dep] = tot + 1;
    dfs(dep + 1);
    int flag[15];
    memset(flag,0,sizeof(flag));
    for (int i = 1; i < dep; i++)
    {
        if (!flag[num[i]])
        {
            flag[num[i]] = 1;
            num[dep] = num[i];
            dfs(dep + 1);
        }
    }
}

ll A(ll x,ll y)
{
    ll res = 1;
    for (ll i = 1; i <= y; i++)
        res = res * (x - i + 1) % mod;
    return res % mod;
}

int main()
{
    scanf("%lld%lld%lld",&n,&m,&k);
    dfs(1);
    for (int i = 1; i <= n; i++)
    {
        ans += f[n][m][i] * A(k,i) % mod;
        ans %= mod;
    }
    printf(" %lld\n",ans % mod);

    return 0;
}

 

posted @ 2018-03-25 23:24  zbtrs  阅读(267)  评论(0编辑  收藏  举报