@@ -142,6 +142,10 @@ cdef class Criterion:
142142
143143 pass
144144
145+ cdef void children_af(self , double * impurity_left,
146+ double * impurity_right) nogil:
147+ pass
148+
145149 cdef void node_value(self , double * dest) nogil:
146150 """ Placeholder for storing the node value.
147151
@@ -176,6 +180,7 @@ cdef class Criterion:
176180
177181 cdef double impurity_improvement(self , double impurity) nogil:
178182 """ Compute the improvement in impurity
183+ gain
179184
180185 This method computes the improvement in impurity when a split occurs.
181186 The weighted impurity improvement equation is the following:
@@ -199,14 +204,17 @@ cdef class Criterion:
199204
200205 cdef double impurity_left
201206 cdef double impurity_right
207+ cdef double af_left
208+ cdef double af_right
202209
203210 self .children_impurity(& impurity_left, & impurity_right)
211+ self .children_af(& af_left,& af_right)
204212
205213 return ((self .weighted_n_node_samples / self .weighted_n_samples) *
206214 (impurity - (self .weighted_n_right /
207- self .weighted_n_node_samples * impurity_right)
215+ self .weighted_n_node_samples * impurity_right * af_right )
208216 - (self .weighted_n_left /
209- self .weighted_n_node_samples * impurity_left)))
217+ self .weighted_n_node_samples * impurity_left * af_left )))
210218
211219
212220cdef class ClassificationCriterion(Criterion):
@@ -597,6 +605,53 @@ cdef class Entropy(ClassificationCriterion):
597605 impurity_left[0 ] = entropy_left / self .n_outputs
598606 impurity_right[0 ] = entropy_right / self .n_outputs
599607
608+ cdef void children_af(self , double * impurity_left,
609+ double * impurity_right) nogil:
610+ """ Evaluate the impurity in children nodes
611+
612+ i.e. the impurity of the left child (samples[start:pos]) and the
613+ impurity the right child (samples[pos:end]).
614+
615+ Parameters
616+ ----------
617+ impurity_left : double pointer
618+ The memory address to save the impurity of the left node
619+ impurity_right : double pointer
620+ The memory address to save the impurity of the right node
621+ """
622+
623+ cdef SIZE_t* n_classes = self .n_classes
624+ cdef double * sum_left = self .sum_left
625+ cdef double * sum_right = self .sum_right
626+ cdef double entropy_left = 0.0
627+ cdef double entropy_right = 0.0
628+ cdef double af_left = 0.0
629+ cdef double af_right = 0.0
630+ cdef double count_k
631+ cdef double two = 2
632+ cdef double one = 1
633+ cdef SIZE_t k
634+ cdef SIZE_t c
635+
636+ for k in range (self .n_outputs):
637+ for c in range (n_classes[k]):
638+ count_k = sum_left[c]
639+ if count_k > 0.0 :
640+ count_k /= self .weighted_n_left
641+ entropy_left -= count_k * log(count_k)
642+ af_left = fabs(count_k* two - one)
643+
644+ count_k = sum_right[c]
645+ if count_k > 0.0 :
646+ count_k /= self .weighted_n_right
647+ entropy_right -= count_k * log(count_k)
648+ af_right = fabs(count_k* two - one)
649+
650+ sum_left += self .sum_stride
651+ sum_right += self .sum_stride
652+
653+ impurity_left[0 ] = af_left / self .n_outputs
654+ impurity_right[0 ] = af_right / self .n_outputs
600655
601656cdef class Gini(ClassificationCriterion):
602657 """ Gini Index impurity criterion.
0 commit comments