[Scipy-svn] r2248 - trunk/Lib/io
scipy-svn at scipy.org
scipy-svn at scipy.org
Mon Oct 9 09:47:08 EDT 2006
Author: matthew.brett at gmail.com
Date: 2006-10-09 08:47:02 -0500 (Mon, 09 Oct 2006)
New Revision: 2248
Modified:
trunk/Lib/io/mio5.py
trunk/Lib/io/miobase.py
Log:
More progress on implementing Mat5 write
Modified: trunk/Lib/io/mio5.py
===================================================================
--- trunk/Lib/io/mio5.py 2006-10-08 02:12:06 UTC (rev 2247)
+++ trunk/Lib/io/mio5.py 2006-10-09 13:47:02 UTC (rev 2248)
@@ -524,8 +524,23 @@
class Mat5MatrixWriter(MatStreamWriter):
+ mat_tag = zeros((), mdtypes_template['tag_full'])
+ mat_tag['mdtype'] = miMATRIX
+
+ def __init__(self, file_stream, arr, name, is_global=False):
+ super(Mat5MatrixWriter, self).__init__(file_stream, arr, name)
+ self.is_global = is_global
+
+ def write_dtype(self, arr):
+ self.file_stream.write(arr.tostring)
+
+ def write_element(self, arr):
+ # check if small element works - do it
+ # write tag, data
+ pass
+
def write_header(self, mclass,
- is_global,
+ is_global=False,
is_complex=False,
is_logical=False,
nzmax=0):
@@ -534,21 +549,27 @@
@is_global - True if matrix is global
@is_complex - True is matrix is complex
@is_logical - True if matrix is logical
+ nzmax - max non zero elements for sparse arrays
'''
- dims = self.arr.shape
- header = empty((), mdtypes_template['header'])
- M = not ByteOrder.little_endian
- O = 0
- header['mopt'] = (M * 1000 +
- O * 100 +
- P * 10 +
- T)
- header['mrows'] = dims[0]
- header['ncols'] = dims[1]
- header['imagf'] = imagf
- header['namlen'] = len(self.name) + 1
- self.write_bytes(header)
- self.write_string(self.name + '\0')
+ self._mat_tag_pos = self.file_stream.tell()
+ self.write_dtype(self.mat_tag)
+ # write array flags (complex, global, logical, class, nzmax)
+ af = zeros((), mdtypes_template['array_flags'])
+ af['data_type'] = miUINT32
+ af['byte_count'] = 8
+ flags = is_complex << 3 | is_global << 2 | is_logical << 1
+ af['flags_class'] = mclass | flags << 8
+ af['nzmax'] = nzmax
+ self.write_dtype(af)
+ self.write_element(array(self.arr.shape, dtype='i4'))
+ self.write_element(self.name)
+
+ def update_matrix_tag(self):
+ curr_pos = self.file_stream.tell()
+ self.file_stream.seek(self._mat_tag_pos)
+ self.mat_tag['byte_count'] = curr_pos - self._mat_tag_pos - 8
+ self.write_dtype(self.mat_tag)
+ self.file_stream.seek(curr_pos)
def write(self):
assert False, 'Not implemented'
@@ -559,10 +580,6 @@
def write(self):
# identify matlab type for array
# make at least 2d
- # write miMATRIX tag
- # write array flags (complex, global, logical, class, nzmax)
- # dimensions
- # array name
# maybe downcast array to smaller matlab type
# write real
# write imaginary
@@ -611,75 +628,84 @@
T=mxSPARSE_CLASS,
dims=ijd.shape)
self.write_bytes(ijd)
-
-
-def matrix_writer_factory(stream, arr, name, unicode_strings=False, is_global=False):
- ''' Factory function to return matrix writer given variable to write
- @stream - file or file-like stream to write to
- @arr - array to write
- @name - name in matlab (TM) workspace
- '''
- if have_sparse:
- if scipy.sparse.issparse(arr):
- return Mat5SparseWriter(stream, arr, name, is_global)
- arr = array(arr)
- if arr.dtype.hasobject:
- types, arr_type = classify_mobjects(arr)
- if arr_type == 'c':
- return Mat5CellWriter(stream, arr, name, is_global, types)
- elif arr_type == 's':
- return Mat5StructWriter(stream, arr, name, is_global)
- elif arr_type == 'o':
- return Mat5ObjectWriter(stream, arr, name, is_global)
- if arr.dtype.kind in ('U', 'S'):
- if unicode_strings:
- return Mat5UniCharWriter(stream, arr, name, is_global)
+
+
+class Mat5WriterGetter(object):
+ ''' Wraps stream and options, provides methods for getting Writer objects '''
+ def __init__(self, stream, unicode_strings):
+ self.stream = stream
+ self.unicode_strings = unicode_strings
+
+ def rewind(self):
+ self.stream.seek(0)
+
+ def matrix_writer_factory(self, arr, name, is_global=False):
+ ''' Factory function to return matrix writer given variable to write
+ @stream - file or file-like stream to write to
+ @arr - array to write
+ @name - name in matlab (TM) workspace
+ '''
+ if have_sparse:
+ if scipy.sparse.issparse(arr):
+ return Mat5SparseWriter(self.stream, arr, name, is_global)
+ arr = array(arr)
+ if arr.dtype.hasobject:
+ types, arr_type = classify_mobjects(arr)
+ if arr_type == 'c':
+ return Mat5CellWriter(self.stream, arr, name, is_global, types)
+ elif arr_type == 's':
+ return Mat5StructWriter(self.stream, arr, name, is_global)
+ elif arr_type == 'o':
+ return Mat5ObjectWriter(self.stream, arr, name, is_global)
+ if arr.dtype.kind in ('U', 'S'):
+ if self.unicode_strings:
+ return Mat5UniCharWriter(self.stream, arr, name, is_global)
+ else:
+ return Mat5IntCharWriter(self.stream, arr, name, is_global)
else:
- return Mat5IntCharWriter(stream, arr, name, is_global)
- else:
- return Mat5NumericWriter(stream, arr, name, is_global)
+ return Mat5NumericWriter(self.stream, arr, name, is_global)
-def classify_mobjects(objarr):
- ''' Function to classify objects passed for writing
- returns
- types - S1 array of same shape as objarr with codes for each object
- i - invalid object
- a - ndarray
- s - matlab struct
- o - matlab object
- arr_type - one of
- c - cell array
- s - struct array
- o - object array
- '''
- N = objarr.size
- types = empty((N,), dtype='S1')
- types[:] = 'i'
- type_set = set()
- flato = objarr.flat
- for i in range(N):
- obj = flato[i]
- if isinstance(obj, ndarray):
- types[i] = 'a'
- continue
- try:
- fns = tuple(obj._fieldnames)
- except AttributeError:
- continue
- try:
- cn = obj._classname
- except AttributeError:
- types[i] = 's'
- type_set.add(fns)
- continue
- types[i] = 'o'
- type_set.add((cn, fns))
- arr_type = 'c'
- if len(set(types))==1 and len(type_set) == 1:
- arr_type = types[0]
- return types.reshape(objarr.shape), arr_type
-
-
+ def classify_mobjects(self, objarr):
+ ''' Function to classify objects passed for writing
+ returns
+ types - S1 array of same shape as objarr with codes for each object
+ i - invalid object
+ a - ndarray
+ s - matlab struct
+ o - matlab object
+ arr_type - one of
+ c - cell array
+ s - struct array
+ o - object array
+ '''
+ N = objarr.size
+ types = empty((N,), dtype='S1')
+ types[:] = 'i'
+ type_set = set()
+ flato = objarr.flat
+ for i in range(N):
+ obj = flato[i]
+ if isinstance(obj, ndarray):
+ types[i] = 'a'
+ continue
+ try:
+ fns = tuple(obj._fieldnames)
+ except AttributeError:
+ continue
+ try:
+ cn = obj._classname
+ except AttributeError:
+ types[i] = 's'
+ type_set.add(fns)
+ continue
+ types[i] = 'o'
+ type_set.add((cn, fns))
+ arr_type = 'c'
+ if len(set(types))==1 and len(type_set) == 1:
+ arr_type = types[0]
+ return types.reshape(objarr.shape), arr_type
+
+
class MatFile5Writer(MatFileWriter):
''' Class for writing mat5 files '''
def __init__(self, file_stream,
@@ -688,22 +714,32 @@
global_vars=None):
super(MatFile5Writer, self).__init__(file_stream)
self.do_compression = do_compression
- self.unicode_strings = unicode_strings
if global_vars:
self.global_vars = global_vars
else:
self.global_vars = []
+ self.writer_getter = Mat5WriterGetter(
+ StringIO,
+ unicode_strings)
+
+ def get_unicode_strings(self):
+ return self.write_getter.unicode_strings
+ def set_unicode_strings(self, unicode_strings):
+ self.writer_getter.unicode_strings = unicode_strings
+ unicode_strings = property(get_unicode_strings,
+ set_unicode_strings,
+ None,
+ 'get/set unicode strings property')
def put_variables(self, mdict):
for name, var in mdict.items():
is_global = name in self.global_vars
- stream = StringIO()
- matrix_writer_factory(stream,
- var,
- name,
- is_global,
- self.unicode_strings,
- ).write()
+ self.writer_getter.rewind()
+ self.writer_getter.matrix_writer_factory(
+ var,
+ name,
+ is_global,
+ ).write()
if self.do_compression:
str = zlib.compress(stream.getvalue())
tag = empty((), mdtypes_template['tag_full'])
Modified: trunk/Lib/io/miobase.py
===================================================================
--- trunk/Lib/io/miobase.py 2006-10-08 02:12:06 UTC (rev 2247)
+++ trunk/Lib/io/miobase.py 2006-10-09 13:47:02 UTC (rev 2248)
@@ -391,7 +391,27 @@
self.dt_dict = dt_dict
self.rtol = rtol
self.atol = atol
-
+
+ def eps(self, dt):
+ ''' Calculate machine precision for datatype
+
+ Machine precision defined as difference between X and smallest
+ encodable number greater than X, where X is usually 1.
+
+ Input can be datatype, in which case X=1, or X.
+ '''
+ try:
+ dt = dtype(dt)
+ start = array(1, dt)
+ except TypeError:
+ start = array(dt)
+ dt = start.dtype
+ two = array(2, dt)
+ e = start.copy()
+ while (e / two + start) > start:
+ e = e / two
+ return e
+
def default_dt_dict(self):
d_dict = {}
for sc_type in ('complex','float'):
@@ -474,10 +494,9 @@
def downcast_complex(self, arr):
# can we downcast to float?
- flts = self.storage_criterion(arr.dtype.itemsize / 2,
- ('f'),
- lambda x, y: x <=y)[0]
- test_arr = arr.astype(flt)
+ fts = self.dt_arrs['float']
+ flts = flts[flts['storage'] <= arr.dtype.itemsize]
+ test_arr = arr.astype(flt[0]['type'])
if allclose(arr, test_arr, self.rtol, self.atol):
return self.downcast_float(test_arr)
# try downcasting to another complex type
More information about the Scipy-svn
mailing list