dp[i]表示前i个贼全部被抓住的方案数,先离散化,转移时可以将左端点减1到右端点的dp值累加到右端点的dp值上,并把后面所有的dp值都乘2,这显然可以用线段树维护,没有用的也对答案有乘2的贡献。
#include <stdio.h> #include <algorithm> using namespace std; #define ll long long int const int maxn=500001,mod=1000000009; struct oper { ll l,r; }a[maxn]; struct node { ll l,r; ll w; ll f; }tree[maxn*8]; ll b[maxn]; ll n,m,ans=1; bool cmp(oper xt,oper yt) { return xt.l<yt.l; } void build(ll lt,ll rt,ll cur) { tree[cur].l=lt; tree[cur].r=rt; if(lt==rt) return; ll mid=(lt+rt)/2; build(lt,mid,cur*2); build(mid+1,rt,cur*2+1); } void down(ll cur) { tree[cur*2].w=(tree[cur*2].w*tree[cur].f)%mod; tree[cur*2+1].w=(tree[cur*2+1].w*tree[cur].f)%mod; tree[cur*2].f=(tree[cur*2].f*tree[cur].f)%mod; tree[cur*2+1].f=(tree[cur*2+1].f*tree[cur].f)%mod; tree[cur].f=1; } ll ask(ll cur,ll s,ll t) { if(t<s) return 0; if(tree[cur].l>=s&&tree[cur].r<=t) return tree[cur].w; if(tree[cur].f!=1) down(cur); ll mid=(tree[cur].l+tree[cur].r)/2; if(mid>=t) return ask(cur*2,s,t); else if(mid<s) return ask(cur*2+1,s,t); else return (ask(cur*2,s,t)+ask(cur*2+1,mid+1,t))%mod; } void add(ll cur,ll p,ll v) { if(tree[cur].l==tree[cur].r) { tree[cur].w=(tree[cur].w+v)%mod; return; } if(tree[cur].f!=1) down(cur); ll mid=(tree[cur].l+tree[cur].r)/2; if(mid>=p) add(cur*2,p,v); else add(cur*2+1,p,v); tree[cur].w=(tree[cur*2].w+tree[cur*2+1].w)%mod; } void mult(ll cur,ll s,ll t) { if(t<s) return;//千万注意 if(tree[cur].l>=s&&tree[cur].r<=t) { tree[cur].w=(tree[cur].w*2)%mod; tree[cur].f=(tree[cur].f*2)%mod; return; } if(tree[cur].f!=1) down(cur); ll mid=(tree[cur].l+tree[cur].r)/2; if(mid>=t) mult(cur*2,s,t); else if(mid<s) mult(cur*2+1,s,t); else { mult(cur*2,s,mid); mult(cur*2+1,mid+1,t); } tree[cur].w=(tree[cur*2].w+tree[cur*2+1].w)%mod; } int main() { scanf("%lld%lld",&n,&m); build(0,m,1); ll i; for(i=1;i<=n;i++) scanf("%lld%lld",&a[i].l,&a[i].r); for(i=1;i<=m;i++) scanf("%lld",&b[i]); sort(b+1,b+m+1); ll num=0; for(i=1;i<=n;i++) { a[i].l=lower_bound(b+1,b+m+1,a[i].l)-b;//离散化完全 a[i].r=upper_bound(b+1,b+m+1,a[i].r)-b-1; if(a[i].l>a[i].r) ans=(ans*2)%mod; else a[++num]=a[i]; } n=num; sort(a+1,a+n+1,cmp); ll temp=1; add(1,0,temp);//覆盖0个初始化 for(i=1;i<=n;i++) { temp=ask(1,a[i].l-1,a[i].r); add(1,a[i].r,temp); mult(1,a[i].r+1,m); } temp=ask(1,m,m); ans=(ans*temp)%mod; printf("%lld\n",ans%mod); return 0; }