2019杭电多校三 C. Yukikaze and Demons (点分治)

大意: 给定树, 每个点有一个十进制数位, 求有多少条路径组成的十进制数被$k$整除.

 

点分治, 可以参考CF715C, 转化为求$10^a x+b\equiv 0(mod\space k)$的$x$的个数.

要注意

  • $tmp$不要设成全局!!
  • 如果$\text{y%z==0}$的话, 那么$\text{x%y%z==x%z}$
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <math.h>
#include <set>
#include <map>
#include <queue>
#include <string>
#include <string.h>
#include <bitset>
#define REP(i,a,n) for(int i=a;i<=n;++i)
#define PER(i,a,n) for(int i=n;i>=a;--i)
#define hr putchar(10)
#define pb push_back
#define lc (o<<1)
#define rc (lc|1)
#define mid ((l+r)>>1)
#define ls lc,l,mid
#define rs rc,mid+1,r
#define x first
#define y second
#define io std::ios::sync_with_stdio(false)
#define endl '\n'
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;




const int N = 1e5+10;
int sum, n, rt, m, p10[N];
int sz[N], mx[N], vis[N], b[N];
char s[N];
vector<int> g[N];
ll ans, ans1, Phi;

int gcd(int a, int b) {return b?gcd(b,a%b):a;}
int exgcd(int a, int b, int &x, int &y) {
    int d;
    if (b) d=exgcd(b,a%b,y,x), y-=a/b*x;
    else d=a,x=1,y=0;
    return d;
}
bool chk(int &a, int &b, int &p) {
    //ax=b(mod p)是否有解
    int x, k, d = exgcd(a,p,x,k);
    if (b%d==0) a=1,p/=d,b=(b/d*x%p+p)%p;
    return a==1;
}

void getrt(int x, int fa) {
    mx[x]=0, sz[x]=1;
    for (int y:g[x]) if (!vis[y]&&y!=fa) {
        getrt(y,x),sz[x]+=sz[y];
        mx[x]=max(mx[x],sz[y]);
    }
    mx[x]=max(mx[x],sum-sz[x]);
    if (mx[rt]>mx[x]) rt=x;
}

int ID(int x) {
    return lower_bound(b+1,b+1+*b,x)-b;
}

map<int,int> mp[40];

//mp[i][j] 记录10^h*x=y(mod m)的y的个数, 其中y = j (mod b[i]), b[i] = m/gcd(10^h,m)
void dfs1(int x, int fa, int dep, int down) {
    //求10^dep*x=(m-down)%m
    int a = p10[dep], b = (m-down)%m, p = m;
    if (chk(a,b,p)) { 
        auto &u = mp[ID(p)];
        if (u.count(b)) ans += u[b];
    }
    for (int y:g[x]) if (!vis[y]&&y!=fa) {
        dfs1(y,x,dep+1,((ll)down*10ll+s[y])%m);
    }
}
int up[40];
void dfs2(int x, int fa, int dep) {
    REP(i,1,*b) { 
        ++mp[i][up[i]];
    }
	int tmp[40];
    for (int y:g[x]) if (!vis[y]&&y!=fa) {
        REP(i,1,*b) tmp[i]=up[i],up[i]=((ll)s[y]*p10[dep]+up[i])%b[i];
        dfs2(y,x,dep+1);
        REP(i,1,*b) up[i]=tmp[i];
    }
}
void dfs3(int x, int fa, int down, int dep, int up) {
    ans1 += !up+!down;
    for (int y:g[x]) if (!vis[y]&&y!=fa) {
        dfs3(y,x,((ll)down*10+s[y])%m,dep+1,((ll)s[y]*p10[dep]+up)%m);
    }
}

vector<int> q;
void calc(int x) {
    REP(i,1,*b) mp[i].clear();
    if (s[x]%m==0) ++ans1;
    for (int y:q) {
        dfs1(y,x,1,s[y]%m);
        REP(i,1,*b) up[i] = (s[x]+10ll*s[y])%b[i];
        dfs2(y,x,2);
        dfs3(y,x,(10ll*s[x]+s[y])%m,2,(s[x]+10ll*s[y])%m);
    }
}

void solve(int x) {
    vis[x] = 1;
    q.clear();
    for (int y:g[x]) if (!vis[y]) q.pb(y);
    calc(x);
    reverse(q.begin(),q.end());
    calc(x);
    for (int y:g[x]) if (!vis[y]) {
        mx[rt=0]=n,sum=sz[y];
        getrt(y,0), solve(rt);
    }
}


void work() {
    scanf("%d%d%s", &n, &m, s+1);
    REP(i,1,n) p10[i]=p10[i-1]*10ll%m;
    REP(i,1,n) s[i]-='0';
    ans = ans1 = 0;
    REP(i,1,n) g[i].clear(),vis[i]=0;
    REP(i,2,n) {
        int u, v;
        scanf("%d%d", &u, &v);
        g[u].pb(v);
        g[v].pb(u);
    }
    if (m==1) return printf("%lld\n", (ll)n*n),void();
    *b = 0;
    REP(i,0,min(n,30)) b[++*b]=m/gcd(p10[i],m);
    sort(b+1,b+1+*b),*b=unique(b+1,b+1+*b)-b-1;
    sum=mx[rt=0]=n,getrt(1,0),solve(rt);
    printf("%lld\n", ans+ans1/2);
}

int main() {
    p10[0]=1;
    int t;
    scanf("%d", &t);
    while (t--) work();
}

 

posted @ 2019-07-29 22:51  uid001  阅读(286)  评论(0编辑  收藏  举报