[scikit-learn] different sized inputs in call to custom metric in KNN

Stephen O'Neill soneill5045 at gmail.com
Fri Aug 25 19:30:52 EDT 2017

Hey Gang,

I was wondering if anyone might be able to answer a question about the
sklearn,neighbors.NearestNeighbors class.  For reference, I'm on:
anaconda distribution python 2.7.11
sklearn version 0.17.1

I'm subclassing the NearestNeighbors class and using a custom distance
metric, something like the following:

class MyModel(NearestNeighbors):
    def __init__(self, some_info):
        def custom_dist(x, y, info=some_info):
            return numpy.sum(numpy.abs(x - y)/some_info)

            return scalar_value

        NearestNeighbors.__init__(self, metric=custom_dist)

So I build a dummy dataset based on some gaussians of shape (5000,3), then
later when I call


I get the error:
"ValueError: operands could not be broadcast together with shapes (10,)

inside of my custom_dist function.  Naturally I checked with some simple
print x, print y statements inside of custom_dist, and sure enough the
shapes of x and y are both (10,), whereas I am expecting them to be of
shape (3,) since my dummy data has 3 columns. (Note the actual custom_dist
function written above is not what I'm truly using but it does reproduce
the same ValueError).

When I change my NearestNeighbors.__init__ call to (self,
algorithm='brute') instead of the default algorithm='auto', the x and y
values that get passed to my custom_dist are shape (3,) like I would
expect.  What is different in the distance metric between the
algorithm='auto' and algorithm='brute' that would transform a 3-dimensional
sample to a 10-dimensional sample?  Do the KDTree and/or BallTree classes
use the distance metric on tree nodes or something too?  I wasn't able to
figure out where the shape (10,) x and y samples could be coming from.

Thanks in advance!

Steve O'Neill
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.python.org/pipermail/scikit-learn/attachments/20170825/f13984a1/attachment.html>

More information about the scikit-learn mailing list