/*
 * ecm.c
 * by jeff hamblin (hamblin@cs.wisc.edu)
 *
 * lenstra's elliptic curve method of factoring 
 * this is my first implementation of this, so it is rather sloppy
 *
 */

#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <math.h>
#include <gmp.h>


int sieve(long bound, int *primes);

int main(int argc, char **argv)
{
	mpz_t composite, gcd, scratch1, scratch2; 
	mpz_t x, y, x1, y1, x2, y2, a, b, discriminant; /* ec stuff */
	long smooth_bound;
	int *primes;
	long pi_bound, exponent;
	int prime_count = 0, i, j;

	if (argc != 3)
	{
		printf("usage: %s composite smoothness_bound\n", argv[0]);
		exit(-1);
	}

	mpz_init_set_str(composite, argv[1], 10);
	smooth_bound = atoi(argv[2]);
	mpz_init(gcd);
	mpz_init(x);
	mpz_init(y);
	mpz_init(x1);
	mpz_init(y1);
	mpz_init(x2);
	mpz_init(y2);
	mpz_init(a);
	mpz_init(b);
	mpz_init(discriminant);
	mpz_init(scratch1);
	mpz_init(scratch2);
	primes = malloc(sizeof(int) * smooth_bound);
	pi_bound = sieve(smooth_bound, primes);
	printf("found %d primes under smoothness-bound %d\n", pi_bound, smooth_bound);
	srand(getpid());

	while (mpz_mod_ui(scratch1, composite, 2) == 0)
	{
		prime_count ++;
		printf("factor: 2\n");
		mpz_div_ui(composite, composite, 2);
	}

	if (mpz_cmp_ui(composite, 1) == 0)
	{
		exit(prime_count);
	}

	if (mpz_probab_prime_p(composite, 25) == 1) 
	{
		prime_count ++;
		printf("factor: ");
		mpz_out_str(stdout, 10, composite);
		printf("\n");
		exit(prime_count);
	}

	while (1)
	{
		/*
		 * set up elliptic curve.  choose x, y, a randomly in Z/n.
		 * choose b s.t. y^2 = x^3 + ax + b
		 *
		 * mpz_random isn't guaranteed to do anything random at all, so
		 * i multiply the results of mpz_random() by rand() and mod by 
		 * the composite to put the number in the correct range.
		 */
		
		mpz_random(x, mpz_size(composite));
		mpz_random(y, mpz_size(composite));
		mpz_random(a, mpz_size(composite));
		mpz_mul_ui(x, x, rand());
		mpz_mul_ui(y, y, rand());
		mpz_mul_ui(a, a, rand());
		mpz_mod(x, x, composite);
		mpz_mod(y, y, composite);
		mpz_mod(a, a, composite);

		mpz_set_ui(x, 0);
		mpz_set_ui(y, 1);

		/* 
		 * now i figure out what b value satisfies the EC equation 
		 */

		mpz_pow_ui(scratch1, x, 3); /* = x^3 */
		mpz_pow_ui(scratch2, y, 2); /* = y^2 */
		mpz_sub(scratch2, scratch2, scratch1); /* = y^2 - x^3 */
		mpz_mul(scratch1, a, x);
		mpz_sub(b, scratch2, scratch1); /* b = y^2 - x^3 - ax */

		/* 
		 * now calculate the discriminant = 4a^3+27b^2 and determine the
		 * gcd of discriminant and composite
		 */

		mpz_pow_ui(scratch1, a, 3);
		mpz_pow_ui(scratch2, b, 2);
		mpz_mul_ui(scratch1, scratch1, 4); /* = 4a^3 */
		mpz_mul_ui(scratch2, scratch2, 27); /* = 27b^2 */
		mpz_add(discriminant, scratch1, scratch2);
		mpz_gcd(gcd, discriminant, composite);

		if (mpz_cmp(gcd, composite) == 0) /* (discriminant, composite) ==  composite */
		{
			/* this ec failed */
		}
		else 
		{
			if (mpz_cmp_ui(gcd, 1) > 0) 
			{
				prime_count ++;
				printf("lucky factor found using point (");
				mpz_out_str(stdout, 10, x);
				printf(", ");
				mpz_out_str(stdout, 10, y);
				printf(")\n");
				printf("and curve parameters\n");
				printf("a = ");
				mpz_out_str(stdout, 10, a);
				printf("\nb = ");
				mpz_out_str(stdout, 10, b);
				printf("\nfactor: ");
				mpz_out_str(stdout, 10, gcd);
				printf("\n---\n");
				mpz_div(composite, composite, gcd);

				if (mpz_cmp_ui(composite, 1) == 0)
				{
					exit(prime_count);
				}
	
				if (mpz_probab_prime_p(composite, 25) == 1)
				{
					prime_count ++;
					printf("factor: ");
					mpz_out_str(stdout, 10, composite);
					printf("\n---\n");
					exit(prime_count);
				}
			}
			else /* gcd is 1 */
			{
				mpz_set(x1, x);
				mpz_set(y1, y);
				for (i = 0; i < smooth_bound; i ++)
				{
					if (primes[i] != 0) /* i is prime by sieve */
					{
						exponent = mpz_sizeinbase(composite, 3);
						exponent /= log((double) i);
						//printf("log_%d = %d\n", i, exponent);
						/* compute prime i ^ log_i(n) */
						for (j = 1; j < exponent; j ++)
						{
							/* x2 = ((3x^2 + a) / 2y )^2 - 2x */
							/* y2 = ((3x^2 + a) / 2y )(x - x2) - y */

							mpz_mul(scratch1, x1, x1);	
							mpz_mul_ui(scratch1, scratch1, 3);	/* 3(x1)^2 */
							mpz_add(scratch1, scratch1, a); /* 3(x1)^2 + a */	
							mpz_mul_ui(scratch2, y1, 2);	/* 2(y1) */
							mpz_div(x2, scratch1, scratch2);
							mpz_set(y2, x2);
							mpz_mul(x2, x2, x2);
							mpz_mul_ui(scratch1, x1, 2); /* 2(x1) */
							mpz_sub(x2, x2, scratch1);

							mpz_sub(scratch1, x1, x2);
							mpz_mul(y2, y2, scratch1);
							mpz_sub(y2, y2, y1);
							
							mpz_sub(scratch1, x1, x2);
							mpz_gcd(gcd, scratch1, composite);
							if ((mpz_cmp(gcd, composite) != 0) &&
								(mpz_cmp_ui(gcd, 1) != 0)) 
							{
								
								prime_count ++;
								printf("unlucky factor found using point (");
								mpz_out_str(stdout, 10, x);
								printf(", ");
								mpz_out_str(stdout, 10, y);
								printf(")\n");
								printf("and curve parameters\n");
								printf("a = ");
								mpz_out_str(stdout, 10, a);
								printf("\nb = ");
								mpz_out_str(stdout, 10, b);
								printf("\nfactor: ");
								mpz_out_str(stdout, 10, gcd);
								printf("\n---\n");
								mpz_div(composite, composite, gcd);

								if (mpz_cmp_ui(composite, 1) == 0)
								{
									exit(prime_count);
								}
	
								if (mpz_probab_prime_p(composite, 25) == 1)
								{
									prime_count ++;
									printf("factor: ");
									mpz_out_str(stdout, 10, composite);
									printf("\n---\n");
									exit(prime_count);
								}
							}
							mpz_set(x1, x2);
							mpz_set(y1, y2);
						}
					}	
				}
			}
		}
	}
}

int sieve(long bound, int *primes) {
	int count = 0;
	long i, j;

	i = 2;
	memset(primes, 1, sizeof(int) * bound);
	primes[0] = primes[1] = 0;
	while (i < bound) 
	{
		primes[i] = 1;
		for (j = i + i; j < bound; j += i) 
		{
			primes[j] = 0;
		}
		count++;
		do /* find the next prime in list */
		{
			i++;
		} 
		while ((primes[i] == 0) && (i < bound)); 
		/* primes[i] == 0 if primes[i] is a composite */
	}
	return count;
}
