[Scipy-svn] r2099 - in trunk/Lib/sandbox/svm: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Thu Jul 13 10:45:33 EDT 2006


Author: fullung
Date: 2006-07-13 09:45:21 -0500 (Thu, 13 Jul 2006)
New Revision: 2099

Modified:
   trunk/Lib/sandbox/svm/dataset.py
   trunk/Lib/sandbox/svm/tests/test_dataset.py
Log:
Precompute with any kernel.


Modified: trunk/Lib/sandbox/svm/dataset.py
===================================================================
--- trunk/Lib/sandbox/svm/dataset.py	2006-07-13 14:26:39 UTC (rev 2098)
+++ trunk/Lib/sandbox/svm/dataset.py	2006-07-13 14:45:21 UTC (rev 2099)
@@ -20,11 +20,13 @@
         return 1.0 / maxlen
     gamma = property(getgamma, 'Gamma parameter for RBF kernel')
 
-    def precompute(self):
-        return LibSvmPrecomputedDataSet(self.data)
+    def precompute(self, kernel):
+        return LibSvmPrecomputedDataSet(kernel, self.data)
 
 class LibSvmPrecomputedDataSet:
-    def __init__(self, origdata):
+    def __init__(self, kernel, origdata):
+        self.kernel = kernel
+
         # XXX look at using a list of vectors instead of a matrix when
         # the size of the precomputed dataset gets huge. This should
         # avoid problems with heap fragmentation, especially on
@@ -40,7 +42,7 @@
             for j, (y2, x2) in enumerate(origdata[i:]):
                 # Gram matrix is symmetric, so calculate dot product
                 # once and store it in both required locations
-                z = svm_node_dot(x1, x2)
+                z = kernel(x1, x2, svm_node_dot)
                 # fix index so we assign to the right place
                 j += i
                 grammat[i, j+1]['value'] = z

Modified: trunk/Lib/sandbox/svm/tests/test_dataset.py
===================================================================
--- trunk/Lib/sandbox/svm/tests/test_dataset.py	2006-07-13 14:26:39 UTC (rev 2098)
+++ trunk/Lib/sandbox/svm/tests/test_dataset.py	2006-07-13 14:45:21 UTC (rev 2099)
@@ -4,6 +4,7 @@
 set_local_path('../..')
 
 from svm.dataset import *
+from svm.kernel import *
 from svm.dataset import convert_to_svm_node, svm_node_dot
 from svm.libsvm import svm_node_dtype
 
@@ -73,15 +74,28 @@
         self.assertAlmostEqual(svm_node_dot(x, y), 4.)
 
 class test_precomputed_dataset(NumpyTestCase):
-    def check_foo(self):
-        y = N.random.randn(50)
-        x = N.random.randn(len(y), 1)
-        expected_dotprods = N.dot(x, N.transpose(x))
+    def check_precompute(self):
+        degree, gamma, coef0 = 4, 3.0, 2.0
+        kernels = [
+            LinearKernel(),
+            PolynomialKernel(degree, gamma, coef0),
+            RBFKernel(gamma),
+            SigmoidKernel(gamma, coef0)
+            ]
+        y = N.random.randn(20)
+        x = N.random.randn(len(y), 10)
         origdata = LibSvmRegressionDataSet(zip(y, x))
-        # get a new dataset containing the precomputed data
-        pcdata = origdata.precompute()
-        actual_dotprods = pcdata.grammat[:,1:-1]['value']
-        assert_array_almost_equal(actual_dotprods, expected_dotprods)
 
+        for kernel in kernels:
+            # calculate expected Gram matrix
+            expt_grammat = N.empty((len(y),)*2)
+            for i, xi in enumerate(x):
+                for j, xj in enumerate(x):
+                    expt_grammat[i, j] = kernel(xi, xj, N.dot)
+            # get a new dataset containing the precomputed data
+            pcdata = origdata.precompute(kernel)
+            actual_grammat = pcdata.grammat[:,1:-1]['value']
+            assert_array_almost_equal(actual_grammat, expt_grammat)
+
 if __name__ == '__main__':
     NumpyTest().run()




More information about the Scipy-svn mailing list