Skip to content

Commit 2b2b5f5

Browse files
committed
Add enough implementation for a smoke test with Rust types
1 parent a8bdf88 commit 2b2b5f5

File tree

2 files changed

+194
-11
lines changed

2 files changed

+194
-11
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ version = "0.1.0"
44
authors = ["Adam Crume <[email protected]>"]
55

66
[dependencies]
7+
libc = "0.1"
78
libtensorflow-sys = { path = "libtensorflow-sys" }

src/lib.rs

Lines changed: 193 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,197 @@
1+
extern crate libc;
12
extern crate libtensorflow_sys;
23

3-
#[test]
4-
fn smoke() {
5-
use libtensorflow_sys::*;
6-
7-
unsafe {
8-
let session_options = TF_NewSessionOptions();
9-
let status = TF_NewStatus();
10-
let session = TF_NewSession(session_options, status);
11-
TF_DeleteSession(session, status);
12-
TF_DeleteStatus(status);
13-
TF_DeleteSessionOptions(session_options);
4+
use std::ffi::CStr;
5+
use std::fmt;
6+
use std::fmt::Display;
7+
use std::fmt::Formatter;
8+
use std::ops::Drop;
9+
10+
use libtensorflow_sys as tf;
11+
12+
////////////////////////
13+
14+
fn check_not_null<T>(p: *mut T) -> *mut T {
15+
assert!(!p.is_null());
16+
p
17+
}
18+
19+
////////////////////////
20+
21+
macro_rules! impl_new {
22+
($name: ident, $call:ident) => {
23+
impl $name {
24+
pub fn new() -> Self {
25+
unsafe {
26+
$name {
27+
inner: check_not_null(tf::$call()),
28+
}
29+
}
30+
}
31+
}
32+
}
33+
}
34+
35+
////////////////////////
36+
37+
macro_rules! impl_drop {
38+
($name: ident, $call:ident) => {
39+
impl Drop for $name {
40+
fn drop(&mut self) {
41+
unsafe {
42+
tf::$call(self.inner);
43+
}
44+
}
45+
}
46+
}
47+
}
48+
49+
////////////////////////
50+
51+
macro_rules! c_enum {
52+
($enum_name:ident { $($name:ident = $num:expr),* }) => {
53+
#[derive(PartialEq,Eq,PartialOrd,Ord,Debug)]
54+
pub enum $enum_name {
55+
UnrecognizedEnumValue(::libc::c_uint),
56+
$($name),*
57+
}
58+
59+
impl $enum_name {
60+
#[allow(dead_code)]
61+
fn from_int(value: ::libc::c_uint) -> $enum_name {
62+
match value {
63+
$($num => $enum_name::$name,)*
64+
c => $enum_name::UnrecognizedEnumValue(c),
65+
}
66+
}
67+
68+
#[allow(dead_code)]
69+
fn to_int(&self) -> ::libc::c_uint {
70+
match self {
71+
&$enum_name::UnrecognizedEnumValue(c) => c,
72+
$(&$enum_name::$name => $num),*
73+
}
74+
}
75+
}
76+
77+
impl ::std::fmt::Display for $enum_name {
78+
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
79+
match self {
80+
$(&$enum_name::$name => f.write_str(stringify!($name)),)*
81+
&$enum_name::UnrecognizedEnumValue(c) => write!(f, "UnrecognizedEnumValue({})", c),
82+
}
83+
}
84+
}
85+
};
86+
($enum_name:ident { $($name:ident = $num:expr,)* }) => {
87+
c_enum!($enum_name { $($name = $num),* });
88+
}
89+
}
90+
91+
////////////////////////
92+
93+
c_enum!(Code {
94+
Ok = 0,
95+
Cancelled = 1,
96+
Unknown = 2,
97+
InvalidArgument = 3,
98+
DeadlineExceeded = 4,
99+
NotFound = 5,
100+
AlreadyExists = 6,
101+
PermissionDenied = 7,
102+
ResourceExhausted = 8,
103+
FailedPrecondition = 9,
104+
Aborted = 10,
105+
OutOfRange = 11,
106+
Unimplemented = 12,
107+
Internal = 13,
108+
Unavailable = 14,
109+
DataLoss = 15,
110+
Unauthenticated = 16,
111+
});
112+
113+
////////////////////////
114+
115+
pub struct Status {
116+
inner: *mut tf::TF_Status,
117+
}
118+
119+
impl_new!(Status, TF_NewStatus);
120+
impl_drop!(Status, TF_DeleteStatus);
121+
122+
impl Status {
123+
pub fn code(&self) -> Code {
124+
unsafe {
125+
Code::from_int(tf::TF_GetCode(self.inner))
126+
}
127+
}
128+
}
129+
130+
impl Display for Status {
131+
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
132+
unsafe {
133+
try!(write!(f, "{}: ", self.code()));
134+
let msg = match CStr::from_ptr(tf::TF_Message(self.inner)).to_str() {
135+
Ok(s) => s,
136+
Err(_) => "<invalid UTF-8 in message>",
137+
};
138+
f.write_str(msg)
139+
}
140+
}
141+
}
142+
143+
////////////////////////
144+
145+
pub struct SessionOptions {
146+
inner: *mut tf::TF_SessionOptions,
147+
}
148+
149+
impl_new!(SessionOptions, TF_NewSessionOptions);
150+
impl_drop!(SessionOptions, TF_DeleteSessionOptions);
151+
152+
////////////////////////
153+
154+
pub struct Session {
155+
inner: *mut tf::TF_Session,
156+
}
157+
158+
impl Session {
159+
pub fn new(options: &SessionOptions) -> (Option<Self>, Status) {
160+
let status = Status::new();
161+
let inner = unsafe { tf::TF_NewSession(options.inner, status.inner) };
162+
let session = if inner.is_null() {
163+
None
164+
} else {
165+
Some(Session {
166+
inner: inner,
167+
})
168+
};
169+
(session, status)
170+
}
171+
}
172+
173+
impl Drop for Session {
174+
fn drop(&mut self) {
175+
let status = Status::new();
176+
unsafe {
177+
tf::TF_DeleteSession(self.inner, status.inner);
178+
}
179+
// TODO: What do we do with the status?
180+
}
181+
}
182+
183+
////////////////////////
184+
185+
#[cfg(test)]
186+
mod tests {
187+
use super::*;
188+
189+
#[test]
190+
fn smoke() {
191+
let options = SessionOptions::new();
192+
match Session::new(&options) {
193+
(Some(_), status) => assert_eq!(status.code(), Code::Ok),
194+
(None, status) => panic!("Creating session failed with status: {}", status),
195+
}
14196
}
15197
}

0 commit comments

Comments
 (0)