Skip to content

Commit db7191c

Browse files
authored
Merge pull request tensorflow#232 from adamcrume/saved-model
Add support for creating saved models
2 parents 7fb10b1 + 1672f35 commit db7191c

25 files changed

+13787
-2
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ tensorflow-sys = { version = "0.17.0", path = "tensorflow-sys" }
1818
byteorder = "1.3.2"
1919
crc = "1.8.1"
2020
half = "1.3.0"
21+
# This is used internally but not intended to be exposed through the API.
22+
protobuf = "=2.8.0"
2123

2224
[dev-dependencies]
2325
random = "0.12.2"

RELEASING.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
## Pre-release
2+
3+
1. Generate protos
4+
1. Run `cd tensorflow-proto-codegen; cargo run -- $PATH_TO_TENSORFLOW $PWD/../src/protos`
5+
1. Update Cargo.toml to ensure version of protobuf exactly equals version of protoc_rust used
6+
1. Commit and push changes
7+
18
## Releasing
29

310
1. Check out a clean copy. Note that `cargo publish` packages up untracked files. Use `--allow-dirty` at your peril.

examples/xor.rs

Lines changed: 114 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,32 @@
1+
use std::env;
12
use std::error::Error;
3+
use std::fs;
4+
use std::io::ErrorKind;
5+
use std::path::Path;
26
use std::result::Result;
37
use tensorflow::ops;
48
use tensorflow::train::AdadeltaOptimizer;
59
use tensorflow::train::MinimizeOptions;
610
use tensorflow::train::Optimizer;
711
use tensorflow::Code;
812
use tensorflow::DataType;
13+
use tensorflow::Graph;
914
use tensorflow::Output;
15+
use tensorflow::OutputName;
16+
use tensorflow::SavedModelBundle;
1017
use tensorflow::Scope;
1118
use tensorflow::Session;
1219
use tensorflow::SessionOptions;
1320
use tensorflow::SessionRunArgs;
1421
use tensorflow::Shape;
22+
use tensorflow::SignatureDef;
1523
use tensorflow::Status;
1624
use tensorflow::Tensor;
25+
use tensorflow::TensorInfo;
1726
use tensorflow::Variable;
27+
use tensorflow::REGRESS_INPUTS;
28+
use tensorflow::REGRESS_METHOD_NAME;
29+
use tensorflow::REGRESS_OUTPUTS;
1830

1931
// Helper for building a layer.
2032
//
@@ -54,7 +66,7 @@ fn layer<O1: Into<Output>>(
5466
))
5567
}
5668

