diff --git a/.classpath b/.classpath index f33cd730..fa36f04a 100644 --- a/.classpath +++ b/.classpath @@ -11,5 +11,7 @@ + + diff --git a/bin/run_mixserv.sh b/bin/run_mixserv.sh new file mode 100644 index 00000000..686d2d2f --- /dev/null +++ b/bin/run_mixserv.sh @@ -0,0 +1,5 @@ +#!/bin/sh + +VMOPTS="-Xmx4g -da -server -XX:+PrintGCDetails -XX:+UseNUMA -XX:+UseParallelGC $VMOPTS" + +java ${VMOPTS} -jar hivemall-fat.jar $@ diff --git a/build.properties b/build.properties index 000f14e6..52b05013 100644 --- a/build.properties +++ b/build.properties @@ -2,8 +2,6 @@ # Project property file for Hivemall # ###################################### -version=0.1 - lib.dir=lib/ src.dir=src/main test.dir=src/test @@ -12,7 +10,19 @@ build.dir=${target.dir}/classes test.build.dir=${target.dir}/test-classes test.result.dir=${target.dir}/test-results -## javac: -------------------------------------------- +## MANIFEST: ---------------------------------------- + +user.name=myui +#java.version=1.6 + +project.version=0.3 +project.name=Hivemall +project.groupId=hivemall +project.organization.name=AIST, Japan. + +jar.mainclass=hivemall.mix.server.MixServer + +## javac: ------------------------------------------- javac.source=1.6 javac.target=1.6 @@ -20,12 +30,7 @@ javac.debug=on javac.debuglevel=lines,source,vars #bootclasspath=/jre/lib/rt.jar -## jar: -------------------------------------------- - -jar.title=Hivemall -jar.vendor=AIST, Japan. - -## javadoc: -------------------------------------------- +## javadoc: ----------------------------------------- javadoc.dstdir=${target.dir}/docs/api javadoc.title=Hivemall API diff --git a/build.xml b/build.xml index 76d4680d..30bf95ea 100644 --- a/build.xml +++ b/build.xml @@ -1,5 +1,5 @@ - + @@ -38,7 +38,7 @@ - + @@ -46,16 +46,57 @@ - + + + + + + - - - + + + + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -66,7 +107,7 @@ - + diff --git a/lib/jsr305-1.3.9.jar b/lib/jsr305-1.3.9.jar new file mode 100644 index 00000000..a9afc661 Binary files /dev/null and b/lib/jsr305-1.3.9.jar differ diff --git a/lib/license/netty.LICENSE b/lib/license/netty.LICENSE new file mode 100644 index 00000000..d6456956 --- /dev/null +++ b/lib/license/netty.LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/lib/netty-all-4.0.23.Final.jar b/lib/netty-all-4.0.23.Final.jar new file mode 100644 index 00000000..0555a164 Binary files /dev/null and b/lib/netty-all-4.0.23.Final.jar differ diff --git a/lib/source/netty-all-4.0.23.Final-sources.jar b/lib/source/netty-all-4.0.23.Final-sources.jar new file mode 100644 index 00000000..b04c08cf Binary files /dev/null and b/lib/source/netty-all-4.0.23.Final-sources.jar differ diff --git a/pom.xml b/pom.xml index 91d9b6e5..062e173c 100644 --- a/pom.xml +++ b/pom.xml @@ -1,95 +1,254 @@ - 4.0.0 + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + 4.0.0 - hivemall - hivemall - 0.2 - jar + hivemall + hivemall + 0.3 - hivemall - http://github.com/myui/hivemall + Hivemall + Scalable Machine Learning Library for Apache Hive + https://github.com/myui/hivemall/ + 2013 + + AIST, Japan. + - - UTF-8 - + + + LGPL 2.1 license + http://www.opensource.org/licenses/lgpl-2.1.php + + - - - cloudera - https://repository.cloudera.com/artifactory/cloudera-repos/ - - + jar + + true + UTF-8 + - - - org.apache.hadoop - hadoop-core - 0.20.2-cdh3u6 - provided - - - org.apache.hive - hive-exec - 0.11.0 - provided - - - jetty - org.mortbay.jetty - - - javax.jdo - jdo2-api - - - - - org.apache.hive - hive-serde - 0.11.0 - provided - - - commons-cli - commons-cli - 1.2 - provided - - - javax.jdo - jdo2-api - 2.3-eb - + + + cloudera + https://repository.cloudera.com/artifactory/cloudera-repos/ + + - - junit - junit - 4.10 - test - - + - - target - target/classes - ${project.artifactId}-${project.version} - target/test-classes - src/main - src/test + + + org.apache.hadoop + hadoop-core + 0.20.2-cdh3u6 + provided + + + org.apache.hive + hive-exec + 0.11.0 + provided + + + jetty + org.mortbay.jetty + + + javax.jdo + jdo2-api + + + + + org.apache.hive + hive-serde + 0.11.0 + provided + + + commons-cli + commons-cli + 1.2 + provided + + + commons-logging + commons-logging + 1.0.4 + provided + + + javax.jdo + jdo2-api + 2.3-eb + provided + + + org.apache.hadoop.thirdparty.guava + guava + r09-jarjar + provided + - - - org.apache.maven.plugins - maven-compiler-plugin - 3.1 - - 1.6 - 1.6 - UTF-8 - - - - + + + io.netty + + netty-all + 4.0.23.Final + compile + + + com.google.code.findbugs + jsr305 + 1.3.9 + compile + + + + + junit + junit + 4.10 + test + + + + + + target + target/classes + ${project.artifactId}-${project.version} + target/test-classes + src/main + src/test + + + + org.apache.maven.plugins + maven-failsafe-plugin + 2.17 + + ${skipTests} + + + + org.codehaus.mojo + properties-maven-plugin + 1.0-alpha-2 + + + initialize + + read-project-properties + + + + ${basedir}/build.properties + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.1 + + ${javac.source} + ${javac.target} + UTF-8 + + + + + org.apache.maven.plugins + maven-jar-plugin + 2.5 + + ${project.artifactId}-${project.version} + + + ${jar.mainclass} + true + + + + + + + org.apache.maven.plugins + maven-shade-plugin + 2.3 + + + jar-with-dependencies + package + + shade + + + ${project.artifactId}-${project.version}-with-dependencies + true + false + + + io.netty:netty-all + com.google.code.findbugs:jsr305 + + + junit:junit + javax.jdo:jdo2-api + org.apache.hadoop.thirdparty.guava:guava + + + + + + ${jar.mainclass} + ${project.name} + ${project.version} + ${project.organization.name} + + + + + + + + + + org.sonatype.plugins + jarjar-maven-plugin + 1.9 + + + package + + jarjar + + + ${project.build.directory}/${project.artifactId}-${project.version}-fat.jar + + org.apache.hadoop:hadoop-core + org.apache.hive:hive-exec + org.apache.hive:hive-serde + commons-cli:commons-cli + commons-logging:commons-logging + com.google.code.findbugs:jsr305 + org.apache.hadoop.thirdparty.guava:guava + + + junit:junit + javax.jdo:jdo2-api + + + + + + + + diff --git a/scripts/ddl/define-all.hive b/scripts/ddl/define-all.hive index 60548282..8aafc585 100644 --- a/scripts/ddl/define-all.hive +++ b/scripts/ddl/define-all.hive @@ -313,12 +313,18 @@ create temporary function sigmoid as 'hivemall.tools.math.SigmodUDF'; drop temporary function taskid; create temporary function taskid as 'hivemall.tools.mapred.TaskIdUDF'; +drop temporary function jobid; +create temporary function jobid as 'hivemall.tools.mapred.JobIdUDF'; + drop temporary function rowid; create temporary function rowid as 'hivemall.tools.mapred.RowIdUDF'; drop temporary function distcache_gets; create temporary function distcache_gets as 'hivemall.tools.mapred.DistributedCacheLookupUDF'; +drop temporary function jobconf_gets; +create temporary function jobconf_gets as 'hivemall.tools.mapred.JobConfGetsUDF'; + -------------------- -- misc functions -- -------------------- diff --git a/scripts/ddl/define-tools-udf.hive b/scripts/ddl/define-tools-udf.hive index 24ffde91..65522bd9 100644 --- a/scripts/ddl/define-tools-udf.hive +++ b/scripts/ddl/define-tools-udf.hive @@ -68,12 +68,18 @@ create temporary function sigmoid as 'hivemall.tools.math.SigmodUDF'; drop temporary function taskid; create temporary function taskid as 'hivemall.tools.mapred.TaskIdUDF'; +drop temporary function jobid; +create temporary function jobid as 'hivemall.tools.mapred.JobIdUDF'; + drop temporary function rowid; create temporary function rowid as 'hivemall.tools.mapred.RowIdUDF'; drop temporary function distcache_gets; create temporary function distcache_gets as 'hivemall.tools.mapred.DistributedCacheLookupUDF'; +drop temporary function jobconf_gets; +create temporary function jobconf_gets as 'hivemall.tools.mapred.JobConfGetsUDF'; + ---------------------- -- string functions -- ---------------------- diff --git a/scripts/misc/emr_hivemall_bootstrap.sh b/scripts/misc/emr_hivemall_bootstrap.sh index 652ad1c3..ad41fbf3 100644 --- a/scripts/misc/emr_hivemall_bootstrap.sh +++ b/scripts/misc/emr_hivemall_bootstrap.sh @@ -2,4 +2,4 @@ mkdir -p /home/hadoop/tmp wget --no-check-certificate -P /home/hadoop/tmp \ - https://github.com/myui/hivemall/raw/master/target/hivemall.jar https://github.com/myui/hivemall/raw/master/scripts/ddl/define-all.hive + https://github.com/myui/hivemall/raw/master/target/hivemall-with-dependencies.jar https://github.com/myui/hivemall/raw/master/scripts/ddl/define-all.hive diff --git a/src/main/hivemall/LearnerBaseUDTF.java b/src/main/hivemall/LearnerBaseUDTF.java index 9da30073..615e3d6e 100644 --- a/src/main/hivemall/LearnerBaseUDTF.java +++ b/src/main/hivemall/LearnerBaseUDTF.java @@ -25,11 +25,15 @@ import hivemall.io.PredictionModel; import hivemall.io.SpaceEfficientDenseModel; import hivemall.io.SparseModel; +import hivemall.io.SynchronizedModelWrapper; import hivemall.io.WeightValue; import hivemall.io.WeightValue.WeightValueWithCovar; +import hivemall.mix.MixMessage.MixEventName; +import hivemall.mix.client.MixClient; import hivemall.utils.datetime.StopWatch; import hivemall.utils.hadoop.HadoopUtils; import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.io.IOUtils; import hivemall.utils.lang.Primitives; import java.io.BufferedReader; @@ -42,6 +46,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.SerDeException; import org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; @@ -59,6 +64,12 @@ public abstract class LearnerBaseUDTF extends UDTFWithOptions { protected boolean dense_model; protected int model_dims; protected boolean disable_halffloat; + protected String mixConnectInfo; + protected String mixSessionName; + protected int mixThreshold; + protected boolean ssl; + + protected MixClient mixClient; public LearnerBaseUDTF() {} @@ -73,6 +84,10 @@ protected Options getOptions() { opts.addOption("dense", "densemodel", false, "Use dense model or not"); opts.addOption("dims", "feature_dimensions", true, "The dimension of model [default: 16777216 (2^24)]"); opts.addOption("disable_halffloat", false, "Toggle this option to disable the use of SpaceEfficientDenseModel"); + opts.addOption("mix", "mix_servers", true, "Comma separated list of MIX servers"); + opts.addOption("mix_session", "mix_session_name", true, "Mix session name [default: ${mapred.job.id}]"); + opts.addOption("mix_threshold", true, "Threshold to mix local updates in range (0,127] [default: 3]"); + opts.addOption("ssl", false, "Use SSL for the communication with mix servers"); return opts; } @@ -82,6 +97,10 @@ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumen boolean denseModel = false; int modelDims = -1; boolean disableHalfFloat = false; + String mixConnectInfo = null; + String mixSessionName = null; + int mixThreshold = -1; + boolean ssl = false; CommandLine cl = null; if(argOIs.length >= 3) { @@ -96,33 +115,73 @@ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumen } disableHalfFloat = cl.hasOption("disable_halffloat"); + + mixConnectInfo = cl.getOptionValue("mix"); + mixSessionName = cl.getOptionValue("mix_session"); + mixThreshold = Primitives.parseInt(cl.getOptionValue("mix_threshold"), 3); + if(mixThreshold > Byte.MAX_VALUE) { + throw new UDFArgumentException("mix_threshold must be in range (0,127]: " + + mixThreshold); + } + ssl = cl.hasOption("ssl"); } this.preloadedModelFile = modelfile; this.dense_model = denseModel; this.model_dims = modelDims; this.disable_halffloat = disableHalfFloat; + this.mixConnectInfo = mixConnectInfo; + this.mixSessionName = mixSessionName; + this.mixThreshold = mixThreshold; + this.ssl = ssl; return cl; } protected PredictionModel createModel() { + return createModel(null); + } + + protected PredictionModel createModel(String label) { + PredictionModel model; + final boolean useCovar = useCovariance(); if(dense_model) { - boolean useCovar = useCovariance(); if(disable_halffloat == false && model_dims > 16777216) { logger.info("Build a space efficient dense model with " + model_dims + " initial dimensions" + (useCovar ? " w/ covariances" : "")); - return new SpaceEfficientDenseModel(model_dims, useCovar); + model = new SpaceEfficientDenseModel(model_dims, useCovar); } else { logger.info("Build a dense model with initial with " + model_dims + " initial dimensions" + (useCovar ? " w/ covariances" : "")); - return new DenseModel(model_dims, useCovar); + model = new DenseModel(model_dims, useCovar); } } else { int initModelSize = getInitialModelSize(); logger.info("Build a sparse model with initial with " + initModelSize + " initial dimensions"); - return new SparseModel(initModelSize); + model = new SparseModel(initModelSize, useCovar); + } + if(mixConnectInfo != null) { + model.configureClock(); + model = new SynchronizedModelWrapper(model); + MixClient client = configureMixClient(mixConnectInfo, label, model); + model.setUpdateHandler(client); + this.mixClient = client; + } + assert (model != null); + return model; + } + + protected MixClient configureMixClient(String connectURIs, String label, PredictionModel model) { + assert (connectURIs != null); + assert (model != null); + String jobId = (mixSessionName == null) ? MixClient.DUMMY_JOB_ID : mixSessionName; + if(label != null) { + jobId = jobId + '-' + label; } + MixEventName event = useCovariance() ? MixEventName.argminKLD : MixEventName.average; + MixClient client = new MixClient(event, jobId, connectURIs, ssl, mixThreshold, model); + logger.info("Successfully configured mix client: " + connectURIs); + return client; } protected int getInitialModelSize() { @@ -241,4 +300,13 @@ private static long loadPredictionModel(PredictionModel model, File file, Primit } return count; } + + @Override + public void close() throws HiveException { + if(mixClient != null) { + IOUtils.closeQuietly(mixClient); + this.mixClient = null; + } + } + } diff --git a/src/main/hivemall/UDTFWithOptions.java b/src/main/hivemall/UDTFWithOptions.java index efa24ee0..f392e68b 100644 --- a/src/main/hivemall/UDTFWithOptions.java +++ b/src/main/hivemall/UDTFWithOptions.java @@ -50,7 +50,7 @@ protected final CommandLine parseOptions(String optionValue) throws UDFArgumentE protected abstract CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException; - protected List parseFeatures(final List features, final ObjectInspector featureInspector, final boolean parseFeature) { + protected final List parseFeatures(final List features, final ObjectInspector featureInspector, final boolean parseFeature) { final int numFeatures = features.size(); if(numFeatures == 0) { return Collections.emptyList(); diff --git a/src/main/hivemall/classifier/AROWClassifierUDTF.java b/src/main/hivemall/classifier/AROWClassifierUDTF.java index 51ba1a3c..e68d053c 100644 --- a/src/main/hivemall/classifier/AROWClassifierUDTF.java +++ b/src/main/hivemall/classifier/AROWClassifierUDTF.java @@ -22,8 +22,8 @@ import hivemall.common.LossFunctions; import hivemall.io.FeatureValue; +import hivemall.io.IWeightValue; import hivemall.io.PredictionResult; -import hivemall.io.WeightValue; import hivemall.io.WeightValue.WeightValueWithCovar; import java.util.List; @@ -126,13 +126,13 @@ protected void update(final List features, final float y, final float alpha, k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector); v = 1.f; } - WeightValue old_w = model.get(k); - WeightValue new_w = getNewWeight(old_w, v, y, alpha, beta); + IWeightValue old_w = model.get(k); + IWeightValue new_w = getNewWeight(old_w, v, y, alpha, beta); model.set(k, new_w); } } - private static WeightValue getNewWeight(final WeightValue old, final float x, final float y, final float alpha, final float beta) { + private static IWeightValue getNewWeight(final IWeightValue old, final float x, final float y, final float alpha, final float beta) { final float old_w; final float old_cov; if(old == null) { diff --git a/src/main/hivemall/classifier/BinaryOnlineClassifierUDTF.java b/src/main/hivemall/classifier/BinaryOnlineClassifierUDTF.java index f28aead2..56c40fc0 100644 --- a/src/main/hivemall/classifier/BinaryOnlineClassifierUDTF.java +++ b/src/main/hivemall/classifier/BinaryOnlineClassifierUDTF.java @@ -25,6 +25,7 @@ import static hivemall.HivemallConstants.STRING_TYPE_NAME; import hivemall.LearnerBaseUDTF; import hivemall.io.FeatureValue; +import hivemall.io.IWeightValue; import hivemall.io.PredictionModel; import hivemall.io.PredictionResult; import hivemall.io.WeightValue; @@ -219,7 +220,7 @@ protected PredictionResult calcScoreAndVariance(List features) { k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector); v = 1.f; } - WeightValue old_w = model.get(k); + IWeightValue old_w = model.get(k); if(old_w == null) { variance += (1.f * v * v); } else { @@ -259,7 +260,8 @@ protected void update(final List features, final float coeff) { } @Override - public void close() throws HiveException { + public final void close() throws HiveException { + super.close(); if(model != null) { int numForwarded = 0; if(useCovariance()) { @@ -267,7 +269,7 @@ public void close() throws HiveException { final Object[] forwardMapObj = new Object[3]; final FloatWritable fv = new FloatWritable(); final FloatWritable cov = new FloatWritable(); - final IMapIterator itor = model.entries(); + final IMapIterator itor = model.entries(); while(itor.next() != -1) { itor.getValue(probe); if(!probe.isTouched()) { @@ -286,7 +288,7 @@ public void close() throws HiveException { final WeightValue probe = new WeightValue(); final Object[] forwardMapObj = new Object[2]; final FloatWritable fv = new FloatWritable(); - final IMapIterator itor = model.entries(); + final IMapIterator itor = model.entries(); while(itor.next() != -1) { itor.getValue(probe); if(!probe.isTouched()) { @@ -300,10 +302,11 @@ public void close() throws HiveException { numForwarded++; } } + int numMixed = model.getNumMixed(); this.model = null; - logger.info("Trained a prediction model using " + count - + " training examples. Forwarded the prediction model of " + numForwarded - + " rows"); + logger.info("Trained a prediction model using " + count + " training examples" + + (numMixed > 0 ? "( numMixed: " + numMixed + " )" : "")); + logger.info("Forwarded the prediction model of " + numForwarded + " rows"); } } diff --git a/src/main/hivemall/classifier/ConfidenceWeightedUDTF.java b/src/main/hivemall/classifier/ConfidenceWeightedUDTF.java index f45264cd..b2f3be57 100644 --- a/src/main/hivemall/classifier/ConfidenceWeightedUDTF.java +++ b/src/main/hivemall/classifier/ConfidenceWeightedUDTF.java @@ -21,8 +21,8 @@ package hivemall.classifier; import hivemall.io.FeatureValue; +import hivemall.io.IWeightValue; import hivemall.io.PredictionResult; -import hivemall.io.WeightValue; import hivemall.io.WeightValue.WeightValueWithCovar; import hivemall.utils.math.StatsUtils; @@ -141,13 +141,13 @@ protected void update(final List features, final float coeff, final float alp k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector); v = 1.f; } - WeightValue old_w = model.get(k); - WeightValue new_w = getNewWeight(old_w, v, coeff, alpha, phi); + IWeightValue old_w = model.get(k); + IWeightValue new_w = getNewWeight(old_w, v, coeff, alpha, phi); model.set(k, new_w); } } - private static WeightValue getNewWeight(final WeightValue old, final float x, final float coeff, final float alpha, final float phi) { + private static IWeightValue getNewWeight(final IWeightValue old, final float x, final float coeff, final float alpha, final float phi) { final float old_w, old_cov; if(old == null) { old_w = 0.f; diff --git a/src/main/hivemall/classifier/SoftConfideceWeightedUDTF.java b/src/main/hivemall/classifier/SoftConfideceWeightedUDTF.java index e0b8b437..db3eb94c 100644 --- a/src/main/hivemall/classifier/SoftConfideceWeightedUDTF.java +++ b/src/main/hivemall/classifier/SoftConfideceWeightedUDTF.java @@ -21,8 +21,8 @@ package hivemall.classifier; import hivemall.io.FeatureValue; +import hivemall.io.IWeightValue; import hivemall.io.PredictionResult; -import hivemall.io.WeightValue; import hivemall.io.WeightValue.WeightValueWithCovar; import hivemall.utils.math.StatsUtils; @@ -251,13 +251,13 @@ protected void update(final List features, final float y, final float alpha, k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector); v = 1.f; } - WeightValue old_w = model.get(k); - WeightValue new_w = getNewWeight(old_w, v, y, alpha, beta); + IWeightValue old_w = model.get(k); + IWeightValue new_w = getNewWeight(old_w, v, y, alpha, beta); model.set(k, new_w); } } - private static WeightValue getNewWeight(final WeightValue old, final float x, final float y, final float alpha, final float beta) { + private static IWeightValue getNewWeight(final IWeightValue old, final float x, final float y, final float alpha, final float beta) { final float old_v; final float old_cov; if(old == null) { diff --git a/src/main/hivemall/classifier/multiclass/MulticlassAROWClassifierUDTF.java b/src/main/hivemall/classifier/multiclass/MulticlassAROWClassifierUDTF.java index c2f3478d..daa46cfb 100644 --- a/src/main/hivemall/classifier/multiclass/MulticlassAROWClassifierUDTF.java +++ b/src/main/hivemall/classifier/multiclass/MulticlassAROWClassifierUDTF.java @@ -21,9 +21,9 @@ package hivemall.classifier.multiclass; import hivemall.io.FeatureValue; +import hivemall.io.IWeightValue; import hivemall.io.Margin; import hivemall.io.PredictionModel; -import hivemall.io.WeightValue; import hivemall.io.WeightValue.WeightValueWithCovar; import java.util.List; @@ -142,19 +142,19 @@ protected void update(final List features, final Object actual_label, final O k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector); v = 1.f; } - WeightValue old_correctclass_w = model2add.get(k); - WeightValue new_correctclass_w = getNewWeight(old_correctclass_w, v, alpha, beta, true); + IWeightValue old_correctclass_w = model2add.get(k); + IWeightValue new_correctclass_w = getNewWeight(old_correctclass_w, v, alpha, beta, true); model2add.set(k, new_correctclass_w); if(model2sub != null) { - WeightValue old_wrongclass_w = model2sub.get(k); - WeightValue new_wrongclass_w = getNewWeight(old_wrongclass_w, v, alpha, beta, false); + IWeightValue old_wrongclass_w = model2sub.get(k); + IWeightValue new_wrongclass_w = getNewWeight(old_wrongclass_w, v, alpha, beta, false); model2sub.set(k, new_wrongclass_w); } } } - private static WeightValue getNewWeight(final WeightValue old, final float v, final float alpha, final float beta, final boolean positive) { + private static IWeightValue getNewWeight(final IWeightValue old, final float v, final float alpha, final float beta, final boolean positive) { final float old_v; final float old_cov; if(old == null) { diff --git a/src/main/hivemall/classifier/multiclass/MulticlassConfidenceWeightedUDTF.java b/src/main/hivemall/classifier/multiclass/MulticlassConfidenceWeightedUDTF.java index c77b2bcb..71d538a2 100644 --- a/src/main/hivemall/classifier/multiclass/MulticlassConfidenceWeightedUDTF.java +++ b/src/main/hivemall/classifier/multiclass/MulticlassConfidenceWeightedUDTF.java @@ -21,9 +21,9 @@ package hivemall.classifier.multiclass; import hivemall.io.FeatureValue; +import hivemall.io.IWeightValue; import hivemall.io.Margin; import hivemall.io.PredictionModel; -import hivemall.io.WeightValue; import hivemall.io.WeightValue.WeightValueWithCovar; import hivemall.utils.math.StatsUtils; @@ -163,19 +163,19 @@ protected void update(List features, float alpha, Object actual_label, Object k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector); v = 1.f; } - WeightValue old_correctclass_w = model2add.get(k); - WeightValue new_correctclass_w = getNewWeight(old_correctclass_w, v, alpha, phi, true); + IWeightValue old_correctclass_w = model2add.get(k); + IWeightValue new_correctclass_w = getNewWeight(old_correctclass_w, v, alpha, phi, true); model2add.set(k, new_correctclass_w); if(model2sub != null) { - WeightValue old_wrongclass_w = model2sub.get(k); - WeightValue new_wrongclass_w = getNewWeight(old_wrongclass_w, v, alpha, phi, false); + IWeightValue old_wrongclass_w = model2sub.get(k); + IWeightValue new_wrongclass_w = getNewWeight(old_wrongclass_w, v, alpha, phi, false); model2sub.set(k, new_wrongclass_w); } } } - private static WeightValue getNewWeight(final WeightValue old, final float x, final float alpha, final float phi, final boolean positive) { + private static IWeightValue getNewWeight(final IWeightValue old, final float x, final float alpha, final float phi, final boolean positive) { final float old_w, old_cov; if(old == null) { old_w = 0.f; diff --git a/src/main/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java b/src/main/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java index 87e3300e..334a570c 100644 --- a/src/main/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java +++ b/src/main/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java @@ -26,6 +26,7 @@ import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableFloatObjectInspector; import hivemall.LearnerBaseUDTF; import hivemall.io.FeatureValue; +import hivemall.io.IWeightValue; import hivemall.io.Margin; import hivemall.io.PredictionModel; import hivemall.io.PredictionResult; @@ -319,7 +320,7 @@ protected final PredictionResult calcScoreAndVariance(final PredictionModel mode k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector); v = 1.f; } - WeightValue old_w = model.get(k); + IWeightValue old_w = model.get(k); if(old_w == null) { variance += (1.f * v * v); } else { @@ -380,9 +381,11 @@ protected void update(List features, float coeff, Object actual_label, Object } @Override - public void close() throws HiveException { + public final void close() throws HiveException { + super.close(); if(label2model != null) { long numForwarded = 0L; + long numMixed = 0L; if(useCovariance()) { final WeightValueWithCovar probe = new WeightValueWithCovar(); final Object[] forwardMapObj = new Object[4]; @@ -392,7 +395,8 @@ public void close() throws HiveException { Object label = entry.getKey(); forwardMapObj[0] = label; PredictionModel model = entry.getValue(); - IMapIterator itor = model.entries(); + numMixed += model.getNumMixed(); + IMapIterator itor = model.entries(); while(itor.next() != -1) { itor.getValue(probe); if(!probe.isTouched()) { @@ -416,7 +420,8 @@ public void close() throws HiveException { Object label = entry.getKey(); forwardMapObj[0] = label; PredictionModel model = entry.getValue(); - IMapIterator itor = model.entries(); + numMixed += model.getNumMixed(); + IMapIterator itor = model.entries(); while(itor.next() != -1) { itor.getValue(probe); if(!probe.isTouched()) { @@ -432,9 +437,9 @@ public void close() throws HiveException { } } this.label2model = null; - logger.info("Trained a prediction model using " + count - + " training examples. Forwarded the prediction model of " + numForwarded - + " rows"); + logger.info("Trained a prediction model using " + count + " training examples" + + (numMixed > 0 ? "( numMixed: " + numMixed + " )" : "")); + logger.info("Forwarded the prediction model of " + numForwarded + " rows"); } } diff --git a/src/main/hivemall/classifier/multiclass/MulticlassSoftConfidenceWeightedUDTF.java b/src/main/hivemall/classifier/multiclass/MulticlassSoftConfidenceWeightedUDTF.java index bfc0816e..b80272eb 100644 --- a/src/main/hivemall/classifier/multiclass/MulticlassSoftConfidenceWeightedUDTF.java +++ b/src/main/hivemall/classifier/multiclass/MulticlassSoftConfidenceWeightedUDTF.java @@ -21,9 +21,9 @@ package hivemall.classifier.multiclass; import hivemall.io.FeatureValue; +import hivemall.io.IWeightValue; import hivemall.io.Margin; import hivemall.io.PredictionModel; -import hivemall.io.WeightValue; import hivemall.io.WeightValue.WeightValueWithCovar; import hivemall.utils.math.StatsUtils; @@ -271,19 +271,19 @@ protected void update(final List features, final Object actual_label, final O k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector); v = 1.f; } - WeightValue old_correctclass_w = model2add.get(k); - WeightValue new_correctclass_w = getNewWeight(old_correctclass_w, v, alpha, beta, true); + IWeightValue old_correctclass_w = model2add.get(k); + IWeightValue new_correctclass_w = getNewWeight(old_correctclass_w, v, alpha, beta, true); model2add.set(k, new_correctclass_w); if(model2sub != null) { - WeightValue old_wrongclass_w = model2sub.get(k); - WeightValue new_wrongclass_w = getNewWeight(old_wrongclass_w, v, alpha, beta, false); + IWeightValue old_wrongclass_w = model2sub.get(k); + IWeightValue new_wrongclass_w = getNewWeight(old_wrongclass_w, v, alpha, beta, false); model2sub.set(k, new_wrongclass_w); } } } - private static WeightValue getNewWeight(final WeightValue old, final float v, final float alpha, final float beta, final boolean positive) { + private static IWeightValue getNewWeight(final IWeightValue old, final float v, final float alpha, final float beta, final boolean positive) { final float old_v; final float old_cov; if(old == null) { diff --git a/src/main/hivemall/io/AbstractPredictionModel.java b/src/main/hivemall/io/AbstractPredictionModel.java new file mode 100644 index 00000000..54df01e6 --- /dev/null +++ b/src/main/hivemall/io/AbstractPredictionModel.java @@ -0,0 +1,99 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.io; + +public abstract class AbstractPredictionModel implements PredictionModel { + public static final byte BYTE0 = 0; + + protected ModelUpdateHandler handler; + protected int numMixed; + + public AbstractPredictionModel() { + this.numMixed = 0; + } + + @Override + public ModelUpdateHandler getUpdateHandler() { + return handler; + } + + @Override + public void setUpdateHandler(ModelUpdateHandler handler) { + this.handler = handler; + } + + @Override + public final int getNumMixed() { + return numMixed; + } + + @Override + public void resetDeltaUpdates(int feature) { + throw new UnsupportedOperationException(); + } + + protected final void onUpdate(final int feature, final float weight, final float covar, final short clock, final int deltaUpdates) { + if(deltaUpdates < 1) { + return; + } + if(handler != null) { + final boolean resetDeltaUpdates; + try { + resetDeltaUpdates = handler.onUpdate(feature, weight, covar, clock, deltaUpdates); + } catch (Exception e) { + throw new RuntimeException(e); + } + if(resetDeltaUpdates) { + resetDeltaUpdates(feature); + } + } + } + + protected final void onUpdate(final Object feature, final IWeightValue value) { + if(handler != null) { + if(!value.isTouched()) { + return; + } + final float weight = value.get(); + final short clock = value.getClock(); + final int deltaUpdates = value.getDeltaUpdates(); + final boolean resetDeltaUpdates; + if(value.hasCovariance()) { + final float covar = value.getCovariance(); + try { + resetDeltaUpdates = handler.onUpdate(feature, weight, covar, clock, deltaUpdates); + } catch (Exception e) { + throw new RuntimeException(e); + } + } else { + try { + resetDeltaUpdates = handler.onUpdate(feature, weight, 1.f, clock, deltaUpdates); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + if(resetDeltaUpdates) { + value.setDeltaUpdates(BYTE0); + } + } + } + +} diff --git a/src/main/hivemall/io/DenseModel.java b/src/main/hivemall/io/DenseModel.java index 8baec931..00141528 100644 --- a/src/main/hivemall/io/DenseModel.java +++ b/src/main/hivemall/io/DenseModel.java @@ -31,18 +31,22 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -public final class DenseModel implements PredictionModel { +public final class DenseModel extends AbstractPredictionModel { private static final Log logger = LogFactory.getLog(DenseModel.class); private int size; private float[] weights; private float[] covars; + private short[] clocks; + private byte[] deltaUpdates; + public DenseModel(int ndims) { this(ndims, false); } public DenseModel(int ndims, boolean withCovar) { + super(); int size = ndims + 1; this.size = size; this.weights = new float[size]; @@ -53,6 +57,31 @@ public DenseModel(int ndims, boolean withCovar) { } else { this.covars = null; } + this.clocks = null; + this.deltaUpdates = null; + } + + @Override + public boolean hasCovariance() { + return covars != null; + } + + @Override + public void configureClock() { + if(clocks == null) { + this.clocks = new short[size]; + this.deltaUpdates = new byte[size]; + } + } + + @Override + public boolean hasClock() { + return clocks != null; + } + + @Override + public void resetDeltaUpdates(int feature) { + deltaUpdates[feature] = 0; } private void ensureCapacity(final int index) { @@ -68,13 +97,17 @@ private void ensureCapacity(final int index) { this.covars = Arrays.copyOf(covars, newSize); Arrays.fill(covars, oldSize, newSize, 1f); } + if(clocks != null) { + this.clocks = Arrays.copyOf(clocks, newSize); + this.deltaUpdates = Arrays.copyOf(deltaUpdates, newSize); + } } } @SuppressWarnings("unchecked") @Override - public T get(Object feature) { - int i = HiveUtils.parseInt(feature); + public T get(Object feature) { + final int i = HiveUtils.parseInt(feature); if(i >= size) { return null; } @@ -86,15 +119,27 @@ public T get(Object feature) { } @Override - public void set(Object feature, T value) { + public void set(Object feature, T value) { int i = HiveUtils.parseInt(feature); ensureCapacity(i); float weight = value.get(); weights[i] = weight; + float covar = 1.f; if(value.hasCovariance()) { - float covar = value.getCovariance(); + covar = value.getCovariance(); covars[i] = covar; } + short clock = 0; + int delta = 0; + if(clocks != null && value.isTouched()) { + clock = (short) (clocks[i] + 1); + clocks[i] = clock; + delta = deltaUpdates[i] + 1; + assert (delta > 0) : delta; + deltaUpdates[i] = (byte) delta; + } + + onUpdate(i, weight, covar, clock, delta); } @Override @@ -116,18 +161,24 @@ public float getCovariance(Object feature) { } @Override - public void setValue(Object feature, float weight) { + public void _set(Object feature, float weight, short clock) { int i = HiveUtils.parseInt(feature); ensureCapacity(i); weights[i] = weight; + clocks[i] = clock; + deltaUpdates[i] = 0; + numMixed++; } @Override - public void setValue(Object feature, float weight, float covar) { + public void _set(Object feature, float weight, float covar, short clock) { int i = HiveUtils.parseInt(feature); ensureCapacity(i); weights[i] = weight; covars[i] = covar; + clocks[i] = clock; + deltaUpdates[i] = 0; + numMixed++; } @Override @@ -147,11 +198,11 @@ public boolean contains(Object feature) { @SuppressWarnings("unchecked") @Override - public IMapIterator entries() { + public IMapIterator entries() { return (IMapIterator) new Itr(); } - private final class Itr implements IMapIterator { + private final class Itr implements IMapIterator { private int cursor; private final WeightValueWithCovar tmpWeight; @@ -181,7 +232,7 @@ public Integer getKey() { } @Override - public WeightValue getValue() { + public IWeightValue getValue() { if(covars == null) { float w = weights[cursor]; WeightValue v = new WeightValue(w); @@ -197,13 +248,13 @@ public WeightValue getValue() { } @Override - public > void getValue(T probe) { + public > void getValue(T probe) { float w = weights[cursor]; tmpWeight.value = w; float cov = 1.f; if(covars != null) { cov = covars[cursor]; - tmpWeight.covariance = cov; + tmpWeight.setCovariance(cov); } tmpWeight.setTouched(w != 0.f || cov != 1.f); probe.copyFrom(tmpWeight); diff --git a/src/main/hivemall/io/IWeightValue.java b/src/main/hivemall/io/IWeightValue.java new file mode 100644 index 00000000..957e964e --- /dev/null +++ b/src/main/hivemall/io/IWeightValue.java @@ -0,0 +1,32 @@ +package hivemall.io; + +import hivemall.utils.lang.Copyable; + +public interface IWeightValue extends Copyable { + + float get(); + + void set(float weight); + + boolean hasCovariance(); + + float getCovariance(); + + void setCovariance(float cov); + + /** + * @return whether touched in training or not + */ + boolean isTouched(); + + void setTouched(boolean touched); + + short getClock(); + + void setClock(short clock); + + byte getDeltaUpdates(); + + void setDeltaUpdates(byte deltaUpdates); + +} \ No newline at end of file diff --git a/src/main/hivemall/io/ModelUpdateHandler.java b/src/main/hivemall/io/ModelUpdateHandler.java new file mode 100644 index 00000000..0214e99d --- /dev/null +++ b/src/main/hivemall/io/ModelUpdateHandler.java @@ -0,0 +1,36 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.io; + +public interface ModelUpdateHandler { + + /** + * @param feature + * @param weight + * @param covar 1.0 by the default + * @param clock 0 by the default + * @param deltaUpdates + * @return reset the deltaUpdates? + */ + boolean onUpdate(Object feature, float weight, float covar, short clock, int deltaUpdates) + throws Exception; + +} diff --git a/src/main/hivemall/io/PredictionModel.java b/src/main/hivemall/io/PredictionModel.java index 09574e67..8f61febf 100644 --- a/src/main/hivemall/io/PredictionModel.java +++ b/src/main/hivemall/io/PredictionModel.java @@ -24,22 +24,36 @@ public interface PredictionModel { - public int size(); + ModelUpdateHandler getUpdateHandler(); - public boolean contains(Object feature); + void setUpdateHandler(ModelUpdateHandler handler); - public T get(Object feature); + int getNumMixed(); - public void set(Object feature, T value); + boolean hasCovariance(); - public float getWeight(Object feature); + void configureClock(); - public float getCovariance(Object feature); + boolean hasClock(); - public void setValue(Object feature, float weight); + void resetDeltaUpdates(int feature); - public void setValue(Object feature, float weight, float covar); + int size(); - public IMapIterator entries(); + boolean contains(Object feature); -} + T get(Object feature); + + void set(Object feature, T value); + + float getWeight(Object feature); + + float getCovariance(Object feature); + + void _set(Object feature, float weight, short clock); + + void _set(Object feature, float weight, float covar, short clock); + + IMapIterator entries(); + +} \ No newline at end of file diff --git a/src/main/hivemall/io/SpaceEfficientDenseModel.java b/src/main/hivemall/io/SpaceEfficientDenseModel.java index 6ece7ea7..48454f83 100644 --- a/src/main/hivemall/io/SpaceEfficientDenseModel.java +++ b/src/main/hivemall/io/SpaceEfficientDenseModel.java @@ -32,18 +32,22 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -public final class SpaceEfficientDenseModel implements PredictionModel { +public final class SpaceEfficientDenseModel extends AbstractPredictionModel { private static final Log logger = LogFactory.getLog(SpaceEfficientDenseModel.class); private int size; private short[] weights; private short[] covars; + private short[] clocks; + private byte[] deltaUpdates; + public SpaceEfficientDenseModel(int ndims) { this(ndims, false); } public SpaceEfficientDenseModel(int ndims, boolean withCovar) { + super(); int size = ndims + 1; this.size = size; this.weights = new short[size]; @@ -54,6 +58,31 @@ public SpaceEfficientDenseModel(int ndims, boolean withCovar) { } else { this.covars = null; } + this.clocks = null; + this.deltaUpdates = null; + } + + @Override + public boolean hasCovariance() { + return covars != null; + } + + @Override + public void configureClock() { + if(clocks == null) { + this.clocks = new short[size]; + this.deltaUpdates = new byte[size]; + } + } + + @Override + public boolean hasClock() { + return clocks != null; + } + + @Override + public void resetDeltaUpdates(int feature) { + deltaUpdates[feature] = 0; } private float getWeight(final int i) { @@ -94,13 +123,17 @@ private void ensureCapacity(final int index) { this.covars = Arrays.copyOf(covars, newSize); Arrays.fill(covars, oldSize, newSize, HalfFloat.ONE); } + if(clocks != null) { + this.clocks = Arrays.copyOf(clocks, newSize); + this.deltaUpdates = Arrays.copyOf(deltaUpdates, newSize); + } } } @SuppressWarnings("unchecked") @Override - public T get(Object feature) { - int i = HiveUtils.parseInt(feature); + public T get(Object feature) { + final int i = HiveUtils.parseInt(feature); if(i >= size) { return null; } @@ -112,15 +145,27 @@ public T get(Object feature) { } @Override - public void set(Object feature, T value) { + public void set(Object feature, T value) { int i = HiveUtils.parseInt(feature); ensureCapacity(i); float weight = value.get(); setWeight(i, weight); + float covar = 1.f; if(value.hasCovariance()) { - float covar = value.getCovariance(); + covar = value.getCovariance(); setCovar(i, covar); } + short clock = 0; + int delta = 0; + if(clocks != null && value.isTouched()) { + clock = (short) (clocks[i] + 1); + clocks[i] = clock; + delta = deltaUpdates[i] + 1; + assert (delta > 0) : delta; + deltaUpdates[i] = (byte) delta; + } + + onUpdate(i, weight, covar, clock, delta); } @Override @@ -142,18 +187,24 @@ public float getCovariance(Object feature) { } @Override - public void setValue(Object feature, float weight) { + public void _set(Object feature, float weight, short clock) { int i = HiveUtils.parseInt(feature); ensureCapacity(i); setWeight(i, weight); + clocks[i] = clock; + deltaUpdates[i] = 0; + numMixed++; } @Override - public void setValue(Object feature, float weight, float covar) { + public void _set(Object feature, float weight, float covar, short clock) { int i = HiveUtils.parseInt(feature); ensureCapacity(i); setWeight(i, weight); setCovar(i, covar); + clocks[i] = clock; + deltaUpdates[i] = 0; + numMixed++; } @Override @@ -173,11 +224,11 @@ public boolean contains(Object feature) { @SuppressWarnings("unchecked") @Override - public IMapIterator entries() { + public IMapIterator entries() { return (IMapIterator) new Itr(); } - private final class Itr implements IMapIterator { + private final class Itr implements IMapIterator { private int cursor; private final WeightValueWithCovar tmpWeight; @@ -207,7 +258,7 @@ public Integer getKey() { } @Override - public WeightValue getValue() { + public IWeightValue getValue() { if(covars == null) { float w = getWeight(cursor); WeightValue v = new WeightValue(w); @@ -223,13 +274,13 @@ public WeightValue getValue() { } @Override - public > void getValue(T probe) { + public > void getValue(T probe) { float w = getWeight(cursor); tmpWeight.value = w; float cov = 1.f; if(covars != null) { cov = getCovar(cursor); - tmpWeight.covariance = cov; + tmpWeight.setCovariance(cov); } tmpWeight.setTouched(w != 0.f || cov != 1.f); probe.copyFrom(tmpWeight); diff --git a/src/main/hivemall/io/SparseModel.java b/src/main/hivemall/io/SparseModel.java index 4b20b781..c3e7c30a 100644 --- a/src/main/hivemall/io/SparseModel.java +++ b/src/main/hivemall/io/SparseModel.java @@ -20,53 +20,118 @@ */ package hivemall.io; -import hivemall.io.WeightValue.WeightValueWithCovar; +import hivemall.io.WeightValueWithClock.WeightValueWithCovarClock; import hivemall.utils.collections.IMapIterator; import hivemall.utils.collections.OpenHashMap; -public final class SparseModel implements PredictionModel { +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; - private final OpenHashMap weights; +public final class SparseModel extends AbstractPredictionModel { + private static final Log logger = LogFactory.getLog(SparseModel.class); - public SparseModel() { - this(16384); + private final OpenHashMap weights; + private final boolean hasCovar; + private boolean clockEnabled; + + public SparseModel(int size, boolean hasCovar) { + super(); + this.weights = new OpenHashMap(size); + this.hasCovar = hasCovar; + this.clockEnabled = false; } - public SparseModel(int size) { - this.weights = new OpenHashMap(size); + @Override + public boolean hasCovariance() { + return hasCovar; + } + + @Override + public void configureClock() { + this.clockEnabled = true; + } + + @Override + public boolean hasClock() { + return clockEnabled; } @SuppressWarnings("unchecked") @Override - public T get(Object feature) { + public T get(final Object feature) { return (T) weights.get(feature); } @Override - public void set(Object feature, T value) { - weights.put(feature, value); + public void set(final Object feature, final T value) { + assert (feature != null); + assert (value != null); + + final IWeightValue wrapperValue = wrapIfRequired(value); + + if(clockEnabled && value.isTouched()) { + IWeightValue old = weights.get(feature); + if(old != null) { + short newclock = (short) (old.getClock() + (short) 1); + wrapperValue.setClock(newclock); + int newDelta = old.getDeltaUpdates() + 1; + wrapperValue.setDeltaUpdates((byte) newDelta); + } + } + weights.put(feature, wrapperValue); + + onUpdate(feature, wrapperValue); + } + + private IWeightValue wrapIfRequired(final IWeightValue value) { + if(clockEnabled) { + if(value.hasCovariance()) { + return new WeightValueWithCovarClock(value); + } else { + return new WeightValueWithClock(value); + } + } else { + return value; + } } @Override - public float getWeight(Object feature) { - WeightValue v = weights.get(feature); - return v == null ? 0.f : v.value; + public float getWeight(final Object feature) { + IWeightValue v = weights.get(feature); + return v == null ? 0.f : v.get(); } @Override - public float getCovariance(Object feature) { - WeightValueWithCovar v = (WeightValueWithCovar) weights.get(feature); - return v == null ? 1.f : v.covariance; + public float getCovariance(final Object feature) { + IWeightValue v = weights.get(feature); + return v == null ? 1.f : v.getCovariance(); } @Override - public void setValue(Object feature, float weight) { - weights.put(feature, new WeightValue(weight)); + public void _set(final Object feature, final float weight, final short clock) { + final IWeightValue w = weights.get(feature); + if(w == null) { + logger.warn("Previous weight not found: " + feature); + throw new IllegalStateException("Previous weight not found " + feature); + } + w.set(weight); + w.setClock(clock); + w.setDeltaUpdates(BYTE0); + numMixed++; } @Override - public void setValue(Object feature, float weight, float covar) { - weights.put(feature, new WeightValueWithCovar(weight, covar)); + public void _set(final Object feature, final float weight, final float covar, final short clock) { + final IWeightValue w = weights.get(feature); + if(w == null) { + logger.warn("Previous weight not found: " + feature); + throw new IllegalStateException("Previous weight not found: " + feature); + } + w.set(weight); + w.setCovariance(covar); + w.setClock(clock); + w.setDeltaUpdates(BYTE0); + numMixed++; } @Override @@ -75,13 +140,13 @@ public int size() { } @Override - public boolean contains(Object feature) { + public boolean contains(final Object feature) { return weights.containsKey(feature); } @SuppressWarnings("unchecked") @Override - public IMapIterator entries() { + public IMapIterator entries() { return (IMapIterator) weights.entries(); } diff --git a/src/main/hivemall/io/SynchronizedModelWrapper.java b/src/main/hivemall/io/SynchronizedModelWrapper.java new file mode 100644 index 00000000..4335d09a --- /dev/null +++ b/src/main/hivemall/io/SynchronizedModelWrapper.java @@ -0,0 +1,152 @@ +package hivemall.io; + +import hivemall.utils.collections.IMapIterator; + +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; + +public final class SynchronizedModelWrapper implements PredictionModel { + + private final PredictionModel model; + private final Lock lock; + + public SynchronizedModelWrapper(PredictionModel model) { + this.model = model; + this.lock = new ReentrantLock(); + } + + // ------------------------------------------------------------ + // Non-synchronized methods with care + + public PredictionModel getModel() { + return model; + } + + @Override + public ModelUpdateHandler getUpdateHandler() { + return model.getUpdateHandler(); + } + + @Override + public void setUpdateHandler(ModelUpdateHandler handler) { + model.setUpdateHandler(handler); + } + + @Override + public int getNumMixed() { + return model.getNumMixed(); + } + + @Override + public boolean hasCovariance() { + return model.hasCovariance(); + } + + @Override + public void configureClock() { + model.configureClock(); + } + + @Override + public boolean hasClock() { + return model.hasClock(); + } + + @Override + public IMapIterator entries() { + return model.entries(); + } + + // ------------------------------------------------------------ + // The below is synchronized methods + + @Override + public void resetDeltaUpdates(int feature) { + try { + lock.lock(); + model.resetDeltaUpdates(feature); + } finally { + lock.unlock(); + } + } + + @Override + public int size() { + try { + lock.lock(); + return model.size(); + } finally { + lock.unlock(); + } + } + + @Override + public boolean contains(Object feature) { + try { + lock.lock(); + return model.contains(feature); + } finally { + lock.unlock(); + } + } + + @Override + public T get(Object feature) { + try { + lock.lock(); + return model.get(feature); + } finally { + lock.unlock(); + } + } + + @Override + public void set(Object feature, T value) { + try { + lock.lock(); + model.set(feature, value); + } finally { + lock.unlock(); + } + } + + @Override + public float getWeight(Object feature) { + try { + lock.lock(); + return model.getWeight(feature); + } finally { + lock.unlock(); + } + } + + @Override + public float getCovariance(Object feature) { + try { + lock.lock(); + return model.getCovariance(feature); + } finally { + lock.unlock(); + } + } + + @Override + public void _set(Object feature, float weight, short clock) { + try { + lock.lock(); + model._set(feature, weight, clock); + } finally { + lock.unlock(); + } + } + + @Override + public void _set(Object feature, float weight, float covar, short clock) { + try { + lock.lock(); + model._set(feature, weight, covar, clock); + } finally { + lock.unlock(); + } + } +} diff --git a/src/main/hivemall/io/WeightValue.java b/src/main/hivemall/io/WeightValue.java index 5f2488f1..f3b1e6d9 100644 --- a/src/main/hivemall/io/WeightValue.java +++ b/src/main/hivemall/io/WeightValue.java @@ -20,13 +20,9 @@ */ package hivemall.io; -import hivemall.utils.lang.Copyable; - -public class WeightValue implements Copyable { +public class WeightValue implements IWeightValue { protected float value; - - /** Is touched in training */ protected boolean touched; public WeightValue() {} @@ -40,22 +36,27 @@ public WeightValue(float weight, boolean touched) { this.touched = touched; } - public float get() { + @Override + public final float get() { return value; } - public void set(float weight) { + @Override + public final void set(float weight) { this.value = weight; } + @Override public boolean hasCovariance() { return false; } + @Override public float getCovariance() { throw new UnsupportedOperationException(); } + @Override public void setCovariance(float cov) { throw new UnsupportedOperationException(); } @@ -63,24 +64,46 @@ public void setCovariance(float cov) { /** * @return whether touched in training or not */ - public boolean isTouched() { + @Override + public final boolean isTouched() { return touched; } - public void setTouched(boolean touched) { + @Override + public final void setTouched(boolean touched) { this.touched = touched; } @Override - public void copyTo(WeightValue another) { - another.value = this.value; - another.touched = this.touched; + public final short getClock() { + throw new UnsupportedOperationException(); + } + + @Override + public final void setClock(short clock) { + throw new UnsupportedOperationException(); + } + + @Override + public final byte getDeltaUpdates() { + throw new UnsupportedOperationException(); + } + + @Override + public final void setDeltaUpdates(byte deltaUpdates) { + throw new UnsupportedOperationException(); + } + + @Override + public void copyTo(IWeightValue another) { + another.set(value); + another.setTouched(touched); } @Override - public void copyFrom(WeightValue another) { - this.value = another.value; - this.touched = another.touched; + public void copyFrom(IWeightValue another) { + this.value = another.get(); + this.touched = another.isTouched(); } @Override @@ -91,7 +114,7 @@ public String toString() { public static final class WeightValueWithCovar extends WeightValue { public static final float DEFAULT_COVAR = 1.f; - float covariance; + private float covariance; public WeightValueWithCovar() { super(); @@ -122,15 +145,15 @@ public void setCovariance(float cov) { } @Override - public void copyTo(WeightValue another) { + public void copyTo(IWeightValue another) { super.copyTo(another); - ((WeightValueWithCovar) another).covariance = this.covariance; + another.setCovariance(covariance); } @Override - public void copyFrom(WeightValue another) { + public void copyFrom(IWeightValue another) { super.copyFrom(another); - this.covariance = ((WeightValueWithCovar) another).covariance; + this.covariance = another.getCovariance(); } @Override diff --git a/src/main/hivemall/io/WeightValueWithClock.java b/src/main/hivemall/io/WeightValueWithClock.java new file mode 100644 index 00000000..b38379c3 --- /dev/null +++ b/src/main/hivemall/io/WeightValueWithClock.java @@ -0,0 +1,156 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.io; + +public class WeightValueWithClock implements IWeightValue { + + protected float value; + protected short clock; + protected byte deltaUpdates; + + public WeightValueWithClock(IWeightValue src) { + this.value = src.get(); + if(src.isTouched()) { + this.clock = 1; + this.deltaUpdates = 1; + } else { + this.clock = 0; + this.deltaUpdates = 0; + } + } + + public final float get() { + return value; + } + + public final void set(float weight) { + this.value = weight; + } + + public boolean hasCovariance() { + return false; + } + + public float getCovariance() { + throw new UnsupportedOperationException(); + } + + public void setCovariance(float cov) { + throw new UnsupportedOperationException(); + } + + /** + * @return whether touched in training or not + */ + public final boolean isTouched() { + return deltaUpdates > 0; + } + + @Override + public void setTouched(boolean touched) { + throw new UnsupportedOperationException("WeightValueWithClock#setTouched should not be called"); + } + + public final short getClock() { + return clock; + } + + public final void setClock(short clock) { + this.clock = clock; + } + + public final byte getDeltaUpdates() { + return deltaUpdates; + } + + public final void setDeltaUpdates(byte deltaUpdates) { + if(deltaUpdates < 0) { + throw new IllegalArgumentException("deltaUpdates is less than 0: " + deltaUpdates); + } + this.deltaUpdates = deltaUpdates; + } + + @Override + public void copyTo(IWeightValue another) { + another.set(value); + another.setClock(clock); + another.setDeltaUpdates(deltaUpdates); + } + + @Override + public void copyFrom(IWeightValue another) { + this.value = another.get(); + this.clock = another.getClock(); + this.deltaUpdates = another.getDeltaUpdates(); + } + + @Override + public String toString() { + return "WeightValueWithClock [value=" + value + ", clock=" + clock + ", deltaUpdates=" + + deltaUpdates + "]"; + } + + public static final class WeightValueWithCovarClock extends WeightValueWithClock { + public static final float DEFAULT_COVAR = 1.f; + + private float covariance; + + public WeightValueWithCovarClock(IWeightValue src) { + super(src); + this.covariance = src.getCovariance(); + } + + @Override + public boolean hasCovariance() { + return true; + } + + @Override + public float getCovariance() { + return covariance; + } + + @Override + public void setCovariance(float cov) { + this.covariance = cov; + } + + @Override + public void copyTo(IWeightValue another) { + super.copyTo(another); + another.setCovariance(covariance); + } + + @Override + public void copyFrom(IWeightValue another) { + super.copyFrom(another); + this.covariance = another.getCovariance(); + } + + @Override + public String toString() { + return "WeightValueWithCovar [value=" + value + ", clock=" + clock + ", deltaUpdates=" + + deltaUpdates + ", covariance=" + covariance + "]"; + } + + } + +} \ No newline at end of file diff --git a/src/main/hivemall/mix/MixMessage.java b/src/main/hivemall/mix/MixMessage.java new file mode 100644 index 00000000..586864c6 --- /dev/null +++ b/src/main/hivemall/mix/MixMessage.java @@ -0,0 +1,155 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013-2014 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.mix; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; + +public final class MixMessage implements Externalizable { + + private MixEventName event; + private Object feature; + private float weight; + private float covariance; + private short clock; + private int deltaUpdates; + + private String groupID; + + public MixMessage() {} // for Externalizable + + public MixMessage(MixEventName event, Object feature, float weight, short clock, int deltaUpdates) { + this(event, feature, weight, 0.f, clock, deltaUpdates); + } + + public MixMessage(MixEventName event, Object feature, float weight, float covariance, short clock, int deltaUpdates) { + if(feature == null) { + throw new IllegalArgumentException("feature is null"); + } + if(deltaUpdates < 0 || deltaUpdates > Byte.MAX_VALUE) { + throw new IllegalArgumentException("Illegal deletaUpdates: " + deltaUpdates); + } + this.event = event; + this.feature = feature; + this.weight = weight; + this.covariance = covariance; + this.clock = clock; + this.deltaUpdates = deltaUpdates; + } + + public enum MixEventName { + average((byte) 1), argminKLD((byte) 2), closeGroup((byte) 3); + + private final byte id; + + MixEventName(byte id) { + this.id = id; + } + + public byte getID() { + return id; + } + + public static MixEventName resolve(int b) { + switch(b) { + case 1: + return average; + case 2: + return argminKLD; + default: + throw new IllegalArgumentException("Illegal ID: " + b); + } + } + } + + public MixEventName getEvent() { + return event; + } + + public Object getFeature() { + return feature; + } + + public float getWeight() { + return weight; + } + + public float getCovariance() { + return covariance; + } + + public short getClock() { + return clock; + } + + public int getDeltaUpdates() { + return deltaUpdates; + } + + public String getGroupID() { + return groupID; + } + + public void setGroupID(String groupID) { + this.groupID = groupID; + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + out.writeByte(event.getID()); + out.writeObject(feature); + out.writeFloat(weight); + out.writeFloat(covariance); + out.writeShort(clock); + out.writeInt(deltaUpdates); + if(groupID == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + out.writeUTF(groupID); + } + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + byte id = in.readByte(); + this.event = MixEventName.resolve(id); + this.feature = in.readObject(); + this.weight = in.readFloat(); + this.covariance = in.readFloat(); + this.clock = in.readShort(); + this.deltaUpdates = in.readInt(); + boolean hasGroupID = in.readBoolean(); + if(hasGroupID) { + this.groupID = in.readUTF(); + } + } + + @Override + public String toString() { + return "MixMessage [event=" + event + ", groupID=" + groupID + ", feature=" + feature + + ", weight=" + weight + ", covariance=" + covariance + ", clock=" + clock + + ", deltaUpdates=" + deltaUpdates + "]"; + } + +} diff --git a/src/main/hivemall/mix/MixMessageDecoder.java b/src/main/hivemall/mix/MixMessageDecoder.java new file mode 100644 index 00000000..432b9b67 --- /dev/null +++ b/src/main/hivemall/mix/MixMessageDecoder.java @@ -0,0 +1,114 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013-2014 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.mix; + +import static hivemall.mix.MixMessageEncoder.INTEGER_TYPE; +import static hivemall.mix.MixMessageEncoder.INT_WRITABLE_TYPE; +import static hivemall.mix.MixMessageEncoder.LONG_WRITABLE_TYPE; +import static hivemall.mix.MixMessageEncoder.STRING_TYPE; +import static hivemall.mix.MixMessageEncoder.TEXT_TYPE; +import hivemall.mix.MixMessage.MixEventName; +import hivemall.utils.lang.StringUtils; +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; + +import java.io.IOException; + +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; + +public final class MixMessageDecoder extends LengthFieldBasedFrameDecoder { + + public MixMessageDecoder() { + super(1048576/* 1MiB */, 0, 4, 0, 4); + } + + @Override + protected MixMessage decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception { + ByteBuf frame = (ByteBuf) super.decode(ctx, in); + if(frame == null) { + return null; + } + + byte b = frame.readByte(); + MixEventName event = MixEventName.resolve(b); + Object feature = decodeObject(frame); + float weight = frame.readFloat(); + float covariance = frame.readFloat(); + short clock = frame.readShort(); + int deltaUpdates = frame.readInt(); + String groupID = readString(frame); + + MixMessage msg = new MixMessage(event, feature, weight, covariance, clock, deltaUpdates); + msg.setGroupID(groupID); + return msg; + } + + private static Object decodeObject(final ByteBuf in) throws IOException { + final byte type = in.readByte(); + switch(type) { + case INTEGER_TYPE: { + int i = in.readInt(); + return Integer.valueOf(i); + } + case TEXT_TYPE: { + int length = in.readInt(); + byte[] b = new byte[length]; + in.readBytes(b, 0, length); + Text t = new Text(b); + return t; + } + case STRING_TYPE: { + return readString(in); + } + case INT_WRITABLE_TYPE: { + int i = in.readInt(); + return new IntWritable(i); + } + case LONG_WRITABLE_TYPE: { + long l = in.readLong(); + return new LongWritable(l); + } + default: + break; + } + throw new IllegalStateException("Illegal type: " + type); + } + + private static String readString(final ByteBuf in) { + int length = in.readInt(); + if(length == -1) { + return null; + } + byte[] b = new byte[length]; + in.readBytes(b, 0, length); + String s = StringUtils.toString(b); + return s; + } + + @Override + protected ByteBuf extractFrame(ChannelHandlerContext ctx, ByteBuf buffer, int index, int length) { + return buffer.slice(index, length); + } + +} \ No newline at end of file diff --git a/src/main/hivemall/mix/MixMessageEncoder.java b/src/main/hivemall/mix/MixMessageEncoder.java new file mode 100644 index 00000000..2729fe79 --- /dev/null +++ b/src/main/hivemall/mix/MixMessageEncoder.java @@ -0,0 +1,120 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013-2014 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.mix; + +import hivemall.mix.MixMessage.MixEventName; +import hivemall.utils.lang.StringUtils; +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToByteEncoder; + +import java.io.IOException; + +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; + +public final class MixMessageEncoder extends MessageToByteEncoder { + private static final byte[] LENGTH_PLACEHOLDER = new byte[4]; + + static final byte INTEGER_TYPE = 1; + static final byte TEXT_TYPE = 2; + static final byte STRING_TYPE = 3; + static final byte INT_WRITABLE_TYPE = 4; + static final byte LONG_WRITABLE_TYPE = 5; + + public MixMessageEncoder() { + super(MixMessage.class, true); + } + + @Override + protected void encode(ChannelHandlerContext ctx, MixMessage msg, ByteBuf out) throws Exception { + int startIdx = out.writerIndex(); + out.writeBytes(LENGTH_PLACEHOLDER); + + MixEventName event = msg.getEvent(); + byte b = event.getID(); + out.writeByte(b); + + Object feature = msg.getFeature(); + encodeObject(feature, out); + + float weight = msg.getWeight(); + out.writeFloat(weight); + + float covariance = msg.getCovariance(); + out.writeFloat(covariance); + + short clock = msg.getClock(); + out.writeShort(clock); + + int deltaUpdates = msg.getDeltaUpdates(); + out.writeInt(deltaUpdates); + + String groupId = msg.getGroupID(); + writeString(groupId, out); + + int endIdx = out.writerIndex(); + out.setInt(startIdx, endIdx - startIdx - 4); + } + + private static void encodeObject(final Object obj, final ByteBuf buf) throws IOException { + assert (obj != null); + if(obj instanceof Integer) { + Integer i = (Integer) obj; + buf.writeByte(INTEGER_TYPE); + buf.writeInt(i.intValue()); + } else if(obj instanceof Text) { + Text t = (Text) obj; + byte[] b = t.getBytes(); + int length = t.getLength(); + buf.writeByte(TEXT_TYPE); + buf.writeInt(length); + buf.writeBytes(b, 0, length); + } else if(obj instanceof String) { + String s = (String) obj; + buf.writeByte(STRING_TYPE); + writeString(s, buf); + } else if(obj instanceof IntWritable) { + IntWritable i = (IntWritable) obj; + buf.writeByte(INT_WRITABLE_TYPE); + buf.writeInt(i.get()); + } else if(obj instanceof LongWritable) { + LongWritable l = (LongWritable) obj; + buf.writeByte(LONG_WRITABLE_TYPE); + buf.writeLong(l.get()); + } else { + throw new IllegalStateException("Unexpected type: " + obj.getClass().getName()); + } + } + + private static void writeString(final String s, final ByteBuf buf) { + if(s == null) { + buf.writeInt(-1); + return; + } + byte[] b = StringUtils.getBytes(s); + int length = b.length; + buf.writeInt(length); + buf.writeBytes(b, 0, length); + } + +} diff --git a/src/main/hivemall/mix/NodeInfo.java b/src/main/hivemall/mix/NodeInfo.java new file mode 100644 index 00000000..7d686c2a --- /dev/null +++ b/src/main/hivemall/mix/NodeInfo.java @@ -0,0 +1,79 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013-2014 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.mix; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; + +public final class NodeInfo { + + private final InetAddress addr; + private final int port; + + public NodeInfo(InetAddress addr, int port) { + if(addr == null) { + throw new IllegalArgumentException("addr is null"); + } + this.addr = addr; + this.port = port; + } + + public NodeInfo(InetSocketAddress sockAddr) { + this.addr = sockAddr.getAddress(); + this.port = sockAddr.getPort(); + } + + public InetAddress getAddress() { + return addr; + } + + public int getPort() { + return port; + } + + public SocketAddress getSocketAddress() { + return new InetSocketAddress(addr, port); + } + + @Override + public int hashCode() { + return addr.hashCode() + port; + } + + @Override + public boolean equals(Object obj) { + if(obj == this) { + return true; + } + if(obj instanceof NodeInfo) { + NodeInfo other = (NodeInfo) obj; + return addr.equals(other.addr) && (port == other.port); + } + return false; + } + + @Override + public String toString() { + return addr.toString() + ":" + port; + } + +} diff --git a/src/main/hivemall/mix/client/MixClient.java b/src/main/hivemall/mix/client/MixClient.java new file mode 100644 index 00000000..3a4cb0e3 --- /dev/null +++ b/src/main/hivemall/mix/client/MixClient.java @@ -0,0 +1,160 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013-2014 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.mix.client; + +import hivemall.io.ModelUpdateHandler; +import hivemall.io.PredictionModel; +import hivemall.mix.MixMessage; +import hivemall.mix.MixMessage.MixEventName; +import hivemall.mix.NodeInfo; +import hivemall.utils.hadoop.HadoopUtils; +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; + +import java.io.Closeable; +import java.io.IOException; +import java.net.SocketAddress; +import java.util.HashMap; +import java.util.Map; + +import javax.annotation.CheckForNull; +import javax.annotation.Nonnull; +import javax.net.ssl.SSLException; + +public final class MixClient implements ModelUpdateHandler, Closeable { + public static final String DUMMY_JOB_ID = "__DUMMY_JOB_ID__"; + + private final MixEventName event; + private String groupID; + private final boolean ssl; + private final int mixThreshold; + private final MixRequestRouter router; + private final MixClientHandler msgHandler; + private final Map channelMap; + + private boolean initialized = false; + private EventLoopGroup workers; + + public MixClient(@Nonnull MixEventName event, @CheckForNull String groupID, @Nonnull String connectURIs, boolean ssl, int mixThreshold, @Nonnull PredictionModel model) { + if(groupID == null) { + throw new IllegalArgumentException("groupID is null"); + } + if(mixThreshold < 1 || mixThreshold > Byte.MAX_VALUE) { + throw new IllegalArgumentException("Invalid mixThreshold: " + mixThreshold); + } + this.event = event; + this.groupID = groupID; + this.router = new MixRequestRouter(connectURIs); + this.ssl = ssl; + this.mixThreshold = mixThreshold; + this.msgHandler = new MixClientHandler(model); + this.channelMap = new HashMap(); + } + + private void initialize() throws Exception { + EventLoopGroup workerGroup = new NioEventLoopGroup(); + NodeInfo[] serverNodes = router.getAllNodes(); + for(NodeInfo node : serverNodes) { + Bootstrap b = new Bootstrap(); + configureBootstrap(b, workerGroup, node); + } + this.workers = workerGroup; + this.initialized = true; + } + + private void configureBootstrap(Bootstrap b, EventLoopGroup workerGroup, NodeInfo server) + throws SSLException, InterruptedException { + // Configure SSL. + final SslContext sslCtx; + if(ssl) { + sslCtx = SslContext.newClientContext(InsecureTrustManagerFactory.INSTANCE); + } else { + sslCtx = null; + } + + b.group(workerGroup); + b.option(ChannelOption.SO_KEEPALIVE, true); + b.option(ChannelOption.TCP_NODELAY, true); + b.channel(NioSocketChannel.class); + b.handler(new MixClientInitializer(msgHandler, sslCtx)); + + SocketAddress remoteAddr = server.getSocketAddress(); + ChannelFuture channelFuture = b.connect(remoteAddr).sync(); + Channel channel = channelFuture.channel(); + + channelMap.put(server, channel); + } + + @Override + public boolean onUpdate(Object feature, float weight, float covar, short clock, int deltaUpdates) + throws Exception { + assert (deltaUpdates > 0) : deltaUpdates; + if(deltaUpdates < mixThreshold) { + return false; // avoid mixing + } + + if(!initialized) { + replaceGroupIDIfRequired(); + initialize(); // initialize connections to mix servers + } + + MixMessage msg = new MixMessage(event, feature, weight, covar, clock, deltaUpdates); + msg.setGroupID(groupID); + + NodeInfo server = router.selectNode(msg); + Channel ch = channelMap.get(server); + if(!ch.isActive()) {// reconnect + SocketAddress remoteAddr = server.getSocketAddress(); + ch.connect(remoteAddr).sync(); + } + + //ch.writeAndFlush(msg).sync(); + ch.writeAndFlush(msg); // send asynchronously in the background + return true; + } + + private void replaceGroupIDIfRequired() { + if(groupID.startsWith(DUMMY_JOB_ID)) { + String jobId = HadoopUtils.getJobId(); + this.groupID = groupID.replace(DUMMY_JOB_ID, jobId); + } + } + + @Override + public void close() throws IOException { + if(workers != null) { + for(Channel ch : channelMap.values()) { + ch.close(); + } + channelMap.clear(); + workers.shutdownGracefully(); + this.workers = null; + } + } + +} diff --git a/src/main/hivemall/mix/client/MixClientHandler.java b/src/main/hivemall/mix/client/MixClientHandler.java new file mode 100644 index 00000000..672f949b --- /dev/null +++ b/src/main/hivemall/mix/client/MixClientHandler.java @@ -0,0 +1,57 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013-2014 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.mix.client; + +import hivemall.io.PredictionModel; +import hivemall.mix.MixMessage; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; + +@Sharable +public final class MixClientHandler extends SimpleChannelInboundHandler { + + private final PredictionModel model; + private final boolean hasCovar; + + public MixClientHandler(PredictionModel model) { + super(); + if(model == null) { + throw new IllegalArgumentException("model is null"); + } + this.model = model; + this.hasCovar = model.hasCovariance(); + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, MixMessage msg) throws Exception { + Object feature = msg.getFeature(); + float weight = msg.getWeight(); + short clock = msg.getClock(); + if(hasCovar) { + float covar = msg.getCovariance(); + model._set(feature, weight, covar, clock); + } else { + model._set(feature, weight, clock); + } + } + +} diff --git a/src/main/hivemall/mix/client/MixClientInitializer.java b/src/main/hivemall/mix/client/MixClientInitializer.java new file mode 100644 index 00000000..a4208dd5 --- /dev/null +++ b/src/main/hivemall/mix/client/MixClientInitializer.java @@ -0,0 +1,57 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013-2014 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.mix.client; + +import hivemall.mix.MixMessageDecoder; +import hivemall.mix.MixMessageEncoder; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.ssl.SslContext; + +public final class MixClientInitializer extends ChannelInitializer { + + private final MixClientHandler responseHandler; + private final SslContext sslCtx; + + public MixClientInitializer(MixClientHandler msgHandler, SslContext sslCtx) { + if(msgHandler == null) { + throw new IllegalArgumentException(); + } + this.responseHandler = msgHandler; + this.sslCtx = sslCtx; + } + + @Override + protected void initChannel(SocketChannel ch) throws Exception { + ChannelPipeline pipeline = ch.pipeline(); + if(sslCtx != null) { + pipeline.addLast(sslCtx.newHandler(ch.alloc())); + } + + //ObjectEncoder encoder = new ObjectEncoder(); + //ObjectDecoder decoder = new ObjectDecoder(ClassResolvers.cacheDisabled(null)); + MixMessageEncoder encoder = new MixMessageEncoder(); + MixMessageDecoder decoder = new MixMessageDecoder(); + pipeline.addLast(encoder, decoder, responseHandler); + } + +} diff --git a/src/main/hivemall/mix/client/MixRequestRouter.java b/src/main/hivemall/mix/client/MixRequestRouter.java new file mode 100644 index 00000000..7d331e1b --- /dev/null +++ b/src/main/hivemall/mix/client/MixRequestRouter.java @@ -0,0 +1,65 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013-2014 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.mix.client; + +import hivemall.mix.MixMessage; +import hivemall.mix.NodeInfo; +import hivemall.mix.server.MixServer; +import hivemall.utils.net.NetUtils; + +import java.net.InetSocketAddress; + +public final class MixRequestRouter { + + private final int numNodes; + private final NodeInfo[] nodes; + + public MixRequestRouter(String connectInfo) { + if(connectInfo == null) { + throw new IllegalArgumentException(); + } + String[] endpoints = connectInfo.split("\\s*,\\s*"); + final int numEndpoints = endpoints.length; + if(numEndpoints < 1) { + throw new IllegalArgumentException("Invalid connectInfo: " + connectInfo); + } + this.numNodes = numEndpoints; + NodeInfo[] nodes = new NodeInfo[numEndpoints]; + for(int i = 0; i < numEndpoints; i++) { + InetSocketAddress addr = NetUtils.getInetSocketAddress(endpoints[i], MixServer.DEFAULT_PORT); + nodes[i] = new NodeInfo(addr); + } + this.nodes = nodes; + } + + public NodeInfo[] getAllNodes() { + return nodes; + } + + public NodeInfo selectNode(MixMessage msg) { + assert (msg != null); + Object feature = msg.getFeature(); + int hashcode = feature.hashCode(); + int index = Math.abs(hashcode) % numNodes; + return nodes[index]; + } + +} diff --git a/src/main/hivemall/mix/server/MixServer.java b/src/main/hivemall/mix/server/MixServer.java new file mode 100644 index 00000000..0b6c7879 --- /dev/null +++ b/src/main/hivemall/mix/server/MixServer.java @@ -0,0 +1,147 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013-2014 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.mix.server; + +import hivemall.mix.store.SessionStore; +import hivemall.mix.store.SessionStore.IdleSessionSweeper; +import hivemall.utils.lang.CommandLineUtils; +import hivemall.utils.lang.Primitives; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.handler.logging.LogLevel; +import io.netty.handler.logging.LoggingHandler; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.util.SelfSignedCertificate; + +import java.security.cert.CertificateException; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +import javax.annotation.Nonnull; +import javax.net.ssl.SSLException; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; + +public final class MixServer implements Runnable { + public static final int DEFAULT_PORT = 11212; + + private final int port; + private final boolean ssl; + private final float scale; + private final short syncThreshold; + private final long sessionTTLinSec; + private final long sweepIntervalInSec; + + public MixServer(CommandLine cl) { + this.port = Primitives.parseInt(cl.getOptionValue("port"), DEFAULT_PORT); + this.ssl = cl.hasOption("ssl"); + this.scale = Primitives.parseFloat(cl.getOptionValue("scale"), 1.f); + this.syncThreshold = Primitives.parseShort(cl.getOptionValue("sync"), (short) 30); + this.sessionTTLinSec = Primitives.parseLong(cl.getOptionValue("ttl"), 120L); + this.sweepIntervalInSec = Primitives.parseLong(cl.getOptionValue("sweep"), 60L); + } + + public static void main(String[] args) { + Options opts = getOptions(); + CommandLine cl = CommandLineUtils.parseOptions(args, opts); + new MixServer(cl).run(); + } + + static Options getOptions() { + Options opts = new Options(); + opts.addOption("p", "port", true, "port number of the mix server [default: 11212]"); + opts.addOption("ssl", false, "Use SSL for the mix communication [default: false]"); + opts.addOption("scale", "scalemodel", true, "Scale values of prediction models to avoid overflow [default: 1.0 (no-scale)]"); + opts.addOption("sync", "sync_threshold", true, "Synchronization threshold using clock difference [default: 30]"); + opts.addOption("ttl", "session_ttl", true, "The TTL in sec that an idle session lives [default: 120 sec]"); + opts.addOption("sweep", "session_sweep_interval", true, "The interval in sec that the session expiry thread runs [default: 60 sec]"); + return opts; + } + + @Override + public void run() { + try { + start(); + } catch (CertificateException e) { + e.printStackTrace(); + } catch (SSLException e) { + e.printStackTrace(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + + public void start() throws CertificateException, SSLException, InterruptedException { + // Configure SSL. + final SslContext sslCtx; + if(ssl) { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + sslCtx = SslContext.newServerContext(ssc.certificate(), ssc.privateKey()); + } else { + sslCtx = null; + } + + SessionStore sessionStore = new SessionStore(); + MixServerHandler msgHandler = new MixServerHandler(sessionStore, syncThreshold, scale); + MixServerInitializer initializer = new MixServerInitializer(msgHandler, sslCtx); + Runnable cleanSessionTask = new IdleSessionSweeper(sessionStore, sessionTTLinSec * 1000L); + + final ScheduledExecutorService idleSessionChecker = Executors.newScheduledThreadPool(1); + try { + idleSessionChecker.scheduleAtFixedRate(cleanSessionTask, sessionTTLinSec + 10L, sweepIntervalInSec, TimeUnit.SECONDS); + acceptConnections(initializer, port); + } finally { + idleSessionChecker.shutdownNow(); + } + } + + private static void acceptConnections(@Nonnull MixServerInitializer initializer, int port) + throws InterruptedException { + final EventLoopGroup bossGroup = new NioEventLoopGroup(1); + final EventLoopGroup workerGroup = new NioEventLoopGroup(); + try { + ServerBootstrap b = new ServerBootstrap(); + b.option(ChannelOption.SO_KEEPALIVE, true); + b.group(bossGroup, workerGroup); + b.channel(NioServerSocketChannel.class); + b.handler(new LoggingHandler(LogLevel.INFO)); + b.childHandler(initializer); + + // Bind and start to accept incoming connections. + ChannelFuture f = b.bind(port).sync(); + + // Wait until the server socket is closed. + // In this example, this does not happen, but you can do that to gracefully + // shut down your server. + f.channel().closeFuture().sync(); + } finally { + workerGroup.shutdownGracefully(); + bossGroup.shutdownGracefully(); + } + } + +} diff --git a/src/main/hivemall/mix/server/MixServerHandler.java b/src/main/hivemall/mix/server/MixServerHandler.java new file mode 100644 index 00000000..950ca212 --- /dev/null +++ b/src/main/hivemall/mix/server/MixServerHandler.java @@ -0,0 +1,127 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013-2014 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.mix.server; + +import hivemall.mix.MixMessage; +import hivemall.mix.MixMessage.MixEventName; +import hivemall.mix.store.PartialArgminKLD; +import hivemall.mix.store.PartialAverage; +import hivemall.mix.store.PartialResult; +import hivemall.mix.store.SessionStore; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; + +import java.util.concurrent.ConcurrentMap; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +@Sharable +public final class MixServerHandler extends SimpleChannelInboundHandler { + + @Nonnull + private final SessionStore sessionStore; + private final int syncThreshold; + private final float scale; + + public MixServerHandler(@Nonnull SessionStore sessionStore, @Nonnegative int syncThreshold, @Nonnegative float scale) { + super(); + this.sessionStore = sessionStore; + this.syncThreshold = syncThreshold; + this.scale = scale; + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, MixMessage msg) throws Exception { + final MixEventName event = msg.getEvent(); + switch(event) { + case average: + case argminKLD: + PartialResult partial = getPartialResult(msg); + mix(ctx, msg, partial); + break; + default: + throw new IllegalStateException("Unexpected event: " + event); + } + } + + @Nonnull + private PartialResult getPartialResult(@Nonnull MixMessage msg) { + String groupID = msg.getGroupID(); + if(groupID == null) { + throw new IllegalStateException("JobID is not set in the request message"); + } + ConcurrentMap map = sessionStore.get(groupID); + + Object feature = msg.getFeature(); + PartialResult partial = map.get(feature); + if(partial == null) { + final MixEventName event = msg.getEvent(); + switch(event) { + case average: + partial = new PartialAverage(scale); + break; + case argminKLD: + partial = new PartialArgminKLD(scale); + break; + default: + throw new IllegalStateException("Unexpected event: " + event); + } + PartialResult existing = map.putIfAbsent(feature, partial); + if(existing != null) { + partial = existing; + } + } + return partial; + } + + private void mix(final ChannelHandlerContext ctx, final MixMessage requestMsg, final PartialResult partial) { + MixEventName event = requestMsg.getEvent(); + Object feature = requestMsg.getFeature(); + float weight = requestMsg.getWeight(); + float covar = requestMsg.getCovariance(); + short clock = requestMsg.getClock(); + int deltaUpdates = requestMsg.getDeltaUpdates(); + + MixMessage responseMsg = null; + try { + partial.lock(); + + int diffClock = partial.diffClock(clock); + partial.add(weight, covar, clock, deltaUpdates); + + if(diffClock >= syncThreshold) {// sync model if clock DIFF is above threshold + float averagedWeight = partial.getWeight(); + float minCovar = partial.getMinCovariance(); + short totalClock = partial.getClock(); + responseMsg = new MixMessage(event, feature, averagedWeight, minCovar, totalClock, 0 /* deltaUpdates */); + } + } finally { + partial.unlock(); + } + + if(responseMsg != null) { + ctx.writeAndFlush(responseMsg); + } + } + +} diff --git a/src/main/hivemall/mix/server/MixServerInitializer.java b/src/main/hivemall/mix/server/MixServerInitializer.java new file mode 100644 index 00000000..2c2bd786 --- /dev/null +++ b/src/main/hivemall/mix/server/MixServerInitializer.java @@ -0,0 +1,57 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013-2014 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.mix.server; + +import hivemall.mix.MixMessageDecoder; +import hivemall.mix.MixMessageEncoder; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.ssl.SslContext; + +public final class MixServerInitializer extends ChannelInitializer { + + private final MixServerHandler requestHandler; + private final SslContext sslCtx; + + public MixServerInitializer(MixServerHandler msgHandler, SslContext sslCtx) { + if(msgHandler == null) { + throw new IllegalArgumentException(); + } + this.requestHandler = msgHandler; + this.sslCtx = sslCtx; + } + + @Override + protected void initChannel(SocketChannel ch) throws Exception { + ChannelPipeline pipeline = ch.pipeline(); + if(sslCtx != null) { + pipeline.addLast(sslCtx.newHandler(ch.alloc())); + } + + //ObjectEncoder encoder = new ObjectEncoder(); + //ObjectDecoder decoder = new ObjectDecoder(4194304, ClassResolvers.cacheDisabled(null)); + MixMessageEncoder encoder = new MixMessageEncoder(); + MixMessageDecoder decoder = new MixMessageDecoder(); + pipeline.addLast(decoder, encoder, requestHandler); + } + +} diff --git a/src/main/hivemall/mix/store/PartialArgminKLD.java b/src/main/hivemall/mix/store/PartialArgminKLD.java new file mode 100644 index 00000000..c11fb605 --- /dev/null +++ b/src/main/hivemall/mix/store/PartialArgminKLD.java @@ -0,0 +1,63 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013-2014 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.mix.store; + +import javax.annotation.Nonnegative; +import javax.annotation.concurrent.GuardedBy; + +public final class PartialArgminKLD extends PartialResult { + + private final float scale; + + @GuardedBy("lock()") + private float sum_mean_div_covar; + @GuardedBy("lock()") + private float sum_inv_covar; + + public PartialArgminKLD() { + this(1.f); // no scaling + } + + public PartialArgminKLD(@Nonnegative float scale) { + super(); + this.scale = scale; + this.sum_mean_div_covar = 0.f; + this.sum_inv_covar = 0.f; + } + + @Override + public void add(float localWeight, float covar, short clock, int deltaUpdates) { + addWeight(localWeight, covar); + setMinCovariance(covar); + incrClock(clock); + } + + protected void addWeight(float localWeight, float covar) { + this.sum_mean_div_covar += (localWeight / covar) / scale; + this.sum_inv_covar += (1.f / covar) / scale; + } + + @Override + public float getWeight() { + return sum_mean_div_covar / sum_inv_covar; + } + +} diff --git a/src/main/hivemall/mix/store/PartialAverage.java b/src/main/hivemall/mix/store/PartialAverage.java new file mode 100644 index 00000000..069614d6 --- /dev/null +++ b/src/main/hivemall/mix/store/PartialAverage.java @@ -0,0 +1,65 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013-2014 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.mix.store; + +import javax.annotation.Nonnegative; +import javax.annotation.concurrent.GuardedBy; + +public final class PartialAverage extends PartialResult { + public static final float DEFAULT_SCALE = 10; + + private final float scale; + + @GuardedBy("lock()") + private float scaledSumWeights; + @GuardedBy("lock()") + private short totalUpdates; + + public PartialAverage() { + this(1.f); // no scaling + } + + public PartialAverage(float scale) { + super(); + this.scale = scale; + this.scaledSumWeights = 0.f; + this.totalUpdates = 0; + } + + @Override + public void add(float localWeight, float covar, short clock, @Nonnegative int deltaUpdates) { + addWeight(localWeight, deltaUpdates); + setMinCovariance(covar); + incrClock(clock); + } + + protected void addWeight(float localWeight, int deltaUpdates) { + scaledSumWeights += ((localWeight / scale) * deltaUpdates); + totalUpdates += deltaUpdates; // not deltaUpdates is in range (0,127] + assert (totalUpdates > 0) : totalUpdates; + } + + @Override + public float getWeight() { + return (scaledSumWeights / totalUpdates) * scale; + } + +} diff --git a/src/main/hivemall/mix/store/PartialResult.java b/src/main/hivemall/mix/store/PartialResult.java new file mode 100644 index 00000000..713df960 --- /dev/null +++ b/src/main/hivemall/mix/store/PartialResult.java @@ -0,0 +1,75 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013-2014 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.mix.store; + +import hivemall.utils.lock.Lock; +import hivemall.utils.lock.TTASLock; + +import javax.annotation.Nonnegative; +import javax.annotation.concurrent.GuardedBy; + +public abstract class PartialResult { + + private final Lock lock; + + @GuardedBy("lock()") + protected float minCovariance; + @GuardedBy("lock()") + protected short totalClock; + + public PartialResult() { + this.lock = new TTASLock(); + } + + public final void lock() { + lock.lock(); + } + + public final void unlock() { + lock.unlock(); + } + + public abstract void add(float localWeight, float covar, short clock, @Nonnegative int deltaUpdates); + + public abstract float getWeight(); + + public final float getMinCovariance() { + return minCovariance; + } + + protected final void setMinCovariance(float covar) { + this.minCovariance = Math.max(minCovariance, covar); + } + + public final short getClock() { + return totalClock; + } + + protected final void incrClock(short clock) { + totalClock += clock; + } + + public final int diffClock(short clock) { + short diff = (short) (totalClock - clock); + return diff < 0 ? -diff : diff; + } + +} diff --git a/src/main/hivemall/mix/store/SessionObject.java b/src/main/hivemall/mix/store/SessionObject.java new file mode 100644 index 00000000..e1438ffb --- /dev/null +++ b/src/main/hivemall/mix/store/SessionObject.java @@ -0,0 +1,58 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013-2014 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.mix.store; + +import java.util.concurrent.ConcurrentMap; + +import javax.annotation.Nonnull; +import javax.annotation.concurrent.ThreadSafe; + +@ThreadSafe +public final class SessionObject { + + @Nonnull + private final ConcurrentMap object; + private volatile long lastAccessed; // being accessed by multiple threads + + public SessionObject(@Nonnull ConcurrentMap obj) { + if(obj == null) { + throw new IllegalArgumentException("obj is null"); + } + this.object = obj; + } + + @Nonnull + public ConcurrentMap get() { + return object; + } + + /** + * @return last accessed time in msec + */ + public long getLastAccessed() { + return lastAccessed; + } + + public void touch() { + this.lastAccessed = System.currentTimeMillis(); + } + +} diff --git a/src/main/hivemall/mix/store/SessionStore.java b/src/main/hivemall/mix/store/SessionStore.java new file mode 100644 index 00000000..c66504d9 --- /dev/null +++ b/src/main/hivemall/mix/store/SessionStore.java @@ -0,0 +1,74 @@ +package hivemall.mix.store; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import javax.annotation.concurrent.ThreadSafe; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +@ThreadSafe +public final class SessionStore { + private static final Log logger = LogFactory.getLog(SessionStore.class); + + private static final int EXPECTED_MODEL_SIZE = 4194305; /* 2^22+1=4194304+1=4194305 */ + + private final ConcurrentMap sessions; + + public SessionStore() { + this.sessions = new ConcurrentHashMap(); + } + + @Nonnull + private ConcurrentMap getSessions() { + return sessions; + } + + @Nonnull + public ConcurrentMap get(@Nonnull String groupID) { + SessionObject sessionObj = sessions.get(groupID); + if(sessionObj == null) { + ConcurrentMap map = new ConcurrentHashMap(EXPECTED_MODEL_SIZE); + sessionObj = new SessionObject(map); + SessionObject existing = sessions.putIfAbsent(groupID, sessionObj); + if(existing != null) { + sessionObj = existing; + } + } + ConcurrentMap map = sessionObj.get(); + sessionObj.touch(); + return map; + } + + @ThreadSafe + public static final class IdleSessionSweeper implements Runnable { + + private final ConcurrentMap sessions; + private final long ttl; + + public IdleSessionSweeper(@Nonnull SessionStore sessionStore, @Nonnegative long ttlInMillis) { + this.sessions = sessionStore.getSessions(); + this.ttl = ttlInMillis; + } + + public void run() { + for(Map.Entry e : sessions.entrySet()) { + SessionObject sessionObj = e.getValue(); + long lastAccessed = sessionObj.getLastAccessed(); + long elapsedTime = System.currentTimeMillis() - lastAccessed; + if(elapsedTime > ttl) { + String key = e.getKey(); + assert (key != null); + sessions.remove(key); + logger.info("Removed an idle session group: " + key); + } + } + + } + } + +} diff --git a/src/main/hivemall/regression/AROWRegressionUDTF.java b/src/main/hivemall/regression/AROWRegressionUDTF.java index 910a6e67..09d41b8e 100644 --- a/src/main/hivemall/regression/AROWRegressionUDTF.java +++ b/src/main/hivemall/regression/AROWRegressionUDTF.java @@ -23,8 +23,8 @@ import hivemall.common.LossFunctions; import hivemall.common.OnlineVariance; import hivemall.io.FeatureValue; +import hivemall.io.IWeightValue; import hivemall.io.PredictionResult; -import hivemall.io.WeightValue; import hivemall.io.WeightValue.WeightValueWithCovar; import java.util.Collection; @@ -122,13 +122,13 @@ protected void update(final Collection features, final float coeff, final flo k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector); v = 1.f; } - WeightValue old_w = model.get(k); - WeightValue new_w = getNewWeight(old_w, v, coeff, beta); + IWeightValue old_w = model.get(k); + IWeightValue new_w = getNewWeight(old_w, v, coeff, beta); model.set(k, new_w); } } - private static WeightValue getNewWeight(final WeightValue old, final float x, final float coeff, final float beta) { + private static IWeightValue getNewWeight(final IWeightValue old, final float x, final float coeff, final float beta) { final float old_w; final float old_cov; if(old == null) { diff --git a/src/main/hivemall/regression/OnlineRegressionUDTF.java b/src/main/hivemall/regression/OnlineRegressionUDTF.java index e1caeb81..f70f9099 100644 --- a/src/main/hivemall/regression/OnlineRegressionUDTF.java +++ b/src/main/hivemall/regression/OnlineRegressionUDTF.java @@ -25,6 +25,7 @@ import static hivemall.HivemallConstants.STRING_TYPE_NAME; import hivemall.LearnerBaseUDTF; import hivemall.io.FeatureValue; +import hivemall.io.IWeightValue; import hivemall.io.PredictionModel; import hivemall.io.PredictionResult; import hivemall.io.WeightValue; @@ -215,7 +216,7 @@ protected PredictionResult calcScoreAndVariance(Collection features) { k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector); v = 1.f; } - WeightValue old_w = model.get(k); + IWeightValue old_w = model.get(k); if(old_w == null) { variance += (1.f * v * v); } else { @@ -260,7 +261,8 @@ protected void update(Collection features, float coeff) { } @Override - public void close() throws HiveException { + public final void close() throws HiveException { + super.close(); if(model != null) { int numForwarded = 0; if(useCovariance()) { @@ -268,7 +270,7 @@ public void close() throws HiveException { final Object[] forwardMapObj = new Object[3]; final FloatWritable fv = new FloatWritable(); final FloatWritable cov = new FloatWritable(); - final IMapIterator itor = model.entries(); + final IMapIterator itor = model.entries(); while(itor.next() != -1) { itor.getValue(probe); if(!probe.isTouched()) { @@ -287,7 +289,7 @@ public void close() throws HiveException { final WeightValue probe = new WeightValue(); final Object[] forwardMapObj = new Object[2]; final FloatWritable fv = new FloatWritable(); - final IMapIterator itor = model.entries(); + final IMapIterator itor = model.entries(); while(itor.next() != -1) { itor.getValue(probe); if(!probe.isTouched()) { @@ -301,10 +303,11 @@ public void close() throws HiveException { numForwarded++; } } + int numMixed = model.getNumMixed(); this.model = null; - logger.info("Trained a prediction model using " + count - + " training examples. Forwarded the prediction model of " + numForwarded - + " rows"); + logger.info("Trained a prediction model using " + count + " training examples" + + (numMixed > 0 ? "( numMixed: " + numMixed + " )" : "")); + logger.info("Forwarded the prediction model of " + numForwarded + " rows"); } } diff --git a/src/main/hivemall/tools/mapred/JobConfGetsUDF.java b/src/main/hivemall/tools/mapred/JobConfGetsUDF.java new file mode 100644 index 00000000..dcb3b53d --- /dev/null +++ b/src/main/hivemall/tools/mapred/JobConfGetsUDF.java @@ -0,0 +1,61 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013-2014 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.tools.mapred; + +import static hivemall.utils.hadoop.WritableUtils.val; +import hivemall.utils.hadoop.HadoopUtils; + +import javax.annotation.Nullable; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.MapredContext; +import org.apache.hadoop.hive.ql.exec.MapredContextAccessor; +import org.apache.hadoop.hive.ql.exec.UDF; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapred.JobConf; + +/** + * @since Hive 0.12.0 + */ +@Description(name = "jobconf_gets", value = "_FUNC_() - Returns the value from JobConf") +@UDFType(deterministic = false, stateful = true) +public class JobConfGetsUDF extends UDF { + + public Text evaluate() { + return evaluate(null); + } + + public Text evaluate(@Nullable final String regexKey) { + MapredContext ctx = MapredContextAccessor.get(); + if(ctx == null) { + throw new IllegalStateException("MapredContext is not set"); + } + JobConf jobconf = ctx.getJobConf(); + if(jobconf == null) { + throw new IllegalStateException("JobConf is not set"); + } + + String dumped = HadoopUtils.toString(jobconf, regexKey); + return val(dumped); + } + +} diff --git a/src/main/hivemall/tools/mapred/JobIdUDF.java b/src/main/hivemall/tools/mapred/JobIdUDF.java new file mode 100644 index 00000000..cb98d4a7 --- /dev/null +++ b/src/main/hivemall/tools/mapred/JobIdUDF.java @@ -0,0 +1,42 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013-2014 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.tools.mapred; + +import static hivemall.utils.hadoop.WritableUtils.val; +import hivemall.utils.hadoop.HadoopUtils; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDF; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.io.Text; + +@Description(name = "jobid", value = "_FUNC_() - Returns the value of mapred.task.partition") +@UDFType(deterministic = false, stateful = true) +public class JobIdUDF extends UDF { + + /** + * @since Hive 0.12.0 + */ + public Text evaluate() { + return val(HadoopUtils.getJobId()); + } + +} diff --git a/src/main/hivemall/utils/hadoop/HadoopUtils.java b/src/main/hivemall/utils/hadoop/HadoopUtils.java index af163a9c..81f7f69d 100644 --- a/src/main/hivemall/utils/hadoop/HadoopUtils.java +++ b/src/main/hivemall/utils/hadoop/HadoopUtils.java @@ -28,6 +28,11 @@ import java.io.InputStreamReader; import java.io.Reader; import java.net.URI; +import java.util.Iterator; +import java.util.Map.Entry; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; @@ -39,6 +44,8 @@ import org.apache.hadoop.io.compress.CompressionInputStream; import org.apache.hadoop.io.compress.Decompressor; import org.apache.hadoop.mapred.JobConf; +import org.apache.hadoop.mapred.JobID; +import org.apache.hadoop.mapred.TaskID; public final class HadoopUtils { @@ -69,7 +76,7 @@ public static BufferedReader getBufferedReader(File file, MapredContext context) } } - private static class BufferedReaderExt extends BufferedReader { + private static final class BufferedReaderExt extends BufferedReader { private Decompressor decompressor; @@ -89,14 +96,93 @@ public void close() throws IOException { } - public static int getTaskId() { + @Nonnull + public static String getJobId() { MapredContext ctx = MapredContextAccessor.get(); + if(ctx == null) { + throw new IllegalStateException("MapredContext is not set"); + } JobConf conf = ctx.getJobConf(); - int taskid = conf.getInt("mapred.task.partition", -1); + if(conf == null) { + throw new IllegalStateException("JobConf is not set"); + } + String jobId = conf.get("mapred.job.id"); + if(jobId == null) { + jobId = conf.get("mapreduce.job.id"); + if(jobId == null) { + String queryId = conf.get("hive.query.id"); + if(queryId != null) { + return queryId; + } + String taskidStr = conf.get("mapred.task.id"); + if(taskidStr == null) { + throw new IllegalStateException("Cannot resolve jobId: " + toString(conf)); + } + jobId = getJobIdFromTaskId(taskidStr); + } + } + return jobId; + } + + @Nonnull + public static String getJobIdFromTaskId(@Nonnull String taskidStr) { + if(!taskidStr.startsWith("task_")) {// workaround for Tez + taskidStr = taskidStr.replace("task", "task_"); + taskidStr = taskidStr.substring(0, taskidStr.lastIndexOf('_')); + } + TaskID taskId = TaskID.forName(taskidStr); + JobID jobId = taskId.getJobID(); + return jobId.toString(); + } + + public static int getTaskId() { + MapredContext ctx = MapredContextAccessor.get(); + if(ctx == null) { + throw new IllegalStateException("MapredContext is not set"); + } + JobConf jobconf = ctx.getJobConf(); + if(jobconf == null) { + throw new IllegalStateException("JobConf is not set"); + } + int taskid = jobconf.getInt("mapred.task.partition", -1); if(taskid == -1) { - throw new IllegalStateException("mapred.task.partition is not set"); + taskid = jobconf.getInt("mapreduce.task.partition", -1); + if(taskid == -1) { + throw new IllegalStateException("Both mapred.task.partition and mapreduce.task.partition are not set: " + + toString(jobconf)); + } } return taskid; } + @Nonnull + public static String toString(@Nonnull JobConf jobconf) { + return toString(jobconf, null); + } + + @Nonnull + public static String toString(@Nonnull JobConf jobconf, @Nullable String regexKey) { + final Iterator> itor = jobconf.iterator(); + boolean hasNext = itor.hasNext(); + if(!hasNext) { + return ""; + } + final StringBuilder buf = new StringBuilder(1024); + do { + Entry e = itor.next(); + hasNext = itor.hasNext(); + String k = e.getKey(); + if(k == null) { + continue; + } + if(regexKey == null || k.matches(regexKey)) { + String v = e.getValue(); + buf.append(k).append('=').append(v); + if(hasNext) { + buf.append(','); + } + } + } while(hasNext); + return buf.toString(); + } } diff --git a/src/main/hivemall/utils/io/IOUtils.java b/src/main/hivemall/utils/io/IOUtils.java new file mode 100644 index 00000000..f94b241a --- /dev/null +++ b/src/main/hivemall/utils/io/IOUtils.java @@ -0,0 +1,40 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013-2014 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.utils.io; + +import java.io.Closeable; +import java.io.IOException; + +public final class IOUtils { + + private IOUtils() {} + + public static void closeQuietly(final Closeable channel) { + if(channel != null) { + try { + channel.close(); + } catch (IOException e) { + ; + } + } + } + +} diff --git a/src/main/hivemall/utils/lang/Primitives.java b/src/main/hivemall/utils/lang/Primitives.java index 07ecd5d6..78236e96 100644 --- a/src/main/hivemall/utils/lang/Primitives.java +++ b/src/main/hivemall/utils/lang/Primitives.java @@ -24,29 +24,52 @@ public final class Primitives { private Primitives() {} - public static int parseInt(String s, int defaultValue) { + public static int toUnsignedShort(final short v) { + return v & 0xFFFF; // convert to range 0-65535 from -32768-32767. + } + + public static short parseShort(final String s, final short defaultValue) { + if(s == null) { + return defaultValue; + } + return Short.parseShort(s); + } + + public static int parseInt(final String s, final int defaultValue) { if(s == null) { return defaultValue; } return Integer.parseInt(s); } - public static float parseFloat(String s, float defaultValue) { + public static long parseLong(final String s, final long defaultValue) { + if(s == null) { + return defaultValue; + } + return Long.parseLong(s); + } + + public static float parseFloat(final String s, final float defaultValue) { if(s == null) { return defaultValue; } return Float.parseFloat(s); } - public static boolean parseBoolean(String s, boolean defaultValue) { + public static boolean parseBoolean(final String s, final boolean defaultValue) { if(s == null) { return defaultValue; } return Boolean.parseBoolean(s); } - - public static int compare(int x, int y) { + + public static int compare(final int x, final int y) { return (x < y) ? -1 : ((x == y) ? 0 : 1); } + public static void putChar(final byte[] b, final int off, final char val) { + b[off + 1] = (byte) (val >>> 0); + b[off] = (byte) (val >>> 8); + } + } diff --git a/src/main/hivemall/utils/lang/StringUtils.java b/src/main/hivemall/utils/lang/StringUtils.java new file mode 100644 index 00000000..11ee3da1 --- /dev/null +++ b/src/main/hivemall/utils/lang/StringUtils.java @@ -0,0 +1,50 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013-2014 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.utils.lang; + +public final class StringUtils { + + private StringUtils() {} + + public static byte[] getBytes(final String s) { + final int len = s.length(); + final byte[] b = new byte[len * 2]; + for(int i = 0; i < len; i++) { + Primitives.putChar(b, i * 2, s.charAt(i)); + } + return b; + } + + public static String toString(byte[] b) { + return toString(b, 0, b.length); + } + + public static String toString(byte[] b, int off, int len) { + final int clen = len >>> 1; + final char[] c = new char[clen]; + for(int i = 0; i < clen; i++) { + final int j = off + (i << 1); + c[i] = (char) ((b[j + 1] & 0xFF) + ((b[j + 0]) << 8)); + } + return new String(c); + } + +} diff --git a/src/main/hivemall/utils/lock/Lock.java b/src/main/hivemall/utils/lock/Lock.java new file mode 100644 index 00000000..c867fc30 --- /dev/null +++ b/src/main/hivemall/utils/lock/Lock.java @@ -0,0 +1,33 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013-2014 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.utils.lock; + +public interface Lock { + + void lock(); + + boolean tryLock(); + + void unlock(); + + boolean isLocked(); + +} diff --git a/src/main/hivemall/utils/lock/TTASLock.java b/src/main/hivemall/utils/lock/TTASLock.java new file mode 100644 index 00000000..bb5fb482 --- /dev/null +++ b/src/main/hivemall/utils/lock/TTASLock.java @@ -0,0 +1,65 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013-2014 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.utils.lock; + +import java.util.concurrent.atomic.AtomicBoolean; + +public final class TTASLock implements Lock { + + private final AtomicBoolean state; + + public TTASLock() { + this(false); + } + + public TTASLock(boolean locked) { + this.state = new AtomicBoolean(locked); + } + + @Override + public void lock() { + while(true) { + while(state.get()) + ; // wait until the lock free + if(!state.getAndSet(true)) { // now try to acquire the lock + return; + } + } + } + + @Override + public boolean tryLock() { + if(state.get()) { + return false; + } + return !state.getAndSet(true); + } + + @Override + public void unlock() { + state.set(false); + } + + @Override + public boolean isLocked() { + return state.get(); + } +} diff --git a/src/main/hivemall/utils/net/NetUtils.java b/src/main/hivemall/utils/net/NetUtils.java new file mode 100644 index 00000000..935968c9 --- /dev/null +++ b/src/main/hivemall/utils/net/NetUtils.java @@ -0,0 +1,57 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013-2014 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.utils.net; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; + +public final class NetUtils { + + private NetUtils() {} + + public static InetSocketAddress getInetSocketAddress(String endpointURI, int defaultPort) { + final int pos = endpointURI.indexOf(':'); + if(pos == -1) { + InetAddress addr = getInetAddress(endpointURI); + return new InetSocketAddress(addr, defaultPort); + } else { + String host = endpointURI.substring(0, pos); + InetAddress addr = getInetAddress(host); + String portStr = endpointURI.substring(pos + 1); + int port = Integer.parseInt(portStr); + return new InetSocketAddress(addr, port); + } + } + + public static InetAddress getInetAddress(final String addressOrName) { + try { + return InetAddress.getByName(addressOrName); + } catch (UnknownHostException e) { + throw new IllegalArgumentException("Cannot find InetAddress: " + addressOrName); + } + } + + public static boolean isIPAddress(final String ip) { + return ip.matches("^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])$"); + } + +} diff --git a/src/test/hivemall/common/FeatureValueTest.java b/src/test/hivemall/io/FeatureValueTest.java similarity index 98% rename from src/test/hivemall/common/FeatureValueTest.java rename to src/test/hivemall/io/FeatureValueTest.java index f9bc0757..4be2b06e 100644 --- a/src/test/hivemall/common/FeatureValueTest.java +++ b/src/test/hivemall/io/FeatureValueTest.java @@ -18,7 +18,7 @@ * License along with this library; if not, write to the Free Software * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA */ -package hivemall.common; +package hivemall.io; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; diff --git a/src/test/hivemall/common/SpaceEfficientDenseModelTest.java b/src/test/hivemall/io/SpaceEfficientDenseModelTest.java similarity index 72% rename from src/test/hivemall/common/SpaceEfficientDenseModelTest.java rename to src/test/hivemall/io/SpaceEfficientDenseModelTest.java index 41eedaaa..24b53b0d 100644 --- a/src/test/hivemall/common/SpaceEfficientDenseModelTest.java +++ b/src/test/hivemall/io/SpaceEfficientDenseModelTest.java @@ -1,9 +1,6 @@ -package hivemall.common; +package hivemall.io; import static junit.framework.Assert.assertEquals; -import hivemall.io.DenseModel; -import hivemall.io.SpaceEfficientDenseModel; -import hivemall.io.WeightValue; import hivemall.utils.collections.IMapIterator; import java.util.Random; @@ -17,19 +14,22 @@ public void testGetSet() { final int size = 1 << 12; final SpaceEfficientDenseModel model1 = new SpaceEfficientDenseModel(size); + //model1.configureClock(); final DenseModel model2 = new DenseModel(size); + //model2.configureClock(); final Random rand = new Random(); for(int t = 0; t < 1000; t++) { int i = rand.nextInt(size); - float w = 65520f * rand.nextFloat(); - model1.setValue(i, w); - model2.setValue(i, w); + float f = 65520f * rand.nextFloat(); + IWeightValue w = new WeightValue(f); + model1.set(i, w); + model2.set(i, w); } assertEquals(model2.size(), model1.size()); - IMapIterator itor = model1.entries(); + IMapIterator itor = model1.entries(); while(itor.next() != -1) { int k = itor.getKey(); float expected = itor.getValue().get(); diff --git a/src/test/hivemall/mix/client/MixRequestRouterTest.java b/src/test/hivemall/mix/client/MixRequestRouterTest.java new file mode 100644 index 00000000..449c1b83 --- /dev/null +++ b/src/test/hivemall/mix/client/MixRequestRouterTest.java @@ -0,0 +1,37 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013-2014 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.mix.client; + +import hivemall.mix.NodeInfo; +import junit.framework.Assert; + +import org.junit.Test; + +public class MixRequestRouterTest { + + @Test + public void test() { + MixRequestRouter router = new MixRequestRouter("dm01.hpcc.jp:11212,yahoo.co.jp:11212,google.com"); + NodeInfo[] nodes = router.getAllNodes(); + Assert.assertEquals(3, nodes.length); + } + +} diff --git a/src/test/hivemall/mix/server/MixServerTest.java b/src/test/hivemall/mix/server/MixServerTest.java new file mode 100644 index 00000000..c747e7ad --- /dev/null +++ b/src/test/hivemall/mix/server/MixServerTest.java @@ -0,0 +1,238 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013-2014 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.mix.server; + +import hivemall.io.DenseModel; +import hivemall.io.PredictionModel; +import hivemall.io.SparseModel; +import hivemall.io.WeightValue; +import hivemall.mix.MixMessage.MixEventName; +import hivemall.mix.client.MixClient; +import hivemall.utils.io.IOUtils; +import hivemall.utils.lang.CommandLineUtils; + +import java.util.Random; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import junit.framework.Assert; + +import org.apache.commons.cli.CommandLine; +import org.junit.Test; + +public class MixServerTest { + + @Test + public void testSimpleScenario() throws InterruptedException { + CommandLine cl = CommandLineUtils.parseOptions(new String[] { "-port", "11212", + "-sync_threshold", "3" }, MixServer.getOptions()); + MixServer server = new MixServer(cl); + ExecutorService serverExec = Executors.newSingleThreadExecutor(); + serverExec.submit(server); + + PredictionModel model = new DenseModel(16777216, false); + model.configureClock(); + MixClient client = new MixClient(MixEventName.average, "testSimpleScenario", "localhost:11212", false, 2, model); + model.setUpdateHandler(client); + + final Random rand = new Random(43); + for(int i = 0; i < 100000; i++) { + Integer feature = Integer.valueOf(rand.nextInt(100)); + float weight = (float) rand.nextGaussian(); + model.set(feature, new WeightValue(weight)); + } + + Thread.sleep(5 * 1000); + int numMixed = model.getNumMixed(); + //System.out.println("number of mix events: " + numMixed); + Assert.assertTrue("number of mix events: " + numMixed, numMixed > 0); + + IOUtils.closeQuietly(client); + serverExec.shutdown(); + } + + @Test + public void testSSL() throws InterruptedException { + CommandLine cl = CommandLineUtils.parseOptions(new String[] { "-port", "11213", + "-sync_threshold", "3", "-ssl" }, MixServer.getOptions()); + MixServer server = new MixServer(cl); + ExecutorService serverExec = Executors.newSingleThreadExecutor(); + serverExec.submit(server); + + PredictionModel model = new DenseModel(16777216, false); + model.configureClock(); + MixClient client = new MixClient(MixEventName.average, "testSSL", "localhost:11213", true, 2, model); + model.setUpdateHandler(client); + + final Random rand = new Random(43); + for(int i = 0; i < 100000; i++) { + Integer feature = Integer.valueOf(rand.nextInt(100)); + float weight = (float) rand.nextGaussian(); + model.set(feature, new WeightValue(weight)); + } + + Thread.sleep(5 * 1000); + int numMixed = model.getNumMixed(); + //System.out.println("number of mix events: " + numMixed); + Assert.assertTrue("number of mix events: " + numMixed, numMixed > 0); + + IOUtils.closeQuietly(client); + serverExec.shutdown(); + } + + @Test + public void testMultipleClients() throws InterruptedException { + CommandLine cl = CommandLineUtils.parseOptions(new String[] { "-port", "11214", + "-sync_threshold", "3" }, MixServer.getOptions()); + MixServer server = new MixServer(cl); + ExecutorService serverExec = Executors.newSingleThreadExecutor(); + serverExec.submit(server); + + Thread.sleep(500);// slight delay to boot a server + + final int numClients = 5; + final ExecutorService clientsExec = Executors.newCachedThreadPool(); + for(int i = 0; i < numClients; i++) { + clientsExec.submit(new Runnable() { + @Override + public void run() { + try { + invokeClient("testMultipleClients", 11214); + } catch (InterruptedException e) { + Assert.fail(e.getMessage()); + } + } + }); + } + clientsExec.awaitTermination(10, TimeUnit.SECONDS); + clientsExec.shutdown(); + serverExec.shutdown(); + } + + private static void invokeClient(String groupId, int serverPort) throws InterruptedException { + PredictionModel model = new DenseModel(16777216, false); + model.configureClock(); + MixClient client = new MixClient(MixEventName.average, groupId, "localhost:" + serverPort, false, 2, model); + model.setUpdateHandler(client); + + final Random rand = new Random(43); + for(int i = 0; i < 100000; i++) { + Integer feature = Integer.valueOf(rand.nextInt(100)); + float weight = (float) rand.nextGaussian(); + model.set(feature, new WeightValue(weight)); + } + + Thread.sleep(5 * 1000); + + int numMixed = model.getNumMixed(); + //System.out.println("number of mix events: " + numMixed); + Assert.assertTrue("number of mix events: " + numMixed, numMixed > 0); + + IOUtils.closeQuietly(client); + } + + @Test + public void test2ClientsZeroOneSparseModel() throws InterruptedException { + CommandLine cl = CommandLineUtils.parseOptions(new String[] { "-port", "11215", + "-sync_threshold", "30" }, MixServer.getOptions()); + MixServer server = new MixServer(cl); + ExecutorService serverExec = Executors.newSingleThreadExecutor(); + serverExec.submit(server); + + Thread.sleep(500);// slight delay to boot a server + + final ExecutorService clientsExec = Executors.newCachedThreadPool(); + for(int i = 0; i < 2; i++) { + clientsExec.submit(new Runnable() { + @Override + public void run() { + try { + invokeClient01("test2ClientsZeroOne", 11215, false); + } catch (InterruptedException e) { + Assert.fail(e.getMessage()); + } + } + }); + } + clientsExec.awaitTermination(30, TimeUnit.SECONDS); + clientsExec.shutdown(); + serverExec.shutdown(); + } + + @Test + public void test2ClientsZeroOneDenseModel() throws InterruptedException { + CommandLine cl = CommandLineUtils.parseOptions(new String[] { "-port", "11215", + "-sync_threshold", "30" }, MixServer.getOptions()); + MixServer server = new MixServer(cl); + ExecutorService serverExec = Executors.newSingleThreadExecutor(); + serverExec.submit(server); + + Thread.sleep(500);// slight delay to boot a server + + final ExecutorService clientsExec = Executors.newCachedThreadPool(); + for(int i = 0; i < 2; i++) { + clientsExec.submit(new Runnable() { + @Override + public void run() { + try { + invokeClient01("test2ClientsZeroOne", 11215, true); + } catch (InterruptedException e) { + Assert.fail(e.getMessage()); + } + } + }); + } + clientsExec.awaitTermination(30, TimeUnit.SECONDS); + clientsExec.shutdown(); + serverExec.shutdown(); + } + + private static void invokeClient01(String groupId, int serverPort, boolean denseModel) + throws InterruptedException { + PredictionModel model = denseModel ? new DenseModel(100, false) + : new SparseModel(100, false); + model.configureClock(); + MixClient client = new MixClient(MixEventName.average, groupId, "localhost:" + serverPort, false, 3, model); + model.setUpdateHandler(client); + + final Random rand = new Random(43); + for(int i = 0; i < 1000000; i++) { + Integer feature = Integer.valueOf(rand.nextInt(100)); + float weight = rand.nextFloat() >= 0.5f ? 1.f : 0.f; + model.set(feature, new WeightValue(weight)); + } + + Thread.sleep(5 * 1000); + + int numMixed = model.getNumMixed(); + //System.out.println("number of mix events: " + numMixed); + Assert.assertTrue("number of mix events: " + numMixed, numMixed > 0); + + for(int i = 0; i < 100; i++) { + float w = model.getWeight(i); + Assert.assertEquals(0.5f, w, 0.1f); + } + + IOUtils.closeQuietly(client); + } + +} diff --git a/src/test/hivemall/utils/collections/OpenHashMapTest.java b/src/test/hivemall/utils/collections/OpenHashMapTest.java index cc5d69aa..f6ce0087 100644 --- a/src/test/hivemall/utils/collections/OpenHashMapTest.java +++ b/src/test/hivemall/utils/collections/OpenHashMapTest.java @@ -33,7 +33,7 @@ public class OpenHashMapTest { @Test public void testPutAndGet() { Map map = new OpenHashMap(16384); - final int numEntries = 10000000; + final int numEntries = 1000000; for(int i = 0; i < numEntries; i++) { map.put(Integer.toString(i), i); } diff --git a/src/test/hivemall/utils/hadoop/HadoopUtilsTest.java b/src/test/hivemall/utils/hadoop/HadoopUtilsTest.java new file mode 100644 index 00000000..e604b761 --- /dev/null +++ b/src/test/hivemall/utils/hadoop/HadoopUtilsTest.java @@ -0,0 +1,35 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2013-2014 + * National Institute of Advanced Industrial Science and Technology (AIST) + * Registration Number: H25PRO-1520 + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ +package hivemall.utils.hadoop; + +import junit.framework.Assert; + +import org.junit.Test; + +public class HadoopUtilsTest { + + @Test + public void testGetJobIdFromTaskId() { + String actual = HadoopUtils.getJobIdFromTaskId("task1407733647643_0188_m_000000_0"); + Assert.assertEquals("job_1407733647643_0188", actual); + } + +} diff --git a/target/hivemall-fat.jar b/target/hivemall-fat.jar new file mode 100644 index 00000000..0fd9a61f Binary files /dev/null and b/target/hivemall-fat.jar differ diff --git a/target/hivemall-with-dependencies.jar b/target/hivemall-with-dependencies.jar new file mode 100644 index 00000000..8ff806ef Binary files /dev/null and b/target/hivemall-with-dependencies.jar differ diff --git a/target/hivemall.jar b/target/hivemall.jar index 8d1a48be..34f5150b 100644 Binary files a/target/hivemall.jar and b/target/hivemall.jar differ