Constant Optimization on Binary Exponentiation and Matrix Multiplication

Dec 1, 2023

During our discussion on a certain CSES problem (Counting Tilings) in the Purdue CPU Club, we come across a debate on whether matrix multiplication can be used to solve this problem. While surely it is not the intended solution (and much more complicated than the intended solution), since the complexity is \( O(n^3\log m) \) with \( n=m=1000 \), the discussion quickly became figuring out if some matrix multiplication template code is fast enough or not. I decided to make a point that one does not need a template code to solve matrix multiplication efficiently, so I went ahead and tried to program something that works without using a template, which turned out to be an optimization hell…

First Version

The code itself is very easy to implement:

#include <bits/stdc++.h> #define __ __attribute__((optimize("-O3"))) #define _ \ __ __inline__ \ __attribute__((__gnu_inline__, __always_inline__, __artificial__)) const int SIZE = 1 << 20; const int MOD = 1E9 + 7; int N, M, N1, MASK; int mat[SIZE], ans[SIZE]; _ void mut(int mata[], int matb[]) { static int matc[SIZE]; memset(matc, 0, SIZE * sizeof(int)); for (int k = 0; k < N1; ++k) for (int i = 0; i < N1; ++i) for (int j = 0; j < N1; ++j) matc[i << N | j] = (matc[i << N | j] + 1ll * mata[i << N | k] * matb[k << N | j]) % MOD; memcpy(mata, matc, SIZE * sizeof(int)); } __ int main() { scanf("%d%d%d", &N, &M); N1 = 1 << N; MASK = N1 - 1; for (int mask1 = 0; mask1 < N1; ++mask1) { for (int mask2 = 0; mask2 < N1; ++mask2) { if (!(mask1 & mask2)) { int rem = N1 + mask1 + mask2; int succ = 1; while (rem > 0) { int low = __builtin_ctz(rem); if (low % 2) { succ = 0; break; } rem >>= low + 1; } mat[mask1 << N | mask2] = succ; } } } for (int i = 0; i < N; ++i) ans[i << N | i] = 1; while (M > 0) { if (M & 1) mut(ans, mat); mut(mat, mat); M >>= 1; } printf("%d\n", ans[0]); }

This version of the code is fast enough to run on my machine, but unfortunately is not fast enough to run on whatever machine CESC uses. Well, that's fine, we can always use some more optimization.

Second Version: Isolate the Slow Part

In any optimization, the most important job is to find out the bottleneck, where the code runs slowest. In this code it is not difficult to infer that the matrix multiplication is very slow with three loops:

for (int k = 0; k < N1; ++k) for (int i = 0; i < N1; ++i) for (int j = 0; j < N1; ++j) matc[i << N | j] = (matc[i << N | j] + 1ll * mata[i << N | k] * matb[k << N | j]) % MOD;

It seems apparent that our matrix is sparse, so we only need to think about the situation where both elements are not zero. Let us also rearrange the order of the loop a bit to use cache better:

for (int i = 0; i < N1; ++i) for (int k = 0; k < N1; ++k) if (mata[i << N | k]) for (int j = 0; j < N1; ++j) if (matb[k << N | j]) matc[i << N | j] = (matc[i << N | j] + 1ll * mata[i << N | k] * matb[k << N | j]) % MOD;

The code is two times faster now. Unfortunately, it is not enough to AC. We can also try to simplify the inner loop by computing the indices in the outer loop, and use unsigned int instead of int:

for (uint i = 0, iN = 0; i < N1; ++i, iN += N1) for (uint k = 0, kN = 0, vala; k < N1; ++k, kN += N1) if (vala = mata[iN | k]) for (uint j = 0, iNj = iN, kNj = kN; j < N1; ++j, ++iNj, ++kNj) { if (matb[kNj]) matc[iNj] = (matc[iNj] + 1ll * vala * matb[kNj]) % MOD; }

The code is 1.5 times faster now. Unfortunately, it is still not enough to AC.

Third Version: Vectorization

We know from our compiler course that manually unroll the loop will make things faster. Fortunately, for sufficiently large \( N \) in this problem the size of the matrices are all powers of 2 so we do not need to wory about cutoff:

