Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added BaseSystemMLClassifier and updated the classifier to use new
MLContext
  • Loading branch information
Niketan Pansare committed Aug 9, 2016
commit 21e91c7dc6bbe0ea7314e262c890b222c677935f
142 changes: 83 additions & 59 deletions src/main/java/org/apache/sysml/api/MLContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.spark.data.RDDObject;
Expand Down Expand Up @@ -476,25 +477,6 @@ public void registerInput(String varName, RDD<String> rdd, String format, long r
registerInput(varName, rdd.toJavaRDD().mapToPair(new ConvertStringToLongTextPair()), format, rlen, clen, nnz, null);
}

public void registerInput(String varName, MatrixBlock mb) throws DMLRuntimeException {
MatrixCharacteristics mc = new MatrixCharacteristics(mb.getNumRows(), mb.getNumColumns(), OptimizerUtils.DEFAULT_BLOCKSIZE, OptimizerUtils.DEFAULT_BLOCKSIZE, mb.getNonZeros());
registerInput(varName, mb, mc);
}

public void registerInput(String varName, MatrixBlock mb, MatrixCharacteristics mc) throws DMLRuntimeException {
if(_variables == null)
_variables = new LocalVariableMap();
if(_inVarnames == null)
_inVarnames = new ArrayList<String>();

MatrixObject mo = new MatrixObject(ValueType.DOUBLE, "temp", new MatrixFormatMetaData(mc, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
mo.acquireModify(mb);
mo.release();
_variables.put(varName, mo);
_inVarnames.add(varName);
checkIfRegisteringInputAllowed();
}

// All CSV related methods call this ... It provides access to dimensions, nnz, file properties.
private void registerInput(String varName, JavaPairRDD<LongWritable, Text> textOrCsv_rdd, String format, long rlen, long clen, long nnz, FileFormatProperties props) throws DMLRuntimeException {
if(!(DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK || DMLScript.rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK)) {
Expand Down Expand Up @@ -618,6 +600,24 @@ public void registerInput(String varName, JavaPairRDD<MatrixIndexes,MatrixBlock>
checkIfRegisteringInputAllowed();
}

public void registerInput(String varName, MatrixBlock mb) throws DMLRuntimeException {
MatrixCharacteristics mc = new MatrixCharacteristics(mb.getNumRows(), mb.getNumColumns(), OptimizerUtils.DEFAULT_BLOCKSIZE, OptimizerUtils.DEFAULT_BLOCKSIZE, mb.getNonZeros());
registerInput(varName, mb, mc);
}

public void registerInput(String varName, MatrixBlock mb, MatrixCharacteristics mc) throws DMLRuntimeException {
if(_variables == null)
_variables = new LocalVariableMap();
if(_inVarnames == null)
_inVarnames = new ArrayList<String>();
MatrixObject mo = new MatrixObject(ValueType.DOUBLE, "temp", new MatrixFormatMetaData(mc, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
mo.acquireModify(mb);
mo.release();
_variables.put(varName, mo);
_inVarnames.add(varName);
checkIfRegisteringInputAllowed();
}

// =============================================================================================

/**
Expand Down Expand Up @@ -1240,56 +1240,80 @@ private MLOutput compileAndExecuteScript(String dmlScriptFilePath, String [] arg
* @throws ParseException
*/
private synchronized MLOutput compileAndExecuteScript(String dmlScriptFilePath, String [] args, boolean isFile, boolean isNamedArgument, boolean isPyDML, String configFilePath) throws IOException, DMLException {
// Set active MLContext.
_activeMLContext = this;

if(_monitorUtils != null) {
_monitorUtils.resetMonitoringData();
}

if(DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK || DMLScript.rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK) {

// Depending on whether registerInput/registerOutput was called initialize the variables
String[] inputs; String[] outputs;
if(_inVarnames != null) {
inputs = _inVarnames.toArray(new String[0]);
}
else {
inputs = new String[0];
}
if(_outVarnames != null) {
outputs = _outVarnames.toArray(new String[0]);
}
else {
outputs = new String[0];
try {
if(getActiveMLContext() != null) {
throw new DMLRuntimeException("SystemML (and hence by definition MLContext) doesnot support parallel execute() calls from same or different MLContexts. "
+ "As a temporary fix, please do explicit synchronization, i.e. synchronized(MLContext.class) { ml.execute(...) } ");
}
Map<String, MatrixCharacteristics> outMetadata = new HashMap<String, MatrixCharacteristics>();

Map<String, String> argVals = DMLScript.createArgumentsMap(isNamedArgument, args);
// Set active MLContext.
_activeMLContext = this;

// Run the DML script
ExecutionContext ec = executeUsingSimplifiedCompilationChain(dmlScriptFilePath, isFile, argVals, isPyDML, inputs, outputs, _variables, configFilePath);
if(_monitorUtils != null) {
_monitorUtils.resetMonitoringData();
}

// Now collect the output
if(_outVarnames != null) {
if(_variables == null) {
throw new DMLRuntimeException("The symbol table returned after executing the script is empty");
if(DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK || DMLScript.rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK) {

Map<String, JavaPairRDD<MatrixIndexes,MatrixBlock>> retVal = null;

// Depending on whether registerInput/registerOutput was called initialize the variables
String[] inputs; String[] outputs;
if(_inVarnames != null) {
inputs = _inVarnames.toArray(new String[0]);
}
else {
inputs = new String[0];
}
if(_outVarnames != null) {
outputs = _outVarnames.toArray(new String[0]);
}
else {
outputs = new String[0];
}
Map<String, MatrixCharacteristics> outMetadata = new HashMap<String, MatrixCharacteristics>();

Map<String, String> argVals = DMLScript.createArgumentsMap(isNamedArgument, args);

for( String ovar : _outVarnames ) {
if( _variables.keySet().contains(ovar) ) {
outMetadata.put(ovar, ec.getMatrixCharacteristics(ovar)); // For converting output to dataframe
// Run the DML script
ExecutionContext ec = executeUsingSimplifiedCompilationChain(dmlScriptFilePath, isFile, argVals, isPyDML, inputs, outputs, _variables, configFilePath);

// Now collect the output
if(_outVarnames != null) {
if(_variables == null) {
throw new DMLRuntimeException("The symbol table returned after executing the script is empty");
}
else {
throw new DMLException("Error: The variable " + ovar + " is not available as output after the execution of the DMLScript.");

for( String ovar : _outVarnames ) {
if( _variables.keySet().contains(ovar) ) {
if(retVal == null) {
retVal = new HashMap<String, JavaPairRDD<MatrixIndexes,MatrixBlock>>();
}
retVal.put(ovar, ((SparkExecutionContext) ec).getBinaryBlockRDDHandleForVariable(ovar));
outMetadata.put(ovar, ec.getMatrixCharacteristics(ovar)); // For converting output to dataframe
}
else {
throw new DMLException("Error: The variable " + ovar + " is not available as output after the execution of the DMLScript.");
}
}
}

return new MLOutput(retVal, outMetadata);
}

return new MLOutput(_variables, ec, outMetadata);
else {
throw new DMLRuntimeException("Unsupported runtime:" + DMLScript.rtplatform.name());
}

}
else {
throw new DMLRuntimeException("Unsupported runtime:" + DMLScript.rtplatform.name());
finally {
// Remove global dml config and all thread-local configs
// TODO enable cleanup whenever invalid GNMF MLcontext is fixed
// (the test is invalid because it assumes that status of previous execute is kept)
//ConfigurationManager.setGlobalConfig(new DMLConfig());
//ConfigurationManager.clearLocalConfigs();

// Reset active MLContext.
_activeMLContext = null;
}
}

Expand Down Expand Up @@ -1451,4 +1475,4 @@ public MLMatrix read(SQLContext sqlContext, String filePath, String format) thro
// return MLMatrix.createMLMatrix(this, sqlContext, blocks, mc);
// }

}
}
48 changes: 19 additions & 29 deletions src/main/java/org/apache/sysml/api/MLOutput.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,14 @@
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.instructions.spark.functions.GetMLBlock;
import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.util.UtilFunctions;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;

import scala.Tuple2;

/**
Expand All @@ -57,39 +55,31 @@
*/
public class MLOutput {

private LocalVariableMap _variables;
private ExecutionContext _ec;
Map<String, JavaPairRDD<MatrixIndexes,MatrixBlock>> _outputs;
private Map<String, MatrixCharacteristics> _outMetadata = null;

public MLOutput(LocalVariableMap variables, ExecutionContext ec, Map<String, MatrixCharacteristics> outMetadata) {
this._variables = variables;
this._ec = ec;
this._outMetadata = outMetadata;
}

public MatrixBlock getMatrixBlock(String varName) throws DMLRuntimeException {
if( _variables.keySet().contains(varName) ) {
MatrixObject mo = _ec.getMatrixObject(varName);
MatrixBlock mb = mo.acquireRead();
mo.release();
return mb;
}
else {
throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
}
MatrixCharacteristics mc = getMatrixCharacteristics(varName);
// The matrix block is always pushed to an RDD and then we do collect
// We can later avoid this by returning symbol table rather than "Map<String, JavaPairRDD<MatrixIndexes,MatrixBlock>> _outputs"
MatrixBlock mb = SparkExecutionContext.toMatrixBlock(getBinaryBlockedRDD(varName), (int) mc.getRows(), (int) mc.getCols(),
mc.getRowsPerBlock(), mc.getColsPerBlock(), mc.getNonZeros());
return mb;
}
public MLOutput(Map<String, JavaPairRDD<MatrixIndexes,MatrixBlock>> outputs, Map<String, MatrixCharacteristics> outMetadata) {
this._outputs = outputs;
this._outMetadata = outMetadata;
}

public JavaPairRDD<MatrixIndexes,MatrixBlock> getBinaryBlockedRDD(String varName) throws DMLRuntimeException {
if( _variables.keySet().contains(varName) ) {
return ((SparkExecutionContext) _ec).getBinaryBlockRDDHandleForVariable(varName);
}
else {
throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
if(_outputs.containsKey(varName)) {
return _outputs.get(varName);
}
throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
}

public MatrixCharacteristics getMatrixCharacteristics(String varName) throws DMLRuntimeException {
if(_outMetadata.containsKey(varName)) {
if(_outputs.containsKey(varName)) {
return _outMetadata.get(varName);
}
throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
Expand Down Expand Up @@ -255,15 +245,15 @@ public Iterable<Tuple2<Long, Tuple2<Long, Double[]>>> call(Tuple2<MatrixIndexes,
int lclen = UtilFunctions.computeBlockSize(clen, blockColIndex, bclen);
// ------------------------------------------------------------------

long startRowIndex = (kv._1.getRowIndex()-1) * bclen;
long startRowIndex = (kv._1.getRowIndex()-1) * bclen + 1;
MatrixBlock blk = kv._2;
ArrayList<Tuple2<Long, Tuple2<Long, Double[]>>> retVal = new ArrayList<Tuple2<Long,Tuple2<Long,Double[]>>>();
for(int i = 0; i < lrlen; i++) {
Double[] partialRow = new Double[lclen];
for(int j = 0; j < lclen; j++) {
partialRow[j] = blk.getValue(i, j);
}
retVal.add(new Tuple2<Long, Tuple2<Long,Double[]>>(startRowIndex + i + 1, new Tuple2<Long,Double[]>(kv._1.getColumnIndex(), partialRow)));
retVal.add(new Tuple2<Long, Tuple2<Long,Double[]>>(startRowIndex + i, new Tuple2<Long,Double[]>(kv._1.getColumnIndex(), partialRow)));
}
return retVal;
}
Expand Down Expand Up @@ -427,4 +417,4 @@ public Row call(Tuple2<Long, Iterable<Tuple2<Long, Double[]>>> arg0)
return RowFactory.create(row);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.sql.DataFrame;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
Expand Down Expand Up @@ -97,6 +99,13 @@ public BinaryBlockMatrix(JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlocks,
public JavaPairRDD<MatrixIndexes, MatrixBlock> getBinaryBlocks() {
return binaryBlocks;
}

public MatrixBlock getMatrixBlock() throws DMLRuntimeException {
MatrixCharacteristics mc = getMatrixCharacteristics();
MatrixBlock mb = SparkExecutionContext.toMatrixBlock(binaryBlocks, (int) mc.getRows(), (int) mc.getCols(),
mc.getRowsPerBlock(), mc.getColsPerBlock(), mc.getNonZeros());
return mb;
}

/**
* Obtain the SystemML binary-block matrix characteristics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ public static double[][] matrixObjectToDoubleMatrix(MatrixObject matrixObject) {
* @return the {@code MatrixObject} converted to a {@code DataFrame}
*/
public static DataFrame matrixObjectToDataFrame(MatrixObject matrixObject,
SparkExecutionContext sparkExecutionContext) {
SparkExecutionContext sparkExecutionContext, boolean isVectorDF) {
try {
@SuppressWarnings("unchecked")
JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlockMatrix = (JavaPairRDD<MatrixIndexes, MatrixBlock>) sparkExecutionContext
Expand All @@ -686,8 +686,17 @@ public static DataFrame matrixObjectToDataFrame(MatrixObject matrixObject,
MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContext();
SparkContext sc = activeMLContext.getSparkContext();
SQLContext sqlContext = new SQLContext(sc);
DataFrame df = RDDConverterUtilsExt.binaryBlockToDataFrame(binaryBlockMatrix, matrixCharacteristics,
DataFrame df = null;
if(isVectorDF) {
df = RDDConverterUtilsExt.binaryBlockToVectorDataFrame(binaryBlockMatrix, matrixCharacteristics,
sqlContext);
}
else {
df = RDDConverterUtilsExt.binaryBlockToDataFrame(binaryBlockMatrix, matrixCharacteristics,
sqlContext);
}


return df;
} catch (DMLRuntimeException e) {
throw new MLContextException("DMLRuntimeException while converting matrix object to DataFrame", e);
Expand Down
Loading