使用模2^64-2^24+1 的数论变换NTT, karatsuba分治 的 base2^64大整数乘法

#include <bits/stdc++.h>

using namespace std;

namespace mbase {

const size_t kara_least=32;

typedef uint64_t u64;
typedef uint32_t u32;

// 3*(1<<30)+1 P0-1 = [2^24, 3, 5, 5, 11, 17, 31, 41, 61681]
const u64 P0 = 18446744073692774401ull,
          G0=129;

const u64 F64=0xffffffffffffffffull;
constexpr uint64_t H32(uint64_t x) {
    return (x>>32);
}
constexpr uint64_t L32(uint64_t x) {
    return (x&0xffffffff);
}

int log2i(u64 x) {
    if(x==0) return -1;
    int ans=0;
    for(int i=32; i; i>>=1) {
        if(x>>i) {
            x>>=i;
            ans+=i;
        }
    }
    return ans;
}

void mult_to128(u64 x,u64 y,u64& a0,u64& a1) {
    unsigned __int128 I=((unsigned __int128)x)*y;
    a0=I;
    a1=(I>>64);
    return;
    const u32 xh=(x>>32), xl=x, yh=(y>>32), yl=y;
    u64 x0=static_cast<u64>(xl)*yl;
    u64 x1=static_cast<u64>(xl)*yh;
    u64 x2=static_cast<u64>(xh)*yl;
    u64 x3=static_cast<u64>(xh)*yh;
    x1+=(x0>>32);
    x1+=x2;
    x3+=static_cast<u64>(static_cast<bool>(x1<x2)) <<32;
    a1=x3+(x1>>32);
    a0=(x1<<32) | (x0&0xffffffff);
}
void square_to128(u64 x,u64& a0,u64& a1) {
    a1=H32(x)*H32(x);
    a0=L32(x)*L32(x);
    u64 p=H32(x)*L32(x);
    a0+=(p<<33);
    if(a0 < (p<<33)) ++a1;
    a1+=(p>>31);
}

void v_add1(u64* v) {
    while(!(++*v)) ++v;
}

void v_minus1(u64* v) {
    while(!~(--*v)) ++v;
}

//carry up even reach sa
void v_add(u64* a,size_t sa,const u64* b,size_t sb) {
    if(sb>sa) {
        memcpy(a,b+sa,sizeof(u64)*(sb-sa));
        swap(sa,sb);
    }
    for(size_t i=0; i<sb;) {
        a[i]+=b[i];
        if(a[i]<b[i]) {
            for(i=i+1; a[i]==F64; ++i) {
                if(i<sb) a[i]=b[i];
                else a[i]=0;
            }
            a[i]++;
        } else {
            ++i;
        }
    }
}

int my_comp(const u64* a,const u64* b,size_t sz) {
    for(int i=sz-1; i>=0; --i) {
        if(a[i]>b[i]) return 1;
        if(a[i]<b[i]) return -1;
    }
    return 0;
}

//dont edit a[sa] if needed, but return 1
u64 v_minus_ret(u64* a,size_t sa,const u64* b,size_t sb) {
    for (size_t i = 0; i < sb;) {
        if (a[i]<b[i]) {
            a[i]-=b[i];
            for(i=i+1; i<sa&&a[i]==0; ++i) {
                if(i<sb) a[i]=~b[i];
                else a[i]=F64;
            }
            if(i==sa) {
                return 1;
            }
            a[i]-=1;
        } else {
            a[i]-=b[i];
            ++i;
        }
    }
    return 0;
}
// assert a>=b
void v_minus_free(u64* a,const u64* b,size_t sb) {
    for (size_t i = 0; i < sb;) {
        if (a[i]<b[i]) {
            a[i]-=b[i];
            for(i=i+1; a[i]==0; ++i) {
                if(i<sb) a[i]=~b[i];
                else a[i]=F64;
            }
            a[i]-=1;
        } else {
            a[i]-=b[i];
            ++i;
        }
    }
}
// assert(c[0:sa+sb-1]==0)
void v_mult_simple(const u64* a,size_t sa,const u64* b,size_t sb,
                   u64* c) {
    if(sa==0 || sb==0) return;
    for(size_t i=0; i<sa+sb-1; ++i) {
        size_t rt=(i>=sb?i+1-sb:0), nd=min(i,sa-1);
        u64 cl=0,ch=0;
        for(size_t j=rt; j<=nd; ++j) {
            size_t k=i-j;
            u64 sl,sh;
            mult_to128(a[j],b[k],sl,sh);
            cl+=sl;
            ch+=sh;
            c[i+1]+=(cl<sl);
            if(ch<sh) c[i+2]+=1;
        }
        c[i]+=cl;
        c[i+1]+=ch+(c[i]<cl);
        if(c[i+1]<ch) {
            c[i+2]+=1;
        }
    }
}

void v_mult_karats(const u64* aa,size_t sa,const u64* bb,
                   size_t sb,u64* c,u64* BUFF=nullptr) {
    while(sa&&aa[sa-1]==0) --sa;
    while(sb&&bb[sb-1]==0) --sb;
    bool alloc_buff=false;

    if(sa<kara_least || sb<kara_least) {
        v_mult_simple(aa,sa,bb,sb,c);
        return;
    }

    const u64* a=(sa<sb?aa:bb),*b=(sa<sb?bb:aa);
    if(sb<sa)  swap(sa,sb);
    if(sa < (sb>>1)+2) {
        //v_mult_simple(a,sa,b,sb,c); return;
        u64* tmp;
        if(!BUFF) tmp=new u64[sa];
        else tmp=BUFF;
        for(size_t shift=0; shift<sb; shift+=sa) {
            if(shift) {
                memcpy(tmp,c+shift,sizeof(u64)*sa);
                memset(c+shift,0,sizeof(u64)*sa);
            }
            v_mult_karats(a,sa,b+shift,min(sa,sb-shift),c+shift,
                          BUFF?BUFF+sa:nullptr);
            if(shift) {
                v_add(c+shift,sa,tmp,sa);
            }
        }
        if(!BUFF) delete[] tmp;
        return;
    }
    if(!BUFF) {
        size_t buff_cur=((sb+1)>>1)+1,buff_all=0;
        while(buff_cur>=kara_least/2) {
            buff_all += buff_cur;
            buff_cur = ((buff_cur+1)>>1)+2;
        }
        BUFF=new u64[buff_all*4] {};
        alloc_buff=true;
    }

    size_t k=(sb>>1);

    v_mult_karats(a,k,b,k,c,BUFF);
    v_mult_karats(a+k,sa-k,b+k,sb-k,c+k*2,BUFF);

    size_t sc=sa+sb;
    while(c[sc-1]==0) --sc;
    u64 minu2=v_minus_ret(c+k,k,c+k*2,k), minu3=0;

    if(minu2==0) {
        for(size_t i=0; i<k; ++i) {
            if(c[k+i]) {
                minu3=1;
                break;
            }
        }
    }
    if(minu2||minu3) {
        for(size_t i=0; i<k; ++i) {
            c[k*2+i]=~c[k+i];
        }
        v_add1(c+k*2);
    } else {
        memset(c+k*2,0,k*sizeof(u64));
    }
    v_minus_free(c+k*2,c+k*3,sc-k*3);
    v_minus_free(c+k,c,k);
    while(minu2--) v_minus1(c+k*2);
    while(minu3--) v_minus1(c+k*3);

    size_t sA=1+max(sa-k,k), sB=1+max(sb-k,k);
    u64* A=BUFF, *B=A+sb-k+1;

    memcpy(A,a,sizeof(u64)*k);
    memset(A+k,0,sizeof(u64)*(sA-k));
    v_add(A,sA,a+k,sa-k);
    while(sA>k&&A[sA-1]==0) --sA;

    memcpy(B,b,sizeof(u64)*k);
    memset(B+k,0,sizeof(u64)*(sB-k));
    v_add(B,sB,b+k,sb-k);
    while(sB>k&&B[sB-1]==0) --sB;

    u64* C=B+((sb+1)>>1)+1;
    memset(C,0,sizeof(u64)*(sB+sA));
    v_mult_karats(A,sA,B,sB,C, C+(sb-k)*2+2);
    v_add(c+k,sa+sb-k,C,sA+sB);

    if(alloc_buff) {
        delete[] BUFF;
    }
}

inline void QaddXy(u64 &x, u64 y) {
    if(x+y<y) x=x+y-P0; 
    else x=x+y;
}

u64 qmult(u64 x, u64 y) {
    u64 a0,a1;
    mult_to128(x,y,a0,a1);
    QaddXy(a0,P0-a1); //assert(a1<P0), need x<P0 or y<P0)
    QaddXy(a0,a1<<24); // P = 2^64 - 2^24 + 1 > (a1<<24)
    QaddXy(a0,((a1>>40)<<24)-(a1>>40)); //2^48 <P
    return (a0<P0) ? a0 : a0-P0;
}
u64 qsquare(u64 x) {
    u64 a0,a1;
    square_to128(x,a0,a1);
    QaddXy(a0,P0-a1);
    QaddXy(a0,a1<<24);
    QaddXy(a0,((a1>>40)<<24)-(a1>>40));
    return a0;
}
u64 qpow(u64 a, u64 b)  {
    u64 ans(1);
    while(b!=0) {
        if(b&1)
            ans=qmult(ans,a);
        b>>=1;
        a=qsquare(a);
    }
    return ans;
}
/*
u64 qmult129(u64 x){
    return QaddXy(x-(x>>57),QaddXy((x>>57)<<24,x<<7));
}*/

void ntt(u64 *a, int fsz,int opt=1) {
    static int* r=nullptr;
    static int rsz=0;
    if(fsz!=rsz) {
        if(fsz==0) return;
        if(r) delete[] r;
        r = new int[fsz] {};
        rsz=fsz;
        for(int i=0; i<fsz; ++i) r[i]=((i&1)*(fsz>>1))+(r[i>>1]>>1);
    }
    for(int i=0; i<fsz; ++i) if(r[i]<i) swap(a[i],a[r[i]]);
    for(int m=2; m<=fsz; m<<=1) {
        int k=m>>1;
        u64 gn = qpow(G0, (P0-1)/m);
        for (u64* ai=a; ai < a+fsz; ai += m) {
            u64 g = 1;
            for (int j = 0; j < k; j+=1) {
                u64 tmp = qmult(g, ai[j+k]);
                ai[j+k] = (ai[j]<tmp) ? ai[j]-tmp + P0 : ai[j]-tmp;
                ai[j] = (ai[j]<P0-tmp) ? ai[j]+tmp : ai[j]-(P0-tmp);
                g = qmult(g,gn);
            }
        }
    }
    if(opt==-1) {
        reverse(a+1,a+fsz);
        uint64_t fsz_1 = qpow(fsz, P0-2);
        for(int i=0; i<fsz; ++i) a[i]=qmult(fsz_1,a[i]);
    }
}

void ntt2(u64* a,int fsz,int opt=1){
    static int* r=nullptr, rsz=0;
    static u64* P=nullptr;
    if(fsz!=rsz) {
        if(fsz==0) return;
        if(r) delete[] r;
        if(P) delete[] P;
        r = new int[fsz];r[0]=0;
        P = new u64[fsz];
        rsz=fsz;
        for(int i=0; i<fsz; ++i) r[i]=((i&1)*(fsz>>1))+(r[i>>1]>>1);
        u64 P1=qpow(G0,(P0-1)/fsz);
        P[fsz/2]=1;
        for(int i=fsz/2+1; i<fsz; ++i) P[i]=qmult(P[i-1],P1);
        for(int k=fsz/4; k; k/=2){
            for(int i=0;i<k;++i){
                P[k+i]=P[k*2+i*2];
            }
        }
    }
    for(int i=0; i<fsz; ++i) if(r[i]<i) swap(a[i],a[r[i]]);
    for(int m=2; m<=fsz; m<<=1) {
        int k=m>>1;
        for (u64* ai=a; ai < a+fsz; ai += m) {
            int gi = k;
            for (int j = 0; j < k; j+=1) {
                u64 tmp = qmult(P[gi], ai[j+k]);
                ai[j+k] = (ai[j]<tmp) ? ai[j]-tmp + P0 : ai[j]-tmp;
                ai[j] = (ai[j]<P0-tmp) ? ai[j]+tmp : ai[j]-(P0-tmp);
                gi = gi+1;
            }
        }
    }
    if(opt==-1) {
        reverse(a+1,a+fsz);
        uint64_t fsz_1 = qpow(fsz, P0-2);
        for(int i=0; i<fsz; ++i) a[i]=qmult(fsz_1,a[i]);
    }
}

void bit_split(u64* t,const u64* s,size_t n, int bw) {
    int bi=0;
    u64 msk = (1ull<<bw)-1;
    for(size_t i=0,j=0; j<n; ++i) {
        t[i]=(s[j]>>bi)&msk;
        bi+=bw;
        if(bi>=64) {
            bi-=64;
            ++j;
            if(j<n) t[i] |= (s[j]<<(bw-bi))&msk;
        }
    }
}
void bit_merge(u64* t,const u64* s,size_t n, int bw) {
    int bi=0;
    for(size_t i=0,j=0; j<n; ++j) {
        t[i]|=(s[j]<<bi);
        bi=bi+bw;
        if(bi>=64) {
            bi-=64;
            ++i;
            if(i<n) t[i]=(s[j]>>(bw-bi));
        }
    }
}

//assert c==nullptr, sa>0 sb>0
void v_mult_ntt(const u64* a,size_t sa,const u64* b,size_t sb,
                u64* &c, size_t& fsz) {
    //cout<<"ntt start"<<endl;
    int bw = 28;
    while((64*(sa+sb)+bw-1)/bw*((1ull<<bw)-1) > (P0>>bw)) --bw;
    size_t csz=(64*sb+bw-1)/bw+(64*sa+bw-1)/bw;

    fsz=2ull<<log2i(csz-1);
    //cout<<"bw "<<bw<<" "<<csz<<" "<<fsz<<endl;
    c = new u64[fsz] {};
    u64* t = new u64[fsz] {};
    bit_split(c,a,sa,bw);
    bit_split(t,b,sb,bw);
    ntt2(c,fsz);
    ntt2(t,fsz);
    for(size_t i=0; i<fsz; ++i) {
        t[i]=qmult(t[i],c[i]);
    }
    ntt2(t,fsz,-1);
    u64 up=0, msk=(1ull<<bw)-1;
    for(size_t i=0; i<fsz; ++i) {
        t[i] = t[i] + up;
        up = t[i]>>bw;
        t[i]=t[i]&msk;
    }
    memset(c,0,sizeof(u64)*fsz);
    bit_merge(c,t,fsz,bw);
    delete[] t;
}

void v_mult(const u64* aa,size_t sa,const u64* bb,size_t sb,
            u64* &c, size_t& fsz,int method=0) {
    if(!method) {
        if(sa<64 || sb<64) method=1;
        else if(log2(sa+sb)*(sa+sb)*6.9<pow(min(sa,sb),
                                             -0.585)*sa*sb) method =3;
        else method=2;
    }
    if(method==1) {
        fsz=sa+sb;
        c=new u64[fsz] {};
        v_mult_simple(aa,sa,bb,sb,c);
        return;
    } else if(method==2) {
        fsz=sa+sb;
        c=new u64[fsz] {};
        v_mult_karats(aa,sa,bb,sb,c);
        return;
    } else if(method==3) {
        v_mult_ntt(aa,sa,bb,sb,c,fsz);
        return;
    }
    assert(0);
}

void v_to_baseX(u64 x,const u64* a,size_t z,u64* &b,size_t& bz){
    assert(z);
    if(!b){
        bz = ceil(64.0000000001*z/log2(x));
        b=new u64[bz]{};
    }
    const u64 p = x;
    size_t bi=0;
    u64 up;
    for(size_t i=0;i<z;++i){
        up=0;
        for(size_t j=0;j<bi;++j){
            u64 a0,a1;
            mult_to128(b[j],p,a0,a1);
            b[j]=(a0%p)+up;
            up=a1+(a0/p)+(b[j]<up);
        }
        if(up){
            b[bi]=up;
            ++bi;
        }
    }
}
/*
void v_to_chars(u64 ,){

}*/

}//namespace mbase

