Skip to content

Commit a51e3dd

Browse files
committed
Add docs
1 parent db27396 commit a51e3dd

File tree

2 files changed

+68
-10
lines changed

2 files changed

+68
-10
lines changed

src/buffer.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ use std::ops::RangeTo;
1515
use std::slice;
1616

1717
/// Fixed-length heap-allocated vector.
18-
/// This is basically a Box<[T]>, except that that type can't actually be constructed.
19-
/// Furthermore, [T; N] can't be constructed if N is not a compile-time constant.
18+
/// This is basically a `Box<[T]>`, except that that type can't actually be constructed.
19+
/// Furthermore, `[T; N]` can't be constructed if N is not a compile-time constant.
2020
#[derive(Debug)]
2121
pub struct Buffer<T> {
2222
ptr: *mut T,
@@ -25,6 +25,9 @@ pub struct Buffer<T> {
2525
}
2626

2727
impl<T: Default> Buffer<T> {
28+
/// Creates a new buffer initialized to zeros.
29+
///
30+
/// `len` is the number of elements.
2831
pub fn new(len: usize) -> Self {
2932
let mut b = unsafe {
3033
Buffer::new_uninitialized(len)
@@ -38,6 +41,10 @@ impl<T: Default> Buffer<T> {
3841
}
3942

4043
impl<T> Buffer<T> {
44+
/// Creates a new uninitialized buffer.
45+
///
46+
/// `len` is the number of elements.
47+
/// The caller is responsible for initializing the data.
4148
pub unsafe fn new_uninitialized(len: usize) -> Self {
4249
let elem_size = mem::size_of::<T>();
4350
let alloc_size = len * elem_size;
@@ -54,6 +61,10 @@ impl<T> Buffer<T> {
5461
}
5562
}
5663

64+
/// Creates a buffer from data owned by the C API.
65+
///
66+
/// `len` is the number of elements.
67+
/// The underlying data is *not* freed when the buffer is destroyed.
5768
pub unsafe fn from_ptr(ptr: *mut T, len: usize) -> Self {
5869
Buffer {
5970
ptr: ptr,

src/lib.rs

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// -*- indent-tabs-mode:nil; tab-width:2; -*-
2+
//! This crate provides Rust bindings for the [TensorFlow](https://www.tensorflow.org) machine learning library.
23
#![cfg(feature = "tensorflow_unstable")]
34

45
extern crate libc;
@@ -30,8 +31,9 @@ fn check_not_null<T>(p: *mut T) -> *mut T {
3031
////////////////////////
3132

3233
macro_rules! impl_new {
33-
($name: ident, $call:ident) => {
34+
($name: ident, $call:ident, $doc:expr) => {
3435
impl $name {
36+
#[doc = $doc]
3537
pub fn new() -> Self {
3638
unsafe {
3739
$name {
@@ -60,7 +62,8 @@ macro_rules! impl_drop {
6062
////////////////////////
6163

6264
macro_rules! c_enum {
63-
($enum_name:ident { $($name:ident = $num:expr),* }) => {
65+
($doc:expr, $enum_name:ident { $($name:ident = $num:expr),* }) => {
66+
#[doc = $doc]
6467
#[derive(PartialEq,Eq,PartialOrd,Ord,Debug)]
6568
pub enum $enum_name {
6669
UnrecognizedEnumValue(raw::c_uint),
@@ -94,14 +97,14 @@ macro_rules! c_enum {
9497
}
9598
}
9699
};
97-
($enum_name:ident { $($name:ident = $num:expr,)* }) => {
98-
c_enum!($enum_name { $($name = $num),* });
100+
($doc:expr, $enum_name:ident { $($name:ident = $num:expr,)* }) => {
101+
c_enum!($doc, $enum_name { $($name = $num),* });
99102
}
100103
}
101104

102105
////////////////////////
103106

104-
c_enum!(Code {
107+
c_enum!("Error values that can be returned.", Code {
105108
Ok = 0,
106109
Cancelled = 1,
107110
Unknown = 2,
@@ -123,7 +126,7 @@ c_enum!(Code {
123126

124127
////////////////////////
125128

126-
c_enum!(DataType {
129+
c_enum!("Type of a single tensor element.", DataType {
127130
Float = 1,
128131
Double = 2,
129132
Int32 = 3,
@@ -144,30 +147,35 @@ c_enum!(DataType {
144147

145148
////////////////////////
146149

150+
/// Holds error information. It either has an OK code, or else an error code with an associated error message.
147151
pub struct Status {
148152
inner: *mut tf::TF_Status,
149153
}
150154

151-
impl_new!(Status, TF_NewStatus);
155+
impl_new!(Status, TF_NewStatus, "Creates a status with `Code::Ok` and no message.");
152156
impl_drop!(Status, TF_DeleteStatus);
153157

154158
impl Status {
159+
/// Creates a status and sets its code and message.
155160
pub fn new_set(code: Code, msg: &str) -> std::result::Result<Status, NulError> {
156161
let mut status = Status::new();
157162
try!(status.set(code, msg));
158163
Ok(status)
159164
}
160165

166+
/// Returns the status's code.
161167
pub fn code(&self) -> Code {
162168
unsafe {
163169
Code::from_int(tf::TF_GetCode(self.inner) as u32)
164170
}
165171
}
166172

173+
/// Returns true if the status's code is `Code::Ok`.
167174
pub fn is_ok(&self) -> bool {
168175
self.code() == Code::Ok
169176
}
170177

178+
/// Sets the code and message.
171179
pub fn set(&mut self, code: Code, msg: &str) -> std::result::Result<(), NulError> {
172180
let message = try!(CString::new(msg)).as_ptr();
173181
unsafe {
@@ -208,11 +216,20 @@ impl Debug for Status {
208216

209217
////////////////////////
210218

219+
/// Options that can be passed during session creation.
211220
pub struct SessionOptions {
212221
inner: *mut tf::TF_SessionOptions,
213222
}
214223

215224
impl SessionOptions {
225+
/// Set the target.
226+
///
227+
/// `target` can be empty, a single entry, or a comma separated list of entries.
228+
/// Each entry is in one of the following formats :
229+
///
230+
/// - "local"
231+
/// - ip:port
232+
/// - host:port
216233
pub fn set_target(&mut self, target: &str) -> std::result::Result<(), NulError> {
217234
let cstr = try!(CString::new(target));
218235
unsafe {
@@ -221,6 +238,10 @@ impl SessionOptions {
221238
Ok(())
222239
}
223240

241+
/// Set the config.
242+
///
243+
/// `config` should be a serialized brain.ConfigProto proto.
244+
/// Returns an error if config was not parsed successfully as a ConfigProto.
224245
pub fn set_config(&mut self, config: &[u8]) -> Result<()> {
225246
let status = Status::new();
226247
unsafe {
@@ -234,16 +255,18 @@ impl SessionOptions {
234255
}
235256
}
236257

237-
impl_new!(SessionOptions, TF_NewSessionOptions);
258+
impl_new!(SessionOptions, TF_NewSessionOptions, "Creates a blank set of options.");
238259
impl_drop!(SessionOptions, TF_DeleteSessionOptions);
239260

240261
////////////////////////
241262

263+
/// Manages a single graph and execution.
242264
pub struct Session {
243265
inner: *mut tf::TF_Session,
244266
}
245267

246268
impl Session {
269+
/// Creates a session.
247270
pub fn new(options: &SessionOptions) -> Result<Self> {
248271
let status = Status::new();
249272
let inner = unsafe { tf::TF_NewSession(options.inner, status.inner) };
@@ -256,6 +279,7 @@ impl Session {
256279
}
257280
}
258281

282+
/// Closes the session.
259283
pub fn close(&mut self) -> Status {
260284
let status = Status::new();
261285
unsafe {
@@ -264,6 +288,7 @@ impl Session {
264288
status
265289
}
266290

291+
/// Treat `proto` as a serialized `GraphDef` and add the nodes in that `GraphDef` to the graph for the session.
267292
pub fn extend_graph(&mut self, proto: &[u8]) -> Status {
268293
let status = Status::new();
269294
unsafe {
@@ -285,12 +310,15 @@ impl Drop for Session {
285310

286311
////////////////////////
287312

313+
/// Convenience type for `Result` with `Status` as the error type.
288314
pub type Result<T> = std::result::Result<T, Status>;
289315

290316
////////////////////////
291317

318+
/// A Rust type that maps to a `DataType`.
292319
pub trait TensorType: Default + Clone {
293320
// TODO: Use associated constants when/if available
321+
/// Returns the DataType that corresponds to this type.
294322
fn data_type() -> DataType;
295323
}
296324

@@ -323,6 +351,18 @@ tensor_type!(bool, Bool);
323351

324352
////////////////////////
325353

354+
/// Holds a multi-dimensional array of elements of a single data type.
355+
///
356+
/// For all types other than strings, the data buffer stores elements
357+
/// in row major order. E.g. if data is treated as a vector of `T`:
358+
///
359+
/// ```text
360+
/// element 0: index (0, ..., 0)
361+
/// element 1: index (0, ..., 1)
362+
/// ...
363+
/// ```
364+
///
365+
/// The layout for strings is currently undefined.
326366
pub struct Tensor<T> {
327367
inner: *mut tf::TF_Tensor,
328368
data: Buffer<T>,
@@ -344,6 +384,9 @@ fn product(values: &[u64]) -> u64 {
344384
}
345385

346386
impl<T: TensorType> Tensor<T> {
387+
/// Creates a new tensor.
388+
///
389+
/// The data is initialized to zeros.
347390
pub fn new(dims: &[u64]) -> Self {
348391
let total = product(dims);
349392
let data = <Buffer<T>>::new(total as usize);
@@ -353,6 +396,7 @@ impl<T: TensorType> Tensor<T> {
353396
Self::new_with_buffer(dims, data).unwrap()
354397
}
355398

399+
/// Creates a new tensor from existing data.
356400
pub fn new_with_buffer(dims: &[u64], data: Buffer<T>) -> Option<Self> {
357401
let total = product(dims);
358402
if total != data.len() as u64 {
@@ -377,14 +421,17 @@ impl<T: TensorType> Tensor<T> {
377421
})
378422
}
379423

424+
/// Returns the tensor's data.
380425
pub fn data(&self) -> &Buffer<T> {
381426
&self.data
382427
}
383428

429+
/// Returns the tensor's data.
384430
pub fn data_mut(&mut self) -> &mut Buffer<T> {
385431
&mut self.data
386432
}
387433

434+
/// Returns the tensor's dimensions.
388435
pub fn dims(&self) -> &[u64] {
389436
&self.dims
390437
}

0 commit comments

Comments
 (0)