#include "Python.h"
#include "structmember.h"

#include "numpy/arrayobject.h"

#include "clip_imp.c"

#define _ARET(x) PyArray_Return((PyArrayObject *)(x))
#define GET_REF_COUNT(x) ( ((PyObject*)(x))->ob_refcnt )

/* function callable from python */
static PyObject* PyArray_MyClip_GlobalMeth(PyObject *dummy, PyObject *args);
static PyObject *array_my_clip(PyArrayObject *self, PyObject *args, PyObject *kwds);

/* function NOT supposed to be callable from python */
static PyObject* PyArray_FastClip(PyArrayObject *input, PyObject *min, 
        PyObject *max, PyArrayObject *out);
static PyObject* PyArray_NumericFastClip(PyArrayObject *in, PyObject *min, 
        PyObject *max, PyArrayObject *out);

/* Not python specific functions */

static PyMethodDef mymethods[] = {
    {"myclip", PyArray_MyClip_GlobalMeth, METH_VARARGS, NULL}, 
    {"mykwclip", (PyCFunction)array_my_clip, METH_VARARGS | METH_KEYWORDS, NULL}, 
    {NULL, NULL, 0, NULL} /* Sentinel */
};

PyMODINIT_FUNC init_fast_clip(void);

PyMODINIT_FUNC
init_fast_clip(void)
{
    (void)Py_InitModule("_fast_clip", mymethods);
    import_array();
}

static PyObject *
array_my_clip(PyArrayObject *self, PyObject *args, PyObject *kwds)
{
    PyObject *min, *max;
    PyArrayObject *out=NULL;
    static char *kwlist[] = {"in", "min", "max", "out", NULL};

    if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&OOO&", kwlist,
                        PyArray_OutputConverter, &self, 
                        &min, &max, PyArray_OutputConverter, &out)) {
        return NULL;
    }

    return _ARET(PyArray_FastClip(self, min, max, out));
}


/*
 * expect 3 inputs in args: the input array, minvalue, maxvalue. input array
 * expected to be con vertible to array, minval and maxval convertibale to 
 * scalars.
 *
 * Only handles native double for now. For all other cases, pass down to
 * numpy clip func.
 */
static PyObject* PyArray_MyClip_GlobalMeth(PyObject *dummy, PyObject *args)
{
    PyObject        *in_o, *min_o, *max_o;
    (void)dummy;

    if (!PyArg_ParseTuple(args, "O!OO", &PyArray_Type, &in_o, &min_o, &max_o)) {
        return NULL;
    }

    return PyArray_FastClip((PyArrayObject*)in_o, min_o, max_o, (PyArrayObject*)NULL);
}

/*
 * Behaviour of old clip:
 *  - Endianness wise:
 *      - returns an array which has the same endianness 
 *      than input
 *      - if min and/or max has non native endianness, still 
 *      returns input's endianness
 *  - If out is NULL, does a copy  of input. Otherwise, use 
 *  the given out (may fail in this case).
 *  - If min and/or max is not scalar, it has to be the same 
 *  shape than input.
 *
 *  Problems :
 *  - Complex case ?
 *  - if input type is "smaller" than min or max ? current clip 
 *  does not upcast, we do.
 *
 */
static PyObject* PyArray_FastClip(PyArrayObject *in, PyObject *min, 
        PyObject *max, PyArrayObject *out)
{
    int is_num, typenum;

    PyObject*  tmp;
     
    /* Get common type, used for working buffer, and endianness of input */
    typenum = 0;
    typenum = PyArray_ObjectType((PyObject*)in, typenum);
    typenum = PyArray_ObjectType((PyObject*)min, typenum);
    typenum = PyArray_ObjectType((PyObject*)max, typenum);

    is_num  = PyTypeNum_ISNUMBER(typenum);

    if(is_num) {
        /* 
         * Numeric case
        */
        tmp     = PyArray_NumericFastClip(in, min, max, out);
        if (tmp == NULL) {
            goto fail;
        }
    } else {
        /*
         * Non Numeric case (can current clip really handle those ?)
         */

        /* TODO: 
         *  - check that I got the refcount right here with PyArray_Clip
         *  - custom implementation 
         */
        tmp     = PyArray_Clip(in, min, max, out);
        if (tmp == NULL) {
            goto fail;
        }
    }

    return tmp;

fail:
    return NULL;
}

/*
 * Before expanding this function, we need to clarify some cases:
 *  - if any input is not native endianness, what to do ?
 *  - if all input does not have same type, what to do ?
 *  - complex case: what to do ?
 *  - why not having a inplace clip ?
 */
