#!/usr/bin/python3

import numpy
import fasta

def predict(hmm, inputfile):
    dna = fasta.read(inputfile)
    N = len(dna)
    K = len(hmm['states'])
    A = hmm['A']
    phi = hmm['phi']

    # Actually ŵ
    w = numpy.ones(shape=(K, N), dtype=numpy.float64, order='F')*numpy.NINF

    # w[k][n] is probability of having emitted all previous n symbols, and
    # ending up in state k.

    # initialize w[:][0] to the (log of) chances of starting in the state 
    # emitting that symbol.
    for k in range(K):
        x = hmm['aIndex'][dna[0]]  # The index of current symbol

        if phi[k][x] == 0:
            continue

        # phi[k][x] is the p(x|k), where x is the symbol and k is the state
        w[k][0] = numpy.log(phi[k][x])

    for n in range(1, N):
        x = hmm['aIndex'][dna[n]]

        for k in range(K):
            # Best value yet is minus infinity
            m = numpy.NINF

            if phi[k][x] == 0:
                continue

            for j in range(K):
                if numpy.isinf(w[j][n-1]) or A[j][k] == 0:
                    continue

                v = w[j][n-1] + numpy.log(A[j][k])
                if v > m:
                    m = v 

            w[k][n] = numpy.log(phi[k][x]) + m
    return w

def backtrack(hmm, f, w):
    dna = fasta.read(f)
    N = len(dna)
    K = len(hmm['states'])
    A = hmm['A']
    phi = hmm['phi']

    z = numpy.ones(shape=(N), dtype=numpy.float64)*numpy.NINF
    z[-1] = numpy.argmax(w[:,-1:])

    print(w[:,-1:])
    print(z[-1])

    for n in range(N-2, -1, -1):
        x = hmm['aIndex'][dna[n+1]]

        # Look for the state which sums to the maximum value for n+1
        for k in range(K):

            if numpy.isinf(w[k][n]) or A[k][z[n+1]] == 0 or phi[z[n+1]][x] == 0:
                continue

            if w[k][n] + numpy.log(A[k][z[n+1]]) + numpy.log(phi[z[n+1]][x]) == w[z[n+1]][n+1]:
                z[n] = k
                break;

    return z;
