1
+ use std:: env;
1
2
use std:: error:: Error ;
3
+ use std:: fs;
4
+ use std:: io:: ErrorKind ;
5
+ use std:: path:: Path ;
2
6
use std:: result:: Result ;
3
7
use tensorflow:: ops;
4
8
use tensorflow:: train:: AdadeltaOptimizer ;
5
9
use tensorflow:: train:: MinimizeOptions ;
6
10
use tensorflow:: train:: Optimizer ;
7
11
use tensorflow:: Code ;
8
12
use tensorflow:: DataType ;
13
+ use tensorflow:: Graph ;
9
14
use tensorflow:: Output ;
15
+ use tensorflow:: OutputName ;
16
+ use tensorflow:: SavedModelBundle ;
10
17
use tensorflow:: Scope ;
11
18
use tensorflow:: Session ;
12
19
use tensorflow:: SessionOptions ;
13
20
use tensorflow:: SessionRunArgs ;
14
21
use tensorflow:: Shape ;
22
+ use tensorflow:: SignatureDef ;
15
23
use tensorflow:: Status ;
16
24
use tensorflow:: Tensor ;
25
+ use tensorflow:: TensorInfo ;
17
26
use tensorflow:: Variable ;
27
+ use tensorflow:: REGRESS_INPUTS ;
28
+ use tensorflow:: REGRESS_METHOD_NAME ;
29
+ use tensorflow:: REGRESS_OUTPUTS ;
18
30
19
31
// Helper for building a layer.
20
32
//
@@ -54,7 +66,7 @@ fn layer<O1: Into<Output>>(
54
66
) )
55
67
}
56
68
57
- fn main ( ) -> Result < ( ) , Box < Error > > {
69
+ fn train < P : AsRef < Path > > ( save_dir : P ) -> Result < ( ) , Box < dyn Error > > {
58
70
// ================
59
71
// Build the model.
60
72
// ================
@@ -94,6 +106,34 @@ fn main() -> Result<(), Box<Error>> {
94
106
MinimizeOptions :: default ( ) . with_variables ( & variables) ,
95
107
) ?;
96
108
109
+ let mut all_vars = variables. clone ( ) ;
110
+ all_vars. extend_from_slice ( & minimizer_vars) ;
111
+ let mut builder = tensorflow:: SavedModelBuilder :: new ( ) ;
112
+ builder
113
+ . add_collection ( "train" , & all_vars)
114
+ . add_tag ( "serve" )
115
+ . add_tag ( "train" )
116
+ . add_signature ( REGRESS_METHOD_NAME , {
117
+ let mut def = SignatureDef :: new ( REGRESS_METHOD_NAME . to_string ( ) ) ;
118
+ def. add_input_info (
119
+ REGRESS_INPUTS . to_string ( ) ,
120
+ TensorInfo :: new (
121
+ DataType :: Float ,
122
+ Shape :: from ( None ) ,
123
+ OutputName {
124
+ name : input. name ( ) ?,
125
+ index : 0 ,
126
+ } ,
127
+ ) ,
128
+ ) ;
129
+ def. add_output_info (
130
+ REGRESS_OUTPUTS . to_string ( ) ,
131
+ TensorInfo :: new ( DataType :: Float , Shape :: from ( None ) , layer2. name ( ) ?) ,
132
+ ) ;
133
+ def
134
+ } ) ;
135
+ let saved_model_saver = builder. inject ( scope) ?;
136
+
97
137
// =========================
98
138
// Initialize the variables.
99
139
// =========================
@@ -118,7 +158,7 @@ fn main() -> Result<(), Box<Error>> {
118
158
let mut label_tensor = Tensor :: < f32 > :: new ( & [ 1 ] ) ;
119
159
// Helper that generates a training example from an integer, trains on that
120
160
// example, and returns the error.
121
- let mut train = |i| -> Result < f32 , Box < Error > > {
161
+ let mut train = |i| -> Result < f32 , Box < dyn Error > > {
122
162
input_tensor[ 0 ] = ( i & 1 ) as f32 ;
123
163
input_tensor[ 1 ] = ( ( i >> 1 ) & 1 ) as f32 ;
124
164
label_tensor[ 0 ] = ( ( i & 1 ) ^ ( ( i >> 1 ) & 1 ) ) as f32 ;
@@ -134,6 +174,11 @@ fn main() -> Result<(), Box<Error>> {
134
174
train ( i) ?;
135
175
}
136
176
177
+ // ================
178
+ // Save the model.
179
+ // ================
180
+ saved_model_saver. save ( & session, & g, & save_dir) ?;
181
+
137
182
// ===================
138
183
// Evaluate the model.
139
184
// ===================
@@ -149,3 +194,70 @@ fn main() -> Result<(), Box<Error>> {
149
194
}
150
195
Ok ( ( ) )
151
196
}
197
+
198
+ fn eval < P : AsRef < Path > > ( save_dir : P ) -> Result < ( ) , Box < dyn Error > > {
199
+ let mut graph = Graph :: new ( ) ;
200
+ let bundle = SavedModelBundle :: load (
201
+ & SessionOptions :: new ( ) ,
202
+ & [ "serve" , "train" ] ,
203
+ & mut graph,
204
+ save_dir,
205
+ ) ?;
206
+ let session = & bundle. session ;
207
+ let signature = bundle. meta_graph_def ( ) . get_signature ( REGRESS_METHOD_NAME ) ?;
208
+ let input_info = signature. get_input ( REGRESS_INPUTS ) ?;
209
+ let output_info = signature. get_output ( REGRESS_OUTPUTS ) ?;
210
+ let input_op = graph. operation_by_name_required ( & input_info. name ( ) . name ) ?;
211
+ let output_op = graph. operation_by_name_required ( & output_info. name ( ) . name ) ?;
212
+
213
+ let mut input_tensor = Tensor :: < f32 > :: new ( & [ 1 , 2 ] ) ;
214
+ for i in 0 ..4 {
215
+ input_tensor[ 0 ] = ( i & 1 ) as f32 ;
216
+ input_tensor[ 1 ] = ( ( i >> 1 ) & 1 ) as f32 ;
217
+ let expected = ( ( i & 1 ) ^ ( ( i >> 1 ) & 1 ) ) as f32 ;
218
+ let mut run_args = SessionRunArgs :: new ( ) ;
219
+ run_args. add_feed ( & input_op, input_info. name ( ) . index , & input_tensor) ;
220
+ let output_fetch = run_args. request_fetch ( & output_op, output_info. name ( ) . index ) ;
221
+ session. run ( & mut run_args) ?;
222
+ let output = run_args. fetch :: < f32 > ( output_fetch) ?[ 0 ] ;
223
+ let error = ( output - expected) * ( output - expected) ;
224
+ println ! ( "Error: {}" , error) ;
225
+ if error > 0.1 {
226
+ return Err ( Box :: new ( Status :: new_set (
227
+ Code :: Internal ,
228
+ & format ! ( "Error too high: {}" , error) ,
229
+ ) ?) ) ;
230
+ }
231
+ }
232
+
233
+ Ok ( ( ) )
234
+ }
235
+
236
+ fn main ( ) -> Result < ( ) , Box < dyn Error > > {
237
+ let mut dir = env:: temp_dir ( ) ;
238
+ dir. push ( "tf-rust-example-xor-saved-model" ) ;
239
+ let mut dir2 = env:: temp_dir ( ) ;
240
+ dir2. push ( "tf-rust-example-xor-saved-model2" ) ;
241
+ match fs:: remove_dir_all ( & dir) {
242
+ Err ( e) => {
243
+ if e. kind ( ) != ErrorKind :: NotFound {
244
+ return Err ( Box :: new ( e) ) ;
245
+ }
246
+ }
247
+ Ok ( _) => ( ) ,
248
+ }
249
+ match fs:: remove_dir_all ( & dir2) {
250
+ Err ( e) => {
251
+ if e. kind ( ) != ErrorKind :: NotFound {
252
+ return Err ( Box :: new ( e) ) ;
253
+ }
254
+ }
255
+ Ok ( _) => ( ) ,
256
+ }
257
+ train ( & dir) ?;
258
+ // Ensure that the saved model works even when moved.
259
+ // Users do not need to do this; this is purely for testing purposes.
260
+ fs:: rename ( & dir, & dir2) ?;
261
+ eval ( & dir2) ?;
262
+ Ok ( ( ) )
263
+ }
0 commit comments