Oct 9, 2024

Constant Optimization and Cheesing through NAQ

One of the most difficult problems in NAQ 2024 is Balatro with a recorded 5 solves. I attempted and failed to solve the problem during the contest, and found that my solution differed quite a bit from the intended solution (and also got TLE'ed). I spent quite a bit of my time just trying to speed up my solution, which I found to be interesting.

These are the observations I made during my solve, which roughly reflected how I came up with the solution:

Therefore, we can brute-force through which set of multiply cards to use, and then do the sort-and-greedily-pick thing, which gives a complexity of \(2^{11}\cdot O(n\log n)\).

#include <bits/stdc++.h> int N, K; int tp[210000], val[210000]; int act[210000]; std::vector<int> ind[210000]; std::vector<int> ll; long long ans[210000]; long long res[210000]; void search(int level, int tot) { if (level >= ll.size()) { long long mul = 1; for (int i = N - 1; i >= 0; --i) { if (tp[i] == 1) { res[i] = 0; if (act[i]) { mul *= val[i]; } } else { res[i] = val[i] * mul; } } std::sort(res, res + N, std::greater<int>()); long long sum = 0; for (int i = 0; tot + i + 1 <= N; ++i) { sum += res[i]; ans[tot + i + 1] = std::max(ans[tot + i + 1], sum); } return; } search(level + 1, tot); for (int i = 0; i < ind[ll[level]].size() && tot + i + 1 <= K; ++i) { int x = *(ind[ll[level]].rbegin() + i); act[x] = 1; search(level + 1, tot + i + 1); } for (int i = 0; i < ind[ll[level]].size() && tot + i + 1 <= K; ++i) { int x = *(ind[ll[level]].rbegin() + i); act[x] = 0; } } int main() { scanf("%d%d", &N, &K); for (int i = 0; i < N; ++i) { char c; scanf(" %c%d", &c, &val[i]); tp[i] = (c == 'm'); if (c == 'm') { ind[val[i]].push_back(i); ll.push_back(val[i]); } } std::sort(ll.begin(), ll.end()); ll.erase(std::unique(ll.begin(), ll.end()), ll.end()); search(0, 0); for (int i = 1; i <= N; ++i) { printf("%lld\n", ans[i]); } }

I think it's plausible that this program can pass, given the rather generous time limit of 12 seconds. Unfortunately, it turns out that the program runs in 30 seconds on worst cases and dies.

Obviously the bottleneck here is the sort which gives the solution \(O(n\log n)\) complexity. I first tried to play around with different implementations of sort with no success. Eventually I came to understand that the \(11\) multiply cards cut the sequence into \(12\) ranges, and within each range the order of the add cards is fixed, not matter which multiply cards we use. Therefore, we can preprocess the sort for these ranges, and run a heap with \(12\) elements during our sort-and-pick process. This gives a complexity of \(2^{11}\cdot\log 12\cdot O(n)\).

#include <bits/stdc++.h> int N, K; int tp[210000], val[210000]; int col[210000]; int act[210000]; std::vector<int> ind[210000]; std::vector<int> ll; std::vector<int> mul_index; std::vector<std::pair<int, int>> range; long long mul[210000]; long long ans[210000]; void search(int level, int tot) { if (level >= ll.size()) { static std::pair<long long, int> pq[20]; int pqs = 0; for (int i = 0; i < range.size(); ++i) { pq[pqs++] = {mul[i] * val[range[i].first], range[i].first}; } std::make_heap(pq, pq + pqs); long long sum = 0; int cnt = 0; while (pqs > 0) { long long v; int i; std::tie(v, i) = pq[0]; std::pop_heap(pq, pq + pqs--); sum += v; ++cnt; if (ans[tot + cnt] < sum) { ans[tot + cnt] = sum; } if (i + 1 <= range[col[i]].second) { pq[pqs++] = {mul[col[i]] * val[i + 1], i + 1}; std::push_heap(pq, pq + pqs); } } return; } search(level + 1, tot); for (int i = 0; i < ind[ll[level]].size() && tot + i + 1 <= K; ++i) { int x = *(ind[ll[level]].rbegin() + i); act[x] = 1; for (int i = 0; i < range.size(); ++i) { if (range[i].first < x) { mul[i] *= ll[level]; } else { break; } } search(level + 1, tot + i + 1); } for (int i = 0; i < ind[ll[level]].size() && tot + i + 1 <= K; ++i) { int x = *(ind[ll[level]].rbegin() + i); act[x] = 0; for (int i = 0; i < range.size(); ++i) { if (range[i].first < x) { mul[i] /= ll[level]; } else { break; } } } } int main() { scanf("%d%d", &N, &K); int l = 0; for (int i = 0; i < N; ++i) { char c; scanf(" %c%d", &c, &val[i]); tp[i] = (c == 'm'); col[i] = range.size(); if (c == 'm') { ind[val[i]].push_back(i); ll.push_back(val[i]); if (i - 1 >= l) { range.push_back(std::make_pair(l, i - 1)); } l = i + 1; } } if (N - 1 >= l) { range.push_back(std::make_pair(l, N - 1)); } for (auto [l, r] : range) { std::sort(val + l, val + r + 1, std::greater<int>()); } std::fill(mul, mul + range.size(), 1); std::sort(ll.begin(), ll.end()); ll.erase(std::unique(ll.begin(), ll.end()), ll.end()); search(0, 0); long long aans = 0; for (int i = 1; i <= N; ++i) { aans = std::max(aans, ans[i]); printf("%lld\n", aans); } }

It runs on my machine but still dies online to TLE. Fortunately a bit of bit magic helps here:

#pragma GCC optimize("O3,unroll-loops") #include <bits/stdc++.h> int N, K; int tp[210000], val[210000]; int col[210000]; int act[210000]; std::vector<int> ind[210000]; std::vector<int> ll; std::vector<std::pair<int, int>> range; long long mul[210000]; long long ans[210000]; void search(int level, int tot) { if (level >= ll.size()) { static long long pq[20]; int pqs = 0; for (int i = 0; i < range.size(); ++i) { pq[pqs++] = (mul[i] * val[range[i].first]) << 20 | range[i].first; } std::make_heap(pq, pq + pqs); long long sum = 0; int cnt = 0; while (pqs > 0) { long long v = pq[0] >> 20; int i = pq[0] & ((1 << 20) - 1); std::pop_heap(pq, pq + pqs--); sum += v; ++cnt; if (ans[tot + cnt] < sum) { ans[tot + cnt] = sum; } if (i + 1 <= range[col[i]].second) { pq[pqs++] = (mul[col[i]] * val[i + 1]) << 20 | i + 1; std::push_heap(pq, pq + pqs); } } return; } search(level + 1, tot); int limit = std::min((int)ind[ll[level]].size(), K - tot); for (int i = 0; i < limit; ++i) { int x = *(ind[ll[level]].rbegin() + i); act[x] = 1; for (int j = 0; j < range.size() && range[j].first < x; ++j) { mul[j] *= ll[level]; } search(level + 1, tot + i + 1); } for (int i = 0; i < limit; ++i) { int x = *(ind[ll[level]].rbegin() + i); act[x] = 0; for (int j = 0; j < range.size() && range[j].first < x; ++j) { mul[j] /= ll[level]; } } } int main() { scanf("%d%d", &N, &K); int l = 0; for (int i = 0; i < N; ++i) { char c; scanf(" %c%d", &c, &val[i]); tp[i] = (c == 'm'); col[i] = range.size(); if (c == 'm') { ind[val[i]].push_back(i); ll.push_back(val[i]); if (i - 1 >= l) { range.push_back(std::make_pair(l, i - 1)); } l = i + 1; } } if (N - 1 >= l) { range.push_back(std::make_pair(l, N - 1)); } for (auto [l, r] : range) { std::sort(val + l, val + r + 1, std::greater<int>()); } std::fill(mul, mul + range.size(), 1); std::sort(ll.begin(), ll.end()); ll.erase(std::unique(ll.begin(), ll.end()), ll.end()); search(0, 0); long long aans = 0; for (int i = 1; i <= N; ++i) { aans = std::max(aans, ans[i]); printf("%lld\n", aans); } }

This code passes with 10.75 seconds on the judge.

At this point I am so done with my garbage solution and ready to appreciate the editorial:

...so we can solve it in \(O(n \log n)\) using a Li Chao tree, or \(O(n)\) using SMAWK...the overall running time is \(O(29\cdot n \log n)\) or \(O(29\cdot n)\)...

Okay, maybe my solution is not so garbage after all.