import numpy as np
import TensorTools as tt

def Nabla(k, sig, tau, i, j):

    """Computes the value of \nabla k(\sigma, \tau)[i, j].
        See pg. 21 for the definition of this notation."""

    return k(sig[i+1],tau[j+1]) +k(sig[i], tau[j]) -k(sig[i],tau[j+1]) \
            -k(sig[i+1], tau[j])


def getDiffMatrix(sig, tau, k):

    """Input: Sequences sig, tau of length L+1, L'+1 resp., and primary
    kernel k.
    Output: An (L x L') array K s.t. K[i,j] = Nabla(k, sig, tau, i, j)."""

    L = len(sig) - 1
    LPrime = len(tau) - 1
    K = np.zeros( (L, LPrime) )
    for i in range(L):
        for j in range(LPrime):
            K[i, j] = Nabla(k, sig, tau, i, j)

    return K


def seqKernEval(sig, tau, k, M, d=1):

    """Input: Ordered sequences \sigma, \tau \in X^+ of lengths L+1, L'+1 resp.
                A kernel k : X^+ \times X^+ \to \mathbb{R} to sequentialise.
                A cut-off degree M.
        Output: k_M^+ (\sigma, \tau), as the sequentialisation of k."""

    # Compute (L x L') array K s.t. K[i, j] = Nabla(k, sig, tau, i, j)
    K = getDiffMatrix(sig, tau, k)

    # Initalise an (M x L x L')-array A.
    A = np.zeros((M, len(sig)-1, len(tau)-1))

    # Algorithm 3, line 4. Note: Python indices count from 0.
    A[0,:,:] = K

    # For loop, Algorithm 3, lines 5-8.
    for m in range(M-1):

        # Get Q[m, :, :] = A[m, \boxplus, \boxplus].
        Q = np.cumsum(np.cumsum(A[m, :, :], axis=0), axis=1)

        # Calculate Q[+1, +1].
        QShift = tt.shiftByPlusOne(Q, [0,1])

        # Set A[m+1|:,:] = K \dot (1 +Q[+1, +1])
        A[m+1, :, :] = np.multiply(K, 1 +QShift)

    # Line 9: Compute R = 1 +A[M | \Sigma, \Sigma].
    R = 1 +np.sum(A[M-1,:,:])

    return R


def hiSeqKernEval(sig, tau, k, M, D):

    """Input: Ordered sequences \sigma, \tau \in X^+ of length L+1, L'+1 resp.
                A kernel k : X^+ \times X^+ \to \mathbb{R} to sequentialise.
                A cut-off degree M.
                An approximation order D, D <= M.
        Output: k_{(D), M}^+ (\sigma, \tau), as the sequentialisation of k."""

    # Fixes bug where seqKernEval(...,M,...) = hiSeqKernEval(..., M+1,...).
    M += 1

    # Compute (L x L') array K s.t. K[i, j] = Nabla(k, sig, tau, i, j)
    K = getDiffMatrix(sig, tau, k)

    # Initalise an (M x D x D x L x L')-array A.
    A = np.zeros((M, D, D, len(sig)-1, len(tau)-1))

    for m in range(1, M):

        # Line 5.
        DPrime = np.min(np.array([D, m]))

        # Compute ARec1 = A[m-1|\sigma, \sigma|\boxplus, \boxplus]
        ARec1 = \
        np.cumsum(np.cumsum(np.sum(A[m-1, :, :, :, :], axis=(0, 1)) \
        , axis=0), axis=1)
        # Compute ARec1 = A[m-1|\sigma, \sigma|\boxplus +1, \boxplus +1]
        ARec1 = tt.shiftByPlusOne(ARec1, [0,1])

        # Line 6.
        A[m, 0, 0, :, :] = np.multiply(K, 1 +ARec1)

        for d in range(1, DPrime):

            # Compute ARec2 = A[m-1|d-1, \sigma|:, \boxplus]
            ARec2 = np.cumsum(np.sum(A[m-1, d-1, :, :, :], axis=0), axis=1)
            # Compute ARec2 = A[m-1|d-1, \sigma|:, \boxplus +1]
            ARec2 = tt.shiftByPlusOne(ARec2, [1])

            #Line 8.
            A[m, d, 0, :, :] += np.divide(1, d) * np.multiply(K, ARec2)

            # Compute ARec3 = A[m-1|\sigma, d-1|\boxplus, :]
            ARec3 = np.cumsum(np.sum(A[m-1, :, d-1, :, :], axis=0), axis=0)
            # Compute ARec3 = A[m-1|\sigma, d-1|\boxplus +1, :]
            ARec3 = tt.shiftByPlusOne(ARec3, [0])

            # Line 9.
            A[m, 0, d, :, :] += np.divide(1, d) * np.multiply(K, ARec3)

            for dprime in range(1, DPrime):

                # Line 11.
                A[m, d, dprime, :, :] += np.divide(1, d*dprime) \
                * np.multiply(K, A[m-1, d-1, dprime-1, :, :])

    # Line 15.
    R = 1 + np.sum(A[M-1, :, :, :, :])

    return R


def hiSeqKernelXY(X, Y, k, M, D):
    """Computes sequential cross-kernel matrices using higher-order alg."""

    N = np.shape(X)[0]
    M = np.shape(Y)[0]

    KSeq = np.zeros((N,M))

    for row1ind in range(N):
        for row2ind in range(M):
            KSeq[row1ind,row2ind] = hiSeqKernEval(X[row1ind], Y[row2ind], k, M, D)

    return KSeq
