【CSA49G】【XSY3315】jump DP

题目大意

  有一个数轴。yww 最开始在位置 \(0\)yww 总共要跳跃很多次。每次 yww 可以往右跳 \(1\) 单位长度,或者跳到位置 \(1\)

  定义位置序列为 yww 在每次跳跃之后所在的位置组成的序列(显然不包括 \(0\))。

  有 \(k\) 个数是好的,分别为 \(a_1,a_2,\ldots,a_k\)

  定义一个位置序列是好的当且仅当:

  • 所有好的数的出现次数之和为 \(n\)
  • 序列中最后一个数是好的。
  • 对于每一个长度为 \(m\) 的子区间,区间内至少有一个好的数。

  定义两个位置序列本质相同当且仅当:

  • 这两个位置序列的长度相同。
  • 不存在一个数 \(t\),满足其中一个序列的第 \(t\) 项是一个好的数,且另一个序列的第 \(t\) 项不是一个好的数。

  求所有本质不同的好的位置序列的长度之和。

  对 \({10}^9+7\) 取模。

  \(k\leq 100,n,m\leq {10}^9\)

题解

  记 \(b_i=a_i-a_{i-1}\)

  考虑对于一个最终位置是好的点的位置序列,求出和这个序列本质相同的序列中,最终能到达那些点。

  这样就有一个 \(O(2^kn)\) 的DP了。

  如果你把这个做法写出来,就会发现其实只有 \(O(k)\) 个状态是有用的。

  这是为什么呢?

  对于一个位置集合 \(S\),考虑集合内最大的元素 \(x\),对于一个更小的 \(y\)\(y\in S\) 当且仅当 \(b_1\leq b_{x-y+1}\)\(b_{2\ldots y}=b_{x-y+2 \ldots x}\) 这样对于每个 \(x\)\(S\) 是唯一的。

  直接套个矩阵快速幂优化DP就 \(O(k^3\log n)\) 了。


  还可以换一种方向思考。

  考虑对于一个好的序列,每次取一个最短的前缀,满足后面还是合法的好的序列。

  如果一个前缀能被分成几个前缀拼在一起,就不能取这个前缀了。

  处理出删掉 \(b_1\) 之后的 ex_kmp 数组就可以快速求出每个前缀是否可选。

  这就是一个常系数齐次线性递推。

  直接BM+倍增取模可以做到 \(O(k^2+k\log k\log n)\)

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<functional>
#include<cmath>
#include<vector>
#include<assert.h>
#include<map>
//using namespace std;
using std::min;
using std::max;
using std::swap;
using std::sort;
using std::reverse;
using std::random_shuffle;
using std::lower_bound;
using std::upper_bound;
using std::unique;
using std::vector;
using std::map;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef std::pair<int,int> pii;
typedef std::pair<ll,ll> pll;
void open(const char *s){
#ifndef ONLINE_JUDGE
	char str[100];sprintf(str,"%s.in",s);freopen(str,"r",stdin);sprintf(str,"%s.out",s);freopen(str,"w",stdout);
#endif
}
void open2(const char *s){
#ifdef DEBUG
	char str[100];sprintf(str,"%s.in",s);freopen(str,"r",stdin);sprintf(str,"%s.out",s);freopen(str,"w",stdout);
#endif
}
int rd(){int s=0,c,b=0;while(((c=getchar())<'0'||c>'9')&&c!='-');if(c=='-'){c=getchar();b=1;}do{s=s*10+c-'0';}while((c=getchar())>='0'&&c<='9');return b?-s:s;}
void put(int x){if(!x){putchar('0');return;}static int c[20];int t=0;while(x){c[++t]=x%10;x/=10;}while(t)putchar(c[t--]+'0');}
int upmin(int &a,int b){if(b<a){a=b;return 1;}return 0;}
int upmax(int &a,int b){if(b>a){a=b;return 1;}return 0;}
const int K=110;
const ll p=1000000007;
int k,m,n;
template <typename T>
T operator +(T a,T b)
{
	return T(a.first+b.first,a.second+b.second);
}
template <typename T>
T operator *(T a,T b)
{
	return T(a.first*b.first,a.first*b.second+a.second*b.first);
}
struct mat
{
	pii a[K][K];
	mat()
	{
		memset(a,0,sizeof a);
	}
	pii *operator [](int x)
	{
		return a[x];
	}
};
mat operator *(mat a,mat b)
{
	mat c;
	for(int i=1;i<=k;i++)
		for(int j=1;j<=k;j++)
		{
			std::pair<__int128,__int128> s;
			for(int l=1;l<=k;l++)
			{
				s.first+=(ll)a[i][l].first*b[l][j].first;
				s.second+=(ll)a[i][l].first*b[l][j].second;
				s.second+=(ll)a[i][l].second*b[l][j].first;
			}
			c[i][j]=pii(s.first%p,s.second%p);
		}
	return c;
}
mat fp(mat a,ll b)
{
	mat s;
	for(int i=1;i<=k;i++)
		s[i][i]=pll(1,0);
	for(;b;b>>=1,a=a*a)
		if(b&1)
			s=s*a;
	return s;
}
int b[K][K];
int a[K];
mat c,d;
map<int,int> s;
int sum(ll l,ll r)
{
	return (l+r)*(r-l+1)/2%p;
}
int main()
{
	open("jump");
	scanf("%d%d%d",&k,&m,&n);
	k=min(k,n);
	for(int i=1;i<=k;i++)
		scanf("%d",&a[i]);
	for(int i=k;i>=2;i--)
		a[i]-=a[i-1];
	if(a[1]>m)
	{
		printf("0\n");
		return 0;
	}
	for(int i=2;i<=k;i++)
		if(a[i]>m)
			k=i;
	for(int i=1;i<=k;i++)
	{
		b[i][i]=1;
		for(int j=1;j<i;j++)
		{
			int flag=1;
			if(a[1]>a[i-j+1])
				flag=0;
			for(int l=2;flag&&l<=j;l++)
				if(a[l]!=a[i-j+l])
					flag=0;
			b[i][j]=flag;
		}
	}
	for(int i=1;i<=k;i++)
	{
		int cnt=m-a[1]+1;
		ll _s=sum(a[1],m);
		s.clear();
		for(int j=min(i,k-1);j>=1;j--)
			if(b[i][j]&&!s[a[j+1]])
			{
				s[a[j+1]]=1;
				c[i][j+1]=pll(1,a[j+1]);
				if(a[j+1]>=a[1])
				{
					cnt--;
					_s=(_s-a[j+1])%p;
				}
			}
		c[i][1]=pll(cnt,_s);
	}
	d[1][1]=pll(m-a[1]+1,sum(a[1],m));
	c=fp(c,n-1);
	d=d*c;
//	for(int i=1;i<n;i++)
//		d=d*c;
	ll ans=0;
	for(int i=1;i<=k;i++)
		ans=(ans+d[1][i].second)%p;
	ans=(ans%p+p)%p;
	printf("%lld\n",ans);
	return 0;
}
posted @ 2019-01-03 16:21  ywwyww  阅读(406)  评论(0编辑  收藏  举报