Skip to content

Commit 66ea133

Browse files
committed
Add in an association function
1 parent 194c231 commit 66ea133

File tree

2 files changed

+59
-2
lines changed

2 files changed

+59
-2
lines changed

sklearn/tree/_criterion.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ cdef class Criterion:
6262
cdef double node_impurity(self) nogil
6363
cdef void children_impurity(self, double* impurity_left,
6464
double* impurity_right) nogil
65+
cdef void children_af(self, double* impurity_left,
66+
double* impurity_right) nogil
6567
cdef void node_value(self, double* dest) nogil
6668
cdef double impurity_improvement(self, double impurity) nogil
6769
cdef double proxy_impurity_improvement(self) nogil

sklearn/tree/_criterion.pyx

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ cdef class Criterion:
142142

143143
pass
144144

145+
cdef void children_af(self, double* impurity_left,
146+
double* impurity_right) nogil:
147+
pass
148+
145149
cdef void node_value(self, double* dest) nogil:
146150
"""Placeholder for storing the node value.
147151
@@ -176,6 +180,7 @@ cdef class Criterion:
176180

177181
cdef double impurity_improvement(self, double impurity) nogil:
178182
"""Compute the improvement in impurity
183+
gain
179184
180185
This method computes the improvement in impurity when a split occurs.
181186
The weighted impurity improvement equation is the following:
@@ -199,14 +204,17 @@ cdef class Criterion:
199204

200205
cdef double impurity_left
201206
cdef double impurity_right
207+
cdef double af_left
208+
cdef double af_right
202209

203210
self.children_impurity(&impurity_left, &impurity_right)
211+
self.children_af(&af_left,&af_right)
204212

205213
return ((self.weighted_n_node_samples / self.weighted_n_samples) *
206214
(impurity - (self.weighted_n_right /
207-
self.weighted_n_node_samples * impurity_right)
215+
self.weighted_n_node_samples * impurity_right * af_right)
208216
- (self.weighted_n_left /
209-
self.weighted_n_node_samples * impurity_left)))
217+
self.weighted_n_node_samples * impurity_left * af_left)))
210218

211219

212220
cdef class ClassificationCriterion(Criterion):
@@ -597,6 +605,53 @@ cdef class Entropy(ClassificationCriterion):
597605
impurity_left[0] = entropy_left / self.n_outputs
598606
impurity_right[0] = entropy_right / self.n_outputs
599607

608+
cdef void children_af(self, double* impurity_left,
609+
double* impurity_right) nogil:
610+
"""Evaluate the impurity in children nodes
611+
612+
i.e. the impurity of the left child (samples[start:pos]) and the
613+
impurity the right child (samples[pos:end]).
614+
615+
Parameters
616+
----------
617+
impurity_left : double pointer
618+
The memory address to save the impurity of the left node
619+
impurity_right : double pointer
620+
The memory address to save the impurity of the right node
621+
"""
622+
623+
cdef SIZE_t* n_classes = self.n_classes
624+
cdef double* sum_left = self.sum_left
625+
cdef double* sum_right = self.sum_right
626+
cdef double entropy_left = 0.0
627+
cdef double entropy_right = 0.0
628+
cdef double af_left = 0.0
629+
cdef double af_right = 0.0
630+
cdef double count_k
631+
cdef double two = 2
632+
cdef double one = 1
633+
cdef SIZE_t k
634+
cdef SIZE_t c
635+
636+
for k in range(self.n_outputs):
637+
for c in range(n_classes[k]):
638+
count_k = sum_left[c]
639+
if count_k > 0.0:
640+
count_k /= self.weighted_n_left
641+
entropy_left -= count_k * log(count_k)
642+
af_left = fabs(count_k*two - one)
643+
644+
count_k = sum_right[c]
645+
if count_k > 0.0:
646+
count_k /= self.weighted_n_right
647+
entropy_right -= count_k * log(count_k)
648+
af_right = fabs(count_k*two - one)
649+
650+
sum_left += self.sum_stride
651+
sum_right += self.sum_stride
652+
653+
impurity_left[0] = af_left / self.n_outputs
654+
impurity_right[0] = af_right / self.n_outputs
600655

601656
cdef class Gini(ClassificationCriterion):
602657
"""Gini Index impurity criterion.

0 commit comments

Comments
 (0)