#include <inttypes.h>
#include <stdio.h>
#include <assert.h>
#include <string.h>

#include "strassen.h"
#include "matrix.h"
#include "xutil.h"

#define SM (CLS/sizeof(uint32_t))

uint32_t THRESHOLD;

int32_t *restrict R1;
int32_t *restrict R2;

// Matrix addition
void mm_strassen_add(
		const matrix_t *restrict A, const matrix_t *restrict B, 
		matrix_t *restrict C, const uint32_t ai, const uint32_t aj, 
		const uint32_t bi, const uint32_t bj, const uint32_t ci, 
		const uint32_t cj, const uint32_t m, const uint32_t n) {
	uint32_t i, j;

	for (i = 0; i < m; i++) {
		for (j = 0; j < n; j++) {
			MAT(C, (ci + i), (cj + j)) = 
				MAT(A, (ai + i), (aj + j)) + MAT(B, (bi + i), (bj + j));
		}
	}
}

// Matrix subtraction
void mm_strassen_sub(
		const matrix_t *restrict A, const matrix_t *restrict B, 
		matrix_t *restrict C, const uint32_t ai, const uint32_t aj, 
		const uint32_t bi, const uint32_t bj, const uint32_t ci, 
		const uint32_t cj, const uint32_t m, const uint32_t n) {
	uint32_t i, j;

	for (i = 0; i < m; i++) {
		for (j = 0; j < n; j++) {
			MAT(C, (ci + i), (cj + j)) = 
				MAT(A, (ai + i), (aj + j)) - MAT(B, (bi + i), (bj + j));
		}
	}
}

// Naive matrix multiplication as fallback
void mm_strassen_naive_mul(
		const matrix_t *restrict A, const matrix_t *restrict B, 
		matrix_t *restrict C, const uint32_t ai, const uint32_t aj, 
		const uint32_t bi, const uint32_t bj, const uint32_t ci, 
		const uint32_t cj, const uint32_t m, const uint32_t n, 
		const uint32_t p) {
	//uint32_t i, j, k;
	//int32_t sum;

	uint32_t i, j, k, i2, j2, k2;

	for (i = 0; i < m; i++) {
		memset(&MAT(C, (ci + i), cj), 0, p*sizeof(uint32_t));
		/*
		for (j = 0; j < p; j++) {
			//sum = 0;

			for (k = 0; k < n; k++) {
				//sum += MAT(A, (ai + i), (aj + k)) * MAT(B, (bi + k), (bj + j));
			}

			//MAT(C, (ci + i), (cj + j)) = sum;
			MAT(C, (ci + i), (cj + j)) = 0;
		}
		*/
	}

	uint32_t iend, jend, kend;

	uint32_t Ac = A->n,
			 Bc = B->n,
			 Cc = C->n;

	int32_t *restrict rres, *restrict rmul1, *restrict rmul2;

	for (i = 0; i < m; i += SM) {
		iend = i + SM > m ? m-i : SM;

		for(j = 0; j < p; j += SM) {
			jend = j + SM > p ? p-j : SM;

			for (k = 0; k < n; k += SM) {
				kend = k + SM > n ? n-k : SM;

				for (i2 = 0, rres = &C->data[Cc*(i+ci) + j + cj],
						rmul1 = &A->data[Ac*(i+ai) + k + aj]; i2 < iend;
						++i2, rres += Cc, rmul1 += Ac) {

					for (k2 = 0, rmul2 = &B->data[Bc*(bi + k) + j + bj];
							k2 < kend;
							++k2, rmul2 += Bc) {

						for (j2 = 0; j2 < jend; ++j2) {
							rres[j2] += rmul1[k2] * rmul2[j2];
						}
					}
				}
			}
		}
	}
}

// Naive matrix multiplication with addition to matrix C
void mm_strassen_naive_mul_add(
		const matrix_t *restrict A, const matrix_t *restrict B, 
		matrix_t *restrict C, const uint32_t ai, const uint32_t aj, 
		const uint32_t bi, const uint32_t bj, const uint32_t ci, 
		const uint32_t cj, const uint32_t m, const uint32_t n, 
		uint32_t p) {

	uint32_t i, j;

	for (i = 0; i < m; i++) {
		for (j = 0; j < p; j++) {
			/*
			sum = 0;

			for (k = 0; k < n; k++) {
				sum += MAT(A, (ai + i), (aj + k)) * MAT(B, (bi + k), (bj + j));
			}
			*/

			MAT(C, (ci + i), (cj + j)) += MAT(A, (ai + i), (aj + 0)) * MAT(B, (bi + 0), (bj + j));
		}
	}
}

