TechFUL: スーパー転倒数
問題概要
長さ の順列 が与えられる.
次の 2 つの条件を満たす の組の個数を求めよ.
制約
- は長さ の順列
解法
数列の転倒数は、その数列をマージソートする過程で計算する.
そのため、スーパー転倒数も数列をマージソートする過程で計算できないか考えてみる.
マージソートの実装や性質については drken さんの記事が参考になる.
スーパー転倒数の計算
をそれぞれ2つの数列 に分割する.
数列 のスーパー転倒数には次の3種類のスーパー転倒が起こる.
の要素のみで起こるスーパー転倒
の要素のみで起こるスーパー転倒
と の要素間で起こるスーパー転倒
1つ目ケースは、 を再び分割して再帰的に計算することができる (2つ目のケースも同様).
3つ目のケースを考える. それぞれをソートした数列を とする. また、 を の順に対応するように並び替えた数列を とする. と の要素間で起こる転倒数は かつ となる の組の個数に等しい.
これは、次のようなアルゴリズムで計算できる.
- に設定する.
とする.
になったら 5. に移動する. - のとき、
に を追加して 2. に戻る. - のとき、
に の中で より小さい要素の個数を追加する.
に を追加して 2. に戻る. - は と の要素間で起こるスーパー転倒数になっている.
操作 4. の で より小さい要素の個数の計算は BIT や セグ木を使用して 時間で計算できる.
計算例
例として、本問のサンプルケース2を扱う. である.
をそれぞれ2つの数列 に分割する.
先程の例で確認してみる. である.
より
- で より小さい要素の個数は 3 より、
より
- で より小さい要素の個数は 2 より、
より
より
より
- で より小さい要素の個数は 1 より、
より
- で より小さい要素の個数は 1 より、
より
より終了
以上より、 と の要素間のスーパー転倒数は 7 となる.
これに、 の要素のみのスーパー転倒数と の要素のみのスーパー転倒数を加えることで、 のスーパー転倒数が求まる.
このアルゴリズムは 時間で計算できる.
#include <bits/stdc++.h> using namespace std; // Binary Indexed Tree template <class T> struct BIT { vector<T> d; BIT(int n) { init(n); } void init(int n) { d.assign(n + 1, 0); } // a: 1-indexed inline void add(int a, T x) { for (int i = a; i < (int)d.size(); i += i & -i) { d[i] = d[i] + x; } } // [1, a] // a: 1-indexed inline T sum(int a) { T res = 0; for (int i = a; i > 0; i -= i & -i) { res = res + d[i]; } return res; } // [a, b) // a, b: 1-indexed inline T sum(int a, int b) { return sum(b - 1) - sum(a - 1); } }; // 昇順にソートされた数列 v において、v[i] >= x となる最小の i を返す. int LBP(vector<int>& v, int x) { return lower_bound(v.begin(), v.end(), x) - v.begin(); } void merge_sort(vector<int>& a, vector<int>& b, int& num) { const int n = a.size(); if (n == 1) return; vector<int> a1, a2, b1, b2; for (int i = 0; i < n / 2; ++i) { a1.push_back(a[i]); b1.push_back(b[i]); } for (int i = n / 2; i < n; ++i) { a2.push_back(a[i]); b2.push_back(b[i]); } vector<int> b1_sorted = b1; sort(b1_sorted.begin(), b1_sorted.end()); merge_sort(a1, b1, num); merge_sort(a2, b2, num); BIT<int> bit(b1.size()); for (int bb : b1) { int pos = LBP(b1_sorted, bb) + 1; bit.add(pos, 1); } int i = 0, j = 0; for (int k = 0; k < n; ++k) { if (i != a1.size() && (j == a2.size() || a1[i] < a2[j])) { a[k] = a1[i]; b[k] = b1[i]; int pos = LBP(b1_sorted, b1[i]) + 1; bit.add(pos, -1); i++; } else { int pos = LBP(b1_sorted, b2[j]); num += bit.sum(pos); a[k] = a2[j]; b[k] = b2[j]; j++; } } } int main() { int N; cin >> N; vector<int> a(N), b(N); for (int i = 0; i < N; ++i) cin >> a[i]; for (int i = 0; i < N; ++i) cin >> b[i]; int ans = 0; merge_sort(a, b, ans); cout << ans << endl; }