LG10704

看到 \(\sum \limits_{i=1}^n a_i \le 10^9\),可以很快发现不同的 \(a_i\) 大约只有 \(42000\) 个。设去重后的 \(a_i\)\(k\) 个,则不难得到一个 \(O(k^2)\) 的做法,可以获得 \(70\) 分。

for( int i = 1 ; i <= n ; i ++ )
{
	if( a[i] != a[i - 1] ) b[++ tot] = a[i],cnt[tot] = 1;
	else cnt[tot] ++;
}
for( int i = 1 ; i <= tot ; i ++ )
{
	s = ( s + 1ll * ( m / ( b[i] * b[i] ) ) * cnt[i] % MOD * cnt[i] ) % MOD;
	for( int j = i + 1 ; j <= tot ; j ++ )
	{
		if( b[i] * b[j] > m ) break;
		ans = ( ans + 1ll * ( m / ( b[i] * b[j] ) ) * cnt[i] % MOD * cnt[j] ) % MOD;
	}
}

然而面对大约 \(42000\) 的数据,这个算法仍需改进。注意到题目中有一个向下取整的符号,这暗示着虽然 \(a_i \times a_j\) 不同,但最终得到的值却可能是一样的。我们就从这一点下手优化。

设去重后的数组为 \(b\),并且 \(b\) 已经排序。对于每个 \(b_i\),首先确定 \(\left \lfloor \frac{m}{b_ib_j} \right \rfloor\) 相同时 \(j\) 的取值范围。显然,相同时 \(j\) 是连续的,因此可以直接二分求得左右边界即可。再通过前缀和维护这段区间内的数的总个数,即可求得该段的贡献。以此类推。

时间复杂度大约为 \(O(k \sqrt m \log k )\),实际上完全跑不满,因此能够通过。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <map>
#define int long long
#define MOD 998244353
//#define ll long long

using namespace std;

int n;
int a[1000001],b[1000001],tot,cnt[1000001],mx,s,ans,m,S[1000001],L,R,nw;

inline int read()
{
	int x=0,f=1;char ch=getchar();
	while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
	while (ch>='0'&&ch<='9'){x=x*10+ch-48;ch=getchar();}
	return x*f;
}

int ck( int x , int ti , int Ll )
{
	int l = Ll,r = tot,mid,Rr = Ll;
	while( l < r )
	{
		mid = ( l + r ) / 2 + 1;
		if( ( m / ( b[x] * b[mid] ) ) < ti ) r = mid - 1;
		else l = mid,Rr = mid;
	}
	return Rr;
}

signed main()
{
	n = read(),m = read();
	for( int i = 1 ; i <= n ; i ++ )
		a[i] = read();
	sort( a + 1 , a + n + 1 );
	for( int i = 1 ; i <= n ; i ++ )
	{
		if( a[i] != a[i - 1] ) b[++ tot] = a[i],cnt[tot] = 1;
		else cnt[tot] ++;
	}
	for( int i = 1 ; i <= tot ; i ++ )
		S[i] = S[i - 1] + cnt[i];
	for( int i = 1 ; i <= tot ; i ++ )
	{
		s = ( s + 1ll * ( m / ( b[i] * b[i] ) ) * cnt[i] % MOD * cnt[i] ) % MOD;
		if( i == tot ) continue;
		L = i + 1,nw = ( m / ( b[i] * b[L] ) );
		while( L <= tot )
		{
			R = ck( i , nw , L );
			ans = ( ans + nw * cnt[i] % MOD * ( S[R] - S[L - 1] ) ) % MOD;
			L = R + 1;
			if( L > tot ) break;
			nw = m / ( b[i] * b[L] ); 
		}
	}
	cout << ( ans * 2 + s ) % MOD;
	return 0;
}
posted @ 2025-09-08 18:31  FormulaOne  阅读(10)  评论(0)    收藏  举报