Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
JsMVA: TMVA's method rewriter; 'overloaded' methods: Factory::TrainAl…
…lMethods, Factory::Factory, Factory::BookMethods
  • Loading branch information
qati committed Sep 19, 2016
commit f12c04ed25bee5dce80e255708576a2c4aaae0da
95 changes: 87 additions & 8 deletions bindings/pyroot/JsMVA/python/JsMVA/Factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from threading import Thread
import time
from string import Template
import types


## Getting method object from factory
Expand Down Expand Up @@ -438,9 +439,9 @@ def clicked(b):
container = widgets.HBox([label,treeSelector, drawTree])
display(container)

## Rewrites the TMVA::Factory::TrainAllMethods function. This function provides interactive training.
## Rewrite function for TMVA::Factory::TrainAllMethods. This function provides interactive training.
# @param fac the factory object pointer
def __TrainAllMethods(fac):
def ChangeTrainAllMethods(fac):
clear_output()
#stop button
button = """
Expand Down Expand Up @@ -591,9 +592,87 @@ def exit_supported(mn):
t.join()
return


ROOT.TMVA.MethodBase.GetInteractiveTrainingError._threaded = True
ROOT.TMVA.MethodBase.ExitFromTraining._threaded = True
ROOT.TMVA.MethodBase.TrainingEnded._threaded = True
ROOT.TMVA.MethodBase.TrainMethod._threaded = True
ROOT.TMVA.Factory.TrainAllMethods = __TrainAllMethods
## Get's special parameters from kwargs and converts to positional parameter
def __ConvertKwargsToArgs(positionalArgumentsToNamed, *args, **kwargs):
# args[0] = self
args = list(args)
idx = 0
PositionalArgsEnded = False
for argName in positionalArgumentsToNamed:
if not PositionalArgsEnded:
if argName in kwargs:
if (idx+1)!=len(args):
raise AttributeError
PositionalArgsEnded = True
else:
idx += 1
if PositionalArgsEnded and argName not in kwargs:
raise AttributeError
if argName in kwargs:
args.append(kwargs[argName])
del kwargs[argName]
args = tuple(args)
return (args, kwargs)

## Converts object to TMVA style option string
def __ProcessParameters(optStringStartIndex, *args, **kwargs):
originalFunction = None
if optStringStartIndex!=-10:
originalFunction = kwargs["originalFunction"]
del kwargs["originalFunction"]
OptionStringPassed = False
if (len(args)-1) == optStringStartIndex:
opt = args[optStringStartIndex] + ":"
tmp = list(args)
del tmp[optStringStartIndex]
args = tuple(tmp)
OptionStringPassed = True
else:
opt = ""
for key in kwargs:
if type(kwargs[key]) == types.BooleanType:
if kwargs[key] == True:
opt += key + ":"
else:
opt += "!" + key + ":"
else:
opt += key + "=" + str(kwargs[key]) + ":"
tmp = list(args)
if OptionStringPassed or len(kwargs)>0:
tmp.append( opt[:-1] )
return ( originalFunction, tuple(tmp) )

## Rewrite the constructor of TMVA::Factory
def ChangeCallOriginal__init__(*args, **kwargs):
try:
args, kwargs = __ConvertKwargsToArgs(["JobName", "TargetFile"], *args, **kwargs)
except AttributeError:
try:
args, kwargs = __ConvertKwargsToArgs(["JobName"], *args, **kwargs)
except AttributeError:
raise AttributeError
originalFunction, args = __ProcessParameters(3, *args, **kwargs)
return originalFunction(*args)

## Rewrite the constructor of TMVA::Factory
def ChangeCallOriginalBookMethod(*args, **kwargs):
compositeOpts = False
composite = False
if "Composite" in kwargs:
composite = kwargs["Composite"]
del kwargs["Composite"]
if "CompositeOptions" in kwargs:
compositeOpts = kwargs["CompositeOptions"]
del kwargs["CompositeOptions"]
args, kwargs = __ConvertKwargsToArgs(["DataLoader", "Method", "MethodTitle"], *args, **kwargs)
originalFunction, args = __ProcessParameters(4, *args, **kwargs)
if composite!=False:
args = list(args)
args.append(composite)
args = tuple(args)
if compositeOpts!=False:
o, compositeOptStr = __ProcessParameters(-10, **compositeOpts)
args = list(args)
args.append(compositeOptStr[0])
args = tuple(args)
return originalFunction(*args)
32 changes: 30 additions & 2 deletions bindings/pyroot/JsMVA/python/JsMVA/JPyInterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
# to TMVA::Factory, TMVA::DataLoader
class functions:

## Threaded functions
ThreadedFunctions = {
"MethodBase": ["GetInteractiveTrainingError", "ExitFromTraining", "TrainingEnded", "TrainMethod",
"GetMaxIter", "GetCurrentIter"]
}

## The method inserter function
# @param target which class to insert
# @param source module which contains the methods to insert
Expand All @@ -28,6 +34,24 @@ def __register(target, source, *args):
continue
setattr(target, arg, getattr(source, arg))

## This method change TMVA methods with new methods
# @param target which class to insert
# @param source module which contains the methods to insert
# @param args list of methods to insert
@staticmethod
def __changeMethod(target, source, *args):
def rewriter(originalFunction, newFunction):
def wrapper(*args, **kwargs):
kwargs["originalFunction"] = originalFunction
return newFunction(*args, **kwargs)
return wrapper
for arg in args:
if arg.find("CallOriginal")!=-1:
originalName = arg.replace("Change", "").replace("CallOriginal", "")
setattr(target, originalName, rewriter(getattr(target, originalName), getattr(source, arg)))
else:
setattr(target, arg.replace("Change", ""), getattr(source, arg))

## The method removes inserted functions from class
# @param target from which class to remove functions
# @param args list of methods to remove
Expand All @@ -54,6 +78,11 @@ def __getMethods(module, selector):
def register():
functions.__register(ROOT.TMVA.DataLoader, DataLoader, *functions.__getMethods(DataLoader, "Draw"))
functions.__register(ROOT.TMVA.Factory, Factory, *functions.__getMethods(Factory, "Draw"))
functions.__changeMethod(ROOT.TMVA.Factory, Factory, *functions.__getMethods(Factory, "Change"))
functions.__changeMethod(ROOT.TMVA.DataLoader, DataLoader, *functions.__getMethods(DataLoader, "Change"))
for key in functions.ThreadedFunctions:
for func in functions.ThreadedFunctions[key]:
setattr(getattr(getattr(ROOT.TMVA, key), func), "_threaded", True)

## This function will remove all functions which name contains "Draw" from TMVA.DataLoader and TMVA.Factory
# if the function was inserted from DataLoader and Factory modules
Expand All @@ -65,9 +94,8 @@ def unregister():

## Class for creating the output scripts and inserting them to cell output
class JsDraw:
#__jsMVASourceDir = "https://rawgit.com/qati/GSOC16/master/src/js"
## String containing the link to JavaScript files
__jsMVASourceDir = "http://localhost:8888/notebooks/code/GSOC/wd/src/js"
__jsMVASourceDir = "https://rawgit.com/qati/GSOC16/master/src/js"

## Drawing are sizes
jsCanvasWidth = 800
Expand Down
3 changes: 1 addition & 2 deletions bindings/pyroot/JsMVA/python/JsMVA/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,4 @@ def loadExtensions():
extMgr = ExtensionManager(ip)
extMgr.load_extension("JsMVA.JsMVAMagic")


loadExtensions();
loadExtensions()