LLL Algorithm — Yet Another Paper-Reading Problem
Dec 9, 2023
So recently there has been a paper-reading problem in UCup Stage 12, 2023 ICPC Asia Hefei Regional: SQRT Problem…
Problem Statement
Miss Burger has three positive integers \( n \), \( a \), and \( b \). She wants to find a positive integer solution \( x \) (\( 1 \leq x \leq n - 1 \)) that satisfies the following two conditions:
- \( x^2 \equiv a \ (\text{mod} \ n) \)
- \( \left\lfloor \sqrt[3]{x^2} \right\rfloor = b \)
Additionally, it is guaranteed that \( n \) is an odd number and \( \gcd(a, n) = 1 \). Here \( \gcd(x, y) \) denotes the greatest common divisor of \( x \) and \( y \). We also guarantee that there exists a unique solution.
Input
The first line contains a single integer \( n \) (\( 3 \leq n \leq 10^{100} - 1 \)).
The second line contains a single integer \( a \) (\( 1 \leq a \leq n - 1 \)).
The third line contains a single integer \( b \) (\( 1 \leq b \leq n - 1 \)).
Output
Output a single integer denoting the solution \( x \).
Solution
First we read this wonderful lecture note from U. of Michigan by Prof. Chris Peikert to Lecture 3. This note tells us how to do a simple version of Coppersmith's method (while Lecture 4 tells us how to do a full version, which is highly interesting but not needed in this problem).
We observe that \( \left\lfloor \sqrt[3]{x^2} \right\rfloor = b \) allows us to bound \( x \) between a small range \( [l, r] \). How small is this range? We can use some math to bound it. Observe that
\[ r - l \le n - (n^\frac{2}{3} - 1)^{\frac{3}{2}} = \frac{n^2-(n^\frac{2}{3}-1)^3}{n+(n^\frac{2}{3} - 1)^{\frac{3}{2}}} < \frac{3n^\frac{4}{3}-2}{2n} < \frac{3}{2}n^\frac{1}{3}. \]
Therefore, we can rewrite \( x \) as \( x = l + k \) such that \( 0 \le k \le r - l \le \frac{3}{2}n^\frac{1}{3} \). This allows us to rewrite the original statement \( x^2 \equiv a \ (\text{mod} \ n) \) as \( k^2 + 2lk +l^2 - a \equiv 0 (\text{mod} \ n) \) with a relatively small \( k \).
We now turn our attention to the simple Coppersmith. The \( 3 \times 3 \) matrix in Lecture 3 allows us to find a solution of this equation if \( k < d=\frac{1}{6}N^\frac{1}{3} \), so we are off by a constant of 9. Nevertheless, we can partition the range \( [l, r] \) into at most 9 intervals \( \{ [l, l + d), [l + d, l + 2d), \cdots, [l + 8d, l + 9d) \} \), and run LLL 9 times to get a solution.
I used the LLL code from here with a little modification to implement the solution:
import sys
from fractions import Fraction
from typing import List, Sequence
from math import isqrt
class Vector(list):
def __init__(self, x):
super().__init__(map(Fraction, x))
def sdot(self) -> Fraction:
return self.dot(self)
def dot(self, rhs: "Vector") -> Fraction:
rhs = Vector(rhs)
assert len(self) == len(rhs)
return sum(map(lambda x: x[0] * x[1], zip(self, rhs)))
def proj_coff(self, rhs: "Vector") -> Fraction:
rhs = Vector(rhs)
assert len(self) == len(rhs)
return self.dot(rhs) / self.sdot()
def proj(self, rhs: "Vector") -> "Vector":
rhs = Vector(rhs)
assert len(self) == len(rhs)
return self.proj_coff(rhs) * self
def __sub__(self, rhs: "Vector") -> "Vector":
rhs = Vector(rhs)
assert len(self) == len(rhs)
return Vector(x - y for x, y in zip(self, rhs))
def __mul__(self, rhs: Fraction) -> "Vector":
return Vector(x * rhs for x in self)
def __rmul__(self, lhs: Fraction) -> "Vector":
return Vector(x * lhs for x in self)
def __repr__(self) -> str:
return "[{}]".format(", ".join(str(x) for x in self))
def gramschmidt(v: Sequence[Vector]) -> Sequence[Vector]:
u: List[Vector] = []
for vi in v:
ui = Vector(vi)
for uj in u:
ui = ui - uj.proj(vi)
if any(ui):
u.append(ui)
return u
def reduction(
basis: Sequence[Sequence[int]], delta: Fraction = Fraction(3, 4)
) -> Sequence[Sequence[int]]:
n = len(basis)
basis = list(map(Vector, basis))
ortho = gramschmidt(basis)
def mu(i: int, j: int) -> Fraction:
return ortho[j].proj_coff(basis[i])
k = 1
while k < n:
for j in range(k - 1, -1, -1):
mu_kj = mu(k, j)
if abs(mu_kj) > Fraction(1, 2):
basis[k] = basis[k] - basis[j] * round(mu_kj)
ortho = gramschmidt(basis)
if ortho[k].sdot() >= (delta - mu(k, k - 1) ** 2) * ortho[k - 1].sdot():
k += 1
else:
basis[k], basis[k - 1] = basis[k - 1], basis[k]
ortho = gramschmidt(basis)
k = max(k - 1, 1)
return [list(map(int, b)) for b in basis]
def icube(x):
l, r, ans = 0, x, 0
while l <= r:
m = (l + r) // 2
if m * m * m <= x:
l, ans = m + 1, m
else:
r = m - 1
return ans
input = sys.stdin.readline
N = int(input())
A = int(input())
B = int(input())
def find_left():
l, r, ans = 0, N, 0
while l <= r:
m = (l + r) // 2
if icube(m * m) < B:
l = m + 1
else:
r, ans = m - 1, m
return ans
L = find_left()
def is_answer(x):
return x >= 1 and x <= N - 1 and x * x % N == A and icube(x * x) == B
def coppersmith(poly: List, mod: int, d: int):
n = len(poly)
pd = [d**i for i in range(n)]
mat = [[poly[i] * pd[i] for i in range(n)]] + [
[pd[i] * mod if i == j else 0 for i in range(n)] for j in range(n - 1)
]
mat = reduction(mat)
return [mat[0][i] // pd[i] for i in range(n)]
if N <= 100:
for x in range(N):
if is_answer(x):
sys.stdout.write(f"{x}\n")
exit(0)
else:
d = int(N ** (1 / 3) / 6)
while True:
poly = [L * L - A, 2 * L, 1]
reduced = coppersmith(poly=poly, mod=N, d=d)
a, b, c = reduced[2], reduced[1], reduced[0]
ans = []
if a != 0:
delta = b * b - 4 * a * c
if delta >= 0:
delta = isqrt(delta)
ans = [(-b + delta) // (2 * a), (-b - delta) // (2 * a)]
elif b != 0:
ans = [-c // b]
for x in ans:
if is_answer(L + x):
sys.stdout.write(f"{L + x}\n")
exit(0)
L = L + d
Some Optimization
In fact, we can observe that even if the shortest vector \( b_1 \)'s function \( h_1(x) \) we get from LLL does not solve the problem directly, it is so short that \( |h_1(k)| \le |h_1(9d)| \le 81n \). So we can try every equation with the form \( |h_1(k)| - in = 0 \) with \( 0\le i\le 81 \) to check for a solution.
The only problem here is that \( h_1(k) \) may devolve to a trivial function (i.e. \( h_1(k) = n \) is a basis but not interesting). So we may instead use the second row \( b_2 \) in the basis to get a non-trivial function \( h_2(x) \). Observe that since \( h_1(k)=n \), we can simply remove the first basis (i.e. \( b_1 \)) and the first column from the reduced matrix \( A \), and the remaining matrix \( A' \) still satisfies both LLL conditions, with \( b_2 \) now being the approximation. Therefore, we can try every equation with the form \( |h_2(k)| - in = 0 \) with \( 0\le i\le 81 \) to check for a solution.
def coppersmith(poly: List, mod: int, d: int):
n = len(poly)
pd = [d**i for i in range(n)]
mat = [[poly[i] * pd[i] for i in range(n)]] + [
[pd[i] * mod if i == j else 0 for i in range(n)] for j in range(n - 1)
]
mat = reduction(mat)
if mat[0][1] == 0 and mat[0][2] == 0:
return [mat[1][i] // pd[i] for i in range(n)]
return [mat[0][i] // pd[i] for i in range(n)]
if N <= 100:
for x in range(N):
if is_answer(x):
sys.stdout.write(f"{x}\n")
exit(0)
else:
d = int(N ** (1 / 3) / 6)
poly = [L * L - A, 2 * L, 1]
reduced = coppersmith(poly=poly, mod=N, d=d)
a, b, c = reduced[2], reduced[1], reduced[0] - 81 * N
while True:
ans = []
if a != 0:
delta = b * b - 4 * a * c
if delta >= 0:
delta = isqrt(delta)
ans = [(-b + delta) // (2 * a), (-b - delta) // (2 * a)]
elif b != 0:
ans = [-c // b]
for x in ans:
if is_answer(L + x):
sys.stdout.write(f"{L + x}\n")
exit(0)
c = c + N