From 0847a0afcd92ba36219addbb906270416e482809 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 25 Oct 2018 13:39:46 +0200 Subject: [PATCH] Replace Numberer by TransitionLookup in the transition systems. TransitionLookup is a wrapper of Numberer that has several additional properties that are useful for transition tables. - It insures that the identifier 0 for unknown transitions. - It returns the correct length of the transition table, that includes the special identifier 0. - The table can be both fresh and frozen. A fresh table automatically adds transitions that are not known. A frozen table returns the special identifier 0 when a transition is now known. For future consideration: provide a thaw method as well? --- dpar/Cargo.toml | 1 + dpar/src/guide/tensorflow/guide.rs | 4 +- dpar/src/lib.rs | 3 + dpar/src/models/tensorflow/model.rs | 2 +- dpar/src/numberer.rs | 2 +- dpar/src/system/mod.rs | 2 +- dpar/src/system/trans_system.rs | 191 ++++++++++++++++++++++++++- dpar/src/systems/arc_eager.rs | 15 +-- dpar/src/systems/arc_hybrid.rs | 15 +-- dpar/src/systems/arc_standard.rs | 15 +-- dpar/src/systems/stack_projective.rs | 15 +-- dpar/src/systems/stack_swap.rs | 15 +-- dpar/src/train/hdf5.rs | 2 +- 13 files changed, 228 insertions(+), 54 deletions(-) diff --git a/dpar/Cargo.toml b/dpar/Cargo.toml index 88b0d89..b7f8d11 100644 --- a/dpar/Cargo.toml +++ b/dpar/Cargo.toml @@ -24,3 +24,4 @@ approx = "0.3" flate2 = "1" lazy_static = "0.2" pretty_assertions = "0.5" +serde_yaml = "0.8" diff --git a/dpar/src/guide/tensorflow/guide.rs b/dpar/src/guide/tensorflow/guide.rs index 07865f0..3d34579 100644 --- a/dpar/src/guide/tensorflow/guide.rs +++ b/dpar/src/guide/tensorflow/guide.rs @@ -169,7 +169,7 @@ where { // Invariant: we should have as many predictions as transitions. let n_predictions = logits.as_ref().len(); - let n_transitions = self.system.transitions().len() + self.system.transitions().start_at(); + let n_transitions = self.system.transitions().len(); assert_eq!( n_predictions, n_transitions, "Number of transitions ({}) and predictions ({}) are inequal.", @@ -189,7 +189,7 @@ where } } - best.clone() + best.into_owned() } } diff --git a/dpar/src/lib.rs b/dpar/src/lib.rs index df9e2ec..0d30c1a 100644 --- a/dpar/src/lib.rs +++ b/dpar/src/lib.rs @@ -66,3 +66,6 @@ extern crate maplit; #[cfg(test)] #[macro_use] extern crate pretty_assertions; + +#[cfg(test)] +extern crate serde_yaml; diff --git a/dpar/src/models/tensorflow/model.rs b/dpar/src/models/tensorflow/model.rs index a012a26..dc25c65 100644 --- a/dpar/src/models/tensorflow/model.rs +++ b/dpar/src/models/tensorflow/model.rs @@ -413,7 +413,7 @@ where } } - best.clone() + best.into_owned() } /// Compute transition logits from the feature representations of the diff --git a/dpar/src/numberer.rs b/dpar/src/numberer.rs index 2f4cd99..9f14f7c 100644 --- a/dpar/src/numberer.rs +++ b/dpar/src/numberer.rs @@ -4,7 +4,7 @@ use std::collections::HashMap; use std::hash::Hash; /// Numberer for categorical values, such as features or class labels. -#[derive(Eq, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Eq, PartialEq, Serialize, Deserialize)] pub struct Numberer where T: Eq + Hash, diff --git a/dpar/src/system/mod.rs b/dpar/src/system/mod.rs index ae8ebe8..95596c6 100644 --- a/dpar/src/system/mod.rs +++ b/dpar/src/system/mod.rs @@ -11,7 +11,7 @@ mod parser_state; pub use self::parser_state::ParserState; mod trans_system; -pub use self::trans_system::{Transition, TransitionSystem}; +pub use self::trans_system::{Transition, TransitionLookup, TransitionSystem}; pub fn sentence_to_dependencies(sentence: &Sentence) -> Result { let mut dependencies = HashSet::new(); diff --git a/dpar/src/system/trans_system.rs b/dpar/src/system/trans_system.rs index 3fc4e32..43d92af 100644 --- a/dpar/src/system/trans_system.rs +++ b/dpar/src/system/trans_system.rs @@ -1,8 +1,10 @@ +use std::borrow::Cow; +use std::cell::RefCell; use std::fmt::Debug; use std::hash::Hash; use serde::de::DeserializeOwned; -use serde::Serialize; +use serde::{Serialize, Serializer}; use guide::Guide; use numberer::Numberer; @@ -14,8 +16,7 @@ pub trait TransitionSystem { fn is_terminal(state: &ParserState) -> bool; fn oracle(gold_dependencies: &DependencySet) -> Self::Oracle; - fn transitions(&self) -> &Numberer; - fn transitions_mut(&mut self) -> &mut Numberer; + fn transitions(&self) -> &TransitionLookup; } pub trait Transition: Clone + Debug + Eq + Hash + Serialize + DeserializeOwned { @@ -24,3 +25,187 @@ pub trait Transition: Clone + Debug + Eq + Hash + Serialize + DeserializeOwned { fn is_possible(&self, state: &ParserState) -> bool; fn apply(&self, state: &mut ParserState); } + +/// Transition lookup table. +/// +/// Instances of this type provide a transition lookup table. When a fresh +/// table is created, a transition lookup adds a transition to the table. +/// If the table is frozen through serialization or the `freeze` method, +/// the table becomes immutable. Lookups of transitions that are not in the +/// table will result in a special index (`0`). +#[derive(Debug, Deserialize, Eq)] +pub enum TransitionLookup +where + T: Eq + Hash, +{ + Fresh(RefCell>), + Frozen(Numberer), +} + +impl PartialEq for TransitionLookup +where + T: Eq + Hash, +{ + fn eq(&self, rhs: &TransitionLookup) -> bool { + use self::TransitionLookup::*; + + // Two TransitionLookups are equal if their numberers + // are equal. + match (self, rhs) { + (Frozen(ln), Frozen(rn)) => ln == rn, + (Fresh(lrc), Fresh(rrc)) => lrc == rrc, + (Frozen(ln), Fresh(rrc)) => ln == &*rrc.borrow(), + (Fresh(lrc), Frozen(rn)) => &*lrc.borrow() == rn, + } + } +} + +impl Serialize for TransitionLookup +where + T: Eq + Hash + Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + use self::TransitionLookup::*; + + match self { + Fresh(refcell) => serializer.serialize_newtype_variant( + "TransitionLookup", + 1, + "Frozen", + &*refcell.borrow(), + ), + Frozen(ref numberer) => { + serializer.serialize_newtype_variant("TransitionLookup", 1, "Frozen", numberer) + } + } + } +} + +impl TransitionLookup +where + T: Clone + Eq + Hash, +{ + /// Freeze a transition table. + pub fn freeze(self) -> Self { + use self::TransitionLookup::*; + + match self { + Fresh(cell) => Frozen(cell.into_inner()), + frozen => frozen, + } + } + + /// Length of the transition table. + pub fn len(&self) -> usize { + use self::TransitionLookup::*; + + match self { + Fresh(cell) => cell.borrow().len() + 1, + Frozen(numberer) => numberer.len() + 1, + } + } + + /// Look up a transition. + /// + /// If the the transition is not in the table, it is added for a + /// fresh table. Frozen tables will return the special identifier + /// `0` in such cases. + pub fn lookup(&self, t: T) -> usize { + use self::TransitionLookup::*; + + match self { + Fresh(cell) => cell.borrow_mut().add(t), + Frozen(numberer) => numberer.number(&t).unwrap_or(0), + } + } + + /// Get the transition corresponding to an identifier. + /// + /// Fresh tables return copies of transitions, frozen tables references + /// to transitions. `None` is returned when the identifier is unknown + /// or the special identifier `0`. + pub fn value(&self, id: usize) -> Option> { + use self::TransitionLookup::*; + + if id == 0 { + return None; + } + + match self { + Fresh(cell) => cell.borrow().value(id).cloned().map(Cow::Owned), + Frozen(numberer) => numberer.value(id).map(Cow::Borrowed), + } + } +} + +impl Default for TransitionLookup +where + T: Clone + Eq + Hash, +{ + fn default() -> Self { + TransitionLookup::Fresh(RefCell::new(Numberer::new(1))) + } +} + +#[cfg(test)] +mod tests { + use std::borrow::Cow; + + use systems::arc_standard::ArcStandardTransition; + + use super::TransitionLookup; + + #[test] + pub fn transition_lookup() { + use self::ArcStandardTransition::*; + + let fresh = TransitionLookup::default(); + assert_eq!(fresh.lookup(Shift), 1); + assert_eq!(fresh.lookup(LeftArc("foo".to_owned())), 2); + assert_eq!(fresh.lookup(RightArc("bar".to_owned())), 3); + + let frozen = fresh.freeze(); + + assert_eq!(frozen.len(), 4); + + assert_eq!(frozen.lookup(Shift), 1); + assert_eq!(frozen.lookup(LeftArc("foo".to_owned())), 2); + assert_eq!(frozen.lookup(RightArc("bar".to_owned())), 3); + assert_eq!(frozen.lookup(LeftArc("baz".to_owned())), 0); + + assert_eq!(frozen.value(1), Some(Cow::Owned(Shift))); + assert_eq!(frozen.value(2), Some(Cow::Owned(LeftArc("foo".to_owned())))); + assert_eq!( + frozen.value(3), + Some(Cow::Owned(RightArc("bar".to_owned()))) + ); + assert_eq!(frozen.value(0), None); + } + + #[test] + pub fn transition_lookup_serialization_roundtrip() { + use self::ArcStandardTransition::*; + + let fresh = TransitionLookup::default(); + assert_eq!(fresh.lookup(Shift), 1); + assert_eq!(fresh.lookup(LeftArc("foo".to_owned())), 2); + assert_eq!(fresh.lookup(RightArc("bar".to_owned())), 3); + + let serialized = + ::serde_yaml::to_string(&fresh).expect("Serialization of transition lookup failed"); + + let frozen: TransitionLookup = ::serde_yaml::from_str(&serialized) + .expect("Deserialization of transition lookup failed"); + + // Check that serialization freezes the lookup table. + if let TransitionLookup::Fresh(_) = frozen { + panic!("Deserialized transition lookup was fresh, should be frozen."); + }; + + // Check that serialization/deserialization roundtrip preserved data. + assert_eq!(fresh, frozen); + } +} diff --git a/dpar/src/systems/arc_eager.rs b/dpar/src/systems/arc_eager.rs index 5250ac7..4abd440 100644 --- a/dpar/src/systems/arc_eager.rs +++ b/dpar/src/systems/arc_eager.rs @@ -1,19 +1,20 @@ use std::collections::HashMap; use guide::Guide; -use numberer::Numberer; -use system::{Dependency, DependencySet, ParserState, Transition, TransitionSystem}; +use system::{ + Dependency, DependencySet, ParserState, Transition, TransitionLookup, TransitionSystem, +}; use systems::util::dep_head_mapping; #[derive(Eq, PartialEq, Serialize, Deserialize)] pub struct ArcEagerSystem { - transitions: Numberer, + transitions: TransitionLookup, } impl ArcEagerSystem { pub fn new() -> Self { ArcEagerSystem { - transitions: Numberer::new(0), + transitions: TransitionLookup::default(), } } } @@ -39,13 +40,9 @@ impl TransitionSystem for ArcEagerSystem { ArcEagerOracle::new(gold_dependencies) } - fn transitions(&self) -> &Numberer { + fn transitions(&self) -> &TransitionLookup { &self.transitions } - - fn transitions_mut(&mut self) -> &mut Numberer { - &mut self.transitions - } } #[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] diff --git a/dpar/src/systems/arc_hybrid.rs b/dpar/src/systems/arc_hybrid.rs index 13d86fc..8e2dc54 100644 --- a/dpar/src/systems/arc_hybrid.rs +++ b/dpar/src/systems/arc_hybrid.rs @@ -1,8 +1,9 @@ use std::collections::HashMap; use guide::Guide; -use numberer::Numberer; -use system::{Dependency, DependencySet, ParserState, Transition, TransitionSystem}; +use system::{ + Dependency, DependencySet, ParserState, Transition, TransitionLookup, TransitionSystem, +}; use systems::util::dep_head_mapping; @@ -12,13 +13,13 @@ use systems::util::dep_head_mapping; /// Dependency Parsers, 2011. #[derive(Eq, PartialEq, Serialize, Deserialize)] pub struct ArcHybridSystem { - transitions: Numberer, + transitions: TransitionLookup, } impl ArcHybridSystem { pub fn new() -> Self { ArcHybridSystem { - transitions: Numberer::new(0), + transitions: TransitionLookup::default(), } } } @@ -41,13 +42,9 @@ impl TransitionSystem for ArcHybridSystem { ArcHybridOracle::new(gold_dependencies) } - fn transitions(&self) -> &Numberer { + fn transitions(&self) -> &TransitionLookup { &self.transitions } - - fn transitions_mut(&mut self) -> &mut Numberer { - &mut self.transitions - } } /// Stack-projective transition. diff --git a/dpar/src/systems/arc_standard.rs b/dpar/src/systems/arc_standard.rs index e94c07f..8406169 100644 --- a/dpar/src/systems/arc_standard.rs +++ b/dpar/src/systems/arc_standard.rs @@ -1,20 +1,21 @@ use std::collections::HashMap; use guide::Guide; -use numberer::Numberer; -use system::{Dependency, DependencySet, ParserState, Transition, TransitionSystem}; +use system::{ + Dependency, DependencySet, ParserState, Transition, TransitionLookup, TransitionSystem, +}; use systems::util::dep_head_mapping; #[derive(Eq, PartialEq, Serialize, Deserialize)] pub struct ArcStandardSystem { - transitions: Numberer, + transitions: TransitionLookup, } impl ArcStandardSystem { pub fn new() -> Self { ArcStandardSystem { - transitions: Numberer::new(0), + transitions: TransitionLookup::default(), } } } @@ -37,13 +38,9 @@ impl TransitionSystem for ArcStandardSystem { ArcStandardOracle::new(gold_dependencies) } - fn transitions(&self) -> &Numberer { + fn transitions(&self) -> &TransitionLookup { &self.transitions } - - fn transitions_mut(&mut self) -> &mut Numberer { - &mut self.transitions - } } /// Arc-standard transition. diff --git a/dpar/src/systems/stack_projective.rs b/dpar/src/systems/stack_projective.rs index 1ac2c48..045dc36 100644 --- a/dpar/src/systems/stack_projective.rs +++ b/dpar/src/systems/stack_projective.rs @@ -1,20 +1,21 @@ use std::collections::HashMap; use guide::Guide; -use numberer::Numberer; -use system::{Dependency, DependencySet, ParserState, Transition, TransitionSystem}; +use system::{ + Dependency, DependencySet, ParserState, Transition, TransitionLookup, TransitionSystem, +}; use systems::util::dep_head_mapping; #[derive(Eq, PartialEq, Serialize, Deserialize)] pub struct StackProjectiveSystem { - transitions: Numberer, + transitions: TransitionLookup, } impl StackProjectiveSystem { pub fn new() -> Self { StackProjectiveSystem { - transitions: Numberer::new(0), + transitions: TransitionLookup::default(), } } } @@ -37,13 +38,9 @@ impl TransitionSystem for StackProjectiveSystem { StackProjectiveOracle::new(gold_dependencies) } - fn transitions(&self) -> &Numberer { + fn transitions(&self) -> &TransitionLookup { &self.transitions } - - fn transitions_mut(&mut self) -> &mut Numberer { - &mut self.transitions - } } /// Stack-projective transition. diff --git a/dpar/src/systems/stack_swap.rs b/dpar/src/systems/stack_swap.rs index 46a6f63..1b913b6 100644 --- a/dpar/src/systems/stack_swap.rs +++ b/dpar/src/systems/stack_swap.rs @@ -5,8 +5,9 @@ use petgraph::visit::Dfs; use petgraph::{Directed, Graph}; use guide::Guide; -use numberer::Numberer; -use system::{Dependency, DependencySet, ParserState, Transition, TransitionSystem}; +use system::{ + Dependency, DependencySet, ParserState, Transition, TransitionLookup, TransitionSystem, +}; use systems::util::dep_head_mapping; /// The stack-swap transition system for non-projective parsing. @@ -19,13 +20,13 @@ use systems::util::dep_head_mapping; /// Joakim Nivre, Non-projective dependency parsing in expected linear time, 2009 #[derive(Eq, PartialEq, Serialize, Deserialize)] pub struct StackSwapSystem { - transitions: Numberer, + transitions: TransitionLookup, } impl StackSwapSystem { pub fn new() -> Self { StackSwapSystem { - transitions: Numberer::new(0), + transitions: TransitionLookup::default(), } } } @@ -48,13 +49,9 @@ impl TransitionSystem for StackSwapSystem { StackSwapOracle::new(gold_dependencies) } - fn transitions(&self) -> &Numberer { + fn transitions(&self) -> &TransitionLookup { &self.transitions } - - fn transitions_mut(&mut self) -> &mut Numberer { - &mut self.transitions - } } /// Stack-projective transition. diff --git a/dpar/src/train/hdf5.rs b/dpar/src/train/hdf5.rs index b01df6b..1e7ba50 100644 --- a/dpar/src/train/hdf5.rs +++ b/dpar/src/train/hdf5.rs @@ -48,7 +48,7 @@ where T: TransitionSystem, { fn collect(&mut self, t: &T::Transition, state: &ParserState) -> Result<()> { - let label = self.transition_system.transitions_mut().add(t.clone()); + let label = self.transition_system.transitions().lookup(t.clone()); let v = self.vectorizer.realize(state); self.writer.write(label, v) }