Skip to content

Commit

Permalink
fixed columnrepetition bug to run in small matrices, can choose cross…
Browse files Browse the repository at this point in the history
… entropy reduction
  • Loading branch information
maniospas committed Mar 27, 2024
1 parent 66cf283 commit 8463fb9
Show file tree
Hide file tree
Showing 16 changed files with 221 additions and 22 deletions.
49 changes: 49 additions & 0 deletions JGNN/neural_graph.jggn
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
mklab.JGNN.adhoc.parsers.LayeredBuilder
features = config: 2.0
hidden = config: 8.0
reg = config: 0.005
classes = config: 2.0
reduced = config: 6.0
2hidden = config: 16.0
hiddenReduced = config: 48.0
A = var: null
h0 = var: null
_tmp11 = param DenseMatrix (2hidden 16, hidden 8): [-0.19328476877863923,-0.6816842687001173,0.33084330660685063,0.29012383747022974,-0.008856823759693731,0.3243209298674903,0.03630835867502023,-0.36893022633353045,0.3852823246760141,0.1675140168304929,0.34749646804722034,-0.043510359020801116,-0.5480261005662305,-0.36506518491627604,0.4618722893431745,-0.31670662910836095,-0.06766959139063297,3.964965286227935E-4,0.42952042800161827,-0.6955929308960008,-0.6286510760021554,-0.0965623432938279,-0.18299829835981102,-0.05956609007597157,-0.2920302998132937,-0.2547563636041248,0.15096554826682856,0.3775004405962659,-0.8995334446683099,0.06333771747725044,0.65460552575879,-0.3186670866531967,-0.5309119132179063,0.394784224886659,0.29509546796445657,-0.4690491776981738,-0.10572405664194967,0.27624178602476357,0.019287597474623407,-0.08221736520461895,0.3752314288377563,-0.14388855480059862,0.5031819562803397,-0.39231949747503947,0.030226728241452908,0.539230918192968,-0.1866028170582539,-0.054814542702866335,-0.037001880635509805,-0.05866626266145604,0.6570208369054049,0.4037472351461765,-0.7434327464304322,-0.1800452245270526,0.40850049372980474,0.22084528373229415,0.01669375687488634,-0.01769463510961822,0.48369486751642693,-0.4513516083571352,-0.18643307652802016,0.47531322785792024,-0.15418322605338194,0.08869303689826054,0.6496841665481397,0.12057361296404132,-0.10087202695777087,0.20077395712411783,0.12298342734788265,0.4463359109741554,0.595268214487788,0.3198579540060096,0.15003772422143247,0.3036177944924227,0.39176025176862367,0.4385176746516399,0.29267786236453575,-0.538075342791701,0.18692588008664582,-0.22874264781324213,-0.2956689694666771,0.08507352063026218,0.16405172192265624,-0.16627517419517734,0.7367038620405035,0.23692065568506798,-0.39340620243769026,-0.1218118529198394,0.34270271396790364,0.10611754538241865,0.23511719104140957,0.06257205623691434,-0.026719449424170108,0.13262598343728552,0.3321072492887023,-0.322213838336707,-0.09058229994288335,-0.03864904977375033,0.18830357531855324,0.035130801296273106,-0.40433428753295697,0.34087491252629254,0.8426791513172762,0.07823534114592788,0.017292176905372218,-0.07825012073294113,-0.551117338968398,0.5990794942026233,-0.3498034727233782,-0.15620124985299463,-0.1313686261579107,-0.17445929115062758,0.05423176395512194,0.3305972373060727,-0.07433345988400326,0.02211377023555965,0.5169316243904939,-0.35626175262079174,0.2605377618616539,0.040877884464697034,0.6011922447538958,0.10590649114094869,0.04866151867585637,1.0498653811003313,-0.4831934782954236,0.2155992503280004,-0.4449982638493458,-0.15205818353654899]
_tmp12 = param Tensor (hidden 8): [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]
_tmp17 = param Tensor (hidden 8): [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]
_tmp16 = param DenseMatrix (2hidden 16, hidden 8): [0.5021034187023461,0.13868567394933753,-0.4319463012253357,-0.01957776716822693,-0.6094145749292609,0.1439419292273619,-0.10279569620728086,-0.24391336506490913,-0.07060120291020784,0.365303228749883,0.46701090217314395,-0.07945034851001198,0.18256005736919872,0.523490149473091,0.3783930271864008,-0.5348058052156716,0.248710923254586,-0.18060630411319933,-0.9064084134026116,-0.5047861771258582,-0.21041070643734053,-0.2508198800475379,-0.19029786209348837,-0.34694159663184004,0.02890813945657403,0.11926584949928104,0.008029066336753142,-0.30160347629852446,0.4142657749089172,-0.10276450006376554,0.21668328634092998,0.7794994419864524,0.31912230822442206,-0.7896441157875099,0.0669323910449409,0.005079087575026865,0.3871876711708779,-0.027795613673683106,0.5195156321578849,0.1451244649129374,-0.1234044285044108,-0.3910978872449653,0.19433235371428767,-0.484786863765534,-0.06500045233351805,-0.13535973107007787,-0.001400944670307145,0.14315247913721696,-0.44489526404547725,-0.7113623734549364,-0.3537344028937658,-0.2826613161956493,0.38693511145728415,-0.5540629852227372,-0.26564026626370785,-0.7528758353657103,-0.5030481946916305,0.7361802029086209,0.048380119415574255,0.19193865286784526,-0.18572427292244453,-0.1070739962101043,-0.4667238425266632,0.1724481017787257,0.07907853007683265,0.07794419563775512,-0.2258137337359118,0.7440959006326245,-0.684988546540687,0.17920646820148134,0.46240699117307765,0.4320721338608193,-0.37346377979218376,-0.4607102416897734,-0.5562339107003924,-1.1335825147833678,-0.6780782600923088,0.3572746140653733,-0.41292103314568673,-0.31905109650720387,-0.25370387856853094,-0.480925487249198,-0.3151921264430039,0.07776796917965208,-0.1788986631998994,0.08475066278957165,-0.5237351973530295,0.4193834593785406,0.03681254346412109,-0.584427852645633,-0.02896556963429372,-0.1299053244734533,-0.2218704270185149,0.24571193054996324,0.42048425749103463,-0.8880484715017992,0.31700958719137495,-0.4581902441160126,-0.30034222408803624,-0.7713140669541653,0.1906079322091254,0.5614853908408053,-0.12311847026365601,0.094618305065478,0.16756886278939304,0.5190453004025772,0.1461756848410729,0.10340340085403356,0.009543912257454283,-0.4945736675559478,0.18117986153395202,0.4416916735062306,0.2303877229303233,0.22360404940516052,-0.6112281541522248,-0.3049268148884489,-0.21320808341784314,0.23185142098280917,0.4502900379269799,-0.165791378027787,0.7864019252894919,0.2813364624331278,0.9474109407516412,0.12430889772546948,-0.2172150595120388,0.4876508580337231,0.344365758524931,0.36338942310374867]
_tmp20 = param DenseMatrix (hidden 8, hidden 8): [0.6814404680132916,0.25361885765312314,0.10256163018756725,0.1922851650348111,-0.05410077006150943,-0.38549110279625004,0.5004202987313823,-0.2324453468875155,-0.22099505283361742,0.06551162436227614,0.5271888950214714,0.22805732984591637,0.42595558449507714,-0.2595469927668719,0.41457441383114,-0.40693167679055886,0.39803040041408844,0.3440959433142813,-0.5753750505438064,0.06237296287559791,-0.4211115226435094,-0.1763906243200833,1.1985422471982727,-0.07676527261740763,0.2326178151531441,0.35704961284054426,-0.26729485975443884,0.4095283664505398,0.18168483029293744,-0.2932381742502261,0.8414330706272171,0.37977048763315235,0.5062335030466366,0.9869523173389951,1.6838631648078104,0.1263031934267131,-0.5478652955503843,-3.035840047741802E-4,-0.3174070977338061,0.7191699495163361,-0.07943020358021734,0.21984277828430565,0.8033588782047216,-0.1861673333583821,0.5596593401061914,-0.35031592936519707,-0.37717583187856096,0.23602515813879557,0.0449456664499197,-0.47412881293932263,0.21589515054272587,-0.20568191860829724,0.2901169807397064,0.4150953875668683,0.03210280126721354,-0.7381724464055942,0.08744218760930367,1.657249410635567,0.38200285347323787,0.41650869362747445,0.12939056231065268,0.40376690602764725,-0.047337732760709926,0.6966554415998744]
_tmp21 = param Tensor (hidden 8): [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]
_tmp3 = param Tensor (hidden 8): [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]
_tmp2 = param DenseMatrix (features 2, hidden 8): [0.0962596204088395,-0.028884691451251224,-0.028116488038815862,0.3577563699462411,-0.1915495424716903,-0.5879573411284648,0.16687000433532928,0.18406857056345563,-0.5892489716766497,0.9247104404685634,-0.06459238390322171,-1.1421568977936178,-0.722566667961998,-1.3796258912305088,-0.06280695024528939,1.7517702074586412]
_tmp23 = param DenseMatrix (hiddenReduced 48, classes 2): [0.40938433617236525,-0.4987199441662107,0.26203992378101376,-0.43133379215734263,0.6709105283816391,-0.9814661346740309,0.539095205750881,-0.4207697065158933,0.0240229334618446,-0.14223548752189802,0.4438842753898452,-0.49120390072985726,0.17395173314865078,0.06471401007521176,0.4478543276829584,0.43432759563978957,0.3447420155314258,-1.1253578954761723,0.4993227096710841,0.30042014773726416,0.6795526379343243,-0.5195903531075249,0.4621000372201286,-0.7689859292393859,0.9106708726562932,-0.6086682072036363,-0.49443734786718163,-2.0015702142041314,0.2759331578478036,-0.9644772192516243,-0.14250917516719866,-5.04944040825525,1.8578946969831234,-2.689873422142453,17.819588902507228,2.3524952524944034,0.607415470001217,-0.005342721841655263,0.2345666718277672,-0.6435856594935998,0.09465516271962526,-1.028924483552736,0.5494581031684685,0.004150240845003633,0.09245869079258418,-1.159928537758163,0.5395232430817755,-0.8876905496583974,-0.43672870130761726,-0.6001793288795612,-0.2209758164529676,0.3771985686838059,-0.5264231473231888,1.2086668405130667,-0.3677004028561642,-0.23961520350564466,-0.1312633323701847,0.20028985803474514,-0.5997756663497473,1.1861726582546483,-0.5171240051903958,-0.36876941825191734,-0.08069918314199832,-0.6008323309697636,0.2923273123036548,1.4251798900771944,-0.3047689840101401,-0.11336172298411515,-0.5004652657880019,-0.25323032177253096,-0.04182412745906058,0.7914127746669496,-0.7185649111348198,0.4124581236471469,0.8006830216814292,1.4727565356504557,-0.98152228243357,0.6496030752864841,0.2086967713555983,5.221403415918512,-2.005888671728391,2.2603888220693245,-17.575851034929514,-2.500026725809498,-0.24434030852837782,-0.33232526683552893,0.3851370120960596,1.4858460025078009,-0.421678210018014,0.7016615436437713,-0.578971814276199,0.029161238557184666,-0.06956312718637947,1.0065065640325745,-0.5451395936226792,1.0762157054167099]
_tmp5 = param DenseMatrix (hidden 8, hidden 8): [-0.30954935594367816,0.12116378595232227,-0.306386786291067,-0.5220864721581343,-0.003000295537359033,-0.37628978694357235,-0.2695812643245417,0.6700734618145359,0.20645751601752704,-0.05260439493400403,0.17754381848591488,-0.17735111718384539,0.167779299276608,0.27389750680988467,0.23764308545771937,1.6708105311783883,0.06728854514002967,-0.44839216787805547,0.13632607512869982,0.37582277585252005,-0.28216295775529515,-0.5678759903551267,-0.35640357469197825,0.3560222333683681,-0.06393612839774328,-0.43344284383245607,-0.3986854818366697,0.26732202379329006,0.34858445673548166,0.01968688359473008,0.7123181505558769,-0.6411211857602117,-0.271134106235963,0.2491391197457063,0.0939090724621519,-0.7099186482668389,0.4388952099959059,0.8554823355249708,-1.166165039369413,1.1988479095710083,-0.2320232541561487,-0.08166507448918221,-0.41733625408668373,-0.6475038987362935,-0.4785141909120537,-0.6513421097528718,0.7196243595671512,0.0627929089576402,0.08109030285230427,-2.3186007809859356E-4,0.06453881213468693,0.7580057257515331,-0.12080130803301743,-0.17046938190024785,0.28216868068309153,0.005635712488808225,-0.45095495159758453,0.13584676111699664,-0.73411011182003,-0.22715114851512125,-0.5224650791751124,0.8293987840564809,0.021505674769915967,-0.2903500417602574]
_tmp6 = param Tensor (hidden 8): [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]
edgeSrc = from A
edgeDst = to A
_tmp1 = h0 @ _tmp2
_tmp0 = _tmp1 + _tmp3
h1 = relu _tmp0
_tmp4 = h1 @ _tmp5
h2 = _tmp4 + _tmp6
_tmp7 = h2 [ edgeSrc ]
_tmp8 = h2 [ edgeDst ]
message2 = _tmp7 | _tmp8
_tmp10 = message2 @ _tmp11
_tmp9 = _tmp10 + _tmp12
transformed2 = relu _tmp9
received2 = reduce ( transformed2 , A )
_tmp14 = received2 | h2
_tmp15 = _tmp14 @ _tmp16
_tmp13 = _tmp15 + _tmp17
i2 = relu _tmp13
_tmp19 = i2 @ _tmp20
_tmp18 = _tmp19 + _tmp21
h3 = relu _tmp18
z3 = sort ( h3 , reduced )
_tmp22 = h3 [ z3 ]
h4 = reshape ( _tmp22 , 1 , hiddenReduced )
h5 = h4 @ _tmp23
h6 = softmax ( h5 , row )

