@@ -18,9 +18,13 @@ use std;
18
18
use std:: ffi:: CStr ;
19
19
use std:: ffi:: CString ;
20
20
use std:: ffi:: NulError ;
21
+ use std:: fmt;
22
+ use std:: fmt:: Display ;
23
+ use std:: fmt:: Formatter ;
21
24
use std:: os:: raw:: c_void as std_c_void;
22
25
use std:: ptr;
23
26
use std:: slice;
27
+ use std:: str:: FromStr ;
24
28
use std:: str:: Utf8Error ;
25
29
use std:: sync:: Arc ;
26
30
use tensorflow_sys as tf;
@@ -1724,6 +1728,50 @@ impl Output {
1724
1728
} )
1725
1729
}
1726
1730
}
1731
+
1732
+ /// Returns the name of this output.
1733
+ pub fn name ( & self ) -> Result < OutputName > {
1734
+ Ok ( OutputName {
1735
+ name : self . operation . name ( ) ?,
1736
+ index : self . index ,
1737
+ } )
1738
+ }
1739
+ }
1740
+
1741
+ ////////////////////////
1742
+
1743
+ /// Names a specific Output in the graph.
1744
+ #[ derive( Clone , PartialEq , Eq , Hash , Debug , Default ) ]
1745
+ pub struct OutputName {
1746
+ /// Name of the operation the edge connects to.
1747
+ pub name : String ,
1748
+
1749
+ /// Index into either the outputs of the operation.
1750
+ pub index : c_int ,
1751
+ }
1752
+
1753
+ impl FromStr for OutputName {
1754
+ type Err = Status ;
1755
+ fn from_str ( s : & str ) -> Result < Self > {
1756
+ let splits: Vec < _ > = s. split ( ':' ) . collect ( ) ;
1757
+ if splits. len ( ) != 2 {
1758
+ return Err ( Status :: new_set_lossy (
1759
+ Code :: InvalidArgument ,
1760
+ "Name must contain exactly one colon (':')" ,
1761
+ ) ) ;
1762
+ }
1763
+ let index = splits[ 1 ] . parse :: < c_int > ( ) ?;
1764
+ Ok ( Self {
1765
+ name : splits[ 0 ] . to_string ( ) ,
1766
+ index,
1767
+ } )
1768
+ }
1769
+ }
1770
+
1771
+ impl Display for OutputName {
1772
+ fn fmt ( & self , f : & mut Formatter < ' _ > ) -> fmt:: Result {
1773
+ write ! ( f, "{}:{}" , self . name, self . index)
1774
+ }
1727
1775
}
1728
1776
1729
1777
////////////////////////
@@ -2847,4 +2895,26 @@ mod tests {
2847
2895
assert_eq ! ( consumers[ 0 ] . 0 . name( ) . unwrap( ) , "y" ) ;
2848
2896
assert_eq ! ( consumers[ 0 ] . 1 , 0 ) ;
2849
2897
}
2898
+
2899
+ #[ test]
2900
+ fn output_name ( ) {
2901
+ assert_eq ! (
2902
+ "foo:1" . parse:: <OutputName >( ) . unwrap( ) ,
2903
+ OutputName {
2904
+ name: "foo" . to_string( ) ,
2905
+ index: 1
2906
+ }
2907
+ ) ;
2908
+ assert_eq ! (
2909
+ OutputName {
2910
+ name: "foo" . to_string( ) ,
2911
+ index: 1
2912
+ }
2913
+ . to_string( ) ,
2914
+ "foo:1"
2915
+ ) ;
2916
+ assert ! ( "foo" . parse:: <OutputName >( ) . is_err( ) ) ;
2917
+ assert ! ( "foo:bar" . parse:: <OutputName >( ) . is_err( ) ) ;
2918
+ assert ! ( "foo:0:1" . parse:: <OutputName >( ) . is_err( ) ) ;
2919
+ }
2850
2920
}
0 commit comments