Skip to content

Commit cb04758

Browse files
committed
Add Tensor type
1 parent be359d8 commit cb04758

File tree

1 file changed

+118
-0
lines changed

1 file changed

+118
-0
lines changed

src/lib.rs

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,117 @@ pub type Result<T> = std::result::Result<T, Status>;
223223

224224
////////////////////////
225225

226+
trait TensorType: Default + Clone {
227+
// TODO: Use associated constants when/if available
228+
fn data_type() -> DataType;
229+
}
230+
231+
macro_rules! tensor_type {
232+
($rust_type:ident, $tensor_type:ident) => {
233+
impl TensorType for $rust_type {
234+
fn data_type() -> DataType {
235+
DataType::$tensor_type
236+
}
237+
}
238+
}
239+
}
240+
241+
tensor_type!(f32, Float);
242+
tensor_type!(f64, Double);
243+
tensor_type!(i32, Int32);
244+
tensor_type!(u8, UInt8);
245+
tensor_type!(i16, Int16);
246+
tensor_type!(i8, Int8);
247+
// TODO: provide type for String
248+
// TODO: provide type for Complex
249+
tensor_type!(i64, Int64);
250+
tensor_type!(bool, Bool);
251+
// TODO: provide type for QInt8
252+
// TODO: provide type for QUInt8
253+
// TODO: provide type for QInt32
254+
// TODO: provide type for BFloat16
255+
// TODO: provide type for QInt16
256+
// TODO: provide type for QUInt16
257+
258+
////////////////////////
259+
260+
pub struct Tensor<T> {
261+
inner: *mut tf::TF_Tensor,
262+
data: Buffer<T>,
263+
dims: Vec<u64>,
264+
}
265+
266+
extern "C" fn noop_deallocator(_data: *mut ::libc::c_void,
267+
_len: ::libc::size_t,
268+
_arg: *mut ::libc::c_void)-> () {
269+
}
270+
271+
// TODO: Replace with Iterator::product once that's stable
272+
fn product(values: &[u64]) -> u64 {
273+
let mut product = 1;
274+
for v in values.iter() {
275+
product *= *v;
276+
}
277+
product
278+
}
279+
280+
impl<T: TensorType> Tensor<T> {
281+
pub fn new(dims: &[u64]) -> Self {
282+
let total = product(dims);
283+
let data = <Buffer<T>>::new(total as usize);
284+
// Guaranteed safe to unwrap, because the only way for it to fail is for the
285+
// length of the buffer not to match the dimensions, and we created it with
286+
// exactly the right size.
287+
Self::new_with_buffer(dims, data).unwrap()
288+
}
289+
290+
pub fn new_with_buffer(dims: &[u64], data: Buffer<T>) -> Option<Self> {
291+
let total = product(dims);
292+
if total != data.len() as u64 {
293+
return None
294+
}
295+
let inner = unsafe {
296+
tf::TF_NewTensor(T::data_type().to_int(),
297+
dims.as_ptr() as *mut i64,
298+
dims.len() as i32,
299+
data.as_ptr() as *mut libc::c_void,
300+
data.len(),
301+
Some(noop_deallocator),
302+
std::ptr::null_mut())
303+
};
304+
let mut dims_vec = Vec::new();
305+
// TODO: Use extend_from_slice once we're on Rust 1.6
306+
dims_vec.extend(dims.iter());
307+
Some(Tensor {
308+
inner: inner,
309+
data: data,
310+
dims: dims_vec,
311+
})
312+
}
313+
314+
pub fn data(&self) -> &Buffer<T> {
315+
&self.data
316+
}
317+
318+
pub fn data_mut(&mut self) -> &mut Buffer<T> {
319+
&mut self.data
320+
}
321+
322+
pub fn dims(&self) -> &[u64] {
323+
&self.dims
324+
}
325+
}
326+
327+
impl<T> Drop for Tensor<T> {
328+
fn drop(&mut self) {
329+
unsafe {
330+
tf::TF_DeleteTensor(self.inner);
331+
}
332+
}
333+
}
334+
335+
////////////////////////
336+
226337
#[cfg(test)]
227338
mod tests {
228339
use super::*;
@@ -245,4 +356,11 @@ mod tests {
245356
let status = create_session().close();
246357
assert!(status.is_ok());
247358
}
359+
360+
#[test]
361+
fn test_tensor() {
362+
let mut tensor = <Tensor<f32>>::new(&[2, 3]);
363+
assert_eq!(tensor.data().len(), 6);
364+
tensor.data_mut()[0] = 1.0;
365+
}
248366
}

0 commit comments

Comments
 (0)