标签: BSGS

BSGS算法(大步小步法)及其扩展原理及应用

BSGS算法(大步小步法)及其扩展原理及应用

BSGS算法(Baby-step giant-step)

算法用于解决解高次同余方程的问题,问题形式如:有同余方程a^x \equiv b \pmod{p},p为质数,求最小非负整数解x使得原方程成立。该算法的复杂度可以达到O(\sqrt{n} \log n)甚至更低。

原理

根据欧拉定理,我们知道模的剩余类有产生循环的情况,即a^0, a^1, \ldots, a^{n-1}模n(质数)意义下的剩余类与a^n, a^{n+1}, \ldots, a^{2n-1}的剩余类相同,因此我们要的答案一定在 [0, n-1] 内。
我们考虑先求出一部分a的幂次模p意义下的值,将它们存起来,然后使得剩下没有求值的部分能够想个办法利用已求值直接查出来。我们想起了根号,不如直接令求值的长度为m=\lceil \sqrt{p} \rceil
下面要考虑的是没有求值部分怎么来求,比如a^m, \ldots, a^{2m-1}这一段,如果有解,一定是a^i \cdot a^m \equiv b \pmod{p}的情况,我们把那个a^m移到右边来,就变成了a^i \equiv b' \pmod{p}b' = ba^{-m}。既然这样,我们为什么不考虑直接查这个b'有没有对应的i的答案。这样,没有求值的部分的每一段就可以只进行一次查询来判断该段内是否有解。
这个保存,我们可以用一个map<int, int>来存,x[i]为余数为i的a^x中最小的x值,则插入查询都为O(\log n),总复杂度达到O(\sqrt{n} \log n)。当然还可以用HashMap存,复杂度会更优一些。

代码

本代码可以通过2417 — Discrete Logging一题。

STL map版本

// Code by KSkun, 2018/4
#include <cstdio>
#include <cmath>

#include <map>

typedef long long LL;

inline LL fpow(LL n, LL k, LL p) {
    LL res = 1; n %= p;
    while(k) {
        if(k & 1) res = res * n % p;
        n = n * n % p;
        k >>= 1;
    }
    return res;
}

std::map<LL, LL> x;

inline LL bsgs(LL a, LL b, LL p) {
    a %= p; b %= p;
    if(a == 0) return b == 0 ? 1 : -1;
    if(b == 1) return 0;
    LL m = ceil(sqrt(p - 1)), inv = fpow(a, p - m - 1, p);
    x.clear();
    x[1] = m; // use m instead of 0
    for(LL i = 1, e = 1; i < m; i++) {
        e = e * a % p;
        if(!x[e]) x[e] = i;
    }
    for(LL i = 0; i < m; i++) {
        if(x[b]) {
            LL res = x[b];
            return i * m + (res == m ? 0 : res);
        }
        b = b * inv % p;
    }
    return -1;
}

LL p, b, n;

int main() {
    while(scanf("%lld%lld%lld", &p, &b, &n) != EOF) {
        LL res = bsgs(b, n, p);
        if(res != -1) printf("%lld\n", res); else puts("no solution");
    }
    return 0;
}

HashMap版本

// Code by KSkun, 2018/4
#include <cstdio>
#include <cmath>
#include <cstring>

#include <algorithm>

typedef long long LL;

inline LL fpow(LL n, LL k, LL p) {
    LL res = 1; n %= p;
    while(k) {
        if(k & 1) res = res * n % p;
        n = n * n % p;
        k >>= 1;
    }
    return res;
}

const int MO = 611977, MAXN = 1000005;

struct HashMap {
    int head[MO + 5], key[MAXN], value[MAXN], nxt[MAXN], tot;
    inline void clear() {
        tot = 0;
        memset(head, -1, sizeof(head));
    }
    HashMap() {
        clear();
    }
    inline void insert(int k, int v) {
        int idx = k % MO;
        for(int i = head[idx]; ~i; i = nxt[i]) {
            if(key[i] == k) {
                value[i] = v;
                return;
            }
        }
        key[tot] = k; value[tot] = v; nxt[tot] = head[idx]; head[idx] = tot++;
    }
    inline int operator[](const int &k) const {
        int idx = k % MO;
        for(int i = head[idx]; ~i; i = nxt[i]) {
            if(key[i] == k) return value[i];
        }
        return -1;
    }
} x;

