标签: FFT

[洛谷4245]【模板】MTT 题解 & 任意模数卷积算法(MTT)原理与实现

[洛谷4245]【模板】MTT 题解 & 任意模数卷积算法(MTT)原理与实现

题目地址:洛谷:【P4245】【模板】MTT – 洛谷

题目描述

多项式乘法。

输入输出样例

输入样例#1:

5 8 28
19 32 0 182 99 95
77 54 15 3 98 66 21 20 38

输出样例#1:

7 18 25 19 5 13 12 2 9 22 5 27 6 26

说明

数据范围1e5。

题解

由于NTT对模数有要求,我们需要使用任意模数卷积算法,也就是俗称的MTT。具体做法是,把一个数拆成32768x+y的形式,对拆出来的四组数分别卷积后拼起来。这样可以避免计算过程中的越界问题,一般的数据能够满足精度。
单位根最好预处理出来,这样才不会有精度问题。

代码

// Code by KSkun, 2018/3
#include <cstdio>
#include <cmath>

#include <algorithm>

typedef long long LL;

inline char fgc() {
    static char buf[100000], *p1 = buf, *p2 = buf;
    return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 100000, stdin), p1 == p2) ? EOF 
        : *p1++;
}

inline LL readint() {
    register LL res = 0, neg = 1;
    char c = fgc();
    while(c < '0' || c > '9') {
        if(c == '-') neg = -1;
        c = fgc();
    }
    while(c >= '0' && c <= '9') {
        res = res * 10 + c - '0';
        c = fgc();
    }
    return res * neg;
}

const int MAXN = 1 << 20;
const double PI = acos(-1);
int MO;

struct Complex {
    double real, imag;
    Complex(double real = 0, double imag = 0) : real(real), imag(imag) {}
    inline Complex conj() {
        return Complex(real, -imag);
    }
    inline Complex operator+(Complex rhs) const {
        return Complex(real + rhs.real, imag + rhs.imag);
    }
    inline Complex operator-(Complex rhs) const {
        return Complex(real - rhs.real, imag - rhs.imag);
    }
    inline Complex operator*(Complex rhs) const {
        return Complex(real * rhs.real - imag * rhs.imag, imag * rhs.real + real * rhs.imag);
    }
    inline Complex operator*=(Complex rhs) {
        return (*this) = (*this) * rhs;
    }
    friend inline Complex operator*(double x, Complex cp) {
        return Complex(x * cp.real, x * cp.imag);
    }
    inline Complex operator/(double x) const {
        return Complex(real / x, imag / x);
    }
    inline Complex operator/=(double x) {
        return (*this) = (*this) / x;
    }
    friend inline Complex operator/(double x, Complex cp) {
        return x * cp.conj() / (cp.real * cp.real - cp.imag * cp.imag);
    }
    inline Complex operator/(Complex rhs) const {
        return (*this) * rhs.conj() / (rhs.real * rhs.real - rhs.imag * rhs.imag);
    }
    inline Complex operator/=(Complex rhs) {
        return (*this) = (*this) / rhs;
    }
    inline double length() {
        return sqrt(real * real + imag * imag);
    }
};

int n, m, len, rev[MAXN];
LL x[MAXN], y[MAXN], z[MAXN];
Complex a[MAXN], b[MAXN], c[MAXN], d[MAXN], e[MAXN], f[MAXN], g[MAXN], h[MAXN], wn[MAXN];

inline void fft(Complex *arr, int f) {
    for(int i = 0; i < n; i++) {
        if(i < rev[i]) std::swap(arr[i], arr[rev[i]]);
    }
    for(int i = 1; i < n; i <<= 1) {
        for(int j = 0; j < n; j += i << 1) {
            for(int k = 0; k < i; k++) {
                Complex w = Complex(wn[n / i * k].real, f * wn[n / i * k].imag), 
                    x = arr[j + k], y = w * arr[j + k + i];
                arr[j + k] = x + y;
                arr[j + k + i] = x - y;
            }
        }
    }
    if(f == -1) {
        for(int i = 0; i < n; i++) {
            arr[i] /= n;
        }
    }
}

int main() {
    n = readint(); m = readint(); MO = readint();
    for(int i = 0; i <= n; i++) {
        x[i] = readint() % MO;
    }
    for(int i = 0; i <= m; i++) {
        y[i] = readint() % MO;
    }
    m += n;
    for(n = 1; n <= m; n <<= 1) len++;
    for(int i = 0; i < n; i++) {
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (len - 1));
    }
    for(int i = 1; i < n; i <<= 1) {
        for(int k = 0; k < i; k++) {
            wn[n / i * k] = Complex {cos(PI * k / i), sin(PI * k / i)};
        }
    }
    for(int i = 0; i < n; i++) {
        a[i].real = x[i] >> 15; b[i].real = x[i] & 0x7fff;
        c[i].real = y[i] >> 15; d[i].real = y[i] & 0x7fff;
    }
    fft(a, 1); fft(b, 1); fft(c, 1); fft(d, 1);
    for(int i = 0; i < n; i++) {
        e[i] = a[i] * c[i]; f[i] = b[i] * c[i];
        g[i] = a[i] * d[i]; h[i] = b[i] * d[i];
    }
    fft(e, -1); fft(f, -1); fft(g, -1); fft(h, -1);
    for(int i = 0; i < n; i++) {
        z[i] = (((LL(round(e[i].real)) % MO) << 30) % MO 
            + ((LL(round(f[i].real)) % MO) << 15) % MO
            + ((LL(round(g[i].real)) % MO) << 15) % MO 
            + LL(round(h[i].real)) % MO) % MO;
    }
    for(int i = 0; i <= m; i++) {
        printf("%lld ", z[i]);
    }
    return 0;
}