import numpy as np

from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV
from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
from sklearn.utils.multiclass import unique_labels

# class SqizeSVC(BaseEstimator, ClassifierMixin):
#
#     def __init__(self, C=1.0, kernel='rbf', degree=3, gamma=1,
#                  coef0=0.0, shrinking=True, probability=False,
#                  tol=1e-3, cache_size=200, class_weight=None,
#                  verbose=False, max_iter=-1, decision_function_shape=None,
#                  random_state=None, cut_ord_pair=(2,1)):
#         self.C = C
#         self.kernel = kernel
#         self.degree = degree
#         self.gamma = gamma
#         self.coef0 = coef0
#         self.shrinking = shrinking
#         self.probability = probability
#         self.tol = tol
#         self.cache_size = cache_size
#         self.class_weight = class_weight
#         self.verbose = verbose
#         self.max_iter = max_iter
#         self.decision_function_shape = decision_function_shape
#         self.random_state = random_state
#         self.cut_ord_pair = cut_ord_pair
#
#
#     def fit(self, X, y):
#
#         def kPolynom(x,y):
#             return (self.coef0+self.gamma*np.inner(x,y))**self.degree
#         def kGauss(x,y):
#             return np.exp(-self.gamma*np.sum(np.square(x-y)))
#         def kLinear(x,y):
#             return np.inner(x,y)
#         def kSigmoid(x,y):
#             return np.tanh(self.gamma*np.inner(x,y) +self.coef0)
#
#         def kernselect(kername):
#                 switcher = {
#                     'linear': kPolynom,
#                     'rbf': kGauss,
#                     'sigmoid': kLinear,
#                     'poly': kSigmoid,
#                         }
#                 return switcher.get(kername, "nothing")
#
#         cut_off = self.cut_ord_pair[0]
#         order = self.cut_ord_pair[1]
#
#         from SeqKernel import hiSeqKernelXY
#
#         def seq_kernel(x,y):
#             return hiSeqKernelXY(x,y,kernselect(self.kernel),cut_off,order)
#         self.seq_kernel = seq_kernel
#
#         self.svc_ = SVC(C=self.C, kernel=self.seq_kernel, degree=self.degree, gamma=self.gamma,
#                      coef0=self.coef0, shrinking=self.shrinking, probability=self.probability,
#                      tol=self.tol, cache_size=self.cache_size, class_weight=self.class_weight,
#                      verbose=self.verbose, max_iter=self.max_iter, decision_function_shape=self.decision_function_shape,
#                      random_state=self.random_state)
#
#         self.svc_.fit(X, y)
#
#         return self
#
#     def predict(self, X):
#         return self.svc_.predict(X)


# class SqizeSVC(BaseEstimator, ClassifierMixin):
#
#     def __init__(self, C=1.0, kernel='rbf', degree=3, gamma=1,
#                  coef0=0.0, shrinking=True, probability=False,
#                  tol=1e-3, cache_size=200, class_weight=None,
#                  verbose=False, max_iter=-1, decision_function_shape=None,
#                  random_state=None, cut_ord_pair=(2,1)):
#         self.C = C
#         self.kernel = kernel
#         self.degree = degree
#         self.gamma = gamma
#         self.coef0 = coef0
#         self.shrinking = shrinking
#         self.probability = probability
#         self.tol = tol
#         self.cache_size = cache_size
#         self.class_weight = class_weight
#         self.verbose = verbose
#         self.max_iter = max_iter
#         self.decision_function_shape = decision_function_shape
#         self.random_state = random_state
#         self.cut_ord_pair = cut_ord_pair
#
#
#     def fit(self, X, y):
#
#         def kPolynom(x,y):
#             return (self.coef0+self.gamma*np.inner(x,y))**self.degree
#         def kGauss(x,y):
#             return np.exp(-self.gamma*np.sum(np.square(x-y)))
#         def kLinear(x,y):
#             return np.inner(x,y)
#         def kSigmoid(x,y):
#             return np.tanh(self.gamma*np.inner(x,y) +self.coef0)
#
#         def kernselect(kername):
#                 switcher = {
#                     'linear': kPolynom,
#                     'rbf': kGauss,
#                     'sigmoid': kLinear,
#                     'poly': kSigmoid,
#                         }
#                 return switcher.get(kername, "nothing")
#
#         cut_off = self.cut_ord_pair[0]
#         order = self.cut_ord_pair[1]
#
#         from SeqKernel import hiSeqKernEval
#
#         def getGram(X):
#             gram_matrix = np.zeros((X.shape[0], X.shape[0]))
#             for row1ind in range(X.shape[0]):
#                 for row2ind in range(X.shape[0]):
#                     gram_matrix[row1ind,row2ind] = \
#                     hiSeqKernEval(X[row1ind],X[row2ind],kernselect(self.kernel),\
#                     cut_off,order)
#             return gram_matrix
#
#         self.getGram = getGram
#
#         self.svc_ = SVC(C=self.C, kernel='precomputed', degree=self.degree, \
#                     gamma=self.gamma,coef0=self.coef0, \
#                     shrinking=self.shrinking, probability=self.probability,
#                     tol=self.tol, cache_size=self.cache_size, \
#                     class_weight=self.class_weight,verbose=self.verbose, \
#                     max_iter=self.max_iter, \
#                     decision_function_shape=self.decision_function_shape,\
#                     random_state=self.random_state)
#
#         self.svc_.fit(self.getGram(X), y)
#
#         return self
#
#     def predict(self, X):
#         return self.svc_.predict(self.getGram(X))

