[Numpy-discussion] sparse array data
Francesc Alted
francesc at continuum.io
Wed May 2 16:53:47 EDT 2012
On 5/2/12 11:16 AM, Wolfgang Kerzendorf wrote:
> Hi all,
>
> I'm currently writing a code that needs three dimensional data (for the physicists it's dimensions are atom, ion, level). The problem is that not all combinations do exist (a sparse array). Sparse matrices in scipy only deal with two dimensions. The operations that I need to do on those are running functions like exp(item/constant) on all of the items. I also want to sum them up in the last dimension. What's the best way to make a class that takes this kind of data and does the required operations fast. Maybe some phycisists have implemented these things already. Any thoughts?
Curiously enough, I have recently been discussing with Travis O. about
how to represent sparse matrices with complete generality. One of the
possibilities is to use what Travis call "synthetic dimensions". The
idea behind it is easy: use a table with as many columns as dimensions,
and add another one for the actual values of the array. For a 3-D
sparse array, this looks like:
dim0 | dim1 | dim2 | value
==========================
0 | 0 | 0 | val0
0 | 10 | 100 | val1
20 | 5 | 202 | val2
You can use any package that deals with tables for implementing such a
thing. I'm going to quickly describe a raw implementation of this on
top of carray [1], not only because I'm the author, but also because it
adapts well to the needs you exposed.
[1] https://github.com/FrancescAlted/carray
Let's start with a small array with shape (2,5). We are going to use a
dense array for this, mainly for comparison purposes with typical NumPy
arrays, but of course the logic behind this can be extended to
multidimensional sparse arrays with complete generality.
In [1]: import carray as ca
In [2]: import numpy as np
In [3]: syn_dtype = np.dtype([('dim0', np.uint32), ('dim1', np.uint32),
('value', np.float64)])
In [4]: N = 10
In [6]: ct = ca.fromiter(((i/2, i%2, i*i) for i in xrange(N)),
dtype=syn_dtype, count=N)
In [7]: ct
Out[7]:
ctable((10,), |V16) nbytes: 160; cbytes: 12.00 KB; ratio: 0.01
cparams := cparams(clevel=5, shuffle=True)
[(0, 0, 0.0) (0, 1, 1.0) (1, 0, 4.0) (1, 1, 9.0) (2, 0, 16.0) (2, 1, 25.0)
(3, 0, 36.0) (3, 1, 49.0) (4, 0, 64.0) (4, 1, 81.0)]
Okay, we have our small array. Now, let's apply a function for the
values (in this case the log()):
In [8]: ct['value'][:] = ct.eval('log(value)')
In [9]: ct
Out[9]:
ctable((10,), |V16) nbytes: 160; cbytes: 12.00 KB; ratio: 0.01
cparams := cparams(clevel=5, shuffle=True)
[(0, 0, -inf) (0, 1, 0.0) (1, 0, 1.3862943611198906)
(1, 1, 2.1972245773362196) (2, 0, 2.772588722239781)
(2, 1, 3.2188758248682006) (3, 0, 3.58351893845611)
(3, 1, 3.8918202981106265) (4, 0, 4.1588830833596715)
(4, 1, 4.394449154672439)]
carray uses numexpr behind the scenes, so these operations are very
fast. Also, for functions not supported inside numexpr, carray can also
make use of the ones in NumPy (although these are typically not as
efficient).
Let's see how to do sums in different axis. For this, we will use the
selection capabilities in the ctable object. Let's do the sum in the
last axis first:
In [10]: [ sum(row.value for row in ct.where('(dim0==%d)' % (i,))) for i
in range(N/2) ]
Out[10]:
[-inf,
3.58351893845611,
5.991464547107982,
7.475339236566736,
8.55333223803211]
So, it is just a matter of summing over dim1 while keeping dim0 fixed.
One can check that the results are the same than for NumPy:
In [11]: t = np.fromiter((np.log(i*i) for i in xrange(N)),
dtype='f8').reshape(N/2,2)
In [12]: t.sum(axis=1)
Out[12]: array([ -inf, 3.58351894, 5.99146455, 7.47533924,
8.55333224])
Summing over the leading dimension means keeping dim1 fixed:
In [13]: [ sum(row.value for row in ct.where('(dim1==%d)' % (i,))) for i
in range(2) ]
Out[13]: [-inf, 13.702369854987484]
and again, this is the same than using the `axis=0` parameter:
In [14]: t.sum(axis=0)
Out[14]: array([ -inf, 13.70236985])
Summing everything is, as expected, the easiest:
In [15]: sum(row.value for row in ct.iter())
Out[15]: -inf
In [16]: t.sum()
Out[16]: -inf
Of course, the case for more dimensions requires a bit more complexity,
but nothing fancy (this is left as an exercise for the reader ;). In
case you are going to use this in your package, you may want to create
wrappers that would access the different functionality more easily.
Finally, you should note that I used 4-byte integers for representing
the dimensions. If this is not enough, you can use 8-byte integers
too. As the carray objects are compressed by default, this usually
doesn't take a lot of space. For example, for an array with 1 million
elements:
In [31]: ct = ca.fromiter(((i/2, i%2, i*i) for i in xrange(N)),
dtype=syn_dtype, count=N)
In [32]: ct
Out[32]:
ctable((1000000,), |V16) nbytes: 15.26 MB; cbytes: 1.76 MB; ratio: 8.67
cparams := cparams(clevel=5, shuffle=True)
[(0, 0, 0.0), (0, 1, 1.0), (1, 0, 4.0), ..., (499998, 1,
999994000009.0), (499999, 0, 999996000004.0), (499999, 1, 999998000001.0)]
That is saying that the ctable object is requiring just 1.76 MB (compare
this with the 8 MB that requires the equivalent dense NumPy array).
One inconvenient of this approach is that it is generally much slower
than using a dense array representation:
In [30]: time [ sum(row.value for row in ct.where('(dim1==%d)' % (i,)))
for i in range(2) ]
CPU times: user 1.80 s, sys: 0.00 s, total: 1.81 s
Wall time: 1.81 s
Out[30]: [1.666661666665056e+17, 1.6666666666667674e+17]
In [33]: t = np.fromiter((i*i for i in xrange(N)),
dtype='f8').reshape(N/2,2)
In [34]: time t.sum(axis=0)
CPU times: user 0.01 s, sys: 0.00 s, total: 0.01 s
Wall time: 0.01 s
Out[34]: array([ 1.66666167e+17, 1.66666667e+17])
Probably, implementing a native sum() operation on top of ctable objects
would help improving performance here. Alternatively, you could
accelerate these operations by using the Table object in PyTables [2]
and indexing the dimensions for getting much improved speed for
accessing elements in big sparse arrays. Using a table in a relational
database (indexed for dimensions) could be an option too.
[2] https://github.com/PyTables/PyTables
Hope this helps,
--
Francesc Alted
More information about the NumPy-Discussion
mailing list