#include <stdio.h>
#include <stdint.h>
#include <inttypes.h>
#include <math.h>

#include "classic.h"
#include "naive.h"
#include "bitvector.h"
#include "xutil.h"

void succinct_classic_preprocess(bitvector_t *restrict B) {
	uint64_t i, j, k;
	uint64_t bits = B->bits;
	double logSize = log2(bits), logB;
	uint64_t index = 0;
	uint64_t tmp = 0;
	uint64_t size;
	uint64_t sp, smodbit;
	uint64_t read;

	// Create table structure to construct O(1) lookups
	classic_t *classic = xmalloc(sizeof(classic_t));
	classic->b = floor((double)logSize/2);
	classic->s = classic->b*floor(logSize);
	classic->logS = floor(log2(classic->s)+1);
	classic->bits = floor(log2(bits)+1);

	// Cache calculations
	logB = log2(classic->b);
	smodbit = classic->s % WORD;
	read = classic->b;

	// (log_2(n)+1)*(n/s) = n/(0.5*(log_2(n)+1))
	size = ceil((double)((bits/classic->s+1)*classic->bits)/WORD)
			*sizeof(bitvector);
	classic->Rs = xmalloc(size);
	xmemset(classic->Rs, size);

	// n/b*log_2(s)
	size = ceil((((double)bits/classic->b+1)*classic->logS)/WORD)
			*sizeof(bitvector);
	classic->Rb = xmalloc(size);
	xmemset(classic->Rb, size);

	size = (ceil((double)(((bitvector)1 << classic->b)*classic->b*logB)/WORD) 
			 + (classic->b == 1))*sizeof(bitvector);
	classic->Rp = xmalloc(size);
	xmemset(classic->Rp, size);

	// Calculating Rs
	bitvector_set_bits(classic->Rs, index, classic->bits, 0);
	index += classic->bits;

	for (i = 1; i <= floor(bits/classic->s); i++) {
		sp = 0;

		for (j = 0; j < floor(classic->s/WORD); j++) {
			sp += __builtin_popcountll(
				bitvector_get_bits(
						B->B, (i-1)*classic->s+j*WORD, WORD	
				)
			);
		}

		if (smodbit) {
			sp += __builtin_popcountll(
				bitvector_get_bits(B->B, (i-1)*classic->s+j*WORD, smodbit)
			);
		}

		bitvector_set_bits(
				classic->Rs, 
				index, 
				classic->bits, 
				bitvector_get_bits(
					classic->Rs, index-classic->bits, classic->bits
				) + sp
		);

//		printf("OLD: %"PRIu64"\n",
//				succinct_naive_rank(B, i*classic->s));
//		printf("NEW: %"PRIu64"\n\n", bitvector_get_bits(classic->Rs, index, classic->bits));

		index += classic->bits;
	}

	// Reset values to be used later
	index = 0;

	// Calculating Rb
	bitvector_set_bits(classic->Rb, index, classic->logS, 0);
	index += classic->logS;

	for (i = 1; i <= floor(bits/classic->b); i++) {
		sp = 0;
		j = floor(i/floor(logSize));

		if (tmp != j) {
			tmp = j;
			read = classic->b;
			bitvector_set_bits(classic->Rb, index, classic->logS, 0);
		} else {
			smodbit = read % WORD;

			for (k = 0; k < floor(read/WORD); k++) {
				sp += __builtin_popcountll(
					bitvector_get_bits(B->B, j*classic->s + k*WORD, WORD)
				);
			}

			if (smodbit) {
				sp += __builtin_popcountll(
					bitvector_get_bits(B->B, j*classic->s + k*WORD, smodbit)
				);
			}

			bitvector_set_bits(classic->Rb, index, classic->logS, sp);

			read += classic->b;
		}

//		printf("OLD: %"PRIu64"\n",
//				succinct_naive_rank(B, i*classic->b) -
//				bitvector_get_bits(classic->Rs, j*classic->bits, classic->bits));
//		printf("NEW: %"PRIu64"\n\n", bitvector_get_bits(classic->Rb, index, classic->logS));

		// Every number can at most fill log2(s) bits
		index += classic->logS;
	}

	// We create new bitvector to utilize the succinct_naive_rank function
	uint64_t N = ceil((double)classic->b/WORD);
	struct succinct_t succ_bv;
	bitvector n[1];
	bitvector_t bv = (bitvector_t) {
		.B = &n[0],
		.N = N,
		.size = N*sizeof(bitvector),
		.bits = classic->b
	}; 

	succ_bv.B = &bv;

	index = 0;

	// Calculating Rp
	for (k = 0; k < ((bitvector)1 << classic->b); k++) {
		// Change value of the temporary bitvector
		n[0] = k << (WORD-classic->b);

		bitvector_set_bits(classic->Rp, index, 1, 0);
		index++;

		for (i = 1; i < classic->b; i++) {
			succ_bv.i = i;
			bits = floor(log2(i)+1);
			bitvector_set_bits(
					classic->Rp,
					index,
					bits,
					succinct_naive_rank(&succ_bv)
			);
			index += bits;
		}
	}

	B->table = classic;
}

