Skip to content

Commit

Permalink
Add quantize command to cli and use aligned offheap memory
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Nov 4, 2023
1 parent 56637f0 commit 4cccd14
Show file tree
Hide file tree
Showing 18 changed files with 513 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import com.github.tjake.jlama.cli.commands.ChatCommand;
import com.github.tjake.jlama.cli.commands.CompleteCommand;
import com.github.tjake.jlama.cli.commands.QuantizeCommand;
import com.github.tjake.jlama.cli.commands.ServeCommand;
import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.ModelSupport.ModelType;
Expand All @@ -30,6 +31,7 @@ public class JlamaCli implements Runnable {

public static void main(String[] args) {
CommandLine cli = new CommandLine(new JlamaCli());
cli.addSubcommand("quantize", new QuantizeCommand());
cli.addSubcommand("chat", new ChatCommand());
cli.addSubcommand("complete", new CompleteCommand());
cli.addSubcommand("serve", new ServeCommand());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package com.github.tjake.jlama.cli.commands;


import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Optional;

import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.SafeTensorSupport;
import picocli.CommandLine;

@CommandLine.Command(name = "quantize", description = "Quantize the specified model")

public class QuantizeCommand extends BaseCommand {

@CommandLine.Parameters(index = "1", arity = "0..1", description = "The output location")
protected Path output;

@CommandLine.Option(names = { "-q", "--quantization"}, description = "Model quantization type", arity = "1")
protected DType modelQuantization;

@CommandLine.Option(names = {"-s", "--skip-layer"}, description = "Layer name prefix to not quantize")
protected String[] skipLayerPrefixes;

@Override
public void run() {

if (!model.exists()) {
System.err.println("Model location does not exist: " + model);
System.exit(1);
}

File baseDir = model.isFile() ? model.getParentFile() : model;

if (!baseDir.isDirectory()) {
System.err.println("Model directory does not exist: " + baseDir);
System.exit(1);
}

try {
Path out = SafeTensorSupport.quantizeModel(baseDir.toPath(), modelQuantization, skipLayerPrefixes, Optional.ofNullable(output));

System.out.println("Quantized model written to: " + out);
} catch (IOException e) {
e.printStackTrace();
System.exit(2);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public class SafeTensorIndex implements WeightLoader, AutoCloseable {
public static final String SINGLE_MODEL_NAME = "model.safetensors";
public static final String MODEL_INDEX_JSON = "model.safetensors.index.json";

private final Map<String, Object> metadata;
private final Map<String, String> metadata;

// Map from weight name to file name (this is what's in the JSON file)
private final Map<String, String> weightFileMap;
Expand Down Expand Up @@ -136,6 +136,33 @@ private Map<List<Long>, List<String>> computeMmapSplits(Map<String, TensorInfo>
return splits;
}

@JsonCreator
SafeTensorIndex(@JsonProperty("metadata") Map<String, String> metadata,
@JsonProperty("weight_map") Map<String, String> weightFileMap) {
this.metadata = ImmutableMap.copyOf(metadata);
this.weightFileMap = ImmutableMap.copyOf(weightFileMap);
}

@Override
public Map<String, String> metadata() {
return metadata;
}

@Override
public Map<String, TensorInfo> tensorInfoMap() {
Map<String, TensorInfo> tensorInfoMap = new HashMap<>();
for (String name : weightMap.keySet()) {
Weights w = weightMap.get(name);
if (w == null)
throw new NoSuchElementException(name);

tensorInfoMap.put(name, w.tensorInfoMap().get(name));
}

return tensorInfoMap;
}

@Override
public AbstractTensor load(String name) {
Weights w = weightMap.get(name);
if (w == null)
Expand All @@ -150,13 +177,6 @@ public DType getModelDType() {
return weightMap.values().iterator().next().getModelDType();
}

@JsonCreator
SafeTensorIndex(@JsonProperty("metadata") Map<String, Object> metadata,
@JsonProperty("weight_map") Map<String, String> weightFileMap) {
this.metadata = ImmutableMap.copyOf(metadata);
this.weightFileMap = ImmutableMap.copyOf(weightFileMap);
}

@Override
public void close() throws Exception {
weightMap.clear();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,37 @@
import com.fasterxml.jackson.databind.type.MapType;
import com.github.tjake.jlama.model.ModelSupport.ModelType;
import com.github.tjake.jlama.safetensors.tokenizer.TokenizerModel;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.Q4ByteBufferTensor;
import com.github.tjake.jlama.tensor.Q5ByteBufferTensor;
import com.github.tjake.jlama.tensor.Q8ByteBufferTensor;

import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;

import java.io.File;
import java.io.IOException;
import java.io.OutputStream;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SafeTensorSupport {
private static final Logger logger = LoggerFactory.getLogger(SafeTensorSupport.class);
private static final ObjectMapper om = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
private static final MapType metadataTypeReference = om.getTypeFactory().constructMapType(Map.class, String.class, String.class);

public static Map<String, TensorInfo> readTensorInfoMap(ByteBuffer buf, Optional<Map<String, String>> saveMetadata) {
long headerLength = buf.order() == ByteOrder.BIG_ENDIAN ? Long.reverseBytes(buf.getLong()) : buf.getLong();
buf = buf.order(ByteOrder.LITTLE_ENDIAN);
long headerLength = buf.getLong();
byte[] header = new byte[Ints.checkedCast(headerLength)];
buf.get(header);

Expand Down Expand Up @@ -76,7 +88,7 @@ public static WeightLoader loadWeights(File baseDir) throws IOException {
if (Files.exists(Paths.get(baseDir.getAbsolutePath(), SafeTensorIndex.SINGLE_MODEL_NAME)))
return SafeTensorIndex.loadSingleFile(baseDir.toPath(), SafeTensorIndex.SINGLE_MODEL_NAME);

throw new IllegalArgumentException("No safetensors model found in: " + baseDir);
throw new IllegalArgumentException("No safetensor model found in: " + baseDir);
}

public static TokenizerModel loadTokenizer(Path modelRoot) throws IOException {
Expand All @@ -89,4 +101,100 @@ public static TokenizerModel loadTokenizer(Path modelRoot) throws IOException {

return om.treeToValue(rootNode.get("model"), TokenizerModel.class);
}

public static Path quantizeModel(Path modelRoot, DType modelQuantization, String[] skipLayerPrefixes, Optional<Path> outputRoot) throws IOException {
File tmp = File.createTempFile("safe", "tensor");
tmp.deleteOnExit();
WeightLoader wl = SafeTensorSupport.loadWeights(modelRoot.toFile());
Map<String, Object> writtenInfo = new HashMap<>();

try (RandomAccessFile raf = new RandomAccessFile(tmp, "rw")) {
Map<String, TensorInfo> tensors = wl.tensorInfoMap();

for (Map.Entry<String, TensorInfo> e : tensors.entrySet()) {
try (AbstractTensor tr = wl.load(e.getKey())) {

boolean skipQ = false;
if (skipLayerPrefixes != null) {
for (String skipLayerPrefix : skipLayerPrefixes) {
if (e.getKey().startsWith(skipLayerPrefix)) {
skipQ = true;
break;
}
}
}

AbstractTensor t = skipQ ? tr : tr.quantize(modelQuantization);

switch (t.dType()) {
case F32:
case BF16:
case F16:
writtenInfo.put(e.getKey(), t.save(raf.getChannel()));
break;
case Q4:
writtenInfo.put(e.getKey(), t.save(raf.getChannel()));
writtenInfo.put(e.getKey() + ".qb", ((Q4ByteBufferTensor) t).getBlockF().save(raf.getChannel()));
break;
case Q5:
writtenInfo.put(e.getKey(), t.save(raf.getChannel()));
writtenInfo.put(e.getKey() + ".qb", ((Q5ByteBufferTensor) t).getBlockF().save(raf.getChannel()));
//FIXME: Need to add b5 bits
throw new UnsupportedOperationException("TODO");
//break;
case I8:
writtenInfo.put(e.getKey(), t.save(raf.getChannel()));
writtenInfo.put(e.getKey() + ".qb", ((Q8ByteBufferTensor) t).getBlockF().save(raf.getChannel()));
break;
default:
throw new UnsupportedOperationException("" + t.dType() + " not implemented");
}
}
}
}

//Now create the output file
String baseDirName = modelRoot.getName(modelRoot.getNameCount() - 1).toString();
Path parentPath = modelRoot.getParent();

Path qPath = outputRoot.orElseGet(() -> Paths.get(parentPath.toString(), baseDirName + "-jlama-" + modelQuantization.name()));
File qDir = qPath.toFile();
qDir.mkdirs();

//Copy config.json and tokenizer.json
Files.copy(modelRoot.resolve("config.json"), qPath.resolve("config.json"));
Files.copy(modelRoot.resolve("tokenizer.json"), qPath.resolve("tokenizer.json"));

try (RandomAccessFile raf = new RandomAccessFile(qPath.resolve("model.safetensors").toFile(), "rw")) {
FileChannel chan = raf.getChannel();

byte[] header = om.writeValueAsBytes(writtenInfo);
logger.debug("pos = {}", chan.position());
byte[] hsize = new byte[Long.BYTES];
ByteBuffer.wrap(hsize).order(ByteOrder.LITTLE_ENDIAN).putLong(header.length);
raf.write(hsize);
logger.debug("pos = {}", chan.position());
raf.write(header);
logger.debug("pos = {}", chan.position());

Files.copy(tmp.toPath(), new OutputStream() {
@Override
public void write(int b) throws IOException {
raf.write(b);
}

@Override
public void write(byte[] b) throws IOException {
raf.write(b);
}

@Override
public void write(byte[] b, int off, int len) throws IOException {
raf.write(b, off, len);
}
});
}

return qPath;
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
package com.github.tjake.jlama.safetensors;

import java.util.Map;

import com.github.tjake.jlama.tensor.AbstractTensor;

public interface WeightLoader {
public interface WeightLoader extends AutoCloseable {

Map<String, String> metadata();

Map<String, TensorInfo> tensorInfoMap();

AbstractTensor load(String name);

DType getModelDType();
Expand Down
Loading

0 comments on commit 4cccd14

Please sign in to comment.