[scikit-learn] Question about the Library of “sklearn.neural_network.BernoulliRBM” that Creates Highly Correlated Features.

Masanari Kondo masa.kondo5 at gmail.com
Thu Jul 27 16:16:06 EDT 2017


Dear all,

I’m using the sklearn library to generate new features of a dataset
using a Restricted Boltzmann Machine (RBM,
sklearn.neural_network.BernoulliRBM). I use the following environment:

python 3.5.0
numpy==1.11.1
scikit-learn==0.18


I have already tried a large number of iterations (n_iter=6000) and a
low learning rate (0.0001) for all training data (373 samples). However,
The new features that are generated by the RBM are all highly
correlated. Can anyone explain why this happens?


Below is a MWE:


import numpy as np
import csv
from sklearn.neural_network import BernoulliRBM

# train data
train_data = np.array(
[[0.0326086956522,0.0,0.0,0.0200400801603,0.0674157303371,0.000805152979066,0.00200803212851,0.243243243243,0.0123456790123,0.55,0.0233428760185,0.0,0.0,0.0,0.444444444,0.0,0.0,0.157556270138,0.0188679245283,0.0983652512615],
[0.0108695652174,0.2,0.0,0.00200400801603,0.0112359550562,0.0,0.0,0.027027027027,0.0123456790123,1.0,0.00154151068047,0.0,0.0,1.0,1.0,0.0,0.0,0.0289389067571,0.0,0.0],
[0.0869565217391,0.0,0.152542372881,0.0260521042084,0.0749063670412,0.00322061191626,0.0180722891566,0.108108108108,0.0987654320988,0.4,0.022241796961,0.2,0.0909090909091,0.0,0.40625,0.0,0.0,0.053054662388,0.0188679245283,0.129097937384],
[0.0326086956522,0.2,0.0847457627119,0.0140280561122,0.0149812734082,0.000268384326355,0.0120481927711,0.027027027027,0.0246913580247,0.25,0.00352345298392,1.0,0.0,0.75,0.555555556,0.0,0.0,0.0192926045047,0.0188679245283,0.0983652512615],
[0.0978260869565,0.0,0.0,0.0100200400802,0.0711610486891,0.00214707461084,0.00803212851406,0.027027027027,0.111111111111,0.265625,0.0262056815679,1.0,0.0,0.0,0.518518519,0.0,0.0,0.0568060021635,0.0566037735849,0.213107498008],
[0.0760869565217,0.8,0.0,0.0180360721443,0.0936329588015,0.0,0.0120481927711,0.0810810810811,0.0864197530864,0.3333333335,0.0561550319313,0.0,0.0,0.863636364,0.342857143,0.5,0.333333333333,0.168121267841,0.169811320755,0.463705037033],
[0.0978260869565,1.0,0.0,0.0100200400802,0.063670411985,0.00697799248524,0.0,0.135135135135,0.0740740740741,0.4166666665,0.0156353226162,0.0,0.0,0.949367089,0.333333333,0.25,0.266666666667,0.0316184351626,0.0566037735849,0.163932249402],
[0.0326086956522,0.2,0.0,0.0380761523046,0.0374531835206,0.000805152979066,0.0281124497992,0.135135135135,0.037037037037,1.0,0.00836820083682,0.0,0.0,0.923076923,0.583333333,0.0,0.0,0.0562700964881,0.0188679245283,0.0491752486057],
[0.0108695652174,0.0,0.0,0.0200400801603,0.00374531835206,0.0,0.0160642570281,0.0540540540541,0.0123456790123,1.0,0.000220215811495,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0188679245283,0.147540499867],
[0.217391304348,0.0,0.0,0.0140280561122,0.295880149813,0.0365002683843,0.0100401606426,0.135135135135,0.123456790123,0.4487534625,0.183880202599,1.0,0.0909090909091,0.0,0.19375,0.0,0.0,0.191961414822,0.188679245283,0.287703974741],
[0.0652173913043,0.0,0.0,0.0160320641283,0.0224719101124,0.00402576489533,0.0140562248996,0.027027027027,0.0740740740741,1.0,0.00132129486897,0.0,0.0,0.0,0.444444444,0.0,0.0,0.0,0.0188679245283,0.147540499867],
[0.0326086956522,0.6,0.0,0.0100200400802,0.0411985018727,0.000268384326355,0.00200803212851,0.108108108108,0.0123456790123,0.25,0.00902884827131,1.0,0.0909090909091,0.971428571,0.75,0.25,0.133333333333,0.0594855305401,0.0566037735849,0.147540499867],
[0.119565217391,0.2,0.0,0.0140280561122,0.0973782771536,0.0,0.0100401606426,0.0540540540541,0.135802469136,0.29,0.0398590618806,1.0,0.0,0.529411765,0.409090909,0.0,0.0,0.0723472668927,0.0188679245283,0.107306205553],
[0.0326086956522,0.2,0.0,0.0100200400802,0.0262172284644,0.000268384326355,0.00200803212851,0.108108108108,0.037037037037,0.25,0.00638625853336,1.0,0.0,0.818181818,0.666666667,0.0,0.0,0.0401929260499,0.0188679245283,0.0983652512615],
[0.173913043478,0.4,0.0,0.0300601202405,0.243445692884,0.020397208803,0.0,0.405405405405,0.16049382716,0.46,0.106364236952,1.0,0.0,0.725490196,0.311111111,0.0,0.0,0.136254019315,0.169811320755,0.230532031043],
[0.163043478261,0.4,0.0,0.0180360721443,0.153558052434,0.0,0.0,0.243243243243,0.185185185185,0.3392857145,0.044924025545,1.0,0.0909090909091,0.725490196,0.225,0.25,0.133333333333,0.0594855305401,0.0377358490566,0.226223848446],
[0.152173913043,0.6,0.0508474576271,0.0220440881764,0.10861423221,0.0228126677402,0.00602409638554,0.216216216216,0.135802469136,0.2884615385,0.0237833076415,1.0,0.0909090909091,0.759259259,0.321428571,0.0,0.0,0.0316949931128,0.0754716981132,0.189692820679],
[0.29347826087,0.4,0.0,0.0160320641283,0.378277153558,0.0421363392378,0.0100401606426,0.0810810810811,0.185185185185,0.4123931625,0.283197533583,0.888888889,0.0909090909091,0.294117647,0.183760684,0.25,0.466666666667,0.220078599537,0.0754716981132,0.163932249402],
[0.0326086956522,0.0,0.0,0.00400801603206,0.0112359550562,0.000805152979066,0.00401606425703,0.0,0.037037037037,0.75,0.000880863245981,0.0,0.0,0.0,0.666666667,0.0,0.0,0.0,0.0188679245283,0.147540499867],
[0.597826086957,0.4,0.135593220339,0.0400801603206,0.397003745318,0.352388620505,0.0160642570281,0.324324324324,0.111111111111,0.4782763535,0.249504514424,1.0,0.181818181818,0.406593407,0.195454545,0.0,0.0,0.0922537270084,0.188679245283,0.273613857004]]
)


# define the RBM model
random_state = 200
model = BernoulliRBM(n_components=10,n_iter=10,random_state=random_state)

# building RBM and creating RBM features
# Each column means one feature, each row means one line of the train data.
RBM_feature_data = model.fit_transform(train_data)

print(RBM_feature_data)



Thank you!

Masanari Kondo
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.python.org/pipermail/scikit-learn/attachments/20170727/3fcb5d7f/attachment-0001.html>


More information about the scikit-learn mailing list