[Scipy-svn] r2104 - in trunk/Lib/sandbox/svm: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Fri Jul 14 17:55:52 EDT 2006


Author: fullung
Date: 2006-07-14 16:55:34 -0500 (Fri, 14 Jul 2006)
New Revision: 2104

Added:
   trunk/Lib/sandbox/svm/tests/test_precomputed.py
Modified:
   trunk/Lib/sandbox/svm/classification.py
   trunk/Lib/sandbox/svm/dataset.py
   trunk/Lib/sandbox/svm/libsvm.py
   trunk/Lib/sandbox/svm/model.py
   trunk/Lib/sandbox/svm/regression.py
   trunk/Lib/sandbox/svm/tests/test_all.py
Log:
Precomputed model training.


Modified: trunk/Lib/sandbox/svm/classification.py
===================================================================
--- trunk/Lib/sandbox/svm/classification.py	2006-07-14 21:21:55 UTC (rev 2103)
+++ trunk/Lib/sandbox/svm/classification.py	2006-07-14 21:55:34 UTC (rev 2104)
@@ -130,7 +130,7 @@
         This function returns the percentage of data that was
         classified correctly over all the experiments.
         """
-        problem = dataset.create_svm_problem()
+        problem = dataset._create_svm_problem()
         target = N.empty((len(dataset.data),), dtype=N.float64)
         tp = cast(target.ctypes.data, POINTER(c_double))
         libsvm.svm_cross_validation(problem, self.param, nr_fold, tp)

Modified: trunk/Lib/sandbox/svm/dataset.py
===================================================================
--- trunk/Lib/sandbox/svm/dataset.py	2006-07-14 21:21:55 UTC (rev 2103)
+++ trunk/Lib/sandbox/svm/dataset.py	2006-07-14 21:55:34 UTC (rev 2104)
@@ -1,4 +1,3 @@
-from ctypes import c_double, POINTER, cast
 import numpy as N
 
 import libsvm
@@ -24,18 +23,13 @@
     def precompute(self, kernel):
         return LibSvmPrecomputedDataSet(kernel, self.data)
 
-    def create_svm_problem(self):
-        problem = libsvm.svm_problem()
-        problem.l = len(self.data)
-        y = (c_double*problem.l)()
-        x = (POINTER(libsvm.svm_node)*problem.l)()
-        for i, (yi, xi) in enumerate(self.data):
-            y[i] = yi
-            x[i] = cast(xi.ctypes.data, POINTER(libsvm.svm_node))
-        problem.x = x
-        problem.y = y
-        return problem
+    def _create_svm_problem(self):
+        return libsvm.create_svm_problem(self.data)
 
+    def _update_svm_parameter(self, param):
+        # XXX we can handle gamma=None here
+        pass
+
 class LibSvmPrecomputedDataSet:
     def __init__(self, kernel, origdata=None):
         self.kernel = kernel
@@ -126,6 +120,12 @@
         newdataset.grammat = newgrammat
         return newdataset
 
+    def _create_svm_problem(self):
+        return libsvm.create_svm_problem(self.data)
+
+    def _update_svm_parameter(self, param):
+        param.kernel_type = libsvm.PRECOMPUTED
+
 class LibSvmRegressionDataSet(LibSvmDataSet):
     def __init__(self, origdata):
         data = map(lambda x: (x[0], convert_to_svm_node(x[1])), origdata)

Modified: trunk/Lib/sandbox/svm/libsvm.py
===================================================================
--- trunk/Lib/sandbox/svm/libsvm.py	2006-07-14 21:21:55 UTC (rev 2103)
+++ trunk/Lib/sandbox/svm/libsvm.py	2006-07-14 21:55:34 UTC (rev 2104)
@@ -1,7 +1,7 @@
 import inspect
 
+from ctypes import *
 import numpy as N
-from ctypes import c_int, c_double, POINTER, Structure, c_char_p
 
 _libsvm = N.ctypes_load_library('libsvm_', __file__)
 
@@ -124,6 +124,18 @@
     func.argtypes = argtypes
     inspect.currentframe().f_locals[f] = func
 
+def create_svm_problem(data):
+    problem = svm_problem()
+    problem.l = len(data)
+    y = (c_double*problem.l)()
+    x = (POINTER(svm_node)*problem.l)()
+    for i, (yi, xi) in enumerate(data):
+        y[i] = yi
+        x[i] = cast(xi.ctypes.data, POINTER(svm_node))
+    problem.x = x
+    problem.y = y
+    return problem
+
 __all__ = [
     'svm_node_dtype',
     'C_SVC',

Modified: trunk/Lib/sandbox/svm/model.py
===================================================================
--- trunk/Lib/sandbox/svm/model.py	2006-07-14 21:21:55 UTC (rev 2103)
+++ trunk/Lib/sandbox/svm/model.py	2006-07-14 21:55:34 UTC (rev 2104)
@@ -1,4 +1,4 @@
-from ctypes import *
+from ctypes import POINTER, c_double, c_int
 
 from kernel import *
 import libsvm
@@ -47,9 +47,10 @@
         self.param = param
 
     def fit(self, dataset):
-        problem = dataset.create_svm_problem()
+        problem = dataset._create_svm_problem()
+        dataset._update_svm_parameter(self.param)
+        self._check_problem_param(problem, self.param)
 
-        self._check_problem_param(problem, self.param)
         model = libsvm.svm_train(problem, self.param)
 
         # weights are no longer required, so remove to them as the
@@ -58,8 +59,12 @@
         model.contents.param.weight = c_double_null_ptr
         model.contents.param.weight_label = c_int_null_ptr
 
-        # results keep a refence to the dataset because the svm_model
-        # refers to some of its vectors as the support vectors
+        # results keep a reference to the dataset because the
+        # svm_model refers to some of its vectors as the support
+        # vectors
+        # XXX we can hide an id in the end of record marker so that we
+        # can figure out which support vectors to keep references to
+        # even when not using precomputed kernels
         return self.Results(model, dataset)
 
     def _check_problem_param(self, problem, param):

Modified: trunk/Lib/sandbox/svm/regression.py
===================================================================
--- trunk/Lib/sandbox/svm/regression.py	2006-07-14 21:21:55 UTC (rev 2103)
+++ trunk/Lib/sandbox/svm/regression.py	2006-07-14 21:55:34 UTC (rev 2104)
@@ -66,7 +66,7 @@
         error and the squared correlation coefficient.
         """
 
