类欧几里得
例题
https://atcoder.jp/contests/practice2/tasks/practice2_c
在\(O(\log (n+m+k+b))\)的时间复杂度求:
其中\(n,m,k,b\)都是整数。
类欧几里得
如果没有下取整函数,我们直接将里面的一次函数求前缀和(用等差数列转成一个二次函数)即可,这是简单的。
但由于有下取整函数,这个问题似乎难以解决。并且之前的套路也无法快速的处理下取整函数。
接下来要将一个神奇的做法——类欧几里得算法(这里只讲最简单的情况,其实这个算法还可以解决很多变式,打算以后开博客讲)
首先我们学过欧几里得算法:\(\gcd(x,y)=\gcd(y,x\%y)\),但是其实,这个等式是建立在下面两个等式成立的基础上的:
类欧几里得也是这个思路:
设\(solve(k,b,m,n)=\sum_{i=0}^{n-1} \lfloor{\frac{ki+b}{m}}\rfloor\),尝试进行递归。
对于:\(\sum_{i=0}^{n-1} \lfloor{\frac{ki+b}{m}}\rfloor\) ,如果\(k,b>m\),我们可以直接拆开,故:
于是:\(solve(k,b,m,n)=solve(k\%m,b\%m,m,n) + \frac{n(n-1)}{2} \lfloor\frac{k}{m}\rfloor + n\lfloor\frac{b}{m}\rfloor\),也就像欧几里得算法中第二个式子一样。
之后我们要解决\(k,b\leq m\)的情况,我们希望对\((k,b,m,n)\)进行一个变换,再递归下去。
此时我们设\(smallest(j)\)表示让\(\lfloor{\frac{ki+b}{m}}\rfloor=j\)的最小的一个\(i\)。并且设\(MAX=\lfloor{\frac{nk+b}{m}}\rfloor\),也就是在函数范围内能取到最大的数。特别的,我们设\(smallest(MAX+1)=n\)。
于是:
其实\(smallest(i+1)-smallest(i)\)就是一个前缀和,等价于等于\(i\)的\(\lfloor{\frac{ki+b}{m}}\rfloor=j\)的数量。
并且此时,注意到\(\sum_{i=0}^{MAX-1} smallest(i+1)\)若能表示成一个下取整的形式,那么我们就可以递归了!
\(smallest(i)\)还是很好求的,它等于\(\lceil \frac{mi-b}{k} \rceil\)。这个式子是上取整,这可不太妙,意味着我们不好转成一个递归的形式。但是大家都知道,上取整是可以转成下取整的,故:\(smallest(i)=\lceil \frac{mi-b}{k} \rceil=\lfloor \frac{mi-b+k-1}{k} \rfloor\)。
所以带入推出的式子中:
所以综上所述:
那么也就是\(solve(k,b,m,n)=n\cdot MAX-solve(m,m+k-b-1,k,MAX)\),也就是类似于欧几里得算法的第一个变换。
至此,我们在讨论边界情况\(n=0\)时,答案为\(0\),就可以愉快的递归了!
时间复杂度显然和欧几里得算法一样,是\(\log\)级别的。
代码还是很简单的:
#include<bits/stdc++.h>
#define debug(...) std::cerr<<#__VA_ARGS__<<" : "<<__VA_ARGS__<<std::endl
using ll=long long;
ll solve(ll k,ll b,ll m,ll n) {
if(n==0) return 0;
if(k>=m||b>=m) {
return solve(k%m,b%m,m,n)+(n*(n-1)/2ll)*(k/m)+n*(b/m);
} else {
ll MAX=(k*(n-1)+b)/m;
return n*MAX-solve(m,m+k-b-1,k,MAX);
}
}
int main() {
int k,b,m,n,T;
scanf("%d",&T);
while(T--) {
std::cin>>n>>m>>k>>b;
std::cout<<solve(k,b,m,n)<<'\n';
}
return 0;
}
常用套路
我们在求解一些整除问题时可以使用这个算法。我们可能会遇到统计两个函数内没有某个数\(D\)的倍数的问题,此时我们可以二分找到两个函数差的绝对值\(\leq 1\)的左右两端,此时内部最多只能出现一个\(D\)的倍数。只需用类欧几里得求区间和,再相减即可。
比如下面的两道题:ARC111E和ARC123E。
ARC111E
我们设:
\(f(x)=bx+a,g(x)=cx+a\),由于题中说了\(b<c\),故\(f(x)\leq g(x)\)。
若\(g(x)-f(x)\geq d\),那么显然区间\([f(x),g(x)]\)内部至少有一个数是\(d\)的倍数。
我们二分出最大的点\(n\),满足\(g(x)-f(x)<d\)。(其实也可以\(O(1)\)算,不过有精度误差问题,并且最后类欧几里得时你还是要\(O(\log n)\)的计算)
那么对于\(i=0,1,...,n\),区间\([f(x),g(x)]\)内只能有一个\(d\)的倍数,要么没有。
我们直接求出\(cnt=\sum_{i=0}^{n-1} \lfloor \frac{g(x)}{d} \rfloor - \sum_{i=0}^{n-1} \lfloor \frac{f(x)-1}{d} \rfloor\),这样我们就算出区间内有多少个\(d\)的倍数了,然而每个点最多只能有一个\(d\)的倍数,所以\(cnt\)实际上就是多少个\(i\)使得区间内有\(d\)的倍数。我们直接将\((n+1)\)减去\(cnt\)即可算出区间内没有\(d\)的倍数的\(i\)的个数。
也就是\(answer=n+1-(\sum_{i=0}^{n-1} \lfloor \frac{g(x)}{d} \rfloor - \sum_{i=0}^{n-1} \lfloor \frac{f(x)-1}{d} \rfloor)\),显然这个求和式是类欧几里得板子。
注意题中是不算\(0\)的。如果\(a\mod d\)不为\(0\)时,我们会将\(0\)也统计进去,此时我们要将答案减一。
时间复杂度为\(\log(a+b+c)\)。
#include<bits/stdc++.h>
#define debug(...) std::cerr<<#__VA_ARGS__<<" : "<<__VA_ARGS__<<std::endl
using ll=long long;
ll lgcd(ll k,ll b,ll m,ll n) {
if(n==0) return 0ll;
if(k>=m||b>=m) {
return lgcd(k%m,b%m,m,n)+(k/m)*(n*(n-1)/2ll)+(b/m)*n;
} else {
ll MAX=((n-1)*k+b)/m;
return n*MAX-lgcd(m,m+k-b-1,k,MAX);
}
}
void solve() {
ll a,b,c,d;
scanf("%lld%lld%lld%lld",&a,&b,&c,&d);
ll lef=0,rig=1e10,p=-1;
while(lef<=rig) {
ll mid=lef+rig>>1ll;
if((a+c*mid)-(a+b*mid)<d) {
p=mid;
lef=mid+1;
} else {
rig=mid-1;
}
}
ll ans=p+1-(lgcd(c,a,d,p+1)-lgcd(b,a-1,d,p+1));
if(a%d!=0) ans--;
printf("%lld\n",ans);
return;
}
int main() {
int T;
scanf("%d",&T);
while(T--) solve();
return 0;
}
ARC123E
其实这题和ARC111E是类似的,不过添加了另一侧的情况。
如果\(Bx<By\),我们交换\(Ax,Ay,Bx,By\)。否则我们设:
这题就是想让我们求\(F(x)=G(x)\)的\(1\leq x\leq n\)的个数。
那么还是同上题的套路,我们求出左端点\(l\)为满足\(f(x)-g(x)\leq 1\)的最小的点,右端点\(r\)为满足\(g(x)-f(x)\leq 1\)最大的点。并且\(1\leq l,r\leq n\)。
如果\(l,r\)不存在,说明没有这样的点,直接输出\(0\)。
为了以防万一,我还特判了\(l>r\),输出\(0\);以及当\(Bx=By\)时,若\(Ax=Ay\),输出\(n\),否则输出\(0\)。
不同于上一题,这题有左右两个部分,我们要找到\(f(x)\)和\(g(x)\)的交点\(p\),并分几种情况讨论:
第一种是\(l\leq p \leq r\),也就是下图:

