快速数论变换(NTT)原理及实现

快速数论变换(NTT)原理及实现

概述

上次写了一篇狗屎文章快速傅里叶变换(FFT)原理与实现 | KSkun’s Blog,中间的FFT过程使用复数实现,由于使用浮点数存在精度误差。有什么应用于整数没有精度误差的方法吗?就是NTT了。
在NTT中,过程大致与FFT相同,只是用原根来代替单位根的作用。

原根?

你需要的数学姿势是:数学笔记:数论(欧拉函数、阶、原根) | KSkun’s Blog
为什么原根能够代替单位根?因为单位根具有的性质原根也可以具有。
证明参见:快速数论变换 (NTT) – riteme.site

实现?

大致与FFT相同。不同点大致就是要到处取模以及使用逆元操作。

代码?

本代码能够通过【P3803】【模板】多项式乘法(FFT) – 洛谷一题。代码中默认模数为998244353,一个原根为3。

// Code by KSkun, 2018/3
#include <cstdio>

#include <algorithm>

typedef long long LL;

inline char fgc() {
    static char buf[100000], *p1 = buf, *p2 = buf;
    return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 100000, stdin), p1 == p2) ? EOF : *p1++;
}

inline LL readint() {
    register LL res = 0, neg = 1;
    char c = fgc();
    while(c < '0' || c > '9') {
        if(c == '-') neg = -1;
        c = fgc();
    }
    while(c >= '0' && c <= '9') {
        res = res * 10 + c - '0';
        c = fgc();
    }
    return res * neg;
}

const int MAXN = 1 << 22, G = 3, MO = 998244353;

int n, m, len, rev[MAXN];
LL a[MAXN], b[MAXN];

inline LL fpow(LL n, LL k) {
    LL t = 1;
    while(k) {
        if(k & 1) t = (t * n) % MO;
        n = (n * n) % MO;
        k >>= 1;
    }
    return t;
}

inline void ntt(LL *arr, int f) {
    for(int i = 0; i < n; i++) {
        if(i < rev[i]) std::swap(arr[i], arr[rev[i]]);
    }
    for(int i = 1; i < n; i <<= 1) {
        LL gn = fpow(G, (MO - 1) / (i << 1));
        if(f == -1) gn = fpow(gn, MO - 2);
        for(int j = 0; j < n; j += i << 1) {
            LL w = 1;
            for(int k = 0; k < i; k++) {
                LL x = arr[j + k], y = w * arr[j + k + i] % MO;
                arr[j + k] = (x + y) % MO;
                arr[j + k + i] = ((x - y) % MO + MO) % MO;
                w = (w * gn) % MO;
            }
        }
    }
}

int main() {
    n = readint(); m = readint();
    for(int i = 0; i <= n; i++) {
        a[i] = readint();
    }
    for(int i = 0; i <= m; i++) {
        b[i] = readint();
    }
    m += n;
    for(n = 1; n <= m; n <<= 1) len++;
    for(int i = 0; i < n; i++) {
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (len - 1));
    }
    ntt(a, 1);
    ntt(b, 1);
    for(int i = 0; i <= n; i++) {
        a[i] = (a[i] * b[i]) % MO;
    }
    ntt(a, -1);
    int invn = fpow(n, MO - 2);
    for(int i = 0; i <= m; i++) {
        printf("%lld ", a[i] * invn % MO);
    }
    return 0;
}


发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

This site is protected by reCAPTCHA and the Google Privacy Policy and Terms of Service apply.

此站点使用Akismet来减少垃圾评论。了解我们如何处理您的评论数据