Skip to content

Commit aafe7af

Browse files
committed
Add examples/xor.rs for a neural net trained entirely in Rust
1 parent 2b8637a commit aafe7af

File tree

4 files changed

+268
-11
lines changed

4 files changed

+268
-11
lines changed

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,7 @@ name = "regression_savedmodel"
4848

4949
[[example]]
5050
name = "regression_checkpoint"
51+
52+
[[example]]
53+
name = "xor"
54+
required-features = ["experimental_training"]

examples/xor.rs

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
use std::error::Error;
2+
use std::result::Result;
3+
use tensorflow::ops;
4+
use tensorflow::train::AdadeltaOptimizer;
5+
use tensorflow::train::MinimizeOptions;
6+
use tensorflow::train::Optimizer;
7+
use tensorflow::Code;
8+
use tensorflow::DataType;
9+
use tensorflow::Output;
10+
use tensorflow::Scope;
11+
use tensorflow::Session;
12+
use tensorflow::SessionOptions;
13+
use tensorflow::SessionRunArgs;
14+
use tensorflow::Shape;
15+
use tensorflow::Status;
16+
use tensorflow::Tensor;
17+
use tensorflow::Variable;
18+
19+
// Helper for building a layer.
20+
//
21+
// `activation` is a function which takes a tensor and applies an activation
22+
// function such as tanh.
23+
//
24+
// Returns variables created and the layer output.
25+
fn layer<O1: Into<Output>>(
26+
input: O1,
27+
input_size: u64,
28+
output_size: u64,
29+
activation: &dyn Fn(Output, &mut Scope) -> Result<Output, Status>,
30+
scope: &mut Scope,
31+
) -> Result<(Vec<Variable>, Output), Status> {
32+
let mut scope = scope.new_sub_scope("layer");
33+
let scope = &mut scope;
34+
let w_shape = ops::constant(&[input_size as i64, output_size as i64][..], scope)?;
35+
let w = Variable::builder()
36+
.initial_value(ops::random_normal(w_shape, scope)?)
37+
.data_type(DataType::Float)
38+
.shape(Shape::from(&[input_size, output_size][..]))
39+
.build(&mut scope.with_op_name("w"))?;
40+
let b = Variable::builder()
41+
.const_initial_value(Tensor::<f32>::new(&[output_size]))
42+
.build(&mut scope.with_op_name("b"))?;
43+
Ok((
44+
vec![w.clone(), b.clone()],
45+
activation(
46+
ops::add(
47+
ops::mat_mul(input, w.output().clone(), scope)?,
48+
b.output().clone(),
49+
scope,
50+
)?
51+
.into(),
52+
scope,
53+
)?,
54+
))
55+
}
56+
57+
fn main() -> Result<(), Box<Error>> {
58+
// ================
59+
// Build the model.
60+
// ================
61+
let mut scope = Scope::new_root_scope();
62+
let scope = &mut scope;
63+
// Size of the hidden layer.
64+
// This is far more than is necessary, but makes it train more reliably.
65+
let hidden_size: u64 = 8;
66+
let input = ops::Placeholder::new()
67+
.data_type(DataType::Float)
68+
.shape(Shape::from(&[1u64, 2][..]))
69+
.build(&mut scope.with_op_name("input"))?;
70+
let label = ops::Placeholder::new()
71+
.data_type(DataType::Float)
72+
.shape(Shape::from(&[1u64][..]))
73+
.build(&mut scope.with_op_name("label"))?;
74+
// Hidden layer.
75+
let (vars1, layer1) = layer(
76+
input.clone(),
77+
2,
78+
hidden_size,
79+
&|x, scope| Ok(ops::tanh(x, scope)?.into()),
80+
scope,
81+
)?;
82+
// Output layer.
83+
let (vars2, layer2) = layer(layer1.clone(), hidden_size, 1, &|x, _| Ok(x), scope)?;
84+
let error = ops::subtract(layer2.clone(), label.clone(), scope)?;
85+
let error_squared = ops::multiply(error.clone(), error, scope)?;
86+
let mut optimizer = AdadeltaOptimizer::new();
87+
optimizer.set_learning_rate(ops::constant(1.0f32, scope)?);
88+
let mut variables = Vec::new();
89+
variables.extend(vars1);
90+
variables.extend(vars2);
91+
let (minimizer_vars, minimize) = optimizer.minimize(
92+
scope,
93+
error_squared.clone().into(),
94+
MinimizeOptions::default().with_variables(&variables),
95+
)?;
96+
97+
// =========================
98+
// Initialize the variables.
99+
// =========================
100+
let options = SessionOptions::new();
101+
let g = scope.graph_mut();
102+
let session = Session::new(&options, &g)?;
103+
let mut run_args = SessionRunArgs::new();
104+
// Initialize variables we defined.
105+
for var in &variables {
106+
run_args.add_target(&var.initializer());
107+
}
108+
// Initialize variables the optimizer defined.
109+
for var in &minimizer_vars {
110+
run_args.add_target(&var.initializer());
111+
}
112+
session.run(&mut run_args)?;
113+
114+
// ================
115+
// Train the model.
116+
// ================
117+
let mut input_tensor = Tensor::<f32>::new(&[1, 2]);
118+
let mut label_tensor = Tensor::<f32>::new(&[1]);
119+
// Helper that generates a training example from an integer, trains on that
120+
// example, and returns the error.
121+
let mut train = |i| -> Result<f32, Box<Error>> {
122+
input_tensor[0] = (i & 1) as f32;
123+
input_tensor[1] = ((i >> 1) & 1) as f32;
124+
label_tensor[0] = ((i & 1) ^ ((i >> 1) & 1)) as f32;
125+
let mut run_args = SessionRunArgs::new();
126+
run_args.add_target(&minimize);
127+
let error_squared_fetch = run_args.request_fetch(&error_squared, 0);
128+
run_args.add_feed(&input, 0, &input_tensor);
129+
run_args.add_feed(&label, 0, &label_tensor);
130+
session.run(&mut run_args)?;
131+
Ok(run_args.fetch::<f32>(error_squared_fetch)?[0])
132+
};
133+
for i in 0..10000 {
134+
train(i)?;
135+
}
136+
137+
// ===================
138+
// Evaluate the model.
139+
// ===================
140+
for i in 0..4 {
141+
let error = train(i)?;
142+
println!("Error: {}", error);
143+
if error > 0.1 {
144+
return Err(Box::new(Status::new_set(
145+
Code::Internal,
146+
&format!("Error too high: {}", error),
147+
)?));
148+
}
149+
}
150+
Ok(())
151+
}

