[scikit-learn] Understanding sklearn.tree._tree.value object
Mon Oct 8 17:31:43 EDT 2018
Hi Pranav,
The reason you're getting that output is that your first column has a
single value (1), and that becomes your "first" class, hence your first
value in the rows you're interpreting.
To understand it better, you can try to check this code:
>>> from sklearn.preprocessing import MultiLabelBinarizer
>>> from sklearn.tree import DecisionTreeClassifier
>>>
>>> X = [[2, 51], [3, 20], [5, 30], [7, 1], [20, 46], [25, 25], [45, 70]]
>>> Y = [[2,3],[1,2,3],[1,2,3],[1,2],[1,2],[1],[1]]
>>>
>>> y = MultiLabelBinarizer().fit_transform(Y) + 40
>>> y[0, 1] = 0
>>>
>>> clf = DecisionTreeClassifier().fit(X, y)
>>> print(clf.tree_.value)
[[[1. 6. 0.]
[1. 2. 4.]
[4. 3. 0.]]
[[1. 2. 0.]
[1. 0. 2.]
[0. 3. 0.]]
[[0. 2. 0.]
[0. 0. 2.]
[0. 2. 0.]]
[[1. 0. 0.]
[1. 0. 0.]
[0. 1. 0.]]
[[0. 4. 0.]
[0. 2. 2.]
[4. 0. 0.]]
[[0. 2. 0.]
[0. 0. 2.]
[2. 0. 0.]]
[[0. 2. 0.]
[0. 2. 0.]
[2. 0. 0.]]]
On Mon, 8 Oct 2018 at 20:53 Pranav Ashok <pranavashok at gmail.com> wrote:
> I have a multi-class multi-label decision tree learnt using
> DecisionTreeClassifier class. The input looks like follows:
>
> X = [[2, 51], [3, 20], [5, 30], [7, 1], [20, 46], [25, 25], [45, 70]]
> Y = [[1,2,3],[1,2,3],[1,2,3],[1,2],[1,2],[1],[1]]
>
> I have used MultiLabelBinarizer to convert Y into
>
> [[1 1 1]
> [1 1 1]
> [1 1 1]
> [1 1 0]
> [1 1 0]
> [1 0 0]
> [1 0 0]]
>
>
> After training, the _tree.values looks like follows:
>
> array([[[7., 0.],
> [2., 5.],
> [4., 3.]],
>
> [[3., 0.],
> [0., 3.],
> [0., 3.]],
>
> [[4., 0.],
> [2., 2.],
> [4., 0.]],
>
> [[2., 0.],
> [0., 2.],
> [2., 0.]],
>
> [[2., 0.],
> [2., 0.],
> [2., 0.]]])
>
> I had the impression that the value array contains for each node, a list of lists [[n_1, y_1], [n_2, y_2], [n_3, y_3]]
> such that n_i are the number of samples disagreeing with class i and y_i are the number of samples agreeing with
> class i. But after seeing this output, it does not make sense.
>
> For example, the root node has the value [[7,0],[2,5],[4,3]]. According to my interpretation, this would mean
> 7 samples disagree with class 1; 2 disagree with class 2 and 5 agree with class 2; 4 disagree with class 3 and 3 agree with class 3.
>
> which, according to the input dataset is wrong.
>
> Could someone please help me understand the semantics of _tree.value for multi-label DTs?
>
