-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fixed columnrepetition bug to run in small matrices, can choose cross…
… entropy reduction
- Loading branch information
Showing
16 changed files
with
221 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
mklab.JGNN.adhoc.parsers.LayeredBuilder | ||
features = config: 2.0 | ||
hidden = config: 8.0 | ||
reg = config: 0.005 | ||
classes = config: 2.0 | ||
reduced = config: 6.0 | ||
2hidden = config: 16.0 | ||
hiddenReduced = config: 48.0 | ||
A = var: null | ||
h0 = var: null | ||
_tmp11 = param DenseMatrix (2hidden 16, hidden 8): [-0.19328476877863923,-0.6816842687001173,0.33084330660685063,0.29012383747022974,-0.008856823759693731,0.3243209298674903,0.03630835867502023,-0.36893022633353045,0.3852823246760141,0.1675140168304929,0.34749646804722034,-0.043510359020801116,-0.5480261005662305,-0.36506518491627604,0.4618722893431745,-0.31670662910836095,-0.06766959139063297,3.964965286227935E-4,0.42952042800161827,-0.6955929308960008,-0.6286510760021554,-0.0965623432938279,-0.18299829835981102,-0.05956609007597157,-0.2920302998132937,-0.2547563636041248,0.15096554826682856,0.3775004405962659,-0.8995334446683099,0.06333771747725044,0.65460552575879,-0.3186670866531967,-0.5309119132179063,0.394784224886659,0.29509546796445657,-0.4690491776981738,-0.10572405664194967,0.27624178602476357,0.019287597474623407,-0.08221736520461895,0.3752314288377563,-0.14388855480059862,0.5031819562803397,-0.39231949747503947,0.030226728241452908,0.539230918192968,-0.1866028170582539,-0.054814542702866335,-0.037001880635509805,-0.05866626266145604,0.6570208369054049,0.4037472351461765,-0.7434327464304322,-0.1800452245270526,0.40850049372980474,0.22084528373229415,0.01669375687488634,-0.01769463510961822,0.48369486751642693,-0.4513516083571352,-0.18643307652802016,0.47531322785792024,-0.15418322605338194,0.08869303689826054,0.6496841665481397,0.12057361296404132,-0.10087202695777087,0.20077395712411783,0.12298342734788265,0.4463359109741554,0.595268214487788,0.3198579540060096,0.15003772422143247,0.3036177944924227,0.39176025176862367,0.4385176746516399,0.29267786236453575,-0.538075342791701,0.18692588008664582,-0.22874264781324213,-0.2956689694666771,0.08507352063026218,0.16405172192265624,-0.16627517419517734,0.7367038620405035,0.23692065568506798,-0.39340620243769026,-0.1218118529198394,0.34270271396790364,0.10611754538241865,0.23511719104140957,0.06257205623691434,-0.026719449424170108,0.13262598343728552,0.3321072492887023,-0.322213838336707,-0.09058229994288335,-0.03864904977375033,0.18830357531855324,0.035130801296273106,-0.40433428753295697,0.34087491252629254,0.8426791513172762,0.07823534114592788,0.017292176905372218,-0.07825012073294113,-0.551117338968398,0.5990794942026233,-0.3498034727233782,-0.15620124985299463,-0.1313686261579107,-0.17445929115062758,0.05423176395512194,0.3305972373060727,-0.07433345988400326,0.02211377023555965,0.5169316243904939,-0.35626175262079174,0.2605377618616539,0.040877884464697034,0.6011922447538958,0.10590649114094869,0.04866151867585637,1.0498653811003313,-0.4831934782954236,0.2155992503280004,-0.4449982638493458,-0.15205818353654899] | ||
_tmp12 = param Tensor (hidden 8): [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0] | ||
_tmp17 = param Tensor (hidden 8): [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0] | ||
_tmp16 = param DenseMatrix (2hidden 16, hidden 8): [0.5021034187023461,0.13868567394933753,-0.4319463012253357,-0.01957776716822693,-0.6094145749292609,0.1439419292273619,-0.10279569620728086,-0.24391336506490913,-0.07060120291020784,0.365303228749883,0.46701090217314395,-0.07945034851001198,0.18256005736919872,0.523490149473091,0.3783930271864008,-0.5348058052156716,0.248710923254586,-0.18060630411319933,-0.9064084134026116,-0.5047861771258582,-0.21041070643734053,-0.2508198800475379,-0.19029786209348837,-0.34694159663184004,0.02890813945657403,0.11926584949928104,0.008029066336753142,-0.30160347629852446,0.4142657749089172,-0.10276450006376554,0.21668328634092998,0.7794994419864524,0.31912230822442206,-0.7896441157875099,0.0669323910449409,0.005079087575026865,0.3871876711708779,-0.027795613673683106,0.5195156321578849,0.1451244649129374,-0.1234044285044108,-0.3910978872449653,0.19433235371428767,-0.484786863765534,-0.06500045233351805,-0.13535973107007787,-0.001400944670307145,0.14315247913721696,-0.44489526404547725,-0.7113623734549364,-0.3537344028937658,-0.2826613161956493,0.38693511145728415,-0.5540629852227372,-0.26564026626370785,-0.7528758353657103,-0.5030481946916305,0.7361802029086209,0.048380119415574255,0.19193865286784526,-0.18572427292244453,-0.1070739962101043,-0.4667238425266632,0.1724481017787257,0.07907853007683265,0.07794419563775512,-0.2258137337359118,0.7440959006326245,-0.684988546540687,0.17920646820148134,0.46240699117307765,0.4320721338608193,-0.37346377979218376,-0.4607102416897734,-0.5562339107003924,-1.1335825147833678,-0.6780782600923088,0.3572746140653733,-0.41292103314568673,-0.31905109650720387,-0.25370387856853094,-0.480925487249198,-0.3151921264430039,0.07776796917965208,-0.1788986631998994,0.08475066278957165,-0.5237351973530295,0.4193834593785406,0.03681254346412109,-0.584427852645633,-0.02896556963429372,-0.1299053244734533,-0.2218704270185149,0.24571193054996324,0.42048425749103463,-0.8880484715017992,0.31700958719137495,-0.4581902441160126,-0.30034222408803624,-0.7713140669541653,0.1906079322091254,0.5614853908408053,-0.12311847026365601,0.094618305065478,0.16756886278939304,0.5190453004025772,0.1461756848410729,0.10340340085403356,0.009543912257454283,-0.4945736675559478,0.18117986153395202,0.4416916735062306,0.2303877229303233,0.22360404940516052,-0.6112281541522248,-0.3049268148884489,-0.21320808341784314,0.23185142098280917,0.4502900379269799,-0.165791378027787,0.7864019252894919,0.2813364624331278,0.9474109407516412,0.12430889772546948,-0.2172150595120388,0.4876508580337231,0.344365758524931,0.36338942310374867] | ||
_tmp20 = param DenseMatrix (hidden 8, hidden 8): [0.6814404680132916,0.25361885765312314,0.10256163018756725,0.1922851650348111,-0.05410077006150943,-0.38549110279625004,0.5004202987313823,-0.2324453468875155,-0.22099505283361742,0.06551162436227614,0.5271888950214714,0.22805732984591637,0.42595558449507714,-0.2595469927668719,0.41457441383114,-0.40693167679055886,0.39803040041408844,0.3440959433142813,-0.5753750505438064,0.06237296287559791,-0.4211115226435094,-0.1763906243200833,1.1985422471982727,-0.07676527261740763,0.2326178151531441,0.35704961284054426,-0.26729485975443884,0.4095283664505398,0.18168483029293744,-0.2932381742502261,0.8414330706272171,0.37977048763315235,0.5062335030466366,0.9869523173389951,1.6838631648078104,0.1263031934267131,-0.5478652955503843,-3.035840047741802E-4,-0.3174070977338061,0.7191699495163361,-0.07943020358021734,0.21984277828430565,0.8033588782047216,-0.1861673333583821,0.5596593401061914,-0.35031592936519707,-0.37717583187856096,0.23602515813879557,0.0449456664499197,-0.47412881293932263,0.21589515054272587,-0.20568191860829724,0.2901169807397064,0.4150953875668683,0.03210280126721354,-0.7381724464055942,0.08744218760930367,1.657249410635567,0.38200285347323787,0.41650869362747445,0.12939056231065268,0.40376690602764725,-0.047337732760709926,0.6966554415998744] | ||
_tmp21 = param Tensor (hidden 8): [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0] | ||
_tmp3 = param Tensor (hidden 8): [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0] | ||
_tmp2 = param DenseMatrix (features 2, hidden 8): [0.0962596204088395,-0.028884691451251224,-0.028116488038815862,0.3577563699462411,-0.1915495424716903,-0.5879573411284648,0.16687000433532928,0.18406857056345563,-0.5892489716766497,0.9247104404685634,-0.06459238390322171,-1.1421568977936178,-0.722566667961998,-1.3796258912305088,-0.06280695024528939,1.7517702074586412] | ||
_tmp23 = param DenseMatrix (hiddenReduced 48, classes 2): [0.40938433617236525,-0.4987199441662107,0.26203992378101376,-0.43133379215734263,0.6709105283816391,-0.9814661346740309,0.539095205750881,-0.4207697065158933,0.0240229334618446,-0.14223548752189802,0.4438842753898452,-0.49120390072985726,0.17395173314865078,0.06471401007521176,0.4478543276829584,0.43432759563978957,0.3447420155314258,-1.1253578954761723,0.4993227096710841,0.30042014773726416,0.6795526379343243,-0.5195903531075249,0.4621000372201286,-0.7689859292393859,0.9106708726562932,-0.6086682072036363,-0.49443734786718163,-2.0015702142041314,0.2759331578478036,-0.9644772192516243,-0.14250917516719866,-5.04944040825525,1.8578946969831234,-2.689873422142453,17.819588902507228,2.3524952524944034,0.607415470001217,-0.005342721841655263,0.2345666718277672,-0.6435856594935998,0.09465516271962526,-1.028924483552736,0.5494581031684685,0.004150240845003633,0.09245869079258418,-1.159928537758163,0.5395232430817755,-0.8876905496583974,-0.43672870130761726,-0.6001793288795612,-0.2209758164529676,0.3771985686838059,-0.5264231473231888,1.2086668405130667,-0.3677004028561642,-0.23961520350564466,-0.1312633323701847,0.20028985803474514,-0.5997756663497473,1.1861726582546483,-0.5171240051903958,-0.36876941825191734,-0.08069918314199832,-0.6008323309697636,0.2923273123036548,1.4251798900771944,-0.3047689840101401,-0.11336172298411515,-0.5004652657880019,-0.25323032177253096,-0.04182412745906058,0.7914127746669496,-0.7185649111348198,0.4124581236471469,0.8006830216814292,1.4727565356504557,-0.98152228243357,0.6496030752864841,0.2086967713555983,5.221403415918512,-2.005888671728391,2.2603888220693245,-17.575851034929514,-2.500026725809498,-0.24434030852837782,-0.33232526683552893,0.3851370120960596,1.4858460025078009,-0.421678210018014,0.7016615436437713,-0.578971814276199,0.029161238557184666,-0.06956312718637947,1.0065065640325745,-0.5451395936226792,1.0762157054167099] | ||
_tmp5 = param DenseMatrix (hidden 8, hidden 8): [-0.30954935594367816,0.12116378595232227,-0.306386786291067,-0.5220864721581343,-0.003000295537359033,-0.37628978694357235,-0.2695812643245417,0.6700734618145359,0.20645751601752704,-0.05260439493400403,0.17754381848591488,-0.17735111718384539,0.167779299276608,0.27389750680988467,0.23764308545771937,1.6708105311783883,0.06728854514002967,-0.44839216787805547,0.13632607512869982,0.37582277585252005,-0.28216295775529515,-0.5678759903551267,-0.35640357469197825,0.3560222333683681,-0.06393612839774328,-0.43344284383245607,-0.3986854818366697,0.26732202379329006,0.34858445673548166,0.01968688359473008,0.7123181505558769,-0.6411211857602117,-0.271134106235963,0.2491391197457063,0.0939090724621519,-0.7099186482668389,0.4388952099959059,0.8554823355249708,-1.166165039369413,1.1988479095710083,-0.2320232541561487,-0.08166507448918221,-0.41733625408668373,-0.6475038987362935,-0.4785141909120537,-0.6513421097528718,0.7196243595671512,0.0627929089576402,0.08109030285230427,-2.3186007809859356E-4,0.06453881213468693,0.7580057257515331,-0.12080130803301743,-0.17046938190024785,0.28216868068309153,0.005635712488808225,-0.45095495159758453,0.13584676111699664,-0.73411011182003,-0.22715114851512125,-0.5224650791751124,0.8293987840564809,0.021505674769915967,-0.2903500417602574] | ||
_tmp6 = param Tensor (hidden 8): [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0] | ||
edgeSrc = from A | ||
edgeDst = to A | ||
_tmp1 = h0 @ _tmp2 | ||
_tmp0 = _tmp1 + _tmp3 | ||
h1 = relu _tmp0 | ||
_tmp4 = h1 @ _tmp5 | ||
h2 = _tmp4 + _tmp6 | ||
_tmp7 = h2 [ edgeSrc ] | ||
_tmp8 = h2 [ edgeDst ] | ||
message2 = _tmp7 | _tmp8 | ||
_tmp10 = message2 @ _tmp11 | ||
_tmp9 = _tmp10 + _tmp12 | ||
transformed2 = relu _tmp9 | ||
received2 = reduce ( transformed2 , A ) | ||
_tmp14 = received2 | h2 | ||
_tmp15 = _tmp14 @ _tmp16 | ||
_tmp13 = _tmp15 + _tmp17 | ||
i2 = relu _tmp13 | ||
_tmp19 = i2 @ _tmp20 | ||
_tmp18 = _tmp19 + _tmp21 | ||
h3 = relu _tmp18 | ||
z3 = sort ( h3 , reduced ) | ||
_tmp22 = h3 [ z3 ] | ||
h4 = reshape ( _tmp22 , 1 , hiddenReduced ) | ||
h5 = h4 @ _tmp23 | ||
h6 = softmax ( h5 , row ) | ||
|
||
return h6 |
99 changes: 99 additions & 0 deletions
99
JGNN/src/examples/graphClassification/MessageSortPooling.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
package graphClassification; | ||
|
||
import java.util.Arrays; | ||
|
||
import mklab.JGNN.adhoc.ModelBuilder; | ||
import mklab.JGNN.adhoc.parsers.LayeredBuilder; | ||
import mklab.JGNN.core.Matrix; | ||
import mklab.JGNN.core.Tensor; | ||
import mklab.JGNN.core.ThreadPool; | ||
import mklab.JGNN.nn.Loss; | ||
import mklab.JGNN.nn.Model; | ||
import mklab.JGNN.nn.initializers.XavierNormal; | ||
import mklab.JGNN.nn.loss.CategoricalCrossEntropy; | ||
import mklab.JGNN.nn.optimizers.Adam; | ||
import mklab.JGNN.nn.optimizers.BatchOptimizer; | ||
|
||
/** | ||
* | ||
* @author github.com/gavalian | ||
* @author Emmanouil Krasanakis | ||
*/ | ||
public class MessageSortPooling { | ||
|
||
public static void main(String[] args){ | ||
long reduced = 5; // input graphs need to have at least that many nodes, lower values decrease accuracy | ||
long hidden = 8; // since this library does not use GPU parallelization, many latent dims reduce speed | ||
|
||
ModelBuilder builder = new LayeredBuilder() | ||
.var("A") | ||
.config("features", 1) | ||
.config("classes", 2) | ||
//.config("reduced", reduced) | ||
.config("hidden", hidden) | ||
.config("2hidden", 2*hidden) | ||
.config("reg", 0.005) | ||
.operation("edgeSrc = from(A)") | ||
.operation("edgeDst = to(A)") | ||
.layer("h{l+1}=relu(h{l}@matrix(features, hidden, reg)+vector(hidden))") | ||
.layer("h{l+1}=h{l}@matrix(hidden, hidden, reg)+vector(hidden)") | ||
|
||
// message passing layer (make it as complex as needed) | ||
.operation("message{l}=h{l}[edgeSrc] | h{l}[edgeDst]") | ||
.operation("transformed{l}=relu(message{l}@matrix(2hidden, hidden, reg)+vector(hidden))") | ||
.operation("received{l}=reduce(transformed{l}, A)") | ||
.operation("i{l}=relu((received{l} | h{l})@matrix(2hidden, hidden, reg)+vector(hidden))") | ||
.layer("h{l+1}=relu(i{l}@matrix(hidden, hidden, reg)+vector(hidden))") | ||
|
||
// this would be the sort pooling | ||
/*.config("hiddenReduced", hidden*reduced) // reduced * (previous layer's output size) | ||
.operation("z{l}=sort(h{l}, reduced)") // currently, the parser fails to understand full expressions within next step's gather, so we need to create this intermediate variable | ||
.layer("h{l+1}=reshape(h{l}[z{l}], 1, hiddenReduced)") // | ||
.layer("h{l+1}=h{l}@matrix(hiddenReduced, classes)") | ||
.layer("h{l+1}=softmax(h{l}, row)")*/ | ||
|
||
// the following two layers implement the sum pooling | ||
.layer("h{l+1}=sum(h{l}@matrix(hidden, classes)+vector(classes), row)") | ||
.layer("h{l+1}=softmax(h{l}, row)") | ||
|
||
.out("h{l}"); | ||
|
||
TrajectoryData dtrain = new TrajectoryData(800); | ||
TrajectoryData dtest = new TrajectoryData(200); | ||
|
||
Model model = builder.getModel().init(new XavierNormal()); | ||
BatchOptimizer optimizer = new BatchOptimizer(new Adam(0.01)); | ||
Loss loss = new CategoricalCrossEntropy(); | ||
for(int epoch=0; epoch<600; epoch++) { | ||
// gradient update over all graphs | ||
for(int graphId=0; graphId<dtrain.graphs.size(); graphId++) { | ||
int graphIdentifier = graphId; | ||
// each gradient calculation into a new thread pool task | ||
ThreadPool.getInstance().submit(new Runnable() { | ||
@Override | ||
public void run() { | ||
//System.out.println(dtrain.graphs.get(graphIdentifier).sum()); | ||
Matrix adjacency = dtrain.graphs.get(graphIdentifier); | ||
Matrix features= dtrain.features.get(graphIdentifier); | ||
Tensor graphLabel = dtrain.labels.get(graphIdentifier).asRow(); | ||
model.train(loss, optimizer, | ||
Arrays.asList(features, adjacency), | ||
Arrays.asList(graphLabel)); | ||
} | ||
}); | ||
} | ||
ThreadPool.getInstance().waitForConclusion(); // wait for all gradients to compute | ||
optimizer.updateAll(); // apply gradients on model parameters | ||
|
||
double acc = 0.0; | ||
for(int graphId=0; graphId<dtest.graphs.size(); graphId++) { | ||
Matrix adjacency = dtest.graphs.get(graphId); | ||
Matrix features= dtest.features.get(graphId); | ||
Tensor graphLabel = dtest.labels.get(graphId); | ||
if(model.predict(Arrays.asList(features, adjacency)).get(0).argmax()==graphLabel.argmax()) | ||
acc += 1; | ||
} | ||
System.out.println("iter = " + epoch + " " + acc/dtest.graphs.size()); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.