From b951f09137f88ef050146737a1c09090ced4f87e Mon Sep 17 00:00:00 2001 From: Emil Bols Date: Mon, 17 Feb 2020 15:48:59 +0100 Subject: [PATCH] migration to djc2 --- README.md | 69 +++++ Train/train_DeepFlavour.py | 57 +++++ env.sh | 14 + modules/Layers.py | 1 + modules/Losses.py | 3 + modules/Metrics.py | 3 + modules/compiled/Makefile | 53 ++++ modules/compiled/src/c_convert.C | 35 +++ .../datastructures/TrainData_deepFlavour.py | 242 ++++++++++++++++++ modules/datastructures/__init__.py | 22 ++ modules/models/__init__.py | 17 ++ modules/models/buildingBlocks.py | 150 +++++++++++ modules/models/convolutional.py | 58 +++++ 13 files changed, 724 insertions(+) create mode 100644 README.md create mode 100644 Train/train_DeepFlavour.py create mode 100644 env.sh create mode 100644 modules/Layers.py create mode 100644 modules/Losses.py create mode 100644 modules/Metrics.py create mode 100644 modules/compiled/Makefile create mode 100644 modules/compiled/src/c_convert.C create mode 100644 modules/datastructures/TrainData_deepFlavour.py create mode 100644 modules/datastructures/__init__.py create mode 100644 modules/models/__init__.py create mode 100644 modules/models/buildingBlocks.py create mode 100644 modules/models/convolutional.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..6248413 --- /dev/null +++ b/README.md @@ -0,0 +1,69 @@ +DeepJet: Repository for training and evaluation of deep neural networks for Jet identification +=============================================================================== + +This package depends on DeepJetCore 2.X (https://github.com/DL4Jets/DeepJetCore). + +Usage +============== + +After logging in, please source the environment (please cd to the directory first!): +``` +cd /DeepJet +source env.sh +``` + + +The preparation for the training consists of the following steps +==== + +- define the data structure for the training. The DeepJet datastructure is found in the modules directory as the class TrainData_DF. + +- convert the root file to the data strucure for training using DeepJetCore tools: + ``` + convertFromSource.py -i /path/to/the/root/ntuple/list_of_root_files.txt -o /output/path/that/needs/some/disk/space -c TrainData_DF + ``` + + This step can take a while. + + +- prepare the training file and the model. Please refer to DeepJet/Train/train_DeepFlavour.py + + + +Training +==== + +Since the training can take a while, it is advised to open a screen session, such that it does not die at logout. +``` +ssh lxplus.cern.ch + +screen +ssh lxplus7 +``` +Then source the environment, and proceed with the training. Detach the screen session with ctr+a d. +You can go back to the session by logging in to the machine the session is running on (e.g. lxplus58): + +``` +ssh lxplus.cern.ch +ssh lxplus058 +screen -r +``` + +Please close the session when the training is finished + +the training is launched in the following way: +``` +python train_DeepFlavour.py /path/to/the/output/of/convert/dataCollection.dc +``` + + +Evaluation +==== + +After the training has finished, the performance can be evaluated. + +``` +predict.py /KERAS_model.h5 /trainsamples.dc /filelist.txt +``` + +This creates output trees with the prediction scores as well as truth information and some kinematic variables. diff --git a/Train/train_DeepFlavour.py b/Train/train_DeepFlavour.py new file mode 100644 index 0000000..954b5cb --- /dev/null +++ b/Train/train_DeepFlavour.py @@ -0,0 +1,57 @@ + +#import sys +#import tensorflow as tf +#sys.modules["keras"] = tf.keras + +from DeepJetCore.training.training_base import training_base +from DeepJetCore.modeltools import fixLayersContaining,printLayerInfosAndWeights + + +#also does all the parsing +train=training_base(testrun=False) + +newtraining= not train.modelSet() +#for recovering a training +if newtraining: + from models import model_deepFlavourReference + + train.setModel(model_deepFlavourReference,dropoutRate=0.1,momentum=0.3) + + #train.keras_model=fixLayersContaining(train.keras_model, 'regression', invert=False) + + train.compileModel(learningrate=0.001, + loss='categorical_crossentropy', + metrics=['accuracy']) + + + train.train_data.maxFilesOpen=1 + + print(train.keras_model.summary()) + model,history = train.trainModel(nepochs=1, + batchsize=10000, + stop_patience=300, + lr_factor=0.5, + lr_patience=--1, + lr_epsilon=0.0001, + lr_cooldown=6, + lr_minimum=0.0001) + + + print('fixing input norms...') + train.keras_model=fixLayersContaining(train.keras_model, 'input_batchnorm') +train.compileModel(learningrate=0.0001, + loss='categorical_crossentropy', + metrics=['accuracy']) + +print(train.keras_model.summary()) +#printLayerInfosAndWeights(train.keras_model) + +model,history = train.trainModel(nepochs=65, #sweet spot from looking at the testing plots + batchsize=10000, + stop_patience=300, + lr_factor=0.5, + lr_patience=-1, + lr_epsilon=0.0001, + lr_cooldown=10, + lr_minimum=0.00001, + verbose=1,checkperiod=1) diff --git a/env.sh b/env.sh new file mode 100644 index 0000000..31e8538 --- /dev/null +++ b/env.sh @@ -0,0 +1,14 @@ + +#! /bin/bash + +export DJSUBPACKAGE=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd -P) +export DEEPJETCORE_SUBPACKAGE=$DJSUBPACKAGE + +cd $DJSUBPACKAGE +export PYTHONPATH=$DJSUBPACKAGE/modules:$PYTHONPATH +export PYTHONPATH=$DJSUBPACKAGE/modules/datastructures:$PYTHONPATH +export PATH=$DJSUBPACKAGE/scripts:$PATH + +export LD_LIBRARY_PATH=$DJSUBPACKAGE/modules/compiled:$LD_LIBRARY_PATH +export PYTHONPATH=$DJSUBPACKAGE/modules/compiled:$PYTHONPATH + diff --git a/modules/Layers.py b/modules/Layers.py new file mode 100644 index 0000000..623e9b1 --- /dev/null +++ b/modules/Layers.py @@ -0,0 +1 @@ +global_layers_list = {} diff --git a/modules/Losses.py b/modules/Losses.py new file mode 100644 index 0000000..8fb1be0 --- /dev/null +++ b/modules/Losses.py @@ -0,0 +1,3 @@ + +# Define custom losses here and add them to the global_loss_list dict (important!) +global_loss_list = {} diff --git a/modules/Metrics.py b/modules/Metrics.py new file mode 100644 index 0000000..49613ba --- /dev/null +++ b/modules/Metrics.py @@ -0,0 +1,3 @@ + +# Define custom metrics here and add them to the global_metrics_list dict (important!) +global_metrics_list = {} diff --git a/modules/compiled/Makefile b/modules/compiled/Makefile new file mode 100644 index 0000000..df51e5d --- /dev/null +++ b/modules/compiled/Makefile @@ -0,0 +1,53 @@ + + + +# +# This file might need some adjustments but should serve as a good basis +# + +PYTHON_INCLUDE = `python-config --includes` +PYTHON_LIB=`python-config --libs` + +ROOTSTUFF=`root-config --libs --glibs --ldflags` +ROOTCFLAGS=`root-config --cflags` + +CPP_FILES := $(wildcard src/*.cpp) +OBJ_FILES := $(addprefix obj/,$(notdir $(CPP_FILES:.cpp=.o))) +LD_FLAGS := `root-config --cflags --glibs` -lMathMore -L${DEEPJETCORE}/compiled -ldeepjetcorehelpers -lquicklz +CC_FLAGS := -fPIC -g -Wall `root-config --cflags` +CC_FLAGS += -I./interface -I${DEEPJETCORE}/compiled/interface +DJC_LIB = -L${DEEPJETCORE}/compiled -ldeepjetcorehelpers + + +MODULES := $(wildcard src/*.C) +MODULES_OBJ_FILES := $(addprefix ./,$(notdir $(MODULES:.C=.o))) +MODULES_SHARED_LIBS := $(addprefix ./,$(notdir $(MODULES:.C=.so))) + + +all: $(MODULES_SHARED_LIBS) $(patsubst bin/%.cpp, %, $(wildcard bin/*.cpp)) + +#compile the module helpers if necessary +#../modules/libsubpackagehelpers.so: +# cd ../modules; make; cd - + +%: bin/%.cpp $(OBJ_FILES) + g++ $(CC_FLAGS) $(LD_FLAGS) $(OBJ_FILES) $< -o $@ + + +obj/%.o: src/%.cpp + g++ $(CC_FLAGS) -c -o $@ $< + + +#python modules + +%.so: %.o + g++ -o $(@) -shared -fPIC $(LINUXADD) $< $(ROOTSTUFF) $(PYTHON_LIB) -lboost_python -lboost_numpy $(DJC_LIB) + +%.o: src/%.C + g++ $(ROOTCFLAGS) -O2 $(CC_FLAGS) -I./interface $(PYTHON_INCLUDE) -fPIC -c -o $(@) $< + + +clean: + rm -f obj/*.o obj/*.d *.so + rm -f % + diff --git a/modules/compiled/src/c_convert.C b/modules/compiled/src/c_convert.C new file mode 100644 index 0000000..cbb67d0 --- /dev/null +++ b/modules/compiled/src/c_convert.C @@ -0,0 +1,35 @@ + + +#include +#include "boost/python/numpy.hpp" +#include "boost/python/list.hpp" +#include "boost/python/str.hpp" +#include +#include + +//includes from deepjetcore +#include "helper.h" +#include "simpleArray.h" + +namespace p = boost::python; +namespace np = boost::python::numpy; + +/* + * Example of a python module that will be compiled. + * It can be used, e.g. to convert from fully custom input data + */ + +np::ndarray readFirstFeatures(std::string infile){ + + auto arr = djc::simpleArray({10,3,4}); + arr.at(0,2,1) = 5. ;//filling some data + + return simpleArrayToNumpy(arr); +} + +BOOST_PYTHON_MODULE(c_convert) { + Py_Initialize(); + np::initialize(); + def("readFirstFeatures", &readFirstFeatures); +} + diff --git a/modules/datastructures/TrainData_deepFlavour.py b/modules/datastructures/TrainData_deepFlavour.py new file mode 100644 index 0000000..1c37727 --- /dev/null +++ b/modules/datastructures/TrainData_deepFlavour.py @@ -0,0 +1,242 @@ + + +from DeepJetCore.TrainData import TrainData, fileTimeOut +import numpy as np + + + +class TrainData_DF(TrainData): + def __init__(self): + + TrainData.__init__(self) + + self.description = "DeepJet training datastructure" + + self.truth_branches = ['isB','isBB','isGBB','isLeptonicB','isLeptonicB_C','isC','isGCC','isCC','isUD','isS','isG'] + self.undefTruth=['isUndefined'] + self.weightbranchX='jet_pt' + self.weightbranchY='jet_eta' + self.remove = True + self.referenceclass='isB' + + self.weight_binX = np.array([ + 10,25,30,35,40,45,50,60,75,100, + 125,150,175,200,250,300,400,500, + 600,2000],dtype=float) + + self.weight_binY = np.array( + [-2.5,-2.,-1.5,-1.,-0.5,0.5,1,1.5,2.,2.5], + dtype=float + ) + + self.global_branches = ['jet_pt', 'jet_eta', + 'nCpfcand','nNpfcand', + 'nsv','npv', + 'TagVarCSV_trackSumJetEtRatio', + 'TagVarCSV_trackSumJetDeltaR', + 'TagVarCSV_vertexCategory', + 'TagVarCSV_trackSip2dValAboveCharm', + 'TagVarCSV_trackSip2dSigAboveCharm', + 'TagVarCSV_trackSip3dValAboveCharm', + 'TagVarCSV_trackSip3dSigAboveCharm', + 'TagVarCSV_jetNSelectedTracks', + 'TagVarCSV_jetNTracksEtaRel'] + + + self.cpf_branches = ['Cpfcan_BtagPf_trackEtaRel', + 'Cpfcan_BtagPf_trackPtRel', + 'Cpfcan_BtagPf_trackPPar', + 'Cpfcan_BtagPf_trackDeltaR', + 'Cpfcan_BtagPf_trackPParRatio', + 'Cpfcan_BtagPf_trackSip2dVal', + 'Cpfcan_BtagPf_trackSip2dSig', + 'Cpfcan_BtagPf_trackSip3dVal', + 'Cpfcan_BtagPf_trackSip3dSig', + 'Cpfcan_BtagPf_trackJetDistVal', + 'Cpfcan_ptrel', + 'Cpfcan_drminsv', + 'Cpfcan_VTX_ass', + 'Cpfcan_puppiw', + 'Cpfcan_chi2', + 'Cpfcan_quality'] + self.n_cpf = 25 + + self.npf_branches = ['Npfcan_ptrel','Npfcan_deltaR','Npfcan_isGamma','Npfcan_HadFrac','Npfcan_drminsv','Npfcan_puppiw'] + self.n_npf = 25 + + self.vtx_branches = ['sv_pt','sv_deltaR', + 'sv_mass', + 'sv_ntracks', + 'sv_chi2', + 'sv_normchi2', + 'sv_dxy', + 'sv_dxysig', + 'sv_d3d', + 'sv_d3dsig', + 'sv_costhetasvpv', + 'sv_enratio', + ] + + self.n_vtx = 4 + + self.reduced_truth = ['isB','isBB','isLeptonicB','isC','isUDS','isG'] + + + def createWeighterObjects(self, allsourcefiles): + # + # Calculates the weights needed for flattening the pt/eta spectrum + + from DeepJetCore.Weighter import Weighter + weighter = Weighter() + weighter.undefTruth = self.undefTruth + branches = [self.weightbranchX,self.weightbranchY] + branches.extend(self.truth_branches) + + if self.remove: + weighter.setBinningAndClasses( + [self.weight_binX,self.weight_binY], + self.weightbranchX,self.weightbranchY, + self.truth_branches + ) + + + counter=0 + import ROOT + from root_numpy import tree2array, root2array + if self.remove: + for fname in allsourcefiles: + fileTimeOut(fname, 120) + nparray = root2array( + fname, + treename = "deepntuplizer/tree", + stop = None, + branches = branches + ) + weighter.addDistributions(nparray) + #del nparray + counter=counter+1 + weighter.createRemoveProbabilitiesAndWeights(self.referenceclass) + return {'weigther':weighter} + + def convertFromSourceFile(self, filename, weighterobjects, istraining): + + # Function to produce the numpy training arrays from root files + + from DeepJetCore.Weighter import Weighter + from DeepJetCore.stopwatch import stopwatch + sw=stopwatch() + swall=stopwatch() + + def reduceTruth(uproot_arrays): + + b = uproot_arrays['isB'] + + bb = uproot_arrays['isBB'] + gbb = uproot_arrays['isGBB'] + + bl = uproot_arrays['isLeptonicB'] + blc = uproot_arrays['isLeptonicB_C'] + lepb = bl+blc + + c = uproot_arrays['isC'] + cc = uproot_arrays['isCC'] + gcc = uproot_arrays['isGCC'] + + ud = uproot_arrays['isUD'] + s = uproot_arrays['isS'] + uds = ud+s + + g = uproot_arrays['isG'] + + return np.vstack((b,bb+gbb,lepb,c+cc+gcc,uds,g)).transpose() + + print('reading '+filename) + + import ROOT + from root_numpy import tree2array, root2array + fileTimeOut(filename,120) #give eos a minute to recover + rfile = ROOT.TFile(filename) + tree = rfile.Get("deepntuplizer/tree") + self.nsamples = tree.GetEntries() + + + # user code, example works with the example 2D images in root format generated by make_example_data + from DeepJetCore.preprocessing import MeanNormZeroPad,MeanNormZeroPadParticles + + x_global = MeanNormZeroPad(filename,None, + [self.global_branches], + [1],self.nsamples) + + x_cpf = MeanNormZeroPadParticles(filename,None, + self.cpf_branches, + self.n_cpf,self.nsamples) + + x_npf = MeanNormZeroPadParticles(filename,None, + self.npf_branches, + self.n_npf,self.nsamples) + + x_vtx = MeanNormZeroPadParticles(filename,None, + self.vtx_branches, + self.n_vtx,self.nsamples) + + + + import uproot + urfile = uproot.open(filename)["deepntuplizer/tree"] + truth_arrays = urfile.arrays(self.truth_branches) + truth = reduceTruth(truth_arrays) + truth = truth.astype(dtype='float32', order='C') #important, float32 and C-type! + + x_global = x_global.astype(dtype='float32', order='C') + x_cpf = x_cpf.astype(dtype='float32', order='C') + x_npf = x_npf.astype(dtype='float32', order='C') + x_vtx = x_vtx.astype(dtype='float32', order='C') + + + + if self.remove: + b = [self.weightbranchX,self.weightbranchY] + b.extend(self.truth_branches) + b.extend(self.undefTruth) + fileTimeOut(filename, 120) + for_remove = root2array( + filename, + treename = "deepntuplizer/tree", + stop = None, + branches = b + ) + print weighterobjects + notremoves=weighterobjects['weigther'].createNotRemoveIndices(for_remove) + undef=for_remove['isUndefined'] + notremoves-=undef + print('took ', sw.getAndReset(), ' to create remove indices') + + + if self.remove: + print('remove') + x_global=x_global[notremoves > 0] + x_cpf=x_cpf[notremoves > 0] + x_npf=x_npf[notremoves > 0] + x_vtx=x_vtx[notremoves > 0] + truth=truth[notremoves > 0] + + newnsamp=x_global.shape[0] + print('reduced content to ', int(float(newnsamp)/float(self.nsamples)*100),'%') + + + print('remove nans') + x_global = np.where(np.isfinite(x_global), x_global, 0) + x_cpf = np.where(np.isfinite(x_cpf), x_cpf, 0) + x_npf = np.where(np.isfinite(x_npf), x_npf, 0) + x_vtx = np.where(np.isfinite(x_vtx), x_vtx, 0) + + return [x_global,x_cpf,x_npf,x_vtx], [truth], [] + + ## defines how to write out the prediction + def writeOutPrediction(self, predicted, features, truth, weights, outfilename, inputfile): + # predicted will be a list + + from root_numpy import array2root + out = np.core.records.fromarrays(np.vstack( (predicted[0].transpose(),truth[0].transpose(), features[0][:,0:2].transpose() ) ), + names='prob_isB, prob_isBB,prob_isLeptB, prob_isC,prob_isUDS,prob_isG,isB, isBB, isLeptB, isC,isUDS,isG,jet_pt, jet_eta') + array2root(out, outfilename, 'tree') diff --git a/modules/datastructures/__init__.py b/modules/datastructures/__init__.py new file mode 100644 index 0000000..c6a6e23 --- /dev/null +++ b/modules/datastructures/__init__.py @@ -0,0 +1,22 @@ + +#Make it look like a package +from glob import glob +from os import environ +from os.path import basename, dirname +from pdb import set_trace + +#gather all the files here +modules = [basename(i.replace('.py','')) for i in glob('%s/[A-Za-z]*.py' % dirname(__file__))] +__all__ = [] +structure_list=[] +for module_name in modules: + module = __import__(module_name, globals(), locals(), [module_name]) + for model_name in [i for i in dir(module) if 'TrainData' in i]: + + + model = getattr(module, model_name) + globals()[model_name] = model + locals( )[model_name] = model + __all__.append(model_name) + structure_list.append(model_name) + diff --git a/modules/models/__init__.py b/modules/models/__init__.py new file mode 100644 index 0000000..125b109 --- /dev/null +++ b/modules/models/__init__.py @@ -0,0 +1,17 @@ +#Make it look like a package +from glob import glob +from os import environ +from os.path import basename, dirname +from pdb import set_trace + +#gather all the files here +modules = [basename(i.replace('.py','')) for i in glob('%s/[A-Za-z]*.py' % dirname(__file__))] +__all__ = [] +for module_name in modules: + module = __import__(module_name, globals(), locals(), [module_name]) + for model_name in [i for i in dir(module) if 'model' in i]: + model = getattr(module, model_name) + globals()[model_name] = model + locals( )[model_name] = model + __all__.append(model_name) + diff --git a/modules/models/buildingBlocks.py b/modules/models/buildingBlocks.py new file mode 100644 index 0000000..e94bf02 --- /dev/null +++ b/modules/models/buildingBlocks.py @@ -0,0 +1,150 @@ +''' +standardised building blocks for the models +''' +from keras.layers import Dense, Dropout, Flatten,Convolution2D, Convolution1D, Lambda, LeakyReLU,Reshape +#from keras.layers.pooling import MaxPooling2D +from keras.layers import MaxPool2D +#from keras.layers.normalization import BatchNormalization +from keras.layers import BatchNormalization +from Layers import * + +def block_deepFlavourBTVConvolutions(charged,vertices,dropoutRate,active=True,batchnorm=False,batchmomentum=0.6): + ''' + deep Flavour convolution part. + ''' + cpf=charged + if active: + cpf = Convolution1D(64, 1, kernel_initializer='lecun_uniform', activation='relu', name='cpf_conv0')(cpf) + if batchnorm: + cpf = BatchNormalization(momentum=batchmomentum ,name='cpf_batchnorm0')(cpf) + cpf = Dropout(dropoutRate,name='cpf_dropout0')(cpf) + cpf = Convolution1D(32, 1, kernel_initializer='lecun_uniform', activation='relu', name='cpf_conv1')(cpf) + if batchnorm: + cpf = BatchNormalization(momentum=batchmomentum,name='cpf_batchnorm1')(cpf) + cpf = Dropout(dropoutRate,name='cpf_dropout1')(cpf) + cpf = Convolution1D(32, 1, kernel_initializer='lecun_uniform', activation='relu', name='cpf_conv2')(cpf) + if batchnorm: + cpf = BatchNormalization(momentum=batchmomentum,name='cpf_batchnorm2')(cpf) + cpf = Dropout(dropoutRate,name='cpf_dropout2')(cpf) + cpf = Convolution1D(8, 1, kernel_initializer='lecun_uniform', activation='relu' , name='cpf_conv3')(cpf) + else: + cpf = Convolution1D(1,1, kernel_initializer='zeros',trainable=False)(cpf) + + vtx = vertices + if active: + vtx = Convolution1D(64, 1, kernel_initializer='lecun_uniform', activation='relu', name='vtx_conv0')(vtx) + if batchnorm: + vtx = BatchNormalization(momentum=batchmomentum,name='vtx_batchnorm0')(vtx) + vtx = Dropout(dropoutRate,name='vtx_dropout0')(vtx) + vtx = Convolution1D(32, 1, kernel_initializer='lecun_uniform', activation='relu', name='vtx_conv1')(vtx) + if batchnorm: + vtx = BatchNormalization(momentum=batchmomentum,name='vtx_batchnorm1')(vtx) + vtx = Dropout(dropoutRate,name='vtx_dropout1')(vtx) + vtx = Convolution1D(32, 1, kernel_initializer='lecun_uniform', activation='relu', name='vtx_conv2')(vtx) + if batchnorm: + vtx = BatchNormalization(momentum=batchmomentum,name='vtx_batchnorm2')(vtx) + vtx = Dropout(dropoutRate,name='vtx_dropout2')(vtx) + vtx = Convolution1D(8, 1, kernel_initializer='lecun_uniform', activation='relu', name='vtx_conv3')(vtx) + else: + vtx = Convolution1D(1,1, kernel_initializer='zeros',trainable=False)(vtx) + + return cpf,vtx + + +def block_deepFlavourConvolutions(charged,neutrals,vertices,dropoutRate,active=True,batchnorm=False,batchmomentum=0.6): + ''' + deep Flavour convolution part. + ''' + cpf=charged + if active: + cpf = Convolution1D(64, 1, kernel_initializer='lecun_uniform', activation='relu', name='cpf_conv0')(cpf) + if batchnorm: + cpf = BatchNormalization(momentum=batchmomentum ,name='cpf_batchnorm0')(cpf) + cpf = Dropout(dropoutRate,name='cpf_dropout0')(cpf) + cpf = Convolution1D(32, 1, kernel_initializer='lecun_uniform', activation='relu', name='cpf_conv1')(cpf) + if batchnorm: + cpf = BatchNormalization(momentum=batchmomentum,name='cpf_batchnorm1')(cpf) + cpf = Dropout(dropoutRate,name='cpf_dropout1')(cpf) + cpf = Convolution1D(32, 1, kernel_initializer='lecun_uniform', activation='relu', name='cpf_conv2')(cpf) + if batchnorm: + cpf = BatchNormalization(momentum=batchmomentum,name='cpf_batchnorm2')(cpf) + cpf = Dropout(dropoutRate,name='cpf_dropout2')(cpf) + cpf = Convolution1D(8, 1, kernel_initializer='lecun_uniform', activation='relu' , name='cpf_conv3')(cpf) + else: + cpf = Convolution1D(1,1, kernel_initializer='zeros',trainable=False)(cpf) + + npf=neutrals + if active: + npf = Convolution1D(32, 1, kernel_initializer='lecun_uniform', activation='relu', name='npf_conv0')(npf) + if batchnorm: + npf = BatchNormalization(momentum=batchmomentum,name='npf_batchnorm0')(npf) + npf = Dropout(dropoutRate,name='npf_dropout0')(npf) + npf = Convolution1D(16, 1, kernel_initializer='lecun_uniform', activation='relu', name='npf_conv1')(npf) + if batchnorm: + npf = BatchNormalization(momentum=batchmomentum,name='npf_batchnorm1')(npf) + npf = Dropout(dropoutRate,name='npf_dropout1')(npf) + npf = Convolution1D(4, 1, kernel_initializer='lecun_uniform', activation='relu' , name='npf_conv2')(npf) + else: + npf = Convolution1D(1,1, kernel_initializer='zeros',trainable=False)(npf) + + vtx = vertices + if active: + vtx = Convolution1D(64, 1, kernel_initializer='lecun_uniform', activation='relu', name='vtx_conv0')(vtx) + if batchnorm: + vtx = BatchNormalization(momentum=batchmomentum,name='vtx_batchnorm0')(vtx) + vtx = Dropout(dropoutRate,name='vtx_dropout0')(vtx) + vtx = Convolution1D(32, 1, kernel_initializer='lecun_uniform', activation='relu', name='vtx_conv1')(vtx) + if batchnorm: + vtx = BatchNormalization(momentum=batchmomentum,name='vtx_batchnorm1')(vtx) + vtx = Dropout(dropoutRate,name='vtx_dropout1')(vtx) + vtx = Convolution1D(32, 1, kernel_initializer='lecun_uniform', activation='relu', name='vtx_conv2')(vtx) + if batchnorm: + vtx = BatchNormalization(momentum=batchmomentum,name='vtx_batchnorm2')(vtx) + vtx = Dropout(dropoutRate,name='vtx_dropout2')(vtx) + vtx = Convolution1D(8, 1, kernel_initializer='lecun_uniform', activation='relu', name='vtx_conv3')(vtx) + else: + vtx = Convolution1D(1,1, kernel_initializer='zeros',trainable=False)(vtx) + + return cpf,npf,vtx + + +def block_deepFlavourDense(x,dropoutRate,active=True,batchnorm=False,batchmomentum=0.6): + if active: + x= Dense(200, activation='relu',kernel_initializer='lecun_uniform', name='df_dense0')(x) + if batchnorm: + x = BatchNormalization(momentum=batchmomentum,name='df_dense_batchnorm0')(x) + x = Dropout(dropoutRate,name='df_dense_dropout0')(x) + x= Dense(100, activation='relu',kernel_initializer='lecun_uniform', name='df_dense1')(x) + if batchnorm: + x = BatchNormalization(momentum=batchmomentum,name='df_dense_batchnorm1')(x) + x = Dropout(dropoutRate,name='df_dense_dropout1')(x) + x= Dense(100, activation='relu',kernel_initializer='lecun_uniform', name='df_dense2')(x) + if batchnorm: + x = BatchNormalization(momentum=batchmomentum,name='df_dense_batchnorm2')(x) + x = Dropout(dropoutRate,name='df_dense_dropout2')(x) + x= Dense(100, activation='relu',kernel_initializer='lecun_uniform', name='df_dense3')(x) + if batchnorm: + x = BatchNormalization(momentum=batchmomentum,name='df_dense_batchnorm3')(x) + x = Dropout(dropoutRate,name='df_dense_dropout3')(x) + x= Dense(100, activation='relu',kernel_initializer='lecun_uniform', name='df_dense4')(x) + if batchnorm: + x = BatchNormalization(momentum=batchmomentum,name='df_dense_batchnorm4')(x) + x = Dropout(dropoutRate,name='df_dense_dropout4')(x) + x= Dense(100, activation='relu',kernel_initializer='lecun_uniform', name='df_dense5')(x) + if batchnorm: + x = BatchNormalization(momentum=batchmomentum,name='df_dense_batchnorm5')(x) + x = Dropout(dropoutRate,name='df_dense_dropout5')(x) + x= Dense(100, activation='relu',kernel_initializer='lecun_uniform', name='df_dense6')(x) + if batchnorm: + x = BatchNormalization(momentum=batchmomentum,name='df_dense_batchnorm6')(x) + x = Dropout(dropoutRate,name='df_dense_dropout6')(x) + x= Dense(100, activation='relu',kernel_initializer='lecun_uniform', name='df_dense7')(x) + if batchnorm: + x = BatchNormalization(momentum=batchmomentum,name='df_dense_batchnorm7')(x) + x = Dropout(dropoutRate,name='df_dense_dropout7')(x) + + else: + x= Dense(1,kernel_initializer='zeros',trainable=False,name='df_dense_off')(x) + + return x + diff --git a/modules/models/convolutional.py b/modules/models/convolutional.py new file mode 100644 index 0000000..e89ee0f --- /dev/null +++ b/modules/models/convolutional.py @@ -0,0 +1,58 @@ +from keras.layers import Dense, Dropout, Flatten,Concatenate, Lambda, Convolution2D, LSTM, Convolution1D, Conv2D,GlobalAveragePooling1D, GlobalMaxPooling1D,TimeDistributed +from keras.models import Model +import tensorflow as tf +from keras.layers import BatchNormalization +#from keras.layers.normalization import BatchNormalization + +from keras import backend as K +from Layers import * +from buildingBlocks import block_deepFlavourConvolutions, block_deepFlavourDense + + +def model_deepFlavourReference(Inputs,dropoutRate=0.1,momentum=0.6): + """ + reference 1x1 convolutional model for 'deepFlavour' + with recurrent layers and batch normalisation + standard dropout rate it 0.1 + should be trained for flavour prediction first. afterwards, all layers can be fixed + that do not include 'regression' and the training can be repeated focusing on the regression part + (check function fixLayersContaining with invert=True) + """ + globalvars = BatchNormalization(momentum=momentum,name='globals_input_batchnorm') (Inputs[0]) + cpf = BatchNormalization(momentum=momentum,name='cpf_input_batchnorm') (Inputs[1]) + npf = BatchNormalization(momentum=momentum,name='npf_input_batchnorm') (Inputs[2]) + vtx = BatchNormalization(momentum=momentum,name='vtx_input_batchnorm') (Inputs[3]) + + cpf,npf,vtx = block_deepFlavourConvolutions(charged=cpf, + neutrals=npf, + vertices=vtx, + dropoutRate=dropoutRate, + active=True, + batchnorm=True, batchmomentum=momentum) + + + # + cpf = LSTM(150,go_backwards=True,implementation=2, name='cpf_lstm')(cpf) + cpf=BatchNormalization(momentum=momentum,name='cpflstm_batchnorm')(cpf) + cpf = Dropout(dropoutRate)(cpf) + + npf = LSTM(50,go_backwards=True,implementation=2, name='npf_lstm')(npf) + npf=BatchNormalization(momentum=momentum,name='npflstm_batchnorm')(npf) + npf = Dropout(dropoutRate)(npf) + + vtx = LSTM(50,go_backwards=True,implementation=2, name='vtx_lstm')(vtx) + vtx=BatchNormalization(momentum=momentum,name='vtxlstm_batchnorm')(vtx) + vtx = Dropout(dropoutRate)(vtx) + + + x = Concatenate()( [globalvars,cpf,npf,vtx ]) + + x = block_deepFlavourDense(x,dropoutRate,active=True,batchnorm=True,batchmomentum=momentum) + + flavour_pred=Dense(6, activation='softmax',kernel_initializer='lecun_uniform',name='ID_pred')(x) + + predictions = [flavour_pred] + model = Model(inputs=Inputs, outputs=predictions) + return model + +