习题:Ralph and Mushrooms(tarjan)

题目

传送门

思路

有一个比较明显的性质,如果一条路能够被多次经过,那么这条路上的蘑菇一定会被采完,也就是指蘑菇的数量为0

考虑如果判断一条路能否被多次经过,比较容易的就能想到用tarjan来缩点

再者就是一条路上怎么统计一共可以采多少次蘑菇,即计算\(\sum_{i=0}^{len}(w-\sum_{j=1}^{i}j)\),这里的len满足\(\sum_{i=1}^{len}i<=w\)

同时len是最大的,统计每一个点的答案二分len即可

很容易发现其可以分离,即\((len+1)*w-\sum_{i=0}^{len}\sum_{j=1}^{i}j\),后面的式子很明显是可以预处理出来的

缩完点之后直接在DAG上跑DP就可以了,\(dp_i\)表示以i号节点为起点所能到达的最大收益,\(dp_i=val_i+\max_{v\in son}dp_v\)

代码

#include<iostream>
#include<cstring>
#include<vector>
#include<cstdio>
#include<stack>
using namespace std;
namespace IO
{
    void read(int &x)
    {
        x=0;
        int f=1;
        char c=getchar();
        while('0'>c||c>'9')
        {
            if(c=='-')
                f=-1;
            c=getchar();
        }
        while('0'<=c&&c<='9')
        {
            x=(x<<3)+(x<<1)+c-'0';
            c=getchar();
        }
        x*=f;
    }
    void read(long long &x)
    {
        x=0;
        int f=1;
        char c=getchar();
        while('0'>c||c>'9')
        {
            if(c=='-')
                f=-1;
            c=getchar();
        }
        while('0'<=c&&c<='9')
        {
            x=(x<<3)+(x<<1)+c-'0';
            c=getchar();
        }
        x*=f;
    }
    void write(int x)
    {
        if(x<10)
            putchar(x+'0');
        else
        {
            write(x/10);
            putchar(x%10+'0');
        }
    }
    void write(long long x)
    {
        if(x<10)
            putchar(x+'0');
        else
        {
            write(x/10);
            putchar(x%10+'0');
        }
    }
}
using namespace IO;
struct node
{
    int e;
    long long w;
};
int n,m;
int st;
int dfn[1000005],low[1000005],bel[1000005],cnt,scc;
bool ins[1000005];
long long init[1000005];
long long dp[1000005],val[1000005];
stack<int> s;
vector<node> g[1000005],dag[1000005];
void tarjan(int u)
{
    s.push(u);
    ins[u]=1;
    dfn[u]=low[u]=++cnt;
    for(int i=0;i<g[u].size();i++)
    {
        int v=g[u][i].e;
        if(dfn[v]==0)
        {
            tarjan(v);
            low[u]=min(low[u],low[v]);
        }
        else if(ins[v])
            low[u]=min(low[u],dfn[v]);
    }
    if(dfn[u]==low[u])
    {
        int v;
        scc++;
        do
        {
            v=s.top();
            s.pop();
            ins[v]=0;
            bel[v]=scc;
        }while(u!=v);
    }
}
long long calc(long long w)
{
    long long l=0,r=40000,mid;
    while(l+1<r)
    {
        mid=(l+r)>>1;
        if(mid*(mid+1)/2<w)
            l=mid;
        else
            r=mid;
    }
    while((l+1)*(l+2)/2<=w)
        l++;
    return w*(l+1)-init[l];
}
void solve(int u)
{
    if(dp[u]!=-1)
        return;
    dp[u]=0;
    for(int i=0;i<dag[u].size();i++)
    {
        int v=dag[u][i].e;
        solve(v);
        dp[u]=max(dp[u],dp[v]+dag[u][i].w);
    }
    dp[u]+=val[u];
}
int main()
{
    memset(dp,-1,sizeof(dp));
    for(int i=1;i<=20000;i++)
        init[i]=init[i-1]+1ll*i*(i+1)/2;
    read(n);
    read(m);
    for(int i=1,u,v;i<=m;i++)
    {
        long long w;
        read(u);read(v);read(w);
        g[u].push_back((node){v,w});
    }
    for(int i=1;i<=n;i++)
        if(dfn[i]==0)
            tarjan(i);
    for(int i=1;i<=n;i++)
        for(int j=0;j<g[i].size();j++)
            if(bel[i]==bel[g[i][j].e])
                val[bel[i]]+=calc(g[i][j].w);
            else
                dag[bel[i]].push_back((node){bel[g[i][j].e],g[i][j].w});
   	read(st);
    solve(bel[st]);
    write(dp[bel[st]]);
    return 0;
}
posted @ 2020-08-24 09:53  loney_s  阅读(92)  评论(0)    收藏  举报