hdu 6198 杜教BM
打表+递推式,留个板子
#include<bits/stdc++.h>
#define ll long long
#define IO ios::sync_with_stdio(false);cin.tie(0);cout.tie(0)
using namespace std;//head
const int mod=998244353;
int n,casn,m,k;
namespace bm{
const int maxl=1e4+10; //@不需改@
ll res[maxl],base[maxl],_c[maxl],_md[maxl];
vector<ll> md;
ll inv(ll a,ll c=mod) {
a%=c;if(a<0)a+=c;
ll b=c,u=0,v=1;
while(a) {
ll t=b/a;b-=t*a;
swap(a,b);u-=t*v;
swap(u,v);
}
if(u<0)u+=c;
return u;
}
void mul(ll *a,ll *b,int k) {
for(int i=0;i<k+k;i++) _c[i]=0;
for(int i=0;i<k;i++) if (a[i])
for(int j=0;j<k;j++) _c[i+j]=(_c[i+j]+a[i]*b[j])%mod;
for (ll i=k+k-1;i>=k;i--) if (_c[i])
for(int j=0;j<md.size();j++)
_c[i-k+md[j]]=(_c[i-k+md[j]]-_c[i]*_md[md[j]])%mod;
for(int i=0;i<k;i++) a[i]=_c[i];
}
int solve(ll n,vector<ll> a,vector<ll> b) {
//@a 系数 b 初值 b[n+1]=a[0]*b[n]+...@
//@求出的是第n+1项 @
ll ans=0,pnt=0;
ll k=a.size();
for(int i=0;i<k;i++) _md[k-1-i]=-a[i];_md[k]=1;
md.clear();
for(int i=0;i<k;i++) if (_md[i]!=0) md.push_back(i);
for(int i=0;i<k;i++) res[i]=base[i]=0;
res[0]=1;
while ((1ll<<pnt)<=n) pnt++;
for (ll p=pnt;p>=0;p--) {
mul(res,res,k);
if ((n>>p)&1) {
for (ll i=k-1;i>=0;i--) res[i+1]=res[i];res[0]=0;
for(int j=0;j<md.size();j++)
res[md[j]]=(res[md[j]]-res[k]*_md[md[j]])%mod;
}
}
for(int i=0;i<k;i++) ans=(ans+res[i]*b[i])%mod;
if (ans<0) ans+=mod;
return ans;
}
vector<ll> init(vector<ll> s) {
vector<ll> coe(1,1),base(1,1);
int len=0,m=1,b=1;
for(int n=0;n<s.size();n++) {
ll d=0;
for(int i=0;i<len+1;i++) d=(d+(ll)coe[i]*s[n-i])%mod;
if (d==0) ++m;
else if (2*len<=n) {
vector<ll> tmp=coe;
ll c=mod-d*inv(b)%mod;
while (coe.size()<base.size()+m) coe.push_back(0);
for(int i=0;i<base.size();i++) coe[i+m]=(coe[i+m]+c*base[i])%mod;
len=n+1-len; base=tmp; b=d; m=1;
} else {
ll c=mod-d*inv(b)%mod;
while (coe.size()<base.size()+m) coe.push_back(0);
for(int i=0;i<base.size();i++) coe[i+m]=(coe[i+m]+c*base[i])%mod;
++m;
}
}
return coe;
}
vector<ll> c,a;
void inita(vector<ll> _a){
a=_a;
c=init(a);c.erase(c.begin());
for(auto &i:c) i=(mod-i)%mod;
}
int get(ll n) {
return solve(n,c,vector<ll>(a.begin(),a.begin()+c.size()));
}
int get(vector<ll> a,ll n) {
vector<ll> c=init(a);
c.erase(c.begin());
for(int i=0;i<c.size();i++) c[i]=(mod-c[i])%mod;
return solve(n,c,vector<ll>(a.begin(),a.begin()+c.size()));
}
};
int main(){IO;
ll n;
bm::inita(vector<ll>{4,12,33,88,232,609,1596});
while(cin>>n){
cout<<bm::get(n-1)<<'\n';
}
}

浙公网安备 33010602011771号