diff --git a/examples/apollo-service-example/src/main/java/com/spotify/zoltar/examples/apollo/IrisPredictor.java b/examples/apollo-service-example/src/main/java/com/spotify/zoltar/examples/apollo/IrisPredictor.java index 63547a88..6237c0d3 100644 --- a/examples/apollo-service-example/src/main/java/com/spotify/zoltar/examples/apollo/IrisPredictor.java +++ b/examples/apollo-service-example/src/main/java/com/spotify/zoltar/examples/apollo/IrisPredictor.java @@ -28,6 +28,7 @@ import com.spotify.zoltar.Predictors; import com.spotify.zoltar.featran.FeatranExtractFns; import com.spotify.zoltar.metrics.PredictorMetrics; +import com.spotify.zoltar.tf.TensorFlowLoader; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Paths; @@ -44,14 +45,17 @@ public final class IrisPredictor { public static Predictor create(final ModelConfig modelConfig, final PredictorMetrics metrics) throws IOException { + final FeatureSpec irisFeatureSpec = IrisFeaturesSpec.irisFeaturesSpec(); final String settings = new String(Files.readAllBytes(Paths.get(modelConfig.settingsUri()))); final ExtractFn extractFn = FeatranExtractFns.example(irisFeatureSpec, settings); + final TensorFlowLoader modelLoader = TensorFlowLoader + .create(modelConfig.modelUri().toString(), modelConfig.modelLoaderExecutor()); final String[] ops = new String[]{"linear/head/predictions/class_ids"}; return Predictors.tensorFlow( - modelConfig.modelUri().toString(), + modelLoader, extractFn, tensors -> tensors.get(ops[0]).longValue()[0], ops, diff --git a/examples/apollo-service-example/src/main/java/com/spotify/zoltar/examples/apollo/ModelConfig.java b/examples/apollo-service-example/src/main/java/com/spotify/zoltar/examples/apollo/ModelConfig.java index dbfa9d69..1d65f400 100644 --- a/examples/apollo-service-example/src/main/java/com/spotify/zoltar/examples/apollo/ModelConfig.java +++ b/examples/apollo-service-example/src/main/java/com/spotify/zoltar/examples/apollo/ModelConfig.java @@ -24,6 +24,8 @@ import com.typesafe.config.Config; import java.net.URI; import java.net.URISyntaxException; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; /** * Immutable object that contains model / feature extraction configuration properties. @@ -37,6 +39,11 @@ public abstract class ModelConfig { /** settings URI path. */ public abstract URI settingsUri(); + public Executor modelLoaderExecutor() { + final int threads = Runtime.getRuntime().availableProcessors(); + return Executors.newFixedThreadPool(threads); + } + /** * Creates a {@link ModelConfig} create a {@link Config}. */ diff --git a/examples/batch-predictor/src/main/java/com/spotify/zoltar/examples/batch/BatchPredictorExample.java b/examples/batch-predictor/src/main/java/com/spotify/zoltar/examples/batch/BatchPredictorExample.java index 11036390..bb63f781 100644 --- a/examples/batch-predictor/src/main/java/com/spotify/zoltar/examples/batch/BatchPredictorExample.java +++ b/examples/batch-predictor/src/main/java/com/spotify/zoltar/examples/batch/BatchPredictorExample.java @@ -27,7 +27,6 @@ import com.spotify.zoltar.Predictor; import com.spotify.zoltar.PredictorBuilder; import com.spotify.zoltar.Predictors; -import com.spotify.zoltar.loaders.Preloader; import java.time.Duration; import java.util.List; import java.util.concurrent.CompletionStage; @@ -43,9 +42,7 @@ class BatchPredictorExample implements Predictor, List> { private PredictorBuilder, List, List> predictorBuilder; BatchPredictorExample() { - final ModelLoader modelLoader = ModelLoader - .lift(DummyModel::new) - .with(Preloader.preload(Duration.ofMinutes(1))); + final ModelLoader modelLoader = ModelLoader.loaded(new DummyModel()); final BatchExtractFn batchExtractFn = BatchExtractFn.lift((Function) input -> (float) input / 10); diff --git a/examples/custom-metrics/src/main/java/com/spotify/zoltar/examples/metrics/CustomMetricsExample.java b/examples/custom-metrics/src/main/java/com/spotify/zoltar/examples/metrics/CustomMetricsExample.java index 20d5cec5..a8d1f2ec 100644 --- a/examples/custom-metrics/src/main/java/com/spotify/zoltar/examples/metrics/CustomMetricsExample.java +++ b/examples/custom-metrics/src/main/java/com/spotify/zoltar/examples/metrics/CustomMetricsExample.java @@ -37,7 +37,6 @@ import com.spotify.zoltar.PredictorBuilder; import com.spotify.zoltar.Predictors; import com.spotify.zoltar.Vector; -import com.spotify.zoltar.loaders.Preloader; import com.spotify.zoltar.metrics.FeatureExtractorMetrics; import com.spotify.zoltar.metrics.Instrumentations; import com.spotify.zoltar.metrics.PredictFnMetrics; @@ -165,9 +164,7 @@ public void extraction(final List> vectors) { } CustomMetricsExample(final SemanticMetricRegistry metricRegistry, final MetricId metricId) { - final ModelLoader modelLoader = ModelLoader - .lift(DummyModel::new) - .with(Preloader.preload(Duration.ofMinutes(1))); + final ModelLoader modelLoader = ModelLoader.loaded(new DummyModel()); final ExtractFn extractFn = ExtractFn.lift(input -> (float) input / 10); final PredictFn predictFn = (model, vectors) -> { return vectors.stream() diff --git a/zoltar-api/src/main/java/com/spotify/zoltar/Models.java b/zoltar-api/src/main/java/com/spotify/zoltar/Models.java index be1577e6..87dd8ff6 100644 --- a/zoltar-api/src/main/java/com/spotify/zoltar/Models.java +++ b/zoltar-api/src/main/java/com/spotify/zoltar/Models.java @@ -25,6 +25,9 @@ import com.spotify.zoltar.tf.TensorFlowLoader; import com.spotify.zoltar.tf.TensorFlowModel; import com.spotify.zoltar.xgboost.XGBoostLoader; +import java.io.IOException; +import java.security.GeneralSecurityException; +import java.util.concurrent.Executor; import javax.annotation.Nullable; import org.tensorflow.Graph; import org.tensorflow.framework.ConfigProto; @@ -45,9 +48,10 @@ private Models() { * * @param modelUri should point to serialized XGBoost model file, can be a URI to a local * filesystem, resource, GCS etc. + * @param executor the executor to use for asynchronous execution. */ - public static XGBoostLoader xgboost(final String modelUri) { - return XGBoostLoader.create(modelUri); + public static XGBoostLoader xgboost(final String modelUri, final Executor executor) { + return XGBoostLoader.create(modelUri, executor); } /** @@ -56,9 +60,12 @@ public static XGBoostLoader xgboost(final String modelUri) { * @param id model id @{link Model.Id}. * @param modelUri should point to serialized XGBoost model file, can be a URI to a local * filesystem, resource, GCS etc. + * @param executor the executor to use for asynchronous execution. */ - public static XGBoostLoader xgboost(final Model.Id id, final String modelUri) { - return XGBoostLoader.create(id, modelUri); + public static XGBoostLoader xgboost(final Model.Id id, + final String modelUri, + final Executor executor) { + return XGBoostLoader.create(id, modelUri, executor); } /** @@ -67,9 +74,10 @@ public static XGBoostLoader xgboost(final Model.Id id, final String modelUri) { * @param modelUri should point to a directory of the saved TensorFlow {@link * org.tensorflow.SavedModelBundle}, can be a URI to a local filesystem, resource, * GCS etc. + * @param executor the executor to use for asynchronous execution. */ - public static TensorFlowLoader tensorFlow(final String modelUri) { - return TensorFlowLoader.create(modelUri); + public static TensorFlowLoader tensorFlow(final String modelUri, final Executor executor) { + return TensorFlowLoader.create(modelUri, executor); } /** @@ -79,9 +87,12 @@ public static TensorFlowLoader tensorFlow(final String modelUri) { * @param modelUri should point to a directory of the saved TensorFlow {@link * org.tensorflow.SavedModelBundle}, can be a URI to a local filesystem, resource, * GCS etc. + * @param executor the executor to use for asynchronous execution. */ - public static TensorFlowLoader tensorFlow(final Model.Id id, final String modelUri) { - return TensorFlowLoader.create(id, modelUri); + public static TensorFlowLoader tensorFlow(final Model.Id id, + final String modelUri, + final Executor executor) { + return TensorFlowLoader.create(id, modelUri, executor); } /** @@ -91,10 +102,12 @@ public static TensorFlowLoader tensorFlow(final Model.Id id, final String modelU * org.tensorflow.SavedModelBundle}, can be a URI to a local filesystem, resource, * GCS etc. * @param options TensorFlow options, see {@link TensorFlowModel.Options}. + * @param executor the executor to use for asynchronous execution. */ public static TensorFlowLoader tensorFlow(final String modelUri, - final TensorFlowModel.Options options) { - return TensorFlowLoader.create(modelUri, options); + final TensorFlowModel.Options options, + final Executor executor) { + return TensorFlowLoader.create(modelUri, options, executor); } /** @@ -105,11 +118,13 @@ public static TensorFlowLoader tensorFlow(final String modelUri, * org.tensorflow.SavedModelBundle}, can be a URI to a local filesystem, resource, * GCS etc. * @param options TensorFlow options, see {@link TensorFlowModel.Options}. + * @param executor the executor to use for asynchronous execution. */ public static TensorFlowLoader tensorFlow(final Model.Id id, final String modelUri, - final TensorFlowModel.Options options) { - return TensorFlowLoader.create(id, modelUri, options); + final TensorFlowModel.Options options, + final Executor executor) { + return TensorFlowLoader.create(id, modelUri, options, executor); } /** @@ -119,12 +134,14 @@ public static TensorFlowLoader tensorFlow(final Model.Id id, * local filesystem, resource, GCS etc. * @param config optional TensorFlow {@link ConfigProto} config. * @param prefix optional prefix that will be prepended to names in the graph. + * @param executor the executor to use for asynchronous execution. */ public static TensorFlowGraphLoader tensorFlowGraph( final String modelUri, @Nullable final ConfigProto config, - @Nullable final String prefix) { - return TensorFlowGraphLoader.create(modelUri, config, prefix); + @Nullable final String prefix, + final Executor executor) { + return TensorFlowGraphLoader.create(modelUri, config, prefix, executor); } /** @@ -135,13 +152,15 @@ public static TensorFlowGraphLoader tensorFlowGraph( * local filesystem, resource, GCS etc. * @param config optional TensorFlow {@link ConfigProto} config. * @param prefix optional prefix that will be prepended to names in the graph. + * @param executor the executor to use for asynchronous execution. */ public static TensorFlowGraphLoader tensorFlowGraph( final Model.Id id, final String modelUri, @Nullable final ConfigProto config, - @Nullable final String prefix) { - return TensorFlowGraphLoader.create(id, modelUri, config, prefix); + @Nullable final String prefix, + final Executor executor) { + return TensorFlowGraphLoader.create(id, modelUri, config, prefix, executor); } /** @@ -150,12 +169,14 @@ public static TensorFlowGraphLoader tensorFlowGraph( * @param graphDef byte array representing the TensorFlow {@link Graph} definition. * @param config optional TensorFlow {@link ConfigProto} config. * @param prefix optional prefix that will be prepended to names in the graph. + * @param executor the executor to use for asynchronous execution. */ public static TensorFlowGraphLoader tensorFlowGraph( final byte[] graphDef, @Nullable final ConfigProto config, - @Nullable final String prefix) { - return TensorFlowGraphLoader.create(graphDef, config, prefix); + @Nullable final String prefix, + final Executor executor) { + return TensorFlowGraphLoader.create(graphDef, config, prefix, executor); } /** @@ -165,13 +186,15 @@ public static TensorFlowGraphLoader tensorFlowGraph( * @param graphDef byte array representing the TensorFlow {@link Graph} definition. * @param config optional TensorFlow {@link ConfigProto} config. * @param prefix optional prefix that will be prepended to names in the graph. + * @param executor the executor to use for asynchronous execution. */ public static TensorFlowGraphLoader tensorFlowGraph( final Model.Id id, final byte[] graphDef, @Nullable final ConfigProto config, - @Nullable final String prefix) { - return TensorFlowGraphLoader.create(id, graphDef, config, prefix); + @Nullable final String prefix, + final Executor executor) { + return TensorFlowGraphLoader.create(id, graphDef, config, prefix, executor); } /** @@ -180,7 +203,8 @@ public static TensorFlowGraphLoader tensorFlowGraph( * @param id model id. Id needs to be in the following format: * projects/$projectId/models/$modelId/version/$versionId */ - public static MlEngineLoader mlEngine(final Model.Id id) { + public static MlEngineLoader mlEngine(final Model.Id id) + throws IOException, GeneralSecurityException { return MlEngineLoader.create(id); } @@ -193,7 +217,8 @@ public static MlEngineLoader mlEngine(final Model.Id id) { */ public static MlEngineLoader mlEngine(final String projectId, final String modelId, - final String versionId) { + final String versionId) + throws IOException, GeneralSecurityException { return MlEngineLoader.create(projectId, modelId, versionId); } } diff --git a/zoltar-api/src/main/java/com/spotify/zoltar/Predictors.java b/zoltar-api/src/main/java/com/spotify/zoltar/Predictors.java index 8688348b..b3a68c9d 100644 --- a/zoltar-api/src/main/java/com/spotify/zoltar/Predictors.java +++ b/zoltar-api/src/main/java/com/spotify/zoltar/Predictors.java @@ -26,7 +26,6 @@ import com.spotify.zoltar.metrics.Instrumentations; import com.spotify.zoltar.metrics.PredictorMetrics; import com.spotify.zoltar.tf.JTensor; -import com.spotify.zoltar.tf.TensorFlowLoader; import com.spotify.zoltar.tf.TensorFlowModel; import com.spotify.zoltar.tf.TensorFlowPredictFn; import java.util.Map; @@ -161,8 +160,7 @@ public static , InputT, VectorT, ValueT> PredictorBuilde * @param extractFn a feature extract function to use to transform input into extracted * features. * @param predictFn a prediction function to perform prediction with {@link AsyncPredictFn}. - * @param metrics a predictor metrics implementation - * {@link com.spotify.zoltar.metrics.semantic.SemanticPredictMetrics}. + * @param metrics a predictor metrics implementation {@link com.spotify.zoltar.metrics.semantic.SemanticPredictMetrics}. * @param underlying type of the {@link Model}. * @param type of the input to the {@link FeatureExtractor}. * @param type of the output from {@link FeatureExtractor}. @@ -186,8 +184,7 @@ public static , InputT, VectorT, ValueT> PredictorBuilde * @param modelLoader model loader that loads the model to perform prediction on. * @param featureExtractor a feature extractor to use to transform input into extracted features. * @param predictFn a prediction function to perform prediction with {@link PredictFn}. - * @param metrics a predictor metrics implementation - * {@link com.spotify.zoltar.metrics.semantic.SemanticPredictMetrics}. + * @param metrics a predictor metrics implementation {@link com.spotify.zoltar.metrics.semantic.SemanticPredictMetrics}. * @param underlying type of the {@link Model}. * @param type of the input to the {@link FeatureExtractor}. * @param type of the output from {@link FeatureExtractor}. @@ -212,8 +209,7 @@ public static , InputT, VectorT, ValueT> PredictorBuilde * @param featureExtractor a feature extractor to use to transform input into extracted features. * @param predictFn a prediction function to perform prediction with {@link * AsyncPredictFn}. - * @param metrics a predictor metrics implementation - * {@link com.spotify.zoltar.metrics.semantic.SemanticPredictMetrics}. + * @param metrics a predictor metrics implementation {@link com.spotify.zoltar.metrics.semantic.SemanticPredictMetrics}. * @param underlying type of the {@link Model}. * @param type of the input to the {@link FeatureExtractor}. * @param type of the output from {@link FeatureExtractor}. @@ -233,46 +229,47 @@ public static , InputT, VectorT, ValueT> PredictorBuilde /** * Returns a TensorFlow Predictor. * - * @param modelUri should point to a directory of the saved TensorFlow {@link + * @param modelLoader should point to a directory of the saved TensorFlow {@link * org.tensorflow.SavedModelBundle}, can be a URI to a local filesystem, * resource, GCS etc. * @param extractFn a feature extract function to use to transform input into extracted * features. * @param outTensorExtractor function to extract the output value from a {@link JTensor}. - * @param fetchOps operations to fetch. + * @param fetchOps operations to fetch. * @param type of the input to the {@link FeatureExtractor}. * @param type of the prediction result. */ public static Predictor tensorFlow( - final String modelUri, + final ModelLoader modelLoader, final ExtractFn extractFn, final Function, ValueT> outTensorExtractor, final String... fetchOps) { - return tensorFlow(modelUri, FeatureExtractor.create(extractFn), outTensorExtractor, fetchOps); + return tensorFlow(modelLoader, FeatureExtractor.create(extractFn), outTensorExtractor, + fetchOps); } /** * Returns a TensorFlow Predictor. * - * @param modelUri should point to a directory of the saved TensorFlow {@link + * @param modelLoader should point to a directory of the saved TensorFlow {@link * org.tensorflow.SavedModelBundle}, can be a URI to a local filesystem, * resource, GCS etc. * @param extractFn a feature extract function to use to transform input into extracted * features. * @param outTensorExtractor function to extract the output value from a {@link JTensor}. - * @param fetchOps operations to fetch. + * @param fetchOps operations to fetch. * @param metrics a predictor metrics implementation * {@link com.spotify.zoltar.metrics.semantic.SemanticPredictMetrics}. * @param type of the input to the {@link FeatureExtractor}. * @param type of the prediction result. */ public static Predictor tensorFlow( - final String modelUri, + final ModelLoader modelLoader, final ExtractFn extractFn, final Function, ValueT> outTensorExtractor, final String[] fetchOps, final PredictorMetrics metrics) { - return tensorFlow(modelUri, + return tensorFlow(modelLoader, FeatureExtractor.create(extractFn), outTensorExtractor, fetchOps, @@ -282,23 +279,21 @@ public static Predictor tensorFlow( /** * Returns a TensorFlow Predictor. * - * @param modelUri should point to a directory of the saved TensorFlow {@link + * @param modelLoader should point to a directory of the saved TensorFlow {@link * org.tensorflow.SavedModelBundle}, can be a URI to a local filesystem, * resource, GCS etc. * @param featureExtractor a feature extractor to use to transform input into extracted * features. * @param outTensorExtractor function to extract the output value from a {@link JTensor}. - * @param fetchOps operations to fetch. + * @param fetchOps operations to fetch. * @param type of the input to the {@link FeatureExtractor}. * @param type of the prediction result. */ public static Predictor tensorFlow( - final String modelUri, + final ModelLoader modelLoader, final FeatureExtractor featureExtractor, final Function, ValueT> outTensorExtractor, final String... fetchOps) { - final ModelLoader modelLoader = - TensorFlowLoader.create(modelUri); final TensorFlowPredictFn predictFn = TensorFlowPredictFn.example(outTensorExtractor, fetchOps); @@ -308,26 +303,24 @@ public static Predictor tensorFlow( /** * Returns a TensorFlow Predictor. * - * @param modelUri should point to a directory of the saved TensorFlow {@link + * @param modelLoader should point to a directory of the saved TensorFlow {@link * org.tensorflow.SavedModelBundle}, can be a URI to a local filesystem, * resource, GCS etc. * @param featureExtractor a feature extractor to use to transform input into extracted * features. * @param outTensorExtractor function to extract the output value from a {@link JTensor}. - * @param fetchOps operations to fetch. + * @param fetchOps operations to fetch. * @param metrics a predictor metrics implementation * {@link com.spotify.zoltar.metrics.semantic.SemanticPredictMetrics}. * @param type of the input to the {@link FeatureExtractor}. * @param type of the prediction result. */ public static Predictor tensorFlow( - final String modelUri, + final ModelLoader modelLoader, final FeatureExtractor featureExtractor, final Function, ValueT> outTensorExtractor, final String[] fetchOps, final PredictorMetrics metrics) { - final ModelLoader modelLoader = - TensorFlowLoader.create(modelUri); final TensorFlowPredictFn predictFn = TensorFlowPredictFn.example(outTensorExtractor, fetchOps); @@ -338,20 +331,20 @@ public static Predictor tensorFlow( * Returns a TensorFlow Predictor. Assumes feature extraction is embedded in the model via * Tensorflow Transform, so no extractFn is needed and the input type must be Example. * - * @param modelUri should point to a directory of the saved TensorFlow {@link + * @param modelLoader should point to a directory of the saved TensorFlow {@link * org.tensorflow.SavedModelBundle}, can be a URI to a local filesystem, * resource, GCS etc. * @param outTensorExtractor function to extract the output value from a {@link JTensor}. - * @param fetchOps operations to fetch. + * @param fetchOps operations to fetch. * @param metrics a predictor metrics implementation * {@link com.spotify.zoltar.metrics.semantic.SemanticPredictMetrics}. * @param type of the prediction result. */ public static Predictor tensorFlow( - final String modelUri, + final ModelLoader modelLoader, final Function, ValueT> outTensorExtractor, final String[] fetchOps, final PredictorMetrics metrics) { - return tensorFlow(modelUri, ExtractFn.identity(), outTensorExtractor, fetchOps, metrics); + return tensorFlow(modelLoader, ExtractFn.identity(), outTensorExtractor, fetchOps, metrics); } } diff --git a/zoltar-core/src/main/java/com/spotify/zoltar/ModelLoader.java b/zoltar-core/src/main/java/com/spotify/zoltar/ModelLoader.java index 44c007f7..cea93a1a 100644 --- a/zoltar-core/src/main/java/com/spotify/zoltar/ModelLoader.java +++ b/zoltar-core/src/main/java/com/spotify/zoltar/ModelLoader.java @@ -25,6 +25,7 @@ import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.function.Function; @@ -48,19 +49,141 @@ interface ThrowableSupplier> { } /** - * Lifts a supplier into a {@link ModelLoader}. + * PreLoader is model loader that calls {@link ModelLoader#get()} allowing model preloading. + * + * @param Model instance type. + */ + @FunctionalInterface + interface PreLoader> extends ModelLoader { + + /** + * Returns a blocking {@link ModelLoader}. Blocks till the model is loaded or a {@link Duration} + * is met. + * + * @param supplier model supplier. + * @param duration Amount of time that it should wait, if necessary, for model to be loaded. + * @param executor the executor to use for asynchronous execution. + * @param Underlying model instance. + */ + static > PreLoader preload(final ThrowableSupplier supplier, + final Duration duration, + final Executor executor) + throws InterruptedException, ExecutionException, TimeoutException { + return preload(ModelLoader.load(supplier, executor), duration)::get; + } + + /** + * Returns a blocking {@link PreLoader}. Blocks till the model is loaded or a {@link Duration} + * is met. + * + * @param loader model loader. + * @param duration Amount of time that it should wait, if necessary, for model to be loaded. + * @param Underlying model instance. + */ + static > PreLoader preload(final ModelLoader loader, + final Duration duration) + throws InterruptedException, ExecutionException, TimeoutException { + return ModelLoader.loaded(loader.get(duration))::get; + } + + /** + * Returns a blocking {@link PreLoader}. Blocks till the model is loaded or a {@link Duration} + * is met. + * + * @param duration Amount of time that it should wait, if necessary, for model to be loaded. + */ + static > Function, PreLoader> preload( + final Duration duration) { + return loader -> { + try { + return preload(loader, duration); + } catch (final Exception e) { + final CompletableFuture failed = new CompletableFuture<>(); + failed.completeExceptionally(e); + + return () -> failed; + } + }; + } + + } + + /** + * ConsLoader is a constant {@link ModelLoader}. + * + * @param Model instance type. + */ + @FunctionalInterface + interface ConsLoader> extends ModelLoader { + + /** + * Creates a {@link ModelLoader} with an already loaded model. + * + * @param model Underlying model instance. + */ + static > ConsLoader cons(final M model) { + final CompletableFuture m = CompletableFuture.completedFuture(model); + return () -> m; + } + + } + + /** + * Creates a {@link ModelLoader} with an already loaded model. + * + * @param model Underlying model instance. + */ + static > ModelLoader loaded(final M model) { + return ConsLoader.cons(model); + } + + /** + * Returns a blocking {@link ModelLoader}. Blocks till the model is loaded or a {@link Duration} + * is met. + * + * @param supplier model supplier. + * @param duration Amount of time that it should wait, if necessary, for model to be loaded. + * @param executor the executor to use for asynchronous execution. + * @param Underlying model instance. + */ + static > PreLoader preload(final ThrowableSupplier supplier, + final Duration duration, + final Executor executor) + throws InterruptedException, ExecutionException, TimeoutException { + return PreLoader.preload(supplier, duration, executor); + } + + /** + * Returns a blocking {@link PreLoader}. Blocks till the model is loaded or a {@link Duration} + * is met. + * + * @param loader model loader. + * @param duration Amount of time that it should wait, if necessary, for model to be loaded. + * @param Underlying model instance. + */ + static > ModelLoader preload(final ModelLoader loader, + final Duration duration) + throws InterruptedException, ExecutionException, TimeoutException { + return PreLoader.preload(loader, duration); + } + + /** + * Create a {@link ModelLoader} that loads the supplied model asynchronously. * * @param supplier model supplier. - * @param Underlying model instance. + * @param Underlying model instance. */ - static > ModelLoader lift(final ThrowableSupplier supplier) { - return () -> CompletableFuture.supplyAsync(() -> { + static > ModelLoader load(final ThrowableSupplier supplier, + final Executor executor) { + final CompletableFuture future = CompletableFuture.supplyAsync(() -> { try { return supplier.get(); } catch (final Exception e) { throw new CompletionException(e); } - }); + }, executor); + + return () -> future; } /** diff --git a/zoltar-core/src/main/java/com/spotify/zoltar/Predict.java b/zoltar-core/src/main/java/com/spotify/zoltar/Predict.java new file mode 100644 index 00000000..920e4877 --- /dev/null +++ b/zoltar-core/src/main/java/com/spotify/zoltar/Predict.java @@ -0,0 +1,163 @@ +/*- + * -\-\- + * zoltar-core + * -- + * Copyright (C) 2016 - 2018 Spotify AB + * -- + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * -/-/- + */ + +package com.spotify.zoltar; + +import java.time.Duration; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; + +/** + * Prediction functional interface. Allows to define prediction function via lambda. + * + * @param underlying type of the {@link Model}. + * @param type of the feature extraction input. + * @param type of the feature extraction output. + * @param type of the prediction output. + */ +@FunctionalInterface +interface Predict, InputT, VectorT, ValueT> { + + /** + * ConsPredict is a constant {@link Predict}. + * + * @param underlying type of the {@link Model}. + * @param type of the feature extraction input. + * @param type of the feature extraction output. + * @param type of the prediction output. + */ + @FunctionalInterface + interface ConsPredict, InputT, VectorT, ValueT> + extends Predict { + + /** + * Creates a {@link Predict} with precomputed predictions. + */ + static , I, V, O> ConsPredict cons( + final List> predictions) { + final CompletableFuture>> p = + CompletableFuture.completedFuture(predictions); + return (m, v) -> p; + } + } + + /** + * ConsPredict is a constant {@link Predict}. + * + * @param underlying type of the {@link Model}. + * @param type of the feature extraction input. + * @param type of the feature extraction output. + * @param type of the prediction output. + */ + @FunctionalInterface + interface TimeoutPredict, InputT, VectorT, ValueT> + extends Predict { + + /** + * Creates a {@link Predict} that will timeout if time exceeds {@link Duration}. + */ + static , I, V, O> TimeoutPredict timeout( + final Predict predict, + final Duration duration, + final Executor executor) { + return Predict.predict((m, v) -> { + try { + return predict + .apply(m, v) + .toCompletableFuture() + .get(duration.toMillis(), TimeUnit.MILLISECONDS); + } catch (Exception e) { + throw new RuntimeException(e); + } + }, executor)::apply; + } + } + + /** + * Synchronous prediction functional interface. Allows to define prediction function via lambda. + * + * @param type of the feature extraction input. + * @param type of the feature extraction output. + * @param type of the prediction output. + */ + @FunctionalInterface + interface PredictFn, InputT, VectorT, ValueT> { + + /** + * The functional interface. Your function/lambda takes model and features after extractions as + * input, should perform a prediction and return the predictions. + * + * @param model model to perform prediction on. + * @param vectors extracted features. + * @return predictions ({@link Prediction}). + */ + List> apply(ModelT model, List> vectors); + } + + /** + * Creates a {@link Predict} with precomputed predictions. + */ + static , I, V, O> Predict predicted( + final List> predictions) { + return ConsPredict.cons(predictions); + } + + /** + * Creates a {@link Predict} that will execute the supplied Function asynchronously. + */ + static , I, V, O> Predict predict( + final PredictFn fn, + final Executor executor) { + return (model, vectors) -> CompletableFuture.supplyAsync(() -> { + return fn.apply(model, vectors); + }, executor); + } + + /** + * Creates a {@link Predict} that will timeout if time exceeds {@link Duration}. + */ + static , I, V, O> Predict predict( + final Predict predict, + final Duration duration, + final Executor executor) { + return TimeoutPredict.timeout(predict, duration, executor); + } + + /** + * The functional interface. Your function/lambda takes model and features after extractions as + * input, should perform a prediction and return the predictions. + * + * @param model model to perform prediction on. + * @param vectors extracted features. + * @return predictions ({@link Prediction}). + */ + CompletionStage>> apply( + ModelT model, + List> vectors); + + default > C with( + final Function, C> fn) { + return fn.apply(this); + } +} diff --git a/zoltar-core/src/main/java/com/spotify/zoltar/loaders/ModelMemoizer.java b/zoltar-core/src/main/java/com/spotify/zoltar/loaders/ModelMemoizer.java deleted file mode 100644 index 1e0925d8..00000000 --- a/zoltar-core/src/main/java/com/spotify/zoltar/loaders/ModelMemoizer.java +++ /dev/null @@ -1,66 +0,0 @@ -/*- - * -\-\- - * zoltar-core - * -- - * Copyright (C) 2016 - 2018 Spotify AB - * -- - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * -/-/- - */ - -package com.spotify.zoltar.loaders; - -import com.spotify.zoltar.Model; -import com.spotify.zoltar.ModelLoader; -import java.util.Objects; -import java.util.concurrent.CompletionStage; -import java.util.concurrent.atomic.AtomicReference; - -/** - * Memoizes the result of the supplied {@link ModelLoader}. - * - * @param Model instance type. - */ -@FunctionalInterface -public interface ModelMemoizer> extends ModelLoader { - - /** - * Creates a memoized model loader. - * - * @param loader ModelLoader to be memoized. - * @param Model instance type. - * @return Memoized loader. - */ - static > ModelMemoizer memoize(final ModelLoader loader) { - // AtomicReference can be updated atomically and offers volatile write/read semantics. - // However, in this case it's just being used as a container for the CompletionStage. - final AtomicReference> value = new AtomicReference<>(); - return () -> { - CompletionStage val = value.get(); - if (val == null) { - // we want to avoid .get() being called several times from the different threads - // because it can be very expensive. - synchronized (value) { - val = value.get(); - if (val == null) { - val = Objects.requireNonNull(loader.get()); - value.set(val); - } - } - } - - return val; - }; - } - -} diff --git a/zoltar-core/src/main/java/com/spotify/zoltar/loaders/Preloader.java b/zoltar-core/src/main/java/com/spotify/zoltar/loaders/Preloader.java deleted file mode 100644 index 71f7953c..00000000 --- a/zoltar-core/src/main/java/com/spotify/zoltar/loaders/Preloader.java +++ /dev/null @@ -1,78 +0,0 @@ -/*- - * -\-\- - * zoltar-core - * -- - * Copyright (C) 2016 - 2018 Spotify AB - * -- - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * -/-/- - */ - -package com.spotify.zoltar.loaders; - -import com.spotify.zoltar.Model; -import com.spotify.zoltar.ModelLoader; -import java.time.Duration; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionStage; -import java.util.function.Function; - -/** - * Preloader is model loader that calls {@link ModelLoader#get()} allowing model preloading. - * - * @param Model instance type. - */ -@FunctionalInterface -public interface Preloader> extends ModelLoader { - - /** - * Returns a blocking {@link Preloader}. Blocks at create time till the model is loaded. - */ - static > Function, Preloader> preload() { - return preload(Duration.ofDays(Integer.MAX_VALUE)); - } - - /** - * Returns a blocking {@link Preloader}. Blocks till the model is loaded or a {@link Duration} is - * met. - * - * @param duration Amount of time that it should wait, if necessary, for model to be loaded. - */ - static > Function, Preloader> preload( - final Duration duration) { - return loader -> { - CompletionStage model; - try { - model = CompletableFuture.completedFuture(loader.get(duration)); - } catch (final Exception e) { - final CompletableFuture failed = new CompletableFuture<>(); - failed.completeExceptionally(e); - - model = failed; - } - - final CompletionStage finalModel = model; - return () -> finalModel; - }; - } - - /** - * Returns a asynchronous {@link Preloader}. - */ - static > Function, Preloader> preloadAsync() { - return loader -> { - final CompletionStage model = loader.get(); - return () -> model; - }; - } -} diff --git a/zoltar-mlengine/src/main/java/com/spotify/zoltar/mlengine/MlEngineLoader.java b/zoltar-mlengine/src/main/java/com/spotify/zoltar/mlengine/MlEngineLoader.java index 22ac43d1..31686ceb 100644 --- a/zoltar-mlengine/src/main/java/com/spotify/zoltar/mlengine/MlEngineLoader.java +++ b/zoltar-mlengine/src/main/java/com/spotify/zoltar/mlengine/MlEngineLoader.java @@ -22,10 +22,11 @@ import com.spotify.zoltar.Model; import com.spotify.zoltar.ModelLoader; -import com.spotify.zoltar.loaders.ModelMemoizer; +import java.io.IOException; +import java.security.GeneralSecurityException; /** - * {@link MlEngineLoader} loader. This loader is composed with {@link ModelMemoizer}. + * {@link MlEngineLoader} loader. */ @FunctionalInterface public interface MlEngineLoader extends ModelLoader { @@ -36,8 +37,9 @@ public interface MlEngineLoader extends ModelLoader { * @param projectId Google project id. * @param modelId model id. */ - static MlEngineLoader create(final String projectId, final String modelId) { - return create(() -> MlEngineModel.create(projectId, modelId)); + static MlEngineLoader create(final String projectId, final String modelId) + throws IOException, GeneralSecurityException { + return ModelLoader.loaded(MlEngineModel.create(projectId, modelId))::get; } /** @@ -49,8 +51,9 @@ static MlEngineLoader create(final String projectId, final String modelId) { */ static MlEngineLoader create(final String projectId, final String modelId, - final String versionId) { - return create(() -> MlEngineModel.create(projectId, modelId, versionId)); + final String versionId) + throws IOException, GeneralSecurityException { + return ModelLoader.loaded(MlEngineModel.create(projectId, modelId, versionId))::get; } /** @@ -61,21 +64,8 @@ static MlEngineLoader create(final String projectId, * "projects/{PROJECT_ID}/models/{MODEL_ID}/versions/{MODEL_VERSION}" * */ - static MlEngineLoader create(final Model.Id id) { - return create(() -> MlEngineModel.create(id)); - } - - /** - * Returns a Google Cloud ML Engine model loader. - * - * @param supplier {@link MlEngineModel} supplier. - */ - static MlEngineLoader create(final ThrowableSupplier supplier) { - final ModelLoader loader = ModelLoader - .lift(supplier) - .with(ModelMemoizer::memoize); - - return loader::get; + static MlEngineLoader create(final Model.Id id) throws IOException, GeneralSecurityException { + return ModelLoader.loaded(MlEngineModel.create(id))::get; } } diff --git a/zoltar-tensorflow/src/main/java/com/spotify/zoltar/tf/TensorFlowGraphLoader.java b/zoltar-tensorflow/src/main/java/com/spotify/zoltar/tf/TensorFlowGraphLoader.java index 6df5713d..6ba8d85b 100644 --- a/zoltar-tensorflow/src/main/java/com/spotify/zoltar/tf/TensorFlowGraphLoader.java +++ b/zoltar-tensorflow/src/main/java/com/spotify/zoltar/tf/TensorFlowGraphLoader.java @@ -22,16 +22,14 @@ import com.spotify.zoltar.Model; import com.spotify.zoltar.ModelLoader; -import com.spotify.zoltar.loaders.ModelMemoizer; -import com.spotify.zoltar.loaders.Preloader; import java.net.URI; +import java.util.concurrent.Executor; import javax.annotation.Nullable; import org.tensorflow.Graph; import org.tensorflow.framework.ConfigProto; /** - * {@link TensorFlowGraphModel} loader. This loader is composed with {@link ModelMemoizer} and - * {@link Preloader}. + * {@link TensorFlowGraphModel} loader. */ @FunctionalInterface public interface TensorFlowGraphLoader extends ModelLoader { @@ -44,11 +42,15 @@ public interface TensorFlowGraphLoader extends ModelLoader * local filesystem, resource, GCS etc. * @param config optional TensorFlow {@link ConfigProto} config. * @param prefix optional prefix that will be prepended to names in the graph. + * @param executor the executor to use for asynchronous execution. */ static TensorFlowGraphLoader create(final String modelUri, @Nullable final ConfigProto config, - @Nullable final String prefix) { - return create(() -> TensorFlowGraphModel.create(URI.create(modelUri), config, prefix)); + @Nullable final String prefix, + final Executor executor) { + final ThrowableSupplier supplier = + () -> TensorFlowGraphModel.create(URI.create(modelUri), config, prefix); + return create(supplier, executor); } /** @@ -59,12 +61,16 @@ static TensorFlowGraphLoader create(final String modelUri, * local filesystem, resource, GCS etc. * @param config optional TensorFlow {@link ConfigProto} config. * @param prefix optional prefix that will be prepended to names in the graph. + * @param executor the executor to use for asynchronous execution. */ static TensorFlowGraphLoader create(final Model.Id id, final String modelUri, @Nullable final ConfigProto config, - @Nullable final String prefix) { - return create(() -> TensorFlowGraphModel.create(id, URI.create(modelUri), config, prefix)); + @Nullable final String prefix, + final Executor executor) { + final ThrowableSupplier supplier = + () -> TensorFlowGraphModel.create(id, URI.create(modelUri), config, prefix); + return create(supplier, executor); } /** @@ -73,11 +79,13 @@ static TensorFlowGraphLoader create(final Model.Id id, * @param graphDef byte array representing the TensorFlow {@link Graph} definition. * @param config optional TensorFlow {@link ConfigProto} config. * @param prefix optional prefix that will be prepended to names in the graph. + * @param executor the executor to use for asynchronous execution. */ static TensorFlowGraphLoader create(final byte[] graphDef, @Nullable final ConfigProto config, - @Nullable final String prefix) { - return create(() -> TensorFlowGraphModel.create(graphDef, config, prefix)); + @Nullable final String prefix, + final Executor executor) { + return create(() -> TensorFlowGraphModel.create(graphDef, config, prefix), executor); } /** @@ -87,26 +95,25 @@ static TensorFlowGraphLoader create(final byte[] graphDef, * @param graphDef byte array representing the TensorFlow {@link Graph} definition. * @param config optional TensorFlow {@link ConfigProto} config. * @param prefix optional prefix that will be prepended to names in the graph. + * @param executor the executor to use for asynchronous execution. */ static TensorFlowGraphLoader create(final Model.Id id, final byte[] graphDef, @Nullable final ConfigProto config, - @Nullable final String prefix) { - return create(() -> TensorFlowGraphModel.create(id, graphDef, config, prefix)); + @Nullable final String prefix, + final Executor executor) { + return create(() -> TensorFlowGraphModel.create(id, graphDef, config, prefix), executor); } /** * Returns a TensorFlow model loader based on a serialized TensorFlow {@link Graph}. * * @param supplier {@link TensorFlowGraphModel} supplier. + * @param executor the executor to use for asynchronous execution. */ - static TensorFlowGraphLoader create(final ThrowableSupplier supplier) { - final ModelLoader loader = ModelLoader - .lift(supplier) - .with(ModelMemoizer::memoize) - .with(Preloader.preloadAsync()); - - return loader::get; + static TensorFlowGraphLoader create(final ThrowableSupplier supplier, + final Executor executor) { + return ModelLoader.load(supplier, executor)::get; } } diff --git a/zoltar-tensorflow/src/main/java/com/spotify/zoltar/tf/TensorFlowLoader.java b/zoltar-tensorflow/src/main/java/com/spotify/zoltar/tf/TensorFlowLoader.java index 6f4800dc..f1cf3d23 100644 --- a/zoltar-tensorflow/src/main/java/com/spotify/zoltar/tf/TensorFlowLoader.java +++ b/zoltar-tensorflow/src/main/java/com/spotify/zoltar/tf/TensorFlowLoader.java @@ -22,13 +22,11 @@ import com.spotify.zoltar.Model; import com.spotify.zoltar.ModelLoader; -import com.spotify.zoltar.loaders.ModelMemoizer; -import com.spotify.zoltar.loaders.Preloader; import java.net.URI; +import java.util.concurrent.Executor; /** - * {@link TensorFlowModel} loader. This loader is composed with {@link ModelMemoizer} and {@link - * Preloader}. + * {@link TensorFlowModel} loader. */ @FunctionalInterface public interface TensorFlowLoader extends ModelLoader { @@ -39,9 +37,10 @@ public interface TensorFlowLoader extends ModelLoader { * @param modelUri should point to a directory of the saved TensorFlow {@link * org.tensorflow.SavedModelBundle}, can be a URI to a local filesystem, resource, * GCS etc. + * @param executor the executor to use for asynchronous execution. */ - static TensorFlowLoader create(final String modelUri) { - return create(() -> TensorFlowModel.create(URI.create(modelUri))); + static TensorFlowLoader create(final String modelUri, final Executor executor) { + return create(() -> TensorFlowModel.create(URI.create(modelUri)), executor); } /** @@ -51,9 +50,12 @@ static TensorFlowLoader create(final String modelUri) { * @param modelUri should point to a directory of the saved TensorFlow {@link * org.tensorflow.SavedModelBundle}, can be a URI to a local filesystem, resource, * GCS etc. + * @param executor the executor to use for asynchronous execution. */ - static TensorFlowLoader create(final Model.Id id, final String modelUri) { - return create(() -> TensorFlowModel.create(id, URI.create(modelUri))); + static TensorFlowLoader create(final Model.Id id, + final String modelUri, + final Executor executor) { + return create(() -> TensorFlowModel.create(id, URI.create(modelUri)), executor); } /** @@ -63,10 +65,12 @@ static TensorFlowLoader create(final Model.Id id, final String modelUri) { * org.tensorflow.SavedModelBundle}, can be a URI to a local filesystem, resource, * GCS etc. * @param options TensorFlow options, see {@link TensorFlowModel.Options}. + * @param executor the executor to use for asynchronous execution. */ static TensorFlowLoader create(final String modelUri, - final TensorFlowModel.Options options) { - return create(() -> TensorFlowModel.create(URI.create(modelUri), options)); + final TensorFlowModel.Options options, + final Executor executor) { + return create(() -> TensorFlowModel.create(URI.create(modelUri), options), executor); } /** @@ -77,25 +81,24 @@ static TensorFlowLoader create(final String modelUri, * org.tensorflow.SavedModelBundle}, can be a URI to a local filesystem, resource, * GCS etc. * @param options TensorFlow options, see {@link TensorFlowModel.Options}. + * @param executor the executor to use for asynchronous execution. */ static TensorFlowLoader create(final Model.Id id, final String modelUri, - final TensorFlowModel.Options options) { - return create(() -> TensorFlowModel.create(id, URI.create(modelUri), options)); + final TensorFlowModel.Options options, + final Executor executor) { + return create(() -> TensorFlowModel.create(id, URI.create(modelUri), options), executor); } /** * Returns a TensorFlow model loader based on a saved model. * * @param supplier {@link TensorFlowModel} supplier. + * @param executor the executor to use for asynchronous execution. */ - static TensorFlowLoader create(final ThrowableSupplier supplier) { - final ModelLoader loader = ModelLoader - .lift(supplier) - .with(ModelMemoizer::memoize) - .with(Preloader.preloadAsync()); - - return loader::get; + static TensorFlowLoader create(final ThrowableSupplier supplier, + final Executor executor) { + return ModelLoader.load(supplier, executor)::get; } } diff --git a/zoltar-tests/src/test/java/com/spotify/zoltar/loaders/PreloaderTest.java b/zoltar-tests/src/test/java/com/spotify/zoltar/ModelLoaderTest.java similarity index 60% rename from zoltar-tests/src/test/java/com/spotify/zoltar/loaders/PreloaderTest.java rename to zoltar-tests/src/test/java/com/spotify/zoltar/ModelLoaderTest.java index f052377c..ac93b4ce 100644 --- a/zoltar-tests/src/test/java/com/spotify/zoltar/loaders/PreloaderTest.java +++ b/zoltar-tests/src/test/java/com/spotify/zoltar/ModelLoaderTest.java @@ -18,17 +18,19 @@ * -/-/- */ -package com.spotify.zoltar.loaders; +package com.spotify.zoltar; import static org.hamcrest.core.Is.is; import static org.junit.Assert.assertThat; -import com.spotify.zoltar.Model; -import com.spotify.zoltar.ModelLoader; +import com.spotify.zoltar.ModelLoader.PreLoader; import java.time.Duration; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.TimeoutException; import org.junit.Test; -public class PreloaderTest { +public class ModelLoaderTest { static class DummyModel implements Model { @@ -52,35 +54,27 @@ public void close() throws Exception { } @Test - public void preload() { + public void preload() throws InterruptedException, ExecutionException, TimeoutException { final ModelLoader loader = ModelLoader - .lift(DummyModel::new) - .with(Preloader.preload()); - - assertThat(loader.get().toCompletableFuture().isDone(), is(true)); - } - - @Test - public void preloadTimeout() { - final ModelLoader loader = ModelLoader - .lift(() -> { - Thread.sleep(Duration.ofSeconds(10).toMillis()); + .load(() -> { + Thread.sleep(Duration.ofMillis(5).toMillis()); return new DummyModel(); - }) - .with(Preloader.preload(Duration.ZERO)); + }, ForkJoinPool.commonPool()); + + final ModelLoader preloaded = ModelLoader.preload(loader, Duration.ofSeconds(1)); - assertThat(loader.get().toCompletableFuture().isCompletedExceptionally(), is(true)); + assertThat(preloaded.get().toCompletableFuture().isDone(), is(true)); } - @Test - public void preloadAsync() { + @Test(expected = TimeoutException.class) + public void preloadTimeout() throws InterruptedException, ExecutionException, TimeoutException { final ModelLoader loader = ModelLoader - .lift(() -> { + .load(() -> { Thread.sleep(Duration.ofSeconds(10).toMillis()); return new DummyModel(); - }) - .with(Preloader.preloadAsync()); + }, ForkJoinPool.commonPool()); - assertThat(loader.get().toCompletableFuture().isDone(), is(false)); + ModelLoader.preload(loader, Duration.ZERO); } + } diff --git a/zoltar-tests/src/test/java/com/spotify/zoltar/PredictorBuilderTest.java b/zoltar-tests/src/test/java/com/spotify/zoltar/PredictorBuilderTest.java index c828b8c7..4e21a5ef 100644 --- a/zoltar-tests/src/test/java/com/spotify/zoltar/PredictorBuilderTest.java +++ b/zoltar-tests/src/test/java/com/spotify/zoltar/PredictorBuilderTest.java @@ -100,7 +100,7 @@ public IdentityPredictorBuilder with( @Test public void identityDecoration() throws ExecutionException, InterruptedException { final ModelLoader loader = - ModelLoader.lift(DummyModel::new); + ModelLoader.loaded(new DummyModel()); final ExtractFn extractFn = ExtractFn.lift(input -> (float) input / 10); final PredictFn predictFn = (model, vectors) -> { return vectors.stream() diff --git a/zoltar-tests/src/test/java/com/spotify/zoltar/PredictorTest.java b/zoltar-tests/src/test/java/com/spotify/zoltar/PredictorTest.java index 694e7abe..cfd501a9 100644 --- a/zoltar-tests/src/test/java/com/spotify/zoltar/PredictorTest.java +++ b/zoltar-tests/src/test/java/com/spotify/zoltar/PredictorTest.java @@ -25,7 +25,6 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import com.spotify.zoltar.FeatureExtractFns.ExtractFn; import com.spotify.zoltar.FeatureExtractFns.ExtractFn; import com.spotify.zoltar.PredictFns.AsyncPredictFn; import com.spotify.zoltar.PredictFns.PredictFn; @@ -71,7 +70,7 @@ public void timeout() { }; try { - final ModelLoader loader = ModelLoader.lift(DummyModel::new); + final ModelLoader loader = ModelLoader.loaded(new DummyModel()); DefaultPredictorBuilder .create(loader, extractFn, predictFn) .predictor() @@ -92,7 +91,7 @@ public void empty() throws InterruptedException, ExecutionException, TimeoutExce final AsyncPredictFn predictFn = (model, vectors) -> CompletableFuture.completedFuture(Collections.emptyList()); - final ModelLoader loader = ModelLoader.lift(DummyModel::new); + final ModelLoader loader = ModelLoader.loaded(new DummyModel()); DefaultPredictorBuilder .create(loader, extractFn, predictFn) .predictor() @@ -111,7 +110,7 @@ public void nonEmpty() throws InterruptedException, ExecutionException, TimeoutE .collect(Collectors.toList()); }; - final ModelLoader loader = ModelLoader.lift(DummyModel::new); + final ModelLoader loader = ModelLoader.loaded(new DummyModel()); final List> predictions = DefaultPredictorBuilder .create(loader, extractFn, predictFn) diff --git a/zoltar-tests/src/test/java/com/spotify/zoltar/loaders/ModelMemoizerTest.java b/zoltar-tests/src/test/java/com/spotify/zoltar/loaders/ModelMemoizerTest.java deleted file mode 100644 index 133e68e0..00000000 --- a/zoltar-tests/src/test/java/com/spotify/zoltar/loaders/ModelMemoizerTest.java +++ /dev/null @@ -1,75 +0,0 @@ -/*- - * -\-\- - * zoltar-core - * -- - * Copyright (C) 2016 - 2018 Spotify AB - * -- - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * -/-/- - */ - -package com.spotify.zoltar.loaders; - -import static org.hamcrest.core.Is.is; -import static org.junit.Assert.assertThat; - -import com.spotify.zoltar.Model; -import com.spotify.zoltar.ModelLoader; -import java.time.Duration; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeoutException; -import java.util.concurrent.atomic.AtomicInteger; -import org.junit.Test; - -public class ModelMemoizerTest { - - static class DummyModel implements Model { - private final AtomicInteger inc; - - public DummyModel() { - inc = new AtomicInteger(); - } - - @Override - public Id id() { - return Id.create("dummy"); - } - - @Override - public Object instance() { - inc.getAndIncrement(); - return null; - } - - @Override - public void close() throws Exception { - - } - - public int getIncrementValue() { - return inc.get(); - } - } - - @Test - public void memoize() throws InterruptedException, ExecutionException, TimeoutException { - final ModelLoader loader = - ModelLoader.lift(DummyModel::new).with(ModelMemoizer::memoize); - - final Duration duration = Duration.ofMillis(1000); - loader.get(duration).instance(); - loader.get(duration).instance(); - - assertThat(loader.get(duration).getIncrementValue(), is(2)); - } -} diff --git a/zoltar-tests/src/test/java/com/spotify/zoltar/tf/TensorFlowGraphModelTest.java b/zoltar-tests/src/test/java/com/spotify/zoltar/tf/TensorFlowGraphModelTest.java index f3c3f944..2662fdfb 100644 --- a/zoltar-tests/src/test/java/com/spotify/zoltar/tf/TensorFlowGraphModelTest.java +++ b/zoltar-tests/src/test/java/com/spotify/zoltar/tf/TensorFlowGraphModelTest.java @@ -25,6 +25,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; +import com.google.common.util.concurrent.MoreExecutors; import com.spotify.featran.java.JFeatureSpec; import com.spotify.featran.transformers.Identity; import com.spotify.zoltar.FeatureExtractFns.ExtractFn; @@ -88,8 +89,11 @@ private Path createADummyTFGraph() throws IOException { @Test public void testDefaultId() throws IOException, ExecutionException, InterruptedException { final Path graphFile = createADummyTFGraph(); - final ModelLoader model = - TensorFlowGraphLoader.create(graphFile.toString(), null, null); + final ModelLoader model = TensorFlowGraphLoader.create( + graphFile.toString(), + null, + null, + MoreExecutors.directExecutor()); final TensorFlowGraphModel tensorFlowModel = model.get().toCompletableFuture().get(); @@ -99,8 +103,12 @@ public void testDefaultId() throws IOException, ExecutionException, InterruptedE @Test public void testCustomId() throws IOException, ExecutionException, InterruptedException { final Path graphFile = createADummyTFGraph(); - final ModelLoader model = - TensorFlowGraphLoader.create(Id.create("dummy"), graphFile.toString(), null, null); + final ModelLoader model = TensorFlowGraphLoader.create( + Id.create("dummy"), + graphFile.toString(), + null, + null, + MoreExecutors.directExecutor()); final TensorFlowGraphModel tensorFlowModel = model.get().toCompletableFuture().get(); @@ -162,7 +170,7 @@ public void testModelInference() throws Exception { StandardCharsets.UTF_8); final ModelLoader tfModel = - TensorFlowGraphLoader.create(graphFile.toString(), null, null); + TensorFlowGraphLoader.create(graphFile.toString(), null, null, MoreExecutors.directExecutor()); final PredictFn predictFn = (model, vectors) -> vectors.stream() diff --git a/zoltar-tests/src/test/java/com/spotify/zoltar/tf/TensorFlowModelTest.java b/zoltar-tests/src/test/java/com/spotify/zoltar/tf/TensorFlowModelTest.java index 54340aa0..e6584763 100644 --- a/zoltar-tests/src/test/java/com/spotify/zoltar/tf/TensorFlowModelTest.java +++ b/zoltar-tests/src/test/java/com/spotify/zoltar/tf/TensorFlowModelTest.java @@ -24,6 +24,7 @@ import static org.junit.Assert.assertThat; import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.MoreExecutors; import com.spotify.zoltar.FeatureExtractFns.ExtractFn; import com.spotify.zoltar.IrisFeaturesSpec; import com.spotify.zoltar.IrisFeaturesSpec.Iris; @@ -63,9 +64,11 @@ public static Predictor getTFIrisPredictor() throws Exception { StandardCharsets.UTF_8); final ExtractFn extractFn = FeatranExtractFns .example(IrisFeaturesSpec.irisFeaturesSpec(), settings); + final TensorFlowLoader modelLoader = TensorFlowLoader + .create(modelUri, MoreExecutors.directExecutor()); final String op = "linear/head/predictions/class_ids"; - return Predictors.tensorFlow(modelUri, + return Predictors.tensorFlow(modelLoader, extractFn, tensors -> tensors.get(op).longValue()[0], op); @@ -75,7 +78,8 @@ public static Predictor getTFIrisPredictor() throws Exception { public void testDefaultId() throws URISyntaxException, ExecutionException, InterruptedException { final URI trainedModelUri = TensorFlowModelTest.class.getResource("/trained_model").toURI(); - final ModelLoader model = TensorFlowLoader.create(trainedModelUri.toString()); + final ModelLoader model = + TensorFlowLoader.create(trainedModelUri.toString(), MoreExecutors.directExecutor()); final TensorFlowModel tensorFlowModel = model.get().toCompletableFuture().get(); @@ -85,8 +89,10 @@ public void testDefaultId() @Test public void testCustomId() throws URISyntaxException, ExecutionException, InterruptedException { final URI trainedModelUri = TensorFlowModelTest.class.getResource("/trained_model").toURI(); - final ModelLoader model = - TensorFlowLoader.create(Id.create("dummy"), trainedModelUri.toString()); + final ModelLoader model = TensorFlowLoader.create( + Id.create("dummy"), + trainedModelUri.toString(), + MoreExecutors.directExecutor()); final TensorFlowModel tensorFlowModel = model.get().toCompletableFuture().get(); diff --git a/zoltar-tests/src/test/java/com/spotify/zoltar/xgboost/XGBoostModelTest.java b/zoltar-tests/src/test/java/com/spotify/zoltar/xgboost/XGBoostModelTest.java index 84453733..c28e31c8 100644 --- a/zoltar-tests/src/test/java/com/spotify/zoltar/xgboost/XGBoostModelTest.java +++ b/zoltar-tests/src/test/java/com/spotify/zoltar/xgboost/XGBoostModelTest.java @@ -25,6 +25,7 @@ import static org.junit.Assert.assertTrue; import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.MoreExecutors; import com.spotify.futures.CompletableFutures; import com.spotify.zoltar.FeatureExtractFns.ExtractFn; import com.spotify.zoltar.IrisFeaturesSpec; @@ -83,7 +84,8 @@ public static Predictor getXGBoostIrisPredictor() throws Exception { final String settings = new String(Files.readAllBytes(Paths.get(settingsUri)), StandardCharsets.UTF_8); - final XGBoostLoader model = XGBoostLoader.create(trainedModelUri.toString()); + final XGBoostLoader model = + XGBoostLoader.create(trainedModelUri.toString(), MoreExecutors.directExecutor()); final ExtractFn extractFn = FeatranExtractFns .labeledPoints(IrisFeaturesSpec.irisFeaturesSpec(), settings); @@ -96,7 +98,8 @@ public static Predictor getXGBoostIrisPredictor() throws Exception { @Test public void testDefaultId() throws URISyntaxException, ExecutionException, InterruptedException { final URI trainedModelUri = XGBoostModelTest.class.getResource("/iris.model").toURI(); - final XGBoostLoader model = XGBoostLoader.create(trainedModelUri.toString()); + final XGBoostLoader model = + XGBoostLoader.create(trainedModelUri.toString(), MoreExecutors.directExecutor()); final XGBoostModel xgBoostModel = model.get().toCompletableFuture().get(); @@ -106,7 +109,10 @@ public void testDefaultId() throws URISyntaxException, ExecutionException, Inter @Test public void testCustomId() throws URISyntaxException, ExecutionException, InterruptedException { final URI trainedModelUri = XGBoostModelTest.class.getResource("/iris.model").toURI(); - final XGBoostLoader model = XGBoostLoader.create(Id.create("dummy"), trainedModelUri.toString()); + final XGBoostLoader model = XGBoostLoader.create( + Id.create("dummy"), + trainedModelUri.toString(), + MoreExecutors.directExecutor()); final XGBoostModel xgBoostModel = model.get().toCompletableFuture().get(); diff --git a/zoltar-xgboost/src/main/java/com/spotify/zoltar/xgboost/XGBoostLoader.java b/zoltar-xgboost/src/main/java/com/spotify/zoltar/xgboost/XGBoostLoader.java index e3c69286..3777634b 100644 --- a/zoltar-xgboost/src/main/java/com/spotify/zoltar/xgboost/XGBoostLoader.java +++ b/zoltar-xgboost/src/main/java/com/spotify/zoltar/xgboost/XGBoostLoader.java @@ -22,13 +22,11 @@ import com.spotify.zoltar.Model; import com.spotify.zoltar.ModelLoader; -import com.spotify.zoltar.loaders.ModelMemoizer; -import com.spotify.zoltar.loaders.Preloader; import java.net.URI; +import java.util.concurrent.Executor; /** - * {@link XGBoostModel} loader. This loader is composed with {@link ModelMemoizer} and {@link - * Preloader}. + * {@link XGBoostModel} loader. */ @FunctionalInterface @SuppressWarnings("checkstyle:AbbreviationAsWordInName") @@ -40,8 +38,8 @@ public interface XGBoostLoader extends ModelLoader { * @param modelUri should point to serialized XGBoost model file, can be a URI to a local * filesystem, resource, GCS etc. */ - static XGBoostLoader create(final String modelUri) { - return create(() -> XGBoostModel.create(URI.create(modelUri))); + static XGBoostLoader create(final String modelUri, final Executor executor) { + return create(() -> XGBoostModel.create(URI.create(modelUri)), executor); } @@ -52,8 +50,8 @@ static XGBoostLoader create(final String modelUri) { * @param modelUri should point to serialized XGBoost model file, can be a URI to a local * filesystem, resource, GCS etc. */ - static XGBoostLoader create(final Model.Id id, final String modelUri) { - return create(() -> XGBoostModel.create(id, URI.create(modelUri))); + static XGBoostLoader create(final Model.Id id, final String modelUri, final Executor executor) { + return create(() -> XGBoostModel.create(id, URI.create(modelUri)), executor); } /** @@ -61,13 +59,9 @@ static XGBoostLoader create(final Model.Id id, final String modelUri) { * * @param supplier {@link XGBoostModel} supplier. */ - static XGBoostLoader create(final ThrowableSupplier supplier) { - final ModelLoader loader = ModelLoader - .lift(supplier) - .with(ModelMemoizer::memoize) - .with(Preloader.preloadAsync()); - - return loader::get; + static XGBoostLoader create(final ThrowableSupplier supplier, + final Executor executor) { + return ModelLoader.load(supplier, executor)::get; } }