@@ -131,8 +131,10 @@ pub struct GradientDescentOptimizer {
131
131
132
132
impl GradientDescentOptimizer {
133
133
/// Creates a new optimizer with the given learning rate.
134
- pub fn new ( learning_rate : Output ) -> Self {
135
- Self { learning_rate }
134
+ pub fn new < T : Into < Output > > ( learning_rate : T ) -> Self {
135
+ Self {
136
+ learning_rate : learning_rate. into ( ) ,
137
+ }
136
138
}
137
139
}
138
140
@@ -216,15 +218,9 @@ fn create_zeros_slot(
216
218
dtype : Option < DataType > ,
217
219
) -> Result < Variable > {
218
220
let dtype = dtype. unwrap_or_else ( || primary. dtype ) ;
219
- // TODO: use standard op
220
- let zeros = {
221
- let name = scope. get_unique_name_for_op ( "ZerosLike" ) ;
222
- let mut graph = scope. graph_mut ( ) ;
223
- let mut nd = graph. new_operation ( "ZerosLike" , & name) ?;
224
- nd. add_input ( primary. output . clone ( ) ) ;
225
- nd. add_control_input ( & primary. initializer ) ;
226
- nd. finish ( ) ?
227
- } ;
221
+ let zeros = ops:: ZerosLike :: new ( )
222
+ . add_control_input ( primary. initializer . clone ( ) )
223
+ . build ( primary. output . clone ( ) , scope) ?;
228
224
Variable :: builder ( )
229
225
. initial_value ( zeros)
230
226
. shape ( primary. shape . clone ( ) )
@@ -276,9 +272,13 @@ impl Optimizer for AdadeltaOptimizer {
276
272
#[ cfg( test) ]
277
273
mod tests {
278
274
use super :: * ;
275
+ use crate :: ops;
276
+ use crate :: Scope ;
279
277
use crate :: Session ;
280
278
use crate :: SessionOptions ;
281
279
use crate :: SessionRunArgs ;
280
+ use crate :: Shape ;
281
+ use crate :: Tensor ;
282
282
283
283
#[ test]
284
284
fn simple_gradient_descent ( ) {
@@ -403,4 +403,105 @@ mod tests {
403
403
x_output[ 0 ]
404
404
) ;
405
405
}
406
+
407
+ #[ test]
408
+ fn xor_nn ( ) {
409
+ let mut scope = Scope :: new_root_scope ( ) ;
410
+ let scope = & mut scope;
411
+ let hidden_size: u64 = 4 ;
412
+ let input = ops:: Placeholder :: new ( )
413
+ . data_type ( DataType :: Float )
414
+ . shape ( Shape :: from ( & [ 1u64 , 2 ] [ ..] ) )
415
+ . build ( & mut scope. with_op_name ( "input" ) )
416
+ . unwrap ( ) ;
417
+ let label = ops:: Placeholder :: new ( )
418
+ . data_type ( DataType :: Float )
419
+ . shape ( Shape :: from ( & [ 1u64 ] [ ..] ) )
420
+ . build ( & mut scope. with_op_name ( "label" ) )
421
+ . unwrap ( ) ;
422
+ let w_shape = ops:: constant ( & [ 2 , hidden_size as i64 ] [ ..] , scope) . unwrap ( ) ;
423
+ let w_init = ops:: random_normal ( w_shape, scope) . unwrap ( ) ;
424
+ let w = Variable :: builder ( )
425
+ . initial_value ( w_init)
426
+ . data_type ( DataType :: Float )
427
+ . shape ( Shape :: from ( & [ 2 , hidden_size] [ ..] ) )
428
+ . build ( & mut scope. with_op_name ( "w" ) )
429
+ . unwrap ( ) ;
430
+ let b = Variable :: builder ( )
431
+ . const_initial_value ( Tensor :: < f32 > :: new ( & [ hidden_size] ) )
432
+ . build ( & mut scope. with_op_name ( "b" ) )
433
+ . unwrap ( ) ;
434
+ let layer1a = ops:: MatMul :: new ( )
435
+ . build ( input. clone ( ) , w. output . clone ( ) , scope)
436
+ . unwrap ( ) ;
437
+ let layer1b = ops:: Add :: new ( )
438
+ . build ( layer1a, b. output . clone ( ) , scope)
439
+ . unwrap ( ) ;
440
+ let layer1 = ops:: Tanh :: new ( ) . build ( layer1b, scope) . unwrap ( ) ;
441
+ let w2_shape = ops:: constant ( & [ hidden_size as i64 , 1 ] [ ..] , scope) . unwrap ( ) ;
442
+ let w2_init = ops:: random_normal ( w2_shape, scope) . unwrap ( ) ;
443
+ let w2 = Variable :: builder ( )
444
+ . initial_value ( w2_init)
445
+ . data_type ( DataType :: Float )
446
+ . shape ( Shape :: from ( & [ hidden_size, 1 ] [ ..] ) )
447
+ . build ( & mut scope. with_op_name ( "w2" ) )
448
+ . unwrap ( ) ;
449
+ let b2 = Variable :: builder ( )
450
+ . const_initial_value ( Tensor :: < f32 > :: new ( & [ 1 ] ) )
451
+ . build ( & mut scope. with_op_name ( "b2" ) )
452
+ . unwrap ( ) ;
453
+ let layer2a = ops:: mat_mul ( layer1, w2. output . clone ( ) , scope) . unwrap ( ) ;
454
+ let layer2b = ops:: add ( layer2a, b2. output . clone ( ) , scope) . unwrap ( ) ;
455
+ let layer2 = layer2b;
456
+ let error = ops:: subtract ( layer2. clone ( ) , label. clone ( ) , scope) . unwrap ( ) ;
457
+ let error_squared = ops:: multiply ( error. clone ( ) , error, scope) . unwrap ( ) ;
458
+ let sgd = GradientDescentOptimizer {
459
+ learning_rate : Output {
460
+ operation : ops:: constant ( 0.1f32 , scope) . unwrap ( ) ,
461
+ index : 0 ,
462
+ } ,
463
+ } ;
464
+ let variables = vec ! [ w. clone( ) , b. clone( ) , w2. clone( ) , b2. clone( ) ] ;
465
+ let ( minimizer_vars, minimize) = sgd
466
+ . minimize (
467
+ scope,
468
+ error_squared. clone ( ) . into ( ) ,
469
+ MinimizeOptions :: default ( ) . with_variables ( & variables) ,
470
+ )
471
+ . unwrap ( ) ;
472
+ let options = SessionOptions :: new ( ) ;
473
+ let g = scope. graph_mut ( ) ;
474
+ let session = Session :: new ( & options, & g) . unwrap ( ) ;
475
+
476
+ let mut run_args = SessionRunArgs :: new ( ) ;
477
+ for var in & variables {
478
+ run_args. add_target ( & var. initializer ) ;
479
+ }
480
+ for var in & minimizer_vars {
481
+ run_args. add_target ( & var. initializer ) ;
482
+ }
483
+ session. run ( & mut run_args) . unwrap ( ) ;
484
+
485
+ let mut input_tensor = Tensor :: < f32 > :: new ( & [ 1 , 2 ] ) ;
486
+ let mut label_tensor = Tensor :: < f32 > :: new ( & [ 1 ] ) ;
487
+ let mut train = |i| {
488
+ input_tensor[ 0 ] = ( i & 1 ) as f32 ;
489
+ input_tensor[ 1 ] = ( ( i >> 1 ) & 1 ) as f32 ;
490
+ label_tensor[ 0 ] = ( ( i & 1 ) ^ ( ( i >> 1 ) & 1 ) ) as f32 ;
491
+ let mut run_args = SessionRunArgs :: new ( ) ;
492
+ run_args. add_target ( & minimize) ;
493
+ let error_squared_fetch = run_args. request_fetch ( & error_squared, 0 ) ;
494
+ run_args. add_feed ( & input, 0 , & input_tensor) ;
495
+ run_args. add_feed ( & label, 0 , & label_tensor) ;
496
+ session. run ( & mut run_args) . unwrap ( ) ;
497
+ run_args. fetch :: < f32 > ( error_squared_fetch) . unwrap ( ) [ 0 ]
498
+ } ;
499
+ for i in 0 ..1000 {
500
+ train ( i) ;
501
+ }
502
+ for i in 0 ..4 {
503
+ let error = train ( i) ;
504
+ assert ! ( error < 0.01 , "error = {}" , error) ;
505
+ }
506
+ }
406
507
}
0 commit comments