Skip to content

Commit 70f18f2

Browse files
authored
lstm/gru: use Borrow<Path> instead of &var_store::Path (LaurentMazare#444)
1 parent 7660fb8 commit 70f18f2

1 file changed

Lines changed: 15 additions & 2 deletions

File tree

src/nn/rnn.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! Recurrent Neural Networks
22
use crate::{Device, Kind, Tensor};
3+
use std::borrow::Borrow;
34

45
/// Trait for Recurrent Neural Networks.
56
#[allow(clippy::upper_case_acronyms)]
@@ -86,7 +87,13 @@ pub struct LSTM {
8687
}
8788

8889
/// Creates a LSTM layer.
89-
pub fn lstm(vs: &super::var_store::Path, in_dim: i64, hidden_dim: i64, c: RNNConfig) -> LSTM {
90+
pub fn lstm<'a, T: Borrow<super::Path<'a>>>(
91+
vs: T,
92+
in_dim: i64,
93+
hidden_dim: i64,
94+
c: RNNConfig,
95+
) -> LSTM {
96+
let vs = vs.borrow();
9097
let num_directions = if c.bidirectional { 2 } else { 1 };
9198
let gate_dim = 4 * hidden_dim;
9299
let mut flat_weights = vec![];
@@ -186,7 +193,13 @@ pub struct GRU {
186193
}
187194

188195
/// Creates a new GRU layer.
189-
pub fn gru(vs: &super::var_store::Path, in_dim: i64, hidden_dim: i64, c: RNNConfig) -> GRU {
196+
pub fn gru<'a, T: Borrow<super::Path<'a>>>(
197+
vs: T,
198+
in_dim: i64,
199+
hidden_dim: i64,
200+
c: RNNConfig,
201+
) -> GRU {
202+
let vs = vs.borrow();
190203
let num_directions = if c.bidirectional { 2 } else { 1 };
191204
let gate_dim = 3 * hidden_dim;
192205
let mut flat_weights = vec![];

0 commit comments

Comments
 (0)