This weekend, CTFZone 2023 was running, and after solving a few of the easier challenges, I ended up spending all of my time on the last crypto question, ‘Wise sage’, which I managed to blood ~6 hours before the end of the CTF.

import random as rnd
import logging
import re
import socketserver
from hashlib import blake2b
from Crypto.Cipher import AES
from gmpy2 import isqrt
from py_ecc.fields import optimized_bn128_FQ
from py_ecc.optimized_bn128.optimized_curve import (
    Optimized_Point3D,
    normalize,
    G1,
    multiply,
    curve_order,
    add,
    neg,
)
from flag import flag

hello_string = """Hello, adventurer! A wise sage told me that you can predict the future. Let's see if you can predict my private key. Shhhhhhhh! Send it securely
"""
help_string = """
Usage:
ecdh <(x,y)> - send the point to establish an ECDH tunnel
answer <hex> - send the answer to the question "What is my private key?" encrypted in hex from
help - this info
"""
prompt = ">"


def egcd_step(prev_row, current_row):
    (s0, t0, r0) = prev_row
    (s1, t1, r1) = current_row
    q_i = r0 // r1
    return (s0 - q_i * s1, t0 - q_i * t1, r0 - q_i * r1)


def find_decomposers(lmbda, modulus):
    mod_root = isqrt(modulus)

    egcd_trace = [(1, 0, modulus), (0, 1, lmbda)]
    while egcd_trace[-2][2] >= mod_root:
        egcd_trace.append(egcd_step(egcd_trace[-2], egcd_trace[-1]))

    (_, t_l, r_l) = egcd_trace[-3]
    (_, t_l_plus_1, r_l_plus_1) = egcd_trace[-2]
    (_, t_l_plus_2, r_l_plus_2) = egcd_trace[-1]
    (a_1, b_1) = (r_l_plus_1, -t_l_plus_1)
    if (r_l**2 + t_l**2) <= (r_l_plus_2**2 + t_l_plus_2**2):
        (a_2, b_2) = (r_l, -t_l)
    else:
        (a_2, b_2) = (r_l_plus_2, -t_l_plus_2)

    return (a_1, b_1, a_2, b_2)


lmbda = 4407920970296243842393367215006156084916469457145843978461
beta = 2203960485148121921418603742825762020974279258880205651966

(a_1, b_1, a_2, b_2) = find_decomposers(lmbda, curve_order)


def compute_balanced_representation(scalar, modulus):
    c_1 = (b_2 * scalar) // modulus
    c_2 = (-b_1 * scalar) // modulus
    k_1 = scalar - c_1 * a_1 - c_2 * a_2
    k_2 = -c_1 * b_1 - c_2 * b_2
    return (k_1, k_2)


def multiply_with_endomorphism(x: int, y: int, scalar: int):
    assert scalar >= 0 and scalar < curve_order
    point = (optimized_bn128_FQ(x), optimized_bn128_FQ(y), optimized_bn128_FQ.one())
    endo_point = (
        optimized_bn128_FQ(x) * optimized_bn128_FQ(beta),
        optimized_bn128_FQ(y),
        optimized_bn128_FQ.one(),
    )
    (k1, k2) = compute_balanced_representation(scalar, curve_order)
    print("K decomposed:", k1, k2)
    if k1 < 0:
        point = neg(point)
        k1 = -k1
    if k2 < 0:
        endo_point = neg(endo_point)
        k2 = -k2
    return normalize(add(multiply(point, k1), multiply(endo_point, k2)))


(HOST, PORT) = ("0.0.0.0", 1337)


class CheckHandler(socketserver.BaseRequestHandler):
    """
    The RequestHandler class for our server.

    It is instantiated once per connection to the server, and must
    override the handle() method to implement communication to the
    client.
    """

    def handle(self):
        # Generate the private key
        private_key = rnd.randint(2, curve_order - 1)
        print("Private key:", private_key)
        public_key = multiply_with_endomorphism(G1[0].n, G1[1].n, private_key)
        self.request.sendall((hello_string + help_string + prompt).encode())
        session_key = normalize(G1)
        while True:
            # self.request is the TCP socket connected to the client
            try:
                data = self.request.recv(2048).strip().decode()
                (command, arguments) = data.split(" ")
                if command == "help":
                    self.request.sendall((help_string + prompt).encode())
                elif command == "ecdh":
                    (p_x, p_y) = map(
                        int,
                        re.match(
                            r"\((\d+),(\d+)\)", "".join(arguments).replace(" ", "")
                        )
                        .group()[1:-1]
                        .split(","),
                    )
                    session_key = multiply_with_endomorphism(p_x, p_y, private_key)
                    print(session_key)
                    hash_check = blake2b((str(session_key) + "0").encode()).hexdigest()
                    self.request.sendall(
                        (
                            f"Public key: {str(public_key)}, session check: {hash_check}\n"
                            + prompt
                        ).encode()
                    )
                elif command == "answer":
                    fct = bytes.fromhex("".join(arguments))
                    if len(fct) < 32:
                        self.request.sendall(("IV + CT too short" + prompt).encode())
                        continue
                    key = blake2b((str(session_key) + "1").encode()).digest()[
                        : AES.key_size[-1]
                    ]
                    aes = AES.new(key, AES.MODE_CBC, fct[: AES.block_size])
                    pt = aes.decrypt(fct[AES.block_size :])
                    answer = int(pt.decode())
                    if answer == private_key:
                        self.request.sendall(f"Yay! Here is your flag: {flag}".encode())
                        return
                    else:
                        self.request.sendall("No, wrong answer".encode())
                        return

            except ValueError:
                logging.error("Conversion problem or bad data")
                self.request.sendall(("Malformed data\n" + prompt).encode())
                continue
            except ConnectionResetError:
                logging.error("Connection reset by client")
                return
            except UnicodeDecodeError:
                logging.error("Client sent weird data")
                return
            except AttributeError:
                logging.error("Malformed command")
                self.request.sendall(("Malformed command\n" + prompt).encode())
                continue


