Nimberの乗法・Sqrt・逆元・二次方程式

f:id:natsugiri:20200329072140j:plain

Nimber(Grundy数)はその和と積が定義されている。
Nimber - Wikipedia

Nimber和

 x \oplus y := {\rm mex} ( \{ x’ \oplus y : x’ < x \} \cup \{ x \oplus y’ : y’ < y \} )
これは単純に言えばxとyのビット毎のXORになっている。
丁寧に言うと次の性質がある。

  • Nimber和の性質1:  x < 2^n のときこれらのNimber和は普通の非負整数の和(正確には順序数の和)に等しい ( x \oplus 2^n = x + 2^n  )
  • Nimber和の性質2: 同じ値のNimber和は0 ( x \oplus x = 0)

Nimber積

 x \otimes y := {\rm mex} ( \{ (x’ \otimes y) \oplus (x \otimes y’) \oplus (x’ \otimes y’) : x’ < x, y’ < y \} )

  • Nimber積の性質1:  x < 2^{2^n} のときこれらのNimber積は普通の積(順序数の積)に等しい ( x \otimes 2^{2^n} = x \times 2^{2^n} )
  • Nimber積の性質2: 同じ値  2^{2^n} 同士のNimber積はその3/2倍  (2^{2^n} \otimes 2^{2^n} = \frac{3}{2} \times 2^{2^n} = 2^{2^n} \oplus 2^{2^n-1}) (例  2\otimes 2 = 3, 4 \otimes 4 = 6, 16 \otimes 16 = 24 )
  • その他の性質

練習

実際に 6 \otimes 9 を計算してみる
6と9をそれぞれ 2^nの和の形で書く
 (2+4) \otimes (1+8)
Nimber和の性質1より
 (2 \oplus 4) \otimes (1 \oplus 8)
それぞれの数を 2^{2^n} の積の形で書く
 (2 \oplus 4) \otimes (1 \oplus (2 \times 4))
Nimber積の性質1より
 (2 \oplus 4) \otimes (1 \oplus (2 \otimes 4))
分配法則で
 2 \oplus 4 \oplus (2 \otimes 2 \otimes 4) \oplus (4 \otimes 2 \otimes 4)
ここで、
 2 \otimes 2 \otimes 4 = 3 \otimes 4 = (1 \oplus 2) \otimes 4 = 4 \oplus 8 = 12
 4 \otimes 2 \otimes 4 = 2 \otimes 6 = 2 \otimes (2 \oplus 4) = 3 \oplus 8 = 11
よって
 6 \otimes 9 = 2 \oplus 4 \oplus 12 \oplus 11 = 1

ゲーム

Nimber積に対応するゲームがある。

1手ごとに1000円支払って硬貨を1つ取り除き、その左・下・左下に軸平行な長方形の頂点の位置に3つの硬貨を置く。この動作を2人のプレーヤーが交互に、動かせなくなるまで繰り返す。
このゲームの座標  (x, y) の硬貨に対応するNimberはNimber積 x \otimes y である。定義そのままなので明らか。

f:id:natsugiri:20200329072241j:plain   
f:id:natsugiri:20200329072254j:plain
ゲーム

実装

0以上 2^{2^n}未満の値のNimber集合は有限体を成す(減法は加法に等しい、除法も定義できる)。つまり、32bit や 64bitの unsigned整数においてオーバーフローせずに計算できる。
加法はXORなので何も実装する必要はない。乗法はProblem - F - Codeforcesに実装があり、2つ目の関数をメモ化してやるが計算量が悪い。
速い別のアプローチとしてKaratsuba法が使える。以降Nimber積の記号を省略する。
 x, y < 2^{2^n},  P = \sqrt{2^{2^n}} =  2^{2^{n-1}},  H=P/2 = 2^{2^{n-1}-1} とする。 x, yの上位・下位  2^{n-1} ビットをそれぞれ  0 \le x_h, x_l, y_h, y_l < P と表す。

  •  x = x_h P \oplus x_l
  •  y = y_h P \oplus y_l
  •  P \otimes P = \frac{3}{2} P = P \oplus H

例えば  x, yが任意の64bit整数、 x_h, x_l, y_h, y_lが32bit整数、 P=2^{32} H=2^{31} とか
計算の過程で長さが半分になり、小さくなったらメモ化テーブルを読む。

