File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 11//! Recurrent Neural Networks
22use crate :: { Device , Kind , Tensor } ;
3+ use std:: borrow:: Borrow ;
34
45/// Trait for Recurrent Neural Networks.
56#[ allow( clippy:: upper_case_acronyms) ]
@@ -86,7 +87,13 @@ pub struct LSTM {
8687}
8788
8889/// Creates a LSTM layer.
89- pub fn lstm ( vs : & super :: var_store:: Path , in_dim : i64 , hidden_dim : i64 , c : RNNConfig ) -> LSTM {
90+ pub fn lstm < ' a , T : Borrow < super :: Path < ' a > > > (
91+ vs : T ,
92+ in_dim : i64 ,
93+ hidden_dim : i64 ,
94+ c : RNNConfig ,
95+ ) -> LSTM {
96+ let vs = vs. borrow ( ) ;
9097 let num_directions = if c. bidirectional { 2 } else { 1 } ;
9198 let gate_dim = 4 * hidden_dim;
9299 let mut flat_weights = vec ! [ ] ;
@@ -186,7 +193,13 @@ pub struct GRU {
186193}
187194
188195/// Creates a new GRU layer.
189- pub fn gru ( vs : & super :: var_store:: Path , in_dim : i64 , hidden_dim : i64 , c : RNNConfig ) -> GRU {
196+ pub fn gru < ' a , T : Borrow < super :: Path < ' a > > > (
197+ vs : T ,
198+ in_dim : i64 ,
199+ hidden_dim : i64 ,
200+ c : RNNConfig ,
201+ ) -> GRU {
202+ let vs = vs. borrow ( ) ;
190203 let num_directions = if c. bidirectional { 2 } else { 1 } ;
191204 let gate_dim = 3 * hidden_dim;
192205 let mut flat_weights = vec ! [ ] ;
You can’t perform that action at this time.
0 commit comments