SRM472 Div1 Medium(600) TwoSidedCards

TwoSidedCards

このレベルの問題をコンテスト中に通せれば、赤狙えるんだろうなと思うけど、難しい。

どの数字が同じカードに書かれているかが重要で、カードの順番に意味は無い。カードをサイクルで分ける。Example 3を例にすると、

 {1,3,5}, {2,4,6,7,8}
 {3,5,1}, {4,6,7,8,2}

それぞれのグループの数字は異なるので、グループごとに並べ方が何通りかを数える。
右のグループを考える。TをTaroの側、HをHanakoの側としてTTTTTならば全ての数字が異なる。TTHHTならば8が2回現れる。THTHHならば2回現れる数字は2つ。k>0について、T→H, H→Tという切れ目を2k個入れる場合の数はnC2kで、このとき2回現れる数字はk個。また、k>0の場合はTaroとHanakoが逆になったときに表になる数字が異なる。カードがn枚のグループの裏表を決めたとき、2回現れる数字の数をkとすると、並べ方はn!/2k通り。

#include <vector>
#include <algorithm>

using namespace std;

class intm
{
    const static int M = 1000000007;
    int m;
public:
    intm(int n):m(n){}
    int value(){return m;}
    intm operator+(intm a){return intm(m)+=a;}
    intm operator-(intm a){return intm(m)-=a;}
    intm operator*(intm a){return intm(m)*=a;}
    intm operator/(intm a){return intm(m)/=a;}
    intm operator+=(intm a){m=(m+a.m)%M; return *this;}
    intm operator-=(intm a){m=int(((long long)(m)-a.m+M)%M); return *this;}
    intm operator*=(intm a){m=int(((long long)(m)*a.m)%M); return *this;}
    intm operator/=(intm a){m=(a.pow(M-2)*m).m; return *this;}
    intm pow(int a){intm r=1,t=m; while(a){if(a&1)r*=t;t*=t;a>>=1;} return r;}
};

class TwoSidedCards
{
    intm fact( int n ) { intm r=1; while(n)r*=n--; return r; }
    intm comb( int n, int k ) { return fact(n)/fact(k)/fact(n-k); }
public:
    int theCount( vector <int> taro, vector <int> hanako );
};

int TwoSidedCards::theCount( vector <int> taro, vector <int> hanako )
{
    int N = (int)taro.size();

    //  ループの数と大きさを調べる
    vector<int> loop;

    for ( int i=0; i<N; i++ )
    if ( taro[i] != -1 )
    {
        int c = 0;
        int p = i;
        while ( taro[p] != -1 )
        {
            c++;
            int t = taro[p];
            taro[p] = -1;

            p = (int)( find( hanako.begin(), hanako.end(), t ) - hanako.begin() );
        }

        loop.push_back( c );
    }

    //  サイズnのループについて並べ方は
    //   n! + n!/2 2 C(n,2) + n!/2/2 2 C(n,4) + …… + n!/2^k 2 C(n,2k)
    intm ans = 1;

    for ( int i=0; i<(int)loop.size(); i++ )
    {
        int n = loop[i];
        intm t = fact(n);
        for ( int j=1; j*2<=n; j++ )
            t += fact(n) / intm(2).pow(j) * 2*comb(n,j*2);

        ans *= t;
    }

    ans *= fact(N);
    for ( int i=0; i<(int)loop.size(); i++ )
        ans /= fact(loop[i]);

    return ans.value();
}