if __name__ == "__main__":
    logging.basicConfig(
        level=logging.DEBUG,
        format="%(asctime)s %(name)-12s %(levelname)-8s %(message)s",
        datefmt="%m-%d %H:%M",
        filename="server.log",
        filemode="w",
    )
    server = socketserver.TCPServer((HOST, PORT), CheckHandler, bind_and_activate=False)
    server.allow_reuse_address = True
    server.server_bind()
    server.server_activate()
    logging.info(f"Started listenning on {HOST}:{PORT}")
    server.serve_forever()

The idea behind the server is really simple. A private value $s$ is computed, its associated ECDH public key w.r.t. the point $G = (1, 2)$ is output, and we then have a prompt with two commands.

The second command, ‘answer’ is simply the win function. Once $s$ is recovered, sending its encryption via this command gives us the flag.

The first commmand, ’ecdh’, is our oracle. This command takes in two coordinates, $x, y$, and then computes $sP$, where $P$ is the point on the curve BN128, and then prints us the hash of the result of the multiplication.

Well, this is not exactly correct. First of all, the server never checks that our point actually lies on the curve, and this suggests an ‘invalid curve’ type of approach (props to my teammate Eggroll for pointing it out). Moreover, multiplication is not computed via a simple double-and-add algorithm, it instead uses an endomorphism to speed up the computation.

Getting a bit more technical

On the curve BN128, the map $f((x, y)) = (\beta x, y)$ is an endomorphism, acting as multiplication by $\lambda$. This is to say that $f((x, y)) = \lambda \cdot (x, y)$, where $\cdot$ is representing multiplication on the curve by a scalar (I will now drop the dot for conciseness).

Now, how does multiplication on the server actually work? Basically, the scalar $s$ is rewritten as $k_0 + \lambda k_1$, for balanced values of $k_0, k_1$, and multiplication of a point $P$ is computed as $k_0 P + k_1f((P_x, P_y))$. This paper explains it.

The way we go about it appears to be: we fix a curve, for every small factor of the order of the curve we fix a point with that order, we query the oracle with that point, and we are now supposed to extract information from it. However, there are some details that we need to take care of.

  • Via some testing, one can notice that the number of unique orders of curves over $\mathbb{F}_q$ with $a = 0$ in their Weierstrass form is really limited, there are only 6 such unique orders. The original curve has prime order, so there is nothing we can learn from it, but the other curves have orders with some small prime factors, so the invalid curve approach might work.

  • When the curve is different from the original, $f$ does not behave exactly as defined above, but it seems to be the case that, assuming the curve to be cyclic, then taking a generator $G$ of the curve, taking some multiple of it $P$ such that it has the small prime order of our choice, then $f$ behaves like a morphism, except that the value of $\lambda$ (a.k.a the multiplier) changes w.r.t. the original one provided in the source code.

We can easily find the ’new’ values of $\lambda$ by solving one discrete log, finding $x$ such that $f(P) = xP$. This is quick because the order of $P$ is small, and $k_1P + k_2f(P) = (k_1 + \lambda k_2)P$, so by taking increasing multiples of $P$, we eventually learn the value of $k_1 + \lambda k_2 \bmod |P|$

Iterating over the various factors of the various orders gives us a bunch of bivariate modular equations in $k_1, k_2$, however solving those turns out not to be enough. The parameters of the challenge were fixed so that you need to use ALL of the information from the small factors, and currently we are avoiding what happens when the curve is not cyclic.

As a general approach, one can always choose $P$, and try to compute its associated value of $\lambda$. This can sometimes fail, meaning that $f(P)$ does not lie in the subgroup generated by P. But this means that we can iterate over all of the pairs $(i, j)$ (where $i$ and $j$ are taken modulo the orders of their respective points) and find one such that $iP + jf(P)$ matches our hash. Because the two points span disjoint subgroups, we actually learn $k_1 \bmod |P|$ and $k_2 \bmod |f(P)|$, with a search that is quadratic, but yields 2 independent modular equations.

After gathering all of the equations from all of the reasonably sized factors of the various orders, we use BKZ to solve them simultaneously, learning $k_0, k_1$. At this point it is just a matter of recovering $s$ from its representation, which is trivial (I ended up using LLL again because I totally forgot $s=k_0 + \lambda k_1$). My solution is implemented in Sage.