TechFUL: スーパー転倒数

問題概要

長さ  N の順列  A, B が与えられる.

次の 2 つの条件を満たす  (i, j) の組の個数を求めよ.

  •  A \lbrack i \rbrack \gt A \lbrack j \rbrack
  •  B \lbrack i \rbrack \lt B \lbrack j \rbrack

問題のリンク

制約

  •  1 \lt N \leq 50000
  •  1 \lt A \lbrack i \rbrack, B \lbrack i \rbrack \leq N
  •  A, B は長さ  N の順列

解法

数列の転倒数は、その数列をマージソートする過程で計算する.

そのため、スーパー転倒数も数列をマージソートする過程で計算できないか考えてみる.

マージソートの実装や性質については drken さんの記事が参考になる.

qiita.com

スーパー転倒数の計算

 A, B をそれぞれ2つの数列  A_{1}, A_{2}, B_{1}, B_{2} に分割する.

数列  A, B のスーパー転倒数には次の3種類のスーパー転倒が起こる.

  •  A_{1}, B_{1} の要素のみで起こるスーパー転倒

  •  A_{2}, B_{2} の要素のみで起こるスーパー転倒

  •  A_{1}, B_{1} A_{2}, B_{2} の要素間で起こるスーパー転倒

1つ目ケースは、 A_{1}, B_{1} を再び分割して再帰的に計算することができる (2つ目のケースも同様).

3つ目のケースを考える.  A_{1}, A_{2} それぞれをソートした数列を  A_{1}', A_{2}' とする. また、 B_{1}, B_{2} A_{1}', A_{2}' の順に対応するように並び替えた数列を  B_{1}', B_{2}' とする.  A_{1}, B_{1} A_{2}, B_{2} の要素間で起こる転倒数は  A_{1}' \lbrack i \rbrack \gt A_{2}' \lbrack j \rbrack かつ  B_{1}' \lbrack i \rbrack \lt B_{2}' \lbrack j \rbrack となる  (i, j) の組の個数に等しい.

これは、次のようなアルゴリズムで計算できる.

  1.  i = 0, j = 0, num = 0, v = \min\lbrace A_{1}'\lbrack i\rbrack, A_{2}'\lbrack j\rbrack \rbrace に設定する.
     A_{1}' \lbrack |A_{1}'| \rbrack = \infty, A_{2}' \lbrack |A_{2}'| \rbrack = \infty とする.
     i = |A_{1}'|, j = |A_{2}'| になったら 5. に移動する.
  2.  v = A_{1}\lbrack i\rbrack のとき、
     i 1 を追加して 2. に戻る.
  3.  v = A_{2}\lbrack j\rbrack' のとき、
     num B_{1}' \lbrack i \rbrack, B_{1}' \lbrack i+1 \rbrack, \ldots, B_{1}' \lbrack |B_{1}'|-1 \rbrack の中で  B_{2}' \lbrack j \rbrack より小さい要素の個数を追加する.
     j 1 を追加して 2. に戻る.
  4.  num A_{1}', B_{1}' A_{2}', B_{2}' の要素間で起こるスーパー転倒数になっている.

操作 4. の  B_{1}' \lbrack i \rbrack, B_{1}' \lbrack i+1 \rbrack, \ldots, B_{1}' \lbrack |B_{1}'|-1 \rbrack B_{2}' \lbrack j \rbrack より小さい要素の個数の計算は BIT や セグ木を使用して  O(log |B_{1}'|) 時間で計算できる.

計算例

例として、本問のサンプルケース2を扱う.  A = (7, 4, 3, 2, 6, 1, 5), B = (2, 1, 4, 3, 6, 5, 7) である.

 A, B をそれぞれ2つの数列  A_{1} = (7, 4, 3), A_{2} = (2, 6, 1, 5), B_{1} = (2, 1, 4), B_{2} = (3, 6, 5, 7) に分割する.

