Nearest Neighbor Classifier -OTDA

Hello , I am following these two papers- Optimal Transport for Domain adaptation Courty et al in TPAMI [1] and Large scale OT and Mapping estimation -Seguey et al in ICLR[2] The papers say that they use Nearest Neighbor classifier as the baseline and I am confused about it. As in [1] I took 1800 Usps and 2000 MNIST dataset and using the OTDA examples found the transported source. I am struck after this for the classification i.e., do I now compute a 2000 X 1800 distance matrix between target and source and use nearest neighbors? Here is the code that I am using. Can you please tell me how to proceed from here ##% experiment with domain adaptation import ot import numpy as np import sklearn import matplotlib.pyplot as plt from scipy.io import loadmat #%% data u_m=loadmat('USPS_vs_MNIST.mat') Xs= u_m['X_src'] Ys=u_m['Y_src'] Xt=u_m['X_tar'] Yt=u_m['Y_tar'] #%% DA ot_sinkhorn_un = ot.da.SinkhornTransport(reg_e=1) ot_sinkhorn_un.fit(Xs=Xs.T, Xt=Xt.T) transp_Xs_sinkhorn_un = ot_sinkhorn_un.transform(Xs=Xs.T).T #%% nearest neigbor classifier from sklearn.metrics import confusion_matrix as CM import scipy M=scipy.spatial.distance.cdist(Xt.T,transp_Xs_sinkhorn_un.T,'sqeuclidean')# not correct a=np.argmin(M,axis=1) labels_test= Yt[a] conf= CM(Yt,labels_test) acc= np.sum(np.diag(conf))/np.sum(conf) acc

Ah. I am super embarrassed by posting this here. I could figure this out by simply using sklearn. Wish there is a delete option. from sklearn.neighbors import KNeighborsClassifier neigh = KNeighborsClassifier(n_neighbors=1) neigh.fit(Xs,Ys) Acc_source= neigh.score(Xs,Ys) neigh.fit(transp_Xs_sinkhorn_un,Ys) Acc_source_da= neigh.score(transp_Xs_sinkhorn_un,Ys) Acc_target=neigh.score(Xt,Yt) Thanks and apologies

Ah. I am super embarrassed by posting this here. I could figure this out by simply using sklearn. Wish there is a delete option. from sklearn.neighbors import KNeighborsClassifier neigh = KNeighborsClassifier(n_neighbors=1) neigh.fit(Xs,Ys) Acc_source= neigh.score(Xs,Ys) neigh.fit(transp_Xs_sinkhorn_un,Ys) Acc_source_da= neigh.score(transp_Xs_sinkhorn_un,Ys) Acc_target=neigh.score(Xt,Yt) Thanks and apologies
participants (1)
-
Kowshik Thopalli