"""matrix.py

The discussion about matrix indexing has been interminible and for
the most part pretty pointless IMO. However, it does point out one
thing: the interaction between the matrix and array classes is still
pretty klunky despite a fair amount of effort trying to make them
interoperate.

Here's one possible alternative approach to enabling matrix operations
within the context of numpy arrays. I started with a few requirements:
    1. Matrices should be proper subclasses of arrays in the sense
       that one should be able to use matrices wherever one uses
       arrays with no change in the result.
    2. Indexing into matrices should produce row or column vectors
       where appropriate.
    3. There should be some syntax for matrix multiplication. I
       know that there are other operations defined on the matrix
       class, but I don't consider them worth the headache of
       having two class hierarchies.

This is in part inspired by some comments Bill Spotz, but after
listening to his later comments, I don't think that this is really
what he had in mind. So, please don't blame him for any deficiencies
you find herein.

The usual caveats apply: this is just a sketch, it's probably buggy,
use at your own risk, etc. However, you should be able to save this
mail as "matrix.py" and execute it to run all of the examples
below.

I enforce point #1 by not overriding any ndarray methods except for
__array_finalize__ and __getitem__ and they are overridden in a
way that shouldn't effect code unless it is specifically testing
for types. With that summarily taken care of, let's move onto
point #2: indexing.

First let's create a matrix. Scratch that, let's create an array
of matrices. One things that this approach does is let you create
arrays of matrices and vectors in a natural way. The matrix below
has a shape of (3,2,2), which correspons to an array of 3 2x2
matrices.

>>> m = matrix([ [[1, 0],[0,1]], [[2, 0],[0,2]], [[3, 0],[0,3]] ])

If we index on the first axis, we are selecting a single matrix, so
we would expect an ndmatrix back.

>>> m[0]
ndmatrix([[1, 0],
       [0, 1]])
      
On the other hand, if we index along the second axis, we are indexing
on the matrices columns and thus expect to get an array of row matrices.
      
>>> m[:,0]
ndrow([[1, 0],
       [2, 0],
       [3, 0]])
      
Finally if we index on the last index, we get an array of column matrices.

>>> m[...,0]
ndcolumn([[1, 0],
       [2, 0],
       [3, 0]])

Indexing of row and column matrices works similarly. It's not shown here
and thus will probably turn out to be buggy.

For matrix multiplication, I chose to abuse call. Perhaps someday we'll be able
to convinve the powers that be to free up another operator, but for the
time being this is better than "*" since that forces matrix and array into separate
camps and arguably more readable than most of the new syntax proposals that I've
seen. Here we multiply our set of three matrices 'm', with a single 2x2 matrix 'n':

>>> n = matrix([[0, 1], [3, 0]])
>>> (m)(n)
ndmatrix([[[0, 1],
        [3, 0]],
<BLANKLINE>
       [[0, 2],
        [6, 0]],
<BLANKLINE>
       [[0, 3],
        [9, 0]]])
       
Things behave similarly when multiplying a matrix with a row vector or a row
vector with a column vector.
       
>>> c = column([3, 5])
>>> (m)(c)
ndcolumn([[ 3,  5],
       [ 6, 10],
       [ 9, 15]])
>>> (n)(c)
ndcolumn([5, 9])

>>> r = row([[2, 1], [3, 4]])
>>> (r)(c)
array([11, 29])
>>> (r[0])(c)
11

Note, however that you can't (for instance) multiply column vector with
a row vector:

>>> (c)(r)
Traceback (most recent call last):
  ...
TypeError: Cannot matrix multiply columns with anything


Finally, you'd like to be able to transpose a matrix or vector. Transposing
a vector swaps the type from ndrow and ndcolumn while transposing a matrix
transposes the last two axes of the array it is embedded in. Note that I use
".t" for this instead of ".T" to avoid interfering with the existing ".T".
(Frankly, if it were up to me, I'd just deprecate .T).

>>> r
ndrow([[2, 1],
       [3, 4]])
>>> r.t
ndcolumn([[2, 1],
       [3, 4]])
>>> n.t
ndmatrix([[0, 3],
       [1, 0]])
      
That's about it. Enjoy (or not) as you will.

"""
from numpy import *

