|
34 | 34 | from tensorflow.python.ops import nn |
35 | 35 | from tensorflow.python.ops import state_ops |
36 | 36 | from tensorflow.python.ops import variable_scope |
| 37 | +from tensorflow.python.ops import weights_broadcast_ops |
37 | 38 | from tensorflow.python.util.deprecation import deprecated |
38 | 39 |
|
39 | 40 |
|
@@ -651,7 +652,7 @@ def _streaming_confusion_matrix_at_thresholds( |
651 | 652 | label_is_neg = math_ops.logical_not(label_is_pos) |
652 | 653 |
|
653 | 654 | if weights is not None: |
654 | | - broadcast_weights = _broadcast_weights( |
| 655 | + broadcast_weights = weights_broadcast_ops.broadcast_weights( |
655 | 656 | math_ops.to_float(weights), predictions) |
656 | 657 | weights_tiled = array_ops.tile(array_ops.reshape( |
657 | 658 | broadcast_weights, [1, -1]), [num_thresholds, 1]) |
@@ -1924,7 +1925,7 @@ def streaming_covariance(predictions, |
1924 | 1925 | weighted_predictions = predictions |
1925 | 1926 | weighted_labels = labels |
1926 | 1927 | else: |
1927 | | - weights = _broadcast_weights(weights, labels) |
| 1928 | + weights = weights_broadcast_ops.broadcast_weights(weights, labels) |
1928 | 1929 | batch_count = math_ops.reduce_sum(weights) # n_B in eqn |
1929 | 1930 | weighted_predictions = math_ops.multiply(predictions, weights) |
1930 | 1931 | weighted_labels = math_ops.multiply(labels, weights) |
@@ -2051,7 +2052,7 @@ def streaming_pearson_correlation(predictions, |
2051 | 2052 | # Broadcast weights here to avoid duplicate broadcasting in each call to |
2052 | 2053 | # `streaming_covariance`. |
2053 | 2054 | if weights is not None: |
2054 | | - weights = _broadcast_weights(weights, labels) |
| 2055 | + weights = weights_broadcast_ops.broadcast_weights(weights, labels) |
2055 | 2056 | cov, update_cov = streaming_covariance( |
2056 | 2057 | predictions, labels, weights=weights, name='covariance') |
2057 | 2058 | var_predictions, update_var_predictions = streaming_covariance( |
|
0 commit comments