玄学优化&其他算法

个人用到的一些优化/调试技巧,和一些杂七杂八的算法

玄学启发

  • 指数级别, dfs+剪枝,状态压缩dp
  • n100O(n3)n≤100 \rightarrow O(n^3),floyd,dp,高斯消元
  • n1000O(n2),O(n2logn)n≤1000 \rightarrow O(n^2), O(n^2logn),dp,二分,朴素版Dijkstra、朴素版Prim、Bellman-Ford
  • n10000O(n×n)n\leq 10000 \rightarrow O(n\times \sqrt{n}) ,块状链表、分块、莫队
  • n100000O(nlogn)n\leq 100000 \rightarrow O(nlogn),各种sort,线段树、树状数组、set/map、heap、拓扑排序、dijkstra+heap, prim+heap、spfa、求凸包、求半平面交、二分、CDQ分治、整体二分
  • n1000000O(n)n≤1000000 \rightarrow O(n), 以及常数较小的 O(nlogn)O(n\log n) 算法 => 单调队列、 hash、双指针扫描、并查集,kmp、AC自动机,常数比较小的 O(nlogn)O(n\log n)的做法:sort、树状数组、heap、dijkstra、spfa
  • n10000000O(n)n≤10000000 \rightarrow O(n),双指针扫描、kmp、AC自动机、线性筛素数
  • n109O(n)n\leq 10^{9} \rightarrow O(\sqrt n),判断质数
  • n1018O(logn)n≤10^{18} \rightarrow O(logn),最大公约数,快速幂
  • n101000O((logn)2)n≤10^{1000} \rightarrow O((logn)^2) ,高精度加减乘除
  • n10100000O(logk×loglogk)n≤10^{100000} \rightarrow O(logk×loglogk),k表示位数,高精度加减、FFT/NTT

快读快写

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
inline int read() {
int x=0,f=1;char ch;
do{ch=getchar();if(ch=='-') f=-1;} while(ch<'0'||ch>'9');
do{x=x*10+ch-48;ch=getchar();} while(ch>='0'&&ch<='9');
return x*f;
}
inline void print(int x) {
if (x < 0) {
putchar('-');
x = -x;
}
if (x >= 10)
print(x / 10);
putchar(x % 10 + '0');
}

fread快读

1
2
3
4
5
6
7
8
char buf[1<<23],*p1=buf,*p2=buf,obuf[1<<23],*O=obuf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
inline int read() {
int x=0,f=1;char ch;
do{ch=getchar();if(ch=='-') f=-1;} while(ch<'0'||ch>'9');
do{x=x*10+ch-48;ch=getchar();} while(ch>='0'&&ch<='9');
return x*f;
}

inline一定要写,一定要写,一定要写!!!

__int128

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
inline __int128 read() {
__int128 x = 0, f = 1;
char ch = getchar();
while (ch < '0' || ch > '9') {
if (ch == '-')
f = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9') {
x = x * 10 + ch - '0';
ch = getchar();
}
return x * f;
}
inline void print(__int128 x) {
if (x < 0) {
putchar('-');
x = -x;
}
if (x > 9)
print(x / 10);
putchar(x % 10 + '0');
}

O2/O3优化

1
2
3
4
#pragma GCC optimize(2) 
#pragma GCC optimize(3)
#pragma G++ optimize(2)
#pragma G++ optimize(3)

时间复杂度测试

1
2
3
4
clock_t start, finish;
start = clock();
finish = clock();
cout << (finish - start) << "ns" << endl;

离散化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
vector<int> num;

int findx(int x) {
return lower_bound(num.begin(), num.end(), x) - num.begin() + 1;
}

for (int i = 1; i <= n; i++) {
cin >> a[i];
num.push_back(a[i]);
}
sort(num.begin(), num.end());
num.erase(unique(a.begin(), a.end()), a.end());
for (int i = 0; i < num.size(); i++)
a[i] = findx(a[i]); //根据原数的大小找到离散化后的值

整数二分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
bool check(int x) {/* ... */} // 检查x是否满足某种性质

// 区间[l, r]被划分成[l, mid]和[mid + 1, r]时使用:
int bsearch_1(int l, int r) {
while (l < r) {
int mid = l + r >> 1;
if (check(mid)) r = mid;
else l = mid + 1;
}
return l;
}
// 区间[l, r]被划分成[l, mid - 1]和[mid, r]时使用:
int bsearch_2(int l, int r) {
while (l < r) {
int mid = l + r + 1 >> 1;
if (check(mid)) l = mid;
else r = mid - 1;
}
return l;
}

浮点数二分

注意:不能用eps,这样会出锅。应该把结束条件设定为浮点数二分的次数,比如200(根据题目的时间复杂度而设定),只要不TLE限制次数越大越好。然后每次二分自增当前的次数。

1
2
3
4
5
6
7
8
9
10
11
bool check(double x) {/* ... */} // 检查x是否满足某种性质

double bsearch_3(double l, double r) {
const double eps = 1e-6; // eps 表示精度,取决于题目对精度的要求
while (r - l > eps) {
double mid = (l + r) / 2;
if (check(mid)) r = mid;
else l = mid;
}
return l;
}

高精度加法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
int base = 10; //表示进制
vector<int> add(vector<int> &A, vector<int> &B) {
if (A.size() < B.size())
return add(B, A);
vector<int> res;
int t = 0;
for (int i = 0; i < A.size(); ++i) {
t += A[i];
if (i < B.size())
t += B[i];
res.push_back(t % base);
t /= base;
}
if (t)
res.push_back(1);
return res;
}

void solve() {
string a, b;
cin >> a >> b;
int alen = a.size(), blen = b.size();
vector<int> A, B;
for (int i = alen - 1; i >= 0; i--)
A.push_back(a[i] - '0');
for (int i = blen - 1; i >= 0; --i)
B.push_back(b[i] - '0');
vector<int> res = add(A, B);
int rlen = res.size();
for (int i = rlen - 1; i >= 0; --i)
printf(i == 0 ? "%d\n" : "%d", res[i]);
}

高精度减法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
int base = 10; //表示进制

bool cmp(vector<int> &A, vector<int> &B) {
if (A.size() != B.size())
return A.size() > B.size();
for (int i = A.size() - 1; i >= 0; i--)
if (A[i] != B[i])
return A[i] > B[i];
return true;
}

vector<int> sub(vector<int> &A, vector<int> &B) {
vector<int> res;
int t = 0;
for (int i = 0; i < A.size(); ++i)
{
t += A[i];
if (i < B.size())
t -= B[i];
res.push_back((t + base) % base);
if (t < 0)
t = -1;
else
t = 0;
}
while (res.size() > 1 && res.back() == 0)
res.pop_back();
return res;
}

void solve() {
string a, b;
cin >> a >> b;
int alen = a.size(), blen = b.size();
vector<int> A, B;
for (int i = alen - 1; i >= 0; --i)
A.push_back(a[i] - '0');
for (int i = blen - 1; i >= 0; --i)
B.push_back(b[i] - '0');
vector<int> res;
if (cmp(A, B))
res = sub(A, B);
else {
putchar('-');
res = sub(B, A);
}
int rlen = res.size();
for (int i = rlen - 1; i >= 0; --i)
printf(i == 0 ? "%d\n" : "%d", res[i]);
}

高精度乘低精度

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
int base = 10;

// res = A*b
vector<int> mul(vector<int> &A, int b) {
vector<int> res;
int t = 0; // 进位
for (int i = 0; i < A.size() | t; ++i) {
if (i < A.size())
t += A[i] * b;
res.push_back(t % base);
t /= base;
}
while (res.size() > 1 && res.back() == 0)
res.pop_back();
return res;
}

void solve() {
string a;
int b;
cin >> a >> b;
vector<int> A;
int alen = a.size();
for (int i = alen - 1; i >= 0; --i)
A.push_back(a[i] - '0');
vector<int> res = mul(A, b);
for (int i = res.size() - 1; i >= 0; --i)
printf(i == 0 ? "%d\n" : "%d", res[i]);
}