using namespace mbase;

template< typename T >
std::string int_to_hex( T i ,int w) {
    std::stringstream stream;
    stream << std::setfill ('0') << std::setw(w)
           << std::hex << i;
    string s=stream.str();
    reverse(s.begin(),s.end());
    return s;
}

int main() {
    default_random_engine generator;
    uniform_int_distribution<uint64_t> distribution(0ULL,
            ULLONG_MAX);
    while(1) {
        int k1,k2;
        cin>>k1>>k2;
        int cs=1e6/(k1+k2);
        if(cs==0) cs=1;
        u64* a,*b,*c1,*c2,*c3;
        a=new u64[k1+88] {};
        //u64* aq=new u64[k1*4] {};
        //u64* ap=new u64[k1+2] {};

        for(int i=0; i<k1; ++i) a[i] = distribution(generator);

        /*bit_split(aq,a,k1,20);
        bit_merge(ap,aq,(64*k1+19)/20,20);
        if(my_comp(a,ap,k1)) {
            cout<<"bit error"<<endl;
        }*/
        //cout<<"aq ";
        //for(size_t i=0; i<(64*k1+19)/20; ++i) cout<<int_to_hex(aq[i],5);
        //cout<<endl;
        b=new u64[k2];
        for(int i=0; i<k2; ++i) b[i] = distribution(generator);
        /*cout<<"a  ";
        for(size_t i=0; i<k1; ++i) cout<<a[i]<<" ";
        cout<<endl;
        cout<<"b ";for(size_t i=0;i<k2;++i) cout<<b[i]<<" ";cout<<endl;*/
        size_t cs1,cs2,cs3;
        //v_mult(a,k1,b,k2,c1,cs1,1);
        //cout<<"m1"<<endl;
        double t=clock();
        if(k1+k2<1e6) {
            for(int i=0; i<cs; ++i) v_mult(a,k1,b,k2,c2,cs2,2);
            cout<<"m2  "<<(clock()-t)/cs<<endl;
        }

        t=clock();
        for(int i=0; i<cs; ++i) v_mult(a,k1,b,k2,c3,cs3,3);
        cout<<"m3  "<<(clock()-t)/cs<<endl;


        /*if(my_comp(c1,c2,cs1)) {
            for(size_t i=0; i<cs1; ++i) cout<<(int64_t)(c2[i]-c1[i])<<" ";
            cout<<endl;
            cout<<"mis 12\n\n\n"<<endl;
        }*/
        if(k1+k2<1e6 && my_comp(c2,c3,cs2)) {
            for(size_t i=0; i<cs2; ++i) cout<<c2[i]<<" ";
            cout<<endl;
            for(size_t i=0; i<cs2; ++i) cout<<c3[i]<<" ";
            cout<<endl;
            cout<<"mis 23"<<endl;
        }
        delete[] a;
        delete[] b;
        //delete[] c1;
        if(k1+k2<1e6) delete[] c2;
        delete[] c3;
        cout<<"pass"<<endl;
    }
}

 

posted @ 2022-07-09 16:39  141421356  阅读(41)  评论(0)    收藏  举报