Skip to content

Commit

Permalink
Create multi-release jar that supports Java 20/21/22
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Aug 5, 2024
1 parent e7dc2a5 commit cc996b3
Show file tree
Hide file tree
Showing 38 changed files with 1,017 additions and 919 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Implements:
* Fast GEMM operations
* Distributed Inference!

Jlama is built with Java 21 and utilizes the new [Vector API](https://openjdk.org/jeps/448)
Jlama is requires Java 20 or later and utilizes the new [Vector API](https://openjdk.org/jeps/448)
for faster inference.

## ⭐ Give us a star!
Expand Down
5 changes: 4 additions & 1 deletion jlama-cli/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@
<transformers>
<transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
<mainClass>com.github.tjake.jlama.cli.JlamaCli</mainClass>
<manifestEntries>
<Main-Class>com.github.tjake.jlama.cli.JlamaCli</Main-Class>
<Multi-Release>true</Multi-Release>
</manifestEntries>
</transformer>
</transformers>
<filters>
Expand Down
6 changes: 6 additions & 0 deletions jlama-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,11 @@
<artifactId>jinjava</artifactId>
<version>2.7.2</version>
</dependency>

<dependency>
<groupId>net.fellbaum</groupId>
<artifactId>jemoji</artifactId>
<version>1.4.1</version>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -112,29 +112,6 @@ public static float cosineSimilarity(float[] a, float[] b) {
return (float) (dotProduct / (Math.sqrt(aMagnitude) * Math.sqrt(bMagnitude)));
}

public static void l1normalize(AbstractTensor t) {
float[] x = (float[]) t.getArray();
int offset = t.getArrayOffset(0);
long size = t.size();

float sum = 0.0f;
for (int i = offset; i < size; i++) sum += Math.abs(x[i]);

for (int i = offset; i < size; i++) x[i] /= sum;
}

public static void l2normalize(AbstractTensor t) {
float[] x = (float[]) t.getArray();
int offset = t.getArrayOffset(0);
long size = t.size();

float sum = 0.0f;
for (int i = offset; i < size; i++) sum += x[i] * x[i];

double magnitude = Math.sqrt(sum);
for (int i = offset; i < size; i++) x[i] /= magnitude;
}

public static float[] outerProduct(float[] xs, float[] ys) {
int n = xs.length;
int m = ys.length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import com.github.tjake.jlama.safetensors.tokenizer.BPETokenizer;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import net.fellbaum.jemoji.EmojiManager;

import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.Map;
Expand All @@ -43,7 +45,7 @@ public class GPT2Tokenizer extends BPETokenizer {
// Represent emojis as their badly tokenized strings
codePointsToByteStrings = HashBiMap.create();
for (int j = 9000; j <= 128512; j++) {
if (Character.isEmoji(j)) {
if (EmojiManager.isEmoji(Character.toString(j))) {
byte[] b = Character.toString(j).getBytes(StandardCharsets.UTF_8);
StringBuilder sb = new StringBuilder();
for (int k = 0; k < b.length; k++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -77,7 +79,7 @@ public List<String> tokenize(String sentence) {
if (model.addedTokenPattern() != null) {
// Split the sentence into pieces using the added token pattern
// Any non-added token is split into pieces using the pre-tokenizer
String[] pieces = model.addedTokenPattern().splitWithDelimiters(sentence, 0);
String[] pieces = split(model.addedTokenPattern(), sentence, 0, true);
for (String piece : pieces) {
if (!piece.isEmpty()) {
if (model.addedTokens().containsKey(piece)) sentencePieces.add(piece);
Expand Down Expand Up @@ -233,4 +235,55 @@ public String decode(long[] ids) {
public Optional<PromptSupport> promptSupport() {
return promptSupport.hasPromptTemplates() ? Optional.of(promptSupport) : Optional.empty();
}


// Splitter for added token pattern (optionally with delimiters)
private String[] split(Pattern p, CharSequence input, int limit, boolean withDelimiters) {
int matchCount = 0;
int index = 0;
boolean matchLimited = limit > 0;
ArrayList<String> matchList = new ArrayList<>();
Matcher m = p.matcher(input);

// Add segments before each match found
while(m.find()) {
if (!matchLimited || matchCount < limit - 1) {
if (index == 0 && index == m.start() && m.start() == m.end()) {
// no empty leading substring included for zero-width match
// at the beginning of the input char sequence.
continue;
}
String match = input.subSequence(index, m.start()).toString();
matchList.add(match);
index = m.end();
if (withDelimiters) {
matchList.add(input.subSequence(m.start(), index).toString());
}
++matchCount;
} else if (matchCount == limit - 1) { // last one
String match = input.subSequence(index, input.length()).toString();
matchList.add(match);
index = m.end();
++matchCount;
}
}

// If no match was found, return this
if (index == 0)
return new String[] {input.toString()};

// Add remaining segment
if (!matchLimited || matchCount < limit)
matchList.add(input.subSequence(index, input.length()).toString());

// Construct result
int resultSize = matchList.size();
if (limit == 0) {
while (resultSize > 0 && matchList.get(resultSize-1).isEmpty()) {
resultSize--;
}
}
String[] result = new String[resultSize];
return matchList.subList(0, resultSize).toArray(result);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,12 @@
* This class is abstract because there are multiple implementations
* for different types of data.
**/
public abstract class AbstractTensor<V extends Vector<?>, T extends Number, A> implements AutoCloseable {
public abstract class AbstractTensor<V extends Vector<?>, T extends Number> implements AutoCloseable {
private static final Logger logger = LoggerFactory.getLogger(AbstractTensor.class);

protected final TensorShape shape;
protected final DType dType;
protected final AbstractTensor[] sliceCache;
protected final boolean requiresOffHeapTensor;
protected final Map<String, Object> metadata;
private final int stride;
private volatile TensorCache originCache = null;
Expand All @@ -58,7 +57,6 @@ protected AbstractTensor(DType dType, TensorShape shape, boolean cacheSlices) {
Preconditions.checkArgument(shape != null && shape.dims() > 0);
this.dType = dType;
this.shape = shape;
this.requiresOffHeapTensor = TensorOperationsProvider.get().requiresOffHeapTensor();
this.metadata = new HashMap<>();
this.sliceCache = cacheSlices ? new AbstractTensor[shape.first()] : null;
this.stride = shape.first() > 1 && dims() == 2 ? getOffset(1, shape.sparseOffset()) : 0;
Expand Down Expand Up @@ -151,12 +149,12 @@ public AbstractTensor slice(boolean cacheInnerSlice, int... dims) {
* Creates a sparse tensor that acts like a dense one but is missing the data outside
* the range of in last dimension.
*/
public AbstractTensor<V, T, A> sparsify(int offset, int length) {
public AbstractTensor<V, T> sparsify(int offset, int length) {
if (shape.isSparse()) return this;

if (length == shape.last()) return this;

AbstractTensor<V, T, A> sparseT = this.make(shape.sparsify(offset, length));
AbstractTensor<V, T> sparseT = this.make(shape.sparsify(offset, length));
int originalLength = shape.last();

int[] cursor = new int[shape.dims()];
Expand Down Expand Up @@ -253,10 +251,6 @@ public final DType dType() {
return dType;
}

public abstract A getArray();

public abstract int getArrayOffset(int offset);

public abstract V getVector(VectorSpecies<T> species, int... offset);

public abstract void intoTensor(V vector, int... offset);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.VectorSpecies;

public class BFloat16BufferTensor extends AbstractTensor<ShortVector, Short, short[]> {
public class BFloat16BufferTensor extends AbstractTensor<ShortVector, Short> {

private final ShortBuffer b;
private final String name;
Expand All @@ -51,13 +51,10 @@ public BFloat16BufferTensor(int... shape) {
public BFloat16BufferTensor(TensorShape shape) {
super(DType.BF16, shape, true);
this.name = "tmp";
if (TensorOperationsProvider.get().requiresOffHeapTensor()) {
this.b = UnsafeDirectByteBuffer.allocateAlignedByteBuffer(
Ints.checkedCast(size() * dType().size()), UnsafeDirectByteBuffer.CACHE_LINE_SIZE)
.asShortBuffer();
} else {
this.b = ShortBuffer.allocate(Ints.checkedCast(size()));
}
this.b = UnsafeDirectByteBuffer.allocateAlignedByteBuffer(
Ints.checkedCast(size() * dType().size()), UnsafeDirectByteBuffer.CACHE_LINE_SIZE)
.asShortBuffer();

this.segment = MemorySegment.ofBuffer(b);
}

Expand Down Expand Up @@ -93,34 +90,17 @@ public void set(float v, int... dims) {
b.put(getOffset(dims), FloatConversions.float32ToBFloat16(v));
}

@Override
public short[] getArray() {
if (b.hasArray()) return b.array();
else throw new UnsupportedOperationException("Can't get array from direct buffer");
}

@Override
public int getArrayOffset(int offset) {
return b.arrayOffset() + offset;
}

@Override
public ShortVector getVector(VectorSpecies<Short> species, int... voffset) {
int offset = getOffset(voffset);
if (!TensorOperationsProvider.get().requiresOffHeapTensor())
return ShortVector.fromArray(species, getArray(), getArrayOffset(offset));
else
return ShortVector.fromMemorySegment(
species, segment, getMemorySegmentOffset(offset), ByteOrder.LITTLE_ENDIAN);
return ShortVector.fromMemorySegment(species, segment, getMemorySegmentOffset(offset), ByteOrder.LITTLE_ENDIAN);
}

@Override
public void intoTensor(ShortVector vector, int... aoffset) {
Preconditions.checkArgument(!b.isReadOnly());
int offset = getOffset(aoffset);
if (!TensorOperationsProvider.get().requiresOffHeapTensor())
vector.intoArray(getArray(), getArrayOffset(offset));
else vector.intoMemorySegment(segment, getMemorySegmentOffset(offset), ByteOrder.LITTLE_ENDIAN);
vector.intoMemorySegment(segment, getMemorySegmentOffset(offset), ByteOrder.LITTLE_ENDIAN);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.VectorSpecies;

public class Float16BufferTensor extends AbstractTensor<ShortVector, Short, short[]> {
public class Float16BufferTensor extends AbstractTensor<ShortVector, Short> {
private final ShortBuffer b;
private final String name;
private final MemorySegment segment;
Expand All @@ -49,13 +49,10 @@ public Float16BufferTensor(int... shape) {
public Float16BufferTensor(TensorShape shape) {
super(DType.F16, shape, true);
this.name = "tmp";
if (TensorOperationsProvider.get().requiresOffHeapTensor()) {
this.b = UnsafeDirectByteBuffer.allocateAlignedByteBuffer(
Ints.checkedCast(size() * dType().size()), UnsafeDirectByteBuffer.CACHE_LINE_SIZE)
.asShortBuffer();
} else {
this.b = ShortBuffer.allocate(Ints.checkedCast(size()));
}
this.b = UnsafeDirectByteBuffer.allocateAlignedByteBuffer(
Ints.checkedCast(size() * dType().size()), UnsafeDirectByteBuffer.CACHE_LINE_SIZE)
.asShortBuffer();

this.segment = MemorySegment.ofBuffer(b);
}

Expand Down Expand Up @@ -96,34 +93,17 @@ public void set(float v, int... dims) {
b.put(getOffset(dims), Float.floatToFloat16(v));
}

@Override
public short[] getArray() {
if (b.hasArray()) return b.array();
else throw new UnsupportedOperationException("Can't get array from direct buffer");
}

@Override
public int getArrayOffset(int offset) {
return b.arrayOffset() + offset;
}

@Override
public ShortVector getVector(VectorSpecies<Short> species, int... voffset) {
int offset = getOffset(voffset);
if (!TensorOperationsProvider.get().requiresOffHeapTensor())
return ShortVector.fromArray(species, getArray(), getArrayOffset(offset));
else
return ShortVector.fromMemorySegment(
species, segment, getMemorySegmentOffset(offset), ByteOrder.LITTLE_ENDIAN);
return ShortVector.fromMemorySegment(species, segment, getMemorySegmentOffset(offset), ByteOrder.LITTLE_ENDIAN);
}

@Override
public void intoTensor(ShortVector vector, int... aoffset) {
Preconditions.checkArgument(!b.isReadOnly());
int offset = getOffset(aoffset);
if (!TensorOperationsProvider.get().requiresOffHeapTensor())
vector.intoArray(getArray(), getArrayOffset(offset));
else vector.intoMemorySegment(segment, getMemorySegmentOffset(offset), ByteOrder.LITTLE_ENDIAN);
vector.intoMemorySegment(segment, getMemorySegmentOffset(offset), ByteOrder.LITTLE_ENDIAN);
}

@Override
Expand Down
Loading

0 comments on commit cc996b3

Please sign in to comment.