[Numpy-discussion] subclassing matrix

Colin J. Williams cjw at sympatico.ca
Sat Jan 12 19:43:56 EST 2008


Basilisk96 wrote:
> On Jan 12, 1:36 am, "Timothy Hochberg" <tim.hochb... at ieee.org> wrote:
>> I believe that you need to look at __array_finalize__ and __array_priority__
>> (and there may be one other thing as well, I can't remember; it's late).
>> Search for __array_finalize__ and that will probably help get you started.
>>
> 
> Well sonovagun!
> I removed the hack.
> Then just by setting __array_priority__ = 20.0 in the class body,
> things are magically working almost as I expect. I say "almost"
> because of this custom method:
> 
>     def cross(self, other):
>         """Cross product of this vector and another vector"""
>         return _N.cross(self, other, axis=0)
> 
> That call to numpy.cross returns a numpy.ndarray. Unless I do return
> Vector(_N.cross(self, other, axis=0)), I get problems downstream.
> 
> When is __array_finalize__ called? By adding some print traces, I can
> see it's called every time an array is modified in any way i.e.,
> reshaped, transposed, etc., and also during operations like u+v, u-v,
> A*u. But it's not called during the call to numpy.cross. Why?
> 
> Cheers,
> -Basilisk96

This may help.  It is based on your 
initial script.

The Vectors are considered as columns 
but presented as rows.

This adds a complication which is not 
resolved.

Colin W.

#---------- vector.py
import numpy as _N
import math as _M
#default tolerance for equality tests
TOL_EQ = 1e-6
#default format for pretty-printing 
Vector instances
FMT_VECTOR_DEFAULT = "%+.5f"

class Vector(_N.matrix):
     """
     2D/3D vector class that supports 
numpy matrix operations and more.

     Examples:
         u = Vector([1,2,3])
         v = Vector('3 4 5')
         w = Vector([1, 2])
     """
     def __new__(cls, data="0. 0. 0.", 
dtype=_N.float64):
         """
         Subclass instance constructor.

             If data is not specified, a 
zero Vector is constructed.
             The constructor always 
returns a Vector instance.
             The instance gets a 
customizable Format attribute, which
             controls the printing 
precision.
         """
         data= [1, 2, 3]
         ret= _N.matrix(data, dtype)
##        ret = super(Vector, 
cls).__new__(cls, data, dtype=dtype)

##        #promote the instance to cls type.
##        ret.__class__ = cls
         assert ret.size in (2, 3), 
'Vector must have either two or three 
components'
         if ret.shape[0] == 1:
             ret = ret.T
         assert ret.shape == 
(ret.shape[0], 1), 'could not express 
Vector as a Mx1 matrix'
         if ret.shape[0] == 2:
             ret = _N.vstack((ret, 0.))
         ret.Format = FMT_VECTOR_DEFAULT
         ret=  _N.ndarray.__new__(cls, 
ret.shape, dtype,
 
buffer=ret.data)
         return ret

     def __str__(self):
         fmt = getattr(self, "Format", 
FMT_VECTOR_DEFAULT)
         fmt = ', '.join([fmt]*3)
         return ''.join(["(", fmt, ")"]) 
% tuple(self.T.tolist()[0])

     def __repr__(self):
         fmt = ', '.join(['%s']*3)
         return ''.join(["%s([", fmt, 
"])"]) % tuple([self.__class__.__name__] 
+ self.T.tolist()[0])

     def __mul__(self, mult):
       ''' self * multiplicand '''
       if isinstance(mult, _N.matrix):
         return _N.dot(self, mult)
       else:
         raise DataError, 'multiplicand 
must be a Vector or a matrix'

     def __rmul__(self, mult):
       ''' multiplier * self.__mul__ '''
       if isinstance(mult, _N.matrix):
         return Vector(_N.dot(mult, self))
       else:
         raise DataError, 'multiplier 
must be a Vector or a matrix'

     #### the remaining methods are 
Vector-specific math operations, 
including the X,Y,Z properties...
if __name__ == '__main__':
   u = Vector('1 2 3')
   print str(u)
   print repr(u)
   A = _N.matrix('2 0 0; 0 2 0; 0 0 2')
   print A
   p = A * u
   print p
   print  p.__class__
   q= u.T * A
   try:
     print q
   except:
     print "we don't allow for the 
display of row vectors"
   print q.A, q.T
   print q.__class__




More information about the NumPy-Discussion mailing list