2の冪の乗法

 y = 2^i の場合、一般のNimber積よりも計算量が良い実装がある。とくに、 z H の計算が頻出するのでそれらの計算量も改善する。
 y_l = 0 のとき  xy = (x_h P \oplus x_l) y_h P = (x_h \oplus x_l) y_h P \oplus x_h y_h H
 y_h = 0 のとき  xy = (x_h P \oplus x_l) y_l = x_h y_l P \oplus x_l y_l

計算量  \Theta (w ^ {\lg 3} ) wはワード長、ビット数を 2^nに切り上げた値)

乗法

 \begin{array}{rcl}
x y & = &x_h y_h (P \oplus H) \oplus x_h y_l P \oplus x_l y_h P \oplus x_l y_l \\
     &  = & (x_h y_h \oplus x_h y_l \oplus x_l y_h) P \oplus (x_h y_h H \oplus x_l y_l)
\end{array}

ここで半bit長のNimber積を4個求める

  •  z_0 = x_l y_l
  •  z_1 = (x_h \oplus x_l) (y_h \oplus y_l)
  •  z_2 = x_h y_h
  •  z_3 = z_2 H ここだけ2の冪の乗法

 xy = (z_0 \oplus z_1) P \oplus (z_0 \oplus z_3)

計算量  \Theta (w ^ {\lg 3} \lg w)

Square

 xのiビット目を x_i \in \{0, 1\}とする。
 x \otimes x = \oplus \sum_i (x_i \otimes 2^i \otimes 2^i) より、Nimberの二乗はNimber積より良い実装がある。

 \begin{array}{rcl}
{\rm sq}(x) & = & (x_h P \oplus x_l) (x_h P \oplus x_l) \\
     & = & x_h x_h P P \oplus x_l x_l \\
    & = & {\rm sq}(x_h) P \oplus {\rm sq}(x_h) H \oplus {\rm sq}(x_l)
\end{array}

計算量  \Theta (w ^ {\lg 3})

Sqrt

Nimberの二乗は全単射。sqrtはsquareの逆写像
 \begin{array}{rcl}
y &:=& {\rm sqrt}(x) \\
(x_h, x_l) &=& ({\rm sq}(y_h), {\rm sq}(y_h) H \oplus {\rm sq}(y_l) ) \\
y_h &=& {\rm sqrt}(x_h) \\
y_l &=& {\rm sqrt}(x_h H \oplus x_l)
\end{array}

計算量  \Theta (w ^ {\lg 3})

逆数

 x \neq 0 について、 xy = 1となる yが必ず1つ存在する。

逆数1:バイナリ法(繰り返し二乗法)

Nimber積の指数関数を  {\rm power}(x, m) で表すとする。 0 < x < 2^{2^n} について、
 {\rm power}(x, 2^{2^n}) = x
 {\rm power}(x, 2^{2^n}-1) = 1
 {\rm power}(x, 2^{2^n}-2) = {\rm inverse}(x)
が成り立つ(要証明)。

計算量  \Theta (w ^ {1 + \lg 3} \lg w)

逆数2:逆行列

Nimber積  z = x y = (x_h P \oplus x_l)(y_h P \oplus y_l) = (x_h y_h \oplus x_h y_l \oplus x_l y_h) P \oplus (x_h y_h H \oplus x_l y_l) を行列積で表す。

 \left(\begin{array}{c}
z_h \\ z_l
\end{array}\right) = 
\left(\begin{array}{cc}
x_h \oplus x_l & x_h \\
x_h H & x_l
\end{array}\right)
\left(\begin{array}{c}
y_h \\ y_l
\end{array}\right)

ここで逆行列の公式が使える。

 \displaystyle
ad-bc \neq 0 のとき
 \left(\begin{array}{cc}
a & b \\
c & d
\end{array}\right) の逆行列は
\frac{1}{ad-bc}
\left(\begin{array}{cc}
d & -b \\
 -c & a
\end{array}\right)

逆行列を求める2通りの方法と例題 | 高校数学の美しい物語

 xの行列表現の逆行列 \displaystyle \frac{1}{(x_h \oplus x_l) x_l \oplus x_h x_h H}
