[CodeForces 466C] Number of Ways

Given an array that has n integers, count the number of ways to split all elements of this array into 3 contiguous parts so that the sum of each part is the same. Each part must not be empty.

 

Algorithm: O(N) runtime

1. If total sum % 3 != 0, return 0. 

2. Keep a count of prefix sums that are 1 / 3 of the total sum, denote it by psOneThird.

3. Visit all elements from index 1 to n - 2 and for each newly visited element a[i], do the following:

(a). If the current sum is 2 / 3 of total sum, we know we just found a partition that leaves the unvisited elements summing up to 1 / 3 of total sum. Since we just find this partition by adding a[i] to the current sum, we can conclude that there are psOneThird distinct pairs that have 1 / 3 of total sum each, with the right half of 1 / 3 ends right after a[i]. Add psOneThird to the final answer.

(b). If the current sum is 1 / 3 of total sum, update psOneThird.  

 

Why does the above algorithm work? 

Each valid partition must have the 2nd part ends on an element from a[1] to a[n - 2]. And each time the 2nd part ends on one of these elements, it generates some more unique partitions. (The 2nd part is different). So we check from a[1] to a[n - 2]. After adding one element to the current sum and it becomes 2 /3 of the total sum, we've just found such a unique partition. Now we just need to add the count of (1/3, 2/3) pairs, contributed by this newly found 2nd part. Since the 2nd part is fixed, the count is just how many prefix sums of 1 / 3 of total sum we've seen so far.

 

One key note here is that we must check if the current sum is 2/3 of total sum first, then check if 1/3. This ensures we don't have 1/3 and 2/3 partition ends on the same spot, right after a[i], for the corner case of 0 total sum.

 

    private static void solve(int q, FastScanner in, PrintWriter out) {
        int n = in.nextInt();
        int[] a = new int[n];
        long sum = 0;
        for(int i = 0; i < n; i++) {
            a[i] = in.nextInt();
            sum += a[i];
        }
        if(sum % 3 != 0) {
            out.println(0);
        }
        else {
            long currSum = a[0], oneThirdSum = sum / 3, ans = 0;
            int psOneThird = (a[0] == oneThirdSum ? 1 : 0);
            for(int i = 1; i < n - 1; i++) {
                currSum += a[i];
                if(currSum == oneThirdSum * 2) {
                    ans += psOneThird;
                }
                if(currSum == oneThirdSum) {
                    psOneThird++;
                }
            }
            out.println(ans);
        }
        out.close();
    }
posted @ 2019-10-29 02:10  Review->Improve  阅读(460)  评论(0编辑  收藏  举报