5D 调包

调包

时间限制:6s

空间限制:128MB

题目描述

pzr 非常喜欢多项式,在学习了 NTT 之后,他希望通过编程实现一些简单的代数运算:

给定三个数组 \(a_1,a_2...a_n\) , \(b_1,b_2...b_n\), \(c_1, c_2, .... c_n\) 求多项式

  • \((a_1x^2+b_1x+c_1)(a_2x^2+b_2x+c_2)...(a_nx^2+b_nx+c_n)\)
  • 由于各项系数可能过大,各项系数需要对 \(998244353\) 取模。

但是,程序刚写完一半的时候, pzr 发现自己还没有做完 "计算机系统基础" 课程的作业,所以他希望你能帮忙 完善 这个程序。

输入格式

第一行一个正整数 \(n\)

第二行 \(n\) 个整数\(a_i\)

第三行 \(n\) 个整数\(b_i\)

第四行 \(n\) 个整数\(c_i\)

输出格式

请调用下方程序给出的 output 方法进行输出。

样例输入1

2
1 1 
1 1
1 1 

样例输出1

size : 5
data : 1x^4 + 2x^3 + 3x^2 + 2x + 1

样例1解释

\((x^2+x+1)(x^2+x+1) = (x^4 + 2x^3 + 3x^2 + 2x + 1)\)

样例输入2

7
1 2 3 4 5 6 7
8 9 10 11 12 13 14
15 16 17 18 19 20 21

样例输出2

size : 15
data : 5040x^14 + 126756x^13 + 1516396x^12 + 11519767x^11 + 62246114x^10 + 253481256x^9 + 803402740x^8 + 20053360x^7 + 43117588x^6 + 437588950x^5 + 57982840x^4 + 722627136x^3 + 398748401x^2 + 490375814x + 586051200

数据范围

对于 50 % 的数据, \(1\le n\le 1000\)

对于 100 % 的数据, \(1\le n\le 2*10^5,1\le a_i,b_i\le 10^9\)

请注意程序的时间效率。

题目提供的代码及说明

下方代码提供了 多项式类 Poly, 可能需要用到的函数或运算符 包括:

  • 构造函数

    • 函数原型:poly(vector< int > v);
    • 函数功能:传入一个int类型的vector,分别表示多项式的各项系数。将构造对应的Poly对象
    • 时间复杂度:\(\mathcal{O}(n)\),其中 \(n\) 是多项式的长度。
  • 经重载的 * 和 *= 运算符

    • 函数原型:poly& operator*=(const poly& t);  

    poly operator*(const poly& t);

    • 函数功能:对于两个Poly类的对象,返回两个多项式的乘积。各项系数将自动取模。

    • 时间复杂度:\(\mathcal{O}((n+m)\log (n+m))\),其中 \(n,m\) 是参与运算的两个多项式的长度。

  • output函数

    • 函数原型:void output();
    • 函数功能:输出这个多项式,用于本题的判题。
    • 时间复杂度:\(\mathcal{O}(n)\),其中 \(n\) 是多项式的长度。
#include <cassert>
#include <cmath>
#include <iostream>
#include <vector>
#include <queue>
using namespace std;

