From 6536a29d723892e6984165f63da9377b369e28c8 Mon Sep 17 00:00:00 2001 From: Alex Soto Date: Wed, 16 Oct 2024 10:34:00 +0200 Subject: [PATCH] Overloads download method to specify progress --- .../github/tjake/jlama/safetensors/SafeTensorSupport.java | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorSupport.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorSupport.java index c12267c..4caefb4 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorSupport.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/SafeTensorSupport.java @@ -310,7 +310,7 @@ public void write(byte[] b, int off, int len) throws IOException { return qPath; } - public static File maybeDownloadModel(String modelDir, String fullModelName) throws IOException { + public static File maybeDownloadModel(String modelDir, String fullModelName, TriConsumer progressReporter) throws IOException { String[] parts = fullModelName.split("/"); if (parts.length == 0 || parts.length > 2) { throw new IllegalArgumentException("Model must be in the form owner/name"); @@ -327,7 +327,11 @@ public static File maybeDownloadModel(String modelDir, String fullModelName) thr name = parts[1]; } - return maybeDownloadModel(modelDir, Optional.ofNullable(owner), name, true, Optional.empty(), Optional.empty(), Optional.empty()); + return maybeDownloadModel(modelDir, Optional.ofNullable(owner), name, true, Optional.empty(), Optional.empty(), Optional.ofNullable(progressReporter)); + } + + public static File maybeDownloadModel(String modelDir, String fullModelName) throws IOException { + return maybeDownloadModel(modelDir, fullModelName, null); } public static Path constructLocalModelPath(String modelDir, String owner, String modelName) {