[scikit-learn] different sized inputs in call to custom metric in KNN
Stephen O'Neill
soneill5045 at gmail.com
Fri Aug 25 19:30:52 EDT 2017
Hey Gang,
I was wondering if anyone might be able to answer a question about the
sklearn,neighbors.NearestNeighbors class. For reference, I'm on:
anaconda distribution python 2.7.11
sklearn version 0.17.1
I'm subclassing the NearestNeighbors class and using a custom distance
metric, something like the following:
class MyModel(NearestNeighbors):
def __init__(self, some_info):
def custom_dist(x, y, info=some_info):
return numpy.sum(numpy.abs(x - y)/some_info)
return scalar_value
NearestNeighbors.__init__(self, metric=custom_dist)
So I build a dummy dataset based on some gaussians of shape (5000,3), then
later when I call
MyModel.fit()
I get the error:
"ValueError: operands could not be broadcast together with shapes (10,)
(3,)"
inside of my custom_dist function. Naturally I checked with some simple
print x, print y statements inside of custom_dist, and sure enough the
shapes of x and y are both (10,), whereas I am expecting them to be of
shape (3,) since my dummy data has 3 columns. (Note the actual custom_dist
function written above is not what I'm truly using but it does reproduce
the same ValueError).
When I change my NearestNeighbors.__init__ call to (self,
algorithm='brute') instead of the default algorithm='auto', the x and y
values that get passed to my custom_dist are shape (3,) like I would
expect. What is different in the distance metric between the
algorithm='auto' and algorithm='brute' that would transform a 3-dimensional
sample to a 10-dimensional sample? Do the KDTree and/or BallTree classes
use the distance metric on tree nodes or something too? I wasn't able to
figure out where the shape (10,) x and y samples could be coming from.
Thanks in advance!
Best,
Steve O'Neill
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.python.org/pipermail/scikit-learn/attachments/20170825/f13984a1/attachment.html>
More information about the scikit-learn
mailing list