#include <inttypes.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <getopt.h>

#include "../xutil.h"
#include "../matrix.h"
#include "../naive.h"
#include "../drepper.h"
#include "../naive_idx.h"
#include "../frigo.h"
#include "../strassen.h"
#include "../strassen_idx.h"

void printusage(char *argv[]) {
	fprintf(stderr, "Usage: %s \n"
			"\t[-m rows of A]\n"
			"\t[-n cols of A]\n"
			"\t[-p cols of B]\n"
			"\t[-i iterations]\n"
			"\t[-s seed]\n"
			, argv[0]);
}

int main(int argc, char *argv[]) {
	int opt;
	uint32_t i, j, k;
	uint32_t m = 0, n = 0, p = 0;
	uint32_t iterations = 0;
	uint32_t seed = 0;
	struct matrix_mult str;
	struct strassen_mult str_str;

	matrix_t *A, *B, *C, *D;

	while ((opt = getopt(argc, argv, "i:s:m:n:p:")) != -1) {
		switch (opt) {
		case 'm':
			m = strtoull(optarg, NULL, 10);
			break;
		case 'n':
			n = strtoull(optarg, NULL, 10);
			break;
		case 'p':
			p = strtoull(optarg, NULL, 10);
			break;
		case 's':
			seed = strtoull(optarg, NULL, 10);
			break;
		case 'i':
			iterations = strtoull(optarg, NULL, 10);
			break;
		default:
			printusage(argv);
			exit(EXIT_FAILURE);
		}
	}

	if (!seed) {
		srand(seed);
	}

	if (!m || !n || !p) {

		if (!iterations) {
			printusage(argv);
			exit(EXIT_FAILURE);
		}

		for (i = 0; i < iterations; i++) {
			m = ceil(rand()/pow(2, 21));
			n = ceil(rand()/pow(2, 21));
			p = ceil(rand()/pow(2, 21));

			printf(
					"Testing on A(%"PRIu32"x%"PRIu32") * B(%"PRIu32"x%"PRIu32") = " 
					"C(%"PRIu32"x%"PRIu32")\n", m, n, n, p, m, p
			);

			A = matrix_create(m, n);
			B = matrix_create(n, p);
			C = matrix_create(m, p);
			D = matrix_create(m, p);

			matrix_randomize(A, (unsigned int)rand());
			matrix_randomize(B, (unsigned int)rand());

			str.A = A;
			str.B = B;
			str.C = C;

			str_str.A = A;
			str_str.B = B;
			str_str.C = D;
			str_str.th = 0;

			naive(&str);

			str.C = D;

			matrix_zero(D);
			frigo(&str);

			for (j = 0; j < m; j++) {
				for (k = 0; k < p; k++) {
					if ( MAT(C, j, k) != MAT(D, j, k) ) {
						matrix_destroy(A);
						matrix_destroy(B);
						matrix_destroy(C);
						matrix_destroy(D);

						xerror(
								"Naive and frigo matrices were different!", 
								__LINE__, 
								__FILE__
						); 
						return EXIT_FAILURE;
					}
				}
			}

			matrix_zero(D);

			for (j = 0; j < D->m; j++) {
				for (k = 0; k < D->n; k++) {
					if (MAT(D, j, k) != 0) {
						printf("%"PRId32"\n", MAT(D, j, k));
						//printf("%"PRId32", %"PRId32"\n", MAT(D, j, k), MAT(C, j, k));
					}
				}
			}
			strassen(&str_str);

			for (j = 0; j < m; j++) {
				for (k = 0; k < p; k++) {
					if (MAT(C, j, k) != MAT(D, j, k)) {
						matrix_destroy(A);
						matrix_destroy(B);
						matrix_destroy(C);
						matrix_destroy(D);

						xerror(
								"Naive and strassen matrices were different!", 
								__LINE__, 
								__FILE__
						); 
						return EXIT_FAILURE;
					}
				}
			}

			matrix_zero(D);
			naive_idx(&str);

			for (j = 0; j < m; j++) {
				for (k = 0; k < p; k++) {
					if ( MAT(C, j, k) != MAT(D, j, k) ) { 
						matrix_destroy(A);
						matrix_destroy(B);
						matrix_destroy(C);
						matrix_destroy(D);

						xerror(
								"Naive and naive_idx matrices were different!", 
								__LINE__, 
								__FILE__
						); 
						return EXIT_FAILURE;
					}
				}
			}

			matrix_zero(D);
			drepper(&str);

			for (j = 0; j < m; j++) {
				for (k = 0; k < p; k++) {
					if ( MAT(C, j, k) != MAT(D, j, k) ) { 
						matrix_destroy(A);
						matrix_destroy(B);
						matrix_destroy(C);
						matrix_destroy(D);

						xerror(
								"Naive and drepper matrices were different!", 
								__LINE__, 
								__FILE__
						); 
						return EXIT_FAILURE;
					}
				}
			}

			matrix_zero(D);
			strassen_idx(&str_str);

			for (j = 0; j < m; j++) {
				for (k = 0; k < p; k++) {
					if (MAT(C, j, k) != MAT(D, j, k)) {
						matrix_destroy(A);
						matrix_destroy(B);
						matrix_destroy(C);
						matrix_destroy(D);

						xerror(
								"Naive and strassen_idx matrices were different!", 
								__LINE__, 
								__FILE__
						); 
						return EXIT_FAILURE;
					}
				}
			}

			matrix_destroy(A);
			matrix_destroy(B);
			matrix_destroy(C);
			matrix_destroy(D);
		}
	} else {
		printf(
				"Testing on A(%"PRIu32"x%"PRIu32") * B(%"PRIu32"x%"PRIu32") = " 
				"C(%"PRIu32"x%"PRIu32")\n", m, n, n, p, m, p
		);

		A = matrix_create(m, n);
		B = matrix_create(n, p);
		C = matrix_create(m, p);
		D = matrix_create(m, p);

		matrix_randomize(A, (unsigned int)rand());
		matrix_randomize(B, (unsigned int)rand());

		str.A = A;
		str.B = B;
		str.C = C;

		str_str.A = A;
		str_str.B = B;
		str_str.C = D;
		str_str.th = 0;

		naive(&str);

		str.C = D;

		matrix_zero(D);
		frigo(&str);

		for (j = 0; j < m; j++) {
			for (k = 0; k < p; k++) {
				if ( MAT(C, j, k) != MAT(D, j, k) ) {
					matrix_destroy(A);
					matrix_destroy(B);
					matrix_destroy(C);
					matrix_destroy(D);

					xerror(
							"Naive and frigo matrices were different!", 
							__LINE__, 
							__FILE__
					); 
					return EXIT_FAILURE;
				}
			}
		}

		matrix_zero(D);
		strassen(&str_str);

		for (j = 0; j < m; j++) {
			for (k = 0; k < p; k++) {
				if (MAT(C, j, k) != MAT(D, j, k)) {
					matrix_destroy(A);
					matrix_destroy(B);
					matrix_destroy(C);
					matrix_destroy(D);
					xerror(
							"Naive and strassen matrices were different!", 
							__LINE__, 
							__FILE__
					); 
					return EXIT_FAILURE;
				}
			}
		}

		matrix_zero(D);
		naive_idx(&str);

		for (j = 0; j < m; j++) {
			for (k = 0; k < p; k++) {
				if ( MAT(C, j, k) != MAT(D, j, k) ) { 
					matrix_destroy(A);
					matrix_destroy(B);
					matrix_destroy(C);
					matrix_destroy(D);

					xerror(
							"Naive and naive_idx matrices were different!", 
							__LINE__, 
							__FILE__
					); 
					return EXIT_FAILURE;
				}
			}
		}

		matrix_zero(D);
		strassen_idx(&str_str);

		for (j = 0; j < m; j++) {
			for (k = 0; k < p; k++) {
				if (MAT(C, j, k) != MAT(D, j, k)) {
					matrix_destroy(A);
					matrix_destroy(B);
					matrix_destroy(C);
					matrix_destroy(D);

					xerror(
							"Naive and strassen_idx matrices were different!", 
							__LINE__, 
							__FILE__
					); 
					return EXIT_FAILURE;
				}
			}
		}

		matrix_destroy(A);
		matrix_destroy(B);
		matrix_destroy(C);
		matrix_destroy(D);
	}

	printf("SUCCESS\n");

	return EXIT_SUCCESS;
}
