diff --git a/.classpath b/.classpath
index f33cd730..fa36f04a 100644
--- a/.classpath
+++ b/.classpath
@@ -11,5 +11,7 @@
+
+
diff --git a/bin/run_mixserv.sh b/bin/run_mixserv.sh
new file mode 100644
index 00000000..686d2d2f
--- /dev/null
+++ b/bin/run_mixserv.sh
@@ -0,0 +1,5 @@
+#!/bin/sh
+
+VMOPTS="-Xmx4g -da -server -XX:+PrintGCDetails -XX:+UseNUMA -XX:+UseParallelGC $VMOPTS"
+
+java ${VMOPTS} -jar hivemall-fat.jar $@
diff --git a/build.properties b/build.properties
index 000f14e6..52b05013 100644
--- a/build.properties
+++ b/build.properties
@@ -2,8 +2,6 @@
# Project property file for Hivemall #
######################################
-version=0.1
-
lib.dir=lib/
src.dir=src/main
test.dir=src/test
@@ -12,7 +10,19 @@ build.dir=${target.dir}/classes
test.build.dir=${target.dir}/test-classes
test.result.dir=${target.dir}/test-results
-## javac: --------------------------------------------
+## MANIFEST: ----------------------------------------
+
+user.name=myui
+#java.version=1.6
+
+project.version=0.3
+project.name=Hivemall
+project.groupId=hivemall
+project.organization.name=AIST, Japan.
+
+jar.mainclass=hivemall.mix.server.MixServer
+
+## javac: -------------------------------------------
javac.source=1.6
javac.target=1.6
@@ -20,12 +30,7 @@ javac.debug=on
javac.debuglevel=lines,source,vars
#bootclasspath=/jre/lib/rt.jar
-## jar: --------------------------------------------
-
-jar.title=Hivemall
-jar.vendor=AIST, Japan.
-
-## javadoc: --------------------------------------------
+## javadoc: -----------------------------------------
javadoc.dstdir=${target.dir}/docs/api
javadoc.title=Hivemall API
diff --git a/build.xml b/build.xml
index 76d4680d..30bf95ea 100644
--- a/build.xml
+++ b/build.xml
@@ -1,5 +1,5 @@
-
+
@@ -38,7 +38,7 @@
-
+
@@ -46,16 +46,57 @@
-
+
+
+
+
+
+
-
-
-
+
+
+
+
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -66,7 +107,7 @@
-
+
diff --git a/lib/jsr305-1.3.9.jar b/lib/jsr305-1.3.9.jar
new file mode 100644
index 00000000..a9afc661
Binary files /dev/null and b/lib/jsr305-1.3.9.jar differ
diff --git a/lib/license/netty.LICENSE b/lib/license/netty.LICENSE
new file mode 100644
index 00000000..d6456956
--- /dev/null
+++ b/lib/license/netty.LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/lib/netty-all-4.0.23.Final.jar b/lib/netty-all-4.0.23.Final.jar
new file mode 100644
index 00000000..0555a164
Binary files /dev/null and b/lib/netty-all-4.0.23.Final.jar differ
diff --git a/lib/source/netty-all-4.0.23.Final-sources.jar b/lib/source/netty-all-4.0.23.Final-sources.jar
new file mode 100644
index 00000000..b04c08cf
Binary files /dev/null and b/lib/source/netty-all-4.0.23.Final-sources.jar differ
diff --git a/pom.xml b/pom.xml
index 91d9b6e5..062e173c 100644
--- a/pom.xml
+++ b/pom.xml
@@ -1,95 +1,254 @@
- 4.0.0
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ 4.0.0
- hivemall
- hivemall
- 0.2
- jar
+ hivemall
+ hivemall
+ 0.3
- hivemall
- http://github.com/myui/hivemall
+ Hivemall
+ Scalable Machine Learning Library for Apache Hive
+ https://github.com/myui/hivemall/
+ 2013
+
+ AIST, Japan.
+
-
- UTF-8
-
+
+
+ LGPL 2.1 license
+ http://www.opensource.org/licenses/lgpl-2.1.php
+
+
-
-
- cloudera
- https://repository.cloudera.com/artifactory/cloudera-repos/
-
-
+ jar
+
+ true
+ UTF-8
+
-
-
- org.apache.hadoop
- hadoop-core
- 0.20.2-cdh3u6
- provided
-
-
- org.apache.hive
- hive-exec
- 0.11.0
- provided
-
-
- jetty
- org.mortbay.jetty
-
-
- javax.jdo
- jdo2-api
-
-
-
-
- org.apache.hive
- hive-serde
- 0.11.0
- provided
-
-
- commons-cli
- commons-cli
- 1.2
- provided
-
-
- javax.jdo
- jdo2-api
- 2.3-eb
-
+
+
+ cloudera
+ https://repository.cloudera.com/artifactory/cloudera-repos/
+
+
-
- junit
- junit
- 4.10
- test
-
-
+
-
- target
- target/classes
- ${project.artifactId}-${project.version}
- target/test-classes
- src/main
- src/test
+
+
+ org.apache.hadoop
+ hadoop-core
+ 0.20.2-cdh3u6
+ provided
+
+
+ org.apache.hive
+ hive-exec
+ 0.11.0
+ provided
+
+
+ jetty
+ org.mortbay.jetty
+
+
+ javax.jdo
+ jdo2-api
+
+
+
+
+ org.apache.hive
+ hive-serde
+ 0.11.0
+ provided
+
+
+ commons-cli
+ commons-cli
+ 1.2
+ provided
+
+
+ commons-logging
+ commons-logging
+ 1.0.4
+ provided
+
+
+ javax.jdo
+ jdo2-api
+ 2.3-eb
+ provided
+
+
+ org.apache.hadoop.thirdparty.guava
+ guava
+ r09-jarjar
+ provided
+
-
-
- org.apache.maven.plugins
- maven-compiler-plugin
- 3.1
-
- 1.6
- 1.6
- UTF-8
-
-
-
-
+
+
+ io.netty
+
+ netty-all
+ 4.0.23.Final
+ compile
+
+
+ com.google.code.findbugs
+ jsr305
+ 1.3.9
+ compile
+
+
+
+
+ junit
+ junit
+ 4.10
+ test
+
+
+
+
+
+ target
+ target/classes
+ ${project.artifactId}-${project.version}
+ target/test-classes
+ src/main
+ src/test
+
+
+
+ org.apache.maven.plugins
+ maven-failsafe-plugin
+ 2.17
+
+ ${skipTests}
+
+
+
+ org.codehaus.mojo
+ properties-maven-plugin
+ 1.0-alpha-2
+
+
+ initialize
+
+ read-project-properties
+
+
+
+ ${basedir}/build.properties
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+ 3.1
+
+ ${javac.source}
+ ${javac.target}
+ UTF-8
+
+
+
+
+ org.apache.maven.plugins
+ maven-jar-plugin
+ 2.5
+
+ ${project.artifactId}-${project.version}
+
+
+ ${jar.mainclass}
+ true
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-shade-plugin
+ 2.3
+
+
+ jar-with-dependencies
+ package
+
+ shade
+
+
+ ${project.artifactId}-${project.version}-with-dependencies
+ true
+ false
+
+
+ io.netty:netty-all
+ com.google.code.findbugs:jsr305
+
+
+ junit:junit
+ javax.jdo:jdo2-api
+ org.apache.hadoop.thirdparty.guava:guava
+
+
+
+
+
+ ${jar.mainclass}
+ ${project.name}
+ ${project.version}
+ ${project.organization.name}
+
+
+
+
+
+
+
+
+
+ org.sonatype.plugins
+ jarjar-maven-plugin
+ 1.9
+
+
+ package
+
+ jarjar
+
+
+ ${project.build.directory}/${project.artifactId}-${project.version}-fat.jar
+
+ org.apache.hadoop:hadoop-core
+ org.apache.hive:hive-exec
+ org.apache.hive:hive-serde
+ commons-cli:commons-cli
+ commons-logging:commons-logging
+ com.google.code.findbugs:jsr305
+ org.apache.hadoop.thirdparty.guava:guava
+
+
+ junit:junit
+ javax.jdo:jdo2-api
+
+
+
+
+
+
+
+
diff --git a/scripts/ddl/define-all.hive b/scripts/ddl/define-all.hive
index 60548282..8aafc585 100644
--- a/scripts/ddl/define-all.hive
+++ b/scripts/ddl/define-all.hive
@@ -313,12 +313,18 @@ create temporary function sigmoid as 'hivemall.tools.math.SigmodUDF';
drop temporary function taskid;
create temporary function taskid as 'hivemall.tools.mapred.TaskIdUDF';
+drop temporary function jobid;
+create temporary function jobid as 'hivemall.tools.mapred.JobIdUDF';
+
drop temporary function rowid;
create temporary function rowid as 'hivemall.tools.mapred.RowIdUDF';
drop temporary function distcache_gets;
create temporary function distcache_gets as 'hivemall.tools.mapred.DistributedCacheLookupUDF';
+drop temporary function jobconf_gets;
+create temporary function jobconf_gets as 'hivemall.tools.mapred.JobConfGetsUDF';
+
--------------------
-- misc functions --
--------------------
diff --git a/scripts/ddl/define-tools-udf.hive b/scripts/ddl/define-tools-udf.hive
index 24ffde91..65522bd9 100644
--- a/scripts/ddl/define-tools-udf.hive
+++ b/scripts/ddl/define-tools-udf.hive
@@ -68,12 +68,18 @@ create temporary function sigmoid as 'hivemall.tools.math.SigmodUDF';
drop temporary function taskid;
create temporary function taskid as 'hivemall.tools.mapred.TaskIdUDF';
+drop temporary function jobid;
+create temporary function jobid as 'hivemall.tools.mapred.JobIdUDF';
+
drop temporary function rowid;
create temporary function rowid as 'hivemall.tools.mapred.RowIdUDF';
drop temporary function distcache_gets;
create temporary function distcache_gets as 'hivemall.tools.mapred.DistributedCacheLookupUDF';
+drop temporary function jobconf_gets;
+create temporary function jobconf_gets as 'hivemall.tools.mapred.JobConfGetsUDF';
+
----------------------
-- string functions --
----------------------
diff --git a/scripts/misc/emr_hivemall_bootstrap.sh b/scripts/misc/emr_hivemall_bootstrap.sh
index 652ad1c3..ad41fbf3 100644
--- a/scripts/misc/emr_hivemall_bootstrap.sh
+++ b/scripts/misc/emr_hivemall_bootstrap.sh
@@ -2,4 +2,4 @@
mkdir -p /home/hadoop/tmp
wget --no-check-certificate -P /home/hadoop/tmp \
- https://github.com/myui/hivemall/raw/master/target/hivemall.jar https://github.com/myui/hivemall/raw/master/scripts/ddl/define-all.hive
+ https://github.com/myui/hivemall/raw/master/target/hivemall-with-dependencies.jar https://github.com/myui/hivemall/raw/master/scripts/ddl/define-all.hive
diff --git a/src/main/hivemall/LearnerBaseUDTF.java b/src/main/hivemall/LearnerBaseUDTF.java
index 9da30073..615e3d6e 100644
--- a/src/main/hivemall/LearnerBaseUDTF.java
+++ b/src/main/hivemall/LearnerBaseUDTF.java
@@ -25,11 +25,15 @@
import hivemall.io.PredictionModel;
import hivemall.io.SpaceEfficientDenseModel;
import hivemall.io.SparseModel;
+import hivemall.io.SynchronizedModelWrapper;
import hivemall.io.WeightValue;
import hivemall.io.WeightValue.WeightValueWithCovar;
+import hivemall.mix.MixMessage.MixEventName;
+import hivemall.mix.client.MixClient;
import hivemall.utils.datetime.StopWatch;
import hivemall.utils.hadoop.HadoopUtils;
import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.io.IOUtils;
import hivemall.utils.lang.Primitives;
import java.io.BufferedReader;
@@ -42,6 +46,7 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.SerDeException;
import org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
@@ -59,6 +64,12 @@ public abstract class LearnerBaseUDTF extends UDTFWithOptions {
protected boolean dense_model;
protected int model_dims;
protected boolean disable_halffloat;
+ protected String mixConnectInfo;
+ protected String mixSessionName;
+ protected int mixThreshold;
+ protected boolean ssl;
+
+ protected MixClient mixClient;
public LearnerBaseUDTF() {}
@@ -73,6 +84,10 @@ protected Options getOptions() {
opts.addOption("dense", "densemodel", false, "Use dense model or not");
opts.addOption("dims", "feature_dimensions", true, "The dimension of model [default: 16777216 (2^24)]");
opts.addOption("disable_halffloat", false, "Toggle this option to disable the use of SpaceEfficientDenseModel");
+ opts.addOption("mix", "mix_servers", true, "Comma separated list of MIX servers");
+ opts.addOption("mix_session", "mix_session_name", true, "Mix session name [default: ${mapred.job.id}]");
+ opts.addOption("mix_threshold", true, "Threshold to mix local updates in range (0,127] [default: 3]");
+ opts.addOption("ssl", false, "Use SSL for the communication with mix servers");
return opts;
}
@@ -82,6 +97,10 @@ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumen
boolean denseModel = false;
int modelDims = -1;
boolean disableHalfFloat = false;
+ String mixConnectInfo = null;
+ String mixSessionName = null;
+ int mixThreshold = -1;
+ boolean ssl = false;
CommandLine cl = null;
if(argOIs.length >= 3) {
@@ -96,33 +115,73 @@ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumen
}
disableHalfFloat = cl.hasOption("disable_halffloat");
+
+ mixConnectInfo = cl.getOptionValue("mix");
+ mixSessionName = cl.getOptionValue("mix_session");
+ mixThreshold = Primitives.parseInt(cl.getOptionValue("mix_threshold"), 3);
+ if(mixThreshold > Byte.MAX_VALUE) {
+ throw new UDFArgumentException("mix_threshold must be in range (0,127]: "
+ + mixThreshold);
+ }
+ ssl = cl.hasOption("ssl");
}
this.preloadedModelFile = modelfile;
this.dense_model = denseModel;
this.model_dims = modelDims;
this.disable_halffloat = disableHalfFloat;
+ this.mixConnectInfo = mixConnectInfo;
+ this.mixSessionName = mixSessionName;
+ this.mixThreshold = mixThreshold;
+ this.ssl = ssl;
return cl;
}
protected PredictionModel createModel() {
+ return createModel(null);
+ }
+
+ protected PredictionModel createModel(String label) {
+ PredictionModel model;
+ final boolean useCovar = useCovariance();
if(dense_model) {
- boolean useCovar = useCovariance();
if(disable_halffloat == false && model_dims > 16777216) {
logger.info("Build a space efficient dense model with " + model_dims
+ " initial dimensions" + (useCovar ? " w/ covariances" : ""));
- return new SpaceEfficientDenseModel(model_dims, useCovar);
+ model = new SpaceEfficientDenseModel(model_dims, useCovar);
} else {
logger.info("Build a dense model with initial with " + model_dims
+ " initial dimensions" + (useCovar ? " w/ covariances" : ""));
- return new DenseModel(model_dims, useCovar);
+ model = new DenseModel(model_dims, useCovar);
}
} else {
int initModelSize = getInitialModelSize();
logger.info("Build a sparse model with initial with " + initModelSize
+ " initial dimensions");
- return new SparseModel(initModelSize);
+ model = new SparseModel(initModelSize, useCovar);
+ }
+ if(mixConnectInfo != null) {
+ model.configureClock();
+ model = new SynchronizedModelWrapper(model);
+ MixClient client = configureMixClient(mixConnectInfo, label, model);
+ model.setUpdateHandler(client);
+ this.mixClient = client;
+ }
+ assert (model != null);
+ return model;
+ }
+
+ protected MixClient configureMixClient(String connectURIs, String label, PredictionModel model) {
+ assert (connectURIs != null);
+ assert (model != null);
+ String jobId = (mixSessionName == null) ? MixClient.DUMMY_JOB_ID : mixSessionName;
+ if(label != null) {
+ jobId = jobId + '-' + label;
}
+ MixEventName event = useCovariance() ? MixEventName.argminKLD : MixEventName.average;
+ MixClient client = new MixClient(event, jobId, connectURIs, ssl, mixThreshold, model);
+ logger.info("Successfully configured mix client: " + connectURIs);
+ return client;
}
protected int getInitialModelSize() {
@@ -241,4 +300,13 @@ private static long loadPredictionModel(PredictionModel model, File file, Primit
}
return count;
}
+
+ @Override
+ public void close() throws HiveException {
+ if(mixClient != null) {
+ IOUtils.closeQuietly(mixClient);
+ this.mixClient = null;
+ }
+ }
+
}
diff --git a/src/main/hivemall/UDTFWithOptions.java b/src/main/hivemall/UDTFWithOptions.java
index efa24ee0..f392e68b 100644
--- a/src/main/hivemall/UDTFWithOptions.java
+++ b/src/main/hivemall/UDTFWithOptions.java
@@ -50,7 +50,7 @@ protected final CommandLine parseOptions(String optionValue) throws UDFArgumentE
protected abstract CommandLine processOptions(ObjectInspector[] argOIs)
throws UDFArgumentException;
- protected List parseFeatures(final List> features, final ObjectInspector featureInspector, final boolean parseFeature) {
+ protected final List parseFeatures(final List> features, final ObjectInspector featureInspector, final boolean parseFeature) {
final int numFeatures = features.size();
if(numFeatures == 0) {
return Collections.emptyList();
diff --git a/src/main/hivemall/classifier/AROWClassifierUDTF.java b/src/main/hivemall/classifier/AROWClassifierUDTF.java
index 51ba1a3c..e68d053c 100644
--- a/src/main/hivemall/classifier/AROWClassifierUDTF.java
+++ b/src/main/hivemall/classifier/AROWClassifierUDTF.java
@@ -22,8 +22,8 @@
import hivemall.common.LossFunctions;
import hivemall.io.FeatureValue;
+import hivemall.io.IWeightValue;
import hivemall.io.PredictionResult;
-import hivemall.io.WeightValue;
import hivemall.io.WeightValue.WeightValueWithCovar;
import java.util.List;
@@ -126,13 +126,13 @@ protected void update(final List> features, final float y, final float alpha,
k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector);
v = 1.f;
}
- WeightValue old_w = model.get(k);
- WeightValue new_w = getNewWeight(old_w, v, y, alpha, beta);
+ IWeightValue old_w = model.get(k);
+ IWeightValue new_w = getNewWeight(old_w, v, y, alpha, beta);
model.set(k, new_w);
}
}
- private static WeightValue getNewWeight(final WeightValue old, final float x, final float y, final float alpha, final float beta) {
+ private static IWeightValue getNewWeight(final IWeightValue old, final float x, final float y, final float alpha, final float beta) {
final float old_w;
final float old_cov;
if(old == null) {
diff --git a/src/main/hivemall/classifier/BinaryOnlineClassifierUDTF.java b/src/main/hivemall/classifier/BinaryOnlineClassifierUDTF.java
index f28aead2..56c40fc0 100644
--- a/src/main/hivemall/classifier/BinaryOnlineClassifierUDTF.java
+++ b/src/main/hivemall/classifier/BinaryOnlineClassifierUDTF.java
@@ -25,6 +25,7 @@
import static hivemall.HivemallConstants.STRING_TYPE_NAME;
import hivemall.LearnerBaseUDTF;
import hivemall.io.FeatureValue;
+import hivemall.io.IWeightValue;
import hivemall.io.PredictionModel;
import hivemall.io.PredictionResult;
import hivemall.io.WeightValue;
@@ -219,7 +220,7 @@ protected PredictionResult calcScoreAndVariance(List> features) {
k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector);
v = 1.f;
}
- WeightValue old_w = model.get(k);
+ IWeightValue old_w = model.get(k);
if(old_w == null) {
variance += (1.f * v * v);
} else {
@@ -259,7 +260,8 @@ protected void update(final List> features, final float coeff) {
}
@Override
- public void close() throws HiveException {
+ public final void close() throws HiveException {
+ super.close();
if(model != null) {
int numForwarded = 0;
if(useCovariance()) {
@@ -267,7 +269,7 @@ public void close() throws HiveException {
final Object[] forwardMapObj = new Object[3];
final FloatWritable fv = new FloatWritable();
final FloatWritable cov = new FloatWritable();
- final IMapIterator itor = model.entries();
+ final IMapIterator itor = model.entries();
while(itor.next() != -1) {
itor.getValue(probe);
if(!probe.isTouched()) {
@@ -286,7 +288,7 @@ public void close() throws HiveException {
final WeightValue probe = new WeightValue();
final Object[] forwardMapObj = new Object[2];
final FloatWritable fv = new FloatWritable();
- final IMapIterator itor = model.entries();
+ final IMapIterator itor = model.entries();
while(itor.next() != -1) {
itor.getValue(probe);
if(!probe.isTouched()) {
@@ -300,10 +302,11 @@ public void close() throws HiveException {
numForwarded++;
}
}
+ int numMixed = model.getNumMixed();
this.model = null;
- logger.info("Trained a prediction model using " + count
- + " training examples. Forwarded the prediction model of " + numForwarded
- + " rows");
+ logger.info("Trained a prediction model using " + count + " training examples"
+ + (numMixed > 0 ? "( numMixed: " + numMixed + " )" : ""));
+ logger.info("Forwarded the prediction model of " + numForwarded + " rows");
}
}
diff --git a/src/main/hivemall/classifier/ConfidenceWeightedUDTF.java b/src/main/hivemall/classifier/ConfidenceWeightedUDTF.java
index f45264cd..b2f3be57 100644
--- a/src/main/hivemall/classifier/ConfidenceWeightedUDTF.java
+++ b/src/main/hivemall/classifier/ConfidenceWeightedUDTF.java
@@ -21,8 +21,8 @@
package hivemall.classifier;
import hivemall.io.FeatureValue;
+import hivemall.io.IWeightValue;
import hivemall.io.PredictionResult;
-import hivemall.io.WeightValue;
import hivemall.io.WeightValue.WeightValueWithCovar;
import hivemall.utils.math.StatsUtils;
@@ -141,13 +141,13 @@ protected void update(final List> features, final float coeff, final float alp
k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector);
v = 1.f;
}
- WeightValue old_w = model.get(k);
- WeightValue new_w = getNewWeight(old_w, v, coeff, alpha, phi);
+ IWeightValue old_w = model.get(k);
+ IWeightValue new_w = getNewWeight(old_w, v, coeff, alpha, phi);
model.set(k, new_w);
}
}
- private static WeightValue getNewWeight(final WeightValue old, final float x, final float coeff, final float alpha, final float phi) {
+ private static IWeightValue getNewWeight(final IWeightValue old, final float x, final float coeff, final float alpha, final float phi) {
final float old_w, old_cov;
if(old == null) {
old_w = 0.f;
diff --git a/src/main/hivemall/classifier/SoftConfideceWeightedUDTF.java b/src/main/hivemall/classifier/SoftConfideceWeightedUDTF.java
index e0b8b437..db3eb94c 100644
--- a/src/main/hivemall/classifier/SoftConfideceWeightedUDTF.java
+++ b/src/main/hivemall/classifier/SoftConfideceWeightedUDTF.java
@@ -21,8 +21,8 @@
package hivemall.classifier;
import hivemall.io.FeatureValue;
+import hivemall.io.IWeightValue;
import hivemall.io.PredictionResult;
-import hivemall.io.WeightValue;
import hivemall.io.WeightValue.WeightValueWithCovar;
import hivemall.utils.math.StatsUtils;
@@ -251,13 +251,13 @@ protected void update(final List> features, final float y, final float alpha,
k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector);
v = 1.f;
}
- WeightValue old_w = model.get(k);
- WeightValue new_w = getNewWeight(old_w, v, y, alpha, beta);
+ IWeightValue old_w = model.get(k);
+ IWeightValue new_w = getNewWeight(old_w, v, y, alpha, beta);
model.set(k, new_w);
}
}
- private static WeightValue getNewWeight(final WeightValue old, final float x, final float y, final float alpha, final float beta) {
+ private static IWeightValue getNewWeight(final IWeightValue old, final float x, final float y, final float alpha, final float beta) {
final float old_v;
final float old_cov;
if(old == null) {
diff --git a/src/main/hivemall/classifier/multiclass/MulticlassAROWClassifierUDTF.java b/src/main/hivemall/classifier/multiclass/MulticlassAROWClassifierUDTF.java
index c2f3478d..daa46cfb 100644
--- a/src/main/hivemall/classifier/multiclass/MulticlassAROWClassifierUDTF.java
+++ b/src/main/hivemall/classifier/multiclass/MulticlassAROWClassifierUDTF.java
@@ -21,9 +21,9 @@
package hivemall.classifier.multiclass;
import hivemall.io.FeatureValue;
+import hivemall.io.IWeightValue;
import hivemall.io.Margin;
import hivemall.io.PredictionModel;
-import hivemall.io.WeightValue;
import hivemall.io.WeightValue.WeightValueWithCovar;
import java.util.List;
@@ -142,19 +142,19 @@ protected void update(final List> features, final Object actual_label, final O
k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector);
v = 1.f;
}
- WeightValue old_correctclass_w = model2add.get(k);
- WeightValue new_correctclass_w = getNewWeight(old_correctclass_w, v, alpha, beta, true);
+ IWeightValue old_correctclass_w = model2add.get(k);
+ IWeightValue new_correctclass_w = getNewWeight(old_correctclass_w, v, alpha, beta, true);
model2add.set(k, new_correctclass_w);
if(model2sub != null) {
- WeightValue old_wrongclass_w = model2sub.get(k);
- WeightValue new_wrongclass_w = getNewWeight(old_wrongclass_w, v, alpha, beta, false);
+ IWeightValue old_wrongclass_w = model2sub.get(k);
+ IWeightValue new_wrongclass_w = getNewWeight(old_wrongclass_w, v, alpha, beta, false);
model2sub.set(k, new_wrongclass_w);
}
}
}
- private static WeightValue getNewWeight(final WeightValue old, final float v, final float alpha, final float beta, final boolean positive) {
+ private static IWeightValue getNewWeight(final IWeightValue old, final float v, final float alpha, final float beta, final boolean positive) {
final float old_v;
final float old_cov;
if(old == null) {
diff --git a/src/main/hivemall/classifier/multiclass/MulticlassConfidenceWeightedUDTF.java b/src/main/hivemall/classifier/multiclass/MulticlassConfidenceWeightedUDTF.java
index c77b2bcb..71d538a2 100644
--- a/src/main/hivemall/classifier/multiclass/MulticlassConfidenceWeightedUDTF.java
+++ b/src/main/hivemall/classifier/multiclass/MulticlassConfidenceWeightedUDTF.java
@@ -21,9 +21,9 @@
package hivemall.classifier.multiclass;
import hivemall.io.FeatureValue;
+import hivemall.io.IWeightValue;
import hivemall.io.Margin;
import hivemall.io.PredictionModel;
-import hivemall.io.WeightValue;
import hivemall.io.WeightValue.WeightValueWithCovar;
import hivemall.utils.math.StatsUtils;
@@ -163,19 +163,19 @@ protected void update(List> features, float alpha, Object actual_label, Object
k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector);
v = 1.f;
}
- WeightValue old_correctclass_w = model2add.get(k);
- WeightValue new_correctclass_w = getNewWeight(old_correctclass_w, v, alpha, phi, true);
+ IWeightValue old_correctclass_w = model2add.get(k);
+ IWeightValue new_correctclass_w = getNewWeight(old_correctclass_w, v, alpha, phi, true);
model2add.set(k, new_correctclass_w);
if(model2sub != null) {
- WeightValue old_wrongclass_w = model2sub.get(k);
- WeightValue new_wrongclass_w = getNewWeight(old_wrongclass_w, v, alpha, phi, false);
+ IWeightValue old_wrongclass_w = model2sub.get(k);
+ IWeightValue new_wrongclass_w = getNewWeight(old_wrongclass_w, v, alpha, phi, false);
model2sub.set(k, new_wrongclass_w);
}
}
}
- private static WeightValue getNewWeight(final WeightValue old, final float x, final float alpha, final float phi, final boolean positive) {
+ private static IWeightValue getNewWeight(final IWeightValue old, final float x, final float alpha, final float phi, final boolean positive) {
final float old_w, old_cov;
if(old == null) {
old_w = 0.f;
diff --git a/src/main/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java b/src/main/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java
index 87e3300e..334a570c 100644
--- a/src/main/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java
+++ b/src/main/hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.java
@@ -26,6 +26,7 @@
import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableFloatObjectInspector;
import hivemall.LearnerBaseUDTF;
import hivemall.io.FeatureValue;
+import hivemall.io.IWeightValue;
import hivemall.io.Margin;
import hivemall.io.PredictionModel;
import hivemall.io.PredictionResult;
@@ -319,7 +320,7 @@ protected final PredictionResult calcScoreAndVariance(final PredictionModel mode
k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector);
v = 1.f;
}
- WeightValue old_w = model.get(k);
+ IWeightValue old_w = model.get(k);
if(old_w == null) {
variance += (1.f * v * v);
} else {
@@ -380,9 +381,11 @@ protected void update(List> features, float coeff, Object actual_label, Object
}
@Override
- public void close() throws HiveException {
+ public final void close() throws HiveException {
+ super.close();
if(label2model != null) {
long numForwarded = 0L;
+ long numMixed = 0L;
if(useCovariance()) {
final WeightValueWithCovar probe = new WeightValueWithCovar();
final Object[] forwardMapObj = new Object[4];
@@ -392,7 +395,8 @@ public void close() throws HiveException {
Object label = entry.getKey();
forwardMapObj[0] = label;
PredictionModel model = entry.getValue();
- IMapIterator itor = model.entries();
+ numMixed += model.getNumMixed();
+ IMapIterator itor = model.entries();
while(itor.next() != -1) {
itor.getValue(probe);
if(!probe.isTouched()) {
@@ -416,7 +420,8 @@ public void close() throws HiveException {
Object label = entry.getKey();
forwardMapObj[0] = label;
PredictionModel model = entry.getValue();
- IMapIterator itor = model.entries();
+ numMixed += model.getNumMixed();
+ IMapIterator itor = model.entries();
while(itor.next() != -1) {
itor.getValue(probe);
if(!probe.isTouched()) {
@@ -432,9 +437,9 @@ public void close() throws HiveException {
}
}
this.label2model = null;
- logger.info("Trained a prediction model using " + count
- + " training examples. Forwarded the prediction model of " + numForwarded
- + " rows");
+ logger.info("Trained a prediction model using " + count + " training examples"
+ + (numMixed > 0 ? "( numMixed: " + numMixed + " )" : ""));
+ logger.info("Forwarded the prediction model of " + numForwarded + " rows");
}
}
diff --git a/src/main/hivemall/classifier/multiclass/MulticlassSoftConfidenceWeightedUDTF.java b/src/main/hivemall/classifier/multiclass/MulticlassSoftConfidenceWeightedUDTF.java
index bfc0816e..b80272eb 100644
--- a/src/main/hivemall/classifier/multiclass/MulticlassSoftConfidenceWeightedUDTF.java
+++ b/src/main/hivemall/classifier/multiclass/MulticlassSoftConfidenceWeightedUDTF.java
@@ -21,9 +21,9 @@
package hivemall.classifier.multiclass;
import hivemall.io.FeatureValue;
+import hivemall.io.IWeightValue;
import hivemall.io.Margin;
import hivemall.io.PredictionModel;
-import hivemall.io.WeightValue;
import hivemall.io.WeightValue.WeightValueWithCovar;
import hivemall.utils.math.StatsUtils;
@@ -271,19 +271,19 @@ protected void update(final List> features, final Object actual_label, final O
k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector);
v = 1.f;
}
- WeightValue old_correctclass_w = model2add.get(k);
- WeightValue new_correctclass_w = getNewWeight(old_correctclass_w, v, alpha, beta, true);
+ IWeightValue old_correctclass_w = model2add.get(k);
+ IWeightValue new_correctclass_w = getNewWeight(old_correctclass_w, v, alpha, beta, true);
model2add.set(k, new_correctclass_w);
if(model2sub != null) {
- WeightValue old_wrongclass_w = model2sub.get(k);
- WeightValue new_wrongclass_w = getNewWeight(old_wrongclass_w, v, alpha, beta, false);
+ IWeightValue old_wrongclass_w = model2sub.get(k);
+ IWeightValue new_wrongclass_w = getNewWeight(old_wrongclass_w, v, alpha, beta, false);
model2sub.set(k, new_wrongclass_w);
}
}
}
- private static WeightValue getNewWeight(final WeightValue old, final float v, final float alpha, final float beta, final boolean positive) {
+ private static IWeightValue getNewWeight(final IWeightValue old, final float v, final float alpha, final float beta, final boolean positive) {
final float old_v;
final float old_cov;
if(old == null) {
diff --git a/src/main/hivemall/io/AbstractPredictionModel.java b/src/main/hivemall/io/AbstractPredictionModel.java
new file mode 100644
index 00000000..54df01e6
--- /dev/null
+++ b/src/main/hivemall/io/AbstractPredictionModel.java
@@ -0,0 +1,99 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.io;
+
+public abstract class AbstractPredictionModel implements PredictionModel {
+ public static final byte BYTE0 = 0;
+
+ protected ModelUpdateHandler handler;
+ protected int numMixed;
+
+ public AbstractPredictionModel() {
+ this.numMixed = 0;
+ }
+
+ @Override
+ public ModelUpdateHandler getUpdateHandler() {
+ return handler;
+ }
+
+ @Override
+ public void setUpdateHandler(ModelUpdateHandler handler) {
+ this.handler = handler;
+ }
+
+ @Override
+ public final int getNumMixed() {
+ return numMixed;
+ }
+
+ @Override
+ public void resetDeltaUpdates(int feature) {
+ throw new UnsupportedOperationException();
+ }
+
+ protected final void onUpdate(final int feature, final float weight, final float covar, final short clock, final int deltaUpdates) {
+ if(deltaUpdates < 1) {
+ return;
+ }
+ if(handler != null) {
+ final boolean resetDeltaUpdates;
+ try {
+ resetDeltaUpdates = handler.onUpdate(feature, weight, covar, clock, deltaUpdates);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ if(resetDeltaUpdates) {
+ resetDeltaUpdates(feature);
+ }
+ }
+ }
+
+ protected final void onUpdate(final Object feature, final IWeightValue value) {
+ if(handler != null) {
+ if(!value.isTouched()) {
+ return;
+ }
+ final float weight = value.get();
+ final short clock = value.getClock();
+ final int deltaUpdates = value.getDeltaUpdates();
+ final boolean resetDeltaUpdates;
+ if(value.hasCovariance()) {
+ final float covar = value.getCovariance();
+ try {
+ resetDeltaUpdates = handler.onUpdate(feature, weight, covar, clock, deltaUpdates);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ } else {
+ try {
+ resetDeltaUpdates = handler.onUpdate(feature, weight, 1.f, clock, deltaUpdates);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+ if(resetDeltaUpdates) {
+ value.setDeltaUpdates(BYTE0);
+ }
+ }
+ }
+
+}
diff --git a/src/main/hivemall/io/DenseModel.java b/src/main/hivemall/io/DenseModel.java
index 8baec931..00141528 100644
--- a/src/main/hivemall/io/DenseModel.java
+++ b/src/main/hivemall/io/DenseModel.java
@@ -31,18 +31,22 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
-public final class DenseModel implements PredictionModel {
+public final class DenseModel extends AbstractPredictionModel {
private static final Log logger = LogFactory.getLog(DenseModel.class);
private int size;
private float[] weights;
private float[] covars;
+ private short[] clocks;
+ private byte[] deltaUpdates;
+
public DenseModel(int ndims) {
this(ndims, false);
}
public DenseModel(int ndims, boolean withCovar) {
+ super();
int size = ndims + 1;
this.size = size;
this.weights = new float[size];
@@ -53,6 +57,31 @@ public DenseModel(int ndims, boolean withCovar) {
} else {
this.covars = null;
}
+ this.clocks = null;
+ this.deltaUpdates = null;
+ }
+
+ @Override
+ public boolean hasCovariance() {
+ return covars != null;
+ }
+
+ @Override
+ public void configureClock() {
+ if(clocks == null) {
+ this.clocks = new short[size];
+ this.deltaUpdates = new byte[size];
+ }
+ }
+
+ @Override
+ public boolean hasClock() {
+ return clocks != null;
+ }
+
+ @Override
+ public void resetDeltaUpdates(int feature) {
+ deltaUpdates[feature] = 0;
}
private void ensureCapacity(final int index) {
@@ -68,13 +97,17 @@ private void ensureCapacity(final int index) {
this.covars = Arrays.copyOf(covars, newSize);
Arrays.fill(covars, oldSize, newSize, 1f);
}
+ if(clocks != null) {
+ this.clocks = Arrays.copyOf(clocks, newSize);
+ this.deltaUpdates = Arrays.copyOf(deltaUpdates, newSize);
+ }
}
}
@SuppressWarnings("unchecked")
@Override
- public T get(Object feature) {
- int i = HiveUtils.parseInt(feature);
+ public T get(Object feature) {
+ final int i = HiveUtils.parseInt(feature);
if(i >= size) {
return null;
}
@@ -86,15 +119,27 @@ public T get(Object feature) {
}
@Override
- public void set(Object feature, T value) {
+ public void set(Object feature, T value) {
int i = HiveUtils.parseInt(feature);
ensureCapacity(i);
float weight = value.get();
weights[i] = weight;
+ float covar = 1.f;
if(value.hasCovariance()) {
- float covar = value.getCovariance();
+ covar = value.getCovariance();
covars[i] = covar;
}
+ short clock = 0;
+ int delta = 0;
+ if(clocks != null && value.isTouched()) {
+ clock = (short) (clocks[i] + 1);
+ clocks[i] = clock;
+ delta = deltaUpdates[i] + 1;
+ assert (delta > 0) : delta;
+ deltaUpdates[i] = (byte) delta;
+ }
+
+ onUpdate(i, weight, covar, clock, delta);
}
@Override
@@ -116,18 +161,24 @@ public float getCovariance(Object feature) {
}
@Override
- public void setValue(Object feature, float weight) {
+ public void _set(Object feature, float weight, short clock) {
int i = HiveUtils.parseInt(feature);
ensureCapacity(i);
weights[i] = weight;
+ clocks[i] = clock;
+ deltaUpdates[i] = 0;
+ numMixed++;
}
@Override
- public void setValue(Object feature, float weight, float covar) {
+ public void _set(Object feature, float weight, float covar, short clock) {
int i = HiveUtils.parseInt(feature);
ensureCapacity(i);
weights[i] = weight;
covars[i] = covar;
+ clocks[i] = clock;
+ deltaUpdates[i] = 0;
+ numMixed++;
}
@Override
@@ -147,11 +198,11 @@ public boolean contains(Object feature) {
@SuppressWarnings("unchecked")
@Override
- public IMapIterator entries() {
+ public IMapIterator entries() {
return (IMapIterator) new Itr();
}
- private final class Itr implements IMapIterator {
+ private final class Itr implements IMapIterator {
private int cursor;
private final WeightValueWithCovar tmpWeight;
@@ -181,7 +232,7 @@ public Integer getKey() {
}
@Override
- public WeightValue getValue() {
+ public IWeightValue getValue() {
if(covars == null) {
float w = weights[cursor];
WeightValue v = new WeightValue(w);
@@ -197,13 +248,13 @@ public WeightValue getValue() {
}
@Override
- public > void getValue(T probe) {
+ public > void getValue(T probe) {
float w = weights[cursor];
tmpWeight.value = w;
float cov = 1.f;
if(covars != null) {
cov = covars[cursor];
- tmpWeight.covariance = cov;
+ tmpWeight.setCovariance(cov);
}
tmpWeight.setTouched(w != 0.f || cov != 1.f);
probe.copyFrom(tmpWeight);
diff --git a/src/main/hivemall/io/IWeightValue.java b/src/main/hivemall/io/IWeightValue.java
new file mode 100644
index 00000000..957e964e
--- /dev/null
+++ b/src/main/hivemall/io/IWeightValue.java
@@ -0,0 +1,32 @@
+package hivemall.io;
+
+import hivemall.utils.lang.Copyable;
+
+public interface IWeightValue extends Copyable {
+
+ float get();
+
+ void set(float weight);
+
+ boolean hasCovariance();
+
+ float getCovariance();
+
+ void setCovariance(float cov);
+
+ /**
+ * @return whether touched in training or not
+ */
+ boolean isTouched();
+
+ void setTouched(boolean touched);
+
+ short getClock();
+
+ void setClock(short clock);
+
+ byte getDeltaUpdates();
+
+ void setDeltaUpdates(byte deltaUpdates);
+
+}
\ No newline at end of file
diff --git a/src/main/hivemall/io/ModelUpdateHandler.java b/src/main/hivemall/io/ModelUpdateHandler.java
new file mode 100644
index 00000000..0214e99d
--- /dev/null
+++ b/src/main/hivemall/io/ModelUpdateHandler.java
@@ -0,0 +1,36 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.io;
+
+public interface ModelUpdateHandler {
+
+ /**
+ * @param feature
+ * @param weight
+ * @param covar 1.0 by the default
+ * @param clock 0 by the default
+ * @param deltaUpdates
+ * @return reset the deltaUpdates?
+ */
+ boolean onUpdate(Object feature, float weight, float covar, short clock, int deltaUpdates)
+ throws Exception;
+
+}
diff --git a/src/main/hivemall/io/PredictionModel.java b/src/main/hivemall/io/PredictionModel.java
index 09574e67..8f61febf 100644
--- a/src/main/hivemall/io/PredictionModel.java
+++ b/src/main/hivemall/io/PredictionModel.java
@@ -24,22 +24,36 @@
public interface PredictionModel {
- public int size();
+ ModelUpdateHandler getUpdateHandler();
- public boolean contains(Object feature);
+ void setUpdateHandler(ModelUpdateHandler handler);
- public T get(Object feature);
+ int getNumMixed();
- public void set(Object feature, T value);
+ boolean hasCovariance();
- public float getWeight(Object feature);
+ void configureClock();
- public float getCovariance(Object feature);
+ boolean hasClock();
- public void setValue(Object feature, float weight);
+ void resetDeltaUpdates(int feature);
- public void setValue(Object feature, float weight, float covar);
+ int size();
- public IMapIterator entries();
+ boolean contains(Object feature);
-}
+ T get(Object feature);
+
+ void set(Object feature, T value);
+
+ float getWeight(Object feature);
+
+ float getCovariance(Object feature);
+
+ void _set(Object feature, float weight, short clock);
+
+ void _set(Object feature, float weight, float covar, short clock);
+
+ IMapIterator entries();
+
+}
\ No newline at end of file
diff --git a/src/main/hivemall/io/SpaceEfficientDenseModel.java b/src/main/hivemall/io/SpaceEfficientDenseModel.java
index 6ece7ea7..48454f83 100644
--- a/src/main/hivemall/io/SpaceEfficientDenseModel.java
+++ b/src/main/hivemall/io/SpaceEfficientDenseModel.java
@@ -32,18 +32,22 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
-public final class SpaceEfficientDenseModel implements PredictionModel {
+public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
private static final Log logger = LogFactory.getLog(SpaceEfficientDenseModel.class);
private int size;
private short[] weights;
private short[] covars;
+ private short[] clocks;
+ private byte[] deltaUpdates;
+
public SpaceEfficientDenseModel(int ndims) {
this(ndims, false);
}
public SpaceEfficientDenseModel(int ndims, boolean withCovar) {
+ super();
int size = ndims + 1;
this.size = size;
this.weights = new short[size];
@@ -54,6 +58,31 @@ public SpaceEfficientDenseModel(int ndims, boolean withCovar) {
} else {
this.covars = null;
}
+ this.clocks = null;
+ this.deltaUpdates = null;
+ }
+
+ @Override
+ public boolean hasCovariance() {
+ return covars != null;
+ }
+
+ @Override
+ public void configureClock() {
+ if(clocks == null) {
+ this.clocks = new short[size];
+ this.deltaUpdates = new byte[size];
+ }
+ }
+
+ @Override
+ public boolean hasClock() {
+ return clocks != null;
+ }
+
+ @Override
+ public void resetDeltaUpdates(int feature) {
+ deltaUpdates[feature] = 0;
}
private float getWeight(final int i) {
@@ -94,13 +123,17 @@ private void ensureCapacity(final int index) {
this.covars = Arrays.copyOf(covars, newSize);
Arrays.fill(covars, oldSize, newSize, HalfFloat.ONE);
}
+ if(clocks != null) {
+ this.clocks = Arrays.copyOf(clocks, newSize);
+ this.deltaUpdates = Arrays.copyOf(deltaUpdates, newSize);
+ }
}
}
@SuppressWarnings("unchecked")
@Override
- public T get(Object feature) {
- int i = HiveUtils.parseInt(feature);
+ public T get(Object feature) {
+ final int i = HiveUtils.parseInt(feature);
if(i >= size) {
return null;
}
@@ -112,15 +145,27 @@ public T get(Object feature) {
}
@Override
- public void set(Object feature, T value) {
+ public void set(Object feature, T value) {
int i = HiveUtils.parseInt(feature);
ensureCapacity(i);
float weight = value.get();
setWeight(i, weight);
+ float covar = 1.f;
if(value.hasCovariance()) {
- float covar = value.getCovariance();
+ covar = value.getCovariance();
setCovar(i, covar);
}
+ short clock = 0;
+ int delta = 0;
+ if(clocks != null && value.isTouched()) {
+ clock = (short) (clocks[i] + 1);
+ clocks[i] = clock;
+ delta = deltaUpdates[i] + 1;
+ assert (delta > 0) : delta;
+ deltaUpdates[i] = (byte) delta;
+ }
+
+ onUpdate(i, weight, covar, clock, delta);
}
@Override
@@ -142,18 +187,24 @@ public float getCovariance(Object feature) {
}
@Override
- public void setValue(Object feature, float weight) {
+ public void _set(Object feature, float weight, short clock) {
int i = HiveUtils.parseInt(feature);
ensureCapacity(i);
setWeight(i, weight);
+ clocks[i] = clock;
+ deltaUpdates[i] = 0;
+ numMixed++;
}
@Override
- public void setValue(Object feature, float weight, float covar) {
+ public void _set(Object feature, float weight, float covar, short clock) {
int i = HiveUtils.parseInt(feature);
ensureCapacity(i);
setWeight(i, weight);
setCovar(i, covar);
+ clocks[i] = clock;
+ deltaUpdates[i] = 0;
+ numMixed++;
}
@Override
@@ -173,11 +224,11 @@ public boolean contains(Object feature) {
@SuppressWarnings("unchecked")
@Override
- public IMapIterator entries() {
+ public IMapIterator entries() {
return (IMapIterator) new Itr();
}
- private final class Itr implements IMapIterator {
+ private final class Itr implements IMapIterator {
private int cursor;
private final WeightValueWithCovar tmpWeight;
@@ -207,7 +258,7 @@ public Integer getKey() {
}
@Override
- public WeightValue getValue() {
+ public IWeightValue getValue() {
if(covars == null) {
float w = getWeight(cursor);
WeightValue v = new WeightValue(w);
@@ -223,13 +274,13 @@ public WeightValue getValue() {
}
@Override
- public > void getValue(T probe) {
+ public > void getValue(T probe) {
float w = getWeight(cursor);
tmpWeight.value = w;
float cov = 1.f;
if(covars != null) {
cov = getCovar(cursor);
- tmpWeight.covariance = cov;
+ tmpWeight.setCovariance(cov);
}
tmpWeight.setTouched(w != 0.f || cov != 1.f);
probe.copyFrom(tmpWeight);
diff --git a/src/main/hivemall/io/SparseModel.java b/src/main/hivemall/io/SparseModel.java
index 4b20b781..c3e7c30a 100644
--- a/src/main/hivemall/io/SparseModel.java
+++ b/src/main/hivemall/io/SparseModel.java
@@ -20,53 +20,118 @@
*/
package hivemall.io;
-import hivemall.io.WeightValue.WeightValueWithCovar;
+import hivemall.io.WeightValueWithClock.WeightValueWithCovarClock;
import hivemall.utils.collections.IMapIterator;
import hivemall.utils.collections.OpenHashMap;
-public final class SparseModel implements PredictionModel {
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
- private final OpenHashMap weights;
+public final class SparseModel extends AbstractPredictionModel {
+ private static final Log logger = LogFactory.getLog(SparseModel.class);
- public SparseModel() {
- this(16384);
+ private final OpenHashMap weights;
+ private final boolean hasCovar;
+ private boolean clockEnabled;
+
+ public SparseModel(int size, boolean hasCovar) {
+ super();
+ this.weights = new OpenHashMap(size);
+ this.hasCovar = hasCovar;
+ this.clockEnabled = false;
}
- public SparseModel(int size) {
- this.weights = new OpenHashMap(size);
+ @Override
+ public boolean hasCovariance() {
+ return hasCovar;
+ }
+
+ @Override
+ public void configureClock() {
+ this.clockEnabled = true;
+ }
+
+ @Override
+ public boolean hasClock() {
+ return clockEnabled;
}
@SuppressWarnings("unchecked")
@Override
- public T get(Object feature) {
+ public T get(final Object feature) {
return (T) weights.get(feature);
}
@Override
- public void set(Object feature, T value) {
- weights.put(feature, value);
+ public void set(final Object feature, final T value) {
+ assert (feature != null);
+ assert (value != null);
+
+ final IWeightValue wrapperValue = wrapIfRequired(value);
+
+ if(clockEnabled && value.isTouched()) {
+ IWeightValue old = weights.get(feature);
+ if(old != null) {
+ short newclock = (short) (old.getClock() + (short) 1);
+ wrapperValue.setClock(newclock);
+ int newDelta = old.getDeltaUpdates() + 1;
+ wrapperValue.setDeltaUpdates((byte) newDelta);
+ }
+ }
+ weights.put(feature, wrapperValue);
+
+ onUpdate(feature, wrapperValue);
+ }
+
+ private IWeightValue wrapIfRequired(final IWeightValue value) {
+ if(clockEnabled) {
+ if(value.hasCovariance()) {
+ return new WeightValueWithCovarClock(value);
+ } else {
+ return new WeightValueWithClock(value);
+ }
+ } else {
+ return value;
+ }
}
@Override
- public float getWeight(Object feature) {
- WeightValue v = weights.get(feature);
- return v == null ? 0.f : v.value;
+ public float getWeight(final Object feature) {
+ IWeightValue v = weights.get(feature);
+ return v == null ? 0.f : v.get();
}
@Override
- public float getCovariance(Object feature) {
- WeightValueWithCovar v = (WeightValueWithCovar) weights.get(feature);
- return v == null ? 1.f : v.covariance;
+ public float getCovariance(final Object feature) {
+ IWeightValue v = weights.get(feature);
+ return v == null ? 1.f : v.getCovariance();
}
@Override
- public void setValue(Object feature, float weight) {
- weights.put(feature, new WeightValue(weight));
+ public void _set(final Object feature, final float weight, final short clock) {
+ final IWeightValue w = weights.get(feature);
+ if(w == null) {
+ logger.warn("Previous weight not found: " + feature);
+ throw new IllegalStateException("Previous weight not found " + feature);
+ }
+ w.set(weight);
+ w.setClock(clock);
+ w.setDeltaUpdates(BYTE0);
+ numMixed++;
}
@Override
- public void setValue(Object feature, float weight, float covar) {
- weights.put(feature, new WeightValueWithCovar(weight, covar));
+ public void _set(final Object feature, final float weight, final float covar, final short clock) {
+ final IWeightValue w = weights.get(feature);
+ if(w == null) {
+ logger.warn("Previous weight not found: " + feature);
+ throw new IllegalStateException("Previous weight not found: " + feature);
+ }
+ w.set(weight);
+ w.setCovariance(covar);
+ w.setClock(clock);
+ w.setDeltaUpdates(BYTE0);
+ numMixed++;
}
@Override
@@ -75,13 +140,13 @@ public int size() {
}
@Override
- public boolean contains(Object feature) {
+ public boolean contains(final Object feature) {
return weights.containsKey(feature);
}
@SuppressWarnings("unchecked")
@Override
- public IMapIterator entries() {
+ public IMapIterator entries() {
return (IMapIterator) weights.entries();
}
diff --git a/src/main/hivemall/io/SynchronizedModelWrapper.java b/src/main/hivemall/io/SynchronizedModelWrapper.java
new file mode 100644
index 00000000..4335d09a
--- /dev/null
+++ b/src/main/hivemall/io/SynchronizedModelWrapper.java
@@ -0,0 +1,152 @@
+package hivemall.io;
+
+import hivemall.utils.collections.IMapIterator;
+
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
+
+public final class SynchronizedModelWrapper implements PredictionModel {
+
+ private final PredictionModel model;
+ private final Lock lock;
+
+ public SynchronizedModelWrapper(PredictionModel model) {
+ this.model = model;
+ this.lock = new ReentrantLock();
+ }
+
+ // ------------------------------------------------------------
+ // Non-synchronized methods with care
+
+ public PredictionModel getModel() {
+ return model;
+ }
+
+ @Override
+ public ModelUpdateHandler getUpdateHandler() {
+ return model.getUpdateHandler();
+ }
+
+ @Override
+ public void setUpdateHandler(ModelUpdateHandler handler) {
+ model.setUpdateHandler(handler);
+ }
+
+ @Override
+ public int getNumMixed() {
+ return model.getNumMixed();
+ }
+
+ @Override
+ public boolean hasCovariance() {
+ return model.hasCovariance();
+ }
+
+ @Override
+ public void configureClock() {
+ model.configureClock();
+ }
+
+ @Override
+ public boolean hasClock() {
+ return model.hasClock();
+ }
+
+ @Override
+ public IMapIterator entries() {
+ return model.entries();
+ }
+
+ // ------------------------------------------------------------
+ // The below is synchronized methods
+
+ @Override
+ public void resetDeltaUpdates(int feature) {
+ try {
+ lock.lock();
+ model.resetDeltaUpdates(feature);
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public int size() {
+ try {
+ lock.lock();
+ return model.size();
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public boolean contains(Object feature) {
+ try {
+ lock.lock();
+ return model.contains(feature);
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public T get(Object feature) {
+ try {
+ lock.lock();
+ return model.get(feature);
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public void set(Object feature, T value) {
+ try {
+ lock.lock();
+ model.set(feature, value);
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public float getWeight(Object feature) {
+ try {
+ lock.lock();
+ return model.getWeight(feature);
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public float getCovariance(Object feature) {
+ try {
+ lock.lock();
+ return model.getCovariance(feature);
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public void _set(Object feature, float weight, short clock) {
+ try {
+ lock.lock();
+ model._set(feature, weight, clock);
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public void _set(Object feature, float weight, float covar, short clock) {
+ try {
+ lock.lock();
+ model._set(feature, weight, covar, clock);
+ } finally {
+ lock.unlock();
+ }
+ }
+}
diff --git a/src/main/hivemall/io/WeightValue.java b/src/main/hivemall/io/WeightValue.java
index 5f2488f1..f3b1e6d9 100644
--- a/src/main/hivemall/io/WeightValue.java
+++ b/src/main/hivemall/io/WeightValue.java
@@ -20,13 +20,9 @@
*/
package hivemall.io;
-import hivemall.utils.lang.Copyable;
-
-public class WeightValue implements Copyable {
+public class WeightValue implements IWeightValue {
protected float value;
-
- /** Is touched in training */
protected boolean touched;
public WeightValue() {}
@@ -40,22 +36,27 @@ public WeightValue(float weight, boolean touched) {
this.touched = touched;
}
- public float get() {
+ @Override
+ public final float get() {
return value;
}
- public void set(float weight) {
+ @Override
+ public final void set(float weight) {
this.value = weight;
}
+ @Override
public boolean hasCovariance() {
return false;
}
+ @Override
public float getCovariance() {
throw new UnsupportedOperationException();
}
+ @Override
public void setCovariance(float cov) {
throw new UnsupportedOperationException();
}
@@ -63,24 +64,46 @@ public void setCovariance(float cov) {
/**
* @return whether touched in training or not
*/
- public boolean isTouched() {
+ @Override
+ public final boolean isTouched() {
return touched;
}
- public void setTouched(boolean touched) {
+ @Override
+ public final void setTouched(boolean touched) {
this.touched = touched;
}
@Override
- public void copyTo(WeightValue another) {
- another.value = this.value;
- another.touched = this.touched;
+ public final short getClock() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public final void setClock(short clock) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public final byte getDeltaUpdates() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public final void setDeltaUpdates(byte deltaUpdates) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void copyTo(IWeightValue another) {
+ another.set(value);
+ another.setTouched(touched);
}
@Override
- public void copyFrom(WeightValue another) {
- this.value = another.value;
- this.touched = another.touched;
+ public void copyFrom(IWeightValue another) {
+ this.value = another.get();
+ this.touched = another.isTouched();
}
@Override
@@ -91,7 +114,7 @@ public String toString() {
public static final class WeightValueWithCovar extends WeightValue {
public static final float DEFAULT_COVAR = 1.f;
- float covariance;
+ private float covariance;
public WeightValueWithCovar() {
super();
@@ -122,15 +145,15 @@ public void setCovariance(float cov) {
}
@Override
- public void copyTo(WeightValue another) {
+ public void copyTo(IWeightValue another) {
super.copyTo(another);
- ((WeightValueWithCovar) another).covariance = this.covariance;
+ another.setCovariance(covariance);
}
@Override
- public void copyFrom(WeightValue another) {
+ public void copyFrom(IWeightValue another) {
super.copyFrom(another);
- this.covariance = ((WeightValueWithCovar) another).covariance;
+ this.covariance = another.getCovariance();
}
@Override
diff --git a/src/main/hivemall/io/WeightValueWithClock.java b/src/main/hivemall/io/WeightValueWithClock.java
new file mode 100644
index 00000000..b38379c3
--- /dev/null
+++ b/src/main/hivemall/io/WeightValueWithClock.java
@@ -0,0 +1,156 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.io;
+
+public class WeightValueWithClock implements IWeightValue {
+
+ protected float value;
+ protected short clock;
+ protected byte deltaUpdates;
+
+ public WeightValueWithClock(IWeightValue src) {
+ this.value = src.get();
+ if(src.isTouched()) {
+ this.clock = 1;
+ this.deltaUpdates = 1;
+ } else {
+ this.clock = 0;
+ this.deltaUpdates = 0;
+ }
+ }
+
+ public final float get() {
+ return value;
+ }
+
+ public final void set(float weight) {
+ this.value = weight;
+ }
+
+ public boolean hasCovariance() {
+ return false;
+ }
+
+ public float getCovariance() {
+ throw new UnsupportedOperationException();
+ }
+
+ public void setCovariance(float cov) {
+ throw new UnsupportedOperationException();
+ }
+
+ /**
+ * @return whether touched in training or not
+ */
+ public final boolean isTouched() {
+ return deltaUpdates > 0;
+ }
+
+ @Override
+ public void setTouched(boolean touched) {
+ throw new UnsupportedOperationException("WeightValueWithClock#setTouched should not be called");
+ }
+
+ public final short getClock() {
+ return clock;
+ }
+
+ public final void setClock(short clock) {
+ this.clock = clock;
+ }
+
+ public final byte getDeltaUpdates() {
+ return deltaUpdates;
+ }
+
+ public final void setDeltaUpdates(byte deltaUpdates) {
+ if(deltaUpdates < 0) {
+ throw new IllegalArgumentException("deltaUpdates is less than 0: " + deltaUpdates);
+ }
+ this.deltaUpdates = deltaUpdates;
+ }
+
+ @Override
+ public void copyTo(IWeightValue another) {
+ another.set(value);
+ another.setClock(clock);
+ another.setDeltaUpdates(deltaUpdates);
+ }
+
+ @Override
+ public void copyFrom(IWeightValue another) {
+ this.value = another.get();
+ this.clock = another.getClock();
+ this.deltaUpdates = another.getDeltaUpdates();
+ }
+
+ @Override
+ public String toString() {
+ return "WeightValueWithClock [value=" + value + ", clock=" + clock + ", deltaUpdates="
+ + deltaUpdates + "]";
+ }
+
+ public static final class WeightValueWithCovarClock extends WeightValueWithClock {
+ public static final float DEFAULT_COVAR = 1.f;
+
+ private float covariance;
+
+ public WeightValueWithCovarClock(IWeightValue src) {
+ super(src);
+ this.covariance = src.getCovariance();
+ }
+
+ @Override
+ public boolean hasCovariance() {
+ return true;
+ }
+
+ @Override
+ public float getCovariance() {
+ return covariance;
+ }
+
+ @Override
+ public void setCovariance(float cov) {
+ this.covariance = cov;
+ }
+
+ @Override
+ public void copyTo(IWeightValue another) {
+ super.copyTo(another);
+ another.setCovariance(covariance);
+ }
+
+ @Override
+ public void copyFrom(IWeightValue another) {
+ super.copyFrom(another);
+ this.covariance = another.getCovariance();
+ }
+
+ @Override
+ public String toString() {
+ return "WeightValueWithCovar [value=" + value + ", clock=" + clock + ", deltaUpdates="
+ + deltaUpdates + ", covariance=" + covariance + "]";
+ }
+
+ }
+
+}
\ No newline at end of file
diff --git a/src/main/hivemall/mix/MixMessage.java b/src/main/hivemall/mix/MixMessage.java
new file mode 100644
index 00000000..586864c6
--- /dev/null
+++ b/src/main/hivemall/mix/MixMessage.java
@@ -0,0 +1,155 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013-2014
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.mix;
+
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+
+public final class MixMessage implements Externalizable {
+
+ private MixEventName event;
+ private Object feature;
+ private float weight;
+ private float covariance;
+ private short clock;
+ private int deltaUpdates;
+
+ private String groupID;
+
+ public MixMessage() {} // for Externalizable
+
+ public MixMessage(MixEventName event, Object feature, float weight, short clock, int deltaUpdates) {
+ this(event, feature, weight, 0.f, clock, deltaUpdates);
+ }
+
+ public MixMessage(MixEventName event, Object feature, float weight, float covariance, short clock, int deltaUpdates) {
+ if(feature == null) {
+ throw new IllegalArgumentException("feature is null");
+ }
+ if(deltaUpdates < 0 || deltaUpdates > Byte.MAX_VALUE) {
+ throw new IllegalArgumentException("Illegal deletaUpdates: " + deltaUpdates);
+ }
+ this.event = event;
+ this.feature = feature;
+ this.weight = weight;
+ this.covariance = covariance;
+ this.clock = clock;
+ this.deltaUpdates = deltaUpdates;
+ }
+
+ public enum MixEventName {
+ average((byte) 1), argminKLD((byte) 2), closeGroup((byte) 3);
+
+ private final byte id;
+
+ MixEventName(byte id) {
+ this.id = id;
+ }
+
+ public byte getID() {
+ return id;
+ }
+
+ public static MixEventName resolve(int b) {
+ switch(b) {
+ case 1:
+ return average;
+ case 2:
+ return argminKLD;
+ default:
+ throw new IllegalArgumentException("Illegal ID: " + b);
+ }
+ }
+ }
+
+ public MixEventName getEvent() {
+ return event;
+ }
+
+ public Object getFeature() {
+ return feature;
+ }
+
+ public float getWeight() {
+ return weight;
+ }
+
+ public float getCovariance() {
+ return covariance;
+ }
+
+ public short getClock() {
+ return clock;
+ }
+
+ public int getDeltaUpdates() {
+ return deltaUpdates;
+ }
+
+ public String getGroupID() {
+ return groupID;
+ }
+
+ public void setGroupID(String groupID) {
+ this.groupID = groupID;
+ }
+
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeByte(event.getID());
+ out.writeObject(feature);
+ out.writeFloat(weight);
+ out.writeFloat(covariance);
+ out.writeShort(clock);
+ out.writeInt(deltaUpdates);
+ if(groupID == null) {
+ out.writeBoolean(false);
+ } else {
+ out.writeBoolean(true);
+ out.writeUTF(groupID);
+ }
+ }
+
+ @Override
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ byte id = in.readByte();
+ this.event = MixEventName.resolve(id);
+ this.feature = in.readObject();
+ this.weight = in.readFloat();
+ this.covariance = in.readFloat();
+ this.clock = in.readShort();
+ this.deltaUpdates = in.readInt();
+ boolean hasGroupID = in.readBoolean();
+ if(hasGroupID) {
+ this.groupID = in.readUTF();
+ }
+ }
+
+ @Override
+ public String toString() {
+ return "MixMessage [event=" + event + ", groupID=" + groupID + ", feature=" + feature
+ + ", weight=" + weight + ", covariance=" + covariance + ", clock=" + clock
+ + ", deltaUpdates=" + deltaUpdates + "]";
+ }
+
+}
diff --git a/src/main/hivemall/mix/MixMessageDecoder.java b/src/main/hivemall/mix/MixMessageDecoder.java
new file mode 100644
index 00000000..432b9b67
--- /dev/null
+++ b/src/main/hivemall/mix/MixMessageDecoder.java
@@ -0,0 +1,114 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013-2014
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.mix;
+
+import static hivemall.mix.MixMessageEncoder.INTEGER_TYPE;
+import static hivemall.mix.MixMessageEncoder.INT_WRITABLE_TYPE;
+import static hivemall.mix.MixMessageEncoder.LONG_WRITABLE_TYPE;
+import static hivemall.mix.MixMessageEncoder.STRING_TYPE;
+import static hivemall.mix.MixMessageEncoder.TEXT_TYPE;
+import hivemall.mix.MixMessage.MixEventName;
+import hivemall.utils.lang.StringUtils;
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
+
+import java.io.IOException;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+
+public final class MixMessageDecoder extends LengthFieldBasedFrameDecoder {
+
+ public MixMessageDecoder() {
+ super(1048576/* 1MiB */, 0, 4, 0, 4);
+ }
+
+ @Override
+ protected MixMessage decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
+ ByteBuf frame = (ByteBuf) super.decode(ctx, in);
+ if(frame == null) {
+ return null;
+ }
+
+ byte b = frame.readByte();
+ MixEventName event = MixEventName.resolve(b);
+ Object feature = decodeObject(frame);
+ float weight = frame.readFloat();
+ float covariance = frame.readFloat();
+ short clock = frame.readShort();
+ int deltaUpdates = frame.readInt();
+ String groupID = readString(frame);
+
+ MixMessage msg = new MixMessage(event, feature, weight, covariance, clock, deltaUpdates);
+ msg.setGroupID(groupID);
+ return msg;
+ }
+
+ private static Object decodeObject(final ByteBuf in) throws IOException {
+ final byte type = in.readByte();
+ switch(type) {
+ case INTEGER_TYPE: {
+ int i = in.readInt();
+ return Integer.valueOf(i);
+ }
+ case TEXT_TYPE: {
+ int length = in.readInt();
+ byte[] b = new byte[length];
+ in.readBytes(b, 0, length);
+ Text t = new Text(b);
+ return t;
+ }
+ case STRING_TYPE: {
+ return readString(in);
+ }
+ case INT_WRITABLE_TYPE: {
+ int i = in.readInt();
+ return new IntWritable(i);
+ }
+ case LONG_WRITABLE_TYPE: {
+ long l = in.readLong();
+ return new LongWritable(l);
+ }
+ default:
+ break;
+ }
+ throw new IllegalStateException("Illegal type: " + type);
+ }
+
+ private static String readString(final ByteBuf in) {
+ int length = in.readInt();
+ if(length == -1) {
+ return null;
+ }
+ byte[] b = new byte[length];
+ in.readBytes(b, 0, length);
+ String s = StringUtils.toString(b);
+ return s;
+ }
+
+ @Override
+ protected ByteBuf extractFrame(ChannelHandlerContext ctx, ByteBuf buffer, int index, int length) {
+ return buffer.slice(index, length);
+ }
+
+}
\ No newline at end of file
diff --git a/src/main/hivemall/mix/MixMessageEncoder.java b/src/main/hivemall/mix/MixMessageEncoder.java
new file mode 100644
index 00000000..2729fe79
--- /dev/null
+++ b/src/main/hivemall/mix/MixMessageEncoder.java
@@ -0,0 +1,120 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013-2014
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.mix;
+
+import hivemall.mix.MixMessage.MixEventName;
+import hivemall.utils.lang.StringUtils;
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.MessageToByteEncoder;
+
+import java.io.IOException;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+
+public final class MixMessageEncoder extends MessageToByteEncoder {
+ private static final byte[] LENGTH_PLACEHOLDER = new byte[4];
+
+ static final byte INTEGER_TYPE = 1;
+ static final byte TEXT_TYPE = 2;
+ static final byte STRING_TYPE = 3;
+ static final byte INT_WRITABLE_TYPE = 4;
+ static final byte LONG_WRITABLE_TYPE = 5;
+
+ public MixMessageEncoder() {
+ super(MixMessage.class, true);
+ }
+
+ @Override
+ protected void encode(ChannelHandlerContext ctx, MixMessage msg, ByteBuf out) throws Exception {
+ int startIdx = out.writerIndex();
+ out.writeBytes(LENGTH_PLACEHOLDER);
+
+ MixEventName event = msg.getEvent();
+ byte b = event.getID();
+ out.writeByte(b);
+
+ Object feature = msg.getFeature();
+ encodeObject(feature, out);
+
+ float weight = msg.getWeight();
+ out.writeFloat(weight);
+
+ float covariance = msg.getCovariance();
+ out.writeFloat(covariance);
+
+ short clock = msg.getClock();
+ out.writeShort(clock);
+
+ int deltaUpdates = msg.getDeltaUpdates();
+ out.writeInt(deltaUpdates);
+
+ String groupId = msg.getGroupID();
+ writeString(groupId, out);
+
+ int endIdx = out.writerIndex();
+ out.setInt(startIdx, endIdx - startIdx - 4);
+ }
+
+ private static void encodeObject(final Object obj, final ByteBuf buf) throws IOException {
+ assert (obj != null);
+ if(obj instanceof Integer) {
+ Integer i = (Integer) obj;
+ buf.writeByte(INTEGER_TYPE);
+ buf.writeInt(i.intValue());
+ } else if(obj instanceof Text) {
+ Text t = (Text) obj;
+ byte[] b = t.getBytes();
+ int length = t.getLength();
+ buf.writeByte(TEXT_TYPE);
+ buf.writeInt(length);
+ buf.writeBytes(b, 0, length);
+ } else if(obj instanceof String) {
+ String s = (String) obj;
+ buf.writeByte(STRING_TYPE);
+ writeString(s, buf);
+ } else if(obj instanceof IntWritable) {
+ IntWritable i = (IntWritable) obj;
+ buf.writeByte(INT_WRITABLE_TYPE);
+ buf.writeInt(i.get());
+ } else if(obj instanceof LongWritable) {
+ LongWritable l = (LongWritable) obj;
+ buf.writeByte(LONG_WRITABLE_TYPE);
+ buf.writeLong(l.get());
+ } else {
+ throw new IllegalStateException("Unexpected type: " + obj.getClass().getName());
+ }
+ }
+
+ private static void writeString(final String s, final ByteBuf buf) {
+ if(s == null) {
+ buf.writeInt(-1);
+ return;
+ }
+ byte[] b = StringUtils.getBytes(s);
+ int length = b.length;
+ buf.writeInt(length);
+ buf.writeBytes(b, 0, length);
+ }
+
+}
diff --git a/src/main/hivemall/mix/NodeInfo.java b/src/main/hivemall/mix/NodeInfo.java
new file mode 100644
index 00000000..7d686c2a
--- /dev/null
+++ b/src/main/hivemall/mix/NodeInfo.java
@@ -0,0 +1,79 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013-2014
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.mix;
+
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.SocketAddress;
+
+public final class NodeInfo {
+
+ private final InetAddress addr;
+ private final int port;
+
+ public NodeInfo(InetAddress addr, int port) {
+ if(addr == null) {
+ throw new IllegalArgumentException("addr is null");
+ }
+ this.addr = addr;
+ this.port = port;
+ }
+
+ public NodeInfo(InetSocketAddress sockAddr) {
+ this.addr = sockAddr.getAddress();
+ this.port = sockAddr.getPort();
+ }
+
+ public InetAddress getAddress() {
+ return addr;
+ }
+
+ public int getPort() {
+ return port;
+ }
+
+ public SocketAddress getSocketAddress() {
+ return new InetSocketAddress(addr, port);
+ }
+
+ @Override
+ public int hashCode() {
+ return addr.hashCode() + port;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if(obj == this) {
+ return true;
+ }
+ if(obj instanceof NodeInfo) {
+ NodeInfo other = (NodeInfo) obj;
+ return addr.equals(other.addr) && (port == other.port);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return addr.toString() + ":" + port;
+ }
+
+}
diff --git a/src/main/hivemall/mix/client/MixClient.java b/src/main/hivemall/mix/client/MixClient.java
new file mode 100644
index 00000000..3a4cb0e3
--- /dev/null
+++ b/src/main/hivemall/mix/client/MixClient.java
@@ -0,0 +1,160 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013-2014
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.mix.client;
+
+import hivemall.io.ModelUpdateHandler;
+import hivemall.io.PredictionModel;
+import hivemall.mix.MixMessage;
+import hivemall.mix.MixMessage.MixEventName;
+import hivemall.mix.NodeInfo;
+import hivemall.utils.hadoop.HadoopUtils;
+import io.netty.bootstrap.Bootstrap;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.nio.NioSocketChannel;
+import io.netty.handler.ssl.SslContext;
+import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.net.SocketAddress;
+import java.util.HashMap;
+import java.util.Map;
+
+import javax.annotation.CheckForNull;
+import javax.annotation.Nonnull;
+import javax.net.ssl.SSLException;
+
+public final class MixClient implements ModelUpdateHandler, Closeable {
+ public static final String DUMMY_JOB_ID = "__DUMMY_JOB_ID__";
+
+ private final MixEventName event;
+ private String groupID;
+ private final boolean ssl;
+ private final int mixThreshold;
+ private final MixRequestRouter router;
+ private final MixClientHandler msgHandler;
+ private final Map channelMap;
+
+ private boolean initialized = false;
+ private EventLoopGroup workers;
+
+ public MixClient(@Nonnull MixEventName event, @CheckForNull String groupID, @Nonnull String connectURIs, boolean ssl, int mixThreshold, @Nonnull PredictionModel model) {
+ if(groupID == null) {
+ throw new IllegalArgumentException("groupID is null");
+ }
+ if(mixThreshold < 1 || mixThreshold > Byte.MAX_VALUE) {
+ throw new IllegalArgumentException("Invalid mixThreshold: " + mixThreshold);
+ }
+ this.event = event;
+ this.groupID = groupID;
+ this.router = new MixRequestRouter(connectURIs);
+ this.ssl = ssl;
+ this.mixThreshold = mixThreshold;
+ this.msgHandler = new MixClientHandler(model);
+ this.channelMap = new HashMap();
+ }
+
+ private void initialize() throws Exception {
+ EventLoopGroup workerGroup = new NioEventLoopGroup();
+ NodeInfo[] serverNodes = router.getAllNodes();
+ for(NodeInfo node : serverNodes) {
+ Bootstrap b = new Bootstrap();
+ configureBootstrap(b, workerGroup, node);
+ }
+ this.workers = workerGroup;
+ this.initialized = true;
+ }
+
+ private void configureBootstrap(Bootstrap b, EventLoopGroup workerGroup, NodeInfo server)
+ throws SSLException, InterruptedException {
+ // Configure SSL.
+ final SslContext sslCtx;
+ if(ssl) {
+ sslCtx = SslContext.newClientContext(InsecureTrustManagerFactory.INSTANCE);
+ } else {
+ sslCtx = null;
+ }
+
+ b.group(workerGroup);
+ b.option(ChannelOption.SO_KEEPALIVE, true);
+ b.option(ChannelOption.TCP_NODELAY, true);
+ b.channel(NioSocketChannel.class);
+ b.handler(new MixClientInitializer(msgHandler, sslCtx));
+
+ SocketAddress remoteAddr = server.getSocketAddress();
+ ChannelFuture channelFuture = b.connect(remoteAddr).sync();
+ Channel channel = channelFuture.channel();
+
+ channelMap.put(server, channel);
+ }
+
+ @Override
+ public boolean onUpdate(Object feature, float weight, float covar, short clock, int deltaUpdates)
+ throws Exception {
+ assert (deltaUpdates > 0) : deltaUpdates;
+ if(deltaUpdates < mixThreshold) {
+ return false; // avoid mixing
+ }
+
+ if(!initialized) {
+ replaceGroupIDIfRequired();
+ initialize(); // initialize connections to mix servers
+ }
+
+ MixMessage msg = new MixMessage(event, feature, weight, covar, clock, deltaUpdates);
+ msg.setGroupID(groupID);
+
+ NodeInfo server = router.selectNode(msg);
+ Channel ch = channelMap.get(server);
+ if(!ch.isActive()) {// reconnect
+ SocketAddress remoteAddr = server.getSocketAddress();
+ ch.connect(remoteAddr).sync();
+ }
+
+ //ch.writeAndFlush(msg).sync();
+ ch.writeAndFlush(msg); // send asynchronously in the background
+ return true;
+ }
+
+ private void replaceGroupIDIfRequired() {
+ if(groupID.startsWith(DUMMY_JOB_ID)) {
+ String jobId = HadoopUtils.getJobId();
+ this.groupID = groupID.replace(DUMMY_JOB_ID, jobId);
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ if(workers != null) {
+ for(Channel ch : channelMap.values()) {
+ ch.close();
+ }
+ channelMap.clear();
+ workers.shutdownGracefully();
+ this.workers = null;
+ }
+ }
+
+}
diff --git a/src/main/hivemall/mix/client/MixClientHandler.java b/src/main/hivemall/mix/client/MixClientHandler.java
new file mode 100644
index 00000000..672f949b
--- /dev/null
+++ b/src/main/hivemall/mix/client/MixClientHandler.java
@@ -0,0 +1,57 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013-2014
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.mix.client;
+
+import hivemall.io.PredictionModel;
+import hivemall.mix.MixMessage;
+import io.netty.channel.ChannelHandler.Sharable;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.SimpleChannelInboundHandler;
+
+@Sharable
+public final class MixClientHandler extends SimpleChannelInboundHandler {
+
+ private final PredictionModel model;
+ private final boolean hasCovar;
+
+ public MixClientHandler(PredictionModel model) {
+ super();
+ if(model == null) {
+ throw new IllegalArgumentException("model is null");
+ }
+ this.model = model;
+ this.hasCovar = model.hasCovariance();
+ }
+
+ @Override
+ protected void channelRead0(ChannelHandlerContext ctx, MixMessage msg) throws Exception {
+ Object feature = msg.getFeature();
+ float weight = msg.getWeight();
+ short clock = msg.getClock();
+ if(hasCovar) {
+ float covar = msg.getCovariance();
+ model._set(feature, weight, covar, clock);
+ } else {
+ model._set(feature, weight, clock);
+ }
+ }
+
+}
diff --git a/src/main/hivemall/mix/client/MixClientInitializer.java b/src/main/hivemall/mix/client/MixClientInitializer.java
new file mode 100644
index 00000000..a4208dd5
--- /dev/null
+++ b/src/main/hivemall/mix/client/MixClientInitializer.java
@@ -0,0 +1,57 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013-2014
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.mix.client;
+
+import hivemall.mix.MixMessageDecoder;
+import hivemall.mix.MixMessageEncoder;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelPipeline;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.ssl.SslContext;
+
+public final class MixClientInitializer extends ChannelInitializer {
+
+ private final MixClientHandler responseHandler;
+ private final SslContext sslCtx;
+
+ public MixClientInitializer(MixClientHandler msgHandler, SslContext sslCtx) {
+ if(msgHandler == null) {
+ throw new IllegalArgumentException();
+ }
+ this.responseHandler = msgHandler;
+ this.sslCtx = sslCtx;
+ }
+
+ @Override
+ protected void initChannel(SocketChannel ch) throws Exception {
+ ChannelPipeline pipeline = ch.pipeline();
+ if(sslCtx != null) {
+ pipeline.addLast(sslCtx.newHandler(ch.alloc()));
+ }
+
+ //ObjectEncoder encoder = new ObjectEncoder();
+ //ObjectDecoder decoder = new ObjectDecoder(ClassResolvers.cacheDisabled(null));
+ MixMessageEncoder encoder = new MixMessageEncoder();
+ MixMessageDecoder decoder = new MixMessageDecoder();
+ pipeline.addLast(encoder, decoder, responseHandler);
+ }
+
+}
diff --git a/src/main/hivemall/mix/client/MixRequestRouter.java b/src/main/hivemall/mix/client/MixRequestRouter.java
new file mode 100644
index 00000000..7d331e1b
--- /dev/null
+++ b/src/main/hivemall/mix/client/MixRequestRouter.java
@@ -0,0 +1,65 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013-2014
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.mix.client;
+
+import hivemall.mix.MixMessage;
+import hivemall.mix.NodeInfo;
+import hivemall.mix.server.MixServer;
+import hivemall.utils.net.NetUtils;
+
+import java.net.InetSocketAddress;
+
+public final class MixRequestRouter {
+
+ private final int numNodes;
+ private final NodeInfo[] nodes;
+
+ public MixRequestRouter(String connectInfo) {
+ if(connectInfo == null) {
+ throw new IllegalArgumentException();
+ }
+ String[] endpoints = connectInfo.split("\\s*,\\s*");
+ final int numEndpoints = endpoints.length;
+ if(numEndpoints < 1) {
+ throw new IllegalArgumentException("Invalid connectInfo: " + connectInfo);
+ }
+ this.numNodes = numEndpoints;
+ NodeInfo[] nodes = new NodeInfo[numEndpoints];
+ for(int i = 0; i < numEndpoints; i++) {
+ InetSocketAddress addr = NetUtils.getInetSocketAddress(endpoints[i], MixServer.DEFAULT_PORT);
+ nodes[i] = new NodeInfo(addr);
+ }
+ this.nodes = nodes;
+ }
+
+ public NodeInfo[] getAllNodes() {
+ return nodes;
+ }
+
+ public NodeInfo selectNode(MixMessage msg) {
+ assert (msg != null);
+ Object feature = msg.getFeature();
+ int hashcode = feature.hashCode();
+ int index = Math.abs(hashcode) % numNodes;
+ return nodes[index];
+ }
+
+}
diff --git a/src/main/hivemall/mix/server/MixServer.java b/src/main/hivemall/mix/server/MixServer.java
new file mode 100644
index 00000000..0b6c7879
--- /dev/null
+++ b/src/main/hivemall/mix/server/MixServer.java
@@ -0,0 +1,147 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013-2014
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.mix.server;
+
+import hivemall.mix.store.SessionStore;
+import hivemall.mix.store.SessionStore.IdleSessionSweeper;
+import hivemall.utils.lang.CommandLineUtils;
+import hivemall.utils.lang.Primitives;
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.nio.NioServerSocketChannel;
+import io.netty.handler.logging.LogLevel;
+import io.netty.handler.logging.LoggingHandler;
+import io.netty.handler.ssl.SslContext;
+import io.netty.handler.ssl.util.SelfSignedCertificate;
+
+import java.security.cert.CertificateException;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+
+import javax.annotation.Nonnull;
+import javax.net.ssl.SSLException;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+
+public final class MixServer implements Runnable {
+ public static final int DEFAULT_PORT = 11212;
+
+ private final int port;
+ private final boolean ssl;
+ private final float scale;
+ private final short syncThreshold;
+ private final long sessionTTLinSec;
+ private final long sweepIntervalInSec;
+
+ public MixServer(CommandLine cl) {
+ this.port = Primitives.parseInt(cl.getOptionValue("port"), DEFAULT_PORT);
+ this.ssl = cl.hasOption("ssl");
+ this.scale = Primitives.parseFloat(cl.getOptionValue("scale"), 1.f);
+ this.syncThreshold = Primitives.parseShort(cl.getOptionValue("sync"), (short) 30);
+ this.sessionTTLinSec = Primitives.parseLong(cl.getOptionValue("ttl"), 120L);
+ this.sweepIntervalInSec = Primitives.parseLong(cl.getOptionValue("sweep"), 60L);
+ }
+
+ public static void main(String[] args) {
+ Options opts = getOptions();
+ CommandLine cl = CommandLineUtils.parseOptions(args, opts);
+ new MixServer(cl).run();
+ }
+
+ static Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("p", "port", true, "port number of the mix server [default: 11212]");
+ opts.addOption("ssl", false, "Use SSL for the mix communication [default: false]");
+ opts.addOption("scale", "scalemodel", true, "Scale values of prediction models to avoid overflow [default: 1.0 (no-scale)]");
+ opts.addOption("sync", "sync_threshold", true, "Synchronization threshold using clock difference [default: 30]");
+ opts.addOption("ttl", "session_ttl", true, "The TTL in sec that an idle session lives [default: 120 sec]");
+ opts.addOption("sweep", "session_sweep_interval", true, "The interval in sec that the session expiry thread runs [default: 60 sec]");
+ return opts;
+ }
+
+ @Override
+ public void run() {
+ try {
+ start();
+ } catch (CertificateException e) {
+ e.printStackTrace();
+ } catch (SSLException e) {
+ e.printStackTrace();
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ }
+
+ public void start() throws CertificateException, SSLException, InterruptedException {
+ // Configure SSL.
+ final SslContext sslCtx;
+ if(ssl) {
+ SelfSignedCertificate ssc = new SelfSignedCertificate();
+ sslCtx = SslContext.newServerContext(ssc.certificate(), ssc.privateKey());
+ } else {
+ sslCtx = null;
+ }
+
+ SessionStore sessionStore = new SessionStore();
+ MixServerHandler msgHandler = new MixServerHandler(sessionStore, syncThreshold, scale);
+ MixServerInitializer initializer = new MixServerInitializer(msgHandler, sslCtx);
+ Runnable cleanSessionTask = new IdleSessionSweeper(sessionStore, sessionTTLinSec * 1000L);
+
+ final ScheduledExecutorService idleSessionChecker = Executors.newScheduledThreadPool(1);
+ try {
+ idleSessionChecker.scheduleAtFixedRate(cleanSessionTask, sessionTTLinSec + 10L, sweepIntervalInSec, TimeUnit.SECONDS);
+ acceptConnections(initializer, port);
+ } finally {
+ idleSessionChecker.shutdownNow();
+ }
+ }
+
+ private static void acceptConnections(@Nonnull MixServerInitializer initializer, int port)
+ throws InterruptedException {
+ final EventLoopGroup bossGroup = new NioEventLoopGroup(1);
+ final EventLoopGroup workerGroup = new NioEventLoopGroup();
+ try {
+ ServerBootstrap b = new ServerBootstrap();
+ b.option(ChannelOption.SO_KEEPALIVE, true);
+ b.group(bossGroup, workerGroup);
+ b.channel(NioServerSocketChannel.class);
+ b.handler(new LoggingHandler(LogLevel.INFO));
+ b.childHandler(initializer);
+
+ // Bind and start to accept incoming connections.
+ ChannelFuture f = b.bind(port).sync();
+
+ // Wait until the server socket is closed.
+ // In this example, this does not happen, but you can do that to gracefully
+ // shut down your server.
+ f.channel().closeFuture().sync();
+ } finally {
+ workerGroup.shutdownGracefully();
+ bossGroup.shutdownGracefully();
+ }
+ }
+
+}
diff --git a/src/main/hivemall/mix/server/MixServerHandler.java b/src/main/hivemall/mix/server/MixServerHandler.java
new file mode 100644
index 00000000..950ca212
--- /dev/null
+++ b/src/main/hivemall/mix/server/MixServerHandler.java
@@ -0,0 +1,127 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013-2014
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.mix.server;
+
+import hivemall.mix.MixMessage;
+import hivemall.mix.MixMessage.MixEventName;
+import hivemall.mix.store.PartialArgminKLD;
+import hivemall.mix.store.PartialAverage;
+import hivemall.mix.store.PartialResult;
+import hivemall.mix.store.SessionStore;
+import io.netty.channel.ChannelHandler.Sharable;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.SimpleChannelInboundHandler;
+
+import java.util.concurrent.ConcurrentMap;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+@Sharable
+public final class MixServerHandler extends SimpleChannelInboundHandler {
+
+ @Nonnull
+ private final SessionStore sessionStore;
+ private final int syncThreshold;
+ private final float scale;
+
+ public MixServerHandler(@Nonnull SessionStore sessionStore, @Nonnegative int syncThreshold, @Nonnegative float scale) {
+ super();
+ this.sessionStore = sessionStore;
+ this.syncThreshold = syncThreshold;
+ this.scale = scale;
+ }
+
+ @Override
+ protected void channelRead0(ChannelHandlerContext ctx, MixMessage msg) throws Exception {
+ final MixEventName event = msg.getEvent();
+ switch(event) {
+ case average:
+ case argminKLD:
+ PartialResult partial = getPartialResult(msg);
+ mix(ctx, msg, partial);
+ break;
+ default:
+ throw new IllegalStateException("Unexpected event: " + event);
+ }
+ }
+
+ @Nonnull
+ private PartialResult getPartialResult(@Nonnull MixMessage msg) {
+ String groupID = msg.getGroupID();
+ if(groupID == null) {
+ throw new IllegalStateException("JobID is not set in the request message");
+ }
+ ConcurrentMap map = sessionStore.get(groupID);
+
+ Object feature = msg.getFeature();
+ PartialResult partial = map.get(feature);
+ if(partial == null) {
+ final MixEventName event = msg.getEvent();
+ switch(event) {
+ case average:
+ partial = new PartialAverage(scale);
+ break;
+ case argminKLD:
+ partial = new PartialArgminKLD(scale);
+ break;
+ default:
+ throw new IllegalStateException("Unexpected event: " + event);
+ }
+ PartialResult existing = map.putIfAbsent(feature, partial);
+ if(existing != null) {
+ partial = existing;
+ }
+ }
+ return partial;
+ }
+
+ private void mix(final ChannelHandlerContext ctx, final MixMessage requestMsg, final PartialResult partial) {
+ MixEventName event = requestMsg.getEvent();
+ Object feature = requestMsg.getFeature();
+ float weight = requestMsg.getWeight();
+ float covar = requestMsg.getCovariance();
+ short clock = requestMsg.getClock();
+ int deltaUpdates = requestMsg.getDeltaUpdates();
+
+ MixMessage responseMsg = null;
+ try {
+ partial.lock();
+
+ int diffClock = partial.diffClock(clock);
+ partial.add(weight, covar, clock, deltaUpdates);
+
+ if(diffClock >= syncThreshold) {// sync model if clock DIFF is above threshold
+ float averagedWeight = partial.getWeight();
+ float minCovar = partial.getMinCovariance();
+ short totalClock = partial.getClock();
+ responseMsg = new MixMessage(event, feature, averagedWeight, minCovar, totalClock, 0 /* deltaUpdates */);
+ }
+ } finally {
+ partial.unlock();
+ }
+
+ if(responseMsg != null) {
+ ctx.writeAndFlush(responseMsg);
+ }
+ }
+
+}
diff --git a/src/main/hivemall/mix/server/MixServerInitializer.java b/src/main/hivemall/mix/server/MixServerInitializer.java
new file mode 100644
index 00000000..2c2bd786
--- /dev/null
+++ b/src/main/hivemall/mix/server/MixServerInitializer.java
@@ -0,0 +1,57 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013-2014
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.mix.server;
+
+import hivemall.mix.MixMessageDecoder;
+import hivemall.mix.MixMessageEncoder;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelPipeline;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.ssl.SslContext;
+
+public final class MixServerInitializer extends ChannelInitializer {
+
+ private final MixServerHandler requestHandler;
+ private final SslContext sslCtx;
+
+ public MixServerInitializer(MixServerHandler msgHandler, SslContext sslCtx) {
+ if(msgHandler == null) {
+ throw new IllegalArgumentException();
+ }
+ this.requestHandler = msgHandler;
+ this.sslCtx = sslCtx;
+ }
+
+ @Override
+ protected void initChannel(SocketChannel ch) throws Exception {
+ ChannelPipeline pipeline = ch.pipeline();
+ if(sslCtx != null) {
+ pipeline.addLast(sslCtx.newHandler(ch.alloc()));
+ }
+
+ //ObjectEncoder encoder = new ObjectEncoder();
+ //ObjectDecoder decoder = new ObjectDecoder(4194304, ClassResolvers.cacheDisabled(null));
+ MixMessageEncoder encoder = new MixMessageEncoder();
+ MixMessageDecoder decoder = new MixMessageDecoder();
+ pipeline.addLast(decoder, encoder, requestHandler);
+ }
+
+}
diff --git a/src/main/hivemall/mix/store/PartialArgminKLD.java b/src/main/hivemall/mix/store/PartialArgminKLD.java
new file mode 100644
index 00000000..c11fb605
--- /dev/null
+++ b/src/main/hivemall/mix/store/PartialArgminKLD.java
@@ -0,0 +1,63 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013-2014
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.mix.store;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.concurrent.GuardedBy;
+
+public final class PartialArgminKLD extends PartialResult {
+
+ private final float scale;
+
+ @GuardedBy("lock()")
+ private float sum_mean_div_covar;
+ @GuardedBy("lock()")
+ private float sum_inv_covar;
+
+ public PartialArgminKLD() {
+ this(1.f); // no scaling
+ }
+
+ public PartialArgminKLD(@Nonnegative float scale) {
+ super();
+ this.scale = scale;
+ this.sum_mean_div_covar = 0.f;
+ this.sum_inv_covar = 0.f;
+ }
+
+ @Override
+ public void add(float localWeight, float covar, short clock, int deltaUpdates) {
+ addWeight(localWeight, covar);
+ setMinCovariance(covar);
+ incrClock(clock);
+ }
+
+ protected void addWeight(float localWeight, float covar) {
+ this.sum_mean_div_covar += (localWeight / covar) / scale;
+ this.sum_inv_covar += (1.f / covar) / scale;
+ }
+
+ @Override
+ public float getWeight() {
+ return sum_mean_div_covar / sum_inv_covar;
+ }
+
+}
diff --git a/src/main/hivemall/mix/store/PartialAverage.java b/src/main/hivemall/mix/store/PartialAverage.java
new file mode 100644
index 00000000..069614d6
--- /dev/null
+++ b/src/main/hivemall/mix/store/PartialAverage.java
@@ -0,0 +1,65 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013-2014
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.mix.store;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.concurrent.GuardedBy;
+
+public final class PartialAverage extends PartialResult {
+ public static final float DEFAULT_SCALE = 10;
+
+ private final float scale;
+
+ @GuardedBy("lock()")
+ private float scaledSumWeights;
+ @GuardedBy("lock()")
+ private short totalUpdates;
+
+ public PartialAverage() {
+ this(1.f); // no scaling
+ }
+
+ public PartialAverage(float scale) {
+ super();
+ this.scale = scale;
+ this.scaledSumWeights = 0.f;
+ this.totalUpdates = 0;
+ }
+
+ @Override
+ public void add(float localWeight, float covar, short clock, @Nonnegative int deltaUpdates) {
+ addWeight(localWeight, deltaUpdates);
+ setMinCovariance(covar);
+ incrClock(clock);
+ }
+
+ protected void addWeight(float localWeight, int deltaUpdates) {
+ scaledSumWeights += ((localWeight / scale) * deltaUpdates);
+ totalUpdates += deltaUpdates; // not deltaUpdates is in range (0,127]
+ assert (totalUpdates > 0) : totalUpdates;
+ }
+
+ @Override
+ public float getWeight() {
+ return (scaledSumWeights / totalUpdates) * scale;
+ }
+
+}
diff --git a/src/main/hivemall/mix/store/PartialResult.java b/src/main/hivemall/mix/store/PartialResult.java
new file mode 100644
index 00000000..713df960
--- /dev/null
+++ b/src/main/hivemall/mix/store/PartialResult.java
@@ -0,0 +1,75 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013-2014
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.mix.store;
+
+import hivemall.utils.lock.Lock;
+import hivemall.utils.lock.TTASLock;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.concurrent.GuardedBy;
+
+public abstract class PartialResult {
+
+ private final Lock lock;
+
+ @GuardedBy("lock()")
+ protected float minCovariance;
+ @GuardedBy("lock()")
+ protected short totalClock;
+
+ public PartialResult() {
+ this.lock = new TTASLock();
+ }
+
+ public final void lock() {
+ lock.lock();
+ }
+
+ public final void unlock() {
+ lock.unlock();
+ }
+
+ public abstract void add(float localWeight, float covar, short clock, @Nonnegative int deltaUpdates);
+
+ public abstract float getWeight();
+
+ public final float getMinCovariance() {
+ return minCovariance;
+ }
+
+ protected final void setMinCovariance(float covar) {
+ this.minCovariance = Math.max(minCovariance, covar);
+ }
+
+ public final short getClock() {
+ return totalClock;
+ }
+
+ protected final void incrClock(short clock) {
+ totalClock += clock;
+ }
+
+ public final int diffClock(short clock) {
+ short diff = (short) (totalClock - clock);
+ return diff < 0 ? -diff : diff;
+ }
+
+}
diff --git a/src/main/hivemall/mix/store/SessionObject.java b/src/main/hivemall/mix/store/SessionObject.java
new file mode 100644
index 00000000..e1438ffb
--- /dev/null
+++ b/src/main/hivemall/mix/store/SessionObject.java
@@ -0,0 +1,58 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013-2014
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.mix.store;
+
+import java.util.concurrent.ConcurrentMap;
+
+import javax.annotation.Nonnull;
+import javax.annotation.concurrent.ThreadSafe;
+
+@ThreadSafe
+public final class SessionObject {
+
+ @Nonnull
+ private final ConcurrentMap object;
+ private volatile long lastAccessed; // being accessed by multiple threads
+
+ public SessionObject(@Nonnull ConcurrentMap obj) {
+ if(obj == null) {
+ throw new IllegalArgumentException("obj is null");
+ }
+ this.object = obj;
+ }
+
+ @Nonnull
+ public ConcurrentMap get() {
+ return object;
+ }
+
+ /**
+ * @return last accessed time in msec
+ */
+ public long getLastAccessed() {
+ return lastAccessed;
+ }
+
+ public void touch() {
+ this.lastAccessed = System.currentTimeMillis();
+ }
+
+}
diff --git a/src/main/hivemall/mix/store/SessionStore.java b/src/main/hivemall/mix/store/SessionStore.java
new file mode 100644
index 00000000..c66504d9
--- /dev/null
+++ b/src/main/hivemall/mix/store/SessionStore.java
@@ -0,0 +1,74 @@
+package hivemall.mix.store;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import javax.annotation.concurrent.ThreadSafe;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+@ThreadSafe
+public final class SessionStore {
+ private static final Log logger = LogFactory.getLog(SessionStore.class);
+
+ private static final int EXPECTED_MODEL_SIZE = 4194305; /* 2^22+1=4194304+1=4194305 */
+
+ private final ConcurrentMap sessions;
+
+ public SessionStore() {
+ this.sessions = new ConcurrentHashMap();
+ }
+
+ @Nonnull
+ private ConcurrentMap getSessions() {
+ return sessions;
+ }
+
+ @Nonnull
+ public ConcurrentMap get(@Nonnull String groupID) {
+ SessionObject sessionObj = sessions.get(groupID);
+ if(sessionObj == null) {
+ ConcurrentMap map = new ConcurrentHashMap(EXPECTED_MODEL_SIZE);
+ sessionObj = new SessionObject(map);
+ SessionObject existing = sessions.putIfAbsent(groupID, sessionObj);
+ if(existing != null) {
+ sessionObj = existing;
+ }
+ }
+ ConcurrentMap map = sessionObj.get();
+ sessionObj.touch();
+ return map;
+ }
+
+ @ThreadSafe
+ public static final class IdleSessionSweeper implements Runnable {
+
+ private final ConcurrentMap sessions;
+ private final long ttl;
+
+ public IdleSessionSweeper(@Nonnull SessionStore sessionStore, @Nonnegative long ttlInMillis) {
+ this.sessions = sessionStore.getSessions();
+ this.ttl = ttlInMillis;
+ }
+
+ public void run() {
+ for(Map.Entry e : sessions.entrySet()) {
+ SessionObject sessionObj = e.getValue();
+ long lastAccessed = sessionObj.getLastAccessed();
+ long elapsedTime = System.currentTimeMillis() - lastAccessed;
+ if(elapsedTime > ttl) {
+ String key = e.getKey();
+ assert (key != null);
+ sessions.remove(key);
+ logger.info("Removed an idle session group: " + key);
+ }
+ }
+
+ }
+ }
+
+}
diff --git a/src/main/hivemall/regression/AROWRegressionUDTF.java b/src/main/hivemall/regression/AROWRegressionUDTF.java
index 910a6e67..09d41b8e 100644
--- a/src/main/hivemall/regression/AROWRegressionUDTF.java
+++ b/src/main/hivemall/regression/AROWRegressionUDTF.java
@@ -23,8 +23,8 @@
import hivemall.common.LossFunctions;
import hivemall.common.OnlineVariance;
import hivemall.io.FeatureValue;
+import hivemall.io.IWeightValue;
import hivemall.io.PredictionResult;
-import hivemall.io.WeightValue;
import hivemall.io.WeightValue.WeightValueWithCovar;
import java.util.Collection;
@@ -122,13 +122,13 @@ protected void update(final Collection> features, final float coeff, final flo
k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector);
v = 1.f;
}
- WeightValue old_w = model.get(k);
- WeightValue new_w = getNewWeight(old_w, v, coeff, beta);
+ IWeightValue old_w = model.get(k);
+ IWeightValue new_w = getNewWeight(old_w, v, coeff, beta);
model.set(k, new_w);
}
}
- private static WeightValue getNewWeight(final WeightValue old, final float x, final float coeff, final float beta) {
+ private static IWeightValue getNewWeight(final IWeightValue old, final float x, final float coeff, final float beta) {
final float old_w;
final float old_cov;
if(old == null) {
diff --git a/src/main/hivemall/regression/OnlineRegressionUDTF.java b/src/main/hivemall/regression/OnlineRegressionUDTF.java
index e1caeb81..f70f9099 100644
--- a/src/main/hivemall/regression/OnlineRegressionUDTF.java
+++ b/src/main/hivemall/regression/OnlineRegressionUDTF.java
@@ -25,6 +25,7 @@
import static hivemall.HivemallConstants.STRING_TYPE_NAME;
import hivemall.LearnerBaseUDTF;
import hivemall.io.FeatureValue;
+import hivemall.io.IWeightValue;
import hivemall.io.PredictionModel;
import hivemall.io.PredictionResult;
import hivemall.io.WeightValue;
@@ -215,7 +216,7 @@ protected PredictionResult calcScoreAndVariance(Collection> features) {
k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector);
v = 1.f;
}
- WeightValue old_w = model.get(k);
+ IWeightValue old_w = model.get(k);
if(old_w == null) {
variance += (1.f * v * v);
} else {
@@ -260,7 +261,8 @@ protected void update(Collection> features, float coeff) {
}
@Override
- public void close() throws HiveException {
+ public final void close() throws HiveException {
+ super.close();
if(model != null) {
int numForwarded = 0;
if(useCovariance()) {
@@ -268,7 +270,7 @@ public void close() throws HiveException {
final Object[] forwardMapObj = new Object[3];
final FloatWritable fv = new FloatWritable();
final FloatWritable cov = new FloatWritable();
- final IMapIterator itor = model.entries();
+ final IMapIterator itor = model.entries();
while(itor.next() != -1) {
itor.getValue(probe);
if(!probe.isTouched()) {
@@ -287,7 +289,7 @@ public void close() throws HiveException {
final WeightValue probe = new WeightValue();
final Object[] forwardMapObj = new Object[2];
final FloatWritable fv = new FloatWritable();
- final IMapIterator itor = model.entries();
+ final IMapIterator itor = model.entries();
while(itor.next() != -1) {
itor.getValue(probe);
if(!probe.isTouched()) {
@@ -301,10 +303,11 @@ public void close() throws HiveException {
numForwarded++;
}
}
+ int numMixed = model.getNumMixed();
this.model = null;
- logger.info("Trained a prediction model using " + count
- + " training examples. Forwarded the prediction model of " + numForwarded
- + " rows");
+ logger.info("Trained a prediction model using " + count + " training examples"
+ + (numMixed > 0 ? "( numMixed: " + numMixed + " )" : ""));
+ logger.info("Forwarded the prediction model of " + numForwarded + " rows");
}
}
diff --git a/src/main/hivemall/tools/mapred/JobConfGetsUDF.java b/src/main/hivemall/tools/mapred/JobConfGetsUDF.java
new file mode 100644
index 00000000..dcb3b53d
--- /dev/null
+++ b/src/main/hivemall/tools/mapred/JobConfGetsUDF.java
@@ -0,0 +1,61 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013-2014
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.tools.mapred;
+
+import static hivemall.utils.hadoop.WritableUtils.val;
+import hivemall.utils.hadoop.HadoopUtils;
+
+import javax.annotation.Nullable;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.MapredContext;
+import org.apache.hadoop.hive.ql.exec.MapredContextAccessor;
+import org.apache.hadoop.hive.ql.exec.UDF;
+import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.JobConf;
+
+/**
+ * @since Hive 0.12.0
+ */
+@Description(name = "jobconf_gets", value = "_FUNC_() - Returns the value from JobConf")
+@UDFType(deterministic = false, stateful = true)
+public class JobConfGetsUDF extends UDF {
+
+ public Text evaluate() {
+ return evaluate(null);
+ }
+
+ public Text evaluate(@Nullable final String regexKey) {
+ MapredContext ctx = MapredContextAccessor.get();
+ if(ctx == null) {
+ throw new IllegalStateException("MapredContext is not set");
+ }
+ JobConf jobconf = ctx.getJobConf();
+ if(jobconf == null) {
+ throw new IllegalStateException("JobConf is not set");
+ }
+
+ String dumped = HadoopUtils.toString(jobconf, regexKey);
+ return val(dumped);
+ }
+
+}
diff --git a/src/main/hivemall/tools/mapred/JobIdUDF.java b/src/main/hivemall/tools/mapred/JobIdUDF.java
new file mode 100644
index 00000000..cb98d4a7
--- /dev/null
+++ b/src/main/hivemall/tools/mapred/JobIdUDF.java
@@ -0,0 +1,42 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013-2014
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.tools.mapred;
+
+import static hivemall.utils.hadoop.WritableUtils.val;
+import hivemall.utils.hadoop.HadoopUtils;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDF;
+import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.io.Text;
+
+@Description(name = "jobid", value = "_FUNC_() - Returns the value of mapred.task.partition")
+@UDFType(deterministic = false, stateful = true)
+public class JobIdUDF extends UDF {
+
+ /**
+ * @since Hive 0.12.0
+ */
+ public Text evaluate() {
+ return val(HadoopUtils.getJobId());
+ }
+
+}
diff --git a/src/main/hivemall/utils/hadoop/HadoopUtils.java b/src/main/hivemall/utils/hadoop/HadoopUtils.java
index af163a9c..81f7f69d 100644
--- a/src/main/hivemall/utils/hadoop/HadoopUtils.java
+++ b/src/main/hivemall/utils/hadoop/HadoopUtils.java
@@ -28,6 +28,11 @@
import java.io.InputStreamReader;
import java.io.Reader;
import java.net.URI;
+import java.util.Iterator;
+import java.util.Map.Entry;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
@@ -39,6 +44,8 @@
import org.apache.hadoop.io.compress.CompressionInputStream;
import org.apache.hadoop.io.compress.Decompressor;
import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.JobID;
+import org.apache.hadoop.mapred.TaskID;
public final class HadoopUtils {
@@ -69,7 +76,7 @@ public static BufferedReader getBufferedReader(File file, MapredContext context)
}
}
- private static class BufferedReaderExt extends BufferedReader {
+ private static final class BufferedReaderExt extends BufferedReader {
private Decompressor decompressor;
@@ -89,14 +96,93 @@ public void close() throws IOException {
}
- public static int getTaskId() {
+ @Nonnull
+ public static String getJobId() {
MapredContext ctx = MapredContextAccessor.get();
+ if(ctx == null) {
+ throw new IllegalStateException("MapredContext is not set");
+ }
JobConf conf = ctx.getJobConf();
- int taskid = conf.getInt("mapred.task.partition", -1);
+ if(conf == null) {
+ throw new IllegalStateException("JobConf is not set");
+ }
+ String jobId = conf.get("mapred.job.id");
+ if(jobId == null) {
+ jobId = conf.get("mapreduce.job.id");
+ if(jobId == null) {
+ String queryId = conf.get("hive.query.id");
+ if(queryId != null) {
+ return queryId;
+ }
+ String taskidStr = conf.get("mapred.task.id");
+ if(taskidStr == null) {
+ throw new IllegalStateException("Cannot resolve jobId: " + toString(conf));
+ }
+ jobId = getJobIdFromTaskId(taskidStr);
+ }
+ }
+ return jobId;
+ }
+
+ @Nonnull
+ public static String getJobIdFromTaskId(@Nonnull String taskidStr) {
+ if(!taskidStr.startsWith("task_")) {// workaround for Tez
+ taskidStr = taskidStr.replace("task", "task_");
+ taskidStr = taskidStr.substring(0, taskidStr.lastIndexOf('_'));
+ }
+ TaskID taskId = TaskID.forName(taskidStr);
+ JobID jobId = taskId.getJobID();
+ return jobId.toString();
+ }
+
+ public static int getTaskId() {
+ MapredContext ctx = MapredContextAccessor.get();
+ if(ctx == null) {
+ throw new IllegalStateException("MapredContext is not set");
+ }
+ JobConf jobconf = ctx.getJobConf();
+ if(jobconf == null) {
+ throw new IllegalStateException("JobConf is not set");
+ }
+ int taskid = jobconf.getInt("mapred.task.partition", -1);
if(taskid == -1) {
- throw new IllegalStateException("mapred.task.partition is not set");
+ taskid = jobconf.getInt("mapreduce.task.partition", -1);
+ if(taskid == -1) {
+ throw new IllegalStateException("Both mapred.task.partition and mapreduce.task.partition are not set: "
+ + toString(jobconf));
+ }
}
return taskid;
}
+ @Nonnull
+ public static String toString(@Nonnull JobConf jobconf) {
+ return toString(jobconf, null);
+ }
+
+ @Nonnull
+ public static String toString(@Nonnull JobConf jobconf, @Nullable String regexKey) {
+ final Iterator> itor = jobconf.iterator();
+ boolean hasNext = itor.hasNext();
+ if(!hasNext) {
+ return "";
+ }
+ final StringBuilder buf = new StringBuilder(1024);
+ do {
+ Entry e = itor.next();
+ hasNext = itor.hasNext();
+ String k = e.getKey();
+ if(k == null) {
+ continue;
+ }
+ if(regexKey == null || k.matches(regexKey)) {
+ String v = e.getValue();
+ buf.append(k).append('=').append(v);
+ if(hasNext) {
+ buf.append(',');
+ }
+ }
+ } while(hasNext);
+ return buf.toString();
+ }
}
diff --git a/src/main/hivemall/utils/io/IOUtils.java b/src/main/hivemall/utils/io/IOUtils.java
new file mode 100644
index 00000000..f94b241a
--- /dev/null
+++ b/src/main/hivemall/utils/io/IOUtils.java
@@ -0,0 +1,40 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013-2014
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.utils.io;
+
+import java.io.Closeable;
+import java.io.IOException;
+
+public final class IOUtils {
+
+ private IOUtils() {}
+
+ public static void closeQuietly(final Closeable channel) {
+ if(channel != null) {
+ try {
+ channel.close();
+ } catch (IOException e) {
+ ;
+ }
+ }
+ }
+
+}
diff --git a/src/main/hivemall/utils/lang/Primitives.java b/src/main/hivemall/utils/lang/Primitives.java
index 07ecd5d6..78236e96 100644
--- a/src/main/hivemall/utils/lang/Primitives.java
+++ b/src/main/hivemall/utils/lang/Primitives.java
@@ -24,29 +24,52 @@ public final class Primitives {
private Primitives() {}
- public static int parseInt(String s, int defaultValue) {
+ public static int toUnsignedShort(final short v) {
+ return v & 0xFFFF; // convert to range 0-65535 from -32768-32767.
+ }
+
+ public static short parseShort(final String s, final short defaultValue) {
+ if(s == null) {
+ return defaultValue;
+ }
+ return Short.parseShort(s);
+ }
+
+ public static int parseInt(final String s, final int defaultValue) {
if(s == null) {
return defaultValue;
}
return Integer.parseInt(s);
}
- public static float parseFloat(String s, float defaultValue) {
+ public static long parseLong(final String s, final long defaultValue) {
+ if(s == null) {
+ return defaultValue;
+ }
+ return Long.parseLong(s);
+ }
+
+ public static float parseFloat(final String s, final float defaultValue) {
if(s == null) {
return defaultValue;
}
return Float.parseFloat(s);
}
- public static boolean parseBoolean(String s, boolean defaultValue) {
+ public static boolean parseBoolean(final String s, final boolean defaultValue) {
if(s == null) {
return defaultValue;
}
return Boolean.parseBoolean(s);
}
-
- public static int compare(int x, int y) {
+
+ public static int compare(final int x, final int y) {
return (x < y) ? -1 : ((x == y) ? 0 : 1);
}
+ public static void putChar(final byte[] b, final int off, final char val) {
+ b[off + 1] = (byte) (val >>> 0);
+ b[off] = (byte) (val >>> 8);
+ }
+
}
diff --git a/src/main/hivemall/utils/lang/StringUtils.java b/src/main/hivemall/utils/lang/StringUtils.java
new file mode 100644
index 00000000..11ee3da1
--- /dev/null
+++ b/src/main/hivemall/utils/lang/StringUtils.java
@@ -0,0 +1,50 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013-2014
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.utils.lang;
+
+public final class StringUtils {
+
+ private StringUtils() {}
+
+ public static byte[] getBytes(final String s) {
+ final int len = s.length();
+ final byte[] b = new byte[len * 2];
+ for(int i = 0; i < len; i++) {
+ Primitives.putChar(b, i * 2, s.charAt(i));
+ }
+ return b;
+ }
+
+ public static String toString(byte[] b) {
+ return toString(b, 0, b.length);
+ }
+
+ public static String toString(byte[] b, int off, int len) {
+ final int clen = len >>> 1;
+ final char[] c = new char[clen];
+ for(int i = 0; i < clen; i++) {
+ final int j = off + (i << 1);
+ c[i] = (char) ((b[j + 1] & 0xFF) + ((b[j + 0]) << 8));
+ }
+ return new String(c);
+ }
+
+}
diff --git a/src/main/hivemall/utils/lock/Lock.java b/src/main/hivemall/utils/lock/Lock.java
new file mode 100644
index 00000000..c867fc30
--- /dev/null
+++ b/src/main/hivemall/utils/lock/Lock.java
@@ -0,0 +1,33 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013-2014
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.utils.lock;
+
+public interface Lock {
+
+ void lock();
+
+ boolean tryLock();
+
+ void unlock();
+
+ boolean isLocked();
+
+}
diff --git a/src/main/hivemall/utils/lock/TTASLock.java b/src/main/hivemall/utils/lock/TTASLock.java
new file mode 100644
index 00000000..bb5fb482
--- /dev/null
+++ b/src/main/hivemall/utils/lock/TTASLock.java
@@ -0,0 +1,65 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013-2014
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.utils.lock;
+
+import java.util.concurrent.atomic.AtomicBoolean;
+
+public final class TTASLock implements Lock {
+
+ private final AtomicBoolean state;
+
+ public TTASLock() {
+ this(false);
+ }
+
+ public TTASLock(boolean locked) {
+ this.state = new AtomicBoolean(locked);
+ }
+
+ @Override
+ public void lock() {
+ while(true) {
+ while(state.get())
+ ; // wait until the lock free
+ if(!state.getAndSet(true)) { // now try to acquire the lock
+ return;
+ }
+ }
+ }
+
+ @Override
+ public boolean tryLock() {
+ if(state.get()) {
+ return false;
+ }
+ return !state.getAndSet(true);
+ }
+
+ @Override
+ public void unlock() {
+ state.set(false);
+ }
+
+ @Override
+ public boolean isLocked() {
+ return state.get();
+ }
+}
diff --git a/src/main/hivemall/utils/net/NetUtils.java b/src/main/hivemall/utils/net/NetUtils.java
new file mode 100644
index 00000000..935968c9
--- /dev/null
+++ b/src/main/hivemall/utils/net/NetUtils.java
@@ -0,0 +1,57 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013-2014
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.utils.net;
+
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.UnknownHostException;
+
+public final class NetUtils {
+
+ private NetUtils() {}
+
+ public static InetSocketAddress getInetSocketAddress(String endpointURI, int defaultPort) {
+ final int pos = endpointURI.indexOf(':');
+ if(pos == -1) {
+ InetAddress addr = getInetAddress(endpointURI);
+ return new InetSocketAddress(addr, defaultPort);
+ } else {
+ String host = endpointURI.substring(0, pos);
+ InetAddress addr = getInetAddress(host);
+ String portStr = endpointURI.substring(pos + 1);
+ int port = Integer.parseInt(portStr);
+ return new InetSocketAddress(addr, port);
+ }
+ }
+
+ public static InetAddress getInetAddress(final String addressOrName) {
+ try {
+ return InetAddress.getByName(addressOrName);
+ } catch (UnknownHostException e) {
+ throw new IllegalArgumentException("Cannot find InetAddress: " + addressOrName);
+ }
+ }
+
+ public static boolean isIPAddress(final String ip) {
+ return ip.matches("^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])$");
+ }
+
+}
diff --git a/src/test/hivemall/common/FeatureValueTest.java b/src/test/hivemall/io/FeatureValueTest.java
similarity index 98%
rename from src/test/hivemall/common/FeatureValueTest.java
rename to src/test/hivemall/io/FeatureValueTest.java
index f9bc0757..4be2b06e 100644
--- a/src/test/hivemall/common/FeatureValueTest.java
+++ b/src/test/hivemall/io/FeatureValueTest.java
@@ -18,7 +18,7 @@
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
-package hivemall.common;
+package hivemall.io;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
diff --git a/src/test/hivemall/common/SpaceEfficientDenseModelTest.java b/src/test/hivemall/io/SpaceEfficientDenseModelTest.java
similarity index 72%
rename from src/test/hivemall/common/SpaceEfficientDenseModelTest.java
rename to src/test/hivemall/io/SpaceEfficientDenseModelTest.java
index 41eedaaa..24b53b0d 100644
--- a/src/test/hivemall/common/SpaceEfficientDenseModelTest.java
+++ b/src/test/hivemall/io/SpaceEfficientDenseModelTest.java
@@ -1,9 +1,6 @@
-package hivemall.common;
+package hivemall.io;
import static junit.framework.Assert.assertEquals;
-import hivemall.io.DenseModel;
-import hivemall.io.SpaceEfficientDenseModel;
-import hivemall.io.WeightValue;
import hivemall.utils.collections.IMapIterator;
import java.util.Random;
@@ -17,19 +14,22 @@ public void testGetSet() {
final int size = 1 << 12;
final SpaceEfficientDenseModel model1 = new SpaceEfficientDenseModel(size);
+ //model1.configureClock();
final DenseModel model2 = new DenseModel(size);
+ //model2.configureClock();
final Random rand = new Random();
for(int t = 0; t < 1000; t++) {
int i = rand.nextInt(size);
- float w = 65520f * rand.nextFloat();
- model1.setValue(i, w);
- model2.setValue(i, w);
+ float f = 65520f * rand.nextFloat();
+ IWeightValue w = new WeightValue(f);
+ model1.set(i, w);
+ model2.set(i, w);
}
assertEquals(model2.size(), model1.size());
- IMapIterator itor = model1.entries();
+ IMapIterator itor = model1.entries();
while(itor.next() != -1) {
int k = itor.getKey();
float expected = itor.getValue().get();
diff --git a/src/test/hivemall/mix/client/MixRequestRouterTest.java b/src/test/hivemall/mix/client/MixRequestRouterTest.java
new file mode 100644
index 00000000..449c1b83
--- /dev/null
+++ b/src/test/hivemall/mix/client/MixRequestRouterTest.java
@@ -0,0 +1,37 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013-2014
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.mix.client;
+
+import hivemall.mix.NodeInfo;
+import junit.framework.Assert;
+
+import org.junit.Test;
+
+public class MixRequestRouterTest {
+
+ @Test
+ public void test() {
+ MixRequestRouter router = new MixRequestRouter("dm01.hpcc.jp:11212,yahoo.co.jp:11212,google.com");
+ NodeInfo[] nodes = router.getAllNodes();
+ Assert.assertEquals(3, nodes.length);
+ }
+
+}
diff --git a/src/test/hivemall/mix/server/MixServerTest.java b/src/test/hivemall/mix/server/MixServerTest.java
new file mode 100644
index 00000000..c747e7ad
--- /dev/null
+++ b/src/test/hivemall/mix/server/MixServerTest.java
@@ -0,0 +1,238 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013-2014
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.mix.server;
+
+import hivemall.io.DenseModel;
+import hivemall.io.PredictionModel;
+import hivemall.io.SparseModel;
+import hivemall.io.WeightValue;
+import hivemall.mix.MixMessage.MixEventName;
+import hivemall.mix.client.MixClient;
+import hivemall.utils.io.IOUtils;
+import hivemall.utils.lang.CommandLineUtils;
+
+import java.util.Random;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+
+import junit.framework.Assert;
+
+import org.apache.commons.cli.CommandLine;
+import org.junit.Test;
+
+public class MixServerTest {
+
+ @Test
+ public void testSimpleScenario() throws InterruptedException {
+ CommandLine cl = CommandLineUtils.parseOptions(new String[] { "-port", "11212",
+ "-sync_threshold", "3" }, MixServer.getOptions());
+ MixServer server = new MixServer(cl);
+ ExecutorService serverExec = Executors.newSingleThreadExecutor();
+ serverExec.submit(server);
+
+ PredictionModel model = new DenseModel(16777216, false);
+ model.configureClock();
+ MixClient client = new MixClient(MixEventName.average, "testSimpleScenario", "localhost:11212", false, 2, model);
+ model.setUpdateHandler(client);
+
+ final Random rand = new Random(43);
+ for(int i = 0; i < 100000; i++) {
+ Integer feature = Integer.valueOf(rand.nextInt(100));
+ float weight = (float) rand.nextGaussian();
+ model.set(feature, new WeightValue(weight));
+ }
+
+ Thread.sleep(5 * 1000);
+ int numMixed = model.getNumMixed();
+ //System.out.println("number of mix events: " + numMixed);
+ Assert.assertTrue("number of mix events: " + numMixed, numMixed > 0);
+
+ IOUtils.closeQuietly(client);
+ serverExec.shutdown();
+ }
+
+ @Test
+ public void testSSL() throws InterruptedException {
+ CommandLine cl = CommandLineUtils.parseOptions(new String[] { "-port", "11213",
+ "-sync_threshold", "3", "-ssl" }, MixServer.getOptions());
+ MixServer server = new MixServer(cl);
+ ExecutorService serverExec = Executors.newSingleThreadExecutor();
+ serverExec.submit(server);
+
+ PredictionModel model = new DenseModel(16777216, false);
+ model.configureClock();
+ MixClient client = new MixClient(MixEventName.average, "testSSL", "localhost:11213", true, 2, model);
+ model.setUpdateHandler(client);
+
+ final Random rand = new Random(43);
+ for(int i = 0; i < 100000; i++) {
+ Integer feature = Integer.valueOf(rand.nextInt(100));
+ float weight = (float) rand.nextGaussian();
+ model.set(feature, new WeightValue(weight));
+ }
+
+ Thread.sleep(5 * 1000);
+ int numMixed = model.getNumMixed();
+ //System.out.println("number of mix events: " + numMixed);
+ Assert.assertTrue("number of mix events: " + numMixed, numMixed > 0);
+
+ IOUtils.closeQuietly(client);
+ serverExec.shutdown();
+ }
+
+ @Test
+ public void testMultipleClients() throws InterruptedException {
+ CommandLine cl = CommandLineUtils.parseOptions(new String[] { "-port", "11214",
+ "-sync_threshold", "3" }, MixServer.getOptions());
+ MixServer server = new MixServer(cl);
+ ExecutorService serverExec = Executors.newSingleThreadExecutor();
+ serverExec.submit(server);
+
+ Thread.sleep(500);// slight delay to boot a server
+
+ final int numClients = 5;
+ final ExecutorService clientsExec = Executors.newCachedThreadPool();
+ for(int i = 0; i < numClients; i++) {
+ clientsExec.submit(new Runnable() {
+ @Override
+ public void run() {
+ try {
+ invokeClient("testMultipleClients", 11214);
+ } catch (InterruptedException e) {
+ Assert.fail(e.getMessage());
+ }
+ }
+ });
+ }
+ clientsExec.awaitTermination(10, TimeUnit.SECONDS);
+ clientsExec.shutdown();
+ serverExec.shutdown();
+ }
+
+ private static void invokeClient(String groupId, int serverPort) throws InterruptedException {
+ PredictionModel model = new DenseModel(16777216, false);
+ model.configureClock();
+ MixClient client = new MixClient(MixEventName.average, groupId, "localhost:" + serverPort, false, 2, model);
+ model.setUpdateHandler(client);
+
+ final Random rand = new Random(43);
+ for(int i = 0; i < 100000; i++) {
+ Integer feature = Integer.valueOf(rand.nextInt(100));
+ float weight = (float) rand.nextGaussian();
+ model.set(feature, new WeightValue(weight));
+ }
+
+ Thread.sleep(5 * 1000);
+
+ int numMixed = model.getNumMixed();
+ //System.out.println("number of mix events: " + numMixed);
+ Assert.assertTrue("number of mix events: " + numMixed, numMixed > 0);
+
+ IOUtils.closeQuietly(client);
+ }
+
+ @Test
+ public void test2ClientsZeroOneSparseModel() throws InterruptedException {
+ CommandLine cl = CommandLineUtils.parseOptions(new String[] { "-port", "11215",
+ "-sync_threshold", "30" }, MixServer.getOptions());
+ MixServer server = new MixServer(cl);
+ ExecutorService serverExec = Executors.newSingleThreadExecutor();
+ serverExec.submit(server);
+
+ Thread.sleep(500);// slight delay to boot a server
+
+ final ExecutorService clientsExec = Executors.newCachedThreadPool();
+ for(int i = 0; i < 2; i++) {
+ clientsExec.submit(new Runnable() {
+ @Override
+ public void run() {
+ try {
+ invokeClient01("test2ClientsZeroOne", 11215, false);
+ } catch (InterruptedException e) {
+ Assert.fail(e.getMessage());
+ }
+ }
+ });
+ }
+ clientsExec.awaitTermination(30, TimeUnit.SECONDS);
+ clientsExec.shutdown();
+ serverExec.shutdown();
+ }
+
+ @Test
+ public void test2ClientsZeroOneDenseModel() throws InterruptedException {
+ CommandLine cl = CommandLineUtils.parseOptions(new String[] { "-port", "11215",
+ "-sync_threshold", "30" }, MixServer.getOptions());
+ MixServer server = new MixServer(cl);
+ ExecutorService serverExec = Executors.newSingleThreadExecutor();
+ serverExec.submit(server);
+
+ Thread.sleep(500);// slight delay to boot a server
+
+ final ExecutorService clientsExec = Executors.newCachedThreadPool();
+ for(int i = 0; i < 2; i++) {
+ clientsExec.submit(new Runnable() {
+ @Override
+ public void run() {
+ try {
+ invokeClient01("test2ClientsZeroOne", 11215, true);
+ } catch (InterruptedException e) {
+ Assert.fail(e.getMessage());
+ }
+ }
+ });
+ }
+ clientsExec.awaitTermination(30, TimeUnit.SECONDS);
+ clientsExec.shutdown();
+ serverExec.shutdown();
+ }
+
+ private static void invokeClient01(String groupId, int serverPort, boolean denseModel)
+ throws InterruptedException {
+ PredictionModel model = denseModel ? new DenseModel(100, false)
+ : new SparseModel(100, false);
+ model.configureClock();
+ MixClient client = new MixClient(MixEventName.average, groupId, "localhost:" + serverPort, false, 3, model);
+ model.setUpdateHandler(client);
+
+ final Random rand = new Random(43);
+ for(int i = 0; i < 1000000; i++) {
+ Integer feature = Integer.valueOf(rand.nextInt(100));
+ float weight = rand.nextFloat() >= 0.5f ? 1.f : 0.f;
+ model.set(feature, new WeightValue(weight));
+ }
+
+ Thread.sleep(5 * 1000);
+
+ int numMixed = model.getNumMixed();
+ //System.out.println("number of mix events: " + numMixed);
+ Assert.assertTrue("number of mix events: " + numMixed, numMixed > 0);
+
+ for(int i = 0; i < 100; i++) {
+ float w = model.getWeight(i);
+ Assert.assertEquals(0.5f, w, 0.1f);
+ }
+
+ IOUtils.closeQuietly(client);
+ }
+
+}
diff --git a/src/test/hivemall/utils/collections/OpenHashMapTest.java b/src/test/hivemall/utils/collections/OpenHashMapTest.java
index cc5d69aa..f6ce0087 100644
--- a/src/test/hivemall/utils/collections/OpenHashMapTest.java
+++ b/src/test/hivemall/utils/collections/OpenHashMapTest.java
@@ -33,7 +33,7 @@ public class OpenHashMapTest {
@Test
public void testPutAndGet() {
Map map = new OpenHashMap(16384);
- final int numEntries = 10000000;
+ final int numEntries = 1000000;
for(int i = 0; i < numEntries; i++) {
map.put(Integer.toString(i), i);
}
diff --git a/src/test/hivemall/utils/hadoop/HadoopUtilsTest.java b/src/test/hivemall/utils/hadoop/HadoopUtilsTest.java
new file mode 100644
index 00000000..e604b761
--- /dev/null
+++ b/src/test/hivemall/utils/hadoop/HadoopUtilsTest.java
@@ -0,0 +1,35 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2013-2014
+ * National Institute of Advanced Industrial Science and Technology (AIST)
+ * Registration Number: H25PRO-1520
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+package hivemall.utils.hadoop;
+
+import junit.framework.Assert;
+
+import org.junit.Test;
+
+public class HadoopUtilsTest {
+
+ @Test
+ public void testGetJobIdFromTaskId() {
+ String actual = HadoopUtils.getJobIdFromTaskId("task1407733647643_0188_m_000000_0");
+ Assert.assertEquals("job_1407733647643_0188", actual);
+ }
+
+}
diff --git a/target/hivemall-fat.jar b/target/hivemall-fat.jar
new file mode 100644
index 00000000..0fd9a61f
Binary files /dev/null and b/target/hivemall-fat.jar differ
diff --git a/target/hivemall-with-dependencies.jar b/target/hivemall-with-dependencies.jar
new file mode 100644
index 00000000..8ff806ef
Binary files /dev/null and b/target/hivemall-with-dependencies.jar differ
diff --git a/target/hivemall.jar b/target/hivemall.jar
index 8d1a48be..34f5150b 100644
Binary files a/target/hivemall.jar and b/target/hivemall.jar differ