JavaのModIntを考える

序論

プログラミングコンテストでは「答えの数を1000000007で割った余りを求めよ」という問題が大変よく出題される。1000000007は素数で、ほかにも 998244353 が代わりに出題されることもよくある。これらは真の解は大きくなりすぎて計算できない場合でも、加減乗算をするたびにその素数で割った余りを求めてあげれば解が変わらない。

ModInt

ソースコード中では問題で与えられた大きな素数は定数変数の mod と書くことにする。
a * (b + c) を求めたいときに、演算の度に余りを求めると

ans = a * ((b + c) % mod) % mod;

複雑な式になると %mod を書き忘れてオーバーフローで誤答になる危険があるので何とか記法を工夫したい。このモチベーションはとても大きい。

課題点

演算子オーバーロードがない

+ - * などの二項演算子オーバーロードできるプログラミング言語も存在するが Java はできない。そのためどう工夫しても数式的な中置記法をあきらめなければならない。

new のコスト

(a + b) % mod を求めたいだけなのに new で Object を作るのは実行時間に影響があるかもしれないため、new を使う場合はそのオーバーヘッドを知っておきたい。

結果

適当に加算と乗算を1億回するコードを、mod の行を工夫して実行時間を測る。
テストコードはこれ
ModIntTest · GitHub

for (int i = 0; i < a.length; i++) {
    for (int j = 0; j < b.length; j++) {
        ans = (ans * ((a[i] + b[j]) % mod)) % mod;
    }
}

単位はms。

Test AtCoder Codeforces Ideone
1 simple 439 1898 406
2 non final 803 1978 2053
3 function 444 1944 661
4 function2 486 2048 416
5 mint 582 3843 1317
6 mutable 560 2519 700

なぜか Codeforcesが非常に遅い。Codeforcesはそもそも除算・剰余算が驚くほど遅いため、記法の工夫の以前の問題として、演算回数をモンゴメリ乗算などで減らす方がいい。
AtCoder は new が予想よりも早い。正直、どの実装でもいい。
Ideone が最も予想に近い比率の時間になる。関数にすれば遅くなるし、newを使う回数だけもっと遅くなる。

テスト

Test1 simple

for (int i = 0; i < a.length; i++) {
    for (int j = 0; j < b.length; j++) {
        ans = (ans * ((a[i] + b[j]) % mod)) % mod;
    }
}

modを全て手打ちする。typoをしない限りは簡単かつ高速。
加算は mod の2倍を超えないので「剰余算」の代わりに「条件分岐と減算」でもできるが AtCoder では実行時間はほぼ変わらない。遅くなることも有る。

Test2 non final

// final を書き忘れる
long mod = 1000000007; // NG

final にするだけで速くなるので final は付けよう。もしくは

long mod() { return 1000000007; } // GOOD

のような定数関数にしても高速。オーバーライドしても高速なので 998244353 と共存させることもできる。

Test3 functions

ans = mul(ans, add(a[i], b[j]));

ただしstatic関数を定義する。

class ModIntStatic {
    static final long mod = 1000000007;

    static long add(long x, long y) {
        return (x + y) % mod;
    }

    static long mul(long x, long y) {
        return x * y % mod;
    }
}

関数記法になってしまうが実装が軽く実行速度が速く、可読性もよい。long のままなので関数を通し忘れるとオーバーフローの危険があるが、% mod を書き忘れるよりは防ぎやすい気がする。

Test4 function2

interface ModInt {
    public long mod();

    class ModInt998244353 implements ModInt { public long mod() { return 998244353; } }
    class ModInt1000000007 implements ModInt { public long mod() { return 1000000007; } }

    default long add(long x, long y) {
        return (x + y) % mod();
    }

    default long mul(long x, long y) {
        return x * y % mod();
    }
}

static ではなくなったが同様に高速。オーバーライドで複数の素数に対応できる。

Test5 Mint

ans = ans.mul(a[i].add(b[j]));

ただし、ans, a[i], b[j] は Mint オブジェクト。

class Mint {
    static final long mod = 1000000007;

    final long x;
    Mint(long x) { this.x = x; }

    Mint add(Mint y) {
        return new Mint((x + y.x) % mod);
    }

    Mint mul(Mint y) {
        return new Mint(x * y.x % mod);
    }
}

メソッドにしたので語順が中置記法になった。クラス同士でしか演算できないため、最も安全だと思われる。AtCoder は new がなぜか高速なのでこれも十分良いが、ほかのコンテストサイトでは遅いかもしれない。BigInteger と同じ API にできるのも良い。

Test6 Mutable Mint

tmp.assign(a[i]).addAssign(b[j]);
ans.mulAssign(tmp);

ただし ans, tmp, a[i], b[j] は Mint オブジェクト。

class Mint {
    static final long mod = 1000000007;

    long x;

    Mint(long x) { this.x = x; }

    Mint assign(Mint y) {
        x = y.x;
        return this;
    }

    Mint addAssign(Mint y) {
        x = (x + y.x) % mod;
        return this;
    }

    Mint mulAssign(Mint y) {
        x = x * y.x % mod;
        return this;
    }
}

new がなくなったのでどんな環境でも速いはず。ただし、数式が複雑になるだけ tmp 変数を用意しなければならない。また、変更可能なオブジェクトは関数の引数に渡したときに関数内で書き換えられることを心配する必要があるので、慎重に使わなければならない。