57-
fn main() -> Result<(), Box<Error>> {
69+
fn train<P: AsRef<Path>>(save_dir: P) -> Result<(), Box<dyn Error>> {
5870
// ================
5971
// Build the model.
6072
// ================
@@ -94,6 +106,34 @@ fn main() -> Result<(), Box<Error>> {
94106
MinimizeOptions::default().with_variables(&variables),
95107
)?;
96108

109+
let mut all_vars = variables.clone();
110+
all_vars.extend_from_slice(&minimizer_vars);
111+
let mut builder = tensorflow::SavedModelBuilder::new();
112+
builder
113+
.add_collection("train", &all_vars)
114+
.add_tag("serve")
115+
.add_tag("train")
116+
.add_signature(REGRESS_METHOD_NAME, {
117+
let mut def = SignatureDef::new(REGRESS_METHOD_NAME.to_string());
118+
def.add_input_info(
119+
REGRESS_INPUTS.to_string(),
120+
TensorInfo::new(
121+
DataType::Float,
122+
Shape::from(None),
123+
OutputName {
124+
name: input.name()?,
125+
index: 0,
126+
},
127+
),
128+
);
129+
def.add_output_info(
130+
REGRESS_OUTPUTS.to_string(),
131+
TensorInfo::new(DataType::Float, Shape::from(None), layer2.name()?),
132+
);
133+
def
134+
});
135+
let saved_model_saver = builder.inject(scope)?;
136+
97137
// =========================
98138
// Initialize the variables.
99139
// =========================
@@ -118,7 +158,7 @@ fn main() -> Result<(), Box<Error>> {
118158
let mut label_tensor = Tensor::<f32>::new(&[1]);
119159
// Helper that generates a training example from an integer, trains on that
120160
// example, and returns the error.
121-
let mut train = |i| -> Result<f32, Box<Error>> {
161+
let mut train = |i| -> Result<f32, Box<dyn Error>> {
122162
input_tensor[0] = (i & 1) as f32;
123163
input_tensor[1] = ((i >> 1) & 1) as f32;
124164
label_tensor[0] = ((i & 1) ^ ((i >> 1) & 1)) as f32;
@@ -134,6 +174,11 @@ fn main() -> Result<(), Box<Error>> {
134174
train(i)?;
135175
}
136176

177+
// ================
178+
// Save the model.
179+
// ================
180+
saved_model_saver.save(&session, &g, &save_dir)?;
181+
137182
// ===================
138183
// Evaluate the model.
139184
// ===================
@@ -149,3 +194,70 @@ fn main() -> Result<(), Box<Error>> {
149194
}
150195
Ok(())
151196
}
197+
198+
fn eval<P: AsRef<Path>>(save_dir: P) -> Result<(), Box<dyn Error>> {
199+
let mut graph = Graph::new();
200+
let bundle = SavedModelBundle::load(
201+
&SessionOptions::new(),
202+
&["serve", "train"],
203+
&mut graph,
204+
save_dir,
205+
)?;
206+
let session = &bundle.session;
207+
let signature = bundle.meta_graph_def().get_signature(REGRESS_METHOD_NAME)?;
208+
let input_info = signature.get_input(REGRESS_INPUTS)?;
209+
let output_info = signature.get_output(REGRESS_OUTPUTS)?;
210+
let input_op = graph.operation_by_name_required(&input_info.name().name)?;
211+
let output_op = graph.operation_by_name_required(&output_info.name().name)?;
212+
213+
let mut input_tensor = Tensor::<f32>::new(&[1, 2]);
214+
for i in 0..4 {
215+
input_tensor[0] = (i & 1) as f32;
216+
input_tensor[1] = ((i >> 1) & 1) as f32;
217+
let expected = ((i & 1) ^ ((i >> 1) & 1)) as f32;
218+
let mut run_args = SessionRunArgs::new();
219+
run_args.add_feed(&input_op, input_info.name().index, &input_tensor);
220+
let output_fetch = run_args.request_fetch(&output_op, output_info.name().index);
221+
session.run(&mut run_args)?;
222+
let output = run_args.fetch::<f32>(output_fetch)?[0];
223+
let error = (output - expected) * (output - expected);
224+
println!("Error: {}", error);
225+
if error > 0.1 {
226+
return Err(Box::new(Status::new_set(
227+
Code::Internal,
228+
&format!("Error too high: {}", error),
229+
)?));
230+
}
231+
}
232+
233+
Ok(())
234+
}
235+
236+
fn main() -> Result<(), Box<dyn Error>> {
237+
let mut dir = env::temp_dir();
238+
dir.push("tf-rust-example-xor-saved-model");
239+
let mut dir2 = env::temp_dir();
240+
dir2.push("tf-rust-example-xor-saved-model2");
241+
match fs::remove_dir_all(&dir) {
242+
Err(e) => {
243+
if e.kind() != ErrorKind::NotFound {
244+
return Err(Box::new(e));
245+
}
246+
}
247+
Ok(_) => (),
248+
}
249+
match fs::remove_dir_all(&dir2) {
250+
Err(e) => {
251+
if e.kind() != ErrorKind::NotFound {
252+
return Err(Box::new(e));
253+
}
254+
}
255+
Ok(_) => (),
256+
}
257+
train(&dir)?;
258+
// Ensure that the saved model works even when moved.
259+
// Users do not need to do this; this is purely for testing purposes.
260+
fs::rename(&dir, &dir2)?;
261+
eval(&dir2)?;
262+
Ok(())
263+
}

src/graph.rs

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,13 @@ use std;
1818
use std::ffi::CStr;
1919
use std::ffi::CString;
2020
use std::ffi::NulError;
21+
use std::fmt;
22+
use std::fmt::Display;
23+
use std::fmt::Formatter;
2124
use std::os::raw::c_void as std_c_void;
2225
use std::ptr;
2326
use std::slice;
27+
use std::str::FromStr;
2428
use std::str::Utf8Error;
2529
use std::sync::Arc;
2630
use tensorflow_sys as tf;
@@ -1724,6 +1728,50 @@ impl Output {
17241728
})
17251729
}
17261730
}
1731+
1732+
/// Returns the name of this output.
1733+
pub fn name(&self) -> Result<OutputName> {
1734+
Ok(OutputName {
1735+
name: self.operation.name()?,
1736+
index: self.index,
1737+
})
1738+
}
1739+
}
1740+
1741+
////////////////////////
1742+
1743+
/// Names a specific Output in the graph.
1744+
#[derive(Clone, PartialEq, Eq, Hash, Debug, Default)]
1745+
pub struct OutputName {
1746+
/// Name of the operation the edge connects to.
1747+
pub name: String,
1748+
1749+
/// Index into either the outputs of the operation.
1750+
pub index: c_int,
1751+
}
1752+
1753+
impl FromStr for OutputName {
1754+
type Err = Status;
1755+
fn from_str(s: &str) -> Result<Self> {
1756+
let splits: Vec<_> = s.split(':').collect();
1757+
if splits.len() != 2 {
1758+
return Err(Status::new_set_lossy(
1759+
Code::InvalidArgument,
1760+
"Name must contain exactly one colon (':')",
1761+
));
1762+
}
1763+
let index = splits[1].parse::<c_int>()?;
1764+
Ok(Self {
1765+
name: splits[0].to_string(),
1766+
index,
1767+
})
1768+
}
1769+
}
1770+
1771+
impl Display for OutputName {
1772+
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
1773+
write!(f, "{}:{}", self.name, self.index)
1774+
}
17271775
}
17281776

17291777
////////////////////////
@@ -2847,4 +2895,26 @@ mod tests {
28472895
assert_eq!(consumers[0].0.name().unwrap(), "y");
28482896
assert_eq!(consumers[0].1, 0);
28492897
}
2898+
2899+
#[test]
2900+
fn output_name() {
2901+
assert_eq!(
2902+
"foo:1".parse::<OutputName>().unwrap(),
2903+
OutputName {
2904+
name: "foo".to_string(),
2905+
index: 1
2906+
}
2907+
);
2908+
assert_eq!(
2909+
OutputName {
2910+
name: "foo".to_string(),
2911+
index: 1
2912+
}
2913+
.to_string(),
2914+
"foo:1"
2915+
);
2916+
assert!("foo".parse::<OutputName>().is_err());
2917+
assert!("foo:bar".parse::<OutputName>().is_err());
2918+
assert!("foo:0:1".parse::<OutputName>().is_err());
2919+
}
28502920
}

0 commit comments

Comments
 (0)