static PyObject* PyArray_NumericFastClip(PyArrayObject *in, PyObject *min, 
        PyObject *max, PyArrayObject *out)
{
    PyArray_Descr   *ndescr;
    PyArrayObject   *min_a, *max_a, *w_a;
    PyObject        *ret;
    
    int is_scalar, is_in_native, is_in_aligned, is_real;
    int typenum, st, flags;

    is_in_native    = PyArray_ISNOTSWAPPED(in);
    is_in_aligned   = PyArray_ISALIGNED(in);

    typenum = 0;
    typenum = PyArray_ObjectType((PyObject*)in, typenum);
    typenum = PyArray_ObjectType((PyObject*)min, typenum);
    typenum = PyArray_ObjectType((PyObject*)max, typenum);

    /* 
     * Get min and max as numpy arrays. If not scalar, check that they have
     * compatible shape
     *
     * This should be put after working buffer creation once we implement
     * array min/max... The cleaning has to be changed, then
     */
    min_a   = (PyArrayObject *)
        PyArray_FromObject(min, typenum, 0, 0);
    if (min_a == NULL) {
        PyErr_SetString(PyExc_TypeError,
                        "Error while converting min to value,"\
                        " sorry");
        //goto clean_in_a;
        goto fail;
    }
    is_scalar   = PyArray_CheckScalar(min_a);
    if (!is_scalar) {
        /*
         * Check that same shape
         */
    }

    max_a   = (PyArrayObject *)
        PyArray_ContiguousFromAny(max, typenum, 0, 0);
    if (max_a == NULL) {
        PyErr_SetString(PyExc_TypeError,
                        "Error while converting max to value,"\
                        " sorry");
        goto clean_min_a;
    }
    is_scalar   = PyArray_CheckScalar(max_a);
    if (!is_scalar) {
        /*
         * Check that same shape
         */
    }
    is_scalar   = PyArray_CheckScalar(min_a) && PyArray_CheckScalar(max_a);
    is_real     = !PyTypeNum_ISCOMPLEX(typenum);
    
    /*
     * Now, check wether we need to create a working array, or if we can
     * use an existing one (input or output)
     */
    if (out == NULL && is_in_native && is_in_aligned && is_scalar && is_real) {
        /*
         * Create a working buffer of type typenum
         */
        ndescr  = PyArray_DescrFromType(typenum);
        if (ndescr == NULL) {
            goto fail;
        }

        /* 
         * Creating a working array; as it is a copy, we can ask all nice
         * properties
         */
        flags   = NPY_ENSURECOPY | NPY_IN_ARRAY | NPY_CONTIGUOUS | NPY_NOTSWAPPED;
        //flags   = NPY_IN_ARRAY | NPY_CONTIGUOUS | NPY_NOTSWAPPED;
        Py_INCREF(ndescr);
        w_a     = (PyArrayObject*)PyArray_FromArray(in, ndescr, flags);
        if (w_a  == NULL) {
            //Py_XDECREF(ndescr);
            goto clean_ndescr;
        } 
    } else {
        /*
         * out != NULL: For now, does not handle this case, 
         * pass it to old implementation.
         *
         * Before handling this case, we have to check that out is what we
         * want.
         */
        ret = PyArray_Clip(in, min, max, out);
        if (ret == NULL) {
            goto fail;
        }

        return ret;
    }

    /*
     * Now, we got the working buffer w_a
     */
    //is_w_native     = PyArray_ISNOTSWAPPED(w_a);

    /*
     * Now, we have all necessary information to call the correct 
     * implementation * depending on arguments type, shape, etc...
     *
     * Case to differentiate:
     *  - scalar min/max: fast implementation.
     *  - non scalar min/max: pass back to old implementation for now.
     */
    if (is_scalar) {
        /*
         * Numeric, scalar case
         */
        st  = numeric_native_scalar_generic_clip(w_a, min_a, max_a);
        if (st != 0 ) {
            goto clean_in_a;
        }
        ret = (PyObject*)w_a; 
    } else {
        /*
         * Numeric, non scalar case 
         * Numeric, scalar case, non native input
         */
        ret     = PyArray_Clip(w_a, min, max, NULL);
        if (ret == NULL) {
            goto clean_max_a;
        }
        Py_XDECREF(w_a);
        //fprintf(stderr, "%s:%s, line %d: in (%d), min (%d), 
        //          max (%d), out (%d)\n",
        //         __FILE__, __func__, __LINE__, 
        //         GET_REF_COUNT(in_a), GET_REF_COUNT(min_a), 
        //         GET_REF_COUNT(max_a), GET_REF_COUNT(out));
    }
    
    /*
     * We're done, clean everything
     */
    Py_XDECREF(ndescr);
    Py_XDECREF(max_a);
    Py_XDECREF(min_a);

    return ret;

clean_in_a:
    Py_XDECREF(w_a);
clean_ndescr:
    Py_XDECREF(ndescr);
clean_max_a:
    Py_XDECREF(max_a);
clean_min_a:
    Py_XDECREF(min_a);
fail:
    return NULL;
}
