Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dpar/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ approx = "0.3"
flate2 = "1"
lazy_static = "0.2"
pretty_assertions = "0.5"
serde_yaml = "0.8"
4 changes: 2 additions & 2 deletions dpar/src/guide/tensorflow/guide.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ where
{
// Invariant: we should have as many predictions as transitions.
let n_predictions = logits.as_ref().len();
let n_transitions = self.system.transitions().len() + self.system.transitions().start_at();
let n_transitions = self.system.transitions().len();
assert_eq!(
n_predictions, n_transitions,
"Number of transitions ({}) and predictions ({}) are inequal.",
Expand All @@ -189,7 +189,7 @@ where
}
}

best.clone()
best.into_owned()
}
}

Expand Down
3 changes: 3 additions & 0 deletions dpar/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,6 @@ extern crate maplit;
#[cfg(test)]
#[macro_use]
extern crate pretty_assertions;

#[cfg(test)]
extern crate serde_yaml;
2 changes: 1 addition & 1 deletion dpar/src/models/tensorflow/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ where
}
}

best.clone()
best.into_owned()
}

/// Compute transition logits from the feature representations of the
Expand Down
2 changes: 1 addition & 1 deletion dpar/src/numberer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::collections::HashMap;
use std::hash::Hash;

/// Numberer for categorical values, such as features or class labels.
#[derive(Eq, PartialEq, Serialize, Deserialize)]
#[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct Numberer<T>
where
T: Eq + Hash,
Expand Down
2 changes: 1 addition & 1 deletion dpar/src/system/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ mod parser_state;
pub use self::parser_state::ParserState;

mod trans_system;
pub use self::trans_system::{Transition, TransitionSystem};
pub use self::trans_system::{Transition, TransitionLookup, TransitionSystem};

pub fn sentence_to_dependencies(sentence: &Sentence) -> Result<DependencySet> {
let mut dependencies = HashSet::new();
Expand Down
191 changes: 188 additions & 3 deletions dpar/src/system/trans_system.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use std::borrow::Cow;
use std::cell::RefCell;
use std::fmt::Debug;
use std::hash::Hash;

use serde::de::DeserializeOwned;
use serde::Serialize;
use serde::{Serialize, Serializer};

use guide::Guide;
use numberer::Numberer;
Expand All @@ -14,8 +16,7 @@ pub trait TransitionSystem {

fn is_terminal(state: &ParserState) -> bool;
fn oracle(gold_dependencies: &DependencySet) -> Self::Oracle;
fn transitions(&self) -> &Numberer<Self::Transition>;
fn transitions_mut(&mut self) -> &mut Numberer<Self::Transition>;
fn transitions(&self) -> &TransitionLookup<Self::Transition>;
}

pub trait Transition: Clone + Debug + Eq + Hash + Serialize + DeserializeOwned {
Expand All @@ -24,3 +25,187 @@ pub trait Transition: Clone + Debug + Eq + Hash + Serialize + DeserializeOwned {
fn is_possible(&self, state: &ParserState) -> bool;
fn apply(&self, state: &mut ParserState);
}

/// Transition lookup table.
///
/// Instances of this type provide a transition lookup table. When a fresh
/// table is created, a transition lookup adds a transition to the table.
/// If the table is frozen through serialization or the `freeze` method,
/// the table becomes immutable. Lookups of transitions that are not in the
/// table will result in a special index (`0`).
#[derive(Debug, Deserialize, Eq)]
pub enum TransitionLookup<T>
where
T: Eq + Hash,
{
Fresh(RefCell<Numberer<T>>),
Frozen(Numberer<T>),
}

impl<T> PartialEq for TransitionLookup<T>
where
T: Eq + Hash,
{
fn eq(&self, rhs: &TransitionLookup<T>) -> bool {
use self::TransitionLookup::*;

// Two TransitionLookups are equal if their numberers
// are equal.
match (self, rhs) {
(Frozen(ln), Frozen(rn)) => ln == rn,
(Fresh(lrc), Fresh(rrc)) => lrc == rrc,
(Frozen(ln), Fresh(rrc)) => ln == &*rrc.borrow(),
(Fresh(lrc), Frozen(rn)) => &*lrc.borrow() == rn,
}
}
}

impl<T> Serialize for TransitionLookup<T>
where
T: Eq + Hash + Serialize,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
use self::TransitionLookup::*;

match self {
Fresh(refcell) => serializer.serialize_newtype_variant(
"TransitionLookup",
1,
"Frozen",
&*refcell.borrow(),
),
Frozen(ref numberer) => {
serializer.serialize_newtype_variant("TransitionLookup", 1, "Frozen", numberer)
}
}
}
}

