いろはちゃんコンテスト Day2 E 連呼
このコンテストDay1, Day2どちらもとても楽しい。
2日目のテーマは「Typical」だったそうで勉強になった。
問題
問題概要
個の 、 個の がある。
1文字目が 、最後の文字が となり、連続する部分文字列にを持つ文字列の総数を求めよ。
参考
非常にわかりやすい解説が公式であがっているのでそちらを参考にしました。
https://img.atcoder.jp/iroha2019-day2/editorial-E.pdf
考察 and 解説
まず、""を含む文字列の数を数えるのは難しいので、""を含まない文字列の総数を求めることを考えます。
もし、""を含まない文字列の総数が求まれば、 から始まり で終わる文字列の個数からこれを引いた値が答えになります。
""を含まない文字列は、
のように "" を部分文字列として最大でもつしか持たない文字列になるはずです。
これをとの境目で区切ると、左から奇数番目が だけ の文字列、左から偶数番目が だけ の文字列になります。
また、この文字列の要素数は等しくなります。
また奇数番目と偶数番目は独立しているので、要素数 個のときの""を含まない文字列の総数は
奇数番目の場合の数 * 偶数番目の場合の数
になることがわかります。
この を ] の範囲で全探索すれば良いです。
奇数番目の場合の数
文字列の各要素は必ず を つは持つので、 が求める場合の数です。
偶数番目の場合の数
文字列の各要素は必ず を つは持ちます。
残った を各要素に振り分けることになるのでこれは、
となる整数解の個数と同じです。
よって、重複組合せを用いて が求める場合の数になります。
全体の個数
から始まり で終わる文字列の個数は
になります。
したがって求める文字列の総数は、全体からの場合の数を引いた値になります。
また高速に二項係数を求める方法として、
を参考にさせていただきました。
提出コード
#include <iostream> #include <cstdio> #include <algorithm> #include <vector> #include <functional> #include <queue> #include <string> #include <cstring> #include <numeric> #include <cstdlib> #include <cmath> using namespace std; typedef long long ll; #define INF 10e17 // 4倍しても(4回足しても)long longを溢れない #define rep(i,n) for(int i=0; i<n; i++) #define rep_r(i,n,m) for(int i=m; i<n; i++) #define END cout << endl #define MOD 1000000007 #define pb push_back #define sorti(x) sort(x.begin(), x.end()) #define sortd(x) sort(x.begin(), x.end(), std::greater<int>()) #define debug(x) std::cerr << (x) << std::endl; #define roll(x) for (auto itr : x) { debug(itr); } const ll MAX = 510000; ll fac[MAX], finv[MAX], inv[MAX]; void init() { fac[0] = fac[1] = 1; finv[0] = finv[1] = 1; inv[1] = 1; for(int i = 2; i < MAX; i++){ fac[i] = fac[i-1] * i % MOD; inv[i] = MOD - inv[MOD%i] * (MOD/i) % MOD; finv[i] = finv[i-1] * inv[i] % MOD; } } // n!! * k!^(-1) * ((n - k)!)^(-1) ll comb(ll n, ll k) { if (n < k) return 0; if (n < 0 || k < 0) return 0; return fac[n] * (finv[k]* finv[n - k] % MOD) % MOD; } int main() { init(); ll n, m; cin >> n >> m; ll ans = 0; // comb(n,k) -> nCk for (int k = 1; k <= min(n,m); ++k) { ans += comb(k, n - k) * comb(m - 1, k - 1) % MOD; ans %= MOD; } cout << (comb(n + m - 2, m - 1) - ans + MOD) % MOD << endl; }