_ULT = 1 # Utlimate AKA last
_PEN = 2 # Penultimate AKA second to last

# I'm sure there's a better way to do this and this is probably buggy...
def _reduce(key, ndim):
    if isinstance(key, tuple):
        m = len(key)
        if m == 0:
            return 0
        elif m == 1:
            key = key[0]
        else:
            if Ellipsis in key:
                key = (None,)*(ndim-m) + key
            if len(key) == ndim:
                return isinstance(key[-1], int) * _ULT + isinstance(key[-2], int) * _PEN
            if len(key) == ndim-1:
                return isinstance(key[-1], int) * _PEN
    if isinstance(key, int):
        if ndim == 1:
           return _ULT
        if ndim == 2:
            return _PEN
    return 0

class ndmatrix(ndarray):
    def __array_finalize__(self, obj):
        if obj.ndim < 2:
            raise ValueError("matrices must have dimension of at least 2")
    def __getitem__(self, key):
        rmask = _reduce(key, self.ndim)
        value = ndarray.__getitem__(self, key)
        if isinstance(value, ndmatrix):
            if rmask == _ULT | _PEN:
                return value.view(ndarray)
            if rmask == _ULT:
                return value.view(ndcolumn)
            if rmask == _PEN:
                return value.view(ndrow)
        return value
    def __call__(self, other):
        if isinstance(other, ndcolumn):
            return sum(self * other[...,newaxis,:], -1).view(ndcolumn)
        if isinstance(other, ndmatrix):
            return sum(self[...,newaxis] * other[...,newaxis,:], -2).view(ndmatrix)
        else:
            raise TypeError("Can only matrix multiply matrices by matrices or vectors")
    def transpose_vector(self):
        return self.swapaxes(-1,-2)
    t = property(transpose_vector)
       
class ndcolumn(ndarray):
    def __array_finalize__(self, obj):
        if obj.ndim < 1:
            raise ValueError("vectors must have dimension of at least 1")
    def __getitem__(self, key):
        rmask = _reduce(key, self.ndim)
        value = ndarray.__getitem__(self, key)
        if isinstance(value, ndmatrix):
            if rmask == _ULT:
                return value.view(ndarray)
        return value
    def __call__(self, other):
        raise TypeError("Cannot matrix multiply columns with anything")
    def transpose_vector(self):
        return self.view(ndrow)
    t = property(transpose_vector)

class ndrow(ndarray):
    def __array_finalize__(self, obj):
        if obj.ndim < 1:
            raise ValueError("vectors must have dimension of at least 1")
    def __getitem__(self, key):
        rmask = _reduce(key, self.ndim)
        value = ndarray.__getitem__(self, key)
        if isinstance(value, ndmatrix):
            if rmask == _ULT:
                return value.view(ndarray)
        return value
    def __call__(self, other):
        if isinstance(other, ndcolumn):
            return sum(self * other, -1).view(ndarray)
        if isinstance(othe, ndmatrix):
            raise notimplemented
        else:
            raise TypeError("Can only matrix multiply rows with matrices or columns")
    def transpose_vector(self):
        return self.view(ndcolumn)
    t = property(transpose_vector)



def matrix(*args, **kwargs):
    m = array(*args, **kwargs)
    return m.view(ndmatrix)
   
def column(*args, **kwargs):
    v = array(*args, **kwargs)
    return v.view(ndcolumn)
   
def row(*args, **kwargs):
    v = array(*args, **kwargs)
    return v.view(ndrow)
   
   
if __name__ == "__main__":
    import doctest, matrix
    doctest.testmod(matrix)