-        problem = dataset.create_svm_problem()
+        problem = dataset._create_svm_problem()
         target = N.empty((len(dataset.data),), dtype=N.float64)
         tp = cast(target.ctypes.data, POINTER(c_double))
         libsvm.svm_cross_validation(problem, self.param, nr_fold, tp)

Modified: trunk/Lib/sandbox/svm/tests/test_all.py
===================================================================
--- trunk/Lib/sandbox/svm/tests/test_all.py	2006-07-14 21:21:55 UTC (rev 2103)
+++ trunk/Lib/sandbox/svm/tests/test_all.py	2006-07-14 21:55:34 UTC (rev 2104)
@@ -3,6 +3,7 @@
 from test_dataset import *
 from test_oneclass import *
 from test_libsvm import *
+from test_precomputed import *
 
 if __name__ == '__main__':
     NumpyTest().run()

Added: trunk/Lib/sandbox/svm/tests/test_precomputed.py
===================================================================
--- trunk/Lib/sandbox/svm/tests/test_precomputed.py	2006-07-14 21:21:55 UTC (rev 2103)
+++ trunk/Lib/sandbox/svm/tests/test_precomputed.py	2006-07-14 21:55:34 UTC (rev 2104)
@@ -0,0 +1,45 @@
+from numpy.testing import *
+import numpy as N
+
+set_local_path('../..')
+from svm.regression import *
+from svm.dataset import *
+from svm.kernel import LinearKernel
+restore_path()
+
+class test_precomputed(NumpyTestCase):
+    def check_precomputed(self):
+        kernel = LinearKernel()
+
+        # this dataset remains constant
+        y1 = N.random.randn(50)
+        x1 = N.random.randn(len(y1), 10)
+        data1 = LibSvmRegressionDataSet(zip(y1, x1))
+        pcdata1 = data1.precompute(kernel)
+
+        # in a typical problem, this dataset would be smaller than the
+        # part that remains constant and would differ for each model
+        y2 = N.random.randn(5)
+        x2 = N.random.randn(len(y2), x1.shape[1])
+        data2 = LibSvmRegressionDataSet(zip(y2, x2))
+
+        pcdata12 = pcdata1.combine(data2)
+        model = LibSvmEpsilonRegressionModel(kernel)
+        results = model.fit(pcdata12)
+
+        # reference model, calculated without involving the
+        # precomputed Gram matrix
+        refy = N.concatenate([y1, y2])
+        refx = N.vstack([x1, x2])
+        refdata = LibSvmRegressionDataSet(zip(refy, refx))
+        model = LibSvmEpsilonRegressionModel(kernel)
+        refresults = model.fit(refdata)
+
+        self.assertAlmostEqual(results.rho, refresults.rho)
+        assert_array_almost_equal(results.sv_coef, refresults.sv_coef)
+
+        # XXX sigmas don't match yet. need to find out why.
+        #self.assertAlmostEqual(results.sigma, refresults.sigma)
+
+if __name__ == '__main__':
+    NumpyTest().run()




More information about the Scipy-svn mailing list