矩阵线段树

线段树维护矩阵

无标题.png

直接维护矩阵

摒弃之前难看的代码,换上清真的

#include<bits/stdc++.h>
using namespace std;
using LL = long long;
template<class T = int> T mian(){
	T s=0,f=1;char ch;
	while(!isdigit(ch=getchar()))(ch=='-')&&(f=-1);
	for(s=ch-'0';isdigit(ch=getchar());s=s*10+ch-'0');
	return s*f;
}
const int maxn = 5e5+5;
const int p = 1e9+7;
struct Matrix{
	LL c[2][2];
	Matrix(){clear();}
	void clear(){memset(c,0,sizeof(c));}
	void e(){clear();c[0][0]=c[1][1]=1;}
	LL *operator[](int x){return c[x];}
	const LL*operator[](int x)const{return c[x];}
	friend Matrix operator*(const Matrix &a,const Matrix &b){
		Matrix ans;
		ans[0][0]=(a[0][0]*b[0][0]%p+a[0][1]*b[1][0]%p)%p;
		ans[0][1]=(a[0][0]*b[0][1]%p+a[0][1]*b[1][1]%p)%p;
		ans[1][0]=(a[1][0]*b[0][0]%p+a[1][1]*b[1][0]%p)%p;
		ans[1][1]=(a[1][0]*b[0][1]%p+a[1][1]*b[1][1]%p)%p;
		return ans;
	}
	friend Matrix operator + (const Matrix &a,const Matrix &b){
		Matrix ans;
		ans[0][0]=(a[0][0]+b[0][0])%p;
		ans[0][1]=(a[0][1]+b[0][1])%p;
		ans[1][0]=(a[1][0]+b[1][0])%p;
		ans[1][1]=(a[1][1]+b[1][1])%p;
		return ans;
	}
}base;

Matrix ksm(Matrix b,int n){
	Matrix ans;ans.e();
	for(;n;n>>=1,b=b*b)
		if(n&1)ans=ans*b;
	return ans;
}

struct Seg_Node{
	Seg_Node *lch,*rch;
	int l,r,isfucked;
	Matrix sum,add;

	Seg_Node():lch(NULL),rch(NULL),l(0),r(0),isfucked(0){}

	int mid(){return (l+r)>>1;}

	void push_up(){ sum=lch->sum+rch->sum; }
	
	void plus(Matrix x){sum=x*sum;add=x*add;isfucked=1;}

	void push_down(){
		if(!isfucked)return ;
		lch->plus(add);
		rch->plus(add);
		add.e();
		isfucked=0;
	}
	
};

typedef Seg_Node* ptr;
ptr root;

void build(int l,int r,ptr &o=root){
	//printf("%d %d\n",l,r);
	o=new Seg_Node; o->l=l; o->r=r;
	o->add.e(); o->isfucked=0;
	if(l==r)return (void)(o->sum=ksm(base,mian()-1));
	int mid=o->mid();
	build(l,mid,o->lch);
	build(mid+1,r,o->rch);
	o->push_up();
}

void addval(int l,int r,Matrix val,ptr o=root){
	if(l<=o->l&&o->r<=r)return o->plus(val);
	int mid=o->mid(); o->push_down();
	if(l<=mid)addval(l,r,val,o->lch);
	if(r>mid) addval(l,r,val,o->rch);
	o->push_up();
}

LL getsum(int l,int r,ptr o=root){
	if(l<=o->l&&o->r<=r)return o->sum[0][0];
	int mid=o->mid();LL ans=0; o->push_down();
	if(l<=mid)(ans+=getsum(l,r,o->lch))%=p;
	if(r>mid) (ans+=getsum(l,r,o->rch))%=p;
	return ans;
}

int n,m;

int main(){
	base[0][0]=base[0][1]=base[1][0]=1;
	base[1][1]=0;
	n=mian(),m=mian();
	build(1,n);
	for(int i=0;i<m;++i){
		int op=mian(),l=mian(),r=mian(),k;
		if(op==1)k=mian(),addval(l,r,ksm(base,k));
		if(op==2)printf("%lld\n",getsum(l,r));
	}
	return 0;
}
posted @ 2019-01-15 17:22  kraylas  阅读(352)  评论(0编辑  收藏  举报