\left(\begin{array}{cc}
x_l & x_h \\
 x_h H & x_h \oplus x_l
\end{array}\right)

 \displaystyle\left(\begin{array}{c}0 \\ 1\end{array}\right)を掛けてベクトル表現にすると  \displaystyle {\rm inverse}(x) = \left(\begin{array}{c}
y_h \\y_l
\end{array}\right) = \frac{1}{(x_h \oplus x_l) x_l \oplus x_h x_h H}
\left(\begin{array}{c}
x_h \\ x_h \oplus x_l
\end{array}\right)

Determinantが半ビット長の逆数になるので再帰的に計算する。

計算量  \Theta (w ^ { \lg 3} \lg w)

二次方程式

math.stackexchange.com

 f(x) = x x \oplus b x について、 f(x) = cになるような x求める。

 b = 0

 x x = cより x = {\rm sqrt}(c)が唯一の解。

 b = 1

解は2つある。 f(x) = f(x \oplus 1) より、偶数解のみ求めればよい。例えば  c = 0 ならば x = 0 が偶数解、 x=1 が奇数解。

 f(x) は定義域が  2^{m} \le x < 2^{m+1} の偶数ならば値域  2^{m-1} \le f(x) < 2^{m} へ写す全単射である。つまり、 cの最上位ビットで xの範囲が決まる。
解を求める関数を  f^{-1}(c) とする。

 \begin{array}{rcl}
f(x) &=& x x \oplus x = c \\
f(x_h P \oplus x_l) &=& (x_h x_h \oplus x_h) P \oplus (x_l x_l \oplus x_l \oplus x_h x_h H) = c_h P \oplus c_l \\
x_h &=& f^{-1}(c_h) \\
x_l &=& f^{-1}(x_h x_h H \oplus c_l)
\end{array}

f^{-1}(c_h) の偶数解  x_h再帰的に計算して得られる。この時の  x_l の存在を考える。
 (x_h x_h H \oplus c_l) < 2^{2^{n-1}} であるが、 (x_h x_h H \oplus c_l) < 2^{2^{n-1}-1} ならば上記の式の通りに  x_l を求めればよい。もし  (x_h x_h H \oplus c_l) \ge 2^{2^{n-1}-1} ならば、(最上位ビットが1ならば) x_l \ge 2^{2^{n-1}} となり、 x_lの解が範囲外になり間違っている。この場合の  x_h の正しい値は奇数解  x_h' = x_h \oplus 1 である。 x_l の式の定数項の最上位ビットが消えて、 x_l = f^{-1}(x_h' x_h' H \oplus c_l)再帰的に得られる。

 xcで値の範囲が異なるので注意。 c < 2^{63} ならば  x < 2^{64} の範囲で計算できる。

 b > 1

解は2つある。一つ求めれば、 f(x) = f(x \oplus b)より、もう一つも簡単に求められる。
 x = bz を代入して、 f(bz) = bbzz \oplus bbz = c よって、 f'(z) = zz \oplus z = {\rm inverse}({\rm sq}(b) ) c を解けばよい。

計算量  \Theta (w ^ { \lg 3} \lg w)

コード