void mm_strassen_mul(
		const matrix_t *restrict A, const matrix_t *restrict B, 
		matrix_t *restrict C, const uint32_t ai, const uint32_t aj, 
		const uint32_t bi, const uint32_t bj, const uint32_t ci, 
		const uint32_t cj, const uint32_t m, const uint32_t n, 
		const uint32_t p, const uint32_t xo, const uint32_t yo) {

	// If we get below the threshold, then switch to ordinary matrix 
	// multiplication to speed up the algorithm.
	if (n <= THRESHOLD || m <= THRESHOLD || p <= THRESHOLD) {
		mm_strassen_naive_mul(A, B, C, ai, aj, bi, bj, ci, cj, m, n, p);
		return;
	}

	// Used to determine if we need to apply dynamic peeling
	const uint32_t m0 = m & 1;
	const uint32_t n0 = n & 1;
	const uint32_t p0 = p & 1;

	// Split the matrices in half leaving us with 4 blocks
	const uint32_t mh = m >> 1;
	const uint32_t nh = n >> 1;
	const uint32_t ph = p >> 1;

	// Calculate temporary matrices, used to hold temporary values when doing
	// the strassen calculations. We use the temporary memory R1 and R2 to 
	// avoid allocating data space for new matrices at each recursive level.
	matrix_t X = (matrix_t) {
		.data = R1 + xo,
		.m = mh,
		.n = MAX(nh, ph) 
	};
	matrix_t Y = (matrix_t) {
		.data = R2 + yo,
		.m = nh,
		.n = ph 
	};

	// Calculate indices for the blocks in matrice A
	const uint32_t A11i = ai,    A11j = aj;
	const uint32_t A12i = ai,    A12j = aj+nh;
	const uint32_t A21i = ai+mh, A21j = aj;
	const uint32_t A22i = A21i,  A22j = A12j;

	// Calculate indices for the blocks in matrice B
	const uint32_t B11i = bi,    B11j = bj;
	const uint32_t B12i = bi,    B12j = bj+ph;
	const uint32_t B21i = bi+nh, B21j = bj;
	const uint32_t B22i = B21i,  B22j = B12j;

	// Calculate indices for the blocks in matrice C
	const uint32_t C11i = ci,    C11j = cj;
	const uint32_t C12i = ci,    C12j = cj+ph;
	const uint32_t C21i = ci+mh, C21j = cj;
	const uint32_t C22i = C21i,  C22j = C12j;

	const uint32_t offsetx = xo+mh*MAX(nh, ph);
	const uint32_t offsety = yo+nh*ph;

	/* We use the schedule in table 1 proposed by Brice Boyer et al. */

	// S3 = A11 - A21 => X
	mm_strassen_sub(A, A, &X, A11i, A11j, A21i, A21j, 0, 0, mh, nh);

	// T3 = B22 - B12 => Y
	mm_strassen_sub(B, B, &Y, B22i, B22j, B12i, B12j, 0, 0, nh, ph);

	// P7 = S3  * T3  => C21
	mm_strassen_mul(&X, &Y, C, 0, 0, 0, 0, C21i, C21j, mh, nh, ph, 
			offsetx, offsety);

	// S1 = A21 + A22 => X
	mm_strassen_add(A, A, &X, A21i, A21j, A22i, A22j, 0, 0, mh, nh);

	// T1 = B12 - B11 => Y
	mm_strassen_sub(B, B, &Y, B12i, B12j, B11i, B11j, 0, 0, nh, ph);

	// P5 = S1  * T1  => C22
	mm_strassen_mul(&X, &Y, C, 0, 0, 0, 0, C22i, C22j, mh, nh, ph, 
			offsetx, offsety);

	// S2 = S1 - A11  => X
	mm_strassen_sub(&X, A, &X, 0, 0, A11i, A11j, 0, 0, mh, nh);

	// T2 = B22 - T1  => Y
	mm_strassen_sub(B, &Y, &Y, B22i, B22j, 0, 0, 0, 0, nh, ph);

	// P6 = S2  * T2  => C12
	mm_strassen_mul(&X, &Y, C, 0, 0, 0, 0, C12i, C12j, mh, nh, ph, 
			offsetx, offsety);

	// S4 = A12 - S2  => X
	mm_strassen_sub(A, &X, &X, A12i, A12j, 0, 0, 0, 0, mh, nh);

	// P3 = S4  * B22 => C11
	mm_strassen_mul(&X, B, C, 0, 0, B22i, B22j, C11i, C11j, mh, nh, ph, 
			offsetx, offsety);

	// P1 = A11 * B11 => X
	mm_strassen_mul(A, B, &X, A11i, A11j, B11i, B11j, 0, 0, mh, nh, ph, 
			offsetx, offsety);

	// U2 = P1  + P6  => C12
	mm_strassen_add(&X, C, C, 0, 0, C12i, C12j, C12i, C12j, mh, ph);

	// U3 = U2  + P7  => C21
	mm_strassen_add(C, C, C, C12i, C12j, C21i, C21j, C21i, C21j, mh, ph);

	// U4 = U2  + P5  => C12
	mm_strassen_add(C, C, C, C12i, C12j, C22i, C22j, C12i, C12j, mh, ph);

	// U7 = U3  + P5  => C22
	mm_strassen_add(C, C, C, C21i, C21j, C22i, C22j, C22i, C22j, mh, ph);

	// U5 = U4  + P3  => C12
	mm_strassen_add(C, C, C, C12i, C12j, C11i, C11j, C12i, C12j, mh, ph);

	// T4 = T2  - B21 => Y
	mm_strassen_sub(&Y, B, &Y, 0, 0, B21i, B21j, 0, 0, nh, ph);

	// P4 = A22 * T4  => C11
	mm_strassen_mul(A, &Y, C, A22i, A22j, 0, 0, C11i, C11j, mh, nh, ph, 
			offsetx, offsety);

	// U6 = U3  - P4  => C21
	mm_strassen_sub(C, C, C, C21i, C21j, C11i, C11j, C21i, C21j, mh, ph);

	// P2 = A12 * B21 => C11
	mm_strassen_mul(A, B, C, A12i, A12j, B21i, B21j, C11i, C11j, mh, nh, ph, 
			offsetx, offsety);

	// U1 = P1  + P2  => C11
	mm_strassen_add(&X, C, C, 0, 0, C11i, C11j, C11i, C11j, mh, ph);

	// Make dynamic peeling as done in Huss-Lederman et al.
	
	// C11 = a12 * b21 + C11
	if ( n0 ) {
		//        #*1     1*# 
		// C11 += a12  *  b21
		mm_strassen_naive_mul_add(A, B, C, ai, aj+n-n0, bi+n-n0, bj, ci, cj, 
				m-m0, n0, p-p0);
	}

	//                     ( b12 )
	// c12 = ( A11 a12 ) * ( b22 )
	if ( p0 ) {
		//               ( b12 )
		// ( A11 a12 ) * ( b22 ) => c12
		mm_strassen_naive_mul(A, B, C, ai, aj, bi, bj+p-p0, ci, cj+p-p0, m-m0, 
				n, p0);
	}

	//                             ( B11 b12 )
	// ( c21 c22 ) = ( a21 a22 ) * ( b21 b22 )
	if ( m0 ) {
		mm_strassen_naive_mul(A, B, C, ai+m-m0, aj, bi, bj, ci+m-m0, cj, m0, n, 
				p);
	}
}

void strassen(struct strassen_mult *restrict str) {
	matrix_t *restrict A = str->A;
	matrix_t *restrict B = str->B;
	matrix_t *restrict C = str->C;
	uint32_t th = str->th;

	assert(A->n == B->m);

	/* 
	 * According to Huss-Lederman et al. 13 should be picked for even sized 
	 * matrices through the recursion and 23 should be picked, when dynamic 
	 * peeling is used.
	 */
	
	char *val = getenv("STRASSEN_STOP");

	if (val != NULL) {
		th = atoi(val) ;
	} else {
		THRESHOLD = 23;
	}

	if (th > 0) {
		THRESHOLD = th;	
	}

	// Allocate extra memory that is used as temporary memory in the strassen
	// recursion
	R1 = xmalloc(((A->m*MAX(A->n, B->n)+(A->n*B->n))/3)*sizeof(int32_t));
	R2 = R1 + ((A->m*MAX(A->n, B->n))/3);

	mm_strassen_mul(A, B, C, 0, 0, 0, 0, 0, 0, A->m, A->n, B->n, 0, 0);

	// Free extra memory
	free(R1);
}

