[Scipy-svn] r2099 - in trunk/Lib/sandbox/svm: . tests
scipy-svn at scipy.org
scipy-svn at scipy.org
Thu Jul 13 10:45:33 EDT 2006
Author: fullung
Date: 2006-07-13 09:45:21 -0500 (Thu, 13 Jul 2006)
New Revision: 2099
Modified:
trunk/Lib/sandbox/svm/dataset.py
trunk/Lib/sandbox/svm/tests/test_dataset.py
Log:
Precompute with any kernel.
Modified: trunk/Lib/sandbox/svm/dataset.py
===================================================================
--- trunk/Lib/sandbox/svm/dataset.py 2006-07-13 14:26:39 UTC (rev 2098)
+++ trunk/Lib/sandbox/svm/dataset.py 2006-07-13 14:45:21 UTC (rev 2099)
@@ -20,11 +20,13 @@
return 1.0 / maxlen
gamma = property(getgamma, 'Gamma parameter for RBF kernel')
- def precompute(self):
- return LibSvmPrecomputedDataSet(self.data)
+ def precompute(self, kernel):
+ return LibSvmPrecomputedDataSet(kernel, self.data)
class LibSvmPrecomputedDataSet:
- def __init__(self, origdata):
+ def __init__(self, kernel, origdata):
+ self.kernel = kernel
+
# XXX look at using a list of vectors instead of a matrix when
# the size of the precomputed dataset gets huge. This should
# avoid problems with heap fragmentation, especially on
@@ -40,7 +42,7 @@
for j, (y2, x2) in enumerate(origdata[i:]):
# Gram matrix is symmetric, so calculate dot product
# once and store it in both required locations
- z = svm_node_dot(x1, x2)
+ z = kernel(x1, x2, svm_node_dot)
# fix index so we assign to the right place
j += i
grammat[i, j+1]['value'] = z
Modified: trunk/Lib/sandbox/svm/tests/test_dataset.py
===================================================================
--- trunk/Lib/sandbox/svm/tests/test_dataset.py 2006-07-13 14:26:39 UTC (rev 2098)
+++ trunk/Lib/sandbox/svm/tests/test_dataset.py 2006-07-13 14:45:21 UTC (rev 2099)
@@ -4,6 +4,7 @@
set_local_path('../..')
from svm.dataset import *
+from svm.kernel import *
from svm.dataset import convert_to_svm_node, svm_node_dot
from svm.libsvm import svm_node_dtype
@@ -73,15 +74,28 @@
self.assertAlmostEqual(svm_node_dot(x, y), 4.)
class test_precomputed_dataset(NumpyTestCase):
- def check_foo(self):
- y = N.random.randn(50)
- x = N.random.randn(len(y), 1)
- expected_dotprods = N.dot(x, N.transpose(x))
+ def check_precompute(self):
+ degree, gamma, coef0 = 4, 3.0, 2.0
+ kernels = [
+ LinearKernel(),
+ PolynomialKernel(degree, gamma, coef0),
+ RBFKernel(gamma),
+ SigmoidKernel(gamma, coef0)
+ ]
+ y = N.random.randn(20)
+ x = N.random.randn(len(y), 10)
origdata = LibSvmRegressionDataSet(zip(y, x))
- # get a new dataset containing the precomputed data
- pcdata = origdata.precompute()
- actual_dotprods = pcdata.grammat[:,1:-1]['value']
- assert_array_almost_equal(actual_dotprods, expected_dotprods)
+ for kernel in kernels:
+ # calculate expected Gram matrix
+ expt_grammat = N.empty((len(y),)*2)
+ for i, xi in enumerate(x):
+ for j, xj in enumerate(x):
+ expt_grammat[i, j] = kernel(xi, xj, N.dot)
+ # get a new dataset containing the precomputed data
+ pcdata = origdata.precompute(kernel)
+ actual_grammat = pcdata.grammat[:,1:-1]['value']
+ assert_array_almost_equal(actual_grammat, expt_grammat)
+
if __name__ == '__main__':
NumpyTest().run()
More information about the Scipy-svn
mailing list