先程の例で確認してみる. A_{1}' = (3, 4, 7), B_{1}' = (4, 1, 2), A_{2}' = (1, 2, 6, 5), B_{2}' = (5, 3, 7, 6) である.

  •  i = 0, j = 0, num = 0

  •  A_{1}' \lbrack i \rbrack = A_{1}' \lbrack 0 \rbrack = 3, A_{2}' \lbrack j \rbrack = A_{2}' \lbrack 0 \rbrack = 1 より  A_{1}' \lbrack i \rbrack \gt A_{2}' \lbrack j \rbrack

    •  i = 0
    •  B_{1}' \lbrack i:2 \rbrack = B_{1}' \lbrack 0:2 \rbrack = (4, 1, 2) B_{2}' \lbrack j \rbrack = B_{2}' \lbrack 0 \rbrack = 5 より小さい要素の個数は 3 より、 num = 0 + 3 = 3
    •  j = 1
  •  A_{1}' \lbrack i \rbrack = A_{1}' \lbrack 0 \rbrack = 3, A_{2}' \lbrack j \rbrack = A_{2}' \lbrack 1 \rbrack = 2 より  A_{1}' \lbrack i \rbrack \gt A_{2}' \lbrack j \rbrack

    •  i = 0
    •  B_{1}' \lbrack i:2 \rbrack = B_{1}' \lbrack 0:2 \rbrack = (4, 1, 2) B_{2}' \lbrack j \rbrack = B_{2}' \lbrack 1 \rbrack = 3 より小さい要素の個数は 2 より、 num = 3 + 2 = 5
    •  j = 2
  •  A_{1}' \lbrack i \rbrack = A_{1}' \lbrack 0 \rbrack = 3, A_{2}' \lbrack j \rbrack = A_{2}' \lbrack 2 \rbrack = 6 より  A_{1}' \lbrack i \rbrack \lt A_{2}' \lbrack j \rbrack

    •  i = 1, j = 2, num = 5
  •  A_{1}' \lbrack i \rbrack = A_{1}' \lbrack 1 \rbrack = 4, A_{2}' \lbrack j \rbrack = A_{2}' \lbrack 2 \rbrack = 6 より  A_{1}' \lbrack i \rbrack \lt A_{2}' \lbrack j \rbrack

    •  i = 2, j = 2, num = 5
  •  A_{1}' \lbrack i \rbrack = A_{1}' \lbrack 2 \rbrack = 7, A_{2}' \lbrack j \rbrack = A_{2}' \lbrack 2 \rbrack = 6 より  A_{1}' \lbrack i \rbrack \gt A_{2}' \lbrack j \rbrack

    •  i = 2
    •  B_{1}' \lbrack i:2 \rbrack = B_{1}' \lbrack 2:2 \rbrack = (2) B_{2}' \lbrack j \rbrack = B_{2}' \lbrack 2 \rbrack = 7 より小さい要素の個数は 1 より、 num = 5 + 1 = 6
    •  j = 3
  •  A_{1}' \lbrack i \rbrack = A_{1}' \lbrack 2 \rbrack = 7, A_{2}' \lbrack j \rbrack = A_{2}' \lbrack 3 \rbrack = 5 より  A_{1}' \lbrack i \rbrack \gt A_{2}' \lbrack j \rbrack

    •  i = 2
    •  B_{1}' \lbrack i:2 \rbrack = B_{1}' \lbrack 2:2 \rbrack = (2) B_{2}' \lbrack j \rbrack = B_{2}' \lbrack 3 \rbrack = 6 より小さい要素の個数は 1 より、 num = 6 + 1 = 7
    •  j = 4
  •  A_{1}' \lbrack i \rbrack = A_{1}' \lbrack 2 \rbrack = 7, A_{2}' \lbrack j \rbrack = A_{2}' \lbrack 4 \rbrack = ∞ より  A_{1}' \lbrack i \rbrack \lt A_{2}' \lbrack j \rbrack

    •  i = 3, j = 4, num = 7
  •  i = 3, j = 4 より終了

以上より、 A_{1}', B_{1}' A_{2}', B_{2}' の要素間のスーパー転倒数は 7 となる.

これに、 A_{1}', B_{1}' の要素のみのスーパー転倒数と  A_{2}', B_{2}' の要素のみのスーパー転倒数を加えることで、 A, B のスーパー転倒数が求まる.

このアルゴリズム O(N \log^{2} N) 時間で計算できる.

#include <bits/stdc++.h>
using namespace std;

// Binary Indexed Tree
template <class T> struct BIT {
    vector<T> d;

    BIT(int n) { init(n); }
    void init(int n) { d.assign(n + 1, 0); }

    // a: 1-indexed
    inline void add(int a, T x) {
        for (int i = a; i < (int)d.size(); i += i & -i) {
            d[i] = d[i] + x;
        }
    }

    // [1, a]
    // a: 1-indexed
    inline T sum(int a) {
        T res = 0;
        for (int i = a; i > 0; i -= i & -i) {
            res = res + d[i];
        }
        return res;
    }

    // [a, b)
    // a, b: 1-indexed
    inline T sum(int a, int b) {
        return sum(b - 1) - sum(a - 1);
    }
};

// 昇順にソートされた数列 v において、v[i] >= x となる最小の i を返す.
int LBP(vector<int>& v, int x) {
    return lower_bound(v.begin(), v.end(), x) - v.begin();
}

void merge_sort(vector<int>& a, vector<int>& b, int& num) {
    const int n = a.size();
    if (n == 1) return;
    vector<int> a1, a2, b1, b2;
    for (int i = 0; i < n / 2; ++i) {
        a1.push_back(a[i]);
        b1.push_back(b[i]);
    }
    for (int i = n / 2; i < n; ++i) {
        a2.push_back(a[i]);
        b2.push_back(b[i]);
    }

    vector<int> b1_sorted = b1;
    sort(b1_sorted.begin(), b1_sorted.end());

    merge_sort(a1, b1, num);
    merge_sort(a2, b2, num);

    BIT<int> bit(b1.size());
    for (int bb : b1) {
        int pos = LBP(b1_sorted, bb) + 1;
        bit.add(pos, 1);
    }

    int i = 0, j = 0;
    for (int k = 0; k < n; ++k) {
        if (i != a1.size() && (j == a2.size() || a1[i] < a2[j])) {
            a[k] = a1[i];
            b[k] = b1[i];
            int pos = LBP(b1_sorted, b1[i]) + 1;
            bit.add(pos, -1);
            i++;
        } else {
            int pos = LBP(b1_sorted, b2[j]);
            num += bit.sum(pos);
            a[k] = a2[j];
            b[k] = b2[j];
            j++;
        }
    }
}

int main() {
    int N;
    cin >> N;
    vector<int> a(N), b(N);
    for (int i = 0; i < N; ++i) cin >> a[i];
    for (int i = 0; i < N; ++i) cin >> b[i];
    int ans = 0;
    merge_sort(a, b, ans);
    cout << ans << endl;
}