bzoj2164 采矿

题目描述:

bz

题解:

线段树。

考虑在线段树上维护区间内在某个点选$i$个的最大值,以及区间内任意分配$i$个点的最大值。

前者合并$O(m)$,后者合并$O(m^2)$。

所以复杂度$O(nm^2+mlog^2n+m^2logn)$,可过。

代码:

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N = 20050;
const int M = 55;
const int X = (1<<16);
const int Y = 2147483647;
template<typename T>
inline void read(T&x)
{
    T f = 1,c = 0;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){c=c*10+ch-'0';ch=getchar();}
    x = f*c;
}
int n,m,A,B,Q,k[N][M],hed[N],cnt;
inline int read()
{
    A = (((A^B) + (B/X) + (B*X))&Y);
    B = (((A^B) + (A/X) + (A*X))&Y);
    return (A^B)%Q;
}
void get_k(int*k)
{
    for(int i=1;i<=m;i++)
        k[i]=read();
    sort(k+1,k+1+m);
}
struct EG
{
    int to,nxt;
}e[N];
void ae(int f,int t)
{
    e[++cnt].to = t;
    e[cnt].nxt = hed[f];
    hed[f] = cnt;
}
int dep[N],fa[N],son[N],top[N],siz[N],tin[N],tout[N],pla[N],tim;
void dfs0(int u,int f)
{
    fa[u] = f,siz[u] = 1,dep[u] = dep[f]+1;
    for(int j=hed[u];j;j=e[j].nxt)
    {
        int to = e[j].to;
        dfs0(to,u);
        siz[u]+=siz[to];
        if(siz[to]>siz[son[u]])son[u]=to;
    }
}
void dfs1(int u,int Top)
{
    top[u] = Top,tin[u] = ++tim,pla[tim] = u;
    if(son[u])dfs1(son[u],Top);
    for(int j=hed[u];j;j=e[j].nxt)
    {
        int to = e[j].to;
        if(to!=son[u])
            dfs1(to,to);
    }
    tout[u] = tim;
}
struct node
{
    ll s[M];
    void reset(int i)
    {
        memset(s,0,sizeof(s));
        if(!i)return ;
        for(int j=1;j<=m;j++)
            s[j]=k[i][j];
    }
    node operator + (const node&a)const
    {
        node ret;ret.reset(0);
        for(int i=1;i<=m;i++)ret.s[i]=max(s[i],a.s[i]);
        return ret;
    }
    node operator * (const node&a)const
    {
        node ret;ret.reset(0);
        for(int i=1;i<=m;i++)
            for(int j=0;j<=i;j++)
                ret.s[i]=max(ret.s[i],a.s[j]+s[i-j]);
        return ret;
    }
};
struct segtree
{
    node s1[N<<2],s2[N<<2];
    void update(int u){s1[u]=s1[u<<1]+s1[u<<1|1],s2[u]=s2[u<<1]*s2[u<<1|1];}
    void build(int l,int r,int u)
    {
        if(l==r){s1[u].reset(pla[l]),s2[u].reset(pla[l]);return ;}
        int mid = (l+r)>>1;
        build(l,mid,u<<1),build(mid+1,r,u<<1|1);
        update(u);
    }
    void insert(int l,int r,int u,int qx)
    {
        if(l==r){s1[u].reset(pla[l]),s2[u].reset(pla[l]);return ;}
        int mid = (l+r)>>1;
        if(qx<=mid)insert(l,mid,u<<1,qx);
        else insert(mid+1,r,u<<1|1,qx);
        update(u);
    }
    node qs1(int l,int r,int u,int ql,int qr)
    {
        if(l==ql&&r==qr)return s1[u];
        int mid = (l+r)>>1;
        if(qr<=mid)return qs1(l,mid,u<<1,ql,qr);
        else if(ql>mid)return qs1(mid+1,r,u<<1|1,ql,qr);
        else return qs1(l,mid,u<<1,ql,mid)+qs1(mid+1,r,u<<1|1,mid+1,qr);
    }
    node qs2(int l,int r,int u,int ql,int qr)
    {
        if(l==ql&&r==qr)return s2[u];
        int mid = (l+r)>>1;
        if(qr<=mid)return qs2(l,mid,u<<1,ql,qr);
        else if(ql>mid)return qs2(mid+1,r,u<<1|1,ql,qr);
        else return qs2(l,mid,u<<1,ql,mid)*qs2(mid+1,r,u<<1|1,mid+1,qr);
    }
}tr;
int main()
{
//  freopen("tt.in","r",stdin);
    read(n),read(m),read(A),read(B),read(Q);
    for(int i=1;i<=n;i++)
        get_k(k[i]);
    for(int i=2,f;i<=n;i++)
        read(f),ae(f,i);
    dfs0(1,0),dfs1(1,1);tr.build(1,n,1);
    int C;read(C);
    for(int op,u,v,p,i=1;i<=C;i++)
    {
        read(op);
        if(!op)
        {
            read(p);
            get_k(k[p]);
            tr.insert(1,n,1,tin[p]);
        }else
        {
            read(u),read(v);
            node ans;
            if(u==v)
            {
                ans = tr.qs2(1,n,1,tin[u],tout[u]);
            }else
            {
                node k1,k2;k1.reset(0),k2.reset(0);
                int j;
                for(j=fa[u];top[j]!=top[v];j=fa[top[j]])
                    k1 = k1+tr.qs1(1,n,1,tin[top[j]],tin[j]);
                k1 = k1+tr.qs1(1,n,1,tin[v],tin[j]);
                k2 = tr.qs2(1,n,1,tin[u],tout[u]);
                ans = k1*k2;
            }
            printf("%lld\n",ans.s[m]);
        }
    }
    return 0;
}
View Code

 

posted @ 2019-07-01 16:06  LiGuanlin  阅读(...)  评论(...编辑  收藏