[Scipy-svn] r2058 - in trunk/Lib/sandbox/svm: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Sun Jul 9 19:03:49 EDT 2006


Author: fullung
Date: 2006-07-09 18:03:35 -0500 (Sun, 09 Jul 2006)
New Revision: 2058

Modified:
   trunk/Lib/sandbox/svm/classification.py
   trunk/Lib/sandbox/svm/model.py
   trunk/Lib/sandbox/svm/tests/test_classification.py
   trunk/Lib/sandbox/svm/tests/test_oneclass.py
   trunk/Lib/sandbox/svm/tests/test_regression.py
Log:
Classification test and some minor fixes.


Modified: trunk/Lib/sandbox/svm/classification.py
===================================================================
--- trunk/Lib/sandbox/svm/classification.py	2006-07-08 09:29:48 UTC (rev 2057)
+++ trunk/Lib/sandbox/svm/classification.py	2006-07-09 23:03:35 UTC (rev 2058)
@@ -39,7 +39,7 @@
         """
         def p(x):
             xptr = cast(x.ctypes.data, POINTER(libsvm.svm_node))
-            return libsvm.svm_predict(self.model, xptr)
+            return int(libsvm.svm_predict(self.model, xptr))
         return map(p, dataset.data)
 
     def predict_values(self, dataset):

Modified: trunk/Lib/sandbox/svm/model.py
===================================================================
--- trunk/Lib/sandbox/svm/model.py	2006-07-08 09:29:48 UTC (rev 2057)
+++ trunk/Lib/sandbox/svm/model.py	2006-07-09 23:03:35 UTC (rev 2058)
@@ -54,10 +54,6 @@
 
     def fit(self, dataset):
         # XXX don't poke around in dataset's internals
-
-        # no reference to the svm_problem is kept because a svm_model
-        # only requires some parameters and the support vectors chosen
-        # from the dataset
         problem = libsvm.svm_problem()
         problem.l = len(dataset.data)
         y = (c_double*problem.l)()
@@ -71,10 +67,15 @@
 
         model = libsvm.svm_train(problem, self.param)
 
-        # XXX because libsvm only does a shallow copy of the
-        # svm_parameter into the model, we have to make sure that a
-        # reference to weight labels and weights are kept somewhere
+        # weight parametes are no longer required, so remove to them
+        # as the data they point to might disappear when this object
+        # is deallocated
+        model.contents.param.nr_weight = 0
+        model.contents.param.weight = None
+        model.contents.param.weight_label = None
 
+        # results keep a refence to the dataset because the svm_model
+        # refers to some of its vectors as the support vectors
         return self.Results(model, dataset)
 
     def _check_problem_param(self, problem, param):

Modified: trunk/Lib/sandbox/svm/tests/test_classification.py
===================================================================
--- trunk/Lib/sandbox/svm/tests/test_classification.py	2006-07-08 09:29:48 UTC (rev 2057)
+++ trunk/Lib/sandbox/svm/tests/test_classification.py	2006-07-09 23:03:35 UTC (rev 2058)
@@ -17,7 +17,7 @@
         Model(Kernel, 1.0, weights)
         model = Model(Kernel, cost=1.0, weights=weights)
 
-    def check_c_train(self):
+    def check_c_basics(self):
         labels = [0, 1, 1, 2]
         x = [N.array([0, 0]),
              N.array([0, 1]),
@@ -33,6 +33,45 @@
         results.predict(testdata)
         results.predict_values(testdata)
 
+    def check_c_more(self):
+        labels = [0, 1, 1, 2]
+        x = [N.array([0, 0]),
+             N.array([0, 1]),
+             N.array([1, 0]),
+             N.array([1, 1])]
+        traindata = LibSvmClassificationDataSet(zip(labels, x))
+        cost = 10.0
+        weights = [(1, 10.0)]
+        testdata = LibSvmTestDataSet(x)
+
+        kernels = [
+            LinearKernel(),
+            PolynomialKernel(3, traindata.gamma, 0.0),
+            RBFKernel(traindata.gamma)
+            ]
+        expected_rhos = [
+            [-0.999349, -1.0, -3.0],
+            [0.375, -1.0, -1.153547],
+            [0.671181, 0.0, -0.671133]
+            ]
+        expected_errors = [0, 1, 0]
+
+        for kernel, expected_rho, expected_error in \
+            zip(kernels, expected_rhos, expected_errors):
+            model = LibSvmCClassificationModel(kernel, cost, weights)
+            results = model.fit(traindata)
+
+            self.assertEqual(results.labels, [0, 1, 2])
+            #self.assertEqual(model.nSV, [1, 2, 1])
+
+            # XXX decimal=4 to suppress slight differences in values
+            # calculated for rho on Windows with MSVC 7.1 and on
+            # Fedora Core 4 with GCC 4.0.0.
+            assert_array_almost_equal(results.rho, expected_rho, decimal=4)
+
+            predictions = N.array(results.predict(testdata))
+            self.assertEqual(N.sum(predictions != labels), expected_error)
+
     def check_nu_train(self):
         pass
 

Modified: trunk/Lib/sandbox/svm/tests/test_oneclass.py
===================================================================
--- trunk/Lib/sandbox/svm/tests/test_oneclass.py	2006-07-08 09:29:48 UTC (rev 2057)
+++ trunk/Lib/sandbox/svm/tests/test_oneclass.py	2006-07-09 23:03:35 UTC (rev 2058)
@@ -18,11 +18,11 @@
              N.array([0, 1]),
              N.array([1, 0]),
              N.array([1, 1])]
-        dataset = LibSvmOneClassDataSet(x)
+        triandata = LibSvmOneClassDataSet(x)
         
         Model = LibSvmOneClassModel
         model = Model(LinearKernel())
-        results = model.fit(dataset)
+        results = model.fit(traindata)
 
         testdata = LibSvmTestDataSet(x)
         results.predict(testdata)

Modified: trunk/Lib/sandbox/svm/tests/test_regression.py
===================================================================
--- trunk/Lib/sandbox/svm/tests/test_regression.py	2006-07-08 09:29:48 UTC (rev 2057)
+++ trunk/Lib/sandbox/svm/tests/test_regression.py	2006-07-09 23:03:35 UTC (rev 2058)
@@ -22,11 +22,11 @@
              N.array([0, 1]),
              N.array([1, 0]),
              N.array([1, 1])]
-        dataset = LibSvmRegressionDataSet(zip(y, x))
+        traindata = LibSvmRegressionDataSet(zip(y, x))
 
         Model = LibSvmEpsilonRegressionModel
         model = Model(LinearKernel())
-        results = model.fit(dataset)
+        results = model.fit(traindata)
 
         testdata = LibSvmTestDataSet(x)
         results.predict(testdata)




More information about the Scipy-svn mailing list