5D 调包
调包
时间限制:6s
空间限制:128MB
题目描述
pzr 非常喜欢多项式,在学习了 NTT 之后,他希望通过编程实现一些简单的代数运算:
给定三个数组 \(a_1,a_2...a_n\) , \(b_1,b_2...b_n\), \(c_1, c_2, .... c_n\) 求多项式
- \((a_1x^2+b_1x+c_1)(a_2x^2+b_2x+c_2)...(a_nx^2+b_nx+c_n)\)
- 由于各项系数可能过大,各项系数需要对 \(998244353\) 取模。
但是,程序刚写完一半的时候, pzr 发现自己还没有做完 "计算机系统基础" 课程的作业,所以他希望你能帮忙 完善 这个程序。
输入格式
第一行一个正整数 \(n\)
第二行 \(n\) 个整数\(a_i\)
第三行 \(n\) 个整数\(b_i\)
第四行 \(n\) 个整数\(c_i\)
输出格式
请调用下方程序给出的 output 方法进行输出。
样例输入1
2
1 1
1 1
1 1
样例输出1
size : 5
data : 1x^4 + 2x^3 + 3x^2 + 2x + 1
样例1解释
\((x^2+x+1)(x^2+x+1) = (x^4 + 2x^3 + 3x^2 + 2x + 1)\)
样例输入2
7
1 2 3 4 5 6 7
8 9 10 11 12 13 14
15 16 17 18 19 20 21
样例输出2
size : 15
data : 5040x^14 + 126756x^13 + 1516396x^12 + 11519767x^11 + 62246114x^10 + 253481256x^9 + 803402740x^8 + 20053360x^7 + 43117588x^6 + 437588950x^5 + 57982840x^4 + 722627136x^3 + 398748401x^2 + 490375814x + 586051200
数据范围
对于 50 % 的数据, \(1\le n\le 1000\)
对于 100 % 的数据, \(1\le n\le 2*10^5,1\le a_i,b_i\le 10^9\)
请注意程序的时间效率。
题目提供的代码及说明
下方代码提供了 多项式类 Poly, 可能需要用到的函数或运算符 包括:
构造函数
- 函数原型:
poly(vector< int > v);
- 函数功能:传入一个int类型的vector,分别表示多项式的各项系数。将构造对应的Poly对象
- 时间复杂度:\(\mathcal{O}(n)\),其中 \(n\) 是多项式的长度。
- 函数原型:
经重载的 * 和 *= 运算符
- 函数原型:
poly& operator*=(const poly& t);
poly operator*(const poly& t);
函数功能:对于两个Poly类的对象,返回两个多项式的乘积。各项系数将自动取模。
时间复杂度:\(\mathcal{O}((n+m)\log (n+m))\),其中 \(n,m\) 是参与运算的两个多项式的长度。
- 函数原型:
output函数
- 函数原型:
void output();
- 函数功能:输出这个多项式,用于本题的判题。
- 时间复杂度:\(\mathcal{O}(n)\),其中 \(n\) 是多项式的长度。
- 函数原型:
#include <cassert>
#include <cmath>
#include <iostream>
#include <vector>
#include <queue>
using namespace std;
namespace convolution {
namespace modint {
template < int m > class mint
{
unsigned int x = 0;
public:
const int mod = m;
mint inv() const { return pow(mod - 2); }
mint pow(long long t) const
{
assert(t >= 0 && x > 0); mint res = 1, cur = x;
for (; t; t >>= 1) { if (t & 1) res *= cur; cur *= cur; }
return res;
}
mint() = default;
mint(unsigned int t) : x(t) {}
mint(int t) { t %= mod; if (t < 0) t += mod; x = t; }
mint(long long t) { t %= mod; if (t < 0) t += mod; x = t; }
explicit operator int() { return x; }
mint& operator+=(const mint& t){ x += t.x; if (x >= mod) x -= mod; return *this; }
mint& operator-=(const mint& t){ x += mod - t.x; if (x >= mod) x -= mod; return *this; }
mint& operator*=(const mint& t){ x = ( unsigned long long )x * t.x % mod;return *this; }
mint& operator/=(const mint& t){ *this *= t.inv(); return *this; }
mint operator+(const mint& t) { return mint(*this) += t; }
mint operator-(const mint& t) { return mint(*this) -= t; }
mint operator*(const mint& t) { return mint(*this) *= t; }
mint operator/(const mint& t) { return mint(*this) /= t; }
void operator=(const mint& t) { x = t.x; }
bool operator==(const mint& t) { return x == t.x; }
bool operator!=(const mint& t) { return x != t.x; }
bool operator<(const mint& t) { return x < t.x; }
bool operator<=(const mint& t) { return x <= t.x; }
bool operator>(const mint& t) { return x > t.x; }
bool operator>=(const mint& t) { return x >= t.x; }
friend istream& operator>>(istream& is, mint& t) { return is >> t.x; }
friend ostream& operator<<(ostream& os, mint& t) { return os << t.x; }
friend mint operator+(int y, const mint& t) { return mint(y) + t.x; }
friend mint operator-(int y, const mint& t) { return mint(y) - t.x; }
friend mint operator*(int y, const mint& t) { return mint(y) * t.x; }
friend mint operator/(int y, const mint& t) { return mint(y) / t.x; }
};
} // namespace modint
const double pi = acos(-1);
template < class mint, int g = 3 > class poly
{
static vector< int > r;
vector< mint > a;
void init(int limit)
{
if(r.size() < limit) r.resize(limit);
r[0] = 0;
int l = 31 - __builtin_clz(limit);
for (int i = 1; i < limit; i++)
r[i] = (r[i >> 1] >> 1 | (i & 1) << (l - 1));
}
int size() { return a.size(); }
void resize(int x) { a.resize(x); }
void ntt(int limit, int type)
{
init(limit);
for (int i = 0; i < limit; i++) {
if (i < r[i]) swap(a[i], a[r[i]]);
}
mint Wn, w;
for (int mid = 1; mid < limit; mid <<= 1) {
Wn = mint(g).pow((mint(g).mod - 1) / (mid << 1));
if (type == -1) Wn = Wn.inv();
int size = mid << 1;
for (int i = 0; i < limit; i += size) {
w = 1;
int j = i + mid;
for (int k = i; k < j; k++, w *= Wn) {
mint x = a[k], y = w * a[k + mid];
a[k] = x + y;
a[k + mid] = x - y;
}
}
}
}
poly convolution_ntt(poly b)
{
int siz = a.size() + b.size();
int limit = 1;
while (limit < siz)
limit <<= 1;
a.resize(limit);
b.resize(limit);
ntt(limit, 1);
b.ntt(limit, 1);
for (int i = 0; i < limit; i++)
a[i] *= b.a[i];
ntt(limit, -1);
mint t = mint(limit).inv();
for (int i = 0; i < limit; i++)
a[i] *= t;
a.resize(siz - 1);
return *this;
}
public:
poly(vector< int > v) {for (auto it : v) a.push_back(it);}
poly& operator*=(const poly& t) {convolution_ntt(t); return *this;}
poly operator*(const poly& t) { return poly(*this) *= t; }
void output()
{
cout << "size : " << a.size() << '\n';
cout << "data : ";
for (int i = a.size() - 1; i >= 0; i--) {
cout << a[i];
if (i) cout << "x" << (i > 1 ? "^" + to_string(i) : "") << " + ";
}
cout << '\n';
}
};
} // namespace convolution
const int mod = 998244353;
using namespace convolution;
using Poly = poly< modint::mint< mod >, 3 >;
template <> vector<int> Poly::r = {}; //= vector<int>{};
使用例:
(请复制上面给出的代码)
int main()
{
vector< int > v = { 2, 7 };
Poly p1 = v; // 2x + 7
Poly p2 = vector< int > { 14, 8, 7, 3, 4 }; // 14x^4 + 8x^3 + 7x^2 + 3x + 4;
Poly p3 = vector< int > { 1, 4, 7, 6 };
Poly p4 = p1 * p2 * p3;
p4.output();
(p1*p2).output();
p1 *= p2;
p1.output();
return 0;
}