impl<T> TransitionLookup<T>
where
T: Clone + Eq + Hash,
{
/// Freeze a transition table.
pub fn freeze(self) -> Self {
use self::TransitionLookup::*;

match self {
Fresh(cell) => Frozen(cell.into_inner()),
frozen => frozen,
}
}

/// Length of the transition table.
pub fn len(&self) -> usize {
use self::TransitionLookup::*;

match self {
Fresh(cell) => cell.borrow().len() + 1,
Frozen(numberer) => numberer.len() + 1,
}
}

/// Look up a transition.
///
/// If the the transition is not in the table, it is added for a
/// fresh table. Frozen tables will return the special identifier
/// `0` in such cases.
pub fn lookup(&self, t: T) -> usize {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ownership of t is only required in case of a fresh lookup table where t is added to the numberer. Again, when dealing with Transition Enums probably not really a problem. If we should generalize the struct later on to e.g. number strings we might consider adding to_owned or clone to the trait bounds so that we can pass in a &T and only convert it when necessary.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we should generalize the struct later on to e.g. number strings we might consider adding to_owned or clone to the trait bounds so that we can pass in a &T and only convert it when necessary.

I agree that this would be nice. IIRC I decided use moves in Numberer::add to avoid two lookups when the value is absent and HashMaps entry method requires an owned value. But since in transition and feature lookups, the lookups succeed in the vast majority of cases, doing an get (with a reference) + insert when absent would be more economical. There was an RFC to make entry work with borrowed values, but it seems to be stranded:

rust-lang/rfcs#1769

I'll add this as a todo item to the 'one lookup to rule them all'-issue.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also consider moving the numberer / lookup into a separate crate.

For the lemmatizer, I had to implement the same functionality and dealt with similar issues. If there'd be a crate we might spare some people a lot of unnecessary worries and work.

use self::TransitionLookup::*;

match self {
Fresh(cell) => cell.borrow_mut().add(t),
Frozen(numberer) => numberer.number(&t).unwrap_or(0),
}
}

/// Get the transition corresponding to an identifier.
///
/// Fresh tables return copies of transitions, frozen tables references
/// to transitions. `None` is returned when the identifier is unknown
/// or the special identifier `0`.
pub fn value(&self, id: usize) -> Option<Cow<T>> {
use self::TransitionLookup::*;

if id == 0 {
return None;
}

match self {
Fresh(cell) => cell.borrow().value(id).cloned().map(Cow::Owned),
Frozen(numberer) => numberer.value(id).map(Cow::Borrowed),
}
}
}

impl<T> Default for TransitionLookup<T>
where
T: Clone + Eq + Hash,
{
fn default() -> Self {
TransitionLookup::Fresh(RefCell::new(Numberer::new(1)))
}
}

#[cfg(test)]
mod tests {
use std::borrow::Cow;

use systems::arc_standard::ArcStandardTransition;

use super::TransitionLookup;

#[test]
pub fn transition_lookup() {
use self::ArcStandardTransition::*;

let fresh = TransitionLookup::default();
assert_eq!(fresh.lookup(Shift), 1);
assert_eq!(fresh.lookup(LeftArc("foo".to_owned())), 2);
assert_eq!(fresh.lookup(RightArc("bar".to_owned())), 3);

let frozen = fresh.freeze();

assert_eq!(frozen.len(), 4);

assert_eq!(frozen.lookup(Shift), 1);
assert_eq!(frozen.lookup(LeftArc("foo".to_owned())), 2);
assert_eq!(frozen.lookup(RightArc("bar".to_owned())), 3);
assert_eq!(frozen.lookup(LeftArc("baz".to_owned())), 0);

assert_eq!(frozen.value(1), Some(Cow::Owned(Shift)));
assert_eq!(frozen.value(2), Some(Cow::Owned(LeftArc("foo".to_owned()))));
assert_eq!(
frozen.value(3),
Some(Cow::Owned(RightArc("bar".to_owned())))
);
assert_eq!(frozen.value(0), None);
}

#[test]
pub fn transition_lookup_serialization_roundtrip() {
use self::ArcStandardTransition::*;

let fresh = TransitionLookup::default();
assert_eq!(fresh.lookup(Shift), 1);
assert_eq!(fresh.lookup(LeftArc("foo".to_owned())), 2);
assert_eq!(fresh.lookup(RightArc("bar".to_owned())), 3);

let serialized =
::serde_yaml::to_string(&fresh).expect("Serialization of transition lookup failed");

let frozen: TransitionLookup<ArcStandardTransition> = ::serde_yaml::from_str(&serialized)
.expect("Deserialization of transition lookup failed");

// Check that serialization freezes the lookup table.
if let TransitionLookup::Fresh(_) = frozen {
panic!("Deserialized transition lookup was fresh, should be frozen.");
};

// Check that serialization/deserialization roundtrip preserved data.
assert_eq!(fresh, frozen);
}
}
15 changes: 6 additions & 9 deletions dpar/src/systems/arc_eager.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
use std::collections::HashMap;

