Skip to content

Commit 9b5010a

Browse files
authored
Merge pull request tensorflow#160 from adamcrume/master
Add support for while loops
2 parents 657b9c8 + 3dac89f commit 9b5010a

File tree

3 files changed

+365
-7
lines changed

3 files changed

+365
-7
lines changed

src/graph.rs

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ struct GraphLifetime;
3535
#[derive(Debug)]
3636
struct GraphImpl {
3737
inner: *mut tf::TF_Graph,
38+
owned: bool,
3839
}
3940

4041
unsafe impl Send for GraphImpl {}
@@ -43,8 +44,10 @@ unsafe impl Sync for GraphImpl {}
4344
impl Drop for GraphImpl {
4445
/// Graph will be deleted once no more Sessions are referencing it.
4546
fn drop(&mut self) {
46-
unsafe {
47-
tf::TF_DeleteGraph(self.inner);
47+
if self.owned {
48+
unsafe {
49+
tf::TF_DeleteGraph(self.inner);
50+
}
4851
}
4952
}
5053
}
@@ -273,7 +276,10 @@ impl Graph {
273276
pub fn new() -> Graph {
274277
unsafe {
275278
Graph {
276-
gimpl: Arc::new(GraphImpl { inner: tf::TF_NewGraph() }),
279+
gimpl: Arc::new(GraphImpl {
280+
inner: tf::TF_NewGraph(),
281+
owned: true,
282+
}),
277283
lifetime: GraphLifetime,
278284
}
279285
}
@@ -333,6 +339,32 @@ impl Graph {
333339
}
334340
}
335341

342+
/// Finds a unique operation name. The pattern must contain exactly one
343+
/// '{}' placeholder to indicate where a unique ID can be inserted, e.g.
344+
/// 'Add_{}' or 'while_loop_{}/Merge', and the function returns an integer
345+
/// which, when inserted into the placeholder, yields an operation name
346+
/// which does not appear in the graph.
347+
pub(crate) fn generate_operation_name(&self, operation_name_pattern: &str) -> Result<i64> {
348+
let parts: Vec<_> = operation_name_pattern.split("{}").collect();
349+
if parts.len() != 2 {
350+
return Err(invalid_arg!(
351+
"operation_name_pattern must contain placeholder"
352+
));
353+
}
354+
// Can't use format! because its argument must be a string literal.
355+
let mut i = 0;
356+
loop {
357+
let name = format!("{}{}{}", parts[0], i, parts[1]);
358+
let c_name = CString::new(name)?;
359+
unsafe {
360+
if tf::TF_GraphOperationByName(self.gimpl.inner, c_name.as_ptr()).is_null() {
361+
return Ok(i);
362+
}
363+
}
364+
i += 1;
365+
}
366+
}
367+
336368
/// Iterates over the operations in the graph.
337369
pub fn operation_iter(&self) -> OperationIter {
338370
OperationIter {
@@ -717,6 +749,16 @@ impl GraphTrait for Graph {
717749
fn inner(&self) -> *mut tf::TF_Graph {
718750
self.gimpl.inner
719751
}
752+
753+
unsafe fn from_c(inner: *mut tf::TF_Graph) -> Self {
754+
Graph {
755+
gimpl: Arc::new(GraphImpl {
756+
inner,
757+
owned: false,
758+
}),
759+
lifetime: GraphLifetime,
760+
}
761+
}
720762
}
721763

722764
////////////////////////
@@ -1523,14 +1565,14 @@ pub struct Output {
15231565
}
15241566

15251567
impl Output {
1526-
fn to_c(&self) -> tf::TF_Output {
1568+
pub(crate) fn to_c(&self) -> tf::TF_Output {
15271569
tf::TF_Output {
15281570
oper: self.operation.inner,
15291571
index: self.index,
15301572
}
15311573
}
15321574

1533-
fn from_c(graph: &Graph, output: &tf::TF_Output) -> Self {
1575+
pub(crate) fn from_c(graph: &Graph, output: &tf::TF_Output) -> Self {
15341576
Output {
15351577
operation: Operation {
15361578
inner: output.oper,
@@ -2479,4 +2521,17 @@ mod tests {
24792521
// We don't want to compare the actual proto because it may change across releases.
24802522
assert!(g.versions().unwrap().len() > 0);
24812523
}
2524+
2525+
#[test]
2526+
fn graph_generate_operation_name() {
2527+
let mut g = Graph::new();
2528+
for i in 0..5 {
2529+
assert_eq!(i, g.generate_operation_name("foo_{}").unwrap());
2530+
let mut nd = g.new_operation("Placeholder", &format!("foo_{}", i))
2531+
.unwrap();
2532+
nd.set_attr_type("dtype", DataType::Float).unwrap();
2533+
nd.set_attr_shape("shape", &Shape(Some(vec![]))).unwrap();
2534+
nd.finish().unwrap();
2535+
}
2536+
}
24822537
}

src/lib.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ use std::str::Utf8Error;
4747
/// Will panic if `msg` contains an embedded 0 byte.
4848
macro_rules! invalid_arg {
4949
($fmt:expr) => {
50-
Status::new_set(Code::InvalidArgument, $fmt).unwrap()
50+
::Status::new_set(::Code::InvalidArgument, $fmt).unwrap()
5151
};
5252
($fmt:expr, $($arg:tt)*) => ({
5353
let msg = format!($fmt, $($arg)*);
54-
Status::new_set(Code::InvalidArgument, &msg).unwrap()
54+
::Status::new_set(::Code::InvalidArgument, &msg).unwrap()
5555
});
5656
}
5757

@@ -1240,6 +1240,8 @@ impl Library {
12401240

12411241
////////////////////////
12421242

1243+
// TODO: Replace these with pub(crate)
1244+
12431245
/// This exposes Buffer behavior without making it public.
12441246
trait BufferTrait {
12451247
fn is_owned(&self) -> bool;
@@ -1251,6 +1253,7 @@ trait BufferTrait {
12511253
/// This exposes Graph behavior without making it public.
12521254
trait GraphTrait {
12531255
fn inner(&self) -> *mut tf::TF_Graph;
1256+
unsafe fn from_c(inner: *mut tf::TF_Graph) -> Self;
12541257
}
12551258

12561259

@@ -1323,6 +1326,11 @@ impl Display for Shape {
13231326

13241327
////////////////////////
13251328

1329+
mod while_loop;
1330+
pub use while_loop::*;
1331+
1332+
////////////////////////
1333+
13261334
#[cfg(test)]
13271335
mod tests {
13281336
use super::*;

0 commit comments

Comments
 (0)