class Sqizer(BaseEstimator, TransformerMixin):

    def __init__(self, C=1.0, kernel='rbf', degree=3, gamma=1,
                     coef0=0.0, cut_ord_pair=(2,1)):
            self.C = C
            self.kernel = kernel
            self.degree = degree
            self.gamma = gamma
            self.coef0 = coef0
            self.cut_ord_pair = cut_ord_pair

    def fit(self, X, y=None):
        # Check that X and y have correct shape
        X, y = check_X_y(X, y)
        # Store the classes seen during fit
        self.classes_ = unique_labels(y)

        self.X_ = np.array(X)
        self.y_ = np.array(y)
        return self

    def transform(self, X):

        print X.__class__
        X = np.array(X)
        X = check_array(X, warn_on_dtype=True)
        X = np.array(X)
        print X.__class__

        """Returns Gram matrix corresponding to X, once sqized."""
        def kPolynom(x,y):
            return (self.coef0+self.gamma*np.inner(x,y))**self.degree
        def kGauss(x,y):
            return np.exp(-self.gamma*np.sum(np.square(x-y)))
        def kLinear(x,y):
            return np.inner(x,y)
        def kSigmoid(x,y):
            return np.tanh(self.gamma*np.inner(x,y) +self.coef0)

        def kernselect(kername):
            switcher = {
                'linear': kPolynom,
                'rbf': kGauss,
                'sigmoid': kLinear,
                'poly': kSigmoid,
                    }
            return switcher.get(kername, "nothing")

        cut_off = self.cut_ord_pair[0]
        order = self.cut_ord_pair[1]

        from SeqKernel import hiSeqKernEval

        def getGram(Y):
            gram_matrix = np.zeros((Y.shape[0], Y.shape[0]))
            for row1ind in range(Y.shape[0]):
                for row2ind in range(X.shape[0]):
                    gram_matrix[row1ind,row2ind] = \
                    hiSeqKernEval(Y[row1ind],Y[row2ind],kernselect(self.kernel),\
                    cut_off,order)
            return gram_matrix

        return np.array(getGram(X))


def simplexGenerator(var_list):
    """For a decreasing list of values [a_0, ..., a_{n-1}], this function
    returns a list of tuples (a_i, a_j) where i <= j."""
    var_list = np.sort(var_list)[::-1]
    simplex = []
    for i in range(len(var_list)):
        for j in range(i, len(var_list)):
            simplex.append(tuple([var_list[i], var_list[j]]))

    return simplex


from sklearn.pipeline import Pipeline

SeqSVCpipeline = Pipeline([

    ('Sqizer', Sqizer()),

    ('svc', SVC(kernel = 'precomputed'))

])

def findOptSVM(data, cut_ord_dom=np.logspace(1,100,num=5)):
    X = data[:,1:]
    y = data[:,0]
    params = dict(Sqizer__cut_ord_pair=simplexGenerator(cut_ord_dom))
    svr = SqizeSVC()
    clf = GridSearchCV(svr, param_grid=params, n_jobs=-1, scoring='f1')

    return clf