src/train.rs

Lines changed: 112 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,10 @@ pub struct GradientDescentOptimizer {
131131

132132
impl GradientDescentOptimizer {
133133
/// Creates a new optimizer with the given learning rate.
134-
pub fn new(learning_rate: Output) -> Self {
135-
Self { learning_rate }
134+
pub fn new<T: Into<Output>>(learning_rate: T) -> Self {
135+
Self {
136+
learning_rate: learning_rate.into(),
137+
}
136138
}
137139
}
138140

@@ -216,15 +218,9 @@ fn create_zeros_slot(
216218
dtype: Option<DataType>,
217219
) -> Result<Variable> {
218220
let dtype = dtype.unwrap_or_else(|| primary.dtype);
219-
// TODO: use standard op
220-
let zeros = {
221-
let name = scope.get_unique_name_for_op("ZerosLike");
222-
let mut graph = scope.graph_mut();
223-
let mut nd = graph.new_operation("ZerosLike", &name)?;
224-
nd.add_input(primary.output.clone());
225-
nd.add_control_input(&primary.initializer);
226-
nd.finish()?
227-
};
221+
let zeros = ops::ZerosLike::new()
222+
.add_control_input(primary.initializer.clone())
223+
.build(primary.output.clone(), scope)?;
228224
Variable::builder()
229225
.initial_value(zeros)
230226
.shape(primary.shape.clone())
@@ -276,9 +272,13 @@ impl Optimizer for AdadeltaOptimizer {
276272
#[cfg(test)]
277273
mod tests {
278274
use super::*;
275+
use crate::ops;
276+
use crate::Scope;
279277
use crate::Session;
280278
use crate::SessionOptions;
281279
use crate::SessionRunArgs;
280+
use crate::Shape;
281+
use crate::Tensor;
282282

283283
#[test]
284284
fn simple_gradient_descent() {
@@ -403,4 +403,105 @@ mod tests {
403403
x_output[0]
404404
);
405405
}
406+
407+
#[test]
408+
fn xor_nn() {
409+
let mut scope = Scope::new_root_scope();
410+
let scope = &mut scope;
411+
let hidden_size: u64 = 4;
412+
let input = ops::Placeholder::new()
413+
.data_type(DataType::Float)
414+
.shape(Shape::from(&[1u64, 2][..]))
415+
.build(&mut scope.with_op_name("input"))
416+
.unwrap();
417+
let label = ops::Placeholder::new()
418+
.data_type(DataType::Float)
419+
.shape(Shape::from(&[1u64][..]))
420+
.build(&mut scope.with_op_name("label"))
421+
.unwrap();
422+
let w_shape = ops::constant(&[2, hidden_size as i64][..], scope).unwrap();
423+
let w_init = ops::random_normal(w_shape, scope).unwrap();
424+
let w = Variable::builder()
425+
.initial_value(w_init)
426+
.data_type(DataType::Float)
427+
.shape(Shape::from(&[2, hidden_size][..]))
428+
.build(&mut scope.with_op_name("w"))
429+
.unwrap();
430+
let b = Variable::builder()
431+
.const_initial_value(Tensor::<f32>::new(&[hidden_size]))
432+
.build(&mut scope.with_op_name("b"))
433+
.unwrap();
434+
let layer1a = ops::MatMul::new()
435+
.build(input.clone(), w.output.clone(), scope)
436+
.unwrap();
437+
let layer1b = ops::Add::new()
438+
.build(layer1a, b.output.clone(), scope)
439+
.unwrap();
440+
let layer1 = ops::Tanh::new().build(layer1b, scope).unwrap();
441+
let w2_shape = ops::constant(&[hidden_size as i64, 1][..], scope).unwrap();
442+
let w2_init = ops::random_normal(w2_shape, scope).unwrap();
443+
let w2 = Variable::builder()
444+
.initial_value(w2_init)
445+
.data_type(DataType::Float)
446+
.shape(Shape::from(&[hidden_size, 1][..]))
447+
.build(&mut scope.with_op_name("w2"))
448+
.unwrap();
449+
let b2 = Variable::builder()
450+
.const_initial_value(Tensor::<f32>::new(&[1]))
451+
.build(&mut scope.with_op_name("b2"))
452+
.unwrap();
453+
let layer2a = ops::mat_mul(layer1, w2.output.clone(), scope).unwrap();
454+
let layer2b = ops::add(layer2a, b2.output.clone(), scope).unwrap();
455+
let layer2 = layer2b;
456+
let error = ops::subtract(layer2.clone(), label.clone(), scope).unwrap();
457+
let error_squared = ops::multiply(error.clone(), error, scope).unwrap();
458+
let sgd = GradientDescentOptimizer {
459+
learning_rate: Output {
460+
operation: ops::constant(0.1f32, scope).unwrap(),
461+
index: 0,
462+
},
463+
};
464+
let variables = vec![w.clone(), b.clone(), w2.clone(), b2.clone()];
465+
let (minimizer_vars, minimize) = sgd
466+
.minimize(
467+
scope,
468+
error_squared.clone().into(),
469+
MinimizeOptions::default().with_variables(&variables),
470+
)
471+
.unwrap();
472+
let options = SessionOptions::new();
473+
let g = scope.graph_mut();
474+
let session = Session::new(&options, &g).unwrap();
475+
476+
let mut run_args = SessionRunArgs::new();
477+
for var in &variables {
478+
run_args.add_target(&var.initializer);
479+
}
480+
for var in &minimizer_vars {
481+
run_args.add_target(&var.initializer);
482+
}
483+
session.run(&mut run_args).unwrap();
484+
485+
let mut input_tensor = Tensor::<f32>::new(&[1, 2]);
486+
let mut label_tensor = Tensor::<f32>::new(&[1]);
487+
let mut train = |i| {
488+
input_tensor[0] = (i & 1) as f32;
489+
input_tensor[1] = ((i >> 1) & 1) as f32;
490+
label_tensor[0] = ((i & 1) ^ ((i >> 1) & 1)) as f32;
491+
let mut run_args = SessionRunArgs::new();
492+
run_args.add_target(&minimize);
493+
let error_squared_fetch = run_args.request_fetch(&error_squared, 0);
494+
run_args.add_feed(&input, 0, &input_tensor);
495+
run_args.add_feed(&label, 0, &label_tensor);
496+
session.run(&mut run_args).unwrap();
497+
run_args.fetch::<f32>(error_squared_fetch).unwrap()[0]
498+
};
499+
for i in 0..1000 {
500+
train(i);
501+
}
502+
for i in 0..4 {
503+
let error = train(i);
504+
assert!(error < 0.01, "error = {}", error);
505+
}
506+
}
406507
}

test-all

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ cargo test -vv -j 2 --features tensorflow_unstable
4343
cargo test -vv -j 2 --features experimental_training
4444
cargo test -vv -j 2 --features tensorflow_unstable,experimental_training
4545
cargo run --example regression
46+
cargo run --features=experimental_training --example xor
4647
cargo run --features tensorflow_unstable --example expressions
4748
cargo doc -vv --features tensorflow_unstable,experimental_training
4849
# TODO(#66): Re-enable: (cd tensorflow-sys && cargo test -vv -j 1)

0 commit comments

Comments
 (0)