seriruの技術屋ブログ

競技プログラミングやゲーム開発など技術に関することを発信します

AtCoder Beginner Contest 116 D Various Sushi

問題

atcoder.jp

問題概要

 N 個の寿司があり, それぞれの寿司にはネタ  t_i とおいしさ  d_i が割り振られている。
この寿司から  K 個選んで食べる時, おいしさ  d_i の総和と食べたネタ  t_i の種類の二乗の和の最大値を求めよ。

考察

ネタの種類が少なく, それぞれの寿司のおいしさが大きい場合はおいしい順から寿司を食べていけばよいですが, そうでない場合, 複数のネタの中から数個を選んで食べる必要があります。

 N = 6,  K = 3 の次の図のような例を考えて見ます。
ネタが  3 種類あり, それぞれのネタが乗っている  2 つの寿司があります。

f:id:ryo_seriru:20190123140608p:plain:w200

とりあえずネタの種類を考えずにおいしい順から取ると,

f:id:ryo_seriru:20190123141449p:plain:w400

となり, 満足度は  9 + 7 + 6 + (2 * 2) = 26 になります。

次に, 食べるネタの種類を増やすことを考えて見ます。
まだ食べていないのはネタ  3 ですが, ネタ  2 と交換すると食べるネタの種類の総数は変わりません。
従って,  2 つ以上食べたネタ  1 との交換を考えます。
このとき, おいしさが大きい寿司より小さい寿司を交換した方が明らかに満足度が高くなります。
同様に, まだ食べていないネタは, おいしさの小さい寿司よりも大きい寿司を交換した方が良いです。

よって, ネタ  1 のおいしさ  7 の寿司とネタ  3 のおいしさ  3 の寿司を交換します。
すると満足度は  9 + 6 + 3 + (3 * 3) = 27 となり, 全ての種類のネタを試すことができました。
これが最大値です。

f:id:ryo_seriru:20190123142252p:plain:w400

このように,  N 個の寿司のおいしさの大きい順に寿司を選びます。
その後,  2 つ以上食べたネタで一番おいしくない寿司とまだ食べていないネタの中で一番おいしい寿司を交換する, という操作を  K 回, または, 全体のネタの種類回行えば最大値が求まります。

まだ食べていないネタと食べたネタの交換は, priority_queueなどを使用すると十分に高速に行うことができます。

提出コード

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;

#define rep(i,n) for(int i=0; i<n; i++)
#define sorti(x) sort(x.begin(), x.end())
#define pb push_back
#define int long long 

void show_ans(ll ans) {
  cout << ans << endl;
}

signed main() {
  int n,k;
  cin >> n >> k;
  map<int, int> mp;
  // ネタの種類
  ll total = 0;
  // 入力データを保存
  vector<pair<int,int> > data;

  ll t,s;
  rep(i,n) {
    cin >> t >> s;
    // ネタのおいしさ、ネタの種類(pairのsortは標準で第一要素を見て行うため)
    data.pb(make_pair(s,t));
    // まだ記録していないネタだったらインクリメント
    if (mp[t] == 0) total++;
    mp[t] += 1;
  }

  // ネタのおいしさで昇順にsort
  sorti(data);
  // どのネタを食べたかチェックするmap
  map<ll,ll> check;
  
  ll kd = 0; // 食べたネタの種類数
  ll cnt = 0 // 食べた数をカウント
  ll itr = n-1, ans = 0;
  // ネタのおいしさを降順に(priority_queueは標準で大きい方から取り出せる)
  priority_queue<pair<int,int>, vector<pair<int,int> >, greater<pair<int,int> > > sushi;
  for (int i = n-1; i >= 0; --i) {
    ll various = data[i].second, value = data[i].first;
    cnt += 1;
    itr = i;
    ans += value;
    // 選んだ寿司をキューに保存
    sushi.push(data[i]);
    // すでに食べた寿司か?’
    if (check[various] == 0) kd++;
    check[various] += 1;
    
    if (cnt == k) break;
  }
  
  // 大きい方からK個選んだときの合計を記録
  ll tmpans = ans;
  // 答えをK個選んだ時で初期化
  ans = ans + (kd * kd);
  //  N == K のとき, 全ての寿司の種類を食べたときは終わり
  if (itr == 0 or total == kd) {
    show_ans(ans);
    return 0;
  }

  for (int i = itr; i >= 0; --i) {
    // 全部交換し終わったら終了
    if (sushi.size() == 0) break;

    ll various = data[i].second, value = data[i].first;

    // 寿司がすでに食べたものだったら, それよりおいしさが小さい寿司を食べる必要はないのでcontinue
    if (check[various] > 0) continue;
    
    bool flag = false;
    while (sushi.size()) {
      pair<int,int> p = sushi.top();
      sushi.pop();
      if (check[p.second] > 1) {
        // ネタを1つ減らす
        check[p.second] -= 1;
        // おいしさを減らす
        tmpans -= p.first;
        break;
      } else {
        if (sushi.size() == 0) {
          flag = true;
        }
      }
    }

    if (flag) {
      break;
    }
    kd += 1;
    check[various] += 1;
    tmpans += value;
    ans = max(ans, tmpans + (kd * kd));
    if (kd == k) break;
  }

  show_ans(ans);
}

実装が重い...