namespace convolution {

namespace modint {

template < int m > class mint
{
    unsigned int x = 0;

public:
    const int mod = m;
    mint inv() const { return pow(mod - 2); }
    mint pow(long long t) const
    {
        assert(t >= 0 && x > 0); mint res = 1, cur = x;
        for (; t; t >>= 1) { if (t & 1) res *= cur; cur *= cur; }
        return res;
    }
    mint() = default;
    mint(unsigned int t) : x(t) {}
    mint(int t) { t %= mod; if (t < 0) t += mod; x = t; }
    mint(long long t) { t %= mod; if (t < 0) t += mod; x = t; }
    explicit operator int() { return x; }

    mint& operator+=(const mint& t){ x += t.x; if (x >= mod) x -= mod; return *this; }
    mint& operator-=(const mint& t){ x += mod - t.x; if (x >= mod) x -= mod; return *this; }
    mint& operator*=(const mint& t){ x = ( unsigned long long )x * t.x % mod;return *this; }
    mint& operator/=(const mint& t){ *this *= t.inv(); return *this; }
    mint operator+(const mint& t) { return mint(*this) += t; }
    mint operator-(const mint& t) { return mint(*this) -= t; }
    mint operator*(const mint& t) { return mint(*this) *= t; }
    mint operator/(const mint& t) { return mint(*this) /= t; }
    void operator=(const mint& t) { x = t.x; }
    bool operator==(const mint& t) { return x == t.x; }
    bool operator!=(const mint& t) { return x != t.x; }
    bool operator<(const mint& t) { return x < t.x; }
    bool operator<=(const mint& t) { return x <= t.x; }
    bool operator>(const mint& t) { return x > t.x; }
    bool operator>=(const mint& t) { return x >= t.x; }
    friend istream& operator>>(istream& is, mint& t) { return is >> t.x; }
    friend ostream& operator<<(ostream& os, mint& t) { return os << t.x; }
    friend mint operator+(int y, const mint& t) { return mint(y) + t.x; }
    friend mint operator-(int y, const mint& t) { return mint(y) - t.x; }
    friend mint operator*(int y, const mint& t) { return mint(y) * t.x; }
    friend mint operator/(int y, const mint& t) { return mint(y) / t.x; }
};

}  // namespace modint

const double pi = acos(-1);
template < class mint, int g = 3 > class poly
{
    static vector< int > r;
    vector< mint > a;
    void init(int limit)
    {
        if(r.size() < limit) r.resize(limit);
        r[0] = 0;
        int l = 31 - __builtin_clz(limit);
        for (int i = 1; i < limit; i++)
            r[i] = (r[i >> 1] >> 1 | (i & 1) << (l - 1));
    }
    int size() { return a.size(); }
    void resize(int x) { a.resize(x); }
    void ntt(int limit, int type)
    {
        init(limit);
        for (int i = 0; i < limit; i++) {
            if (i < r[i]) swap(a[i], a[r[i]]);
        }
        mint Wn, w;
        for (int mid = 1; mid < limit; mid <<= 1) {
            Wn = mint(g).pow((mint(g).mod - 1) / (mid << 1));
            if (type == -1) Wn = Wn.inv();
            int size = mid << 1;
            for (int i = 0; i < limit; i += size) {
                w = 1;
                int j = i + mid;
                for (int k = i; k < j; k++, w *= Wn) {
                    mint x = a[k], y = w * a[k + mid];
                    a[k] = x + y;
                    a[k + mid] = x - y;
                }
            }
        }
    }

    poly convolution_ntt(poly b)
    {
        int siz = a.size() + b.size();
        int limit = 1;
        while (limit < siz)
            limit <<= 1;
        a.resize(limit);
        b.resize(limit);
        ntt(limit, 1);
        b.ntt(limit, 1);
        for (int i = 0; i < limit; i++)
            a[i] *= b.a[i];
        ntt(limit, -1);
        mint t = mint(limit).inv();
        for (int i = 0; i < limit; i++)
            a[i] *= t;
        a.resize(siz - 1);
        return *this;
    }
public:
    poly(vector< int > v) {for (auto it : v) a.push_back(it);}
    poly& operator*=(const poly& t) {convolution_ntt(t); return *this;}
    poly operator*(const poly& t) { return poly(*this) *= t; }
    void output()
    {
        cout << "size : " << a.size() << '\n';
        cout << "data : ";
        for (int i = a.size() - 1; i >= 0; i--) {
            cout << a[i];
            if (i) cout << "x" << (i > 1 ? "^" + to_string(i) : "") << " + ";
        }
        cout << '\n';
    }
};

}  // namespace convolution

const int mod = 998244353;
using namespace convolution;
using Poly = poly< modint::mint< mod >, 3 >;
template <> vector<int> Poly::r = {}; //= vector<int>{};

使用例:


(请复制上面给出的代码)

int main()
{
    vector< int > v = { 2, 7 };
    Poly p1 = v;                                // 2x + 7
    Poly p2 = vector< int > { 14, 8, 7, 3, 4 };  // 14x^4 + 8x^3 + 7x^2 + 3x + 4;
    Poly p3 = vector< int > { 1, 4, 7, 6 };
    Poly p4 = p1 * p2 * p3;   
    p4.output();
    (p1*p2).output();
    p1 *= p2;
    p1.output();
    return 0;
}

信息

ID
1359
难度
9
分类
(无)
标签
(无)
递交数
11
已通过
3
通过率
27%
上传者

相关

在下列比赛中:

悬赏令第五周