#!/usr/local/bin/python

"""
Derive RSA modulus n, given two signed messages and their signatures.
Copyright 2007 Nate Lawson

S1^e - M1 = K1 * n
S2^e - M2 = K2 * n
n ~= gcd(K1 * n, K2 * n)

Algorithm described to me by Paul Kocher and Joshua Jaffe.
"""

# Toy values for simple RSA system
e = 5
d = 29
n = 5*7

# This can be any number of plaintext messages < n.  The more, the more
# accurate the value derived for n (up to a point).  Remember, this is
# actually a hash and padding, not the signed data itself.
msgs = [3, 22, 5]

import operator

def derive_n(e, sigs, msgs):
	'Figure out RSA modulus n from a set of signatures and messages'
	x = [(s**e) - m for m, s in zip(msgs, sigs)]
	factor_list = []
	for y, z in pairs(x):
		g = gcd(y, z)
		factor_list.append(set(trialfactor(g)))
	# Find the intersection of all factors, this will leave us with n
	int_list = reduce(set.intersection, factor_list)
	n = reduce(operator.mul, int_list)
	return n

def gcd(a, b):
	'Calculate greatest common divisor of a and b'
	if a == 0:
		return b
	return abs(gcd(b % a, a))

def pairs(vals):
	'Generate all pairs of items from a flat list'
	x = []
	for i in vals[1:]:
		x.append((vals[0], i))
	if len(vals) >= 2:
		x += pairs(vals[1:])
	return x

def trialfactor(val):
	'Return all factors of a number using trial division by primes < ~1000'
	primes = [
		2, 3, 5, 7, 11, 13, 17, 19, 23, 29,
		31, 37, 41, 43, 47, 53, 59, 61, 67, 71,
		73, 79, 83, 89, 97, 101, 103, 107, 109, 113,
		127, 131, 137, 139, 149, 151, 157, 163, 167, 173,
		179, 181, 191, 193, 197, 199, 211, 223, 227, 229,
		233, 239, 241, 251, 257, 263, 269, 271, 277, 281,
		283, 293, 307, 311, 313, 317, 331, 337, 347, 349,
		353, 359, 367, 373, 379, 383, 389, 397, 401, 409,
		419, 421, 431, 433, 439, 443, 449, 457, 461, 463,
		467, 479, 487, 491, 499, 503, 509, 521, 523, 541,
		547, 557, 563, 569, 571, 577, 587, 593, 599, 601,
		607, 613, 617, 619, 631, 641, 643, 647, 653, 659,
		661, 673, 677, 683, 691, 701, 709, 719, 727, 733,
		739, 743, 751, 757, 761, 769, 773, 787, 797, 809,
		811, 821, 823, 827, 829, 839, 853, 857, 859, 863,
		877, 881, 883, 887, 907, 911, 919, 929, 937, 941,
		947, 953, 967, 971, 977, 983, 991, 997, 1009, 1013 ]
	fact = []
	max = primes[-1]
	for i in primes:
		if i >= val:
			break
		if val % i == 0:
			fact.append(i)
	# If no small factors, return just the number itself
	if len(fact) == 0:
		fact.append(val)
	return fact

# Simulate factory by calculating signatures for all messages
sigs = [(x**d) % n for x in msgs]

# Run attack to derive modulus n
guess_e = 5
print 'messages', msgs
print 'signatures', sigs
print 'guess of e =', guess_e
print 'RSA modulus n is', derive_n(guess_e, sigs, msgs)