高精度除以低精度(输出商和余数)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
int base = 10;

// A/b == res...r
vector<int> div(vector<int> &A, int b, int &r) {
vector<int> res;
for (int i = A.size() - 1; i >= 0; --i) {
r = r * base + A[i];
res.push_back(r / b);
r %= b;
}
reverse(res.begin(), res.end());
while (res.size() > 1 && res.back() == 0) {
res.pop_back();
}
return res;
}

void solve() {
string a;
int b;
int r = 0;
vector<int> A;
cin >> a >> b;
int alen = a.size();
for (int i = alen - 1; i >= 0; --i)
A.push_back(a[i] - '0');
vector<int> res = div(A, b, r);
for (int i = res.size() - 1; i >= 0; --i)
printf("%d", res[i]);
printf("\n%d\n", r);
}

二维前缀和

1
2
3
S[i, j] = 第i行j列格子左上部分所有元素的和
以(x1, y1)为左上角,(x2, y2)为右下角的子矩阵的和为:
S[x2, y2] - S[x1 - 1, y2] - S[x2, y1 - 1] + S[x1 - 1, y1 - 1]

一维差分

1
给区间[l, r]中的每个数加上c:B[l] += c, B[r + 1] -= c

二维差分

每次将某个子矩阵中所有元素的值加上 cc ,输出操作完后的矩阵模样

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
int n, m, q;
const int Max_n = 1000 + 10;
int a[Max_n][Max_n], sub[Max_n][Max_n], pre[Max_n][Max_n];

void insert(int a1, int b1, int a2, int b2, int val) {
sub[a1][b1] += val;
sub[a2 + 1][b1] -= val;
sub[a1][b2 + 1] -= val;
sub[a2 + 1][b2 + 1] += val;
}

void solve() {
int a1, b1, a2, b2, val;
n = read(), m = read(), q = read();
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j) {
a[i][j] = read();
insert(i, j, i, j, a[i][j]);
}
}
for (int i = 1; i <= q; ++i) {
a1 = read(), b1 = read(), a2 = read(), b2 = read(), val = read();
insert(a1, b1, a2, b2, val);
}
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= m; ++j)
pre[i][j] = pre[i - 1][j] + pre[i][j - 1] - pre[i - 1][j - 1] + sub[i][j];
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= m; ++j)
printf(j == m ? "%d\n" : "%d ", pre[i][j]);
}

位运算

1
2
求n的第k位数字 : n >> k & 1
返回n的最后一位1lobit(n) = n & -n;

双指针

1
2
3
4
5
6
7
8
for (int i = 0, j = 0; i < n; i ++ ) {
while (j < i && check(i, j)) j ++ ;

// 具体问题的逻辑
}
常见问题分类:
(1) 对于一个序列,用两个指针维护一段区间
(2) 对于两个序列,维护某种次序,比如归并排序中合并两个有序序列的操作

ST表求区间最值

本质上是运用了倍增思想的DPDP ,定义 f[i][j]f[i][j] 表示从 ii 开始,长度为 2j2^j 的区间中的最大值

f[i][j]=max(f[i][j1],f[i+2j1][j1])f[i][j] = max(f[i][j-1],f[i + 2^{j-1}][j-1])

优点:代码短,速度快

缺点:只能静态维护

预处理时间复杂度:O(nlogn)O(nlogn) ,查询时间复杂度:O(1)O(1)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
int dp[MAXN << 1][32]; //从i开始,长度为2^j的区间中的最大值,写法的问题必须乘2不然RE
int n; //区间长度
int a[MAXN]; //存放序列
int lg[MAXN]; //每个数以2为底的对数 (highbit)
int m; //查询次数

void pre() {
//预处理出以2为底的对数
for (int i = 1; i <= n; i++) lg[i] = (int)(log2(i));
}

