diff --git a/src/main/scala/epic/lexicon/SimpleLexicon.scala b/src/main/scala/epic/lexicon/SimpleLexicon.scala index e9433132..3d2f369a 100644 --- a/src/main/scala/epic/lexicon/SimpleLexicon.scala +++ b/src/main/scala/epic/lexicon/SimpleLexicon.scala @@ -6,7 +6,7 @@ import breeze.linalg._ import breeze.util.Index import scala.collection.immutable.BitSet -import scala.collection.mutable +import scala.collection.mutable.Map /** * A simple lexicon that thresholds to decide when to open up the rare word to all (open) tags @@ -17,62 +17,59 @@ import scala.collection.mutable * observed tag set. */ @SerialVersionUID(1L) -class SimpleLexicon[L, W](val labelIndex: Index[L], - wordTagCounts: Counter2[L, W, Double], - openTagThreshold: Int = 50, - closedWordThreshold: Int= 10) extends Lexicon[L, W] with Serializable { - private val wordCounts:Counter[W, Double] = sum(wordTagCounts, Axis._0) - private val labelCounts:Counter[L, Double] = sum(wordTagCounts, Axis._1) - - private val byWord: mutable.Map[W, Set[Int]] = mutable.Map.empty[W, Set[Int]] ++ wordTagCounts.keySet.groupBy(_._2).mapValues(_.map(pair => labelIndex(pair._1)).toSet) - - private val openTags: Set[Int] = { - val set = labelCounts.keysIterator.filter(l => wordTagCounts(l, ::).size > openTagThreshold).toSet.map((l:L) => labelIndex(l)) - if(set.isEmpty) BitSet.empty ++ (0 until labelIndex.size) - else set - } - - - - for( (w,v) <- wordCounts.iterator if v < closedWordThreshold) { - byWord.get(w) match { - case None => byWord(w) = openTags - case Some(set) => byWord(w) = set ++ openTags - } - } - +class SimpleLexicon[L, W]( + val labelIndex: Index[L], + wordTagCounts: Counter2[L, W, Double], + openTagThreshold: Int = 50, + closedWordThreshold: Int= 10 +) extends Lexicon[L, W] with Serializable { + private val wordCounts: Counter[W, Double] = sum(wordTagCounts, Axis._0) + private val labelCounts: Counter[L, Double] = sum(wordTagCounts, Axis._1) + private val byWord: Map[W, Set[Int]] = Map.empty[W, Set[Int]] ++ + wordTagCounts.keySet.groupBy(_._2).mapValues(_.map(pair => labelIndex(pair._1)).toSet) + + private val openTags: Set[Int] = Option( + labelCounts.keysIterator.collect { case l if wordTagCounts(l, ::).size > openTagThreshold => labelIndex(l) }.toSet + ).filter(_.nonEmpty).getOrElse( + BitSet.empty ++ (0 until labelIndex.size) + ) + + for((w,v) <- wordCounts.iterator if v < closedWordThreshold) + byWord(w) = byWord.get(w).fold(openTags)( _ ++ openTags) def allowedTags(w: W): Set[Int] = byWord.getOrElse(w, openTags) - def anchor(w: IndexedSeq[W]):Anchoring = new Anchoring { + def anchor(w: IndexedSeq[W]): Anchoring = new Anchoring { def length = w.length - val x = Array.tabulate(w.length)(pos =>byWord.getOrElse(w(pos), openTags)) + val x = Array.tabulate(w.length)(pos => byWord.getOrElse(w(pos), openTags)) def allowedTags(pos: Int): Set[Int] = x(pos) } - @throws(classOf[ObjectStreamException]) - private def writeReplace():Object = { + private def writeReplace(): Object = new SimpleLexicon.SerializedForm(labelIndex, wordTagCounts, openTagThreshold, closedWordThreshold) - } override def morePermissive: Lexicon[L, W] = new SimpleLexicon(labelIndex, wordTagCounts, openTagThreshold, 1000000) } object SimpleLexicon { + @SerialVersionUID(1L) - private class SerializedForm[L, W](labelIndex: Index[L], wordTagCounts: Counter2[L, W, Double], openTagThreshold: Int, closedWordThreshold: Int) extends Serializable { + private class SerializedForm[L, W]( + labelIndex: Index[L], + wordTagCounts: Counter2[L, W, Double], + openTagThreshold: Int, + closedWordThreshold: Int + ) extends Serializable { @throws(classOf[ObjectStreamException]) - private def readResolve():Object = { - try { - Class.forName("breeze.linalg.Counter$Impl") - new SimpleLexicon(labelIndex, wordTagCounts, openTagThreshold, closedWordThreshold) - } catch { - case ex => - ex.printStackTrace() - throw ex - - } + private def readResolve(): Object = try { + Class.forName("breeze.linalg.Counter$Impl") + new SimpleLexicon(labelIndex, wordTagCounts, openTagThreshold, closedWordThreshold) + } catch { + case ex: Throwable => + ex.printStackTrace() + throw ex } } + }