Skip to content

Commit 15ee7b7

Browse files
committed
Add OutputName
1 parent ab12c39 commit 15ee7b7

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed

src/graph.rs

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,13 @@ use std;
1818
use std::ffi::CStr;
1919
use std::ffi::CString;
2020
use std::ffi::NulError;
21+
use std::fmt;
22+
use std::fmt::Display;
23+
use std::fmt::Formatter;
2124
use std::os::raw::c_void as std_c_void;
2225
use std::ptr;
2326
use std::slice;
27+
use std::str::FromStr;
2428
use std::str::Utf8Error;
2529
use std::sync::Arc;
2630
use tensorflow_sys as tf;
@@ -1724,6 +1728,50 @@ impl Output {
17241728
})
17251729
}
17261730
}
1731+
1732+
/// Returns the name of this output.
1733+
pub fn name(&self) -> Result<OutputName> {
1734+
Ok(OutputName {
1735+
name: self.operation.name()?,
1736+
index: self.index,
1737+
})
1738+
}
1739+
}
1740+
1741+
////////////////////////
1742+
1743+
/// Names a specific Output in the graph.
1744+
#[derive(Clone, PartialEq, Eq, Hash, Debug, Default)]
1745+
pub struct OutputName {
1746+
/// Name of the operation the edge connects to.
1747+
pub name: String,
1748+
1749+
/// Index into either the outputs of the operation.
1750+
pub index: c_int,
1751+
}
1752+
1753+
impl FromStr for OutputName {
1754+
type Err = Status;
1755+
fn from_str(s: &str) -> Result<Self> {
1756+
let splits: Vec<_> = s.split(':').collect();
1757+
if splits.len() != 2 {
1758+
return Err(Status::new_set_lossy(
1759+
Code::InvalidArgument,
1760+
"Name must contain exactly one colon (':')",
1761+
));
1762+
}
1763+
let index = splits[1].parse::<c_int>()?;
1764+
Ok(Self {
1765+
name: splits[0].to_string(),
1766+
index,
1767+
})
1768+
}
1769+
}
1770+
1771+
impl Display for OutputName {
1772+
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
1773+
write!(f, "{}:{}", self.name, self.index)
1774+
}
17271775
}
17281776

17291777
////////////////////////
@@ -2847,4 +2895,26 @@ mod tests {
28472895
assert_eq!(consumers[0].0.name().unwrap(), "y");
28482896
assert_eq!(consumers[0].1, 0);
28492897
}
2898+
2899+
#[test]
2900+
fn output_name() {
2901+
assert_eq!(
2902+
"foo:1".parse::<OutputName>().unwrap(),
2903+
OutputName {
2904+
name: "foo".to_string(),
2905+
index: 1
2906+
}
2907+
);
2908+
assert_eq!(
2909+
OutputName {
2910+
name: "foo".to_string(),
2911+
index: 1
2912+
}
2913+
.to_string(),
2914+
"foo:1"
2915+
);
2916+
assert!("foo".parse::<OutputName>().is_err());
2917+
assert!("foo:bar".parse::<OutputName>().is_err());
2918+
assert!("foo:0:1".parse::<OutputName>().is_err());
2919+
}
28502920
}

src/lib.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ use std::fmt::Display;
3232
use std::fmt::Formatter;
3333
use std::marker::PhantomData;
3434
use std::mem;
35+
use std::num::ParseIntError;
3536
use std::ops::Deref;
3637
use std::ops::DerefMut;
3738
use std::ops::Drop;
@@ -528,6 +529,12 @@ impl From<IntoStringError> for Status {
528529
}
529530
}
530531

532+
impl From<ParseIntError> for Status {
533+
fn from(e: ParseIntError) -> Self {
534+
invalid_arg!("Error parsing an integer: {}", e.description())
535+
}
536+
}
537+
531538
impl Error for Status {
532539
fn description(&self) -> &str {
533540
unsafe {

0 commit comments

Comments
 (0)