void init() {
for (int i = 1; i <= n; i++) dp[i][0] = a[i];
for (int j = 1; j <= lg[n]; j++)
for (int i = 1; i <= n; i++)
dp[i][j] = max(dp[i][j - 1], dp[i + (1 << (j - 1))][j - 1]);
}

void solve() {
n = read(), m = read();
for (int i = 1; i <= n; i++) a[i] = read();
pre();
init();
while (m--) {
int l, r; //询问[l,r]这段区间的最大值
l = read(), r = read();
int k = lg[r - l + 1];
printf("%d\n", max(dp[l][k], dp[r - (1 << k) + 1][k]));
}
}

01分数规划

求解例如 kk 个数,求得 (max/min)res=ai×xibi×xi(max/min)res = \frac{\sum{a_i\times x_i}}{\sum{b_i\times x_i}}

化简一下式子得 y=(aires×bi)×xi=0y = \sum{(a_i-res\times b_i)\times x_i}=0

二分这个值,根据 aires×bia_i - res \times b_i 去排序,取 kk 个值之后判断是否为 00 。若去最大值,当值之和大于 00 说明 resres 还能再取大,小于 00 说明 resres 取小了。

GCC内置位运算函数__builtin

判断 nn 的二进制中有多少个 11

__builtin_popcount(unsigned int n)

1
2
int n = 15; //二进制为1111
cout << __builtin_popcount(n) << endl; //输出4

判断 nn 的二进制中 11 的个数的奇偶性

__builtin_parity(unsigned int n)

1
2
3
4
int n = 15; //二进制为1111
int m = 7; //111
cout << __builtin_parity(n) << endl; //偶数个,输出0
cout << __builtin_parity(m) << endl; //奇数个,输出1

判断 nn 的二进制末尾最后一个 11 的位置,从 11 开始

__builtin_ffs(unsigned int n)

1
2
3
4
int n = 1; //1
int m = 8; //1000
cout << __builtin_ffs(n) << endl; //输出1
cout << __builtin_ffs(m) << endl; //输出4

判断 nn 的二进制末尾后面 00 的个数,当 nn00 时,和 nn 的类型有关

__builtin_ctz(unsigned int n)

1
2
3
4
int n = 1;//1
int m = 8;//1000
cout<<__builtin_ctzll(n)<<endl;//输出0
cout<<__builtin_ctz(m)<<endl;//输出3

返回前导 00 的个数

__builtin_clz (unsigned int x)

可以顺便求出第一个 11 出现的位置

整数三分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
int l = 1, r = 100;
while (l < r) {
int lmid = l + (r - l) / 3;
int rmid = r - (r - l) / 3;
lans = f(lmid), rans = f(rmid);
//求凹函数的极小值
if (lans <= rans) r = rmid - 1;
else l = lmid + 1;
//求凸函数的极大值
if (lans <= rans) l = lmid + 1;
else r = rmid - 1;
}
cout << min(lans, rans) << endl; //求凹函数的极小值
cout << max(lans, rans) << endl; //求凸函数的极大值

浮点数三分

注意:不能用eps,这样会出锅。应该把结束条件设定为浮点数三分的次数,比如200(根据题目的时间复杂度而设定),只要不TLE限制次数越大越好。然后每次三分自增当前的次数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
const double eps = 1e-8;
while (r - l < eps) {
double lmid = l + (r - l) / 3;
double rmid = r - (r - l) / 3;
lans = f(lmid), rans = f(rmid);
//求凹函数的极小值
if (lans <= rans) r = rmid;
else l = lmid;
//求凸函数的极大值
if (lans <= rans) l = lmid;
else r = rmid;
}
//输出l和r都行
cout << l << endl;

KMP

nextnext 数组:字符串 ss 中以 s[i]s[i] 为结尾,与前缀相同的最大长度

例如 abbaabbaabbaabbanext[8]=4next[8] = 4

已知 nextnext 数组能够求出最小循环节,最小循环节长度为 lennext[len]len - next[len] 。准确地来说是求出 [begin,j],begin<jn[begin, j],begin < j \leq n 的所在最小循环节。若需要保证完全循环,即最小循环节长度整除子串长度需要特判一下。

