Skip to content
This repository has been archived by the owner on Oct 18, 2023. It is now read-only.

Commit

Permalink
Refactor Predictor (#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
regadas authored Feb 14, 2019
1 parent 9d9a030 commit b49544d
Show file tree
Hide file tree
Showing 22 changed files with 438 additions and 659 deletions.
2 changes: 1 addition & 1 deletion docs/src/paradox/modules/metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Attach it to the @javadoc[PredictorBuilder](com.spotify.zoltar.PredictorBuilder)

@@snip [PredictorMetrics](../../../../examples/custom-metrics/src/main/java/com/spotify/zoltar/examples/metrics/CustomMetricsExample.java) { #PredictorMetrics }

@@snip [PredictorBuilderWithMetrics](../../../../examples/custom-metrics/src/main/java/com/spotify/zoltar/examples/metrics/CustomMetricsExample.java) { #PredictorBuilderWithMetrics }
@@snip [PredictorWithMetrics](../../../../examples/custom-metrics/src/main/java/com/spotify/zoltar/examples/metrics/CustomMetricsExample.java) { #PredictorWithMetrics }

## Example

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import okio.ByteString;

import org.tensorflow.example.Example;

import com.typesafe.config.Config;

import com.spotify.apollo.Environment;
Expand All @@ -36,6 +38,7 @@
import com.spotify.zoltar.Predictor;
import com.spotify.zoltar.metrics.PredictorMetrics;
import com.spotify.zoltar.metrics.semantic.SemanticPredictorMetrics;
import com.spotify.zoltar.tf.TensorFlowModel;

/** Application entry point. */
public class App {
Expand All @@ -61,7 +64,7 @@ static void configure(final Environment environment) {
throw new RuntimeException(e.getMessage());
}

final Predictor<Iris, Long> predictor;
final Predictor<TensorFlowModel, Iris, Example, Long> predictor;
try {
final ModelConfig irisModelConfig = ModelConfig.from(config.getConfig("iris"));
predictor = IrisPredictor.create(irisModelConfig, metrics);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
import java.util.stream.Stream;

import okio.ByteString;

import org.tensorflow.example.Example;

import scala.Option;

import com.google.common.collect.ImmutableMap;
Expand All @@ -35,6 +38,7 @@
import com.spotify.zoltar.IrisFeaturesSpec.Iris;
import com.spotify.zoltar.Prediction;
import com.spotify.zoltar.Predictor;
import com.spotify.zoltar.tf.TensorFlowModel;

/** Route endpoints. */
final class IrisPredictionHandler {
Expand All @@ -45,13 +49,14 @@ final class IrisPredictionHandler {
1L, "Iris-versicolor",
2L, "Iris-virginica");

private IrisPredictionHandler(final Predictor<Iris, Long> predictor) {
private IrisPredictionHandler(final Predictor<TensorFlowModel, Iris, Example, Long> predictor) {
this.predictor = predictor;
}

private final Predictor<Iris, Long> predictor;
private final Predictor<TensorFlowModel, Iris, Example, Long> predictor;

static IrisPredictionHandler create(final Predictor<Iris, Long> predictor) {
static IrisPredictionHandler create(
final Predictor<TensorFlowModel, Iris, Example, Long> predictor) {
return new IrisPredictionHandler(predictor);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@
import com.spotify.zoltar.featran.FeatranExtractFns;
import com.spotify.zoltar.metrics.PredictorMetrics;
import com.spotify.zoltar.tf.TensorFlowLoader;
import com.spotify.zoltar.tf.TensorFlowModel;

/** Iris prediction meat and potatoes. */
public final class IrisPredictor {

/** Configure Iris prediction, should be called at the service startup/configuration stage. */
public static Predictor<Iris, Long> create(
public static Predictor<TensorFlowModel, Iris, Example, Long> create(
final ModelConfig modelConfig, final PredictorMetrics metrics) throws IOException {

final FeatureSpec<Iris> irisFeatureSpec = IrisFeaturesSpec.irisFeaturesSpec();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,24 @@
*/
package com.spotify.zoltar.examples.batch;

import java.time.Duration;
import java.util.List;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ScheduledExecutorService;
import java.util.function.Function;
import java.util.stream.Collectors;

import com.spotify.zoltar.FeatureExtractFns.BatchExtractFn;
import com.spotify.zoltar.FeatureExtractor;
import com.spotify.zoltar.ModelLoader;
import com.spotify.zoltar.PredictFns.AsyncPredictFn;
import com.spotify.zoltar.PredictFns.PredictFn;
import com.spotify.zoltar.Prediction;
import com.spotify.zoltar.Predictor;
import com.spotify.zoltar.PredictorBuilder;
import com.spotify.zoltar.Predictors;

/** Example showing a batch predictor. */
class BatchPredictorExample implements Predictor<List<Integer>, List<Float>> {
class BatchPredictorExample
implements Predictor<DummyModel, List<Integer>, List<Float>, List<Float>> {

private PredictorBuilder<DummyModel, List<Integer>, List<Float>, List<Float>> predictorBuilder;
private Predictor<DummyModel, List<Integer>, List<Float>, List<Float>> predictor;

BatchPredictorExample() {
final ModelLoader<DummyModel> modelLoader = ModelLoader.loaded(new DummyModel());
Expand All @@ -55,15 +54,22 @@ class BatchPredictorExample implements Predictor<List<Integer>, List<Float>> {
.collect(Collectors.toList());
};

// We build the PredictorBuilder as usual
predictorBuilder = Predictors.newBuilder(modelLoader, batchExtractFn, predictFn);
// We build the Predictor as usual
predictor = Predictors.create(modelLoader, batchExtractFn, predictFn);
}

@Override
public CompletionStage<List<Prediction<List<Integer>, List<Float>>>> predict(
final ScheduledExecutorService scheduler,
final Duration timeout,
final List<Integer>... input) {
return predictorBuilder.predictor().predict(scheduler, timeout, input);
public ModelLoader<DummyModel> modelLoader() {
return predictor.modelLoader();
}

@Override
public FeatureExtractor<DummyModel, List<Integer>, List<Float>> featureExtractor() {
return predictor.featureExtractor();
}

@Override
public AsyncPredictFn<DummyModel, List<Integer>, List<Float>, List<Float>> predictFn() {
return predictor.predictFn();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
*/
package com.spotify.zoltar.examples.metrics;

import java.time.Duration;
import java.util.List;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ScheduledExecutorService;
import java.util.stream.Collectors;

import com.codahale.metrics.Counter;
Expand All @@ -33,11 +30,10 @@
import com.spotify.zoltar.FeatureExtractor;
import com.spotify.zoltar.Model;
import com.spotify.zoltar.ModelLoader;
import com.spotify.zoltar.PredictFns.AsyncPredictFn;
import com.spotify.zoltar.PredictFns.PredictFn;
import com.spotify.zoltar.Prediction;
import com.spotify.zoltar.Predictor;
import com.spotify.zoltar.PredictorBuilder;
import com.spotify.zoltar.Predictors;
import com.spotify.zoltar.Vector;
import com.spotify.zoltar.metrics.FeatureExtractorMetrics;
import com.spotify.zoltar.metrics.Instrumentations;
Expand All @@ -48,9 +44,9 @@
import com.spotify.zoltar.metrics.semantic.SemanticPredictorMetrics;

/** Example showing how to add custom metrics to a Predictor. */
class CustomMetricsExample implements Predictor<Integer, Float> {
class CustomMetricsExample implements Predictor<DummyModel, Integer, Float, Float> {

private PredictorBuilder<DummyModel, Integer, Float, Float> predictorBuilder;
private Predictor<DummyModel, Integer, Float, Float> predictor;

/** Define a class containing all the additional metrics we want to register. */
@AutoValue
Expand Down Expand Up @@ -160,10 +156,8 @@ public void extraction(final List<Vector<Integer, Float>> vectors) {
.map(vector -> Prediction.create(vector.input(), vector.value() * 2))
.collect(Collectors.toList());
};
final FeatureExtractor<DummyModel, Integer, Float> featureExtractor =
FeatureExtractor.create(extractFn);

// We build the PredictorBuilder as usual, compose with the built-in metrics, and then compose
// We build the Predictor as usual, compose with the built-in metrics, and then compose
// with our custom metrics.
// #PredictorMetrics
final PredictorMetrics<Integer, Float, Float> predictorMetrics =
Expand All @@ -173,16 +167,30 @@ public void extraction(final List<Vector<Integer, Float>> vectors) {
final PredictorMetrics<Integer, Float, Float> customMetrics =
CustomPredictorMetrics.create(metricRegistry, metricId);

predictorBuilder =
// #PredictorBuilderWithMetrics
Predictors.newBuilder(modelLoader, featureExtractor, predictFn, predictorMetrics)
// #PredictorBuilderWithMetrics
// #PredictorWithMetrics
predictor =
Predictor.<DummyModel, Integer, Float, Float>builder()
.modelLoader(modelLoader)
.featureExtractFn(extractFn)
.predictFn(predictFn)
.build()
.with(Instrumentations.predictor(predictorMetrics))
.with(Instrumentations.predictor(customMetrics));
// #PredictorWithMetrics
}

@Override
public CompletionStage<List<Prediction<Integer, Float>>> predict(
final ScheduledExecutorService scheduler, final Duration timeout, final Integer... input) {
return predictorBuilder.predictor().predict(scheduler, timeout, input);
public ModelLoader<DummyModel> modelLoader() {
return predictor.modelLoader();
}

@Override
public FeatureExtractor<DummyModel, Integer, Float> featureExtractor() {
return predictor.featureExtractor();
}

@Override
public AsyncPredictFn<DummyModel, Integer, Float, Float> predictFn() {
return predictor.predictFn();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,10 @@
import java.net.URI;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ScheduledExecutorService;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

Expand All @@ -39,23 +36,29 @@
import com.spotify.featran.FeatureSpec;
import com.spotify.futures.CompletableFutures;
import com.spotify.zoltar.FeatureExtractFns.ExtractFn;
import com.spotify.zoltar.FeatureExtractor;
import com.spotify.zoltar.IrisFeaturesSpec;
import com.spotify.zoltar.IrisFeaturesSpec.Iris;
import com.spotify.zoltar.ModelLoader;
import com.spotify.zoltar.Models;
import com.spotify.zoltar.PredictFns.AsyncPredictFn;
import com.spotify.zoltar.Prediction;
import com.spotify.zoltar.Predictor;
import com.spotify.zoltar.Predictors;
import com.spotify.zoltar.featran.FeatranExtractFns;
import com.spotify.zoltar.mlengine.MlEngineLoader;
import com.spotify.zoltar.mlengine.MlEngineModel;
import com.spotify.zoltar.mlengine.MlEnginePredictException;
import com.spotify.zoltar.mlengine.MlEnginePredictFn;

/** Cloud Machine Learning Engine online prediction example. */
public final class MlEnginePredictorExample implements Predictor<Iris, Integer> {
public final class MlEnginePredictorExample
implements Predictor<MlEngineModel, Iris, Example, Integer> {

private final Predictor<Iris, Integer> predictor;
private final Predictor<MlEngineModel, Iris, Example, Integer> predictor;

private MlEnginePredictorExample(final Predictor<Iris, Integer> predictor) {
private MlEnginePredictorExample(
final Predictor<MlEngineModel, Iris, Example, Integer> predictor) {
this.predictor = predictor;
}

Expand Down Expand Up @@ -111,16 +114,25 @@ public static MlEnginePredictorExample create(
return CompletableFutures.allAsList(predictions);
};

final Predictor<Iris, Integer> predictor =
Predictors.newBuilder(mlEngineLoader, extractFn, predictFn).predictor();
final Predictor<MlEngineModel, Iris, Example, Integer> predictor =
Predictors.create(mlEngineLoader, extractFn, predictFn);

return new MlEnginePredictorExample(predictor);
}

@Override
public CompletionStage<List<Prediction<Iris, Integer>>> predict(
final ScheduledExecutorService scheduler, final Duration timeout, final Iris... input) {
return predictor.predict(scheduler, timeout, input);
public ModelLoader<MlEngineModel> modelLoader() {
return predictor.modelLoader();
}

@Override
public FeatureExtractor<MlEngineModel, Iris, Example> featureExtractor() {
return predictor.featureExtractor();
}

@Override
public AsyncPredictFn<MlEngineModel, Iris, Example, Integer> predictFn() {
return predictor.predictFn();
}

@AutoValue
Expand Down
Loading

0 comments on commit b49544d

Please sign in to comment.