namespace NIMBER {

using ULL = unsigned long long;

unsigned mul_table[256][256];
unsigned sqrt_table[256];
unsigned inverse_table[256];
unsigned quadratic_equation_b1_table[128]; // even solutions;

bool _auto_init() {
    mul_table[1][1] = 1;
    for (int t=1; t<=4; t+=t) {
	REP (xh, 1<<t) REP (yh, 1<<t) {
	    unsigned xhyhH = mul_table[1<<(t-1)][mul_table[xh][yh]];
	    REP (xl, 1<<t) REP (yl, 1<<t) {
		mul_table[xh<<t|xl][yh<<t|yl] =
		    (mul_table[xh^xl][yh^yl]^mul_table[xl][yl])<<t ^ mul_table[xl][yl] ^ xhyhH;
	    }
	}
    }
    REP (x, 256) sqrt_table[mul_table[x][x]] = x;
    REP (x, 256) REP (y, 256) if (mul_table[x][y] == 1u) inverse_table[x] = y;
    for (int x=0; x<256; x+=2) quadratic_equation_b1_table[mul_table[x][x] ^ x] = x;
    return true;
}
bool _auto_init_done = _auto_init();

int middle(ULL x) {
    if ((x>>8) == 0) return 4;
    if ((x>>16) == 0) return 8;
    if ((x>>32) == 0) return 16;
    return 32;
}

ULL mulp2(ULL x, ULL y) {
    assert(y && (y&(y-1u)) == 0); // assert y = 2^i;
    int p = middle(x|y);
    if (p == 4) return mul_table[x][y];
    ULL mask = ~0u>>(32-p);
    ULL xh = x >> p;
    ULL xl = x & mask;
    ULL yh = y >> p;
    ULL yl = y & mask;
    if (yh) {
	return (mulp2(xh^xl, yh)<<p) ^ mulp2(mulp2(xh, yh), 1u<<(p-1));
    } else {
	return (mulp2(xh, yl)<<p) ^ mulp2(xl, yl);
    }
}

ULL mul(ULL x, ULL y) {
    int p = middle(x|y);
    if (p == 4) return mul_table[x][y];
    ULL mask = ~0u>>(32-p);
    ULL xh = x >> p;
    ULL xl = x & mask;
    ULL yh = y >> p;
    ULL yl = y & mask;
    ULL z0 = mul(xl, yl);
    ULL z1 = mul(xh^xl, yh^yl);
    ULL z2 = mul(xh, yh);
    ULL z3 = mulp2(z2, 1u<<(p-1));
    return ((z0^z1)<<p) ^ z3 ^ z0;
}

ULL sq(ULL x) {
    int p = middle(x);
    if (p == 4) return mul_table[x][x];
    ULL mask = ~0u>>(32-p);
    ULL xh = x >> p;
    ULL xl = x & mask;
    ULL z = sq(xh);
    return (z<<p) ^ mulp2(z, 1u<<(p-1)) ^ sq(xl);
}

ULL sqrt(ULL x) {
    int p = middle(x);
    if (p == 4) return sqrt_table[x];
    ULL mask = ~0u>>(32-p);
    ULL xh = x >> p;
    ULL xl = x & mask;
    return (sqrt(xh)<<p) ^ sqrt(mulp2(xh, 1u<<(p-1)) ^ xl);
}

ULL power(ULL x, ULL y) {
    ULL ret = (y&1? x: 1);
    for (y>>=1; y; y>>=1) {
	x = sq(x);
	if (y&1) ret = mul(ret, x);
    }
    return ret;
}

ULL slow_inverse(ULL x) {
    if (x>>32) return power(x, -2ULL);
    return power(x, (1ULL<<32)-2ULL);
}

ULL inverse(ULL x) {
    int p = middle(x);
    if (p == 4) return inverse_table[x];
    ULL mask = ~0u>>(32-p);
    ULL xh = x >> p;
    ULL xl = x & mask;
    ULL det = mul(xl, xh^xl) ^ mulp2(sq(xh), 1u<<(p-1));
    ULL inv_det = inverse(det);
    return (mul(inv_det, xh)<<p) ^ mul(inv_det, xh^xl);
}

// find x: xx + x = c;
// answer: x, x+1;
ULL quadratic_equation_b1(ULL c) {
    assert(~c>>63&1); // assert c < 2^{63};
    int p = middle(c<<1);
    if (p == 4) return quadratic_equation_b1_table[c];
    ULL mask = ~0u>>(32-p);
    ULL H = 1u<<(p-1);
    ULL ch = c >> p;
    ULL cl = c & mask;
    ULL xh = quadratic_equation_b1(ch);
    ULL z = mulp2(sq(xh), H);
    if ((z ^ cl) & H) { xh ^= 1; z ^= H; }
    ULL xl = quadratic_equation_b1(z ^ cl);
    return (xh<<p) ^ xl;
}

// find x: xx + bx = c;
// answer: x, x+b;
ULL quadratic_equation(ULL b, ULL c) {
    if (b == 0) return sqrt(c);
    ULL d = (b == 1? c: mul(c, inverse(sq(b))));
    assert(~d>>63&1); // assert c/(b^2) < 2^{63};
    ULL x = quadratic_equation_b1(d);
    return (b == 1? x: mul(b, x));
}
};