下标以 11 开始

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
// s[]是长文本,p[]是模式串,n是s的长度,m是p的长度
int n, m;
int ne[N];
char s[M], p[N];
求模式串的Next数组:
for (int i = 2, j = 0; i <= m; i++) {
while (j && p[i] != p[j + 1]) j = ne[j];
if (p[i] == p[j + 1]) j++ ;
ne[i] = j;
}

// 匹配
for (int i = 1, j = 0; i <= n; i++) {
while (j && s[i] != p[j + 1]) j = ne[j];
if (s[i] == p[j + 1]) j ++ ;
if (j == m) {
j = ne[j];
// 匹配成功后的逻辑
}
}

下标以 00 开始

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
int n, m;
char s[N], p[N];
int ne[N];
void solve() {
cin >> m >> p >> n >> s;
ne[0] = -1;
for (int i = 1, j = -1; i < m; i++) {
while (j >= 0 && p[j + 1] != p[i]) j = ne[j];
if (p[j + 1] == p[i]) j++;
ne[i] = j;
}

for (int i = 0, j = -1; i < n; i++) {
while (j != -1 && s[i] != p[j + 1]) j = ne[j];
if (s[i] == p[j + 1]) j++;
if (j == m - 1) {
cout << i - j << ' ';
j = ne[j];
}
}
}

从任意一位作为起始位求 nextnext 数组(下标从0开始)

题目是把一个字符串压缩,压缩方法如下:

若字符串为 ababab 可以压缩成3ab,也可以压缩1a1b1a1b1a1b

问压缩完后的字符串长度最小的方案。

定义十进制下玄幻次数的位数为cnt, 最小循环节长度为len

dp[j+i]=min(dp[j+i],dp[i]+cnt+len)dp[j + i] = min(dp[j + i], dp[i] + cnt + len)

否则的话,只能一整个的压了(相当于放屁操作),转移方程就变成了

dp[j+i]=min(dp[j+i],dp[i]+j+1)dp[j +i] = min(dp[j+i], dp[i] + j +1)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
char s[MAXN];
int n;
int decimal[MAXN]; //十进制数的位数
ll dp[MAXN]; //前i位压缩完最少的位数
int nxt[MAXN];

void kmp(int len, char *s) {
nxt[0] = nxt[1] = 0;
for (int i = 1; i < len; i++) {
int j = nxt[i];
while (j && s[i] != s[j]) j = nxt[j];
if (s[j] == s[i])
nxt[i + 1] = j + 1;
else
nxt[i + 1] = 0;
}
}

void prework() {
for (int i = 0; i < 10; i++) decimal[i] = 1;
for (int i = 10; i < 100; i++) decimal[i] = 2;
for (int i = 100; i < 1000; i++) decimal[i] = 3;
for (int i = 1000; i <= 8000; i++) decimal[i] = 4;
for (int i = 1; i <= n; i++) dp[i] = i + 1;
}

void solve() {
scanf("%s", s);
n = strlen(s);
prework();
for (int i = 0; i < n; i++) {
kmp(n - i, s + i);
for (int j = 1; j + i <= n; j++) {
//j表示i和p之间的字符数量
int len = j - nxt[j];
if (j % len) {
dp[j + i] = min(dp[j + i], dp[i] + 1 + j);
}
else {
dp[j + i] = min(dp[j + i], dp[i] + decimal[j / len] + len);
}
}
}
printf("%d\n", dp[n]);
}

cin读入优化

1
2
3
4
5
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
getline(cin, str);
//为防止空格读入,注意使用cin.ignore();

重载运算符

1
2
3
4
5
6
7
8
9
struct node {
int id;
int num;
bool operator< (const node& a) const {
if (num == a.num)
return id < a.id;
return num > a.num;
}
};

数组切段

1
2
3
4
for (int l = 1, r = 1; l <= n; l = r + 1, r = l) {
while (r < n && (a[l] == a[r + 1])) ++r;
printf("[%d,%d] ",l ,r);
}