【poj 3261】Milk Patterns 后缀数组

Milk Patterns

题意

给出n个数字,以及一个k,求至少出现k次的最长子序列的长度

思路

和poj 1743思路差不多,二分长度,把后缀分成若干组,每组任意后缀公共前缀都>=当前二分的长度。统计是否有某个组后缀数量>=k,如果有当前长度就可以。

代码

// #include <bits/stdc++.h>
#include <stdio.h>
#include <algorithm>
#include <string.h>
#define pb push_back
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int N = 1e6 + 10;
const int mod = 1e9 + 7;
const int inf = 0x3f3f3f3f;

int sa[N], rk[N], oldrk[N], cnt[N], pos[N], ht[N], n, m, x;
int arr[N];
bool cmp(int a, int b, int k)
{
    return oldrk[a] == oldrk[b] && oldrk[a + k] == oldrk[b + k];
}
void getsa()
{
    memset(cnt, 0, sizeof(cnt));
    m = 1000010;
    for (int i = 1; i <= n; i++)
        ++cnt[rk[i] = (arr[i] + 1)];
    for (int i = 1; i <= m; i++)
        cnt[i] += cnt[i - 1];
    for (int i = n; i; i--)
        sa[cnt[rk[i]]--] = i;
    for (int k = 1; k <= n; k <<= 1)
    {
        int num = 0;
        for (int i = n - k + 1; i <= n; i++)
            pos[++num] = i;
        for (int i = 1; i <= n; i++)
        {
            if (sa[i] > k)
                pos[++num] = sa[i] - k;
        }
        memset(cnt, 0, sizeof(cnt));
        for (int i = 1; i <= n; i++)
            ++cnt[rk[i]];
        for (int i = 1; i <= m; i++)
            cnt[i] += cnt[i - 1];
        for (int i = n; i; i--)
            sa[cnt[rk[pos[i]]]--] = pos[i];
        memcpy(oldrk, rk, sizeof(rk));
        num = 0;
        for (int i = 1; i <= n; i++)
            rk[sa[i]] = cmp(sa[i], sa[i - 1], k) ? num : ++num;
        if (num == n)
            break;
        m = num;
    }
    for (int i = 1; i <= n; i++)
        rk[sa[i]] = i;
    int k = 0;
    for (int i = 1; i <= n; i++)
    {
        if (k)
            --k;
        while (arr[i + k] == arr[sa[rk[i] - 1] + k])
            ++k;
        ht[rk[i]] = k;
    }
}
int judge(int len)
{
    int num = 1;
    for (int i = 2; i <= n; i++)
    {
        if (ht[i] >= len)
            ++num;
        else
            num = 1;
        if (num >= x)
            return 1;
    }
    return 0;
}
int main()
{
    while (~scanf("%d%d", &n, &x))
    {
        for (int i = 1; i <= n; i++)
            scanf("%d", &arr[i]);
        getsa();
        int l = 0, r = n, ans = 0;
        while (l <= r)
        {
            int mid = (l + r) / 2;
            if (judge(mid))
            {
                ans = mid;
                l = mid + 1;
            }
            else
                r = mid - 1;
        }
        printf("%d\n", ans);
    }
    return 0;
}
posted @ 2020-05-12 08:47  Valk3  阅读(115)  评论(0编辑  收藏  举报