for (uint i = 0, iN = 0; i < N1; ++i, iN += N1) for (uint k = 0, kN = 0, vala; k < N1; ++k, kN += N1) if (vala = mata[iN + k]) for (uint j = 0, iNj = iN, kNj = kN; j < N1; j += 8, iNj += 8, kNj += 8) { #define unroll(r) \ if (matb[kNj + r]) \ matc[iNj + r] = (matc[iNj + r] + 1ll * vala * matb[kNj + r]) % MOD; unroll(0); unroll(1); unroll(2); unroll(3); unroll(4); unroll(5); unroll(6); unroll(7); }

Let us also rewrite the part for small \( N \) so that it does not WA on the smaller cases:

#define fast_pow(method) \ while (M > 0) { \ if (M & 1) \ method(ans, ans, mat); \ method(mat, mat, mat); \ M >>= 1; \ } if (N > 5) { fast_pow(mut_fast); } else { fast_pow(mut_slow); }

The code is 1.5 times faster now. Unfortunately, it still does not AC.

Fourth Version: Abuse Sparse Matrices

At this point I am almost out of ideas. Fortunately, I noticed that we only need ans[0][0] in the output, so we can make it even more sparse by only setting ans[0][0] = 1; and leave all other values zero.

I then did some stress testing and found method(ans, ans, mat) is very fast (< 10 ms) but method(mat, mat, mat) is very slow (> 50 ms) because the ans matrix is so sparse. I then attempted to do some black magic with the hope that it works…

while (M > 8) { \ if (M & 1) \ method(ans, ans, mat); \ method(mat, mat, mat); \ M >>= 1; \ } \ for (uint i = 0; i < M; ++i) \ method(ans, ans, mat); if (N > 5) { fast_pow(mut_fast); } else { fast_pow(mut_slow); }

I brute-forced the top 3 bits instead of trying to do a binary exponentiation. The code is 1.5 times faster.

Fortunately, this time the code ACs with a time of 0.96 s:

#include <cstdio> #include <cstring> #define __ __attribute__((optimize("-O3"))) #define _ \ __ __inline__ \ __attribute__((__gnu_inline__, __always_inline__, __artificial__)) typedef unsigned int uint; const uint SIZE = 1 << 20; const uint MOD = 1E9 + 7; uint N, M, N1, MASK; uint mat[SIZE], ans[SIZE]; _ void mut_fast(uint res[], uint *mata, uint *matb) { static uint matc[SIZE]; memset(matc, 0, SIZE * sizeof(uint)); for (uint i = 0, iN = 0; i < N1; ++i, iN += N1) for (uint k = 0, kN = 0, vala; k < N1; ++k, kN += N1) if (vala = mata[iN + k]) for (uint j = 0, iNj = iN, kNj = kN; j < N1; j += 8, iNj += 8, kNj += 8) { #define unroll(r) \ if (matb[kNj + r]) \ matc[iNj + r] = (matc[iNj + r] + 1ll * vala * matb[kNj + r]) % MOD; unroll(0); unroll(1); unroll(2); unroll(3); unroll(4); unroll(5); unroll(6); unroll(7); } memcpy(res, matc, SIZE * sizeof(uint)); } _ void mut_slow(uint res[], uint *mata, uint *matb) { static uint matc[SIZE]; memset(matc, 0, SIZE * sizeof(uint)); for (uint i = 0, iN = 0; i < N1; ++i, iN += N1) for (uint k = 0, kN = 0, vala; k < N1; ++k, kN += N1) if (vala = mata[iN | k]) for (uint j = 0, iNj = iN, kNj = kN; j < N1; j += 1, iNj += 1, kNj += 1) { if (matb[kNj]) matc[iNj] = (matc[iNj] + 1ll * vala * matb[kNj]) % MOD; } memcpy(res, matc, SIZE * sizeof(uint)); } __ int main() { scanf("%u%u", &N, &M); N1 = 1 << N; MASK = N1 - 1; for (uint mask1 = 0; mask1 < N1; ++mask1) { for (uint mask2 = 0; mask2 < N1; ++mask2) { if (!(mask1 & mask2)) { uint rem = N1 + mask1 + mask2; uint succ = 1; while (rem > 0) { uint low = __builtin_ctz(rem); if (low % 2) { succ = 0; break; } rem >>= low + 1; } mat[mask1 << N | mask2] = succ; } } } ans[0] = 1; #define fast_pow(method) \ while (M > 8) { \ if (M & 1) \ method(ans, ans, mat); \ method(mat, mat, mat); \ M >>= 1; \ } \ for (uint i = 0; i < M; ++i) \ method(ans, ans, mat); if (N > 5) { fast_pow(mut_fast); } else { fast_pow(mut_slow); } printf("%u\n", ans[0]); }

Aftermath

I was quick to share the code in the club, in hopes of proving that one does not need a template to write three for loops. Neverthess, it quickly became

im just going to steal that template
so clean
just abuse the compiler

🤡.