Skip to content

Commit df849b7

Browse files
alanyeeyifeif
authored andcommitted
Update metrics_op.py (tensorflow#12586)
Using core method over private
1 parent 92e7793 commit df849b7

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

tensorflow/contrib/metrics/python/ops/metric_ops.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from tensorflow.python.ops import nn
3535
from tensorflow.python.ops import state_ops
3636
from tensorflow.python.ops import variable_scope
37+
from tensorflow.python.ops import weights_broadcast_ops
3738
from tensorflow.python.util.deprecation import deprecated
3839

3940

@@ -651,7 +652,7 @@ def _streaming_confusion_matrix_at_thresholds(
651652
label_is_neg = math_ops.logical_not(label_is_pos)
652653

653654
if weights is not None:
654-
broadcast_weights = _broadcast_weights(
655+
broadcast_weights = weights_broadcast_ops.broadcast_weights(
655656
math_ops.to_float(weights), predictions)
656657
weights_tiled = array_ops.tile(array_ops.reshape(
657658
broadcast_weights, [1, -1]), [num_thresholds, 1])
@@ -1924,7 +1925,7 @@ def streaming_covariance(predictions,
19241925
weighted_predictions = predictions
19251926
weighted_labels = labels
19261927
else:
1927-
weights = _broadcast_weights(weights, labels)
1928+
weights = weights_broadcast_ops.broadcast_weights(weights, labels)
19281929
batch_count = math_ops.reduce_sum(weights) # n_B in eqn
19291930
weighted_predictions = math_ops.multiply(predictions, weights)
19301931
weighted_labels = math_ops.multiply(labels, weights)
@@ -2051,7 +2052,7 @@ def streaming_pearson_correlation(predictions,
20512052
# Broadcast weights here to avoid duplicate broadcasting in each call to
20522053
# `streaming_covariance`.
20532054
if weights is not None:
2054-
weights = _broadcast_weights(weights, labels)
2055+
weights = weights_broadcast_ops.broadcast_weights(weights, labels)
20552056
cov, update_cov = streaming_covariance(
20562057
predictions, labels, weights=weights, name='covariance')
20572058
var_predictions, update_var_predictions = streaming_covariance(

0 commit comments

Comments
 (0)