|
1 | 1 | package com.sparklingpandas.sparklingml.feature |
2 | 2 |
|
| 3 | +import java.lang.reflect.Modifier |
| 4 | +import java.io.PrintWriter |
| 5 | + |
| 6 | +import scala.collection.JavaConverters._ |
| 7 | +import scala.collection.mutable.StringBuilder |
| 8 | + |
| 9 | +import org.reflections.Reflections |
| 10 | + |
| 11 | + |
3 | 12 | import org.apache.spark.annotation.DeveloperApi |
4 | 13 |
|
5 | 14 | import org.apache.lucene.analysis.Analyzer |
| 15 | +import org.apache.lucene.analysis.CharArraySet |
6 | 16 |
|
7 | | -@DeveloperApi |
8 | | -object LuceneAnalyzerGenerators { |
9 | | - def generate() = { |
10 | | - import org.reflections.Reflections |
11 | | - import collection.JavaConverters._ |
12 | | - import scala.reflect.runtime.universe._ |
| 17 | + |
| 18 | +/** |
| 19 | + * Code generator for LuceneAnalyzers (LuceneAnalyzers.scala). Run with |
| 20 | + * {{{ |
| 21 | + * build/sbt "runMain com.sparklingpandas.sparklingml.feature.LuceneAnalyzerGenerators" |
| 22 | + * }}} |
| 23 | + */ |
| 24 | +private[sparklingpandas] object LuceneAnalyzerGenerators { |
| 25 | + |
| 26 | + def main(args: Array[String]): Unit = { |
| 27 | + val (testCode, transformerCode) = generate() |
| 28 | + val testCodeFile = "src/test/scala/com/sparklingpandas/sparklingml/feature/LuceneAnalyzersTests.scala" |
| 29 | + val transformerCodeFile = "src/main/scala/com/sparklingpandas/sparklingml/feature/LuceneAnalyzers.scala" |
| 30 | + val header = |
| 31 | + """/* |
| 32 | + | * Licensed to the Apache Software Foundation (ASF) under one or more |
| 33 | + | * contributor license agreements. See the NOTICE file distributed with |
| 34 | + | * this work for additional information regarding copyright ownership. |
| 35 | + | * The ASF licenses this file to You under the Apache License, Version 2.0 |
| 36 | + | * (the "License"); you may not use this file except in compliance with |
| 37 | + | * the License. You may obtain a copy of the License at |
| 38 | + | * |
| 39 | + | * http://www.apache.org/licenses/LICENSE-2.0 |
| 40 | + | * |
| 41 | + | * Unless required by applicable law or agreed to in writing, software |
| 42 | + | * distributed under the License is distributed on an "AS IS" BASIS, |
| 43 | + | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 44 | + | * See the License for the specific language governing permissions and |
| 45 | + | * limitations under the License. |
| 46 | + | */ |
| 47 | + | |
| 48 | + |package com.sparklingpandas.sparklingml.feature |
| 49 | + | |
| 50 | + |import org.apache.spark.ml.param._ |
| 51 | + | |
| 52 | + |import org.apache.lucene.analysis.Analyzer |
| 53 | + | |
| 54 | + |import com.sparklingpandas.sparklingml.param._ |
| 55 | + | |
| 56 | + |// DO NOT MODIFY THIS FILE! It was auto generated by LuceneAnalyzerGenerators |
| 57 | + | |
| 58 | + """.stripMargin('|') |
| 59 | + List((testCode, testCodeFile), (transformerCode, transformerCodeFile)).foreach { |
| 60 | + case (code: String, file: String) => |
| 61 | + val writer = new PrintWriter(file) |
| 62 | + writer.write(header) |
| 63 | + writer.write(code) |
| 64 | + writer.close() |
| 65 | + } |
| 66 | + } |
| 67 | + |
| 68 | + def generate(): (String, String) = { |
13 | 69 | val reflections = new Reflections("org.apache.lucene"); |
14 | | - val analyzers = reflections.getSubTypesOf(classOf[org.apache.lucene.analysis.Analyzer]) |
| 70 | + val generalAnalyzers = reflections.getSubTypesOf(classOf[org.apache.lucene.analysis.Analyzer]).asScala |
| 71 | + val concreteAnalyzers = generalAnalyzers.filter(cls => !Modifier.isAbstract(cls.getModifiers)) |
| 72 | + // A bit of a hack but strip out the factories and such |
| 73 | + val relevantAnalyzers = concreteAnalyzers.filter(cls => |
| 74 | + !(cls.toString.contains("$") || cls.toString.contains("Factory"))) |
| 75 | + val generated = relevantAnalyzers.map{ cls => |
| 76 | + generateForClass(cls) |
| 77 | + } |
| 78 | + val testCode = new StringBuilder() |
| 79 | + val transformerCode = new StringBuilder() |
| 80 | + generated.foreach{case (test, transform) => |
| 81 | + testCode ++= test |
| 82 | + transformerCode ++= transform |
| 83 | + } |
| 84 | + (testCode.toString, transformerCode.toString) |
| 85 | + } |
| 86 | + |
| 87 | + def generateForClass(cls: Class[_]): (String, String) = { |
| 88 | + import scala.reflect.runtime.universe._ |
15 | 89 | val rm = scala.reflect.runtime.currentMirror |
16 | | - val generated = analyzers.asScala.map{ cls => |
17 | | - val constructors = rm.classSymbol(cls).toType.members.collect{ |
18 | | - case m: MethodSymbol if m.isConstructor && m.isPublic => m } |
19 | | - // Since this isn't built with -parameters by default :( |
20 | | - // we'd need a local version built with it to auto generate |
21 | | - // the code here with the right parameters. |
22 | | - // https://docs.oracle.com/javase/tutorial/reflect/member/methodparameterreflection.html |
23 | | - // For now we could dump the class names and go from their |
24 | | - // or we could play a game of pin the field on the constructor. |
25 | | - // local build sounds like the best plan, lets do that l8r |
| 90 | + |
| 91 | + val clsSymbol = rm.classSymbol(cls) |
| 92 | + val clsType = clsSymbol.toType |
| 93 | + val clsFullName = clsSymbol.fullName |
| 94 | + val clsShortName = clsSymbol.name.toString |
| 95 | + val constructors = clsType.members.collect{ |
| 96 | + case m: MethodSymbol if m.isConstructor && m.isPublic => m } |
| 97 | + // Once we have the debug version constructorParametersLists should be useful |
| 98 | + val constructorParametersLists = constructors.map(_.paramLists).toList |
| 99 | + val constructorParametersSizes = constructorParametersLists.map(_(0).size) |
| 100 | + val javaReflectionConstructors = cls.getConstructors().toList |
| 101 | + val publicJavaReflectionConstructors = javaReflectionConstructors.filter(cls => Modifier.isPublic(cls.getModifiers())) |
| 102 | + val constructorParameterTypes = publicJavaReflectionConstructors.map(_.getParameterTypes()) |
| 103 | + // We do this in Java as well since some of the scala reflection magic returns private |
| 104 | + // constructors even though its filtered for public. See CustomAnalyzer for an example. |
| 105 | + val javaConstructorParametersSizes = constructorParameterTypes.map(_.size) |
| 106 | + // Since this isn't built with -parameters by default :( |
| 107 | + // we'd need a local version built with it to auto generate |
| 108 | + // the code here with the right parameters. |
| 109 | + // https://docs.oracle.com/javase/tutorial/reflect/member/methodparameterreflection.html |
| 110 | + // For now we could dump the class names and go from their |
| 111 | + // or we could play a game of pin the field on the constructor. |
| 112 | + // local build sounds like the best plan, lets do that l8r |
| 113 | + |
| 114 | + // Special case for handling stopword analyzers |
| 115 | + val baseClasses = clsType.baseClasses |
| 116 | + // Normally we'd do a checks with <:< but the Lucene types have multiple |
| 117 | + // StopwordAnalyzerBase's that don't inherit from eachother. |
| 118 | + val isStopWordAnalyzer = baseClasses.exists(_.asClass.fullName.contains("Stopword")) |
| 119 | + |
| 120 | + val charsetConstructors = constructorParameterTypes.filter(! _.exists(_ != classOf[CharArraySet])) |
| 121 | + val charsetConstructorSizes = charsetConstructors.map(_.size) |
| 122 | + |
| 123 | + // If it is a stop word analyzer and has a constructor with two charsets then it takes |
| 124 | + // the stopwords as a parameter. |
| 125 | + if (isStopWordAnalyzer && charsetConstructorSizes.contains(1)) { |
| 126 | + // If there are more parameters |
| 127 | + val includeWarning = constructorParametersSizes.exists(_ > 1) |
| 128 | + val warning = if (includeWarning) { |
| 129 | + s""" |
| 130 | + | * There are additional parameters which can not yet be controlled through this API |
| 131 | + | * See https://github.com/sparklingpandas/sparklingml/issues/3 |
| 132 | + """.stripMargin('|') |
| 133 | + } else { |
| 134 | + "" |
| 135 | + } |
| 136 | + val testCode = |
| 137 | + s""" |
| 138 | + |/** |
| 139 | + | * A super simple test |
| 140 | + | */ |
| 141 | + | class ${clsShortName}LuceneTest extends LuceneStopwordTransformerTest {} |
| 142 | + """.stripMargin('|') |
| 143 | + val code = |
| 144 | + s""" |
| 145 | + |/** |
| 146 | + | * A basic Transformer based on ${clsFullName}. Supports configuring stopwords.${warning} |
| 147 | + | */ |
| 148 | + | |
| 149 | + |class ${clsShortName}Lucene extends LuceneTransformer with HasStopwords with HasStopwordCase { |
| 150 | + | def buildAnalyzer(): Analyzer = { |
| 151 | + | // In the future we can use getDefaultStopWords here to allow people to control |
| 152 | + | // the snowball stemmer distinctly from the stopwords. |
| 153 | + | // but that is a TODO for later. |
| 154 | + | if (isSet(stopwords)) { |
| 155 | + | new ${clsFullName}( |
| 156 | + | LuceneHelpers.wordstoCharArraySet($$(stopwords), !$$(stopwordCase))) |
| 157 | + | } else { |
| 158 | + | new ${clsFullName}() |
| 159 | + | } |
| 160 | + | } |
| 161 | + |} |
| 162 | + """.stripMargin('|') |
| 163 | + (testCode, code) |
| 164 | + } else if (constructorParametersSizes.contains(0) && |
| 165 | + javaConstructorParametersSizes.contains(0)) { |
| 166 | + val testCode = |
| 167 | + s""" |
| 168 | + |/** |
| 169 | + | * A super simple test |
| 170 | + | */ |
| 171 | + | class ${clsShortName}LuceneTest extends LuceneTransformerTest {} |
| 172 | + """.stripMargin('|') |
| 173 | + val code = |
| 174 | + s""" |
| 175 | + |/** |
| 176 | + | * A basic Transformer based on ${clsFullName} - does not support |
| 177 | + | * any configuration properties. |
| 178 | + | * See https://github.com/sparklingpandas/sparklingml/issues/3 & LuceneAnalyzerGenerators |
| 179 | + | * for details. |
| 180 | + | */ |
| 181 | + | |
| 182 | + |class ${clsShortName}Lucene extends LuceneTransformer { |
| 183 | + | def buildAnalyzer(): Analyzer = { |
| 184 | + | new ${clsFullName}() |
| 185 | + | } |
| 186 | + |} |
| 187 | + """.stripMargin('|') |
| 188 | + (testCode, code) |
| 189 | + } else { |
| 190 | + ("", s""" |
| 191 | + |/// There is no default zero arg constructor for ${clsFullName} |
| 192 | + """.stripMargin('|')) |
26 | 193 | } |
27 | 194 | } |
28 | 195 | } |
0 commit comments