return h6
99 changes: 99 additions & 0 deletions JGNN/src/examples/graphClassification/MessageSortPooling.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package graphClassification;

import java.util.Arrays;

import mklab.JGNN.adhoc.ModelBuilder;
import mklab.JGNN.adhoc.parsers.LayeredBuilder;
import mklab.JGNN.core.Matrix;
import mklab.JGNN.core.Tensor;
import mklab.JGNN.core.ThreadPool;
import mklab.JGNN.nn.Loss;
import mklab.JGNN.nn.Model;
import mklab.JGNN.nn.initializers.XavierNormal;
import mklab.JGNN.nn.loss.CategoricalCrossEntropy;
import mklab.JGNN.nn.optimizers.Adam;
import mklab.JGNN.nn.optimizers.BatchOptimizer;

/**
*
* @author github.com/gavalian
* @author Emmanouil Krasanakis
*/
public class MessageSortPooling {

public static void main(String[] args){
long reduced = 5; // input graphs need to have at least that many nodes, lower values decrease accuracy
long hidden = 8; // since this library does not use GPU parallelization, many latent dims reduce speed

ModelBuilder builder = new LayeredBuilder()
.var("A")
.config("features", 1)
.config("classes", 2)
//.config("reduced", reduced)
.config("hidden", hidden)
.config("2hidden", 2*hidden)
.config("reg", 0.005)
.operation("edgeSrc = from(A)")
.operation("edgeDst = to(A)")
.layer("h{l+1}=relu(h{l}@matrix(features, hidden, reg)+vector(hidden))")
.layer("h{l+1}=h{l}@matrix(hidden, hidden, reg)+vector(hidden)")

// message passing layer (make it as complex as needed)
.operation("message{l}=h{l}[edgeSrc] | h{l}[edgeDst]")
.operation("transformed{l}=relu(message{l}@matrix(2hidden, hidden, reg)+vector(hidden))")
.operation("received{l}=reduce(transformed{l}, A)")
.operation("i{l}=relu((received{l} | h{l})@matrix(2hidden, hidden, reg)+vector(hidden))")
.layer("h{l+1}=relu(i{l}@matrix(hidden, hidden, reg)+vector(hidden))")

// this would be the sort pooling
/*.config("hiddenReduced", hidden*reduced) // reduced * (previous layer's output size)
.operation("z{l}=sort(h{l}, reduced)") // currently, the parser fails to understand full expressions within next step's gather, so we need to create this intermediate variable
.layer("h{l+1}=reshape(h{l}[z{l}], 1, hiddenReduced)") //
.layer("h{l+1}=h{l}@matrix(hiddenReduced, classes)")
.layer("h{l+1}=softmax(h{l}, row)")*/

// the following two layers implement the sum pooling
.layer("h{l+1}=sum(h{l}@matrix(hidden, classes)+vector(classes), row)")
.layer("h{l+1}=softmax(h{l}, row)")

.out("h{l}");

TrajectoryData dtrain = new TrajectoryData(800);
TrajectoryData dtest = new TrajectoryData(200);

Model model = builder.getModel().init(new XavierNormal());
BatchOptimizer optimizer = new BatchOptimizer(new Adam(0.01));
Loss loss = new CategoricalCrossEntropy();
for(int epoch=0; epoch<600; epoch++) {
// gradient update over all graphs
for(int graphId=0; graphId<dtrain.graphs.size(); graphId++) {
int graphIdentifier = graphId;
// each gradient calculation into a new thread pool task
ThreadPool.getInstance().submit(new Runnable() {
@Override
public void run() {
//System.out.println(dtrain.graphs.get(graphIdentifier).sum());
Matrix adjacency = dtrain.graphs.get(graphIdentifier);
Matrix features= dtrain.features.get(graphIdentifier);
Tensor graphLabel = dtrain.labels.get(graphIdentifier).asRow();
model.train(loss, optimizer,
Arrays.asList(features, adjacency),
Arrays.asList(graphLabel));
}
});
}
ThreadPool.getInstance().waitForConclusion(); // wait for all gradients to compute
optimizer.updateAll(); // apply gradients on model parameters

double acc = 0.0;
for(int graphId=0; graphId<dtest.graphs.size(); graphId++) {
Matrix adjacency = dtest.graphs.get(graphId);
Matrix features= dtest.features.get(graphId);
Tensor graphLabel = dtest.labels.get(graphId);
if(model.predict(Arrays.asList(features, adjacency)).get(0).argmax()==graphLabel.argmax())
acc += 1;
}
System.out.println("iter = " + epoch + " " + acc/dtest.graphs.size());
}
}
}
2 changes: 2 additions & 0 deletions JGNN/src/examples/graphClassification/SortPooling.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public static void main(String[] args){
.layer("h{l+1}=reshape(h{l}[z{l}], 1, hiddenReduced)") //
.layer("h{l+1}=h{l}@matrix(hiddenReduced, classes)")
.layer("h{l+1}=softmax(h{l}, row)")
//.layer("h{l+1}=softmax(sum(h{l}@matrix(hiddenReduced, classes), row))")//this is mean pooling to replace the above sort pooling
.out("h{l}");

TrajectoryData dtrain = new TrajectoryData(8000);
Expand All @@ -56,6 +57,7 @@ public static void main(String[] args){
ThreadPool.getInstance().submit(new Runnable() {
@Override
public void run() {
//System.out.println(dtrain.graphs.get(graphIdentifier).sum());
Matrix adjacency = dtrain.graphs.get(graphIdentifier);
Matrix features= dtrain.features.get(graphIdentifier);
Tensor graphLabel = dtrain.labels.get(graphIdentifier).asRow();
Expand Down
2 changes: 1 addition & 1 deletion JGNN/src/examples/nodeClassification/APPNP.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public static void main(String[] args) throws Exception {
long numClasses = dataset.labels().getCols();
ModelBuilder modelBuilder = new FastBuilder(dataset.graph(), dataset.features())
.config("reg", 0.005)
.config("hidden", 64)
.config("hidden", 8)
.config("classes", numClasses)
.layer("h{l+1}=relu(h{l}@matrix(features, hidden, reg)+vector(hidden))")
.layer("h{l+1}=h{l}@matrix(hidden, classes)+vector(classes)")
Expand Down
8 changes: 4 additions & 4 deletions JGNN/src/examples/nodeClassification/MessagePassing.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
public class MessagePassing {
public static void main(String[] args) throws Exception {
Dataset dataset = new Cora();
dataset.graph().setMainDiagonal(1).setToSymmetricNormalization();
dataset.graph().setToSymmetricNormalization();

long numClasses = dataset.labels().getCols();
ModelBuilder modelBuilder = new FastBuilder(dataset.graph(), dataset.features())
Expand All @@ -33,8 +33,8 @@ public static void main(String[] args) throws Exception {
.config("2hidden", (numClasses+2)*2)
.operation("u = from(A)")
.operation("v = to(A)")
.layer("h{l+1}=relu(h{l}@matrix(features, hidden)+vector(hidden))")
.layer("h{l+1}=h{l}@matrix(hidden, hidden)+vector(hidden)")
.layer("h{l+1}=relu(h{l}@matrix(features, hidden, reg)+vector(hidden))")
.layer("h{l+1}=h{l}@matrix(hidden, hidden, reg)+vector(hidden)")
.operation("m{l}_1=h{l}[u] | h{l}[v]")
.operation("m{l}_2=relu(m{l}_1@matrix(2hidden, hidden, reg)+vector(hidden))")
.operation("m{l}_3=m{l}_2@matrix(hidden, hidden, reg)+vector(hidden)")
Expand All @@ -53,7 +53,7 @@ public static void main(String[] args) throws Exception {
ModelTraining trainer = new ModelTraining()
.setOptimizer(new Adam(0.01))
.setEpochs(300)
.setPatience(100)
.setPatience(10)
.setVerbose(true)
.setLoss(new CategoricalCrossEntropy())
.setValidationLoss(new CategoricalCrossEntropy());
Expand Down
27 changes: 27 additions & 0 deletions JGNN/src/main/java/mklab/JGNN/core/Matrix.java
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,33 @@ public Matrix setToSymmetricNormalization() {
return this;
}


/**
* Sets the Matrix to its asymmetrically normalized transformation
* by appropriately adjusting its element values.
* @return <code>this</code> Matrix instance.
* @see #symmetricNormalization()
*/
public Matrix setToASymmetricNormalization() {
HashMap<Long, Double> outDegrees = new HashMap<Long, Double>();
HashMap<Long, Double> inDegrees = new HashMap<Long, Double>();
for(Entry<Long,Long> element : getNonZeroEntries()) {
long row = element.getKey();
long col = element.getValue();
double value = get(row, col);
outDegrees.put(row, outDegrees.getOrDefault(row, 0.)+value);
inDegrees.put(col, inDegrees.getOrDefault(col, 0.)+value);
}
for(Entry<Long,Long> element : getNonZeroEntries()) {
long row = element.getKey();
long col = element.getValue();
double div = inDegrees.get(col);
if(div!=0)
put(row, col, get(row, col)/div);
}
return this;
}

/**
* Retrieves either the given row or column as a trensor.
* @param index The dimension index to access.
Expand Down
Loading

0 comments on commit 8463fb9

Please sign in to comment.