同ARC111E,我们求出\((p-l+1)-\sum_{i=l}^{p} (F(i)-G(i)) + (r-p)-\sum_{i=p+1}^{r} (G(i)-F(i))\),就是答案。
还有一种是\(l,r\)都在\(p\)左侧,答案就是\((r-l+1)-\sum_{i=l}^{r} (F(i)-G(i))\)。
最后一种是\(l,r\)都在\(p\)右侧,答案就是\((r-l+1)-\sum_{i=l}^{r} (G(i)-F(i))\)。
实现时可以写一个函数,解决左侧和右侧的不同统计方式。(其实就是一侧\(F(i)>G(i)\),另一侧是\(F(i)<G(i)\))
#include<bits/stdc++.h>
#define debug(...) std::cerr<<#__VA_ARGS__<<" : "<<__VA_ARGS__<<std::endl
using ll=long long;
ll n,k1,k2,b1,b2,m1,m2;
double f(ll x) {
return double(k1*x+b1)/(double)m1;
}
double g(ll x) {
return double(k2*x+b2)/(double)m2;
}
ll solve(ll k,ll b,ll m,ll n) {
if(n==0) return 0ll;
if(k>=m||b>=m) {
return solve(k%m,b%m,m,n)+(k/m)*(n*(n-1)/2ll)+(b/m)*n;
} else {
ll MAX=((n-1)*k+b)/m;
return n*MAX-solve(m,m+k-b-1,k,MAX);
}
}
ll get(ll l,ll r,ll type) {
if(type==1) {
return (r-l+1)-(solve(k1,b1,m1,r+1)-solve(k1,b1,m1,l))+(solve(k2,b2,m2,r+1)-solve(k2,b2,m2,l));
} else {
return (r-l+1)-(solve(k2,b2,m2,r+1)-solve(k2,b2,m2,l))+(solve(k1,b1,m1,r+1)-solve(k1,b1,m1,l));
}
}
void solve() {
ll a,b,c,d;
scanf("%lld%lld%lld%lld%lld",&n,&a,&b,&c,&d);
m1=b,k1=1,b1=a*b;
m2=d,k2=1,b2=c*d;
//f(x)=(k1*x+b1)/m1; (k1=k2=1)
//g(x)=(k2*x+b2)/m2;
if(m1<m2) {
std::swap(m1,m2);
std::swap(k1,k2);
std::swap(b1,b2);
} else if(m1==m2) {
if(b1==b2) printf("%lld\n",n);
else printf("0\n");
return;
}
ll l=-1,r=-1,lef,rig;
lef=1,rig=n;
while(lef<=rig) {
ll mid=lef+rig>>1;
if(g(mid)-f(mid)<=1) {
r=mid; lef=mid+1;
} else {
rig=mid-1;
}
}
lef=1,rig=n;
while(lef<=rig) {
ll mid=lef+rig>>1;
if(f(mid)-g(mid)<=1) {
l=mid; rig=mid-1;
} else {
lef=mid+1;
}
}
if(l==-1||r==-1||l>r) {
printf("0\n");
return;
}
ll p=(b2*m1-b1*m2)/(k1*m2-k2*m1),ans=0;
if(p<l) {
ans=get(l,r,2);
} else if(r<=p) {
ans=get(l,r,1);
} else {
ans=get(l,p,1)+get(p+1,r,2);
}
printf("%lld\n",ans);
return;
}
int main() {
int T;
scanf("%d",&T);
while(T--) solve();
return 0;
}
Luogu P433 ALADIN
首先我们观察到:
左半部分就是一个等差数列,求法是简单的,右边部分则是我们刚才提到的类欧几里得板子。
于是我们可以轻松计算一段区间修改了。考虑这题是多次修改区间,使用离散化+线段树就可以做到\(O(\log n)\)查询,加上类欧几里德就是两个\(\log\),是可以通过的。
注意线段树维护的是区间,并且要将区间拆成几个部分。比如我们将所有区间左右节点排序后,有:
这几个点,那么我们建立\(8\)个区间,分别是:
并用线段树维护这\(8\)个区间即可。
#include<bits/stdc++.h>
#define debug(...) std::cerr<<#__VA_ARGS__<<" : "<<__VA_ARGS__<<std::endl
using ll=long long;
ll solve(ll k,ll b,ll m,ll n) {
//类欧几里德
if(n==0) return 0ll;
if(k>=m||b>=m) {
return solve(k%m,b%m,m,n)+(k/m)*(n*(n-1)/2ll)+(b/m)*n;
} else {
ll MAX=((n-1)*k+b)/m;
return n*MAX-solve(m,m+k-b-1,k,MAX);
}
}
const int maxn=300005;
int n,q;
std::vector<int> vec;
std::vector<std::pair<int,int>> seg;
struct node {
ll sum;
int real_left,real_right;
int lzy,lzy_A,lzy_B,lzy_begin;
}tree[maxn<<2];
void build(int pos,int lef,int rig) {
tree[pos].real_left=seg[lef-1].first,tree[pos].real_right=seg[rig-1].second;
tree[pos].sum=0; tree[pos].lzy=0;
if(lef!=rig) {
int mid=lef+rig>>1;
build(pos<<1,lef,mid);
build(pos<<1|1,mid+1,rig);
}
}
ll upd(int pos,int begin,int A,int B) {
ll first=tree[pos].real_left-begin+1,last=tree[pos].real_right-begin+1;
tree[pos].sum=(first+last)*(tree[pos].real_right-tree[pos].real_left+1)/2ll*(ll)A;
ll b_=(ll)(-begin)*A+A,k_=A,m_=B;
tree[pos].sum-=(solve(k_,b_,m_,tree[pos].real_right+1)-solve(k_,b_,m_,tree[pos].real_left))*(ll)B;
/////tagged//////
tree[pos].lzy=1,tree[pos].lzy_begin=begin,tree[pos].lzy_A=A,tree[pos].lzy_B=B;
}
void pushdown(int pos) {
if(tree[pos].lzy==0) return;
upd(pos<<1,tree[pos].lzy_begin,tree[pos].lzy_A,tree[pos].lzy_B);
upd(pos<<1|1,tree[pos].lzy_begin,tree[pos].lzy_A,tree[pos].lzy_B);
tree[pos].lzy=0;
}
void update(int l,int r,int A,int B,int pos,int lef,int rig) {
if(l<=lef&&rig<=r) {
upd(pos,seg[l-1].first,A,B);//注意这里是seg[l-1].first而不是vec[l-1]!!!
} else if(l<=rig&&r>=lef) {
pushdown(pos);
int mid=lef+rig>>1;
update(l,r,A,B,pos<<1,lef,mid);
update(l,r,A,B,pos<<1|1,mid+1,rig);
tree[pos].sum=tree[pos<<1].sum+tree[pos<<1|1].sum;
}
}
ll query(int l,int r,int pos,int lef,int rig) {
if(l<=lef&&rig<=r) {
return tree[pos].sum;
} else if(l<=rig&&r>=lef) {
pushdown(pos);
int mid=lef+rig>>1;
return query(l,r,pos<<1,lef,mid)+query(l,r,pos<<1|1,mid+1,rig);
}
return 0ll;
}
int opt[maxn],l[maxn],r[maxn],a[maxn],b[maxn];
int main() {
scanf("%d%d",&n,&q);
for(int i=1;i<=q;i++) {
scanf("%d%d%d",&opt[i],&l[i],&r[i]);
if(opt[i]==1) scanf("%d%d",&a[i],&b[i]);
vec.push_back(l[i]);
vec.push_back(r[i]);
}
std::sort(vec.begin(),vec.end());
vec.erase(std::unique(vec.begin(),vec.end()),vec.end());
for(int i=0;i<(int)vec.size();i++) {
seg.push_back({vec[i],vec[i]});
if(i&&vec[i]-vec[i-1]>=2) seg.push_back({vec[i-1]+1,vec[i]-1});
}
std::sort(seg.begin(),seg.end());
build(1,1,(int)seg.size());
for(int i=1;i<=q;i++) {
int from=std::lower_bound(seg.begin(),seg.end(),std::make_pair(l[i],l[i]))-seg.begin()+1;
int to=std::lower_bound(seg.begin(),seg.end(),std::make_pair(r[i],r[i]))-seg.begin()+1;
if(opt[i]==1) {
update(from,to,a[i],b[i],1,1,(int)seg.size());
} else {
printf("%lld\n",query(from,to,1,1,(int)seg.size()));
}
}
return 0;
}
浙公网安备 33010602011771号