use guide::Guide;
use numberer::Numberer;
use system::{Dependency, DependencySet, ParserState, Transition, TransitionSystem};
use system::{
Dependency, DependencySet, ParserState, Transition, TransitionLookup, TransitionSystem,
};
use systems::util::dep_head_mapping;

#[derive(Eq, PartialEq, Serialize, Deserialize)]
pub struct ArcEagerSystem {
transitions: Numberer<ArcEagerTransition>,
transitions: TransitionLookup<ArcEagerTransition>,
}

impl ArcEagerSystem {
pub fn new() -> Self {
ArcEagerSystem {
transitions: Numberer::new(0),
transitions: TransitionLookup::default(),
}
}
}
Expand All @@ -39,13 +40,9 @@ impl TransitionSystem for ArcEagerSystem {
ArcEagerOracle::new(gold_dependencies)
}

fn transitions(&self) -> &Numberer<Self::Transition> {
fn transitions(&self) -> &TransitionLookup<Self::Transition> {
&self.transitions
}

fn transitions_mut(&mut self) -> &mut Numberer<Self::Transition> {
&mut self.transitions
}
}

#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
Expand Down
15 changes: 6 additions & 9 deletions dpar/src/systems/arc_hybrid.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::collections::HashMap;

use guide::Guide;
use numberer::Numberer;
use system::{Dependency, DependencySet, ParserState, Transition, TransitionSystem};
use system::{
Dependency, DependencySet, ParserState, Transition, TransitionLookup, TransitionSystem,
};

use systems::util::dep_head_mapping;

Expand All @@ -12,13 +13,13 @@ use systems::util::dep_head_mapping;
/// Dependency Parsers, 2011.
#[derive(Eq, PartialEq, Serialize, Deserialize)]
pub struct ArcHybridSystem {
transitions: Numberer<ArcHybridTransition>,
transitions: TransitionLookup<ArcHybridTransition>,
}

impl ArcHybridSystem {
pub fn new() -> Self {
ArcHybridSystem {
transitions: Numberer::new(0),
transitions: TransitionLookup::default(),
}
}
}
Expand All @@ -41,13 +42,9 @@ impl TransitionSystem for ArcHybridSystem {
ArcHybridOracle::new(gold_dependencies)
}

fn transitions(&self) -> &Numberer<Self::Transition> {
fn transitions(&self) -> &TransitionLookup<Self::Transition> {
&self.transitions
}

fn transitions_mut(&mut self) -> &mut Numberer<Self::Transition> {
&mut self.transitions
}
}

/// Stack-projective transition.
Expand Down
15 changes: 6 additions & 9 deletions dpar/src/systems/arc_standard.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
use std::collections::HashMap;

use guide::Guide;
use numberer::Numberer;
use system::{Dependency, DependencySet, ParserState, Transition, TransitionSystem};
use system::{
Dependency, DependencySet, ParserState, Transition, TransitionLookup, TransitionSystem,
};

use systems::util::dep_head_mapping;

#[derive(Eq, PartialEq, Serialize, Deserialize)]
pub struct ArcStandardSystem {
transitions: Numberer<ArcStandardTransition>,
transitions: TransitionLookup<ArcStandardTransition>,
}

impl ArcStandardSystem {
pub fn new() -> Self {
ArcStandardSystem {
transitions: Numberer::new(0),
transitions: TransitionLookup::default(),
}
}
}
Expand All @@ -37,13 +38,9 @@ impl TransitionSystem for ArcStandardSystem {
ArcStandardOracle::new(gold_dependencies)
}

fn transitions(&self) -> &Numberer<Self::Transition> {
fn transitions(&self) -> &TransitionLookup<Self::Transition> {
&self.transitions
}

fn transitions_mut(&mut self) -> &mut Numberer<Self::Transition> {
&mut self.transitions
}
}

/// Arc-standard transition.
Expand Down
Loading