seriruの技術屋ブログ

競技プログラミングやゲーム開発など技術に関することを発信します

いろはちゃんコンテスト Day2 E 連呼

このコンテストDay1, Day2どちらもとても楽しい。
2日目のテーマは「Typical」だったそうで勉強になった。

問題

atcoder.jp

問題概要

 N 個の  A M 個の  B がある。
1文字目が  A、最後の文字が  B となり、連続する部分文字列に AAAを持つ文字列の総数を求めよ。

参考

非常にわかりやすい解説が公式であがっているのでそちらを参考にしました。

https://img.atcoder.jp/iroha2019-day2/editorial-E.pdf

考察 and 解説

まず、" AAA"を含む文字列の数を数えるのは難しいので、" AAA"を含まない文字列の総数を求めることを考えます。

もし、" AAA"を含まない文字列の総数が求まれば、 A から始まり  B で終わる文字列の個数からこれを引いた値が答えになります。

" AAA"を含まない文字列は、

 AABBBBBAABBBABBBABBB

のように " A" を部分文字列として最大でも 2つしか持たない文字列になるはずです。

これを A Bの境目で区切ると、左から奇数番目が  Aだけ の文字列、左から偶数番目が  Bだけ の文字列になります。

f:id:ryo_seriru:20190506172220p:plain

また、この文字列の要素数は等しくなります。

各要素数 K とすると、文字列全体の要素数 2K となります。

また奇数番目と偶数番目は独立しているので、要素数  K 個のときの" AAA"を含まない文字列の総数は

奇数番目の場合の数 * 偶数番目の場合の数

になることがわかります。

この  K [1, \ min(|A|, |B|)] の範囲で全探索すれば良いです。

奇数番目の場合の数

文字列の各要素は必ず  A 1つは持つので、 _{k} C_{|A| - k} が求める場合の数です。

f:id:ryo_seriru:20190506185016p:plain

偶数番目の場合の数

文字列の各要素は必ず  B 1つは持ちます。

残った  B を各要素に振り分けることになるのでこれは、

 a_{1} + a_{2} + ... + a_{k} = |B| - k

となる整数解の個数と同じです。

よって、重複組合せを用いて  _{|B| - 1} C_{k - 1} が求める場合の数になります。

f:id:ryo_seriru:20190506185036p:plain

全体の個数

 A から始まり  B で終わる文字列の個数は

 _{|A| + |B| - 2}C_{ |B| - 1 }

になります。

したがって求める文字列の総数は、全体から 1 \le k \le min(|A|, |B|)の場合の数を引いた値になります。

また高速に二項係数を求める方法として、

drken1215.hatenablog.com

を参考にさせていただきました。

提出コード

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <functional>
#include <queue>
#include <string>
#include <cstring>
#include <numeric>
#include <cstdlib>
#include <cmath>
using namespace std;

typedef long long ll;

#define INF 10e17 // 4倍しても(4回足しても)long longを溢れない
#define rep(i,n) for(int i=0; i<n; i++)
#define rep_r(i,n,m) for(int i=m; i<n; i++)
#define END cout << endl
#define MOD 1000000007
#define pb push_back
#define sorti(x) sort(x.begin(), x.end())
#define sortd(x) sort(x.begin(), x.end(), std::greater<int>())
#define debug(x) std::cerr << (x) << std::endl;
#define roll(x) for (auto itr : x) { debug(itr); }

const ll MAX = 510000;

ll fac[MAX], finv[MAX], inv[MAX];

void init() {
  fac[0] = fac[1] = 1;
  finv[0] = finv[1] = 1;
  inv[1] = 1;
  for(int i = 2; i < MAX; i++){
    fac[i] = fac[i-1] * i % MOD;
    inv[i] = MOD - inv[MOD%i] * (MOD/i) % MOD;
    finv[i] = finv[i-1] * inv[i] % MOD;
  }
}

// n!! * k!^(-1) * ((n - k)!)^(-1)
ll comb(ll n, ll k) {
  if (n < k) return 0;
  if (n < 0 || k < 0) return 0;
  return fac[n] * (finv[k]* finv[n - k] % MOD) % MOD;
}

int main() {
  init();
  ll n, m;
  cin >> n >> m;

  ll ans = 0; 
  // comb(n,k) -> nCk
  for (int k = 1; k <= min(n,m); ++k) {
    ans += comb(k, n - k) * comb(m - 1, k - 1) % MOD;
    ans %= MOD;
  }
  cout << (comb(n + m - 2, m - 1) - ans + MOD) % MOD << endl;
}