/* fermat.c
 * by jeff hamblin (hamblin@cs.wisc.edu)
 *
 * given n = ab, find a and b.  to do this, we try to represent n as the 
 * difference of two perfect squares.  for all i > 0, x = sqrt(n) + i.
 * calculate xx - n for all i until xx - n = yy, a perfect square.  Now,
 *		xx - n = yy,  thus
 *		n = xx - yy
 *		n = (x - y)(x + y)
 *		a = x - y
 * 		b = x + y
 *
 * be sure to link in the gmp library, as usual
 */

#include <stdio.h>
#include <stdlib.h>
#include <gmp.h>

int main(int argc, char **argv)
{
	mpz_t composite, a, b, x, xx, diff, y;
	int two_flag = 0;

	if (argc != 2)
	{
		printf("usage: %s composite\n", argv[0]);
		exit(-1);
	}

	mpz_init_set_str(composite, argv[1], 10);
	mpz_init(a);
	mpz_init(b);
	mpz_init(x);
	mpz_init(xx);
	mpz_init(y);
	mpz_init(diff);

	mpz_sqrt(x, composite);

	printf("---\nBeginning Fermat Factorization...\n---\n");

	while (mpz_mod_ui(a, composite, 2) == 0)
	{
		mpz_div_ui(composite, composite, 2);
		two_flag ++;
	}


	if (mpz_probab_prime_p(composite, 25) == 1)
	{
		printf("Number to factor is prime!\n");
		exit(0);
	}

	if (mpz_perfect_square_p(composite) == 1)
	{
		printf("Number to factor is perfect square!\n");
		mpz_sqrt(a, composite);
		mpz_out_str(stdout, 10, composite);
		printf(" = \n\t");
		mpz_out_str(stdout, 10, a);
		printf("\n\t*\n\t");
		mpz_out_str(stdout, 10, a);
		printf("\n---\n");
		exit(0);
	}


	do {
		mpz_add_ui(x, x, 1);
		mpz_mul(xx, x, x);
		mpz_sub(diff, xx, composite);
/*
		mpz_out_str(stdout, 10, xx);
		printf(" - ");
		mpz_out_str(stdout, 10, composite);
		printf(" = ");
		mpz_out_str(stdout, 10, diff);
		printf("\n");
*/		
	}
	while (mpz_perfect_square_p(diff) == 0);
	
		mpz_out_str(stdout, 10, xx);
		printf(" - ");
		mpz_out_str(stdout, 10, composite);
		printf(" = ");
		mpz_out_str(stdout, 10, diff);
		printf("\n");

	mpz_sqrt(y, diff);	
	mpz_sub(a, x, y);
	mpz_add(b, x, y);
	
	printf("\n---\n");
	mpz_out_str(stdout, 10, composite);
	printf(" = \n\t");
	while (two_flag)
	{
		printf("2\n\t*\n\t");
		two_flag --;
	}
	mpz_out_str(stdout, 10, a);
	printf("\n\t*\n\t");
	mpz_out_str(stdout, 10, b);
	printf("\n---\n");
	exit(0);
}

