takapt0226's diary

競技プログラミングのことを書きます

1~nまでの逆元をO(n)で求める方法

参考にしたのは以下の@rng_58さんによるツイートです
導出が自明ではないのでメモ(少なくとも僕にとっては)



逆元列挙のコード

以下のコードで1~nまでの逆元(mod m)をO(n)で求められます

// 1~nの逆元を求める(mod m)
vector<ll> list_mod_inverse(ll n, ll m)
{
    vector<ll> inv(n + 1);
    inv[1] = 1;
    for (int i = 2; i <= n; ++i)
        inv[i] = inv[m % i] * (m - m / i) % m;
    return inv;
}
導出

説明では、mを法としています。また、式中の/除算は切り捨て除算です。
上記のコードでは i^{-1} \equiv (m % i)^{-1} * (m - m / i)として、逆元を求めています。この式を導出します。

mは m = k * i + (m % i)と表すことができます。この式中のkは(m / i)に等しくなります。よって、 m = (m / i) * i + (m % i)となります。この式を変形をしていきます。

(※以下の式は (mod \: m)です。)
 m \equiv 0より、
 (m / i) * i + (m % i) \equiv 0

 (m % i)^{-1}を掛ける
 (m % i)^{-1} * (m / i) * i + 1 \equiv 0

 i^{-1}を掛ける
 (m % i)^{-1} * (m / i) + i^{-1} \equiv 0
 i^{-1} \equiv -(m % i)^{-1} * (m / i)

負の数を消す
 i^{-1} \equiv (m % i)^{-1} * -(m / i)
 i^{-1} \equiv (m % i)^{-1} * (m - (m / i))
以上です