Skip to content

Commit

Permalink
Prep for next release
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Oct 21, 2024
1 parent 1edf2b7 commit ea9ac07
Show file tree
Hide file tree
Showing 23 changed files with 334 additions and 255 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.builder.SpringApplicationBuilder;
import org.springframework.boot.web.server.ConfigurableWebServerFactory;
import org.springframework.boot.web.server.WebServerFactoryCustomizer;
import org.springframework.context.ApplicationContextInitializer;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.env.ConfigurableEnvironment;
Expand Down Expand Up @@ -67,35 +63,32 @@ public void addResourceHandlers(ResourceHandlerRegistry registry) {
public void run() {
try {
Path modelPath = SimpleBaseCommand.getModel(
modelName,
modelDirectory,
downloadSection.autoDownload,
downloadSection.branch,
downloadSection.authToken);
modelName,
modelDirectory,
downloadSection.autoDownload,
downloadSection.branch,
downloadSection.authToken
);

m = loadModel(
modelPath.toFile(),
workingDirectory,
advancedSection.workingMemoryType,
advancedSection.workingQuantizationType,
Optional.ofNullable(advancedSection.modelQuantization),
Optional.ofNullable(advancedSection.threadCount));
modelPath.toFile(),
workingDirectory,
advancedSection.workingMemoryType,
advancedSection.workingQuantizationType,
Optional.ofNullable(advancedSection.modelQuantization),
Optional.ofNullable(advancedSection.threadCount)
);

System.out.println("Chat UI: http://localhost:" + port);
System.out.println("OpenAI Chat API: http://localhost:" + port + "/chat/completions");

// Use SpringApplicationBuilder with ApplicationContextInitializer to set the port dynamically
new SpringApplicationBuilder(ApiServiceCommand.class)
.initializers(applicationContext -> {
ConfigurableEnvironment environment = applicationContext.getEnvironment();
Map<String, Object> props = new HashMap<>();
props.put("server.port", port); // Set the port here before the server starts
environment.getPropertySources().addFirst(new MapPropertySource("customProps", props));
})
.properties("logging.level.org.springframework.web", "info")
.lazyInitialization(true)
.build()
.run();
new SpringApplicationBuilder(ApiServiceCommand.class).initializers(applicationContext -> {
ConfigurableEnvironment environment = applicationContext.getEnvironment();
Map<String, Object> props = new HashMap<>();
props.put("server.port", port); // Set the port here before the server starts
environment.getPropertySources().addFirst(new MapPropertySource("customProps", props));
}).properties("logging.level.org.springframework.web", "info").lazyInitialization(true).build().run();

} catch (Exception e) {
e.printStackTrace();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ public class QuantizeCommand extends SimpleBaseCommand {
@CommandLine.Parameters(index = "1", arity = "0..1", description = "The output location")
protected Path output;

@CommandLine.Option(names = { "--quantization" }, paramLabel = "ARG", description = "Model quantization type (default: ${DEFAULT-VALUE})", arity = "1", defaultValue = "Q4")
@CommandLine.Option(names = {
"--quantization" }, paramLabel = "ARG", description = "Model quantization type (default: ${DEFAULT-VALUE})", arity = "1", defaultValue = "Q4")
protected DType modelQuantization = DType.Q4;

@CommandLine.Option(names = { "--skip-layer" }, paramLabel = "ARG", description = "Layer name prefix to not quantize (default: ${DEFAULT-VALUE})", defaultValue = "norm")
@CommandLine.Option(names = {
"--skip-layer" }, paramLabel = "ARG", description = "Layer name prefix to not quantize (default: ${DEFAULT-VALUE})", defaultValue = "norm")
protected String[] skipLayerPrefixes;

@CommandLine.Option(names = { "--drop-layer" }, paramLabel = "ARG", description = "Layer name prefix to drop")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import com.github.tjake.jlama.safetensors.SafeTensorSupport;
import com.github.tjake.jlama.util.ProgressReporter;
import com.github.tjake.jlama.util.TriConsumer;
import com.google.common.util.concurrent.Uninterruptibles;
import me.tongfei.progressbar.ProgressBar;
import me.tongfei.progressbar.ProgressBarBuilder;
Expand Down Expand Up @@ -82,7 +81,9 @@ static Optional<ProgressReporter> getProgressConsumer() {

return Optional.of((ProgressReporter) (filename, sizeDownloaded, totalSize) -> {
if (progressRef.get() == null || !progressRef.get().getTaskName().equals(filename)) {
ProgressBarBuilder builder = new ProgressBarBuilder().setTaskName(filename).setInitialMax(totalSize).setStyle(ProgressBarStyle.ASCII);
ProgressBarBuilder builder = new ProgressBarBuilder().setTaskName(filename)
.setInitialMax(totalSize)
.setStyle(ProgressBarStyle.ASCII);

if (totalSize > 1000000) {
builder.setUnit("MB", 1000000);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ public enum Type {
public static float eval(Type t, float x) {
return switch (t) {
case SILU -> (float) (x * (1.0f / (1.0f + FastMath.exp(-x))));
case GELU, GELU_PYTORCH_TANH -> (float) (0.5 * x * (1 + FastMath.tanh(FastMath.sqrt(2 / Math.PI) * (x + 0.044715 * FastMath.pow(x, 3)))));
case GELU, GELU_PYTORCH_TANH -> (float) (0.5 * x * (1 + FastMath.tanh(
FastMath.sqrt(2 / Math.PI) * (x + 0.044715 * FastMath.pow(x, 3))
)));
case TANH -> (float) FastMath.tanh(x);
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public static void pchunk(int offset, int length, BiIntConsumer action) {
int splits = Math.min(length, TensorOperationsProvider.get().parallelSplitSize());
int chunkSize = length / splits;
int remainder = 0;

// Non optimal case, just run in parallel
if (splits == 1) {
splits = length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import com.github.tjake.jlama.util.DebugSupport;
import com.github.tjake.jlama.util.JsonSupport;
import com.github.tjake.jlama.util.MachineSpec;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;

Expand Down Expand Up @@ -125,16 +124,18 @@ protected AbstractModel(

// Check to make sure the model is big enough to support Q4I8 computations
// If not, fall back to F32
if (modelDType == DType.Q4 && workingMemoryQType == DType.I8 &&
((c.embeddingLength/Q8ByteBufferTensor.BLOCK_SIZE) % (FloatVector.SPECIES_PREFERRED.vectorBitSize()/Float.SIZE) != 0 ||
(c.hiddenLength/Q8ByteBufferTensor.BLOCK_SIZE) % (FloatVector.SPECIES_PREFERRED.vectorBitSize()/Float.SIZE) != 0)){
if (modelDType == DType.Q4
&& workingMemoryQType == DType.I8
&& ((c.embeddingLength / Q8ByteBufferTensor.BLOCK_SIZE) % (FloatVector.SPECIES_PREFERRED.vectorBitSize() / Float.SIZE) != 0
|| (c.hiddenLength / Q8ByteBufferTensor.BLOCK_SIZE) % (FloatVector.SPECIES_PREFERRED.vectorBitSize() / Float.SIZE) != 0)) {
workingMemoryQType = DType.F32;
}

// Check to make sure the model is big enough to support Q4I8 computations
// If not, fall back to F32
if (modelDType == DType.Q4 && workingMemoryQType == DType.I8 &&
(c.embeddingLength/Q8ByteBufferTensor.BLOCK_SIZE) % (FloatVector.SPECIES_PREFERRED.vectorBitSize()/Float.SIZE) != 0){
if (modelDType == DType.Q4
&& workingMemoryQType == DType.I8
&& (c.embeddingLength / Q8ByteBufferTensor.BLOCK_SIZE) % (FloatVector.SPECIES_PREFERRED.vectorBitSize() / Float.SIZE) != 0) {
workingMemoryQType = DType.F32;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,94 +42,103 @@ public class TransformerBlock {
final Optional<LayerNorm> preResponseNorm; // After the residual connection

public TransformerBlock(
AbstractModel model,
int layerIndex,
LayerNorm preAttentionNorm,
CausalSelfAttention attention,
LayerNorm postAttentionNorm,
FeedForward ffBlock) {
AbstractModel model,
int layerIndex,
LayerNorm preAttentionNorm,
CausalSelfAttention attention,
LayerNorm postAttentionNorm,
FeedForward ffBlock
) {
this(
model,
layerIndex,
Optional.of(preAttentionNorm),
attention,
Optional.empty(),
Optional.of(postAttentionNorm),
ffBlock,
Optional.empty(),
Optional.empty());
model,
layerIndex,
Optional.of(preAttentionNorm),
attention,
Optional.empty(),
Optional.of(postAttentionNorm),
ffBlock,
Optional.empty(),
Optional.empty()
);
}

public TransformerBlock(
AbstractModel model,
int layerIndex,
CausalSelfAttention attention,
LayerNorm postAttentionNorm,
FeedForward ffBlock,
LayerNorm postFFNorm) {
AbstractModel model,
int layerIndex,
CausalSelfAttention attention,
LayerNorm postAttentionNorm,
FeedForward ffBlock,
LayerNorm postFFNorm
) {
this(
model,
layerIndex,
Optional.empty(),
attention,
Optional.empty(),
Optional.of(postAttentionNorm),
ffBlock,
Optional.empty(),
Optional.of(postFFNorm));
model,
layerIndex,
Optional.empty(),
attention,
Optional.empty(),
Optional.of(postAttentionNorm),
ffBlock,
Optional.empty(),
Optional.of(postFFNorm)
);
}

public TransformerBlock(
AbstractModel model,
int layerIndex,
LayerNorm preAttentionNorm,
CausalSelfAttention attention,
LayerNorm postAttentionNorm,
FeedForward ffBlock,
LayerNorm postFFNorm) {
AbstractModel model,
int layerIndex,
LayerNorm preAttentionNorm,
CausalSelfAttention attention,
LayerNorm postAttentionNorm,
FeedForward ffBlock,
LayerNorm postFFNorm
) {
this(
model,
layerIndex,
Optional.of(preAttentionNorm),
attention,
Optional.empty(),
Optional.of(postAttentionNorm),
ffBlock,
Optional.empty(),
Optional.of(postFFNorm));
model,
layerIndex,
Optional.of(preAttentionNorm),
attention,
Optional.empty(),
Optional.of(postAttentionNorm),
ffBlock,
Optional.empty(),
Optional.of(postFFNorm)
);
}

public TransformerBlock(
AbstractModel model,
int layerIndex,
LayerNorm preAttentionNorm,
CausalSelfAttention attention,
LayerNorm postAttentionNorm,
LayerNorm preFFNorm,
FeedForward ffBlock,
LayerNorm postFFNorm) {
AbstractModel model,
int layerIndex,
LayerNorm preAttentionNorm,
CausalSelfAttention attention,
LayerNorm postAttentionNorm,
LayerNorm preFFNorm,
FeedForward ffBlock,
LayerNorm postFFNorm
) {
this(
model,
layerIndex,
Optional.of(preAttentionNorm),
attention,
Optional.of(postAttentionNorm),
Optional.of(preFFNorm),
ffBlock,
Optional.of(postFFNorm),
Optional.empty());
model,
layerIndex,
Optional.of(preAttentionNorm),
attention,
Optional.of(postAttentionNorm),
Optional.of(preFFNorm),
ffBlock,
Optional.of(postFFNorm),
Optional.empty()
);
}

protected TransformerBlock(
AbstractModel model,
int layerIndex,
Optional<LayerNorm> preAttentionNorm,
CausalSelfAttention attention,
Optional<LayerNorm> postAttentionNorm,
Optional<LayerNorm> preFFNorm,
FeedForward ffBlock,
Optional<LayerNorm> postFFNorm,
Optional<LayerNorm> preResponseNorm) {
AbstractModel model,
int layerIndex,
Optional<LayerNorm> preAttentionNorm,
CausalSelfAttention attention,
Optional<LayerNorm> postAttentionNorm,
Optional<LayerNorm> preFFNorm,
FeedForward ffBlock,
Optional<LayerNorm> postFFNorm,
Optional<LayerNorm> preResponseNorm
) {

this.model = model;
this.layerIndex = layerIndex;
Expand All @@ -147,10 +156,11 @@ public AbstractTensor forward(AbstractTensor embedding, int position, KvBufferCa
}

public AbstractTensor forward(
AbstractTensor embedding,
int position,
KvBufferCache.KvBuffer kvBuffer,
Optional<Consumer<List<AbstractTensor>>> tensorReducer) {
AbstractTensor embedding,
int position,
KvBufferCache.KvBuffer kvBuffer,
Optional<Consumer<List<AbstractTensor>>> tensorReducer
) {

debug("input_emb", embedding, layerIndex);

Expand Down Expand Up @@ -190,18 +200,19 @@ public AbstractTensor forward(

// Release any tmp buffers (embedding is released by caller)
if (lnemb != embedding) lnemb.close();
if (lnattn != postAttention) lnattn.close(); else postAttention.close();
if (lnpreFF != lnattn) lnpreFF.close(); else lnattn.close();
if (lnattn != postAttention) lnattn.close();
else postAttention.close();
if (lnpreFF != lnattn) lnpreFF.close();
else lnattn.close();

return maybeApplyNorm(lnpostFF, preResponseNorm);
}

private AbstractTensor maybeApplyNorm(AbstractTensor tensor, Optional<LayerNorm> norm) {
return norm.map(ln -> {
AbstractTensor o = ln.forward(tensor);
tensor.close();
return o;
}).orElse(tensor);
AbstractTensor o = ln.forward(tensor);
tensor.close();
return o;
}).orElse(tensor);
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
*/
package com.github.tjake.jlama.model.gemma;

import com.github.tjake.jlama.math.ActivationFunction;
import com.github.tjake.jlama.math.FloatConversions;
import com.github.tjake.jlama.model.*;
import com.github.tjake.jlama.model.functions.EmbedInput;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ public Gemma2Config(
activationFunction,
ropeFreqsTheta == null ? 10000.0 : ropeFreqsTheta,
ropeScaling == null ? 1.0 : Double.parseDouble(ropeScaling.get("factor")),
headDim,
finalLogitSoftCapping, attnLogitSoftCapping
headDim,
finalLogitSoftCapping,
attnLogitSoftCapping
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
*/
package com.github.tjake.jlama.model.gemma2;

import com.github.tjake.jlama.math.ActivationFunction;
import com.github.tjake.jlama.math.FloatConversions;
import com.github.tjake.jlama.model.*;
import com.github.tjake.jlama.model.functions.EmbedInput;
Expand Down
Loading

0 comments on commit ea9ac07

Please sign in to comment.