LibreOJ #2325. 「清华集训 2017」小Y和恐怖的奴隶主(矩阵快速幂优化DP)

  哇这题剧毒,卡了好久常数才过T_T

  设$f(i,s)$为到第$i$轮攻击,怪物状态为$s$时对boss的期望伤害,$sum$为状态$s$所表示的怪物个数,得到朴素的DP方程$f(i,s)=\sum \frac{1}{sum+1}*(f(i+1,s')+[s==s'])$

  状态数只有$C_{8+3}^3=165$个,所以就可以矩乘优化了。再加上一个用于转移的$1$,矩阵大小是$166*166$的,因为多组询问,所以可以先把$2$的所有次幂的矩阵都预处理出来。

  然后会发现复杂度是$O(T*166^3*N)$的,无法承受...

  其实答案矩阵只有一列...用它从左往右乘就能把矩阵乘法优化到$O(166^2)$了,总时间复杂度$O(166^3*logn+T*166^2*logn)$

  $16$亿过$2$秒,长见识了...

#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#include<cmath>
#include<algorithm>
#define ll long long
using namespace std;
const int maxn=200, mod=998244353;
const ll inf=8226880250554875800;
struct mtx{int mp[maxn][maxn], n, m;mtx(){memset(mp, 0, sizeof(mp)); n=m=0;}}
base[60];
mtx operator * (mtx a, mtx b)
{
    mtx c; c.n=a.n; c.m=b.m;
    for(int i=0;i<=a.n;i++)
    for(int j=0;j<=b.m;j++)
    {
        ll s=0;
        for(int k=0;k<=a.m;k++)
        s+=1ll*a.mp[i][k]*b.mp[k][j], s>inf && (s%=mod);
        c.mp[i][j]=s%mod;
    }
    return c;
}
int T, m, K, tott;
ll n;
int st[maxn], mi[maxn], pos[1<<10];
inline int power(int a, int b)
{
    int ans=1;
    for(;b;b>>=1, a=1ll*a*a%mod)
    if(b&1) ans=1ll*ans*a%mod;
    return ans;
}
int main()
{
    scanf("%d%d%d", &T, &m, &K);
    mi[0]=1; for(int i=1;i<=m;i++) mi[i]=mi[i-1]*(K+1);
    for(int i=0;i<mi[m];i++) 
    {
        int sum=0;
        for(int j=0;j<m;j++) sum+=i/mi[j]%(K+1);
        if(sum<=K) st[tott]=i, pos[i]=tott++;
    }
    base[0].mp[tott][tott]=1;
    base[0].n=base[0].m=tott;
    for(int i=0;i<tott;i++)
    {
        int sum=0;
        for(int j=0;j<m;j++) sum+=st[i]/mi[j]%(K+1);
        int inv=power(sum+1, mod-2);
        base[0].mp[i][tott]=base[0].mp[i][i]=inv;
        for(int j=0;j<m;j++)
        if(st[i]/mi[j]%(K+1))
        {
            int x=st[i]-mi[j];
            if(j) x+=mi[j-1];
            if(j && sum<K) x+=mi[m-1];
            base[0].mp[i][pos[x]]=1ll*inv*(st[i]/mi[j]%(K+1))%mod;
        }
    }
    for(int i=1;i<60;i++) base[i]=base[i-1]*base[i-1];
    while(T--)
    {
        scanf("%lld", &n); mtx ans; ans.n=tott; ans.mp[tott][0]=1;
        int digit=0; for(;n;n>>=1, digit++) if(n&1) ans=base[digit]*ans;
        printf("%d\n", ans.mp[pos[mi[m-1]]][0]);
    }
}
View Code
posted @ 2018-02-11 18:28  Sakits  阅读(253)  评论(0编辑  收藏