Skip to content

Commit

Permalink
Add net and cli to maven, include a jbang script
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Sep 16, 2024
1 parent 555494c commit fe44674
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 117 deletions.
28 changes: 17 additions & 11 deletions jlama-cli/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,6 @@
<artifactId>jlama-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.github.tjake</groupId>
<artifactId>jlama-native</artifactId>
<version>${project.version}</version>
<classifier>${jni.classifier}</classifier>
</dependency>
<dependency>
<groupId>com.github.tjake</groupId>
<artifactId>jlama-net</artifactId>
Expand All @@ -61,6 +55,18 @@
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<version>3.2.0</version> <!-- Use the latest version available -->
<configuration>
<archive>
<manifest>
<mainClass>com.github.tjake.jlama.cli.JlamaCli</mainClass>
</manifest>
</archive>
</configuration>
</plugin>
<plugin>
<artifactId>maven-resources-plugin</artifactId>
<version>3.0.2</version>
Expand All @@ -85,7 +91,7 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>3.4.1</version>
<version>3.6.0</version>
<executions>
<execution>
<id>shade-cli</id>
Expand All @@ -94,6 +100,10 @@
<goal>shade</goal>
</goals>
<configuration>
<extraJars>
<jar>../jlama-native/target/jlama-native-${project.version}-${jni.classifier}.jar</jar>
</extraJars>
<createDependencyReducedPom>false</createDependencyReducedPom>
<outputFile>${project.basedir}/target/jlama-cli.jar</outputFile>
<transformers>
<transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
Expand All @@ -106,10 +116,6 @@
</transformers>
<filters>
<filter>
<!--
Shading signed JARs will fail without this.
http://stackoverflow.com/questions/999489/invalid-signature-file-when-attempting-to-run-a-jar
-->
<artifact>*:*</artifact>
<excludes>
<exclude>META-INF/*.SF</exclude>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
import picocli.CommandLine;
import picocli.CommandLine.*;

@Command(name = "jlama", mixinStandardHelpOptions = true, requiredOptionMarker = '*', usageHelpAutoWidth = true, sortOptions = true)
@Command(name = "jlama", mixinStandardHelpOptions = true, requiredOptionMarker = '*',
usageHelpAutoWidth = true, sortOptions = true,
description = "Jlama is a modern LLM inference engine for Java!\n\n" +
"Quantized models are maintained at https://hf.co/tjake\n")
public class JlamaCli implements Runnable {
static {
System.setProperty("jdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK", "0");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import me.tongfei.progressbar.ProgressBarStyle;
import picocli.CommandLine;

@CommandLine.Command(name = "download", description = "Downloads the specified model")
@CommandLine.Command(name = "download", description = "Downloads a HuggingFace model - use owner/name format")
public class DownloadCommand extends JlamaCli {
@CommandLine.Option(names = { "-d",
"--model-directory" }, description = "The directory to download the model to (default: ${DEFAULT-VALUE})", defaultValue = "models")
Expand Down Expand Up @@ -71,6 +71,7 @@ public void run() {
modelDirectory.getAbsolutePath(),
Optional.ofNullable(owner),
name,
false,
Optional.ofNullable(URLEncoder.encode(branch)),
Optional.ofNullable(authToken),
Optional.of((n, c, t) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.github.tjake.jlama.tensor.Q4ByteBufferTensor;
import com.github.tjake.jlama.tensor.Q5ByteBufferTensor;
import com.github.tjake.jlama.tensor.Q8ByteBufferTensor;
import com.github.tjake.jlama.util.HttpSupport;
import com.github.tjake.jlama.util.Pair;
import com.github.tjake.jlama.util.TriConsumer;
import com.google.common.base.Preconditions;
Expand Down Expand Up @@ -307,7 +308,7 @@ public static File maybeDownloadModel(String modelDir, String fullModelName) thr
name = parts[1];
}

return maybeDownloadModel(modelDir, Optional.ofNullable(owner), name, Optional.empty(), Optional.empty(), Optional.empty());
return maybeDownloadModel(modelDir, Optional.ofNullable(owner), name, false, Optional.empty(), Optional.empty(), Optional.empty());
}

/**
Expand All @@ -326,16 +327,17 @@ public static File maybeDownloadModel(
String modelDir,
Optional<String> modelOwner,
String modelName,
boolean metadataOnly,
Optional<String> optionalBranch,
Optional<String> optionalAuthHeader,
Optional<TriConsumer<String, Long, Long>> optionalProgressReporter
) throws IOException {
String hfModel = modelOwner.map(mo -> mo + "/" + modelName).orElse(modelName);
InputStream modelInfoStream = getResponse(
InputStream modelInfoStream = HttpSupport.getResponse(
"https://huggingface.co/api/models/" + hfModel + "/tree/" + optionalBranch.orElse("main"),
optionalAuthHeader
).left;
String modelInfo = readInputStream(modelInfoStream);
String modelInfo = HttpSupport.readInputStream(modelInfoStream);

if (modelInfo == null) {
throw new IOException("No valid model found or trying to access a restricted model (please include correct access token)");
Expand Down Expand Up @@ -369,7 +371,7 @@ public static File maybeDownloadModel(
Files.createDirectories(localModelDir);

for (String currFile : tensorFiles) {
downloadFile(hfModel, currFile, optionalBranch, optionalAuthHeader, localModelDir.resolve(currFile), optionalProgressReporter);
HttpSupport.downloadFile(hfModel, currFile, optionalBranch, optionalAuthHeader, localModelDir.resolve(currFile), optionalProgressReporter);
}

return localModelDir.toFile();
Expand All @@ -389,97 +391,4 @@ private static List<String> parseFileList(String modelInfo) throws IOException {

return fileList;
}

private static Pair<InputStream, Long> getResponse(String urlString, Optional<String> optionalAuthHeader) throws IOException {
URL url = new URL(urlString);
HttpURLConnection connection = (HttpURLConnection) url.openConnection();

// Set the request method
connection.setRequestMethod("GET");

// Set the request header
optionalAuthHeader.ifPresent(authHeader -> connection.setRequestProperty("Authorization", "Bearer " + authHeader));

// Get the response code
int responseCode = connection.getResponseCode();

if (responseCode == HttpURLConnection.HTTP_OK) {
// If the response code is 200 (HTTP_OK), return the input stream
return Pair.of(connection.getInputStream(), connection.getContentLengthLong());
} else {
// If the response code is not 200, throw an IOException
throw new IOException("HTTP response code: " + responseCode + " for URL: " + urlString);
}
}

private static String readInputStream(InputStream inStream) throws IOException {
if (inStream == null) return null;

BufferedReader inReader = new BufferedReader(new InputStreamReader(inStream));
StringBuilder stringBuilder = new StringBuilder();

String currLine;
while ((currLine = inReader.readLine()) != null) {
stringBuilder.append(currLine);
stringBuilder.append(System.lineSeparator());
}

return stringBuilder.toString();
}

private static void downloadFile(
String hfModel,
String currFile,
Optional<String> optionalBranch,
Optional<String> optionalAuthHeader,
Path outputPath,
Optional<TriConsumer<String, Long, Long>> optionalProgressConsumer
) throws IOException {
try {
Pair<InputStream, Long> stream = getResponse(
"https://huggingface.co/" + hfModel + "/resolve/" + optionalBranch.orElse("main") + "/" + currFile,
optionalAuthHeader
);

CountingInputStream inStream = new CountingInputStream(stream.left);

long totalBytes = stream.right;

if (outputPath.toFile().exists() && outputPath.toFile().length() == totalBytes) {
logger.debug("File already exists: {}", outputPath);
return;
}

if (optionalProgressConsumer.isEmpty()) logger.info("Downloading file: {}", outputPath);

optionalProgressConsumer.ifPresent(p -> p.accept(currFile, 0L, totalBytes));

CompletableFuture<Long> result = CompletableFuture.supplyAsync(() -> {
try {
return Files.copy(inStream, outputPath, StandardCopyOption.REPLACE_EXISTING);
} catch (IOException e) {
throw new RuntimeException(e);
}
});

optionalProgressConsumer.ifPresent(p -> {
while (!result.isDone()) {
p.accept(currFile, inStream.getCount(), totalBytes);
}

if (result.isCompletedExceptionally()) p.accept(currFile, inStream.getCount(), totalBytes);
else p.accept(currFile, totalBytes, totalBytes);
});

try {
result.get();
} catch (Throwable e) {
throw new IOException("Failed to download file: " + currFile, e);
}

if (optionalProgressConsumer.isEmpty() && !result.isCompletedExceptionally()) logger.info("Downloaded file: {}", outputPath);
} catch (IOException e) {
throw e;
}
}
}
113 changes: 113 additions & 0 deletions jlama-core/src/main/java/com/github/tjake/jlama/util/HttpSupport.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package com.github.tjake.jlama.util;

import com.google.common.io.CountingInputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;

public class HttpSupport {
public static final Logger logger = LoggerFactory.getLogger(HttpSupport.class);

public static Pair<InputStream, Long> getResponse(String urlString, Optional<String> optionalAuthHeader)
throws IOException {
URL url = new URL(urlString);
HttpURLConnection connection = (HttpURLConnection) url.openConnection();

// Set the request method
connection.setRequestMethod("GET");

// Set the request header
optionalAuthHeader.ifPresent(
authHeader -> connection.setRequestProperty("Authorization", "Bearer " + authHeader));

// Get the response code
int responseCode = connection.getResponseCode();

if (responseCode == HttpURLConnection.HTTP_OK) {
// If the response code is 200 (HTTP_OK), return the input stream
return Pair.of(connection.getInputStream(), connection.getContentLengthLong());
} else {
// If the response code is not 200, throw an IOException
throw new IOException("HTTP response code: " + responseCode + " for URL: " + urlString);
}
}

public static String readInputStream(InputStream inStream) throws IOException {
if (inStream == null) return null;

BufferedReader inReader = new BufferedReader(new InputStreamReader(inStream));
StringBuilder stringBuilder = new StringBuilder();

String currLine;
while ((currLine = inReader.readLine()) != null) {
stringBuilder.append(currLine);
stringBuilder.append(System.lineSeparator());
}

return stringBuilder.toString();
}

public static void downloadFile(
String hfModel,
String currFile,
Optional<String> optionalBranch,
Optional<String> optionalAuthHeader,
Path outputPath,
Optional<TriConsumer<String, Long, Long>> optionalProgressConsumer)
throws IOException {

Pair<InputStream, Long> stream = getResponse(
"https://huggingface.co/" + hfModel + "/resolve/" + optionalBranch.orElse("main") + "/" + currFile,
optionalAuthHeader);

CountingInputStream inStream = new CountingInputStream(stream.left);

long totalBytes = stream.right;

if (outputPath.toFile().exists() && outputPath.toFile().length() == totalBytes) {
logger.debug("File already exists: {}", outputPath);
return;
}

if (optionalProgressConsumer.isEmpty()) logger.info("Downloading file: {}", outputPath);

optionalProgressConsumer.ifPresent(p -> p.accept(currFile, 0L, totalBytes));

CompletableFuture<Long> result = CompletableFuture.supplyAsync(() -> {
try {
return Files.copy(inStream, outputPath, StandardCopyOption.REPLACE_EXISTING);
} catch (IOException e) {
throw new RuntimeException(e);
}
});

optionalProgressConsumer.ifPresent(p -> {
while (!result.isDone()) {
p.accept(currFile, inStream.getCount(), totalBytes);
}

if (result.isCompletedExceptionally()) p.accept(currFile, inStream.getCount(), totalBytes);
else p.accept(currFile, totalBytes, totalBytes);
});

try {
result.get();
} catch (Throwable e) {
throw new IOException("Failed to download file: " + currFile, e);
}

if (optionalProgressConsumer.isEmpty() && !result.isCompletedExceptionally())
logger.info("Downloaded file: {}", outputPath);
}
}
7 changes: 0 additions & 7 deletions jlama-net/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,6 @@
<artifactId>jlama-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.github.tjake</groupId>
<artifactId>jlama-native</artifactId>
<version>${project.version}</version>
<classifier>${jni.classifier}</classifier>
</dependency>

<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-netty-shaded</artifactId>
Expand Down
2 changes: 2 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,8 @@
<modules>
<module>jlama-core</module>
<module>jlama-native</module>
<module>jlama-net</module>
<module>jlama-cli</module>
</modules>
<distributionManagement>
<snapshotRepository>
Expand Down

0 comments on commit fe44674

Please sign in to comment.