Skip to content
This repository has been archived by the owner on Oct 18, 2023. It is now read-only.

Proposal on Prediction interface cleanup #170

Open
wants to merge 2 commits into
base: 0.6.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.spotify.zoltar.Predictors;
import com.spotify.zoltar.featran.FeatranExtractFns;
import com.spotify.zoltar.metrics.PredictorMetrics;
import com.spotify.zoltar.tf.TensorFlowLoader;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
Expand All @@ -44,14 +45,17 @@ public final class IrisPredictor {

public static Predictor<Iris, Long> create(final ModelConfig modelConfig,
final PredictorMetrics metrics) throws IOException {

final FeatureSpec<Iris> irisFeatureSpec = IrisFeaturesSpec.irisFeaturesSpec();
final String settings = new String(Files.readAllBytes(Paths.get(modelConfig.settingsUri())));
final ExtractFn<Iris, Example> extractFn =
FeatranExtractFns.example(irisFeatureSpec, settings);
final TensorFlowLoader modelLoader = TensorFlowLoader
.create(modelConfig.modelUri().toString(), modelConfig.modelLoaderExecutor());

final String[] ops = new String[]{"linear/head/predictions/class_ids"};
return Predictors.tensorFlow(
modelConfig.modelUri().toString(),
modelLoader,
extractFn,
tensors -> tensors.get(ops[0]).longValue()[0],
ops,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import com.typesafe.config.Config;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;

/**
* Immutable object that contains model / feature extraction configuration properties.
Expand All @@ -37,6 +39,11 @@ public abstract class ModelConfig {
/** settings URI path. */
public abstract URI settingsUri();

public Executor modelLoaderExecutor() {
final int threads = Runtime.getRuntime().availableProcessors();
return Executors.newFixedThreadPool(threads);
}

/**
* Creates a {@link ModelConfig} create a {@link Config}.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import com.spotify.zoltar.Predictor;
import com.spotify.zoltar.PredictorBuilder;
import com.spotify.zoltar.Predictors;
import com.spotify.zoltar.loaders.Preloader;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.CompletionStage;
Expand All @@ -43,9 +42,7 @@ class BatchPredictorExample implements Predictor<List<Integer>, List<Float>> {
private PredictorBuilder<DummyModel, List<Integer>, List<Float>, List<Float>> predictorBuilder;

BatchPredictorExample() {
final ModelLoader<DummyModel> modelLoader = ModelLoader
.lift(DummyModel::new)
.with(Preloader.preload(Duration.ofMinutes(1)));
final ModelLoader<DummyModel> modelLoader = ModelLoader.loaded(new DummyModel());

final BatchExtractFn<Integer, Float> batchExtractFn =
BatchExtractFn.lift((Function<Integer, Float>) input -> (float) input / 10);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import com.spotify.zoltar.PredictorBuilder;
import com.spotify.zoltar.Predictors;
import com.spotify.zoltar.Vector;
import com.spotify.zoltar.loaders.Preloader;
import com.spotify.zoltar.metrics.FeatureExtractorMetrics;
import com.spotify.zoltar.metrics.Instrumentations;
import com.spotify.zoltar.metrics.PredictFnMetrics;
Expand Down Expand Up @@ -165,9 +164,7 @@ public void extraction(final List<Vector<Integer, Float>> vectors) {
}

CustomMetricsExample(final SemanticMetricRegistry metricRegistry, final MetricId metricId) {
final ModelLoader<DummyModel> modelLoader = ModelLoader
.lift(DummyModel::new)
.with(Preloader.preload(Duration.ofMinutes(1)));
final ModelLoader<DummyModel> modelLoader = ModelLoader.loaded(new DummyModel());
final ExtractFn<Integer, Float> extractFn = ExtractFn.lift(input -> (float) input / 10);
final PredictFn<DummyModel, Integer, Float, Float> predictFn = (model, vectors) -> {
return vectors.stream()
Expand Down
69 changes: 47 additions & 22 deletions zoltar-api/src/main/java/com/spotify/zoltar/Models.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
import com.spotify.zoltar.tf.TensorFlowLoader;
import com.spotify.zoltar.tf.TensorFlowModel;
import com.spotify.zoltar.xgboost.XGBoostLoader;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.util.concurrent.Executor;
import javax.annotation.Nullable;
import org.tensorflow.Graph;
import org.tensorflow.framework.ConfigProto;
Expand All @@ -45,9 +48,10 @@ private Models() {
*
* @param modelUri should point to serialized XGBoost model file, can be a URI to a local
* filesystem, resource, GCS etc.
* @param executor the executor to use for asynchronous execution.
*/
public static XGBoostLoader xgboost(final String modelUri) {
return XGBoostLoader.create(modelUri);
public static XGBoostLoader xgboost(final String modelUri, final Executor executor) {
return XGBoostLoader.create(modelUri, executor);
}

/**
Expand All @@ -56,9 +60,12 @@ public static XGBoostLoader xgboost(final String modelUri) {
* @param id model id @{link Model.Id}.
* @param modelUri should point to serialized XGBoost model file, can be a URI to a local
* filesystem, resource, GCS etc.
* @param executor the executor to use for asynchronous execution.
*/
public static XGBoostLoader xgboost(final Model.Id id, final String modelUri) {
return XGBoostLoader.create(id, modelUri);
public static XGBoostLoader xgboost(final Model.Id id,
final String modelUri,
final Executor executor) {
return XGBoostLoader.create(id, modelUri, executor);
}

/**
Expand All @@ -67,9 +74,10 @@ public static XGBoostLoader xgboost(final Model.Id id, final String modelUri) {
* @param modelUri should point to a directory of the saved TensorFlow {@link
* org.tensorflow.SavedModelBundle}, can be a URI to a local filesystem, resource,
* GCS etc.
* @param executor the executor to use for asynchronous execution.
*/
public static TensorFlowLoader tensorFlow(final String modelUri) {
return TensorFlowLoader.create(modelUri);
public static TensorFlowLoader tensorFlow(final String modelUri, final Executor executor) {
return TensorFlowLoader.create(modelUri, executor);
}

/**
Expand All @@ -79,9 +87,12 @@ public static TensorFlowLoader tensorFlow(final String modelUri) {
* @param modelUri should point to a directory of the saved TensorFlow {@link
* org.tensorflow.SavedModelBundle}, can be a URI to a local filesystem, resource,
* GCS etc.
* @param executor the executor to use for asynchronous execution.
*/
public static TensorFlowLoader tensorFlow(final Model.Id id, final String modelUri) {
return TensorFlowLoader.create(id, modelUri);
public static TensorFlowLoader tensorFlow(final Model.Id id,
final String modelUri,
final Executor executor) {
return TensorFlowLoader.create(id, modelUri, executor);
}

/**
Expand All @@ -91,10 +102,12 @@ public static TensorFlowLoader tensorFlow(final Model.Id id, final String modelU
* org.tensorflow.SavedModelBundle}, can be a URI to a local filesystem, resource,
* GCS etc.
* @param options TensorFlow options, see {@link TensorFlowModel.Options}.
* @param executor the executor to use for asynchronous execution.
*/
public static TensorFlowLoader tensorFlow(final String modelUri,
final TensorFlowModel.Options options) {
return TensorFlowLoader.create(modelUri, options);
final TensorFlowModel.Options options,
final Executor executor) {
return TensorFlowLoader.create(modelUri, options, executor);
}

/**
Expand All @@ -105,11 +118,13 @@ public static TensorFlowLoader tensorFlow(final String modelUri,
* org.tensorflow.SavedModelBundle}, can be a URI to a local filesystem, resource,
* GCS etc.
* @param options TensorFlow options, see {@link TensorFlowModel.Options}.
* @param executor the executor to use for asynchronous execution.
*/
public static TensorFlowLoader tensorFlow(final Model.Id id,
final String modelUri,
final TensorFlowModel.Options options) {
return TensorFlowLoader.create(id, modelUri, options);
final TensorFlowModel.Options options,
final Executor executor) {
return TensorFlowLoader.create(id, modelUri, options, executor);
}

/**
Expand All @@ -119,12 +134,14 @@ public static TensorFlowLoader tensorFlow(final Model.Id id,
* local filesystem, resource, GCS etc.
* @param config optional TensorFlow {@link ConfigProto} config.
* @param prefix optional prefix that will be prepended to names in the graph.
* @param executor the executor to use for asynchronous execution.
*/
public static TensorFlowGraphLoader tensorFlowGraph(
final String modelUri,
@Nullable final ConfigProto config,
@Nullable final String prefix) {
return TensorFlowGraphLoader.create(modelUri, config, prefix);
@Nullable final String prefix,
final Executor executor) {
return TensorFlowGraphLoader.create(modelUri, config, prefix, executor);
}

/**
Expand All @@ -135,13 +152,15 @@ public static TensorFlowGraphLoader tensorFlowGraph(
* local filesystem, resource, GCS etc.
* @param config optional TensorFlow {@link ConfigProto} config.
* @param prefix optional prefix that will be prepended to names in the graph.
* @param executor the executor to use for asynchronous execution.
*/
public static TensorFlowGraphLoader tensorFlowGraph(
final Model.Id id,
final String modelUri,
@Nullable final ConfigProto config,
@Nullable final String prefix) {
return TensorFlowGraphLoader.create(id, modelUri, config, prefix);
@Nullable final String prefix,
final Executor executor) {
return TensorFlowGraphLoader.create(id, modelUri, config, prefix, executor);
}

/**
Expand All @@ -150,12 +169,14 @@ public static TensorFlowGraphLoader tensorFlowGraph(
* @param graphDef byte array representing the TensorFlow {@link Graph} definition.
* @param config optional TensorFlow {@link ConfigProto} config.
* @param prefix optional prefix that will be prepended to names in the graph.
* @param executor the executor to use for asynchronous execution.
*/
public static TensorFlowGraphLoader tensorFlowGraph(
final byte[] graphDef,
@Nullable final ConfigProto config,
@Nullable final String prefix) {
return TensorFlowGraphLoader.create(graphDef, config, prefix);
@Nullable final String prefix,
final Executor executor) {
return TensorFlowGraphLoader.create(graphDef, config, prefix, executor);
}

/**
Expand All @@ -165,13 +186,15 @@ public static TensorFlowGraphLoader tensorFlowGraph(
* @param graphDef byte array representing the TensorFlow {@link Graph} definition.
* @param config optional TensorFlow {@link ConfigProto} config.
* @param prefix optional prefix that will be prepended to names in the graph.
* @param executor the executor to use for asynchronous execution.
*/
public static TensorFlowGraphLoader tensorFlowGraph(
final Model.Id id,
final byte[] graphDef,
@Nullable final ConfigProto config,
@Nullable final String prefix) {
return TensorFlowGraphLoader.create(id, graphDef, config, prefix);
@Nullable final String prefix,
final Executor executor) {
return TensorFlowGraphLoader.create(id, graphDef, config, prefix, executor);
}

/**
Expand All @@ -180,7 +203,8 @@ public static TensorFlowGraphLoader tensorFlowGraph(
* @param id model id. Id needs to be in the following format:
* <code>projects/$projectId/models/$modelId/version/$versionId</code>
*/
public static MlEngineLoader mlEngine(final Model.Id id) {
public static MlEngineLoader mlEngine(final Model.Id id)
throws IOException, GeneralSecurityException {
return MlEngineLoader.create(id);
}

Expand All @@ -193,7 +217,8 @@ public static MlEngineLoader mlEngine(final Model.Id id) {
*/
public static MlEngineLoader mlEngine(final String projectId,
final String modelId,
final String versionId) {
final String versionId)
throws IOException, GeneralSecurityException {
return MlEngineLoader.create(projectId, modelId, versionId);
}
}
Loading