inline LL bsgs(LL a, LL b, LL p) {
    a %= p; b %= p;
    if(a == 0) return b == 0 ? 1 : -1;
    if(b == 1) return 0;
    LL m = ceil(sqrt(p - 1)), inv = fpow(a, p - m - 1, p);
    x.clear();
    x.insert(1, 0);
    for(LL i = 1, e = 1; i < m; i++) {
        e = e * a % p;
        if(x[e] == -1) x.insert(e, i);
    }
    for(LL i = 0; i < m; i++) {
        if(x[b] != -1) {
            LL res = x[b];
            return i * m + res;
        }
        b = b * inv % p;
    }
    return -1;
}

LL p, b, n;

int main() {
    while(scanf("%lld%lld%lld", &p, &b, &n) != EOF) {
        LL res = bsgs(b, n, p);
        if(res != -1) printf("%lld\n", res); else puts("no solution");
    }
    return 0;
}

扩展BSGS算法(exBSGS)

原理

咦,那p非质数怎么办呢?并不是说p-1以内没有答案,而是p可能会很大,根号的复杂度受不了啊!
想个办法来缩小p的范围。
我们想起了一些同余性质,比如a \equiv b \pmod{m} \Leftrightarrow \frac{a}{d} \equiv \frac{b}{d} \pmod{\frac{m}{d}},其中d为a、b、m的正公因数。我们想办法如此提公因数。
从方程左边拆一个a出来,提公因数,提完了就变成这样一个式子a^{x-1} \cdot \frac{a}{d} \equiv \frac{b}{d} \pmod{\frac{p}{d}}。直到某个时候\mathrm{gcd}(a, \frac{p}{\prod_i d_i}) = 1
如果在提公因数的过程中,遇到\mathrm{gcd}(a, \frac{p}{\prod_i d_i}) \neq 1且d不能整除b的情况,说明式子无解。因为a^x, p中的一个共同的因数b中没有,显然不存在这样的b。
结束以后,我们得到的会是一个这样的式子
a^{x-k} \cdot \frac{a^k}{\prod_i d_i} \equiv \frac{b}{\prod_i d_i} \pmod{\frac{p}{\prod_i d_i}}
把分母搞掉
a^x \equiv b \pmod{\frac{p}{\prod_i d_i}}
这个直接扔给普通BSGS做就好了。

代码

本代码可以通过3243 — Clever Y一题。

STL map版(会TLE)

// Code by KSkun, 2018/4
#include <cstdio>
#include <cmath>

#include <map>

typedef long long LL;

inline char fgc() {
    static char buf[100000], *p1 = buf, *p2 = buf;
    return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 100000, stdin), p1 == p2) ? EOF : *p1++;
}

inline LL readint() {
    register LL res = 0, neg = 1;
    char c = fgc();
    while(c < '0' || c > '9') {
        if(c == '-') neg = -1;
        c = fgc();
    }
    while(c >= '0' && c <= '9') {
        res = res * 10 + c - '0';
        c = fgc();
    }
    return res * neg;
}

inline LL fpow(LL n, LL k, LL p) {
    LL res = 1; n %= p;
    while(k) {
        if(k & 1) res = res * n % p;
        n = n * n % p;
        k >>= 1;
    }
    return res;
}

inline LL exgcd(LL a, LL b, LL &x, LL &y) {
    if(!b) {
        x = 1; y = 0;
        return a;
    }
    LL res = exgcd(b, a % b, x, y);
    LL t = x; x = y; y = t - a / b * y;
    return res;
}

std::map<LL, LL> x;

inline LL bsgs(LL a, LL b, LL p) {
    a %= p; b %= p;
    if(a == 0) return b == 0 ? 1 : -1;
    if(b == 1) return 0;
    LL m = ceil(sqrt(p - 1)), inv, y;
    exgcd(fpow(a, m, p), p, inv, y); inv = (inv % p + p) % p;
    x.clear();
    x[1] = m; // use m instead of 0
    for(LL i = 1, e = 1; i < m; i++) {
        e = e * a % p;
        if(!x[e]) x[e] = i;
    }
    for(LL i = 0; i < m; i++) {
        if(x[b]) {
            LL res = x[b];
            return i * m + (res == m ? 0 : res);
        }
        b = b * inv % p;
    }
    return -1;
}

inline LL gcd(LL a, LL b) {
    if(!b) return a;
    return gcd(b, a % b);
}

