Created
May 15, 2015 06:47
-
-
Save heikoheiko/a7b5d173e5bfc1ab9db0 to your computer and use it in GitHub Desktop.
ecdsa recover in cython with gmp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
cython + gmp implementation of ecdsa recover | |
based on: | |
https://github.com/vbuterin/pybitcointools | |
""" | |
# Elliptic curve parameters (secp256k1) | |
P = 2**256 - 2**32 - 977 | |
N = 115792089237316195423570985008687907852837564279074904382605163141518161494337 | |
cdef int A = 0 | |
cdef int B = 7 | |
Gx = 55066263022277343669578718895168534326250603453777594175500187360389116729240 | |
Gy = 32670510020758816978083085130507043184471273380659243275938904335757337482424 | |
G = (Gx, Gy) | |
# ##################################### | |
cdef extern from "gmp.h": | |
ctypedef struct mpz_t: | |
pass | |
cdef void mpz_init(mpz_t) | |
cdef void mpz_add(mpz_t, mpz_t, mpz_t) | |
cdef void mpz_tdiv_q(mpz_t, mpz_t, mpz_t) | |
cdef void mpz_add_ui(mpz_t, mpz_t, unsigned long int) | |
cdef void mpz_sub_ui(mpz_t, mpz_t, unsigned long int) | |
cdef void mpz_mul_ui(mpz_t, mpz_t, unsigned long int) | |
cdef void mpz_mul_si(mpz_t, mpz_t, long int) | |
cdef void mpz_addmul_ui(mpz_t, mpz_t, unsigned long int) | |
cdef void mpz_submul_ui(mpz_t, mpz_t, unsigned long int) | |
cdef unsigned long int mpz_get_ui(mpz_t) | |
cdef void mpz_set(mpz_t, mpz_t) | |
cdef void mpz_clear(mpz_t) | |
cdef void mpz_set_ui(mpz_t, unsigned long int) | |
cdef void mpz_set_si(mpz_t, long int) | |
cdef void mpz_init_set_ui(mpz_t, unsigned long int) | |
cdef void mpz_init_set(mpz_t, mpz_t,) | |
cdef int mpz_cmp(mpz_t, mpz_t) | |
cdef void mpz_mul_2exp(mpz_t, mpz_t, unsigned long int) | |
cdef unsigned long int mpz_mod_ui(mpz_t, mpz_t, unsigned long int) | |
cdef void mpz_divexact_ui(mpz_t, mpz_t, unsigned long int) | |
cdef unsigned long int mpz_fdiv_ui(mpz_t, unsigned long int) | |
cdef void mpz_fdiv_q(mpz_t q, const mpz_t n, const mpz_t d) | |
cdef void mpz_mul(mpz_t, mpz_t, mpz_t) | |
cdef void mpz_mod(mpz_t, mpz_t, mpz_t) | |
cdef void mpz_powm(mpz_t, mpz_t, mpz_t, mpz_t) | |
cdef void mpz_pow_ui(mpz_t, mpz_t, unsigned long int) | |
cdef void mpz_sub(mpz_t, mpz_t, mpz_t) | |
ui32 = 2**32 | |
ui32M1 = ui32 - 1 | |
cdef mpz_t m_ui32 | |
cdef mpz_t m_ui32M1 | |
mpz_init_set_ui(m_ui32, 1) | |
mpz_mul_2exp(m_ui32, m_ui32, 32) # Set rop to op1 times 2 raised to op2. | |
mpz_sub_ui(m_ui32M1, m_ui32, 1) | |
cdef mpz_t tmp_z | |
mpz_init(tmp_z) | |
cdef mpz_t mNull | |
mpz_init_set_ui(mNull, 0) | |
cdef class Mpz: | |
cdef mpz_t z | |
def __init__(self, l=0): | |
self.from_pylong(l) | |
def __cinit__(self): | |
mpz_init(self.z) | |
def __dealloc__(self): | |
mpz_clear(self.z) | |
def from_pylong(self, l): | |
assert isinstance(l, (long, int)) | |
l = long(l) | |
cdef unsigned long i = 0 | |
cdef unsigned long r | |
mpz_init_set_ui(self.z, 0) | |
while abs(l) > ui32M1: | |
r = l % ui32 | |
mpz_set_ui(tmp_z, r) | |
mpz_mul_2exp(tmp_z, tmp_z, 32 * i) | |
mpz_add(self.z, self.z, tmp_z) | |
l //= ui32 | |
i += 1 | |
mpz_set_si(tmp_z, l) | |
mpz_mul_2exp(tmp_z, tmp_z, 32 * i) | |
mpz_add(self.z, self.z, tmp_z) | |
def as_pylong(self): | |
cdef unsigned long int r | |
cdef unsigned long int d = ui32 | |
l = 0 | |
i = 0 | |
mpz_set(tmp_z, self.z) | |
is_signed = False | |
if mpz_cmp(tmp_z, mNull) < 0: | |
is_signed = True | |
mpz_mul_si(tmp_z, tmp_z, -1) | |
while mpz_cmp(tmp_z, m_ui32M1) > 0: | |
r = mpz_fdiv_ui(tmp_z, d) | |
l += r * ui32 ** i | |
mpz_divexact_ui(tmp_z, tmp_z, d) | |
i += 1 | |
r = mpz_get_ui(tmp_z) | |
l += r * ui32**i | |
if is_signed: | |
l *= -1 | |
return l | |
cdef mpz_to_long(mpz_t m): | |
mm = Mpz() | |
mpz_set(mm.z, m) | |
return mm.as_pylong() | |
cdef void set_mpz_from_long(mpz_t m, l): | |
mm = Mpz(l) | |
mpz_set(m, mm.z) | |
############################################ | |
cdef mpz_t mOne | |
mpz_init_set_ui(mOne, 1) | |
cdef mpz_t mTwo | |
mpz_init_set_ui(mTwo, 2) | |
cdef mpz_t mysq | |
mpz_init(mysq) | |
cdef mpz_t mS | |
mpz_init(mS) | |
cdef mpz_t mM | |
mpz_init(mM) | |
cdef mpz_t mTmp | |
mpz_init(mTmp) | |
cdef mpz_t mTmp2 | |
mpz_init(mTmp2) | |
cdef mpz_t mP | |
mpz_init(mP) | |
set_mpz_from_long(mP, P) | |
cdef mpz_t mN | |
mpz_init(mN) | |
set_mpz_from_long(mN, N) | |
cdef mpz_t mA | |
mpz_init(mA) | |
set_mpz_from_long(mA, A) | |
cdef int zero = 0 | |
cdef mpz_t mU1 | |
mpz_init(mU1) | |
cdef mpz_t mU2 | |
mpz_init(mU2) | |
cdef mpz_t mS1 | |
mpz_init(mS1) | |
cdef mpz_t mS2 | |
mpz_init(mS2) | |
cdef mpz_t mH | |
mpz_init(mH) | |
cdef mpz_t mH2 | |
mpz_init(mH2) | |
cdef mpz_t mH3 | |
mpz_init(mH3) | |
cdef mpz_t mU1H2 | |
mpz_init(mU1H2) | |
cdef mpz_t mR | |
mpz_init(mR) | |
cdef inv(a, n): | |
if a == 0: | |
return 0 | |
lm = 1 | |
hm = 0 | |
low = a % n | |
high = n | |
while low > 1: | |
r = high // low | |
nm = hm - lm * r | |
new = high - low * r | |
hm = lm | |
lm = nm | |
high = low | |
low = new | |
return lm % n | |
cdef class Jacobian: | |
cdef mpz_t x | |
cdef mpz_t y | |
cdef mpz_t z | |
def __cinit__(self): | |
mpz_init(self.x) | |
mpz_init(self.y) | |
mpz_init(self.z) | |
def __dealloc__(self): | |
mpz_clear(self.x) | |
mpz_clear(self.y) | |
mpz_clear(self.z) | |
cdef Jacobian copy(self): | |
j = Jacobian() | |
mpz_set(j.x, self.x) | |
mpz_set(j.y, self.y) | |
mpz_set(j.z, self.z) | |
return j | |
def equals(self, Jacobian other): | |
return (mpz_cmp(self.x, other.x) == zero) and (mpz_cmp(self.y, other.y) == zero) \ | |
and (mpz_cmp(self.z, other.z) == zero) | |
def from_point(self, xyz): | |
set_mpz_from_long(self.x, xyz[0]) | |
set_mpz_from_long(self.y, xyz[1]) | |
mpz_set(self.z, mOne) | |
def as_point(self): | |
z = inv(mpz_to_long(self.z), P) | |
return ((mpz_to_long(self.x) * z**2) % P, (mpz_to_long(self.y) * z**3) % P) | |
cdef void jdouble(self): | |
if mpz_cmp(self.y, mNull) == zero: | |
mpz_set(self.x, mNull) | |
mpz_set(self.z, mNull) | |
else: | |
# ysq = (self.y ** 2) % P | |
mpz_pow_ui(mysq, self.y, 2) | |
mpz_mod(mysq, mysq, mP) | |
# S = (4 * self.x * ysq) % P | |
mpz_mul(mS, self.x, mysq) | |
mpz_mul_ui(mS, mS, 4) | |
mpz_mod(mS, mS, mP) | |
# M = (3 * self.x ** 2 + A * self.z ** 4) % P | |
mpz_pow_ui(mM, self.x, 2) | |
mpz_mul_ui(mM, mM, 3) | |
mpz_pow_ui(mTmp, self.z, 4) | |
mpz_mul(mTmp, mTmp, mA) | |
mpz_add(mM, mM, mTmp) | |
mpz_mod(mM, mM, mP) | |
# self.x = (M**2 - 2 * S) % P | |
mpz_pow_ui(self.x, mM, 2) | |
mpz_mul_ui(mTmp, mS, 2) # this can be cached | |
mpz_sub(self.x, self.x, mTmp) | |
mpz_mod(self.x, self.x, mP) | |
# self.z = (2 * self.y * self.z) % P # relies on old y | |
mpz_mul_ui(self.z, self.z, 2) | |
mpz_mul(self.z, self.z, self.y) | |
mpz_mod(self.z, self.z, mP) | |
# self.y = (M * (S - self.x) - 8 * ysq ** 2) % P | |
mpz_sub(self.y, mS, self.x) | |
mpz_mul(self.y, self.y, mM) | |
mpz_pow_ui(mTmp, mysq, 2) | |
mpz_mul_ui(mTmp, mTmp, 8) | |
mpz_sub(self.y, self.y, mTmp) | |
mpz_mod(self.y, self.y, mP) | |
cdef void add(self, Jacobian q): | |
if mpz_cmp(self.y, mNull) == zero: | |
mpz_set(self.x, q.x) | |
mpz_set(self.y, q.y) | |
mpz_set(self.z, q.z) | |
elif mpz_cmp(q.y, mNull) == zero: | |
pass | |
else: | |
# U1 = (self.x * q.z ** 2) % P | |
mpz_pow_ui(mU1, q.z, 2) | |
mpz_mul(mU1, mU1, self.x) | |
mpz_mod(mU1, mU1, mP) | |
# U2 = (q.x * self.z ** 2) % P | |
mpz_pow_ui(mU2, self.z, 2) | |
mpz_mul(mU2, mU2, q.x) | |
mpz_mod(mU2, mU2, mP) | |
# S1 = (self.y * q.z ** 3) % P | |
mpz_pow_ui(mS1, q.z, 3) | |
mpz_mul(mS1, mS1, self.y) | |
mpz_mod(mS1, mS1, mP) | |
# S2 = (q.y * self.z ** 3) % P | |
mpz_pow_ui(mS2, self.z, 3) | |
mpz_mul(mS2, mS2, q.y) | |
mpz_mod(mS2, mS2, mP) | |
if mpz_cmp(mU1, mU2) == zero: | |
if mpz_cmp(mS1, mS2) != zero: | |
mpz_set(self.x, mNull) | |
mpz_set(self.y, mNull) | |
mpz_set(self.z, mOne) | |
else: | |
self.jdouble() | |
else: | |
# H = U2 - U1 | |
mpz_sub(mH, mU2, mU1) | |
# R = S2 - S1 | |
mpz_sub(mR, mS2, mS1) | |
# H2 = (H * H) % P | |
mpz_mul(mH2, mH, mH) | |
mpz_mod(mH2, mH2, mP) | |
# H3 = (H * H2) % P | |
mpz_mul(mH3, mH, mH2) | |
mpz_mod(mH3, mH3, mP) | |
# U1H2 = (U1 * H2) % P | |
mpz_mul(mU1H2, mU1, mH2) | |
mpz_mod(mU1H2, mU1H2, mP) | |
# self.x = (R ** 2 - H3 - 2 * U1H2) % P | |
mpz_pow_ui(self.x, mR, 2) | |
mpz_sub(self.x, self.x, mH3) | |
mpz_mul(mTmp, mU1H2, mTwo) | |
mpz_sub(self.x, self.x, mTmp) | |
mpz_mod(self.x, self.x, mP) | |
# self.y = (R * (U1H2 - self.x) - S1 * H3) % P | |
mpz_sub(self.y, mU1H2, self.x) | |
mpz_mul(self.y, self.y, mR) | |
mpz_mul(mTmp, mS1, mH3) | |
mpz_sub(self.y, self.y, mTmp) | |
mpz_mod(self.y, self.y, mP) | |
# works | |
# self.z = H * self.z * q.z | |
mpz_mul(self.z, self.z, mH) | |
mpz_mul(self.z, self.z, q.z) | |
cdef void py_multiply(self, l): | |
cdef mpz_t m | |
mpz_init(m) | |
set_mpz_from_long(m, l) | |
self.multiply(m) | |
cdef void multiply(self, mpz_t n): | |
if mpz_cmp(self.y, mNull) == zero or mpz_cmp(n, mNull) == zero: | |
mpz_set(self.x, mNull) | |
mpz_set(self.y, mNull) | |
mpz_set(self.z, mOne) | |
# elif n == 1: # | |
elif mpz_cmp(n, mOne) == 0: | |
pass | |
# elif n < 0 or n >= N: | |
elif mpz_cmp(n, mNull) < 0 or not (mpz_cmp(n, mN) < 0): | |
# self.multiply(n % mN) | |
mpz_mod(n, n, mN) | |
self.multiply(n) | |
# elif (n % 2) == 0: | |
elif mpz_fdiv_ui(n, 2) == 0: | |
mpz_fdiv_q(n, n, mTwo) | |
self.multiply(n) | |
self.jdouble() | |
else: | |
# elif (n % 2) == 1: | |
c = self.copy() | |
mpz_fdiv_q(n, n, mTwo) | |
self.multiply(n) | |
self.jdouble() | |
self.add(c) | |
del | |
############################################# | |
def hash_to_int(msghash): | |
assert len(msghash) == 32 | |
z = 0 | |
for c in msghash: | |
z *= 256 | |
z += ord(c) | |
return z | |
def ecdsa_raw_recover(msghash, vrs): | |
v, r, s = vrs | |
x = r | |
beta = pow(x * x * x + A * x + B, (P + 1) // 4, P) | |
y = beta if v % 2 ^ beta % 2 else (P - beta) | |
z = hash_to_int(msghash) | |
j = Jacobian() | |
j.from_point(G) | |
j.py_multiply((N - z) % N) | |
j2 = Jacobian() | |
j2.from_point((x, y)) | |
j2.py_multiply(s) | |
j.add(j2) | |
j.py_multiply(inv(r, N)) | |
Q = j.as_point() | |
if ecdsa_raw_verify(msghash, vrs, Q): | |
return Q | |
return False | |
def ecdsa_raw_verify(msghash, vrs, pub): | |
v, r, s = vrs | |
w = inv(s, N) | |
z = hash_to_int(msghash) | |
u1, u2 = z*w % N, r*w % N | |
j = Jacobian() | |
j.from_point(G) | |
j.py_multiply(u1) | |
j2 = Jacobian() | |
j2.from_point(pub) | |
j2.py_multiply(u2) | |
j.add(j2) | |
x, y = j.as_point() | |
return r == x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment