Skip to content

Commit 0db8543

Browse files
committed
Return Result<(), Status> instead of Status
Makes it easier to use standard patterns like the try! macro.
1 parent 7aa45bb commit 0db8543

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

src/lib.rs

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,14 @@ impl Status {
176176
self.code() == Code::Ok
177177
}
178178

179+
fn as_result(self) -> Result<()> {
180+
if self.is_ok() {
181+
Ok(())
182+
} else {
183+
Err(self)
184+
}
185+
}
186+
179187
/// Sets the code and message.
180188
pub fn set(&mut self, code: Code, msg: &str) -> std::result::Result<(), NulError> {
181189
let message = try!(CString::new(msg)).as_ptr();
@@ -281,25 +289,25 @@ impl Session {
281289
}
282290

283291
/// Closes the session.
284-
pub fn close(&mut self) -> Status {
292+
pub fn close(&mut self) -> Result<()> {
285293
let status = Status::new();
286294
unsafe {
287295
tf::TF_CloseSession(self.inner, status.inner);
288296
}
289-
status
297+
status.as_result()
290298
}
291299

292300
/// Treat `proto` as a serialized `GraphDef` and add the nodes in that `GraphDef` to the graph for the session.
293-
pub fn extend_graph(&mut self, proto: &[u8]) -> Status {
301+
pub fn extend_graph(&mut self, proto: &[u8]) -> Result<()> {
294302
let status = Status::new();
295303
unsafe {
296304
tf::TF_ExtendGraph(self.inner, proto.as_ptr() as *const raw::c_void, proto.len(), status.inner);
297305
}
298-
status
306+
status.as_result()
299307
}
300308

301309
/// Runs the graph, feeding the inputs and then fetching the outputs requested in the step.
302-
pub fn run(&mut self, step: &mut Step) -> Status {
310+
pub fn run(&mut self, step: &mut Step) -> Result<()> {
303311
// Copy the input tensors because TF_Run consumes them.
304312
let mut input_tensors = Vec::with_capacity(step.input_tensors.len());
305313
for &input_tensor in &step.input_tensors {
@@ -339,7 +347,7 @@ impl Session {
339347
std::ptr::null_mut(),
340348
status.inner);
341349
};
342-
status
350+
status.as_result()
343351
}
344352
}
345353

@@ -709,8 +717,7 @@ mod tests {
709717
let mut step = Step::new();
710718
step.add_input("x:0", &x).unwrap();
711719
let output_ix = step.request_output("y:0").unwrap();
712-
let result = session.run(&mut step);
713-
assert!(result.is_ok(), "{}", result);
720+
session.run(&mut step).unwrap();
714721
let output_tensor = step.take_output::<f32>(output_ix).unwrap();
715722
let data = output_tensor.data();
716723
assert_eq!(data.len(), 2);

0 commit comments

Comments
 (0)