ABC 216
Table of Contents
https://atcoder.jp/contests/abc216
E - Amusement Park
https://atcoder.jp/contests/abc216/tasks/abc216_e
貪欲解法
明らかに、楽しさが大きいアトラクションから乗るのが最適である。 愚直に priority queue などに数値を入れてシミュレーションをすると時間がかかりすぎて TLE になる。
void solve() {
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
ll n, k;
cin >> n >> k;
vll a(n);
rep(i, n) cin >> a[i];
map<ll, ll> mp;
for (ll x : a)
mp[x]++;
ll ans = 0;
while (mp.rbegin() != mp.rend() && k) {
auto it = mp.rbegin();
auto nit = next(it);
ll v1 = it->first;
ll n1 = it->second;
ll v2 = 0;
if (nit != mp.rend())
v2 = nit->first;
if ((v1 - v2) * n1 <= k) {
k -= (v1 - v2) * n1;
ans += cal(v2, v1) * n1;
mp.erase(v1);
if (v2)
mp[v2] += n1;
} else {
ll q = k / n1;
ll r = k % n1;
ans += cal(v1 - q, v1) * n1;
v1 -= q;
ans += v1 * r;
k = 0;
}
}
cout << ans << endl;
}
二分探索解法
楽しさが 0 以下のアトラクションに乗る意味はないから楽しさが 0 より大きい場合を考える。 問題は楽しさが $1, 2, \cdots, A_1, 1, 2, \cdots, A_2, \cdots, 1,2, \cdots, A_N$ である $\sum_i A_i$ 個のアトラクションから $K$ 個選んで乗ったときの楽しさの総和の最大値を求める問題と読み替えることができる。
$\sum_i A_i \leq K$ のとき、全てのアトラクションに乗ることができるので、楽しさの総和は $\sum_i \frac{A_i(A_i+1)}{2}$ である。
$\sum_i A_i > K$ のときを考える。 $f(x)$ を楽しさが $x$ を超えるアトラクションの数とする。$f(x) < K$ となる最小の $x$ を $x_{\min}$ とすると答えは
$$ \left(\sum_{A_{i} > x_{\min}} \frac{A_i (A_i + 1)}{2} - \frac{x_{\min} (x_{\min}+1)}{2} \right) + x_{\min} (K - f(x_{\min})) $$となる。 $f(x)$ の計算は $O(N)$ で求めることができ、$x_{\min}$ の値は二分探索で求めることができるので十分高速。
計算式の説明
第1項目は $x_{\min}$ より大きいアトラクションに乗ったときの楽しさの総和。特に追加で説明することなし。
第2項目は $x_{\min}$ のアトラクションに乗ったときの楽しさの総和である。 $x_{\min}$ のアトラクションを $K-f(x_{\min})$ 回乗れることは保証されているのか?という疑問が出るがこれは保証されている。
$x_{\min}$ は $f(x) < K$ を満たす最小値なので $f(x_{\min}-1) \geq K$ となる。つまり $x_{\min}-1$ より大きい($x_{\min}$ 以上)のアトラクションに全て乗ると $K$ 回以上乗ることになる。
すなわち、$x_{\min}$ のアトラクションに乗らないと $K$ を超えないが、乗ると $K$ を超えるので $K-f(x_{\min})$ 回乗ることは可能である。
void solve() {
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
ll n, k;
cin >> n >> k;
vll a(n);
rep(i, n) cin >> a[i];
sort(rall(a));
{
// 0 で accumulate すると int で計算されてしまうので注意
// ちゃんと long long で計算されるように初期値に注意する
ll acc = accumulate(all(a), 0ll);
if (acc <= k) {
ll ans = 0;
rep(i, n) {
ans += a[i] * (a[i] + 1) / 2;
}
cout << ans << endl;
return;
}
}
ll ac = INF, wa = -1;
while (abs(wa - ac) > 1) {
ll wj = (ac + wa) / 2;
ll cnt = 0;
rep(i, n) {
if (a[i] - wj > 0)
cnt += a[i] - wj;
else
break;
}
if (cnt < k) {
ac = wj;
} else {
wa = wj;
}
}
ll ans = 0;
ll cnt = 0;
rep(i, n) {
ll x = a[i];
if (x > ac) {
ans += (sum(ac, x));
cnt += x - ac;
}
}
ans += ac * (k - cnt);
cout << ans << endl;
}