inline LL exbsgs(LL a, LL b, LL p) {
    if(b == 1) return 0;
    LL tb = b, tmp = 1, k = 0;
    for(int g = gcd(a, p); g != 1; g = gcd(a, p)) {
        if(tb % g) return -1;
        tb /= g; p /= g; tmp = tmp * a / g % p;
        k++;
        if(tmp == tb) return k;
    }
    return bsgs(a, b, p);
}

LL a, b, p;

int main() {
    for(;;) {
        a = readint(); p = readint(); b = readint();
        if(!a && !b && !p) break;
        LL res = exbsgs(a, b, p);
        if(res != -1) printf("%lld\n", res); else puts("No Solution");
    }
    return 0;
}

HashMap版

// Code by KSkun, 2018/4
#include <cstdio>
#include <cmath>
#include <cstring>

#include <algorithm>

typedef long long LL;

inline char fgc() {
    static char buf[100000], *p1 = buf, *p2 = buf;
    return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 100000, stdin), p1 == p2) ? EOF 
        : *p1++;
}

inline LL readint() {
    register LL res = 0, neg = 1;
    char c = fgc();
    while(c < '0' || c > '9') {
        if(c == '-') neg = -1;
        c = fgc();
    }
    while(c >= '0' && c <= '9') {
        res = res * 10 + c - '0';
        c = fgc();
    }
    return res * neg;
}

inline LL fpow(LL n, LL k, LL p) {
    LL res = 1; n %= p;
    while(k) {
        if(k & 1) res = res * n % p;
        n = n * n % p;
        k >>= 1;
    }
    return res;
}

inline LL exgcd(LL a, LL b, LL &x, LL &y) {
    if(!b) {
        x = 1; y = 0;
        return a;
    }
    LL res = exgcd(b, a % b, x, y);
    LL t = x; x = y; y = t - a / b * y;
    return res;
}

const int MO = 611977, MAXN = 1000005;

struct HashMap {
    LL head[MO + 5], key[MAXN], value[MAXN], nxt[MAXN], tot;
    inline void clear() {
        tot = 0;
        memset(head, -1, sizeof(head));
    }
    HashMap() {
        clear();
    }
    inline void insert(LL k, LL v) {
        int idx = k % MO;
        for(int i = head[idx]; ~i; i = nxt[i]) {
            if(key[i] == k) {
                value[i] = std::min(value[i], v);
                return;
            }
        }
        key[tot] = k; value[tot] = v; nxt[tot] = head[idx]; head[idx] = tot++;
    }
    inline LL operator[](const LL &k) const {
        int idx = k % MO;
        for(int i = head[idx]; ~i; i = nxt[i]) {
            if(key[i] == k) return value[i];
        }
        return -1;
    }
} x;

inline LL bsgs(LL a, LL b, LL p) {
    a %= p; b %= p;
    if(a == 0) return b == 0 ? 1 : -1;
    if(b == 1) return 0;
    LL m = ceil(sqrt(p - 1)), inv, y;
    exgcd(fpow(a, m, p), p, inv, y); inv = (inv % p + p) % p;
    x.clear();
    x.insert(1, 0);
    for(LL i = 1, e = 1; i < m; i++) {
        e = e * a % p;
        if(x[e] == -1) x.insert(e, i);
    }
    for(LL i = 0; i < m; i++) {
        if(x[b] != -1) {
            LL res = x[b];
            return i * m + res;
        }
        b = b * inv % p;
    }
    return -1;
}

inline LL gcd(LL a, LL b) {
    if(!b) return a;
    return gcd(b, a % b);
}

inline LL exbsgs(LL a, LL b, LL p) {
    if(b == 1) return 0;
    LL tb = b, tmp = 1, k = 0;
    for(int g = gcd(a, p); g != 1; g = gcd(a, p)) {
        if(tb % g) return -1;
        tb /= g; p /= g; tmp = tmp * a / g % p;
        k++;
        if(tmp == tb) return k;
    }
    return bsgs(a, b, p);
}

LL a, b, p;

int main() {
    for(;;) {
        a = readint(); p = readint(); b = readint();
        if(!a && !b && !p) break;
        LL res = exbsgs(a, b, p);
        if(res != -1) printf("%lld\n", res); else puts("No Solution");
    }
    return 0;
}

参考资料