Skip to content

Commit d9d8abd

Browse files
committed
sync from master
2 parents 7dcb503 + 67a254c commit d9d8abd

File tree

956 files changed

+6448
-2612
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

956 files changed

+6448
-2612
lines changed

.github/workflows/build_and_test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ jobs:
149149
catalyst, hive-thriftserver
150150
- >-
151151
streaming, sql-kafka-0-10, streaming-kafka-0-10,
152-
mllib-local, mllib,
152+
mllib-local, mllib-common, mllib,
153153
yarn, mesos, kubernetes, hadoop-cloud, spark-ganglia-lgpl,
154154
connect, protobuf
155155
# Here, we split Hive and SQL tests into some of slow ones and the rest of them.

common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public KVTypeInfo(Class<?> type) {
5656
KVIndex idx = m.getAnnotation(KVIndex.class);
5757
if (idx != null) {
5858
checkIndex(idx, indices);
59-
Preconditions.checkArgument(m.getParameterTypes().length == 0,
59+
Preconditions.checkArgument(m.getParameterCount() == 0,
6060
"Annotated method %s::%s should not have any parameters.", type.getName(), m.getName());
6161
m.setAccessible(true);
6262
indices.put(idx.value(), idx);

connector/connect/client/jvm/pom.xml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,18 @@
6262
</exclusion>
6363
</exclusions>
6464
</dependency>
65+
<dependency>
66+
<groupId>org.apache.spark</groupId>
67+
<artifactId>spark-mllib-common_${scala.binary.version}</artifactId>
68+
<version>${project.version}</version>
69+
<scope>provided</scope>
70+
<exclusions>
71+
<exclusion>
72+
<groupId>com.google.guava</groupId>
73+
<artifactId>guava</artifactId>
74+
</exclusion>
75+
</exclusions>
76+
</dependency>
6577
<dependency>
6678
<groupId>com.google.protobuf</groupId>
6779
<artifactId>protobuf-java</artifactId>
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml
19+
20+
import scala.annotation.varargs
21+
22+
import org.apache.spark.annotation.Since
23+
import org.apache.spark.ml.param.{ParamMap, ParamPair}
24+
import org.apache.spark.sql.Dataset
25+
26+
/**
27+
* Abstract class for estimators that fit models to data.
28+
*/
29+
abstract class Estimator[M <: Model[M]] extends PipelineStage {
30+
31+
/**
32+
* Fits a single model to the input data with optional parameters.
33+
*
34+
* @param dataset
35+
* input dataset
36+
* @param firstParamPair
37+
* the first param pair, overrides embedded params
38+
* @param otherParamPairs
39+
* other param pairs. These values override any specified in this Estimator's embedded
40+
* ParamMap.
41+
* @return
42+
* fitted model
43+
*/
44+
@Since("3.5.0")
45+
@varargs
46+
def fit(
47+
dataset: Dataset[_],
48+
firstParamPair: ParamPair[_],
49+
otherParamPairs: ParamPair[_]*): M = {
50+
val map = new ParamMap()
51+
.put(firstParamPair)
52+
.put(otherParamPairs: _*)
53+
fit(dataset, map)
54+
}
55+
56+
/**
57+
* Fits a single model to the input data with provided parameter map.
58+
*
59+
* @param dataset
60+
* input dataset
61+
* @param paramMap
62+
* Parameter map. These values override any specified in this Estimator's embedded ParamMap.
63+
* @return
64+
* fitted model
65+
*/
66+
@Since("3.5.0")
67+
def fit(dataset: Dataset[_], paramMap: ParamMap): M = {
68+
copy(paramMap).fit(dataset)
69+
}
70+
71+
/**
72+
* Fits a model to the input data.
73+
*/
74+
@Since("3.5.0")
75+
def fit(dataset: Dataset[_]): M
76+
77+
/**
78+
* Fits multiple models to the input data with multiple sets of parameters. The default
79+
* implementation uses a for loop on each parameter map. Subclasses could override this to
80+
* optimize multi-model training.
81+
*
82+
* @param dataset
83+
* input dataset
84+
* @param paramMaps
85+
* An array of parameter maps. These values override any specified in this Estimator's
86+
* embedded ParamMap.
87+
* @return
88+
* fitted models, matching the input parameter maps
89+
*/
90+
@Since("3.5.0")
91+
def fit(dataset: Dataset[_], paramMaps: Seq[ParamMap]): Seq[M] = {
92+
paramMaps.map(fit(dataset, _))
93+
}
94+
95+
@Since("3.5.0")
96+
override def copy(extra: ParamMap): Estimator[M]
97+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml
19+
20+
import org.apache.spark.annotation.Since
21+
import org.apache.spark.ml.param.ParamMap
22+
23+
/**
24+
* A fitted model, i.e., a [[Transformer]] produced by an [[Estimator]].
25+
*
26+
* @tparam M
27+
* model type
28+
*/
29+
abstract class Model[M <: Model[M]] extends Transformer {
30+
31+
/**
32+
* The parent estimator that produced this model.
33+
* @note
34+
* For ensembles' component Models, this value can be null.
35+
*/
36+
@transient var parent: Estimator[M] = _
37+
38+
/**
39+
* Sets the parent of this model (Java API).
40+
*/
41+
@Since("3.5.0")
42+
def setParent(parent: Estimator[M]): M = {
43+
this.parent = parent
44+
this.asInstanceOf[M]
45+
}
46+
47+
/** Indicates whether this [[Model]] has a corresponding parent. */
48+
@Since("3.5.0")
49+
def hasParent: Boolean = parent != null
50+
51+
@Since("3.5.0")
52+
override def copy(extra: ParamMap): M
53+
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml
19+
20+
import org.apache.spark.annotation.DeveloperApi
21+
import org.apache.spark.internal.Logging
22+
import org.apache.spark.ml.param.{ParamMap, Params}
23+
import org.apache.spark.sql.types.StructType
24+
25+
/**
26+
* A stage in a pipeline, either an [[Estimator]] or a [[Transformer]].
27+
*/
28+
abstract class PipelineStage extends Params with Logging {
29+
30+
/**
31+
* Check transform validity and derive the output schema from the input schema.
32+
*
33+
* We check validity for interactions between parameters during `transformSchema` and raise an
34+
* exception if any parameter value is invalid. Parameter value checks which do not depend on
35+
* other parameters are handled by `Param.validate()`.
36+
*
37+
* Typical implementation should first conduct verification on schema change and parameter
38+
* validity, including complex parameter interaction checks.
39+
*/
40+
def transformSchema(schema: StructType): StructType
41+
42+
/**
43+
* :: DeveloperApi ::
44+
*
45+
* Derives the output schema from the input schema and parameters, optionally with logging.
46+
*
47+
* This should be optimistic. If it is unclear whether the schema will be valid, then it should
48+
* be assumed valid until proven otherwise.
49+
*/
50+
@DeveloperApi
51+
protected def transformSchema(schema: StructType, logging: Boolean): StructType = {
52+
if (logging) {
53+
logDebug(s"Input schema: ${schema.json}")
54+
}
55+
val outputSchema = transformSchema(schema)
56+
if (logging) {
57+
logDebug(s"Expected output schema: ${outputSchema.json}")
58+
}
59+
outputSchema
60+
}
61+
62+
override def copy(extra: ParamMap): PipelineStage
63+
}

0 commit comments

Comments
 (0)