@@ -35,6 +35,7 @@ struct GraphLifetime;
35
35
#[ derive( Debug ) ]
36
36
struct GraphImpl {
37
37
inner : * mut tf:: TF_Graph ,
38
+ owned : bool ,
38
39
}
39
40
40
41
unsafe impl Send for GraphImpl { }
@@ -43,8 +44,10 @@ unsafe impl Sync for GraphImpl {}
43
44
impl Drop for GraphImpl {
44
45
/// Graph will be deleted once no more Sessions are referencing it.
45
46
fn drop ( & mut self ) {
46
- unsafe {
47
- tf:: TF_DeleteGraph ( self . inner ) ;
47
+ if self . owned {
48
+ unsafe {
49
+ tf:: TF_DeleteGraph ( self . inner ) ;
50
+ }
48
51
}
49
52
}
50
53
}
@@ -273,7 +276,10 @@ impl Graph {
273
276
pub fn new ( ) -> Graph {
274
277
unsafe {
275
278
Graph {
276
- gimpl : Arc :: new ( GraphImpl { inner : tf:: TF_NewGraph ( ) } ) ,
279
+ gimpl : Arc :: new ( GraphImpl {
280
+ inner : tf:: TF_NewGraph ( ) ,
281
+ owned : true ,
282
+ } ) ,
277
283
lifetime : GraphLifetime ,
278
284
}
279
285
}
@@ -333,6 +339,32 @@ impl Graph {
333
339
}
334
340
}
335
341
342
+ /// Finds a unique operation name. The pattern must contain exactly one
343
+ /// '{}' placeholder to indicate where a unique ID can be inserted, e.g.
344
+ /// 'Add_{}' or 'while_loop_{}/Merge', and the function returns an integer
345
+ /// which, when inserted into the placeholder, yields an operation name
346
+ /// which does not appear in the graph.
347
+ pub ( crate ) fn generate_operation_name ( & self , operation_name_pattern : & str ) -> Result < i64 > {
348
+ let parts: Vec < _ > = operation_name_pattern. split ( "{}" ) . collect ( ) ;
349
+ if parts. len ( ) != 2 {
350
+ return Err ( invalid_arg ! (
351
+ "operation_name_pattern must contain placeholder"
352
+ ) ) ;
353
+ }
354
+ // Can't use format! because its argument must be a string literal.
355
+ let mut i = 0 ;
356
+ loop {
357
+ let name = format ! ( "{}{}{}" , parts[ 0 ] , i, parts[ 1 ] ) ;
358
+ let c_name = CString :: new ( name) ?;
359
+ unsafe {
360
+ if tf:: TF_GraphOperationByName ( self . gimpl . inner , c_name. as_ptr ( ) ) . is_null ( ) {
361
+ return Ok ( i) ;
362
+ }
363
+ }
364
+ i += 1 ;
365
+ }
366
+ }
367
+
336
368
/// Iterates over the operations in the graph.
337
369
pub fn operation_iter ( & self ) -> OperationIter {
338
370
OperationIter {
@@ -717,6 +749,16 @@ impl GraphTrait for Graph {
717
749
fn inner ( & self ) -> * mut tf:: TF_Graph {
718
750
self . gimpl . inner
719
751
}
752
+
753
+ unsafe fn from_c ( inner : * mut tf:: TF_Graph ) -> Self {
754
+ Graph {
755
+ gimpl : Arc :: new ( GraphImpl {
756
+ inner,
757
+ owned : false ,
758
+ } ) ,
759
+ lifetime : GraphLifetime ,
760
+ }
761
+ }
720
762
}
721
763
722
764
////////////////////////
@@ -1523,14 +1565,14 @@ pub struct Output {
1523
1565
}
1524
1566
1525
1567
impl Output {
1526
- fn to_c ( & self ) -> tf:: TF_Output {
1568
+ pub ( crate ) fn to_c ( & self ) -> tf:: TF_Output {
1527
1569
tf:: TF_Output {
1528
1570
oper : self . operation . inner ,
1529
1571
index : self . index ,
1530
1572
}
1531
1573
}
1532
1574
1533
- fn from_c ( graph : & Graph , output : & tf:: TF_Output ) -> Self {
1575
+ pub ( crate ) fn from_c ( graph : & Graph , output : & tf:: TF_Output ) -> Self {
1534
1576
Output {
1535
1577
operation : Operation {
1536
1578
inner : output. oper ,
@@ -2479,4 +2521,17 @@ mod tests {
2479
2521
// We don't want to compare the actual proto because it may change across releases.
2480
2522
assert ! ( g. versions( ) . unwrap( ) . len( ) > 0 ) ;
2481
2523
}
2524
+
2525
+ #[ test]
2526
+ fn graph_generate_operation_name ( ) {
2527
+ let mut g = Graph :: new ( ) ;
2528
+ for i in 0 ..5 {
2529
+ assert_eq ! ( i, g. generate_operation_name( "foo_{}" ) . unwrap( ) ) ;
2530
+ let mut nd = g. new_operation ( "Placeholder" , & format ! ( "foo_{}" , i) )
2531
+ . unwrap ( ) ;
2532
+ nd. set_attr_type ( "dtype" , DataType :: Float ) . unwrap ( ) ;
2533
+ nd. set_attr_shape ( "shape" , & Shape ( Some ( vec ! [ ] ) ) ) . unwrap ( ) ;
2534
+ nd. finish ( ) . unwrap ( ) ;
2535
+ }
2536
+ }
2482
2537
}
0 commit comments