Skip to content

Commit 0896615

Browse files
committed
Initial part of generating the lucene anaylzers works
1 parent c5b6628 commit 0896615

File tree

3 files changed

+1050
-17
lines changed

3 files changed

+1050
-17
lines changed

src/main/scala/com/sparklingpandas/sparklingml/feature/LuceneAnalyzer.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package com.sparklingpandas.sparklingml.feature
22

33
import org.apache.spark.annotation.DeveloperApi
44
import org.apache.spark.ml.UnaryTransformer
5+
import org.apache.spark.ml.util.Identifiable
56
import org.apache.spark.sql.Dataset
67
import org.apache.spark.sql.types._
78

@@ -15,6 +16,8 @@ import org.apache.lucene.analysis.tokenattributes.CharTermAttribute
1516
@DeveloperApi
1617
trait LuceneTransformer extends UnaryTransformer[String, Array[String], LuceneTransformer] {
1718

19+
override val uid = Identifiable.randomUID(this.getClass.getName)
20+
1821
// Implement this function to construct an analyzer based on the provided settings.
1922
def buildAnalyzer(): Analyzer
2023

Lines changed: 184 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,195 @@
11
package com.sparklingpandas.sparklingml.feature
22

3+
import java.lang.reflect.Modifier
4+
import java.io.PrintWriter
5+
6+
import scala.collection.JavaConverters._
7+
import scala.collection.mutable.StringBuilder
8+
9+
import org.reflections.Reflections
10+
11+
312
import org.apache.spark.annotation.DeveloperApi
413

514
import org.apache.lucene.analysis.Analyzer
15+
import org.apache.lucene.analysis.CharArraySet
616

7-
@DeveloperApi
8-
object LuceneAnalyzerGenerators {
9-
def generate() = {
10-
import org.reflections.Reflections
11-
import collection.JavaConverters._
12-
import scala.reflect.runtime.universe._
17+
18+
/**
19+
* Code generator for LuceneAnalyzers (LuceneAnalyzers.scala). Run with
20+
* {{{
21+
* build/sbt "runMain com.sparklingpandas.sparklingml.feature.LuceneAnalyzerGenerators"
22+
* }}}
23+
*/
24+
private[sparklingpandas] object LuceneAnalyzerGenerators {
25+
26+
def main(args: Array[String]): Unit = {
27+
val (testCode, transformerCode) = generate()
28+
val testCodeFile = "src/test/scala/com/sparklingpandas/sparklingml/feature/LuceneAnalyzersTests.scala"
29+
val transformerCodeFile = "src/main/scala/com/sparklingpandas/sparklingml/feature/LuceneAnalyzers.scala"
30+
val header =
31+
"""/*
32+
| * Licensed to the Apache Software Foundation (ASF) under one or more
33+
| * contributor license agreements. See the NOTICE file distributed with
34+
| * this work for additional information regarding copyright ownership.
35+
| * The ASF licenses this file to You under the Apache License, Version 2.0
36+
| * (the "License"); you may not use this file except in compliance with
37+
| * the License. You may obtain a copy of the License at
38+
| *
39+
| * http://www.apache.org/licenses/LICENSE-2.0
40+
| *
41+
| * Unless required by applicable law or agreed to in writing, software
42+
| * distributed under the License is distributed on an "AS IS" BASIS,
43+
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
44+
| * See the License for the specific language governing permissions and
45+
| * limitations under the License.
46+
| */
47+
|
48+
|package com.sparklingpandas.sparklingml.feature
49+
|
50+
|import org.apache.spark.ml.param._
51+
|
52+
|import org.apache.lucene.analysis.Analyzer
53+
|
54+
|import com.sparklingpandas.sparklingml.param._
55+
|
56+
|// DO NOT MODIFY THIS FILE! It was auto generated by LuceneAnalyzerGenerators
57+
|
58+
""".stripMargin('|')
59+
List((testCode, testCodeFile), (transformerCode, transformerCodeFile)).foreach {
60+
case (code: String, file: String) =>
61+
val writer = new PrintWriter(file)
62+
writer.write(header)
63+
writer.write(code)
64+
writer.close()
65+
}
66+
}
67+
68+
def generate(): (String, String) = {
1369
val reflections = new Reflections("org.apache.lucene");
14-
val analyzers = reflections.getSubTypesOf(classOf[org.apache.lucene.analysis.Analyzer])
70+
val generalAnalyzers = reflections.getSubTypesOf(classOf[org.apache.lucene.analysis.Analyzer]).asScala
71+
val concreteAnalyzers = generalAnalyzers.filter(cls => !Modifier.isAbstract(cls.getModifiers))
72+
// A bit of a hack but strip out the factories and such
73+
val relevantAnalyzers = concreteAnalyzers.filter(cls =>
74+
!(cls.toString.contains("$") || cls.toString.contains("Factory")))
75+
val generated = relevantAnalyzers.map{ cls =>
76+
generateForClass(cls)
77+
}
78+
val testCode = new StringBuilder()
79+
val transformerCode = new StringBuilder()
80+
generated.foreach{case (test, transform) =>
81+
testCode ++= test
82+
transformerCode ++= transform
83+
}
84+
(testCode.toString, transformerCode.toString)
85+
}
86+
87+
def generateForClass(cls: Class[_]): (String, String) = {
88+
import scala.reflect.runtime.universe._
1589
val rm = scala.reflect.runtime.currentMirror
16-
val generated = analyzers.asScala.map{ cls =>
17-
val constructors = rm.classSymbol(cls).toType.members.collect{
18-
case m: MethodSymbol if m.isConstructor && m.isPublic => m }
19-
// Since this isn't built with -parameters by default :(
20-
// we'd need a local version built with it to auto generate
21-
// the code here with the right parameters.
22-
// https://docs.oracle.com/javase/tutorial/reflect/member/methodparameterreflection.html
23-
// For now we could dump the class names and go from their
24-
// or we could play a game of pin the field on the constructor.
25-
// local build sounds like the best plan, lets do that l8r
90+
91+
val clsSymbol = rm.classSymbol(cls)
92+
val clsType = clsSymbol.toType
93+
val clsFullName = clsSymbol.fullName
94+
val clsShortName = clsSymbol.name.toString
95+
val constructors = clsType.members.collect{
96+
case m: MethodSymbol if m.isConstructor && m.isPublic => m }
97+
// Once we have the debug version constructorParametersLists should be useful
98+
val constructorParametersLists = constructors.map(_.paramLists).toList
99+
val constructorParametersSizes = constructorParametersLists.map(_(0).size)
100+
val javaReflectionConstructors = cls.getConstructors().toList
101+
val publicJavaReflectionConstructors = javaReflectionConstructors.filter(cls => Modifier.isPublic(cls.getModifiers()))
102+
val constructorParameterTypes = publicJavaReflectionConstructors.map(_.getParameterTypes())
103+
// We do this in Java as well since some of the scala reflection magic returns private
104+
// constructors even though its filtered for public. See CustomAnalyzer for an example.
105+
val javaConstructorParametersSizes = constructorParameterTypes.map(_.size)
106+
// Since this isn't built with -parameters by default :(
107+
// we'd need a local version built with it to auto generate
108+
// the code here with the right parameters.
109+
// https://docs.oracle.com/javase/tutorial/reflect/member/methodparameterreflection.html
110+
// For now we could dump the class names and go from their
111+
// or we could play a game of pin the field on the constructor.
112+
// local build sounds like the best plan, lets do that l8r
113+
114+
// Special case for handling stopword analyzers
115+
val baseClasses = clsType.baseClasses
116+
// Normally we'd do a checks with <:< but the Lucene types have multiple
117+
// StopwordAnalyzerBase's that don't inherit from eachother.
118+
val isStopWordAnalyzer = baseClasses.exists(_.asClass.fullName.contains("Stopword"))
119+
120+
val charsetConstructors = constructorParameterTypes.filter(! _.exists(_ != classOf[CharArraySet]))
121+
val charsetConstructorSizes = charsetConstructors.map(_.size)
122+
123+
// If it is a stop word analyzer and has a constructor with two charsets then it takes
124+
// the stopwords as a parameter.
125+
if (isStopWordAnalyzer && charsetConstructorSizes.contains(1)) {
126+
// If there are more parameters
127+
val includeWarning = constructorParametersSizes.exists(_ > 1)
128+
val warning = if (includeWarning) {
129+
s"""
130+
| * There are additional parameters which can not yet be controlled through this API
131+
| * See https://github.com/sparklingpandas/sparklingml/issues/3
132+
""".stripMargin('|')
133+
} else {
134+
""
135+
}
136+
val testCode =
137+
s"""
138+
|/**
139+
| * A super simple test
140+
| */
141+
| class ${clsShortName}LuceneTest extends LuceneStopwordTransformerTest {}
142+
""".stripMargin('|')
143+
val code =
144+
s"""
145+
|/**
146+
| * A basic Transformer based on ${clsFullName}. Supports configuring stopwords.${warning}
147+
| */
148+
|
149+
|class ${clsShortName}Lucene extends LuceneTransformer with HasStopwords with HasStopwordCase {
150+
| def buildAnalyzer(): Analyzer = {
151+
| // In the future we can use getDefaultStopWords here to allow people to control
152+
| // the snowball stemmer distinctly from the stopwords.
153+
| // but that is a TODO for later.
154+
| if (isSet(stopwords)) {
155+
| new ${clsFullName}(
156+
| LuceneHelpers.wordstoCharArraySet($$(stopwords), !$$(stopwordCase)))
157+
| } else {
158+
| new ${clsFullName}()
159+
| }
160+
| }
161+
|}
162+
""".stripMargin('|')
163+
(testCode, code)
164+
} else if (constructorParametersSizes.contains(0) &&
165+
javaConstructorParametersSizes.contains(0)) {
166+
val testCode =
167+
s"""
168+
|/**
169+
| * A super simple test
170+
| */
171+
| class ${clsShortName}LuceneTest extends LuceneTransformerTest {}
172+
""".stripMargin('|')
173+
val code =
174+
s"""
175+
|/**
176+
| * A basic Transformer based on ${clsFullName} - does not support
177+
| * any configuration properties.
178+
| * See https://github.com/sparklingpandas/sparklingml/issues/3 & LuceneAnalyzerGenerators
179+
| * for details.
180+
| */
181+
|
182+
|class ${clsShortName}Lucene extends LuceneTransformer {
183+
| def buildAnalyzer(): Analyzer = {
184+
| new ${clsFullName}()
185+
| }
186+
|}
187+
""".stripMargin('|')
188+
(testCode, code)
189+
} else {
190+
("", s"""
191+
|/// There is no default zero arg constructor for ${clsFullName}
192+
""".stripMargin('|'))
26193
}
27194
}
28195
}

0 commit comments

Comments
 (0)