void succinct_classic_postprocess(bitvector_t *restrict B) {
	classic_t *classic = (classic_t *) B->table;

	if (NULL != classic) {
		if (NULL != classic->Rs) {
			free(classic->Rs);
		}
		
		if (NULL != classic->Rb) {
			free(classic->Rb);
		}

		if (NULL != classic->Rp) {
			free(classic->Rp);
		}

		free(classic);
	}
}

/**
 * Returns amount of the 1's up until the i'th offset in the bitvector B
 */
uint64_t succinct_classic_rank(struct succinct_t *restrict succ) {
	bitvector_t *restrict B = succ->B;
	classic_t *restrict classic = (classic_t *)B->table;
	uint64_t i = succ->i;
	uint64_t j;
	uint64_t poffset = 1;
	uint64_t s = classic->s;
	uint64_t b = classic->b;
	uint64_t logS = classic->logS;
	uint64_t bits = classic->bits;
	uint64_t idivb = i/b;
	uint64_t imodb = i%b;
	uint64_t read = (idivb*b+b <= B->bits) ? b : B->bits-idivb*b;
	uint64_t bsize = (imodb > 1) ? floor(log2(imodb)+1) : 1;

	if (i == 0) {
		return 0;
	}

	for (j = 1; j <= (b-1); j++) {
	 	poffset += floor(log2(j)+1);
	}

	// Multiply stream S
	poffset *= (read > 0 && idivb*b+read <= B->bits) ? 
		bitvector_get_bits(B->B, idivb*b, read) << (b-read) : 0;

	for (j = 0; j < imodb; j++) {
		poffset += (j > 1) ? floor(log2(j)+1) : 1;
	}

	return bitvector_get_bits(classic->Rs, (i/s)*bits, bits) + 
		   bitvector_get_bits(classic->Rb, idivb*logS, logS) +
		   bitvector_get_bits(classic->Rp, poffset, bsize);
}

/**
 * Returns offset of the j'th accurance of 1 in the bitvector B
 */
uint64_t succinct_classic_select(struct succinct_t *restrict succ) {
	bitvector_t *restrict B = succ->B;
	classic_t *restrict classic = (classic_t *)B->table;
	uint64_t j = succ->i;
	uint64_t s = classic->s;
	uint64_t b = classic->b;
	uint64_t bits = classic->bits;
	uint64_t tmp1 = 0, tmp2 = 0;
	uint64_t read;
	uint64_t mid;
	int64_t l = 0;
	int64_t r = floor(B->bits/s);

	if (j == 0) {
		return 0;
	}

	// Binary search Rs
	while (r-l > 1) {
		mid = l + ((r-l) >> 1);
		tmp1 = bitvector_get_bits(classic->Rs, mid*bits, bits); 

 		if (tmp1 < j) {
			l = mid;
		} else {
			r = mid;
		}
	}

	// Check r
	tmp1 = bitvector_get_bits(classic->Rs, r*bits, bits);
	if (tmp1 < j) {
		l = r;
	} else {
		tmp1 = bitvector_get_bits(classic->Rs, l*bits, bits);
	}
	mid = l*s;
	tmp2 = 0;
	read = (B->bits-mid > b) ? b : B->bits-mid;

	// Popcount each miniblock in Rs
	while (read > 0 && tmp1 + tmp2 < j) {
		read = (B->bits-mid > b) ? b : B->bits-mid;
		if ( unlikely(read == 0) ) {
			break;
		}
		tmp1 += tmp2;
		tmp2 = __builtin_popcountll(bitvector_get_bits(B->B, mid, read));
		mid += read;
	} 

	// If we have search to the end, return amount of bits in bitvector
	mid -= read;
	if (mid == B->bits) {
		return B->bits;
	}

	// Sequentially search miniblock containing result
	tmp2 = 0;
	while (tmp2 < b) {
		tmp1 += BIT(B, mid+1);
		if (tmp1 == j) {
			return mid+1;
		}
		tmp2++;
		mid++;
	}

	return B->bits;
}
