@@ -223,6 +223,117 @@ pub type Result<T> = std::result::Result<T, Status>;
223223
224224////////////////////////
225225
226+ trait TensorType : Default + Clone {
227+ // TODO: Use associated constants when/if available
228+ fn data_type ( ) -> DataType ;
229+ }
230+
231+ macro_rules! tensor_type {
232+ ( $rust_type: ident, $tensor_type: ident) => {
233+ impl TensorType for $rust_type {
234+ fn data_type( ) -> DataType {
235+ DataType :: $tensor_type
236+ }
237+ }
238+ }
239+ }
240+
241+ tensor_type ! ( f32 , Float ) ;
242+ tensor_type ! ( f64 , Double ) ;
243+ tensor_type ! ( i32 , Int32 ) ;
244+ tensor_type ! ( u8 , UInt8 ) ;
245+ tensor_type ! ( i16 , Int16 ) ;
246+ tensor_type ! ( i8 , Int8 ) ;
247+ // TODO: provide type for String
248+ // TODO: provide type for Complex
249+ tensor_type ! ( i64 , Int64 ) ;
250+ tensor_type ! ( bool , Bool ) ;
251+ // TODO: provide type for QInt8
252+ // TODO: provide type for QUInt8
253+ // TODO: provide type for QInt32
254+ // TODO: provide type for BFloat16
255+ // TODO: provide type for QInt16
256+ // TODO: provide type for QUInt16
257+
258+ ////////////////////////
259+
260+ pub struct Tensor < T > {
261+ inner : * mut tf:: TF_Tensor ,
262+ data : Buffer < T > ,
263+ dims : Vec < u64 > ,
264+ }
265+
266+ extern "C" fn noop_deallocator ( _data : * mut :: libc:: c_void ,
267+ _len : :: libc:: size_t ,
268+ _arg : * mut :: libc:: c_void ) -> ( ) {
269+ }
270+
271+ // TODO: Replace with Iterator::product once that's stable
272+ fn product ( values : & [ u64 ] ) -> u64 {
273+ let mut product = 1 ;
274+ for v in values. iter ( ) {
275+ product *= * v;
276+ }
277+ product
278+ }
279+
280+ impl < T : TensorType > Tensor < T > {
281+ pub fn new ( dims : & [ u64 ] ) -> Self {
282+ let total = product ( dims) ;
283+ let data = <Buffer < T > >:: new ( total as usize ) ;
284+ // Guaranteed safe to unwrap, because the only way for it to fail is for the
285+ // length of the buffer not to match the dimensions, and we created it with
286+ // exactly the right size.
287+ Self :: new_with_buffer ( dims, data) . unwrap ( )
288+ }
289+
290+ pub fn new_with_buffer ( dims : & [ u64 ] , data : Buffer < T > ) -> Option < Self > {
291+ let total = product ( dims) ;
292+ if total != data. len ( ) as u64 {
293+ return None
294+ }
295+ let inner = unsafe {
296+ tf:: TF_NewTensor ( T :: data_type ( ) . to_int ( ) ,
297+ dims. as_ptr ( ) as * mut i64 ,
298+ dims. len ( ) as i32 ,
299+ data. as_ptr ( ) as * mut libc:: c_void ,
300+ data. len ( ) ,
301+ Some ( noop_deallocator) ,
302+ std:: ptr:: null_mut ( ) )
303+ } ;
304+ let mut dims_vec = Vec :: new ( ) ;
305+ // TODO: Use extend_from_slice once we're on Rust 1.6
306+ dims_vec. extend ( dims. iter ( ) ) ;
307+ Some ( Tensor {
308+ inner : inner,
309+ data : data,
310+ dims : dims_vec,
311+ } )
312+ }
313+
314+ pub fn data ( & self ) -> & Buffer < T > {
315+ & self . data
316+ }
317+
318+ pub fn data_mut ( & mut self ) -> & mut Buffer < T > {
319+ & mut self . data
320+ }
321+
322+ pub fn dims ( & self ) -> & [ u64 ] {
323+ & self . dims
324+ }
325+ }
326+
327+ impl < T > Drop for Tensor < T > {
328+ fn drop ( & mut self ) {
329+ unsafe {
330+ tf:: TF_DeleteTensor ( self . inner ) ;
331+ }
332+ }
333+ }
334+
335+ ////////////////////////
336+
226337#[ cfg( test) ]
227338mod tests {
228339 use super :: * ;
@@ -245,4 +356,11 @@ mod tests {
245356 let status = create_session ( ) . close ( ) ;
246357 assert ! ( status. is_ok( ) ) ;
247358 }
359+
360+ #[ test]
361+ fn test_tensor ( ) {
362+ let mut tensor = <Tensor < f32 > >:: new ( & [ 2 , 3 ] ) ;
363+ assert_eq ! ( tensor. data( ) . len( ) , 6 ) ;
364+ tensor. data_mut ( ) [ 0 ] = 1.0 ;
365+ }
248366}
0 commit comments