diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000..6313b56
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1 @@
+* text=auto eol=lf
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
new file mode 100644
index 0000000..0b527df
--- /dev/null
+++ b/.github/workflows/ci.yml
@@ -0,0 +1,168 @@
+name: ci
+
+permissions:
+ contents: read
+ pull-requests: read
+ actions: read
+
+env:
+ VERSION: "${{ github.ref_name }}"
+
+on:
+ workflow_dispatch: { }
+ push:
+ tags:
+ - "v*.*.*"
+ branches:
+ - main
+ paths-ignore:
+ - "docs/**"
+ - "**.md"
+ - "**.mdx"
+ - "**.png"
+ - "**.jpg"
+ pull_request:
+ branches:
+ - main
+ paths-ignore:
+ - "docs/**"
+ - "**.md"
+ - "**.mdx"
+ - "**.png"
+ - "**.jpg"
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
+ cancel-in-progress: true
+
+jobs:
+ darwin-metal:
+ if: ${{ false }}
+ strategy:
+ fail-fast: false
+ matrix:
+ # see https://github.com/actions/runner-images?tab=readme-ov-file#available-images.
+ os: [ macos-13, macos-14 ]
+ runs-on: ${{ matrix.os }}
+ steps:
+ - name: Setup XCode
+ if: ${{ matrix.os == 'macos-13' }}
+ uses: maxim-lobanov/setup-xcode@v1
+ with:
+ xcode-version: '15.2'
+ - name: Checkout
+ uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+ submodules: 'recursive'
+ - name: Deps
+ continue-on-error: true
+ run: |
+ brew update
+ - name: Build
+ run: |
+ make -j LLAMA_METAL=1
+
+ echo "===== info ====="
+ file ./.dist/llama-box
+ - name: Release
+ if: ${{ startsWith(github.ref, 'refs/tags/') }}
+ uses: actions/upload-artifact@v4
+ with:
+ path: ./.dist/llama-box
+ name: llama-box-darwin-${{ endsWith(matrix.os, '-13') && 'amd64' || 'arm64' }}-metal
+
+ linux-hip:
+ strategy:
+ fail-fast: false
+ matrix:
+ arch: [ amd64 ]
+ version: [ '6.0.2' ]
+ runs-on: ubuntu-22.04
+ steps:
+ - name: Maximize Space
+ run: |
+ sudo rm -rf /usr/share/dotnet
+ sudo rm -rf /usr/local/lib/android
+ sudo rm -rf /opt/ghc
+ sudo rm -rf /opt/hostedtoolcache/CodeQL
+ sudo docker image prune --all --force
+ - name: Clone
+ uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+ submodules: 'recursive'
+ - name: Setup QEMU
+ uses: docker/setup-qemu-action@v3
+ with:
+ image: tonistiigi/binfmt:qemu-v7.0.0
+ platforms: "arm64"
+ - name: Build
+ run: |
+ docker run \
+ --rm \
+ --privileged \
+ --platform linux/${{ matrix.arch }} \
+ --volume $(pwd):/workspace \
+ --workdir /workspace \
+ --env CC=/opt/rocm/llvm/bin/clang \
+ --env CXX=/opt/rocm/llvm/bin/clang++ \
+ --env GPU_TARGETS="gfx803 gfx900 gfx906 gfx908 gfx90a gfx1010 gfx1030 gfx1100 gfx1101 gfx1102" \
+ rocm/dev-ubuntu-22.04:${{ matrix.version }} \
+ apt-get udpate && apt-get install -y build-essential git rocblas-dev hipblas-dev && make -j LLAMA_HIPBLAS=1
+
+ echo "===== info ====="
+ file ./.dist/llama-box
+ - name: Release
+ if: ${{ startsWith(github.ref, 'refs/tags/') }}
+ uses: actions/upload-artifact@v4
+ with:
+ path: ./.dist/llama-box
+ name: llama-box-linux-${{ matrix.arch }}-hip
+
+ linux-cuda:
+ if: ${{ false }}
+ strategy:
+ fail-fast: false
+ matrix:
+ arch: [ amd64, arm64 ]
+ version: [ '11.7.1' ]
+ runs-on: ubuntu-22.04
+ steps:
+ - name: Maximize Space
+ run: |
+ sudo rm -rf /usr/share/dotnet
+ sudo rm -rf /usr/local/lib/android
+ sudo rm -rf /opt/ghc
+ sudo rm -rf /opt/hostedtoolcache/CodeQL
+ sudo docker image prune --all --force
+ - name: Clone
+ uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+ submodules: 'recursive'
+ - name: Setup QEMU
+ uses: docker/setup-qemu-action@v3
+ with:
+ image: tonistiigi/binfmt:qemu-v7.0.0
+ platforms: "arm64"
+ - name: Build
+ run: |
+ docker run \
+ --rm \
+ --privileged \
+ --platform linux/${{ matrix.arch }} \
+ --volume $(pwd):/workspace \
+ --workdir /workspace \
+ --env CUDA_DOCKER_ARCH=all \
+ nvidia/cuda:${{ matrix.version }}-devel-ubuntu22.04 \
+ apt-get update && apt-get install -y build-essential git && make -j LLAMA_CUDA=1
+
+ echo "===== info ====="
+ file ./.dist/llama-box
+ - name: Release
+ if: ${{ startsWith(github.ref, 'refs/tags/') }}
+ uses: actions/upload-artifact@v4
+ with:
+ path: ./.dist/llama-box
+ name: llama-box-linux-${{ matrix.arch }}-cuda
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..76bb781
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,34 @@
+# Files
+.DS_Store
+*.o
+*.a
+*.so
+*.gguf
+*.bin
+*.exe
+*.exe~
+*.dll
+*.dylib
+*.log
+*.dot
+*.bat
+*.tmp
+*.metallib
+*.out
+*.swp
+*.swo
+.clang-tidy
+version.cpp
+
+# Directories
+.idea/
+.vscode/
+.vs/
+.build/
+.cache/
+.ccls-cache/
+.direnv/
+.sbin/
+.dist/
+out/
+tmp/
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 0000000..81bc6f3
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "llama.cpp"]
+ path = llama.cpp
+ url = https://github.com/ggerganov/llama.cpp
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..c059907
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2024 The llama-box authors
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..9cec98f
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,275 @@
+# Inspired by https://github.com/ggerganov/llama.cpp/blob/61665277afde2add00c0d387acb94ed5feb95917/Makefile.
+
+.SILENT:
+.DEFAULT_GOAL := build
+
+SHELL := /bin/bash
+
+MK_DIR := $(patsubst %/,%,$(dir $(abspath $(lastword $(MAKEFILE_LIST)))))
+MK_FLAGS:= $(wordlist 3, $(words $(MAKEFLAGS)), $(MAKEFLAGS))
+
+#
+# System flags
+#
+
+ifndef UNAME_S
+ UNAME_S := $(shell uname -s)
+endif
+ifndef UNAME_P
+ UNAME_P := $(shell uname -p)
+endif
+ifndef UNAME_M
+ UNAME_M := $(shell uname -m)
+endif
+
+ifeq ($(origin CC),default)
+ CC := cc
+endif
+ifeq ($(origin CXX),default)
+ CXX := c++
+endif
+
+ifndef LLAMA_NO_CCACHE
+ CCACHE := $(shell which ccache)
+ ifdef CCACHE
+ export CCACHE_SLOPPINESS = time_macros
+ CC := $(CCACHE) $(CC)
+ CXX := $(CCACHE) $(CXX)
+ endif
+endif
+
+## Mac OS + Arm can report x86_64
+## ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
+ifeq ($(UNAME_S),Darwin)
+ ifndef LLAMA_NO_METAL
+ LLAMA_METAL := 1
+ endif
+ ifneq ($(UNAME_P),arm)
+ SYSCTL_M := $(shell sysctl -n hw.optional.arm64 2>/dev/null)
+ ifeq ($(SYSCTL_M),1)
+ # UNAME_P := arm
+ # UNAME_M := arm64
+ warn := $(warning Your arch is announced as x86_64, but it seems to actually be ARM64. Not fixing that can lead to bad performance. For more info see: https://github.com/ggerganov/whisper.cpp/issues/66\#issuecomment-1282546789)
+ endif
+ endif
+endif
+
+#
+# Compile flags
+#
+
+## standard
+MK_CPPFLAGS = -I$(MK_DIR) -I$(MK_DIR)/llama.cpp -I$(MK_DIR)/llama.cpp/common
+MK_CFLAGS = -std=c11 -fPIC
+MK_CXXFLAGS = -std=c++11 -fPIC
+
+## debug or optimization
+ifdef LLAMA_DEBUG
+ MK_CFLAGS += -O0 -g
+ MK_CXXFLAGS += -O0 -g
+ MK_LDFLAGS += -g
+ ifeq ($(UNAME_S),Darwin)
+ MK_CPPFLAGS += -D_GLIBCXX_ASSERTIONS
+ endif
+else
+ MK_CPPFLAGS += -DNDEBUG
+ ifdef LLAMA_FAST
+ MK_CFLAGS += -Ofast
+ MK_CXXFLAGS += -Ofast
+ else
+ MK_CFLAGS += -O3
+ MK_CXXFLAGS += -O3
+ endif
+endif
+
+## warning
+MK_CFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function \
+ -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes -Werror=implicit-int \
+ -Werror=implicit-function-declaration
+MK_CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function \
+ -Wmissing-declarations -Wmissing-noreturn
+ifdef LLAMA_FATAL_WARNINGS
+ MK_CFLAGS += -Werror
+ MK_CXXFLAGS += -Werror
+endif
+
+## os specific
+### thread
+ifneq '' '$(filter $(UNAME_S),Linux Darwin FreeBSD NetBSD OpenBSD Haiku)'
+ MK_CFLAGS += -pthread
+ MK_CXXFLAGS += -pthread
+endif
+### windows
+ifneq ($(findstring _NT,$(UNAME_S)),)
+ _WIN32 := 1
+ LWINSOCK2 := -lws2_32
+endif
+
+## arch specific
+ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686 amd64))
+ # Use all CPU extensions that are available:
+ MK_CFLAGS += -march=native -mtune=native
+ HOST_CXXFLAGS += -march=native -mtune=native
+
+ # Usage AVX-only
+ #MK_CFLAGS += -mfma -mf16c -mavx
+ #MK_CXXFLAGS += -mfma -mf16c -mavx
+
+ # Usage SSSE3-only (Not is SSE3!)
+ #MK_CFLAGS += -mssse3
+ #MK_CXXFLAGS += -mssse3
+endif
+ifneq '' '$(findstring mingw,$(shell $(CC) -dumpmachine))'
+ # The stack is only 16-byte aligned on Windows, so don't let gcc emit aligned moves.
+ # https://gcc.gnu.org/bugzilla/show_bug.cgi?id=54412
+ # https://github.com/ggerganov/llama.cpp/issues/2922
+ MK_CFLAGS += -Xassembler -muse-unaligned-vector-move
+ MK_CXXFLAGS += -Xassembler -muse-unaligned-vector-move
+
+ # Target Windows 8 for PrefetchVirtualMemory
+ MK_CPPFLAGS += -D_WIN32_WINNT=0x602
+endif
+ifneq ($(filter aarch64%,$(UNAME_M)),)
+ # Apple M1, M2, etc.
+ # Raspberry Pi 3, 4, Zero 2 (64-bit)
+ # Nvidia Jetson
+ MK_CFLAGS += -mcpu=native
+ MK_CXXFLAGS += -mcpu=native
+ JETSON_RELEASE_INFO = $(shell jetson_release)
+ ifdef JETSON_RELEASE_INFO
+ ifneq ($(filter TX2%,$(JETSON_RELEASE_INFO)),)
+ JETSON_EOL_MODULE_DETECT = 1
+ CC = aarch64-unknown-linux-gnu-gcc
+ cxx = aarch64-unknown-linux-gnu-g++
+ endif
+ endif
+endif
+ifneq ($(filter armv6%,$(UNAME_M)),)
+ # Raspberry Pi 1, Zero
+ MK_CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access
+ MK_CXXFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access
+endif
+ifneq ($(filter armv7%,$(UNAME_M)),)
+ # Raspberry Pi 2
+ MK_CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations
+ MK_CXXFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations
+endif
+ifneq ($(filter armv8%,$(UNAME_M)),)
+ # Raspberry Pi 3, 4, Zero 2 (32-bit)
+ MK_CFLAGS += -mfp16-format=ieee -mno-unaligned-access
+ MK_CXXFLAGS += -mfp16-format=ieee -mno-unaligned-access
+endif
+ifneq ($(filter ppc64%,$(UNAME_M)),)
+ POWER9_M := $(shell grep "POWER9" /proc/cpuinfo)
+ ifneq (,$(findstring POWER9,$(POWER9_M)))
+ MK_CFLAGS += -mcpu=power9
+ MK_CXXFLAGS += -mcpu=power9
+ endif
+endif
+ifneq ($(filter ppc64le%,$(UNAME_M)),)
+ MK_CFLAGS += -mcpu=powerpc64le
+ MK_CXXFLAGS += -mcpu=powerpc64le
+ CUDA_POWER_ARCH = 1
+endif
+ifneq ($(filter loongarch64%,$(UNAME_M)),)
+ MK_CFLAGS += -mlasx
+ MK_CXXFLAGS += -mlasx
+endif
+ifneq ($(filter riscv64%,$(UNAME_M)),)
+ MK_CFLAGS += -march=rv64gcv -mabi=lp64d
+ MK_CXXFLAGS += -march=rv64gcv -mabi=lp64d
+endif
+
+## platform specific
+ifdef LLAMA_METAL
+ MK_LDFLAGS += -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
+endif
+ifdef LLAMA_HIPBLAS
+ ifeq ($(wildcard /opt/rocm),)
+ ROCM_PATH ?= /usr
+ else
+ ROCM_PATH ?= /opt/rocm
+ endif
+ MK_LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib
+ MK_LDFLAGS += -L$(ROCM_PATH)/lib64 -Wl,-rpath=$(ROCM_PATH)/lib64
+ MK_LDFLAGS += -lhipblas -lamdhip64 -lrocblas
+endif
+
+## combine build flags with cmdline overrides
+override CPPFLAGS := $(MK_CPPFLAGS) $(CPPFLAGS)
+override CFLAGS := $(CPPFLAGS) $(MK_CFLAGS) $(CFLAGS)
+override CXXFLAGS := $(MK_CXXFLAGS) $(CXXFLAGS) $(CPPFLAGS)
+override LDFLAGS := $(MK_LDFLAGS) $(LDFLAGS)
+
+#
+# Helper function
+#
+
+## BUILD_INFO prints out the build info
+define BUILD_INFO
+ @echo "I llama-box build info:"
+ @echo "I UNAME_S: $(UNAME_S)"
+ @echo "I UNAME_P: $(UNAME_P)"
+ @echo "I UNAME_M: $(UNAME_M)"
+ @echo "I CFLAGS: $(CFLAGS)"
+ @echo "I CXXFLAGS: $(CXXFLAGS)"
+ @echo "I LDFLAGS: $(LDFLAGS)"
+ @echo "I CC: $(shell $(CC) --version | head -n 1)"
+ @echo "I CXX: $(shell $(CXX) --version | head -n 1)"
+ @echo
+endef
+
+## GET_OBJ_FILE replaces .c, .cpp, and .cu file endings with .o
+define GET_OBJ_FILE
+ $(patsubst %.c,%.o,$(patsubst %.cpp,%.o,$(patsubst %.cu,%.o,$(1))))
+endef
+
+#
+# Main function
+#
+
+llama-box/version.cpp: $(wildcard .git/index) llama-box/scripts/version.sh
+ @sh $(MK_DIR)/llama-box/scripts/version.sh > $@.tmp
+ @if ! cmp -s $@ $@.tmp; then mv $@.tmp $@; else rm $@.tmp; fi
+
+.PHONY: clean
+clean:
+ # first, clean llama-server.
+ @echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+ @echo "I cleaning llama.cpp"
+ @echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+ifdef LLAMA_METAL
+ make -C $(MK_DIR)/llama.cpp -j $(MK_FLAGS) LLAMA_METAL=1 LLAMA_METAL_EMBED_LIBRARY=1 clean
+else
+ make -C $(MK_DIR)/llama.cpp -j $(MK_FLAGS) LLAMA_NO_METAL=1 clean
+endif
+ # then, clean llama-box.
+ @echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+ @echo "I cleaning llama-box"
+ @echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+ $(call BUILD_INFO)
+ find $(MK_DIR)/llama-box -type f -name "*.o" -delete
+ rm -f $(MK_DIR)/llama-box/version.cpp
+
+.PHONY: build
+build: llama-box/version.cpp llama-box/main.cpp
+ # first, build llama-server.
+ @echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+ @echo "I building llama.cpp"
+ @echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+ifdef LLAMA_METAL
+ make -C $(MK_DIR)/llama.cpp -j $(MK_FLAGS) LLAMA_METAL=1 LLAMA_METAL_EMBED_LIBRARY=1 libllama.a
+else
+ make -C $(MK_DIR)/llama.cpp -j $(MK_FLAGS) LLAMA_NO_METAL=1 libllama.a
+endif
+ # then, build llama-box.
+ @echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+ @echo "I building llama-box"
+ @echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+ifeq ($(_WIN32),1)
+SUFFIX := .exe
+endif
+ $(call BUILD_INFO)
+ mkdir -p $(MK_DIR)/.dist
+ $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
+ $(CXX) $(CXXFLAGS) $(MK_DIR)/llama.cpp/libllama.a $(filter-out %.h %.hpp $<,$^) $(call GET_OBJ_FILE, $<) -o $(MK_DIR)/.dist/llama-box$(SUFFIX) $(LDFLAGS) $(LWINSOCK2)
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..431760a
--- /dev/null
+++ b/README.md
@@ -0,0 +1,197 @@
+# llama-box
+
+[![ci status](https://github.com/thxcode/llama-box/actions/workflows/ci.yml/badge.svg)](https://github.com/thxcode/llama-box/actions/workflows/ci.yml)
+
+LLaMA box is a cleaning LLMs inference server rather
+than [llama-server](https://github.com/ggerganov/llama.cpp/blob/master/examples/server).
+
+## Usage
+
+```shell
+usage: llama-box [options]
+
+general:
+
+ -h, --help, --usage print usage and exit
+ --version show version and build info
+ -m, --model FILE model path (default: models/7B/ggml-model-f16.gguf)
+ -a, --alias NAME model name alias (default: unknown)
+ -s, --seed N RNG seed (default: -1, use random seed for < 0)
+ -t, --threads N number of threads to use during generation (default: 8)
+ -tb, --threads-batch N number of threads to use during batch and prompt processing (default: same as --threads)
+ -lcs, --lookup-cache-static FILE
+ path to static lookup cache to use for lookup decoding (not updated by generation)
+ -lcd, --lookup-cache-dynamic FILE
+ path to dynamic lookup cache to use for lookup decoding (updated by generation)
+ -c, --ctx-size N size of the prompt context (default: 0, 0 = loaded from model)
+ -n, --predict N number of tokens to predict (default: -1, -1 = infinity, -2 = until context filled)
+ -b, --batch-size N logical maximum batch size (default: 2048)
+ -ub, --ubatch-size N physical maximum batch size (default: 512)
+ --keep N number of tokens to keep from the initial prompt (default: 0, -1 = all)
+ --chunks N max number of chunks to process (default: -1, -1 = all)
+ -fa, --flash-attn enable Flash Attention (default: disabled)
+ --no-escape do not process escape sequences
+ --samplers SAMPLERS samplers that will be used for generation in the order, separated by ';'
+ (default: top_k;tfs_z;typical_p;top_p;min_p;temperature)
+ --sampling-seq SEQUENCE simplified sequence for samplers that will be used (default: kfypmt)
+ --penalize-nl penalize newline tokens (default: false)
+ --temp N temperature (default: 0.8)
+ --top-k N top-k sampling (default: 40, 0 = disabled)
+ --top-p N top-p sampling (default: 0.9, 1.0 = disabled)
+ --min-p N min-p sampling (default: 0.1, 0.0 = disabled)
+ --tfs N tail free sampling, parameter z (default: 1.0, 1.0 = disabled)
+ --typical N locally typical sampling, parameter p (default: 1.0, 1.0 = disabled)
+ --repeat-last-n N last n tokens to consider for penalize (default: 64, 0 = disabled, -1 = ctx_size)
+ --repeat-penalty N penalize repeat sequence of tokens (default: 1.0, 1.0 = disabled)
+ --presence-penalty N repeat alpha presence penalty (default: 0.0, 0.0 = disabled)
+ --frequency-penalty N repeat alpha frequency penalty (default: 0.0, 0.0 = disabled)
+ --dynatemp-range N dynamic temperature range (default: 0.0, 0.0 = disabled)
+ --dynatemp-exp N dynamic temperature exponent (default: 1.0)
+ --mirostat N use Mirostat sampling.
+ Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.
+ (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)
+ --mirostat-lr N Mirostat learning rate, parameter eta (default: 0.1)
+ --mirostat-ent N Mirostat target entropy, parameter tau (default: 5.0)
+ -l --logit-bias TOKEN_ID(+/-)BIAS
+ modifies the likelihood of token appearing in the completion,
+ i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',
+ or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'
+ --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '')
+ --grammar-file FILE file to read grammar from
+ -j, --json-schema SCHEMA JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object
+ For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead
+ --rope-scaling {none,linear,yarn}
+ RoPE frequency scaling method, defaults to linear unless specified by the model
+ --rope-scale N RoPE context scaling factor, expands context by a factor of N
+ --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)
+ --rope-freq-scale N RoPE frequency scaling factor, expands context by a factor of 1/N
+ --yarn-orig-ctx N YaRN: original context size of model (default: 0 = model training context size)
+ --yarn-ext-factor N YaRN: extrapolation mix factor (default: -1.0, 0.0 = full interpolation)
+ --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)
+ --yarn-beta-fast N YaRN: low correction dim or beta (default: 32.0)
+ --yarn-beta-slow N YaRN: high correction dim or alpha (default: 1.0)
+ -gan, --grp-attn-n N group-attention factor (default: 1)
+ -gaw, --grp-attn-w N group-attention width (default: 512.0)
+ -nkvo, --no-kv-offload disable KV offload
+ -ctk, --cache-type-k TYPE KV cache data type for K (default: f16)
+ -ctv, --cache-type-v TYPE KV cache data type for V (default: f16)
+ -dt, --defrag-thold N KV cache defragmentation threshold (default: -1.0, < 0 - disabled)
+ -np, --parallel N number of parallel sequences to decode (default: 1)
+ -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: enabled)
+ --mmproj FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md
+ --mlock force system to keep model in RAM rather than swapping or compressing
+ --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)
+ --numa TYPE attempt optimizations that help on some NUMA systems
+ - distribute: spread execution evenly over all nodes
+ - isolate: only spawn threads on CPUs on the node that execution started on
+ - numactl: use the CPU map provided by numactl
+ if run without this previously, it is recommended to drop the system page cache before using this
+ see https://github.com/ggerganov/llama.cpp/issues/1437
+ --override-kv KEY=TYPE:VALUE
+ advanced option to override model metadata by key. may be specified multiple times.
+ types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false
+ --lora FILE apply LoRA adapter (implies --no-mmap)
+ --lora-scaled FILE SCALE
+ apply LoRA adapter with user defined scaling S (implies --no-mmap)
+ --lora-base FILE optional model to use as a base for the layers modified by the LoRA adapter
+ --control-vector FILE add a control vector
+ --control-vector-scaled FILE SCALE
+ add a control vector with user defined scaling SCALE
+ --control-vector-layer-range START END
+ layer range to apply the control vector(s) to, start and end inclusive
+ -ngl, --gpu-layers N number of layers to store in VRAM
+ -sm, --split-mode SPLIT_MODE how to split the model across multiple GPUs, one of:
+ - none: use one GPU only
+ - layer (default): split layers and KV across GPUs
+ - row: split rows across GPUs
+ -ts, --tensor-split SPLIT fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1
+ -mg, --main-gpu N the GPU to use for the model (with split-mode = none),
+ or for intermediate results and KV (with split-mode = row) (default: 0)
+
+server:
+
+ --host HOST ip address to listen (default: 127.0.0.1)
+ --port PORT port to listen (default: 8080)
+ -to --timeout N server read/write timeout in seconds (default: 600)
+ --threads-http N number of threads used to process HTTP requests (default: -1)
+ --system-prompt-file FILE
+ set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications
+ --metrics enable prometheus compatible metrics endpoint (default: disabled)
+ --infill enable infill endpoint (default: disabled)
+ --embeddings enable embedding endpoint (default: disabled)
+ --no-slots disables slots monitoring endpoint (default: enabled)
+ --slot-save-path PATH path to save slot kv cache (default: disabled)
+ --chat-template JINJA_TEMPLATE
+ set custom jinja chat template (default: template taken from model's metadata)
+ only commonly used templates are accepted:
+ https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
+ --chat-template-file FILE
+ set a file to load a custom jinja chat template
+ -sps, --slot-prompt-similarity N
+ how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)
+
+ --conn-idle N server connection idle in seconds (default: 60)
+ --conn-keepalive N server connection keep-alive in seconds (default: 15)
+ -tps --tokens-per-second N maximum number of tokens per second (default: 0, 0 = disabled, -1 = try to detect)
+
+logging:
+
+ --log-format {text,json}
+ log output format: json or text (default: json)
+```
+
+## API Endpoints
+
+- **GET** `/health`: Returns the current state of the llama-box.
+ + 503 -> `{"status": "loading model"}` if the model is still being loaded.
+ + 500 -> `{"status": "error"}` if the model failed to load.
+ + 200 -> `{"status": "ok", "slots_idle": 1, "slots_processing": 2 }` if the model is successfully loaded and the server is ready for further requests mentioned below.
+ + 200 -> `{"status": "no slot available", "slots_idle": 0, "slots_processing": 32}` if no slots are currently available.
+ + 503 -> `{"status": "no slot available", "slots_idle": 0, "slots_processing": 32}` if the query parameter `fail_on_no_slot` is provided and no slots are currently available.
+
+- **GET** `/metrics`: Returns the Prometheus compatible metrics of the llama-box.
+ + This endpoint is only available if the `--metrics` flag is enabled.
+ + `llamacpp:prompt_tokens_total`: (Counter) Number of prompt tokens processed.
+ + `llamacpp:prompt_seconds_total`: (Counter) Prompt process time.
+ + `llamacpp:tokens_predicted_total`: (Counter) Number of generation tokens processed.
+ + `llamacpp:tokens_predicted_seconds_total`: (Counter) Predict process time.
+ + `llamacpp:prompt_tokens_seconds`: (Gauge) Average prompt throughput in tokens/s.
+ + `llamacpp:predicted_tokens_seconds`: (Gauge) Average generation throughput in tokens/s.
+ + `llamacpp:kv_cache_usage_ratio`: (Gauge) KV-cache usage. 1 means 100 percent usage.
+ + `llamacpp:kv_cache_tokens`: (Gauge) KV-cache tokens.
+ + `llamacpp:requests_processing`: (Gauge) Number of request processing.
+ + `llamacpp:requests_deferred`: (Gauge) Number of request deferred.
+
+- **GET** `/props`: Returns current server settings.
+
+- **POST** `/infill`: Returns the completion of the given prompt.
+ + This endpoint is only available if the `--infill` flag is enabled.
+
+- **POST** `/tokenize`: Convert text to tokens.
+
+- **POST** `/detokenize`: Convert tokens to text.
+
+- **GET** `/slots`: Returns the current slots processing state.
+ + This endpoint is only available if the `--no-slots` flag is no provided.
+
+- **POST** `/slots/:id_slot?action={save|restore|erase}`: Operate specific slot via ID.
+ + This endpoint is only available if the `--no-slots` flag is no provided and `--slot-save-path` is provided.
+
+- **POST** `/completion`: Returns the completion of the given prompt.
+
+- **GET** `/v1/models`: (OpenAI-compatible) Returns the list of available models,
+ see https://platform.openai.com/docs/api-reference/models/list.
+
+- **POST** `/v1/completions`: (OpenAI-compatible) Returns the completion of the given prompt,
+ see https://platform.openai.com/docs/api-reference/completions/create.
+
+- **POST** `/v1/chat/completions` (OpenAI-compatible) Returns the completion of the given prompt,
+ see https://platform.openai.com/docs/api-reference/chat/create.
+
+- **POST** `/v1/embeddings`: (OpenAI-compatible) Returns the embeddings of the given prompt,
+ see https://platform.openai.com/docs/api-reference/embeddings/create.
+ + This endpoint is only available if the `--embeddings` flag is enabled.
+
+## License
+
+MIT
diff --git a/llama-box/.clang-format b/llama-box/.clang-format
new file mode 100644
index 0000000..3728ce1
--- /dev/null
+++ b/llama-box/.clang-format
@@ -0,0 +1,14 @@
+---
+BasedOnStyle: LLVM
+IndentWidth: 4
+ColumnLimit: 100
+SeparateDefinitionBlocks: Always
+---
+Language: Cpp
+AllowShortFunctionsOnASingleLine: None
+AlignTrailingComments: true
+AlignEscapedNewlines: Left
+AlwaysBreakTemplateDeclarations: Yes
+ConstructorInitializerAllOnOneLineOrOnePerLine: true
+PackConstructorInitializers: NextLineOnly
+---
diff --git a/llama-box/main.cpp b/llama-box/main.cpp
new file mode 100644
index 0000000..7fe0916
--- /dev/null
+++ b/llama-box/main.cpp
@@ -0,0 +1,3272 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "llama.cpp/common/common.h"
+#include "llama.cpp/common/grammar-parser.h"
+#include "llama.cpp/common/json-schema-to-grammar.h"
+#define JSON_ASSERT GGML_ASSERT
+#include "llama.cpp/common/json.hpp"
+#include "llama.cpp/common/log.h"
+#include "llama.cpp/ggml.h"
+#include "llama.cpp/llama.h"
+
+#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 10485760
+#include "llama.cpp/examples/server/httplib.h"
+
+#include "param.hpp"
+#include "ratelimiter.hpp"
+#include "utils.hpp"
+
+using json = nlohmann::json;
+
+bool server_log_json = true;
+
+enum stop_type {
+ STOP_TYPE_FULL,
+ STOP_TYPE_PARTIAL,
+};
+
+enum slot_state {
+ SLOT_STATE_IDLE,
+ SLOT_STATE_PROCESSING,
+};
+
+enum slot_command {
+ SLOT_COMMAND_NONE,
+ SLOT_COMMAND_LOAD_PROMPT,
+ SLOT_COMMAND_RELEASE,
+};
+
+enum server_state {
+ SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully
+ // loaded yet
+ SERVER_STATE_READY, // Server is ready and model is loaded
+ SERVER_STATE_ERROR // An error occurred, load_model failed
+};
+
+enum server_task_type {
+ SERVER_TASK_TYPE_COMPLETION,
+ SERVER_TASK_TYPE_CANCEL,
+ SERVER_TASK_TYPE_NEXT_RESPONSE,
+ SERVER_TASK_TYPE_METRICS,
+ SERVER_TASK_TYPE_SLOT_SAVE,
+ SERVER_TASK_TYPE_SLOT_RESTORE,
+ SERVER_TASK_TYPE_SLOT_ERASE,
+};
+
+struct server_task {
+ int id = -1; // to be filled by server_queue
+ int id_multi = -1;
+ int id_target = -1;
+
+ server_task_type type;
+ json data;
+
+ bool infill = false;
+ bool embedding = false;
+
+ int tps = 0;
+};
+
+struct server_task_result {
+ int id = -1;
+ int id_multi = -1;
+
+ json data;
+
+ bool stop;
+ bool error;
+};
+
+struct server_task_multi {
+ int id = -1;
+
+ std::set subtasks_remaining;
+ std::vector results;
+};
+
+struct slot_params {
+ bool stream = true;
+ bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
+
+ int32_t n_keep = 0; // number of tokens to keep from initial prompt
+ int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting
+ // context, 0 defaults to half
+ int32_t n_predict = -1; // new tokens to predict
+
+ std::vector antiprompt;
+
+ json input_prefix;
+ json input_suffix;
+};
+
+struct server_slot {
+ int id;
+ int id_task = -1;
+ int id_multi = -1;
+
+ struct slot_params params;
+
+ slot_state state = SLOT_STATE_IDLE;
+ slot_command command = SLOT_COMMAND_NONE;
+
+ // used to determine the slot that has been used the longest
+ int64_t t_last_used = -1;
+
+ // generation props
+ int32_t n_ctx = 0; // context size per slot
+ int32_t n_past = 0;
+ int32_t n_decoded = 0;
+ int32_t n_remaining = -1;
+ int32_t i_batch = -1;
+ int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
+
+ int32_t n_prompt_tokens = 0;
+ int32_t n_prompt_tokens_processed = 0;
+
+ json prompt; // can be either a string, array of strings or array of token
+ // ids
+
+ // when a task is submitted, we first tokenize the prompt and store it here
+ std::vector prompt_tokens;
+
+ std::string generated_text;
+ std::vector cache_tokens;
+ std::vector generated_token_probs;
+
+ bool infill = false;
+ bool embedding = false;
+ bool has_next_token = true;
+ bool truncated = false;
+ bool stopped_eos = false;
+ bool stopped_word = false;
+ bool stopped_limit = false;
+
+ std::string stopping_word;
+
+ bool oaicompat = false;
+ bool oaicompat_completion = false;
+ bool oaicompat_completion_chat = false;
+
+ // sampling
+ llama_token sampled;
+ struct llama_sampling_params sparams;
+ llama_sampling_context *ctx_sampling = nullptr;
+ json json_schema;
+
+ int32_t ga_i = 0; // group-attention state
+ int32_t ga_n = 1; // group-attention factor
+ int32_t ga_w = 512; // group-attention width
+
+ int32_t n_past_se = 0; // self-extend
+
+ // stats
+ size_t n_sent_text = 0; // number of sent text character
+ size_t n_sent_token_probs = 0;
+
+ int64_t t_start_process_prompt;
+ int64_t t_start_generation;
+
+ double t_prompt_processing; // ms
+ double t_token_generation; // ms
+
+ token_bucket *token_bkt = nullptr; // bucket for tokens per second
+
+ void reset() {
+ n_prompt_tokens = 0;
+ generated_text = "";
+ truncated = false;
+ stopped_eos = false;
+ stopped_word = false;
+ stopped_limit = false;
+ stopping_word = "";
+ n_past = 0;
+ n_sent_text = 0;
+ n_sent_token_probs = 0;
+ infill = false;
+ ga_i = 0;
+ n_past_se = 0;
+
+ generated_token_probs.clear();
+
+ if (token_bkt != nullptr) {
+ delete token_bkt;
+ token_bkt = nullptr;
+ }
+ }
+
+ bool has_budget(gpt_params &global_params) {
+ if (params.n_predict == -1 && global_params.n_predict == -1) {
+ return true; // limitless
+ }
+
+ n_remaining = -1;
+
+ if (params.n_predict != -1) {
+ n_remaining = params.n_predict - n_decoded;
+ } else if (global_params.n_predict != -1) {
+ n_remaining = global_params.n_predict - n_decoded;
+ }
+
+ return n_remaining > 0; // no budget
+ }
+
+ bool available() const {
+ return state == SLOT_STATE_IDLE && command == SLOT_COMMAND_NONE;
+ }
+
+ bool is_processing() const {
+ return (state == SLOT_STATE_IDLE && command == SLOT_COMMAND_LOAD_PROMPT) ||
+ state == SLOT_STATE_PROCESSING;
+ }
+
+ void add_token_string(const completion_token_output &token) {
+ if (command == SLOT_COMMAND_RELEASE) {
+ return;
+ }
+ generated_token_probs.push_back(token);
+ }
+
+ void release() {
+ if (state == SLOT_STATE_PROCESSING) {
+ t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
+ command = SLOT_COMMAND_RELEASE;
+ }
+ }
+
+ json get_formated_timings() const {
+ return json{
+ {"prompt_n", n_prompt_tokens_processed},
+ {"prompt_ms", t_prompt_processing},
+ {"prompt_per_token_ms", t_prompt_processing / n_prompt_tokens_processed},
+ {"prompt_per_second", 1e3 / t_prompt_processing * n_prompt_tokens_processed},
+
+ {"predicted_n", n_decoded},
+ {"predicted_ms", t_token_generation},
+ {"predicted_per_token_ms", t_token_generation / n_decoded},
+ {"predicted_per_second", 1e3 / t_token_generation * n_decoded},
+ };
+ }
+
+ size_t find_stopping_strings(const std::string &text, const size_t last_token_size,
+ const stop_type type) {
+ size_t stop_pos = std::string::npos;
+
+ for (const std::string &word : params.antiprompt) {
+ size_t pos;
+
+ if (type == STOP_TYPE_FULL) {
+ const size_t tmp = word.size() + last_token_size;
+ const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
+
+ pos = text.find(word, from_pos);
+ } else {
+ pos = find_partial_stop_string(word, text);
+ }
+
+ if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
+ if (type == STOP_TYPE_FULL) {
+ stopped_word = true;
+ stopping_word = word;
+ has_next_token = false;
+ }
+ stop_pos = pos;
+ }
+ }
+
+ return stop_pos;
+ }
+};
+
+struct server_metrics {
+ int64_t t_start = 0;
+
+ uint64_t n_prompt_tokens_processed_total = 0;
+ uint64_t t_prompt_processing_total = 0;
+ uint64_t n_tokens_predicted_total = 0;
+ uint64_t t_tokens_generation_total = 0;
+
+ uint64_t n_prompt_tokens_processed = 0;
+ uint64_t t_prompt_processing = 0;
+
+ uint64_t n_tokens_predicted = 0;
+ uint64_t t_tokens_generation = 0;
+
+ void init() {
+ t_start = ggml_time_us();
+ }
+
+ void on_prompt_eval(const server_slot &slot) {
+ n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed;
+ n_prompt_tokens_processed += slot.n_prompt_tokens_processed;
+ t_prompt_processing += slot.t_prompt_processing;
+ t_prompt_processing_total += slot.t_prompt_processing;
+ }
+
+ void on_prediction(const server_slot &slot) {
+ n_tokens_predicted_total += slot.n_decoded;
+ n_tokens_predicted += slot.n_decoded;
+ t_tokens_generation += slot.t_token_generation;
+ t_tokens_generation_total += slot.t_token_generation;
+ }
+
+ void reset_bucket() {
+ n_prompt_tokens_processed = 0;
+ t_prompt_processing = 0;
+ n_tokens_predicted = 0;
+ t_tokens_generation = 0;
+ }
+};
+
+struct server_queue {
+ int id = 0;
+ bool running;
+
+ // queues
+ std::vector queue_tasks;
+ std::vector queue_tasks_deferred;
+
+ std::vector queue_multitasks;
+
+ std::mutex mutex_tasks;
+ std::condition_variable condition_tasks;
+
+ // callback functions
+ std::function callback_new_task;
+ std::function callback_finish_multitask;
+ std::function callback_update_slots;
+
+ // Add a new task to the end of the queue
+ int post(server_task task) {
+ std::unique_lock lock(mutex_tasks);
+ if (task.id == -1) {
+ task.id = id++;
+ }
+ queue_tasks.push_back(std::move(task));
+ condition_tasks.notify_one();
+ return task.id;
+ }
+
+ // Add a new task, but defer until one slot is available
+ void defer(server_task task) {
+ std::unique_lock lock(mutex_tasks);
+ queue_tasks_deferred.push_back(std::move(task));
+ }
+
+ // Get the next id for creating anew task
+ int get_new_id() {
+ std::unique_lock lock(mutex_tasks);
+ int new_id = id++;
+ return new_id;
+ }
+
+ // Register function to process a new task
+ void on_new_task(std::function callback) {
+ callback_new_task = std::move(callback);
+ }
+
+ // Register function to process a multitask when it is finished
+ void on_finish_multitask(std::function callback) {
+ callback_finish_multitask = std::move(callback);
+ }
+
+ // Register the function to be called when all slots data is ready to be
+ // processed
+ void on_update_slots(std::function callback) {
+ callback_update_slots = std::move(callback);
+ }
+
+ // Call when the state of one slot is changed
+ void notify_slot_changed() {
+ // move deferred tasks back to main loop
+ std::unique_lock lock(mutex_tasks);
+ for (auto &task : queue_tasks_deferred) {
+ queue_tasks.push_back(std::move(task));
+ }
+ queue_tasks_deferred.clear();
+ }
+
+ // End the start_loop routine
+ void terminate() {
+ std::unique_lock lock(mutex_tasks);
+ running = false;
+ condition_tasks.notify_all();
+ }
+
+ /**
+ * Main loop consists of these steps:
+ * - Wait until a new task arrives
+ * - Process the task (i.e. maybe copy data into slot)
+ * - Check if multitask is finished
+ * - Update all slots
+ */
+ void start_loop() {
+ running = true;
+
+ while (true) {
+ while (true) {
+ std::unique_lock lock(mutex_tasks);
+ if (queue_tasks.empty()) {
+ lock.unlock();
+ break;
+ }
+ server_task task = queue_tasks.front();
+ queue_tasks.erase(queue_tasks.begin());
+ lock.unlock();
+ callback_new_task(task);
+ }
+
+ // check if we have any finished multitasks
+ auto queue_iterator = queue_multitasks.begin();
+ while (queue_iterator != queue_multitasks.end()) {
+ if (queue_iterator->subtasks_remaining.empty()) {
+ // all subtasks done == multitask is done
+ server_task_multi current_multitask = *queue_iterator;
+ callback_finish_multitask(current_multitask);
+ // remove this multitask
+ queue_iterator = queue_multitasks.erase(queue_iterator);
+ } else {
+ ++queue_iterator;
+ }
+ }
+
+ // all tasks in the current loop is processed, slots data is now
+ // ready
+ callback_update_slots();
+
+ {
+ std::unique_lock lock(mutex_tasks);
+ if (queue_tasks.empty()) {
+ if (!running) {
+ return;
+ }
+ condition_tasks.wait(lock, [&] { return (!queue_tasks.empty() || !running); });
+ }
+ }
+ }
+ }
+
+ //
+ // functions to manage multitasks
+ //
+
+ // add a multitask by specifying the id of all subtask (subtask is a
+ // server_task)
+ void add_multitask(int id_multi, std::vector &sub_ids) {
+ std::lock_guard lock(mutex_tasks);
+ server_task_multi multi;
+ multi.id = id_multi;
+ std::copy(sub_ids.begin(), sub_ids.end(),
+ std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
+ queue_multitasks.push_back(multi);
+ }
+
+ // updatethe remaining subtasks, while appending results to multitask
+ void update_multitask(int id_multi, int id_sub, server_task_result &result) {
+ std::lock_guard lock(mutex_tasks);
+ for (auto &multitask : queue_multitasks) {
+ if (multitask.id == id_multi) {
+ multitask.subtasks_remaining.erase(id_sub);
+ multitask.results.push_back(result);
+ }
+ }
+ }
+};
+
+struct server_response {
+ typedef std::function callback_multitask_t;
+ callback_multitask_t callback_update_multitask;
+
+ // for keeping track of all tasks waiting for the result
+ std::set waiting_task_ids;
+
+ // the main result queue
+ std::vector queue_results;
+
+ std::mutex mutex_results;
+ std::condition_variable condition_results;
+
+ // add the id_task to the list of tasks waiting for response
+ void add_waiting_task_id(int id_task) {
+ std::unique_lock lock(mutex_results);
+ waiting_task_ids.insert(id_task);
+ }
+
+ // when the request is finished, we can remove task associated with it
+ void remove_waiting_task_id(int id_task) {
+ std::unique_lock lock(mutex_results);
+ waiting_task_ids.erase(id_task);
+ }
+
+ // This function blocks the thread until there is a response for this
+ // id_task
+ server_task_result recv(int id_task) {
+ while (true) {
+ std::unique_lock lock(mutex_results);
+ condition_results.wait(lock, [&] { return !queue_results.empty(); });
+
+ for (int i = 0; i < (int)queue_results.size(); i++) {
+ if (queue_results[i].id == id_task) {
+ assert(queue_results[i].id_multi == -1);
+ server_task_result res = queue_results[i];
+ queue_results.erase(queue_results.begin() + i);
+ return res;
+ }
+ }
+ }
+
+ // should never reach here
+ }
+
+ // Register the function to update multitask
+ void on_multitask_update(callback_multitask_t callback) {
+ callback_update_multitask = std::move(callback);
+ }
+
+ // Send a new result to a waiting id_task
+ void send(server_task_result result) {
+ std::unique_lock lock(mutex_results);
+ for (const auto &id_task : waiting_task_ids) {
+ // for now, tasks that have associated parent multitasks just get
+ // erased once multitask picks up the result
+ if (result.id_multi == id_task) {
+ callback_update_multitask(id_task, result.id, result);
+ continue;
+ }
+
+ if (result.id == id_task) {
+ queue_results.push_back(result);
+ condition_results.notify_all();
+ return;
+ }
+ }
+ }
+};
+
+struct server_context {
+ llama_model *model = nullptr;
+ llama_context *ctx = nullptr;
+
+ gpt_params params;
+
+ llama_batch batch;
+
+ bool clean_kv_cache = true;
+ bool add_bos_token = true;
+
+ int32_t n_ctx; // total context for all clients / slots
+ int32_t n_tps; // max tokens per second
+
+ // system prompt
+ bool system_need_update = false;
+
+ std::string system_prompt;
+ std::vector system_tokens;
+
+ // slots / clients
+ std::vector slots;
+ json default_generation_settings_for_props;
+
+ server_queue queue_tasks;
+ server_response queue_results;
+
+ server_metrics metrics;
+
+ // Necessary similarity of prompt for slot selection
+ float slot_prompt_similarity = 0.0f;
+
+ ~server_context() {
+ if (ctx) {
+ llama_free(ctx);
+ ctx = nullptr;
+ }
+
+ if (model) {
+ llama_free_model(model);
+ model = nullptr;
+ }
+
+ // Clear any sampling context
+ for (server_slot &slot : slots) {
+ if (slot.ctx_sampling != nullptr) {
+ llama_sampling_free(slot.ctx_sampling);
+ }
+ if (slot.token_bkt != nullptr) {
+ delete slot.token_bkt;
+ }
+ }
+
+ llama_batch_free(batch);
+ }
+
+ bool load_model(const llama_box_params &bparams) {
+ params = bparams.gparams;
+
+ // dedicate one sequence to the system prompt
+ params.n_parallel += 1;
+
+ std::tie(model, ctx) = llama_init_from_gpt_params(params);
+ params.n_parallel -= 1; // but be sneaky about it
+ if (model == nullptr) {
+ LOG_ERROR("unable to load model", {{"model", params.model}});
+ return false;
+ }
+
+ n_ctx = llama_n_ctx(ctx);
+ n_tps = bparams.n_tps;
+
+ add_bos_token = llama_should_add_bos_token(model);
+ GGML_ASSERT(llama_add_eos_token(model) != 1);
+
+ // sample tokens per second
+ if (n_tps < 0) {
+ LOG_INFO("sampling tokens per second, this will take some time...", {});
+ std::vector embd = {llama_token_bos(model)};
+ const int32_t n_check = std::min(n_ctx, params.n_ubatch);
+ llama_sampling_context *ctx_sampling = llama_sampling_init(params.sparams);
+ while (true) {
+ int32_t i = embd.size();
+ if (i >= n_check) {
+ break;
+ }
+ if (llama_decode(ctx, llama_batch_get_one(&embd[i - 1], 1, 0, 0))) {
+ break;
+ }
+ const int32_t id = llama_sampling_sample(ctx_sampling, ctx, nullptr);
+ if (llama_token_is_eog(model, id)) {
+ break;
+ }
+ llama_sampling_accept(ctx_sampling, ctx, id, false);
+ embd.push_back(id);
+ }
+ const llama_timings timings = llama_get_timings(ctx);
+ n_tps = ceil(1e3 / timings.t_eval_ms * timings.n_eval);
+ llama_sampling_free(ctx_sampling);
+ llama_kv_cache_clear(ctx);
+ llama_synchronize(ctx);
+ llama_reset_timings(ctx);
+ LOG_INFO("sampled tokens per second", {"tps", n_tps});
+ }
+
+ return true;
+ }
+
+ bool validate_model_chat_template() const {
+ llama_chat_message chat[] = {{"user", "test"}};
+
+ const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0);
+
+ return res > 0;
+ }
+
+ std::string load_chat_template() const {
+ std::vector model_template(2048, 0); // longest known template is about 1200 bytes
+ std::string template_key = "tokenizer.chat_template";
+ int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(),
+ model_template.size());
+ if (res < 0) {
+ // worst case: there is no information about template, we will use chatml by default
+ return "chatml"; // see llama_chat_apply_template_internal
+ }
+ return std::string(model_template.data(), res);
+ }
+
+ void init() {
+ const int32_t n_ctx_slot = n_ctx / params.n_parallel;
+
+ LOG_INFO("initializing slots", {{"n_slots", params.n_parallel}});
+
+ for (int i = 0; i < params.n_parallel; i++) {
+ server_slot slot;
+
+ slot.id = i;
+ slot.n_ctx = n_ctx_slot;
+ slot.n_predict = params.n_predict;
+
+ LOG_INFO("new slot", {{"id_slot", slot.id}, {"n_ctx_slot", slot.n_ctx}});
+
+ const int ga_n = params.grp_attn_n;
+ const int ga_w = params.grp_attn_w;
+
+ if (ga_n != 1) {
+ GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
+ GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
+ // GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must
+ // be a multiple of ga_w"); // NOLINT GGML_ASSERT(n_ctx >=
+ // n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train *
+ // ga_n"); // NOLINT
+
+ LOG_INFO("slot self-extend",
+ {{"id_slot", slot.id}, {"ga_n", ga_n}, {"ga_w", ga_w}});
+ }
+
+ slot.ga_i = 0;
+ slot.ga_n = ga_n;
+ slot.ga_w = ga_w;
+
+ slot.reset();
+
+ slots.push_back(slot);
+ }
+
+ default_generation_settings_for_props = get_formated_generation(slots.front());
+ default_generation_settings_for_props["seed"] = -1;
+
+ // the update_slots() logic will always submit a maximum of n_batch
+ // tokens note that n_batch can be > n_ctx (e.g. for non-causal
+ // attention models such as BERT where the KV cache is not used)
+ {
+ const int32_t n_batch = llama_n_batch(ctx);
+
+ // only a single seq_id per token is needed
+ batch = llama_batch_init(n_batch, 0, 1);
+ }
+
+ metrics.init();
+ }
+
+ std::vector tokenize(const json &json_prompt, bool add_special) const {
+ // TODO: currently, we tokenize using special tokens by default
+ // this is not always correct (see
+ // https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216)
+ // but it's better compared to completely ignoring ChatML and
+ // other chat templates
+ const bool TMP_FORCE_SPECIAL = true;
+
+ // If `add_bos` is true, we only add BOS, when json_prompt is a string,
+ // or the first element of the json_prompt array is a string.
+ std::vector prompt_tokens;
+
+ if (json_prompt.is_array()) {
+ bool first = true;
+ for (const auto &p : json_prompt) {
+ if (p.is_string()) {
+ auto s = p.template get();
+
+ std::vector p;
+ if (first) {
+ p = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
+ first = false;
+ } else {
+ p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL);
+ }
+
+ prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
+ } else {
+ if (first) {
+ first = false;
+ }
+
+ prompt_tokens.push_back(p.template get());
+ }
+ }
+ } else {
+ auto s = json_prompt.template get();
+ prompt_tokens = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
+ }
+
+ return prompt_tokens;
+ }
+
+ server_slot *get_slot_by_id(int id) {
+ for (server_slot &slot : slots) {
+ if (slot.id == id) {
+ return &slot;
+ }
+ }
+
+ return nullptr;
+ }
+
+ server_slot *get_available_slot(const std::string &prompt) {
+ server_slot *ret = nullptr;
+
+ // find the slot that has at least n% prompt similarity
+ if (ret == nullptr && slot_prompt_similarity != 0.0f && !prompt.empty()) {
+ int max_lcp_len = 0;
+ float similarity = 0;
+
+ for (server_slot &slot : slots) {
+ // skip the slot if it is not available
+ if (!slot.available()) {
+ continue;
+ }
+
+ // skip the slot if it does not contains prompt
+ if (!slot.prompt.is_string()) {
+ continue;
+ }
+
+ // current slot's prompt
+ std::string slot_prompt = slot.prompt.get();
+
+ // length of the current slot's prompt
+ int slot_prompt_len = slot_prompt.size();
+
+ // length of the Longest Common Prefix between the current
+ // slot's prompt and the input prompt
+ int lcp_len = common_part(slot_prompt, prompt);
+
+ // fraction of the common substring length compared to the
+ // current slot's prompt length
+ similarity = static_cast(lcp_len) / slot_prompt_len;
+
+ // select the current slot if the criteria match
+ if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
+ max_lcp_len = lcp_len;
+ ret = &slot;
+ }
+ }
+ }
+
+ // find the slot that has been least recently used
+ if (ret == nullptr) {
+ int64_t t_last = ggml_time_us();
+ for (server_slot &slot : slots) {
+ // skip the slot if it is not available
+ if (!slot.available()) {
+ continue;
+ }
+
+ // select the current slot if the criteria match
+ if (slot.t_last_used < t_last) {
+ t_last = slot.t_last_used;
+ ret = &slot;
+ }
+ }
+ }
+
+ return ret;
+ }
+
+ bool launch_slot_with_task(server_slot &slot, const server_task &task) {
+ llama_sampling_params sparams = params.sparams;
+ auto &data = task.data;
+
+ slot.oaicompat = json_value(data, "__oaicompat", false);
+ slot.oaicompat_completion = json_value(data, "__oaicompat_completion", false);
+ slot.oaicompat_completion_chat = json_value(data, "__oaicompat_completion_chat", false);
+
+ slot.params.stream = json_value(data, "stream", false);
+ slot.params.cache_prompt = json_value(data, "cache_prompt", false);
+ slot.params.n_keep = json_value(data, "n_keep", params.n_keep);
+ slot.params.n_predict = json_value(data, "n_predict", params.n_predict);
+ slot.params.n_discard = json_value(data, "n_discard", 0);
+ slot.params.input_prefix = json_value(data, "input_prefix", params.input_prefix);
+ slot.params.input_suffix = json_value(data, "input_suffix", params.input_suffix);
+
+ slot.sparams.top_k = json_value(data, "top_k", sparams.top_k);
+ slot.sparams.top_p = json_value(data, "top_p", sparams.top_p);
+ slot.sparams.min_p = json_value(data, "min_p", sparams.min_p);
+ slot.sparams.tfs_z = json_value(data, "tfs_z", sparams.tfs_z);
+ slot.sparams.typical_p = json_value(data, "typical_p", sparams.typical_p);
+ slot.sparams.temp = json_value(data, "temperature", sparams.temp);
+ slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", sparams.dynatemp_range);
+ slot.sparams.dynatemp_exponent =
+ json_value(data, "dynatemp_exponent", sparams.dynatemp_exponent);
+ slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", sparams.penalty_last_n);
+ slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", sparams.penalty_repeat);
+ slot.sparams.penalty_freq = json_value(data, "frequency_penalty", sparams.penalty_freq);
+ slot.sparams.penalty_present =
+ json_value(data, "presence_penalty", sparams.penalty_present);
+ slot.sparams.mirostat = json_value(data, "mirostat", sparams.mirostat);
+ slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", sparams.mirostat_tau);
+ slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", sparams.mirostat_eta);
+ slot.sparams.penalize_nl = json_value(data, "penalize_nl", sparams.penalize_nl);
+ slot.sparams.seed = json_value(data, "seed", sparams.seed);
+ slot.sparams.n_probs = json_value(data, "n_probs", sparams.n_probs);
+ slot.sparams.min_keep = json_value(data, "min_keep", sparams.min_keep);
+
+ // process "json_schema" and "grammar"
+ if (data.contains("json_schema") && !data.at("json_schema").is_null() &&
+ data.contains("grammar") && !data.at("grammar").is_null()) {
+ send_error(task,
+ "Either \"json_schema\" or \"grammar\" can be "
+ "specified, but not both",
+ ERROR_TYPE_INVALID_REQUEST);
+ return false;
+ } else if (data.contains("json_schema") && !data.contains("grammar")) {
+ try {
+ auto schema = json_value(data, "json_schema", json::object());
+ slot.sparams.grammar = json_schema_to_grammar(schema);
+ } catch (const std::exception &e) {
+ send_error(task, std::string("\"json_schema\": ") + e.what(),
+ ERROR_TYPE_INVALID_REQUEST);
+ return false;
+ }
+ } else {
+ slot.sparams.grammar = json_value(data, "grammar", sparams.grammar);
+ }
+
+ if (slot.params.cache_prompt && slot.ga_n != 1) {
+ LOG_WARNING("cache_prompt is not supported with group-attention", {});
+ slot.params.cache_prompt = false;
+ }
+
+ if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
+ // Might be better to reject the request with a 400 ?
+ LOG_WARNING("Max tokens to predict exceeds server configuration",
+ {
+ {"params.n_predict", slot.params.n_predict},
+ {"slot.n_predict", slot.n_predict},
+ });
+ slot.params.n_predict = slot.n_predict;
+ }
+
+ // get prompt
+ if (!task.infill) {
+ const auto &prompt = data.find("prompt");
+ if (prompt == data.end()) {
+ send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
+ return false;
+ }
+
+ if ((prompt->is_string()) ||
+ (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
+ (prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
+ slot.prompt = *prompt;
+ } else {
+ send_error(task, "\"prompt\" must be a string or an array of integers",
+ ERROR_TYPE_INVALID_REQUEST);
+ return false;
+ }
+ }
+
+ // penalize user-provided tokens
+ {
+ slot.sparams.penalty_prompt_tokens.clear();
+ slot.sparams.use_penalty_prompt_tokens = false;
+
+ const auto &penalty_prompt = data.find("penalty_prompt");
+
+ if (penalty_prompt != data.end()) {
+ if (penalty_prompt->is_string()) {
+ const auto penalty_prompt_string = penalty_prompt->get();
+ slot.sparams.penalty_prompt_tokens =
+ llama_tokenize(model, penalty_prompt_string, false);
+
+ if (slot.params.n_predict > 0) {
+ slot.sparams.penalty_prompt_tokens.reserve(
+ slot.sparams.penalty_prompt_tokens.size() + slot.params.n_predict);
+ }
+ slot.sparams.use_penalty_prompt_tokens = true;
+ } else if (penalty_prompt->is_array()) {
+ const auto n_tokens = penalty_prompt->size();
+ slot.sparams.penalty_prompt_tokens.reserve(n_tokens +
+ std::max(0, slot.params.n_predict));
+
+ const int n_vocab = llama_n_vocab(model);
+ for (const auto &penalty_token : *penalty_prompt) {
+ if (penalty_token.is_number_integer()) {
+ const auto tok = penalty_token.get();
+ if (tok >= 0 && tok < n_vocab) {
+ slot.sparams.penalty_prompt_tokens.push_back(tok);
+ }
+ }
+ }
+ slot.sparams.use_penalty_prompt_tokens = true;
+ }
+ }
+ }
+
+ {
+ slot.sparams.logit_bias.clear();
+
+ if (json_value(data, "ignore_eos", false)) {
+ slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
+ }
+
+ const auto &logit_bias = data.find("logit_bias");
+ if (logit_bias != data.end() && logit_bias->is_array()) {
+ const int n_vocab = llama_n_vocab(model);
+ for (const auto &el : *logit_bias) {
+ // TODO: we may want to throw errors here, in case "el" is
+ // incorrect
+ if (el.is_array() && el.size() == 2) {
+ float bias;
+ if (el[1].is_number()) {
+ bias = el[1].get();
+ } else if (el[1].is_boolean() && !el[1].get()) {
+ bias = -INFINITY;
+ } else {
+ continue;
+ }
+
+ if (el[0].is_number_integer()) {
+ llama_token tok = el[0].get();
+ if (tok >= 0 && tok < n_vocab) {
+ slot.sparams.logit_bias[tok] = bias;
+ }
+ } else if (el[0].is_string()) {
+ auto toks = llama_tokenize(model, el[0].get(), false);
+ for (auto tok : toks) {
+ slot.sparams.logit_bias[tok] = bias;
+ }
+ }
+ }
+ }
+ }
+ }
+
+ {
+ slot.params.antiprompt.clear();
+
+ const auto &stop = data.find("stop");
+ if (stop != data.end() && stop->is_array()) {
+ for (const auto &word : *stop) {
+ if (!word.empty()) {
+ slot.params.antiprompt.push_back(word);
+ }
+ }
+ }
+ }
+
+ {
+ const auto &samplers_sequence = data.find("samplers");
+ if (samplers_sequence != data.end() && samplers_sequence->is_array()) {
+ std::vector sampler_names;
+ for (const auto &sampler_name : *samplers_sequence) {
+ if (sampler_name.is_string()) {
+ sampler_names.emplace_back(sampler_name);
+ }
+ }
+ slot.sparams.samplers_sequence =
+ llama_sampling_types_from_names(sampler_names, false);
+ } else {
+ slot.sparams.samplers_sequence = sparams.samplers_sequence;
+ }
+ }
+
+ {
+ if (slot.ctx_sampling != nullptr) {
+ llama_sampling_free(slot.ctx_sampling);
+ }
+ slot.ctx_sampling = llama_sampling_init(slot.sparams);
+ if (slot.ctx_sampling == nullptr) {
+ // for now, the only error that may happen here is invalid
+ // grammar
+ send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
+ return false;
+ }
+ }
+
+ {
+ if (slot.token_bkt != nullptr) {
+ delete slot.token_bkt;
+ slot.token_bkt = nullptr;
+ }
+ int tps = task.tps;
+#ifndef NDEBUG
+ tps = json_value(data, "tps",
+ task.tps); // allow overriding tps for debugging
+ if (tps > n_tps) {
+ tps = n_tps;
+ }
+#endif
+ if (tps > 0) {
+ slot.token_bkt = new token_bucket(tps, tps);
+ if (slot.token_bkt == nullptr) {
+ send_error(task, "Failed to create token bucket", ERROR_TYPE_SERVER);
+ return false;
+ }
+ }
+ }
+
+ slot.command = SLOT_COMMAND_LOAD_PROMPT;
+ slot.prompt_tokens.clear();
+
+ LOG_INFO("slot is processing task", {
+ {"id_slot", slot.id},
+ {"id_task", slot.id_task},
+ });
+
+ return true;
+ }
+
+ void kv_cache_clear() {
+ // clear the entire KV cache
+ llama_kv_cache_clear(ctx);
+ clean_kv_cache = false;
+ }
+
+ void system_prompt_update() {
+ kv_cache_clear();
+ system_tokens.clear();
+
+ if (!system_prompt.empty()) {
+ system_tokens = ::llama_tokenize(ctx, system_prompt, true);
+
+ llama_batch_clear(batch);
+
+ for (int i = 0; i < (int)system_tokens.size(); ++i) {
+ llama_batch_add(batch, system_tokens[i], i, {0}, false);
+ }
+
+ const int32_t n_batch = llama_n_batch(ctx);
+
+ for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
+ const int32_t n_tokens = std::min(params.n_batch, batch.n_tokens - i);
+ llama_batch batch_view = {
+ n_tokens,
+ batch.token + i,
+ nullptr,
+ batch.pos + i,
+ batch.n_seq_id + i,
+ batch.seq_id + i,
+ batch.logits + i,
+ 0,
+ 0,
+ 0, // unused
+ };
+
+ if (llama_decode(ctx, batch_view) != 0) {
+ LOG_ERROR("llama_decode() failed", {});
+ return;
+ }
+ }
+
+ // assign the system KV cache to all parallel sequences
+ for (int32_t i = 1; i <= params.n_parallel; ++i) {
+ llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
+ }
+ }
+
+ system_need_update = false;
+ }
+
+ bool system_prompt_set(const std::string &sys_prompt) {
+ system_prompt = sys_prompt;
+
+ // release all slots
+ for (server_slot &slot : slots) {
+ slot.release();
+ }
+
+ system_need_update = true;
+ return true;
+ }
+
+ bool process_token(completion_token_output &result, server_slot &slot) {
+ // remember which tokens were sampled - used for repetition penalties
+ // during sampling
+ const std::string token_str = llama_token_to_piece(ctx, result.tok, false);
+ slot.sampled = result.tok;
+
+ // search stop word and delete it
+ slot.generated_text += token_str;
+ slot.has_next_token = true;
+
+ if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
+ // we can change penalty_prompt_tokens because it is always created
+ // from scratch each request
+ slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
+ }
+
+ // check if there is incomplete UTF-8 character at the end
+ bool incomplete = false;
+ for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
+ unsigned char c = slot.generated_text[slot.generated_text.size() - i];
+ if ((c & 0xC0) == 0x80) {
+ // continuation byte: 10xxxxxx
+ continue;
+ }
+ if ((c & 0xE0) == 0xC0) {
+ // 2-byte character: 110xxxxx ...
+ incomplete = i < 2;
+ } else if ((c & 0xF0) == 0xE0) {
+ // 3-byte character: 1110xxxx ...
+ incomplete = i < 3;
+ } else if ((c & 0xF8) == 0xF0) {
+ // 4-byte character: 11110xxx ...
+ incomplete = i < 4;
+ }
+ // else 1-byte character or invalid byte
+ break;
+ }
+
+ if (!incomplete) {
+ size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
+
+ const std::string str_test = slot.generated_text.substr(pos);
+ bool is_stop_full = false;
+
+ size_t stop_pos =
+ slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_FULL);
+ if (stop_pos != std::string::npos) {
+ is_stop_full = true;
+ slot.generated_text.erase(slot.generated_text.begin() + pos + stop_pos,
+ slot.generated_text.end());
+ pos = std::min(slot.n_sent_text, slot.generated_text.size());
+ } else {
+ is_stop_full = false;
+ stop_pos =
+ slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_PARTIAL);
+ }
+
+ // check if there is any token to predict
+ if (stop_pos == std::string::npos ||
+ (!slot.has_next_token && !is_stop_full && stop_pos > 0)) {
+ // no send the stop word in the response
+ result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
+ slot.n_sent_text += result.text_to_send.size();
+ // add the token to slot queue and cache
+ }
+
+ slot.add_token_string(result);
+ if (slot.params.stream) {
+ send_partial_response(slot, result);
+ }
+ }
+
+ if (incomplete) {
+ slot.has_next_token = true;
+ }
+
+ // check the limits
+ if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params)) {
+ slot.stopped_limit = true;
+ slot.has_next_token = false;
+ }
+
+ if (llama_token_is_eog(model, result.tok)) {
+ slot.stopped_eos = true;
+ slot.has_next_token = false;
+ }
+
+ auto n_ctx_train = llama_n_ctx_train(model);
+ if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1 &&
+ slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
+ LOG_WARNING("n_predict is not set and self-context extend is disabled."
+ " Limiting generated tokens to n_ctx_train to avoid EOS-less "
+ "generation "
+ "infinite loop",
+ {
+ {"id_slot", slot.id},
+ {"params.n_predict", slot.params.n_predict},
+ {"slot.n_prompt_tokens", slot.n_prompt_tokens},
+ {"slot.n_decoded", slot.n_decoded},
+ {"slot.n_predict", slot.n_predict},
+ {"n_slots", params.n_parallel},
+ {"slot.n_ctx", slot.n_ctx},
+ {"n_ctx", n_ctx},
+ {"n_ctx_train", n_ctx_train},
+ {"ga_n", slot.ga_n},
+ });
+ slot.truncated = true;
+ slot.stopped_limit = true;
+ slot.has_next_token = false; // stop prediction
+ }
+
+ return slot.has_next_token; // continue
+ }
+
+ json get_formated_generation(const server_slot &slot) const {
+ const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
+ const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() &&
+ eos_bias->second < 0.0f && std::isinf(eos_bias->second);
+
+ std::vector samplers_sequence;
+ samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
+ for (const auto &sampler_type : slot.sparams.samplers_sequence) {
+ samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type));
+ }
+
+ return json{{"n_ctx", slot.n_ctx},
+ {"n_predict", slot.n_predict},
+ {"model", params.model_alias},
+ {"seed", slot.sparams.seed},
+ {"temperature", slot.sparams.temp},
+ {"dynatemp_range", slot.sparams.dynatemp_range},
+ {"dynatemp_exponent", slot.sparams.dynatemp_exponent},
+ {"top_k", slot.sparams.top_k},
+ {"top_p", slot.sparams.top_p},
+ {"min_p", slot.sparams.min_p},
+ {"tfs_z", slot.sparams.tfs_z},
+ {"typical_p", slot.sparams.typical_p},
+ {"repeat_last_n", slot.sparams.penalty_last_n},
+ {"repeat_penalty", slot.sparams.penalty_repeat},
+ {"presence_penalty", slot.sparams.penalty_present},
+ {"frequency_penalty", slot.sparams.penalty_freq},
+ {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
+ {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
+ {"mirostat", slot.sparams.mirostat},
+ {"mirostat_tau", slot.sparams.mirostat_tau},
+ {"mirostat_eta", slot.sparams.mirostat_eta},
+ {"penalize_nl", slot.sparams.penalize_nl},
+ {"stop", slot.params.antiprompt},
+ {"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict
+ {"n_keep", slot.params.n_keep},
+ {"n_discard", slot.params.n_discard},
+ {"ignore_eos", ignore_eos},
+ {"stream", slot.params.stream},
+ {"logit_bias", slot.sparams.logit_bias},
+ {"n_probs", slot.sparams.n_probs},
+ {"min_keep", slot.sparams.min_keep},
+ {"grammar", slot.sparams.grammar},
+ {"samplers", samplers_sequence}};
+ }
+
+ void send_error(const server_task &task, const std::string &error,
+ const enum error_type type = ERROR_TYPE_SERVER) {
+ send_error(task.id, task.id_multi, error, type);
+ }
+
+ void send_error(const server_slot &slot, const std::string &error,
+ const enum error_type type = ERROR_TYPE_SERVER) {
+ send_error(slot.id_task, slot.id_multi, error, type);
+ }
+
+ void send_error(const int id_task, const int id_multi, const std::string &error,
+ const enum error_type type = ERROR_TYPE_SERVER) {
+ LOG_ERROR("task error", {
+ {"id_multi", id_multi},
+ {"id_task", id_task},
+ {"error", error},
+ });
+
+ server_task_result res;
+ res.id = id_task;
+ res.id_multi = id_multi;
+ res.stop = false;
+ res.error = true;
+ res.data = format_error_response(error, type);
+
+ queue_results.send(res);
+ }
+
+ void send_partial_response(server_slot &slot, completion_token_output tkn) {
+ server_task_result res;
+ res.id = slot.id_task;
+ res.id_multi = slot.id_multi;
+ res.error = false;
+ res.stop = false;
+ res.data = json{{"content", tkn.text_to_send},
+ {"stop", false},
+ {"id_slot", slot.id},
+ {"multimodal", false}};
+
+ if (slot.sparams.n_probs > 0) {
+ const std::vector to_send_toks =
+ llama_tokenize(ctx, tkn.text_to_send, false);
+ const size_t probs_pos =
+ std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
+ const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(),
+ slot.generated_token_probs.size());
+
+ std::vector probs_output;
+ if (probs_pos < probs_stop_pos) {
+ probs_output = std::vector(
+ slot.generated_token_probs.begin() + probs_pos,
+ slot.generated_token_probs.begin() + probs_stop_pos);
+ }
+ slot.n_sent_token_probs = probs_stop_pos;
+
+ res.data["completion_probabilities"] = probs_vector_to_json(
+ ctx, probs_output, slot.oaicompat_completion, slot.oaicompat_completion_chat);
+ }
+
+ queue_results.send(res);
+ }
+
+ void send_final_response(const server_slot &slot) {
+ server_task_result res;
+ res.id = slot.id_task;
+ res.id_multi = slot.id_multi;
+ res.error = false;
+ res.stop = true;
+ res.data = json{{"content", !slot.params.stream ? slot.generated_text : ""},
+ {"id_slot", slot.id},
+ {"stop", true},
+ {"model", params.model_alias},
+ {"tokens_predicted", slot.n_decoded},
+ {"tokens_evaluated", slot.n_prompt_tokens},
+ {"generation_settings", get_formated_generation(slot)},
+ {"prompt", slot.prompt},
+ {"truncated", slot.truncated},
+ {"stopped_eos", slot.stopped_eos},
+ {"stopped_word", slot.stopped_word},
+ {"stopped_limit", slot.stopped_limit},
+ {"stopping_word", slot.stopping_word},
+ {"tokens_cached", slot.n_past},
+ {"timings", slot.get_formated_timings()}};
+
+ if (slot.sparams.n_probs > 0) {
+ std::vector probs;
+ if (!slot.params.stream && slot.stopped_word) {
+ const std::vector stop_word_toks =
+ llama_tokenize(ctx, slot.stopping_word, false);
+
+ size_t safe_offset =
+ std::min(slot.generated_token_probs.size(), stop_word_toks.size());
+ probs = std::vector(slot.generated_token_probs.begin(),
+ slot.generated_token_probs.end() -
+ safe_offset);
+ } else {
+ probs = std::vector(slot.generated_token_probs.begin(),
+ slot.generated_token_probs.end());
+ }
+
+ res.data["completion_probabilities"] = probs_vector_to_json(
+ ctx, probs, slot.oaicompat_completion, slot.oaicompat_completion_chat);
+ }
+
+ queue_results.send(res);
+ }
+
+ void send_embedding(const server_slot &slot, const llama_batch &batch) {
+ server_task_result res;
+ res.id = slot.id_task;
+ res.id_multi = slot.id_multi;
+ res.error = false;
+ res.stop = true;
+
+ const int n_embd = llama_n_embd(model);
+
+ std::vector embd_res(n_embd, 0.0f);
+
+ for (int i = 0; i < batch.n_tokens; ++i) {
+ if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
+ continue;
+ }
+
+ const float *embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
+ if (embd == NULL) {
+ embd = llama_get_embeddings_ith(ctx, i);
+ }
+
+ if (embd == NULL) {
+ LOG_ERROR("failed to get embeddings",
+ {{"token", batch.token[i]}, {"seq_id", batch.seq_id[i][0]}});
+
+ res.data = json{
+ {"embedding", std::vector(n_embd, 0.0f)},
+ };
+
+ continue;
+ }
+
+ llama_embd_normalize(embd, embd_res.data(), n_embd);
+
+ res.data = json{
+ {"embedding", embd_res},
+ };
+ }
+
+ res.data["tokens_evaluated"] = slot.n_prompt_tokens;
+ queue_results.send(res);
+ }
+
+ void request_completion(int id_task, int id_multi, json data, bool infill, bool embedding,
+ int tps = 0) {
+ server_task task;
+ task.id = id_task;
+ task.id_multi = id_multi;
+ task.id_target = 0;
+ task.data = std::move(data);
+ task.infill = infill;
+ task.embedding = embedding;
+ task.type = SERVER_TASK_TYPE_COMPLETION;
+ task.tps = tps;
+
+ // when a completion task's prompt array is not a singleton, we split it
+ // into multiple requests otherwise, it's a single-prompt task, we
+ // actually queue it if there's numbers in the prompt array it will be
+ // treated as an array of tokens
+ if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) {
+ bool numbers = false;
+ for (const auto &e : task.data.at("prompt")) {
+ if (e.is_number()) {
+ numbers = true;
+ break;
+ }
+ }
+
+ // NOTE: split_multiprompt_task() does not handle a mix of strings
+ // and numbers, it will completely stall the server. I don't know
+ // where the bug for this is.
+ //
+ // if there are numbers, it needs to be treated like a single
+ // prompt, queue_tasks handles a mix of strings and numbers just
+ // fine.
+ if (numbers) {
+ queue_tasks.post(task);
+ } else {
+ split_multiprompt_task(id_task, task);
+ }
+ } else {
+ queue_tasks.post(task);
+ }
+ }
+
+ void request_cancel(int id_task) {
+ server_task task;
+ task.type = SERVER_TASK_TYPE_CANCEL;
+ task.id_target = id_task;
+
+ queue_tasks.post(task);
+ }
+
+ void split_multiprompt_task(int id_multi, const server_task &multiprompt_task) {
+ const int prompt_count = multiprompt_task.data.at("prompt").size();
+ if (prompt_count <= 1) {
+ send_error(multiprompt_task, "error while handling multiple prompts");
+ return;
+ }
+
+ // generate all the ID for subtask
+ std::vector subtask_ids(prompt_count);
+ for (int i = 0; i < prompt_count; i++) {
+ subtask_ids[i] = queue_tasks.get_new_id();
+ }
+
+ // queue up the multitask so we can track its subtask progression
+ queue_tasks.add_multitask(id_multi, subtask_ids);
+
+ // add subtasks
+ for (int i = 0; i < prompt_count; i++) {
+ json subtask_data = multiprompt_task.data;
+ subtask_data["prompt"] = subtask_data.at("prompt")[i];
+
+ // subtasks inherit everything else (infill mode, embedding mode,
+ // etc.)
+ request_completion(subtask_ids[i], id_multi, subtask_data, multiprompt_task.infill,
+ multiprompt_task.embedding, multiprompt_task.tps);
+ }
+ }
+
+ void process_single_task(const server_task &task) {
+ switch (task.type) {
+ case SERVER_TASK_TYPE_COMPLETION: {
+ const int id_slot = json_value(task.data, "id_slot", -1);
+
+ server_slot *slot;
+
+ if (id_slot != -1) {
+ slot = get_slot_by_id(id_slot);
+ } else {
+ std::string prompt;
+ if (task.data.contains("prompt") && task.data.at("prompt").is_string()) {
+ prompt = json_value(task.data, "prompt", std::string());
+ }
+
+ slot = get_available_slot(prompt);
+ }
+
+ if (slot == nullptr) {
+ // if no slot is available, we defer this task for
+ // processing later
+ queue_tasks.defer(task);
+ break;
+ }
+ if (!slot->available()) {
+ // if requested slot is unavailable, we defer this task for
+ // processing later
+ queue_tasks.defer(task);
+ break;
+ }
+
+ if (task.data.contains("system_prompt")) {
+ std::string sys_prompt = json_value(task.data, "system_prompt", std::string());
+ system_prompt_set(sys_prompt);
+
+ for (server_slot &slot : slots) {
+ slot.n_past = 0;
+ slot.n_past_se = 0;
+ }
+ }
+
+ slot->reset();
+
+ slot->id_task = task.id;
+ slot->id_multi = task.id_multi;
+ slot->infill = task.infill;
+ slot->embedding = task.embedding;
+
+ if (!launch_slot_with_task(*slot, task)) {
+ LOG_ERROR("error while launching slot", task.data);
+ break;
+ }
+ } break;
+ case SERVER_TASK_TYPE_CANCEL: {
+ // release slot linked with the task id
+ for (auto &slot : slots) {
+ if (slot.id_task == task.id_target) {
+ slot.release();
+ break;
+ }
+ }
+ } break;
+ case SERVER_TASK_TYPE_NEXT_RESPONSE: {
+ // do nothing
+ } break;
+ case SERVER_TASK_TYPE_METRICS: {
+ json slots_data = json::array();
+
+ int n_idle_slots = 0;
+ int n_processing_slots = 0;
+
+ for (server_slot &slot : slots) {
+ json slot_data = get_formated_generation(slot);
+ slot_data["id"] = slot.id;
+ slot_data["id_task"] = slot.id_task;
+ slot_data["state"] = slot.state;
+ slot_data["prompt"] = slot.prompt;
+ slot_data["next_token"] = {
+ {"has_next_token", slot.has_next_token}, {"n_remain", slot.n_remaining},
+ {"n_decoded", slot.n_decoded}, {"stopped_eos", slot.stopped_eos},
+ {"stopped_word", slot.stopped_word}, {"stopped_limit", slot.stopped_limit},
+ {"stopping_word", slot.stopping_word},
+ };
+
+ if (slot_data["state"] == SLOT_STATE_IDLE) {
+ n_idle_slots++;
+ } else {
+ n_processing_slots++;
+ }
+
+ slots_data.push_back(slot_data);
+ }
+ LOG_INFO("slot data", {{"id_task", task.id},
+ {"n_idle_slots", n_idle_slots},
+ {"n_processing_slots", n_processing_slots}});
+
+ server_task_result res;
+ res.id = task.id;
+ res.id_multi = task.id_multi;
+ res.stop = true;
+ res.error = false;
+ res.data = {
+ {"idle", n_idle_slots},
+ {"processing", n_processing_slots},
+ {"deferred", queue_tasks.queue_tasks_deferred.size()},
+ {"t_start", metrics.t_start},
+
+ {"n_prompt_tokens_processed_total", metrics.n_prompt_tokens_processed_total},
+ {"t_tokens_generation_total", metrics.t_tokens_generation_total},
+ {"n_tokens_predicted_total", metrics.n_tokens_predicted_total},
+ {"t_prompt_processing_total", metrics.t_prompt_processing_total},
+
+ {"n_prompt_tokens_processed", metrics.n_prompt_tokens_processed},
+ {"t_prompt_processing", metrics.t_prompt_processing},
+ {"n_tokens_predicted", metrics.n_tokens_predicted},
+ {"t_tokens_generation", metrics.t_tokens_generation},
+
+ {"kv_cache_tokens_count", llama_get_kv_cache_token_count(ctx)},
+ {"kv_cache_used_cells", llama_get_kv_cache_used_cells(ctx)},
+
+ {"slots", slots_data},
+ };
+
+ if (json_value(task.data, "reset_bucket", false)) {
+ metrics.reset_bucket();
+ }
+ queue_results.send(res);
+ } break;
+ case SERVER_TASK_TYPE_SLOT_SAVE: {
+ int id_slot = task.data.at("id_slot");
+ server_slot *slot = get_slot_by_id(id_slot);
+ if (slot == nullptr) {
+ send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
+ break;
+ }
+ if (!slot->available()) {
+ // if requested slot is unavailable, we defer this task for
+ // processing later
+ queue_tasks.defer(task);
+ break;
+ }
+
+ const size_t token_count = slot->cache_tokens.size();
+ const int64_t t_start = ggml_time_us();
+
+ std::string filename = task.data.at("filename");
+ std::string filepath = task.data.at("filepath");
+
+ const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1,
+ slot->cache_tokens.data(), token_count);
+
+ const int64_t t_end = ggml_time_us();
+ const double t_save_ms = (t_end - t_start) / 1000.0;
+
+ server_task_result result;
+ result.id = task.id;
+ result.stop = true;
+ result.error = false;
+ result.data = json{{"id_slot", id_slot},
+ {"filename", filename},
+ {"n_saved", token_count}, // tokens saved
+ {"n_written", nwrite}, // bytes written
+ {"timings", {{"save_ms", t_save_ms}}}};
+ queue_results.send(result);
+ } break;
+ case SERVER_TASK_TYPE_SLOT_RESTORE: {
+ int id_slot = task.data.at("id_slot");
+ server_slot *slot = get_slot_by_id(id_slot);
+ if (slot == nullptr) {
+ send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
+ break;
+ }
+ if (!slot->available()) {
+ // if requested slot is unavailable, we defer this task for
+ // processing later
+ queue_tasks.defer(task);
+ break;
+ }
+
+ const int64_t t_start = ggml_time_us();
+
+ std::string filename = task.data.at("filename");
+ std::string filepath = task.data.at("filepath");
+
+ slot->cache_tokens.resize(slot->n_ctx);
+ size_t token_count = 0;
+ size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1,
+ slot->cache_tokens.data(),
+ slot->cache_tokens.size(), &token_count);
+ if (nread == 0) {
+ slot->cache_tokens.resize(0);
+ send_error(task,
+ "Unable to restore slot, no available space in KV "
+ "cache or invalid slot "
+ "save file",
+ ERROR_TYPE_INVALID_REQUEST);
+ break;
+ }
+ slot->cache_tokens.resize(token_count);
+
+ const int64_t t_end = ggml_time_us();
+ const double t_restore_ms = (t_end - t_start) / 1000.0;
+
+ server_task_result result;
+ result.id = task.id;
+ result.stop = true;
+ result.error = false;
+ result.data = json{{"id_slot", id_slot},
+ {"filename", filename},
+ {"n_restored", token_count}, // tokens restored
+ {"n_read", nread}, // bytes read
+ {"timings", {{"restore_ms", t_restore_ms}}}};
+ queue_results.send(result);
+ } break;
+ case SERVER_TASK_TYPE_SLOT_ERASE: {
+ int id_slot = task.data.at("id_slot");
+ server_slot *slot = get_slot_by_id(id_slot);
+ if (slot == nullptr) {
+ send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
+ break;
+ }
+ if (!slot->available()) {
+ // if requested slot is unavailable, we defer this task for
+ // processing later
+ queue_tasks.defer(task);
+ break;
+ }
+
+ // Erase token cache
+ const size_t n_erased = slot->cache_tokens.size();
+ llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1);
+ slot->cache_tokens.clear();
+
+ server_task_result result;
+ result.id = task.id;
+ result.stop = true;
+ result.error = false;
+ result.data = json{{"id_slot", id_slot}, {"n_erased", n_erased}};
+ queue_results.send(result);
+ } break;
+ }
+ }
+
+ void on_finish_multitask(const server_task_multi &multitask) {
+ // all subtasks done == multitask is done
+ server_task_result result;
+ result.id = multitask.id;
+ result.stop = true;
+ result.error = false;
+
+ // collect json results into one json result
+ std::vector result_jsons;
+ for (const auto &subres : multitask.results) {
+ result_jsons.push_back(subres.data);
+ result.error = result.error && subres.error;
+ }
+ result.data = json{{"results", result_jsons}};
+
+ queue_results.send(result);
+ }
+
+ void update_slots() {
+ if (system_need_update) {
+ system_prompt_update();
+ }
+
+ // release slots
+ for (auto &slot : slots) {
+ if (slot.command == SLOT_COMMAND_RELEASE) {
+ slot.state = SLOT_STATE_IDLE;
+ slot.command = SLOT_COMMAND_NONE;
+ slot.t_last_used = ggml_time_us();
+
+ LOG_INFO("slot released", {{"id_slot", slot.id},
+ {"id_task", slot.id_task},
+ {"n_ctx", n_ctx},
+ {"n_past", slot.n_past},
+ {"n_system_tokens", system_tokens.size()},
+ {"n_cache_tokens", slot.cache_tokens.size()},
+ {"truncated", slot.truncated}});
+
+ queue_tasks.notify_slot_changed();
+ }
+ }
+
+ // check if all slots are idle
+ {
+ bool all_idle = true;
+
+ for (auto &slot : slots) {
+ if (slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_NONE) {
+ all_idle = false;
+ break;
+ }
+ }
+
+ if (all_idle) {
+ LOG_INFO("all slots are idle", {});
+ if (system_prompt.empty() && clean_kv_cache) {
+ kv_cache_clear();
+ }
+
+ return;
+ }
+ }
+
+ {
+ server_task task;
+ task.type = SERVER_TASK_TYPE_NEXT_RESPONSE;
+ task.id_target = -1;
+
+ queue_tasks.post(task);
+ }
+
+ // apply context-shift if needed
+ // TODO: simplify and improve
+ for (server_slot &slot : slots) {
+ if (slot.ga_n == 1) {
+ if (slot.is_processing() &&
+ (int)system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
+ // Shift context
+ const int n_keep = slot.params.n_keep + add_bos_token;
+ const int n_left = (int)system_tokens.size() + slot.n_past - n_keep;
+ const int n_discard =
+ slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
+
+ LOG_INFO("slot context shift", {{"id_slot", slot.id},
+ {"id_task", slot.id_task},
+ {"n_keep", n_keep},
+ {"n_left", n_left},
+ {"n_discard", n_discard},
+ {"n_ctx", n_ctx},
+ {"n_past", slot.n_past},
+ {"n_system_tokens", system_tokens.size()},
+ {"n_cache_tokens", slot.cache_tokens.size()}});
+
+ llama_kv_cache_seq_rm(ctx, slot.id + 1, n_keep, n_keep + n_discard);
+ llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard,
+ system_tokens.size() + slot.n_past, -n_discard);
+
+ if (slot.params.cache_prompt) {
+ for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
+ slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
+ }
+
+ slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
+ }
+
+ slot.n_past -= n_discard;
+
+ slot.truncated = true;
+ }
+ }
+ }
+
+ // start populating the batch for this iteration
+ llama_batch_clear(batch);
+
+ // frist, add sampled tokens from any ongoing sequences
+ for (auto &slot : slots) {
+ if (slot.state == SLOT_STATE_IDLE) {
+ continue;
+ }
+
+ if (slot.token_bkt && !slot.token_bkt->acquire()) {
+ continue;
+ }
+
+ slot.i_batch = batch.n_tokens;
+
+ const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
+
+ // TODO: we always have to take into account the "system_tokens"
+ // this is not great and needs to be improved somehow
+ llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, {slot.id + 1},
+ true);
+
+ slot.n_past += 1;
+
+ if (slot.params.cache_prompt) {
+ slot.cache_tokens.push_back(slot.sampled);
+ }
+ }
+
+ // process in chunks of params.n_batch
+ int32_t n_batch = llama_n_batch(ctx);
+ int32_t n_ubatch = llama_n_ubatch(ctx);
+
+ // next, batch any pending prompts without exceeding n_batch
+ if (params.cont_batching || batch.n_tokens == 0) {
+ for (auto &slot : slots) {
+ // this slot still has a prompt to be processed
+ if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) {
+ auto &prompt_tokens = slot.prompt_tokens;
+
+ // we haven't tokenized the prompt yet - do it now:
+ if (prompt_tokens.empty()) {
+ slot.t_start_process_prompt = ggml_time_us();
+ slot.t_start_generation = 0;
+
+ if (slot.infill) {
+ bool suff_rm_leading_spc = true;
+ if (params.input_suffix.find_first_of(' ') == 0 &&
+ params.input_suffix.size() > 1) {
+ params.input_suffix.erase(0, 1);
+ suff_rm_leading_spc = false;
+ }
+
+ auto prefix_tokens = tokenize(slot.params.input_prefix, false);
+ auto suffix_tokens = tokenize(slot.params.input_suffix, false);
+
+ const int space_token = 29871; // TODO: this should not be hardcoded
+ if (suff_rm_leading_spc && !suffix_tokens.empty() &&
+ suffix_tokens[0] == space_token) {
+ suffix_tokens.erase(suffix_tokens.begin());
+ }
+
+ prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
+ prefix_tokens.insert(prefix_tokens.begin(),
+ llama_token_bos(model)); // always add BOS
+ prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
+ prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(),
+ suffix_tokens.end());
+
+ const llama_token middle_token = llama_token_middle(model);
+ if (middle_token >= 0) {
+ prefix_tokens.push_back(middle_token);
+ }
+
+ prompt_tokens = prefix_tokens;
+ } else {
+ prompt_tokens = tokenize(slot.prompt,
+ system_prompt.empty()); // add BOS if there
+ // isn't system prompt
+ }
+
+ slot.n_past = 0;
+ slot.n_prompt_tokens = prompt_tokens.size();
+
+ // empty prompt passed -> release the slot and send
+ // empty response
+ if (prompt_tokens.empty()) {
+ LOG_INFO("empty prompt - releasing slot",
+ {{"id_slot", slot.id}, {"id_task", slot.id_task}});
+
+ slot.state = SLOT_STATE_PROCESSING;
+ slot.command = SLOT_COMMAND_NONE;
+ slot.release();
+ send_final_response(slot);
+ continue;
+ }
+
+ if (slot.embedding) {
+ // this prompt is too large to process - discard it
+ if (slot.n_prompt_tokens > n_ubatch) {
+ slot.state = SLOT_STATE_PROCESSING;
+ slot.command = SLOT_COMMAND_NONE;
+ slot.release();
+ send_error(slot,
+ "input is too large to process. "
+ "increase the physical "
+ "batch size",
+ ERROR_TYPE_SERVER);
+ continue;
+ }
+ } else {
+ if (slot.params.n_keep < 0) {
+ slot.params.n_keep = slot.n_prompt_tokens;
+ }
+ slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
+
+ // if input prompt is too big, truncate it (if group
+ // attention self-extend is disabled)
+ if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
+ const int n_left = slot.n_ctx - slot.params.n_keep;
+
+ const int n_block_size = n_left / 2;
+ const int erased_blocks =
+ (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) /
+ n_block_size;
+
+ std::vector new_tokens(prompt_tokens.begin(),
+ prompt_tokens.begin() +
+ slot.params.n_keep);
+
+ new_tokens.insert(new_tokens.end(),
+ prompt_tokens.begin() + slot.params.n_keep +
+ erased_blocks * n_block_size,
+ prompt_tokens.end());
+
+ prompt_tokens = std::move(new_tokens);
+
+ slot.truncated = true;
+ slot.n_prompt_tokens = prompt_tokens.size();
+
+ GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
+ }
+
+ llama_sampling_reset(slot.ctx_sampling);
+
+ if (!slot.params.cache_prompt) {
+ slot.n_past_se = 0;
+ slot.ga_i = 0;
+ } else {
+ GGML_ASSERT(slot.ga_n == 1);
+
+ // reuse any previously computed tokens that are
+ // common with the new prompt
+ slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
+
+ // push the prompt into the sampling context (do
+ // not apply grammar)
+ for (int i = 0; i < slot.n_past; ++i) {
+ llama_sampling_accept(slot.ctx_sampling, ctx,
+ slot.cache_tokens[i], false);
+ }
+ }
+ }
+
+ if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
+ // we have to evaluate at least 1 token to generate
+ // logits.
+ LOG_INFO("we have to evaluate at least 1 token to "
+ "generate logits",
+ {{"id_slot", slot.id}, {"id_task", slot.id_task}});
+
+ slot.n_past--;
+ if (slot.ga_i > 0) {
+ slot.n_past_se--;
+ }
+ }
+
+ slot.n_prompt_tokens_processed = 0;
+ }
+
+ if (slot.embedding) {
+ // cannot fit the prompt in the current batch - will try
+ // next iter
+ if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
+ continue;
+ }
+ }
+
+ // keep only the common part
+ int p0 = (int)system_tokens.size() + slot.n_past;
+ if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
+ // could not partially delete (likely using a
+ // non-Transformer model)
+ llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
+
+ p0 = (int)system_tokens.size();
+ if (p0 != 0) {
+ // copy over the system prompt when there is one
+ llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1);
+ }
+
+ // there is no common part left (except for the system
+ // prompt)
+ slot.n_past = 0;
+ slot.n_past_se = 0;
+ slot.ga_i = 0;
+ // TODO: is the system prompt ever in the sampling
+ // context?
+ llama_sampling_reset(slot.ctx_sampling);
+ }
+
+ // remove the non-common part from the cache
+ slot.cache_tokens.resize(slot.n_past);
+
+ LOG_INFO("kv cache rm [p0, end)",
+ {{"id_slot", slot.id}, {"id_task", slot.id_task}, {"p0", p0}});
+
+ int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
+
+ int32_t ga_i = slot.ga_i;
+ int32_t ga_n = slot.ga_n;
+ int32_t ga_w = slot.ga_w;
+
+ // add prompt tokens for processing in the current batch
+ // TODO: the self-extend stuff here is a mess - simplify
+ // and/or abstract it somehow
+ for (; slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch;
+ ++slot.n_past) {
+ if (slot.ga_n != 1) {
+ while (slot_npast >= ga_i + ga_w) {
+ const int bd = (ga_w / ga_n) * (ga_n - 1);
+ slot_npast -= bd;
+ ga_i += ga_w / ga_n;
+ }
+ }
+
+ llama_batch_add(batch, prompt_tokens[slot.n_past],
+ system_tokens.size() + slot_npast, {slot.id + 1}, false);
+
+ if (slot.params.cache_prompt) {
+ slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
+ }
+
+ slot.n_prompt_tokens_processed++;
+ slot_npast++;
+ }
+
+ // entire prompt has been processed - start decoding new
+ // tokens
+ if (slot.n_past == slot.n_prompt_tokens) {
+ slot.state = SLOT_STATE_PROCESSING;
+ slot.command = SLOT_COMMAND_NONE;
+
+ GGML_ASSERT(batch.n_tokens > 0);
+
+ // extract the logits only for the last token
+ batch.logits[batch.n_tokens - 1] = true;
+
+ slot.n_decoded = 0;
+ slot.i_batch = batch.n_tokens - 1;
+ }
+ }
+
+ if (batch.n_tokens >= n_batch) {
+ break;
+ }
+ }
+ }
+
+ if (batch.n_tokens == 0) {
+ return;
+ }
+
+ // process the created batch of tokens
+ for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
+ const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
+
+ for (auto &slot : slots) {
+ if (slot.ga_n != 1) {
+ // context extension via Self-Extend
+ // TODO: simplify and/or abstract this
+ while (slot.n_past_se >= slot.ga_i + slot.ga_w) {
+ const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
+ const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
+ const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
+
+ LOG_TEE("\n");
+ LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i,
+ slot.n_past_se, ib * bd, slot.ga_i + ib * bd,
+ slot.n_past_se + ib * bd);
+ LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd,
+ slot.ga_i + ib * bd + slot.ga_w, slot.ga_n,
+ (slot.ga_i + ib * bd) / slot.ga_n,
+ (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
+ LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n",
+ slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd,
+ slot.ga_i + ib * bd + slot.ga_w + dd,
+ slot.n_past_se + ib * bd + dd);
+
+ llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se,
+ ib * bd);
+ llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd,
+ slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
+ llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w,
+ slot.n_past_se + ib * bd, dd);
+
+ slot.n_past_se -= bd;
+
+ slot.ga_i += slot.ga_w / slot.ga_n;
+
+ LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n",
+ slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
+ }
+
+ slot.n_past_se += n_tokens;
+ }
+ }
+
+ llama_batch batch_view = {
+ n_tokens,
+ batch.token + i,
+ nullptr,
+ batch.pos + i,
+ batch.n_seq_id + i,
+ batch.seq_id + i,
+ batch.logits + i,
+ 0,
+ 0,
+ 0, // unused
+ };
+
+ const int ret = llama_decode(ctx, batch_view);
+
+ if (ret != 0) {
+ if (n_batch == 1 || ret < 0) {
+ // if you get here, it means the KV cache is full - try
+ // increasing it via the context size
+ LOG_ERROR("failed to decode the batch: KV cache is full - "
+ "try increasing it "
+ "via the context size",
+ {
+ {"i", i},
+ {"n_batch", ret},
+ {"ret", ret},
+ });
+ for (auto &slot : slots) {
+ slot.state = SLOT_STATE_PROCESSING;
+ slot.command = SLOT_COMMAND_NONE;
+ slot.release();
+ send_error(slot, "Input prompt is too big compared to "
+ "KV size. Please try "
+ "increasing KV size.");
+ }
+ break; // break loop of n_batch
+ }
+
+ // retry with half the batch size to try to find a free slot in
+ // the KV cache
+ n_batch /= 2;
+ i -= n_batch;
+
+ LOG_WARNING("failed to find free space in the KV cache, "
+ "retrying with smaller batch size - "
+ "try increasing it via the context size or enable "
+ "defragmentation",
+ {
+ {"i", i},
+ {"n_batch", n_batch},
+ {"ret", ret},
+ });
+
+ continue; // continue loop of n_batch
+ }
+
+ for (auto &slot : slots) {
+ if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int)i ||
+ slot.i_batch >= (int)(i + n_tokens)) {
+ continue; // continue loop of slots
+ }
+
+ // prompt evaluated for embedding
+ if (slot.embedding) {
+ send_embedding(slot, batch_view);
+ slot.release();
+ slot.i_batch = -1;
+ continue; // continue loop of slots
+ }
+
+ completion_token_output result;
+ const llama_token id =
+ llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
+
+ llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
+
+ slot.n_decoded += 1;
+ if (slot.n_decoded == 1) {
+ slot.t_start_generation = ggml_time_us();
+ slot.t_prompt_processing =
+ (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
+ metrics.on_prompt_eval(slot);
+ }
+
+ llama_token_data_array cur_p = {slot.ctx_sampling->cur.data(),
+ slot.ctx_sampling->cur.size(), false};
+ result.tok = id;
+
+ const size_t n_probs = std::min(cur_p.size, (size_t)slot.sparams.n_probs);
+ if (n_probs > 0) {
+ const size_t n_valid = slot.ctx_sampling->n_valid;
+
+ // Make sure at least n_probs top tokens are at the front of
+ // the vector:
+ if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
+ llama_sample_top_k(ctx, &cur_p, n_probs, 0);
+ }
+
+ if (slot.sparams.temp == 0.0f) {
+ // With greedy sampling the probabilities have possibly
+ // not been calculated.
+ for (size_t i = 0; i < n_probs; ++i) {
+ result.probs.push_back({cur_p.data[i].id, i == 0 ? 1.0f : 0.0f});
+ }
+ } else {
+ for (size_t i = 0; i < n_probs; ++i) {
+ result.probs.push_back({
+ cur_p.data[i].id,
+ i >= n_valid ? 0.0f
+ : cur_p.data[i].p // Tokens filtered out due to e.g.
+ // top_k have 0 probability.
+ });
+ }
+ }
+ }
+
+ if (!process_token(result, slot)) {
+ slot.release();
+ send_final_response(slot);
+ metrics.on_prediction(slot);
+ }
+
+ slot.i_batch = -1;
+ }
+ }
+ }
+
+ json model_meta() const {
+ return json{
+ {"vocab_type", llama_vocab_type(model)}, {"n_vocab", llama_n_vocab(model)},
+ {"n_ctx_train", llama_n_ctx_train(model)}, {"n_embd", llama_n_embd(model)},
+ {"n_params", llama_model_n_params(model)}, {"size", llama_model_size(model)},
+ };
+ }
+};
+
+static void log_server_request(const httplib::Request &req, const httplib::Response &res) {
+ // skip GH copilot requests when using default port
+ if (req.path == "/v1/health" || req.path == "/v1/completions") {
+ return;
+ }
+
+ LOG_INFO("request", {
+ {"remote_addr", req.remote_addr},
+ {"remote_port", req.remote_port},
+ {"status", res.status},
+ {"method", req.method},
+ {"path", req.path},
+ {"params", req.params},
+ });
+}
+
+std::function shutdown_handler;
+std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
+
+inline void signal_handler(int signal) {
+ if (is_terminating.test_and_set()) {
+ // in case it hangs, we can force terminate the server by hitting Ctrl+C
+ // twice this is for better developer experience, we can remove when the
+ // server is stable enough
+ fprintf(stderr, "Received second interrupt, terminating immediately.\n");
+ exit(1);
+ }
+
+ shutdown_handler(signal);
+}
+
+int main(int argc, char **argv) {
+ log_set_target(stderr);
+
+ llama_box_params bparams;
+ if (!llama_box_params_parse(argc, argv, bparams)) {
+ llama_box_params_print_usage(argc, argv, bparams);
+ return 1;
+ }
+ gpt_params ¶ms = bparams.gparams;
+
+ server_log_json = params.log_json;
+
+ server_context ctx_server;
+ if (!params.system_prompt.empty()) {
+ ctx_server.system_prompt_set(params.system_prompt);
+ }
+
+ if (params.model_alias == "unknown") {
+ params.model_alias = params.model;
+ }
+
+ llama_backend_init();
+ llama_numa_init(params.numa);
+
+ LOG_INFO("build info", {{"build", LLAMA_BUILD_NUMBER}, {"commit", LLAMA_COMMIT}});
+ LOG_INFO("system info", {
+ {"n_threads", params.n_threads},
+ {"n_threads_batch", params.n_threads_batch},
+ {"total_threads", std::thread::hardware_concurrency()},
+ {"system_info", llama_print_system_info()},
+ });
+
+ httplib::Server svr;
+ std::atomic state{SERVER_STATE_LOADING_MODEL};
+
+ // default headers
+ svr.set_default_headers({{"Server", "llama-box"}});
+
+ // CORS preflight
+ svr.Options(R"(.*)", [](const httplib::Request &req, httplib::Response &res) {
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
+ res.set_header("Access-Control-Allow-Methods", "POST");
+ res.set_header("Access-Control-Allow-Headers", "*");
+ return res.set_content("", "application/json; charset=utf-8");
+ });
+
+ // logger
+ svr.set_logger(log_server_request);
+
+ // error handlers
+ auto res_error = [](httplib::Response &res, json error_data) {
+ json final_response{{"error", error_data}};
+ res.set_content(final_response.dump(), "application/json; charset=utf-8");
+ res.status = json_value(error_data, "code", httplib::StatusCode::InternalServerError_500);
+ };
+ svr.set_exception_handler(
+ [&res_error](const httplib::Request &, httplib::Response &res, std::exception_ptr ep) {
+ std::string message;
+ try {
+ std::rethrow_exception(std::move(ep));
+ } catch (std::exception &e) {
+ message = e.what();
+ } catch (...) {
+ message = "Unknown Exception";
+ }
+
+ json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
+ LOG_ERROR("Got exception", formatted_error);
+ res_error(res, formatted_error);
+ });
+ svr.set_error_handler([&res_error](const httplib::Request &, httplib::Response &res) {
+ if (res.status == 404) {
+ res_error(res, format_error_response("Not Found", ERROR_TYPE_NOT_FOUND));
+ }
+ // for other error codes, we skip processing here because it's
+ // already done by res_error()
+ });
+
+ // configure and bind
+ svr.set_read_timeout(params.timeout_read);
+ svr.set_write_timeout(params.timeout_write);
+ svr.set_payload_max_length(1024 * 1024 * 10);
+ svr.set_idle_interval(bparams.conn_idle);
+ svr.set_keep_alive_timeout(bparams.conn_keepalive);
+ if (!svr.bind_to_port(params.hostname, params.port)) {
+ LOG_ERROR("couldn't bind to server socket",
+ {{"hostname", params.hostname}, {"port", params.port}});
+ return 1;
+ }
+
+ std::unordered_map log_data;
+
+ log_data["hostname"] = params.hostname;
+ log_data["port"] = std::to_string(params.port);
+
+ // necessary similarity of prompt for slot selection
+ ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
+
+ // load the model
+ if (!ctx_server.load_model(bparams)) {
+ return 1;
+ }
+ ctx_server.init();
+
+ LOG_INFO("model loaded", {});
+
+ // if a custom chat template is not supplied, we will use the one that comes
+ // with the model (if any)
+ if (params.chat_template.empty()) {
+ params.chat_template = ctx_server.load_chat_template();
+ }
+ LOG_INFO("chat template", {{"template", params.chat_template}});
+
+ //
+ // Handlers
+ //
+
+ const auto handle_health = [&](const httplib::Request &req, httplib::Response &res) {
+ server_state current_state = state.load();
+ switch (current_state) {
+ case SERVER_STATE_READY: {
+ // request slots data using task queue
+ server_task task;
+ task.id_target = -1;
+ task.type = SERVER_TASK_TYPE_METRICS;
+
+ // post the task
+ task.id = ctx_server.queue_tasks.post(task);
+ ctx_server.queue_results.add_waiting_task_id(task.id);
+
+ // get the result
+ server_task_result result = ctx_server.queue_results.recv(task.id);
+ ctx_server.queue_results.remove_waiting_task_id(task.id);
+
+ const int n_idle_slots = result.data.at("idle");
+ const int n_processing_slots = result.data.at("processing");
+
+ json health = {{"status", "ok"},
+ {"slots_idle", n_idle_slots},
+ {"slots_processing", n_processing_slots}};
+
+ if (params.endpoint_slots && req.has_param("include_slots")) {
+ health["slots"] = result.data.at("slots");
+ }
+
+ if (n_idle_slots == 0) {
+ health["status"] = "no slot available";
+ if (req.has_param("fail_on_no_slot")) {
+ res.status = httplib::StatusCode::ServiceUnavailable_503;
+ }
+ }
+
+ res.set_content(health.dump(), "application/json");
+ break;
+ }
+ case SERVER_STATE_LOADING_MODEL:
+ res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
+ break;
+ case SERVER_STATE_ERROR:
+ res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER));
+ break;
+ }
+ };
+
+ const auto handle_metrics = [&](const httplib::Request &, httplib::Response &res) {
+ // request slots data using task queue
+ server_task task;
+ task.id_multi = -1;
+ task.id_target = -1;
+ task.type = SERVER_TASK_TYPE_METRICS;
+ task.data.push_back({{"reset_bucket", true}});
+
+ // post the task
+ task.id = ctx_server.queue_tasks.post(task);
+ ctx_server.queue_results.add_waiting_task_id(task.id);
+
+ // get the result
+ server_task_result result = ctx_server.queue_results.recv(task.id);
+ ctx_server.queue_results.remove_waiting_task_id(task.id);
+
+ std::stringstream metrics;
+ {
+ json data = result.data;
+ uint64_t n_prompt_tokens_processed_total = data.at("n_prompt_tokens_processed_total");
+ uint64_t t_prompt_processing_total = data.at("t_prompt_processing_total");
+ uint64_t n_tokens_predicted_total = data.at("n_tokens_predicted_total");
+ uint64_t t_tokens_generation_total = data.at("t_tokens_generation_total");
+ uint64_t n_prompt_tokens_processed = data.at("n_prompt_tokens_processed");
+ uint64_t t_prompt_processing = data.at("t_prompt_processing");
+ uint64_t n_tokens_predicted = data.at("n_tokens_predicted");
+ uint64_t t_tokens_generation = data.at("t_tokens_generation");
+ int32_t kv_cache_used_cells = data.at("kv_cache_used_cells");
+ uint64_t kv_cache_tokens_count = data.at("kv_cache_tokens_count");
+ uint64_t processing = data.at("processing");
+ uint64_t deferred = data.at("deferred");
+
+ // metrics definition:
+ // https://prometheus.io/docs/practices/naming/#metric-names
+ json all_metrics_def = json{
+ {"counter",
+ {{{"name", "prompt_tokens_total"},
+ {"help", "Number of prompt tokens processed."},
+ {"value", n_prompt_tokens_processed_total}},
+ {{"name", "prompt_seconds_total"},
+ {"help", "Prompt process time"},
+ {"value", t_prompt_processing_total / 1.e3}},
+ {{"name", "tokens_predicted_total"},
+ {"help", "Number of generation tokens processed."},
+ {"value", n_tokens_predicted_total}},
+ {{"name", "tokens_predicted_seconds_total"},
+ {"help", "Predict process time"},
+ {"value", t_tokens_generation_total / 1.e3}}}},
+ {"gauge",
+ {{{"name", "prompt_tokens_seconds"},
+ {"help", "Average prompt throughput in tokens/s."},
+ {"value", n_prompt_tokens_processed
+ ? 1.e3 / t_prompt_processing * n_prompt_tokens_processed
+ : 0.}},
+ {{"name", "predicted_tokens_seconds"},
+ {"help", "Average generation throughput in tokens/s."},
+ {"value",
+ n_tokens_predicted ? 1.e3 / t_tokens_generation * n_tokens_predicted : 0.}},
+ {{"name", "kv_cache_usage_ratio"},
+ {"help", "KV-cache usage. 1 means 100 percent usage."},
+ {"value", 1. * kv_cache_used_cells / params.n_ctx}},
+ {{"name", "kv_cache_tokens"},
+ {"help", "KV-cache tokens."},
+ {"value", kv_cache_tokens_count}},
+ {{"name", "requests_processing"},
+ {"help", "Number of request processing."},
+ {"value", processing}},
+ {{"name", "requests_deferred"},
+ {"help", "Number of request deferred."},
+ {"value", deferred}}}}};
+
+ for (const auto &el : all_metrics_def.items()) {
+ const auto &type = el.key();
+ const auto &metrics_def = el.value();
+
+ for (const auto &metric_def : metrics_def) {
+ const std::string name = metric_def.at("name");
+ const std::string help = metric_def.at("help");
+
+ auto value = json_value(metric_def, "value", 0.);
+ metrics << "# HELP llamacpp:" << name << " " << help << "\n"
+ << "# TYPE llamacpp:" << name << " " << type << "\n"
+ << "llamacpp:" << name << " " << value << "\n";
+ }
+ }
+ }
+ res.set_content(metrics.str(), "text/plain; version=0.0.4");
+ };
+
+ const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response &res) {
+ json props = {
+ {"system_prompt", ctx_server.system_prompt.c_str()},
+ {"default_generation_settings", ctx_server.default_generation_settings_for_props},
+ {"total_slots", ctx_server.params.n_parallel}};
+
+ res.set_content(props.dump(), "application/json; charset=utf-8");
+ };
+
+ const auto handle_infill = [&ctx_server, &res_error](const httplib::Request &req,
+ httplib::Response &res) {
+ int tps = 0;
+ {
+ const std::string tps_s = req.get_header_value("X-Request-Tokens-Per-Second");
+ if (!tps_s.empty()) {
+ try {
+ tps = std::stoi(tps_s);
+ } catch (const std::exception &) {
+ tps = ctx_server.n_tps;
+ }
+ }
+ if (tps > ctx_server.n_tps) {
+ // if the request exceeds the maximum tokens per second, return
+ // 410 Gone
+ if (ctx_server.n_tps > 0) {
+ res.status = httplib::StatusCode::Gone_410;
+ res.set_content("This request exceeds the maximum tokens per second",
+ "text/plain; charset=utf-8");
+ return;
+ }
+ // if the server is not limited by tokens per second, set tps to
+ // 0
+ tps = 0;
+ }
+ }
+
+ const json request = json::parse(req.body);
+
+ // post the task
+ const int id_task = ctx_server.queue_tasks.get_new_id();
+ ctx_server.queue_results.add_waiting_task_id(id_task);
+ ctx_server.request_completion(id_task, -1, request, true, false, tps);
+
+ // process non-streaming requests
+ if (!json_value(request, "stream", false)) {
+ server_task_result result = ctx_server.queue_results.recv(id_task);
+ if (result.error || !result.stop) {
+ res_error(res, result.data);
+ } else {
+ res.set_header("X-Response-Tokens-Per-Second",
+ std::to_string(json_value(result.data.at("timings"),
+ "predicted_per_second", tps)));
+ const std::string infill =
+ result.data.dump(-1, ' ', false, json::error_handler_t::replace);
+ res.set_content(infill, "application/json; charset=utf-8");
+ }
+
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
+ return;
+ }
+
+ // process streaming requests
+ const auto on_chunk = [id_task, &ctx_server, tps](size_t, httplib::DataSink &sink) {
+ while (true) {
+ server_task_result result = ctx_server.queue_results.recv(id_task);
+ if (result.error) {
+ const std::string str = "error: failed to infill\n\n";
+ sink.write(str.c_str(), str.size());
+ sink.done();
+ return false;
+ }
+
+ const std::string infill =
+ "data: " + result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
+ "\n\n";
+ if (!sink.write(infill.c_str(), infill.size())) {
+ sink.done();
+ return false;
+ }
+
+ if (!result.stop) {
+ continue;
+ }
+
+ sink.done_with_trailer({{"X-Response-Tokens-Per-Second",
+ std::to_string(json_value(result.data.at("timings"),
+ "predicted_per_second", tps))}});
+ return true;
+ }
+ };
+ const auto on_complete = [id_task, &ctx_server](bool) {
+ ctx_server.request_cancel(id_task);
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
+ };
+
+ res.set_header("Trailer", "X-Response-Tokens-Per-Second");
+ res.set_chunked_content_provider("text/event-stream", on_chunk, on_complete);
+ };
+
+ const auto handle_tokenize = [&ctx_server](const httplib::Request &req,
+ httplib::Response &res) {
+ const json request = json::parse(req.body);
+
+ std::vector tokens;
+ if (request.count("content") != 0) {
+ const bool add_special = json_value(request, "add_special", false);
+ tokens = ctx_server.tokenize(request.at("content"), add_special);
+ }
+
+ const json response = json{{"tokens", tokens}};
+ return res.set_content(response.dump(), "application/json; charset=utf-8");
+ };
+
+ const auto handle_detokenize = [&ctx_server](const httplib::Request &req,
+ httplib::Response &res) {
+ const json request = json::parse(req.body);
+
+ std::string content;
+ if (request.count("tokens") != 0) {
+ const std::vector tokens = request.at("tokens");
+ content = llama_detokenize_bpe(ctx_server.ctx, tokens);
+ }
+
+ const json response = json{{"content", content}};
+ return res.set_content(response.dump(), "application/json; charset=utf-8");
+ };
+
+ const auto handle_slots = [&](const httplib::Request &, httplib::Response &res) {
+ // request slots data using task queue
+ server_task task;
+ task.id_multi = -1;
+ task.id_target = -1;
+ task.type = SERVER_TASK_TYPE_METRICS;
+
+ // post the task
+ task.id = ctx_server.queue_tasks.post(task);
+ ctx_server.queue_results.add_waiting_task_id(task.id);
+
+ // get the result
+ server_task_result result = ctx_server.queue_results.recv(task.id);
+ ctx_server.queue_results.remove_waiting_task_id(task.id);
+
+ res.set_content(result.data.at("slots").dump(), "application/json");
+ };
+
+ const auto handle_slots_save = [&ctx_server, &res_error, ¶ms](const httplib::Request &req,
+ httplib::Response &res,
+ int id_slot) {
+ json request = json::parse(req.body);
+
+ std::string filename = request.at("filename");
+ if (!fs_validate_filename(filename)) {
+ res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
+ return;
+ }
+ std::string filepath = params.slot_save_path + filename;
+
+ server_task task;
+ task.type = SERVER_TASK_TYPE_SLOT_SAVE;
+ task.data = {{"id_slot", id_slot}, {"filename", filename}, {"filepath", filepath}};
+
+ // post the task
+ task.id = ctx_server.queue_tasks.post(task);
+ ctx_server.queue_results.add_waiting_task_id(task.id);
+
+ // get the result
+ server_task_result result = ctx_server.queue_results.recv(task.id);
+ ctx_server.queue_results.remove_waiting_task_id(task.id);
+
+ if (result.error) {
+ res_error(res, result.data);
+ return;
+ }
+
+ res.set_content(result.data.dump(), "application/json");
+ };
+
+ const auto handle_slots_restore = [&ctx_server, &res_error,
+ ¶ms](const httplib::Request &req, httplib::Response &res,
+ int id_slot) {
+ json request = json::parse(req.body);
+
+ std::string filename = request.at("filename");
+ if (!fs_validate_filename(filename)) {
+ res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
+ return;
+ }
+ std::string filepath = params.slot_save_path + filename;
+
+ server_task task;
+ task.type = SERVER_TASK_TYPE_SLOT_RESTORE;
+ task.data = {{"id_slot", id_slot}, {"filename", filename}, {"filepath", filepath}};
+
+ // post the task
+ task.id = ctx_server.queue_tasks.post(task);
+ ctx_server.queue_results.add_waiting_task_id(task.id);
+
+ // get the result
+ server_task_result result = ctx_server.queue_results.recv(task.id);
+ ctx_server.queue_results.remove_waiting_task_id(task.id);
+
+ if (result.error) {
+ res_error(res, result.data);
+ return;
+ }
+
+ res.set_content(result.data.dump(), "application/json");
+ };
+
+ const auto handle_slots_erase = [&ctx_server, &res_error](const httplib::Request &,
+ httplib::Response &res, int id_slot) {
+ server_task task;
+ task.type = SERVER_TASK_TYPE_SLOT_ERASE;
+ task.data = {{"id_slot", id_slot}};
+
+ // post the task
+ task.id = ctx_server.queue_tasks.post(task);
+ ctx_server.queue_results.add_waiting_task_id(task.id);
+
+ // get the result
+ server_task_result result = ctx_server.queue_results.recv(task.id);
+ ctx_server.queue_results.remove_waiting_task_id(task.id);
+
+ if (result.error) {
+ res_error(res, result.data);
+ return;
+ }
+
+ res.set_content(result.data.dump(), "application/json");
+ };
+
+ const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore,
+ &handle_slots_erase](const httplib::Request &req,
+ httplib::Response &res) {
+ int id_slot = -1;
+ {
+ const std::string id_slot_str = req.path_params.at("id_slot");
+ if (!id_slot_str.empty()) {
+ try {
+ id_slot = std::stoi(id_slot_str);
+ } catch (const std::exception &) {
+ id_slot = -1;
+ }
+ }
+ if (id_slot < 0) {
+ res_error(res,
+ format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
+ return;
+ }
+ }
+
+ // forward
+ const std::string action = req.get_param_value("action");
+ if (action == "save") {
+ handle_slots_save(req, res, id_slot);
+ } else if (action == "restore") {
+ handle_slots_restore(req, res, id_slot);
+ } else if (action == "erase") {
+ handle_slots_erase(req, res, id_slot);
+ } else {
+ res_error(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
+ }
+ };
+
+ const auto handle_completions = [&ctx_server, &res_error](const httplib::Request &req,
+ httplib::Response &res) {
+ int tps = 0;
+ {
+ const std::string tps_s = req.get_header_value("X-Request-Tokens-Per-Second");
+ if (!tps_s.empty()) {
+ try {
+ tps = std::stoi(tps_s);
+ } catch (const std::exception &) {
+ tps = ctx_server.n_tps;
+ }
+ }
+ if (tps > ctx_server.n_tps) {
+ // if the request exceeds the maximum tokens per second, return
+ // 410 Gone
+ if (ctx_server.n_tps > 0) {
+ res.status = httplib::StatusCode::Gone_410;
+ res.set_content("This request exceeds the maximum tokens per second",
+ "text/plain; charset=utf-8");
+ return;
+ }
+ // if the server is not limited by tokens per second, set tps to
+ // 0
+ tps = 0;
+ }
+ }
+
+ bool oaicompat = req.path.compare("/v1/completions") == 0;
+ json request = json::parse(req.body);
+ if (!request.contains("prompt")) {
+ res_error(res, format_error_response("\"prompt\" must be provided",
+ ERROR_TYPE_INVALID_REQUEST));
+ return;
+ }
+ if (oaicompat) {
+ request = oaicompat_completion_request(ctx_server.model, request, std::string());
+ }
+
+ // post the task
+ const int id_task = ctx_server.queue_tasks.get_new_id();
+ ctx_server.queue_results.add_waiting_task_id(id_task);
+ ctx_server.request_completion(id_task, -1, request, false, false, tps);
+
+ const std::string completion_id = gen_cmplid();
+
+ // process non-streaming requests
+ if (!json_value(request, "stream", false)) {
+ server_task_result result = ctx_server.queue_results.recv(id_task);
+ if (result.error || !result.stop) {
+ res_error(res, result.data);
+ } else {
+ res.set_header("X-Response-Tokens-Per-Second",
+ std::to_string(json_value(result.data.at("timings"),
+ "predicted_per_second", tps)));
+
+ json completions_json = result.data;
+ if (req.path.compare("/v1/completions") == 0) {
+ completions_json =
+ oaicompat_completion_response(request, completions_json, completion_id);
+ }
+ const std::string completions =
+ completions_json.dump(-1, ' ', false, json::error_handler_t::replace);
+ res.set_content(completions, "application/json; charset=utf-8");
+ }
+
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
+ return;
+ }
+
+ // process streaming requests
+ const auto on_chunk = [id_task, &ctx_server, completion_id, oaicompat, request,
+ tps](size_t, httplib::DataSink &sink) {
+ while (true) {
+ server_task_result result = ctx_server.queue_results.recv(id_task);
+ if (result.error) {
+ const std::string str = "error: failed to complete\n\n";
+ sink.write(str.c_str(), str.size());
+ sink.done();
+ return false;
+ }
+
+ json completions_json = result.data;
+ if (oaicompat) {
+ completions_json = oaicompat_completion_response(request, completions_json,
+ completion_id, true);
+ }
+ const std::string completions =
+ "data: " +
+ completions_json.dump(-1, ' ', false, json::error_handler_t::replace) + "\n\n";
+ if (!sink.write(completions.c_str(), completions.size())) {
+ sink.done();
+ return false;
+ }
+
+ if (!result.stop) {
+ continue;
+ }
+
+ sink.done_with_trailer({{"X-Response-Tokens-Per-Second",
+ std::to_string(json_value(result.data.at("timings"),
+ "predicted_per_second", tps))}});
+ return true;
+ }
+ };
+ const auto on_complete = [id_task, &ctx_server](bool) {
+ ctx_server.request_cancel(id_task);
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
+ };
+
+ res.set_header("Trailer", "X-Response-Tokens-Per-Second");
+ res.set_chunked_content_provider("text/event-stream", on_chunk, on_complete);
+ };
+
+ const auto handle_models = [&ctx_server, ¶ms](const httplib::Request &,
+ httplib::Response &res) {
+ json models = {{"object", "list"},
+ {"data",
+ {
+ {{"id", params.model_alias},
+ {"object", "model"},
+ {"created", std::time(0)},
+ {"owned_by", "llama-box"},
+ {"meta", ctx_server.model_meta()}},
+ }}};
+
+ res.set_content(models.dump(), "application/json; charset=utf-8");
+ };
+
+ const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error](
+ const httplib::Request &req, httplib::Response &res) {
+ int tps = 0;
+ {
+ const std::string tps_s = req.get_header_value("X-Request-Tokens-Per-Second");
+ if (!tps_s.empty()) {
+ try {
+ tps = std::stoi(tps_s);
+ } catch (const std::exception &) {
+ tps = ctx_server.n_tps;
+ }
+ }
+ if (tps > ctx_server.n_tps) {
+ // if the request exceeds the maximum tokens per second, return
+ // 410 Gone
+ if (ctx_server.n_tps > 0) {
+ res.status = httplib::StatusCode::Gone_410;
+ res.set_content("This request exceeds the maximum tokens per second",
+ "text/plain; charset=utf-8");
+ return;
+ }
+ // if the server is not limited by tokens per second, set tps to
+ // 0
+ tps = 0;
+ }
+ }
+
+ json request = json::parse(req.body);
+ if (!request.contains("messages")) {
+ res_error(res, format_error_response("\"messages\" must be provided",
+ ERROR_TYPE_INVALID_REQUEST));
+ return;
+ }
+ request = oaicompat_completion_request(ctx_server.model, request, params.chat_template);
+
+ // post the task
+ const int id_task = ctx_server.queue_tasks.get_new_id();
+ ctx_server.queue_results.add_waiting_task_id(id_task);
+ ctx_server.request_completion(id_task, -1, request, false, false, tps);
+
+ const std::string completion_id = gen_chatcmplid();
+
+ // process non-streaming requests
+ if (!json_value(request, "stream", false)) {
+ server_task_result result = ctx_server.queue_results.recv(id_task);
+ if (result.error || !result.stop) {
+ res_error(res, result.data);
+ } else {
+ res.set_header("X-Response-Tokens-Per-Second",
+ std::to_string(json_value(result.data.at("timings"),
+ "predicted_per_second", tps)));
+
+ const json chats_completion_json =
+ oaicompat_completion_response(request, result.data, completion_id);
+ const std::string chats_completion =
+ chats_completion_json.dump(-1, ' ', false, json::error_handler_t::replace);
+ res.set_content(chats_completion, "application/json; charset=utf-8");
+ }
+
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
+ return;
+ }
+
+ // process streaming requests
+ const auto on_chunk = [id_task, &ctx_server, completion_id, request,
+ tps](size_t, httplib::DataSink &sink) {
+ bool first = true;
+ while (true) {
+ server_task_result result = ctx_server.queue_results.recv(id_task);
+ if (result.error) {
+ const std::string str = "error: failed to chat\n\n";
+ sink.write(str.c_str(), str.size());
+ sink.done();
+ return false;
+ }
+
+ if (first) {
+ first = false;
+ json chat_completions_json = oaicompat_completion_response(
+ request, result.data, completion_id, true, true);
+ const std::string chat_completions =
+ "data: " +
+ chat_completions_json.dump(-1, ' ', false, json::error_handler_t::replace) +
+ "\n\n";
+ if (!sink.write(chat_completions.c_str(), chat_completions.size())) {
+ sink.done();
+ return false;
+ }
+ }
+
+ json chat_completions_json =
+ oaicompat_completion_response(request, result.data, completion_id, true);
+ const std::string chat_completions =
+ "data: " +
+ chat_completions_json.dump(-1, ' ', false, json::error_handler_t::replace) +
+ "\n\n";
+ if (!sink.write(chat_completions.c_str(), chat_completions.size())) {
+ sink.done();
+ return false;
+ }
+
+ if (!result.stop) {
+ continue;
+ }
+
+ sink.done_with_trailer({{"X-Response-Tokens-Per-Second",
+ std::to_string(json_value(result.data.at("timings"),
+ "predicted_per_second", tps))}});
+ return true;
+ }
+ };
+ auto on_complete = [id_task, &ctx_server](bool) {
+ ctx_server.request_cancel(id_task);
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
+ };
+
+ res.set_header("Trailer", "X-Response-Tokens-Per-Second");
+ res.set_chunked_content_provider("text/event-stream", on_chunk, on_complete);
+ };
+
+ const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request &req,
+ httplib::Response &res) {
+ json request = json::parse(req.body);
+ if (!request.contains("input")) {
+ res_error(res, format_error_response("\"input\" must be provided",
+ ERROR_TYPE_INVALID_REQUEST));
+ return;
+ }
+ request = oaicompat_embedding_request(ctx_server.params, request);
+
+ // post the task
+ const int id_task = ctx_server.queue_tasks.get_new_id();
+ ctx_server.queue_results.add_waiting_task_id(id_task);
+ ctx_server.request_completion(id_task, -1, request, false, true);
+
+ // get the result
+ server_task_result result = ctx_server.queue_results.recv(id_task);
+ if (result.error || !result.stop) {
+ res_error(res, result.data);
+ } else {
+ const json embeddings_json = oaicompat_embedding_response(request, result.data);
+
+ const std::string embeddings = embeddings_json.dump();
+ return res.set_content(embeddings, "application/json; charset=utf-8");
+ }
+
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
+ };
+
+ //
+ // Router
+ //
+
+ svr.Get("/health", handle_health);
+ if (params.endpoint_metrics) {
+ svr.Get("/metrics", handle_metrics);
+ }
+ svr.Get("/props", handle_props);
+ if (params.infill) {
+ svr.Post("/infill", handle_infill);
+ }
+ svr.Post("/tokenize", handle_tokenize);
+ svr.Post("/detokenize", handle_detokenize);
+ if (params.endpoint_slots) {
+ svr.Get("/slots", handle_slots);
+ if (!params.slot_save_path.empty()) {
+ // only enable slot endpoints if slot_save_path is set
+ svr.Post("/slots/:id_slot", handle_slots_action);
+ }
+ }
+ svr.Post("/completion", handle_completions);
+ svr.Get("/v1/models", handle_models);
+ svr.Post("/v1/completions", handle_completions);
+ svr.Post("/v1/chat/completions", handle_chat_completions);
+ if (params.embedding) {
+ svr.Post("/v1/embeddings", handle_embeddings);
+ }
+
+ //
+ // Middlewares
+ //
+
+ svr.set_post_routing_handler([](const httplib::Request &req, httplib::Response &res) {
+ if (req.method == "POST") {
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
+ }
+ return httplib::Server::HandlerResponse::Handled;
+ });
+
+ //
+ // Start
+ //
+
+ if (params.n_threads_http < 1) {
+ // +2 threads for monitoring endpoints: /metrics and /slots
+ params.n_threads_http =
+ std::max(params.n_parallel + 2, (int32_t)std::thread::hardware_concurrency() - 1);
+ }
+ log_data["n_threads_http"] = std::to_string(params.n_threads_http);
+ svr.new_task_queue = [¶ms] { return new httplib::ThreadPool(params.n_threads_http); };
+
+ LOG_INFO("HTTP server listening", log_data);
+ // run the HTTP server in a thread - see comment below
+ std::thread t([&]() {
+ if (!svr.listen_after_bind()) {
+ state.store(SERVER_STATE_ERROR);
+ return 1;
+ }
+
+ return 0;
+ });
+
+ ctx_server.queue_tasks.on_new_task(
+ std::bind(&server_context::process_single_task, &ctx_server, std::placeholders::_1));
+ ctx_server.queue_tasks.on_finish_multitask(
+ std::bind(&server_context::on_finish_multitask, &ctx_server, std::placeholders::_1));
+ ctx_server.queue_tasks.on_update_slots(std::bind(&server_context::update_slots, &ctx_server));
+ ctx_server.queue_results.on_multitask_update(
+ std::bind(&server_queue::update_multitask, &ctx_server.queue_tasks, std::placeholders::_1,
+ std::placeholders::_2, std::placeholders::_3));
+
+ shutdown_handler = [&](int) { ctx_server.queue_tasks.terminate(); };
+
+#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__))
+ struct sigaction sigint_action;
+ sigint_action.sa_handler = signal_handler;
+ sigemptyset(&sigint_action.sa_mask);
+ sigint_action.sa_flags = 0;
+ sigaction(SIGINT, &sigint_action, NULL);
+ sigaction(SIGTERM, &sigint_action, NULL);
+#elif defined(_WIN32)
+ auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
+ return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
+ };
+ SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true);
+#endif
+
+ ctx_server.queue_tasks.start_loop();
+ svr.stop();
+ t.join();
+
+ llama_backend_free();
+ return 0;
+}
diff --git a/llama-box/param.hpp b/llama-box/param.hpp
new file mode 100644
index 0000000..bdb97e7
--- /dev/null
+++ b/llama-box/param.hpp
@@ -0,0 +1,1144 @@
+#pragma once
+
+#include
+
+#include "llama.cpp/common/common.h"
+#include "llama.cpp/common/grammar-parser.h"
+#include "llama.cpp/common/json-schema-to-grammar.h"
+#define JSON_ASSERT GGML_ASSERT
+#include "llama.cpp/common/json.hpp"
+#include "llama.cpp/ggml.h"
+#include "llama.cpp/llama.h"
+
+// version
+extern const char *LLAMA_BOX_BUILD_DATE;
+extern const char *LLAMA_BOX_GIT_TREE_STATE;
+extern const char *LLAMA_BOX_GIT_VERSION;
+extern const char *LLAMA_BOX_GIT_COMMIT;
+
+using json = nlohmann::json;
+
+struct llama_box_params {
+ gpt_params gparams;
+
+ int32_t conn_idle = 60; // connection idle in seconds
+ int32_t conn_keepalive = 15; // connection keep-alive in seconds
+ int32_t n_tps = 0; // maximum number of tokens per seconds
+};
+
+static int unknown(const char *flag) {
+ throw std::invalid_argument("Unknown argument: " + std::string(flag));
+ return 1;
+}
+
+static int missing(const char *flag) {
+ throw std::invalid_argument("Missing argument: " + std::string(flag));
+ return 1;
+}
+
+static int invalid(const char *flag) {
+ throw std::invalid_argument("Invalid argument: " + std::string(flag));
+ return 1;
+}
+
+#ifdef __GNUC__
+#ifdef __MINGW32__
+#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
+#else
+#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
+#endif
+#else
+#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...)
+#endif
+
+void llama_box_params_print_usage(int, char **argv, const llama_box_params &bparams) {
+ struct opt {
+ LLAMA_COMMON_ATTRIBUTE_FORMAT(4, 5)
+
+ opt(const std::string &tags, const char *args, const char *desc, ...)
+ : tags(tags), args(args), desc(desc) {
+ va_list args_list;
+ va_start(args_list, desc);
+ char buffer[1024];
+ vsnprintf(buffer, sizeof(buffer), desc, args_list);
+ va_end(args_list);
+ this->desc = buffer;
+ }
+
+ opt(const std::string &grp)
+ : grp(grp) {
+ }
+
+ std::string tags;
+ std::string args;
+ std::string desc;
+ std::string grp;
+ };
+
+ const auto ¶ms = bparams.gparams;
+ const auto &sparams = params.sparams;
+ std::string sampler_type_chars;
+ std::string sampler_type_names;
+ for (const auto sampler_type : sparams.samplers_sequence) {
+ sampler_type_chars += static_cast(sampler_type);
+ sampler_type_names += llama_sampling_type_to_str(sampler_type) + ";";
+ }
+ sampler_type_names.pop_back();
+
+ std::vector opts;
+ // clang-format off
+ opts.push_back({ "general" });
+ opts.push_back({ "*", "-h, --help, --usage", "print usage and exit" });
+ opts.push_back({ "*", " --version", "show version and build info" });
+ opts.push_back({ "*", "-m, --model FILE", "model path (default: %s)", DEFAULT_MODEL_PATH });
+ opts.push_back({ "*", "-a, --alias NAME", "model name alias (default: %s)", params.model_alias.c_str() });
+ opts.push_back({ "*", "-s, --seed N", "RNG seed (default: %d, use random seed for < 0)", params.seed });
+ opts.push_back({ "*", "-t, --threads N", "number of threads to use during generation (default: %d)", params.n_threads });
+ opts.push_back({ "*", "-tb, --threads-batch N", "number of threads to use during batch and prompt processing (default: same as --threads)" });
+ opts.push_back({ "*", "-lcs, --lookup-cache-static FILE",
+ "path to static lookup cache to use for lookup decoding (not updated by generation)" });
+ opts.push_back({ "*", "-lcd, --lookup-cache-dynamic FILE",
+ "path to dynamic lookup cache to use for lookup decoding (updated by generation)" });
+ opts.push_back({ "*", "-c, --ctx-size N", "size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx });
+ opts.push_back({ "*", "-n, --predict N", "number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)", params.n_predict });
+ opts.push_back({ "*", "-b, --batch-size N", "logical maximum batch size (default: %d)", params.n_batch });
+ opts.push_back({ "*", "-ub, --ubatch-size N", "physical maximum batch size (default: %d)", params.n_ubatch });
+ opts.push_back({ "*", " --keep N", "number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep });
+ opts.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks });
+ opts.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" });
+ opts.push_back({ "*", " --no-escape", "do not process escape sequences" });
+ opts.push_back({ "*", " --samplers SAMPLERS", "samplers that will be used for generation in the order, separated by \';\'\n"
+ "(default: %s)", sampler_type_names.c_str() });
+ opts.push_back({ "*", " --sampling-seq SEQUENCE",
+ "simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str() });
+ opts.push_back({ "*", " --penalize-nl", "penalize newline tokens (default: %s)", sparams.penalize_nl ? "true" : "false" });
+ opts.push_back({ "*", " --temp N", "temperature (default: %.1f)", (double)sparams.temp });
+ opts.push_back({ "*", " --top-k N", "top-k sampling (default: %d, 0 = disabled)", sparams.top_k });
+ opts.push_back({ "*", " --top-p N", "top-p sampling (default: %.1f, 1.0 = disabled)", (double)sparams.top_p });
+ opts.push_back({ "*", " --min-p N", "min-p sampling (default: %.1f, 0.0 = disabled)", (double)sparams.min_p });
+ opts.push_back({ "*", " --tfs N", "tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)sparams.tfs_z });
+ opts.push_back({ "*", " --typical N", "locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)sparams.typical_p });
+ opts.push_back({ "*", " --repeat-last-n N", "last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", sparams.penalty_last_n });
+ opts.push_back({ "*", " --repeat-penalty N", "penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)sparams.penalty_repeat });
+ opts.push_back({ "*", " --presence-penalty N", "repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_present });
+ opts.push_back({ "*", " --frequency-penalty N", "repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_freq });
+ opts.push_back({ "*", " --dynatemp-range N", "dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)sparams.dynatemp_range });
+ opts.push_back({ "*", " --dynatemp-exp N", "dynamic temperature exponent (default: %.1f)", (double)sparams.dynatemp_exponent });
+ opts.push_back({ "*", " --mirostat N", "use Mirostat sampling.\n"
+ "Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"
+ "(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", sparams.mirostat });
+ opts.push_back({ "*", " --mirostat-lr N", "Mirostat learning rate, parameter eta (default: %.1f)", (double)sparams.mirostat_eta });
+ opts.push_back({ "*", " --mirostat-ent N", "Mirostat target entropy, parameter tau (default: %.1f)", (double)sparams.mirostat_tau });
+ opts.push_back({ "*", "-l --logit-bias TOKEN_ID(+/-)BIAS",
+ "modifies the likelihood of token appearing in the completion,\n"
+ "i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"
+ "or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" });
+ opts.push_back({ "*", " --grammar GRAMMAR", "BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", sparams.grammar.c_str() });
+ opts.push_back({ "*", " --grammar-file FILE", "file to read grammar from" });
+ opts.push_back({ "*", "-j, --json-schema SCHEMA",
+ "JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\n"
+ "For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead" });
+ opts.push_back({ "*", " --rope-scaling {none,linear,yarn}",
+ "RoPE frequency scaling method, defaults to linear unless specified by the model" });
+ opts.push_back({ "*", " --rope-scale N", "RoPE context scaling factor, expands context by a factor of N" });
+ opts.push_back({ "*", " --rope-freq-base N", "RoPE base frequency, used by NTK-aware scaling (default: loaded from model)" });
+ opts.push_back({ "*", " --rope-freq-scale N", "RoPE frequency scaling factor, expands context by a factor of 1/N" });
+ opts.push_back({ "*", " --yarn-orig-ctx N", "YaRN: original context size of model (default: %d = model training context size)", params.yarn_orig_ctx });
+ opts.push_back({ "*", " --yarn-ext-factor N", "YaRN: extrapolation mix factor (default: %.1f, 0.0 = full interpolation)", (double)params.yarn_ext_factor });
+ opts.push_back({ "*", " --yarn-attn-factor N", "YaRN: scale sqrt(t) or attention magnitude (default: %.1f)", (double)params.yarn_attn_factor });
+ opts.push_back({ "*", " --yarn-beta-fast N", "YaRN: low correction dim or beta (default: %.1f)", (double)params.yarn_beta_fast });
+ opts.push_back({ "*", " --yarn-beta-slow N", "YaRN: high correction dim or alpha (default: %.1f)", (double)params.yarn_beta_slow });
+ opts.push_back({ "*", "-gan, --grp-attn-n N", "group-attention factor (default: %d)", params.grp_attn_n });
+ opts.push_back({ "*", "-gaw, --grp-attn-w N", "group-attention width (default: %.1f)", (double)params.grp_attn_w });
+ opts.push_back({ "*", "-nkvo, --no-kv-offload", "disable KV offload" });
+ opts.push_back({ "*", "-ctk, --cache-type-k TYPE", "KV cache data type for K (default: %s)", params.cache_type_k.c_str() });
+ opts.push_back({ "*", "-ctv, --cache-type-v TYPE", "KV cache data type for V (default: %s)", params.cache_type_v.c_str() });
+ opts.push_back({ "*", "-dt, --defrag-thold N", "KV cache defragmentation threshold (default: %.1f, < 0 - disabled)", (double)params.defrag_thold });
+ opts.push_back({ "*", "-np, --parallel N", "number of parallel sequences to decode (default: %d)", params.n_parallel });
+ opts.push_back({ "*", "-cb, --cont-batching", "enable continuous batching (a.k.a dynamic batching) (default: %s)", params.cont_batching ? "enabled" : "disabled" });
+ opts.push_back({ "*", " --mmproj FILE", "path to a multimodal projector file for LLaVA. see examples/llava/README.md" });
+ if (llama_supports_mlock()) {
+ opts.push_back({ "*", " --mlock", "force system to keep model in RAM rather than swapping or compressing" });
+ }
+ if (llama_supports_mmap()) {
+ opts.push_back({ "*", " --no-mmap", "do not memory-map model (slower load but may reduce pageouts if not using mlock)" });
+ }
+ opts.push_back({ "*", " --numa TYPE", "attempt optimizations that help on some NUMA systems\n"
+ " - distribute: spread execution evenly over all nodes\n"
+ " - isolate: only spawn threads on CPUs on the node that execution started on\n"
+ " - numactl: use the CPU map provided by numactl\n"
+ "if run without this previously, it is recommended to drop the system page cache before using this\n"
+ "see https://github.com/ggerganov/llama.cpp/issues/1437" });
+ opts.push_back({ "*", " --override-kv KEY=TYPE:VALUE",
+ "advanced option to override model metadata by key. may be specified multiple times.\n"
+ "types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false" });
+ opts.push_back({ "*", " --lora FILE", "apply LoRA adapter (implies --no-mmap)" });
+ opts.push_back({ "*", " --lora-scaled FILE SCALE",
+ "apply LoRA adapter with user defined scaling S (implies --no-mmap)" });
+ opts.push_back({ "*", " --lora-base FILE", "optional model to use as a base for the layers modified by the LoRA adapter" });
+ opts.push_back({ "*", " --control-vector FILE", "add a control vector" });
+ opts.push_back({ "*", " --control-vector-scaled FILE SCALE",
+ "add a control vector with user defined scaling SCALE" });
+ opts.push_back({ "*", " --control-vector-layer-range START END",
+ "layer range to apply the control vector(s) to, start and end inclusive" });
+ if (llama_supports_gpu_offload()) {
+ opts.push_back({ "*", "-ngl, --gpu-layers N", "number of layers to store in VRAM" });
+ opts.push_back({ "*", "-sm, --split-mode SPLIT_MODE",
+ "how to split the model across multiple GPUs, one of:\n"
+ " - none: use one GPU only\n"
+ " - layer (default): split layers and KV across GPUs\n"
+ " - row: split rows across GPUs" });
+ opts.push_back({ "*", "-ts, --tensor-split SPLIT",
+ "fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1" });
+ opts.push_back({ "*", "-mg, --main-gpu N", "the GPU to use for the model (with split-mode = none),\n"
+ "or for intermediate results and KV (with split-mode = row) (default: %d)", params.main_gpu });
+ }
+
+ opts.push_back({ "server" });
+ opts.push_back({ "server", " --host HOST", "ip address to listen (default: %s)", params.hostname.c_str() });
+ opts.push_back({ "server", " --port PORT", "port to listen (default: %d)", params.port });
+ opts.push_back({ "server", "-to --timeout N", "server read/write timeout in seconds (default: %d)", params.timeout_read });
+ opts.push_back({ "server", " --threads-http N", "number of threads used to process HTTP requests (default: %d)", params.n_threads_http });
+ opts.push_back({ "server", " --system-prompt-file FILE",
+ "set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications" });
+ opts.push_back({ "server", " --metrics", "enable prometheus compatible metrics endpoint (default: %s)", params.endpoint_metrics ? "enabled" : "disabled" });
+ opts.push_back({ "server", " --infill", "enable infill endpoint (default: %s)", params.infill? "enabled" : "disabled" });
+ opts.push_back({ "server", " --embeddings", "enable embedding endpoint (default: %s)", params.embedding ? "enabled" : "disabled" });
+ opts.push_back({ "server", " --no-slots", "disables slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled" });
+ opts.push_back({ "server", " --slot-save-path PATH", "path to save slot kv cache (default: disabled)" });
+ opts.push_back({ "server", " --chat-template JINJA_TEMPLATE",
+ "set custom jinja chat template (default: template taken from model's metadata)\n"
+ "only commonly used templates are accepted:\n"
+ "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
+ opts.push_back({ "server", " --chat-template-file FILE",
+ "set a file to load a custom jinja chat template" });
+ opts.push_back({ "server", "-sps, --slot-prompt-similarity N",
+ "how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity });
+ opts.push_back({ "server", " --conn-idle N", "server connection idle in seconds (default: %d)", bparams.conn_idle });
+ opts.push_back({ "server", " --conn-keepalive N", "server connection keep-alive in seconds (default: %d)", bparams.conn_keepalive });
+ opts.push_back({ "server", "-tps --tokens-per-second N", "maximum number of tokens per second (default: %d, 0 = disabled, -1 = try to detect)", bparams.n_tps });
+
+ opts.push_back({ "logging" });
+ opts.push_back({ "logging", " --log-format {text,json}",
+ "log output format: json or text (default: json)" });
+ // clang-format on
+
+ printf("usage: %s [options]\n", argv[0]);
+
+ for (const auto &o : opts) {
+ if (!o.grp.empty()) {
+ printf("\n%s:\n\n", o.grp.c_str());
+ continue;
+ }
+ printf(" %-32s", o.args.c_str());
+ if (o.args.length() > 30) {
+ printf("\n%34s", "");
+ }
+
+ const auto desc = o.desc;
+ size_t start = 0;
+ size_t end = desc.find('\n');
+ while (end != std::string::npos) {
+ printf("%s\n%34s", desc.substr(start, end - start).c_str(), "");
+ start = end + 1;
+ end = desc.find('\n', start);
+ }
+
+ printf("%s\n", desc.substr(start).c_str());
+ }
+ printf("\n");
+}
+
+bool llama_box_params_parse(int argc, char **argv, llama_box_params &bparams) {
+ try {
+ for (int i = 1; i < argc;) {
+ const char *flag = argv[i++];
+
+ if (*flag != '-') {
+ unknown(flag);
+ }
+
+ // general flags
+
+ if (!strcmp(flag, "-h") || !strcmp(flag, "--help") || !strcmp(flag, "--usage")) {
+ llama_box_params_print_usage(argc, argv, bparams);
+ exit(0);
+ }
+
+ if (!strcmp(flag, "--version")) {
+ fprintf(stderr, "version: %s (%s)\n", LLAMA_BOX_GIT_VERSION, LLAMA_BOX_GIT_COMMIT);
+ fprintf(stderr, "llama.cpp version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
+ fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET);
+ exit(0);
+ }
+
+ if (!strcmp(flag, "-m") || !strcmp(flag, "--model")) {
+ if (i == argc) {
+ missing("--model");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.model = std::string(arg);
+ continue;
+ }
+
+ if (!strcmp(flag, "-a") || !strcmp(flag, "--alias")) {
+ if (i == argc) {
+ missing("--alias");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.model_alias = std::string(arg);
+ continue;
+ }
+
+ if (!strcmp(flag, "-s") || !strcmp(flag, "--seed")) {
+ if (i == argc) {
+ missing("--seed");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.seed = std::stoul(std::string(arg));
+ bparams.gparams.sparams.seed = bparams.gparams.seed;
+ continue;
+ }
+
+ if (!strcmp(flag, "-t") || !strcmp(flag, "--threads")) {
+ if (i == argc) {
+ missing("--threads");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.n_threads = std::stoi(std::string(arg));
+ if (bparams.gparams.n_threads <= 0) {
+ bparams.gparams.n_threads = std::thread::hardware_concurrency();
+ }
+ continue;
+ }
+
+ if (!strcmp(flag, "-tb") || !strcmp(flag, "--threads-batch")) {
+ if (i == argc) {
+ missing("--threads-batch");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.n_threads_batch = std::stoi(std::string(arg));
+ if (bparams.gparams.n_threads_batch <= 0) {
+ bparams.gparams.n_threads_batch = std::thread::hardware_concurrency();
+ }
+ continue;
+ }
+
+ if (!strcmp(flag, "-lcs") || !strcmp(flag, "--lookup-cache-static")) {
+ if (i == argc) {
+ missing("--lookup-cache-static");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.lookup_cache_static = std::string(arg);
+ continue;
+ }
+
+ if (!strcmp(flag, "-lcd") || !strcmp(flag, "--lookup-cache-dynamic")) {
+ if (i == argc) {
+ missing("--lookup-cache-dynamic");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.lookup_cache_dynamic = std::string(arg);
+ continue;
+ }
+
+ if (!strcmp(flag, "-c") || !strcmp(flag, "--ctx-size")) {
+ if (i == argc) {
+ missing("--ctx-size");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.n_ctx = std::stoi(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "-n") || !strcmp(flag, "--predict")) {
+ if (i == argc) {
+ missing("--predict");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.n_predict = std::stoi(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "-b") || !strcmp(flag, "--batch-size")) {
+ if (i == argc) {
+ missing("--batch-size");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.n_batch = std::stoi(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "-ub") || !strcmp(flag, "--ubatch-size")) {
+ if (i == argc) {
+ missing("--ubatch-size");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.n_ubatch = std::stoi(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "--keep")) {
+ if (i == argc) {
+ missing("--keep");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.n_keep = std::stoi(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "--chunks")) {
+ if (i == argc) {
+ missing("--chunks");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.n_chunks = std::stoi(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "-fa") || !strcmp(flag, "--flash-attn")) {
+ bparams.gparams.flash_attn = true;
+ continue;
+ }
+
+ if (!strcmp(flag, "--no-escape")) {
+ bparams.gparams.escape = false;
+ continue;
+ }
+
+ if (!strcmp(flag, "--samplers")) {
+ if (i == argc) {
+ missing("--samplers");
+ }
+ char *arg = argv[i++];
+ const auto sampler_names = string_split(arg, ';');
+ bparams.gparams.sparams.samplers_sequence =
+ llama_sampling_types_from_names(sampler_names, true);
+ continue;
+ }
+
+ if (!strcmp(flag, "--sampling-seq")) {
+ if (i == argc) {
+ missing("--sampling-seq");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.sparams.samplers_sequence = llama_sampling_types_from_chars(arg);
+ continue;
+ }
+
+ if (!strcmp(flag, "--penalize-nl")) {
+ bparams.gparams.sparams.penalize_nl = true;
+ continue;
+ }
+
+ if (!strcmp(flag, "--temp")) {
+ if (i == argc) {
+ missing("--temp");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.sparams.temp = std::stof(std::string(arg));
+ bparams.gparams.sparams.temp = std::max(bparams.gparams.sparams.temp, 0.0f);
+ continue;
+ }
+
+ if (!strcmp(flag, "--top-k")) {
+ if (i == argc) {
+ missing("--top-k");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.sparams.top_k = std::stoi(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "--top-p")) {
+ if (i == argc) {
+ missing("--top-p");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.sparams.top_p = std::stof(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "--min-p")) {
+ if (i == argc) {
+ missing("--min-p");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.sparams.min_p = std::stof(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "--tfs")) {
+ if (i == argc) {
+ missing("--tfs");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.sparams.tfs_z = std::stof(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "--typical")) {
+ if (i == argc) {
+ missing("--typical");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.sparams.typical_p = std::stof(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "--repeat-last-n")) {
+ if (i == argc) {
+ missing("--repeat-last-n");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.sparams.penalty_last_n = std::stoi(std::string(arg));
+ bparams.gparams.sparams.n_prev = std::max(bparams.gparams.sparams.n_prev,
+ bparams.gparams.sparams.penalty_last_n);
+ continue;
+ }
+
+ if (!strcmp(flag, "--repeat-penalty")) {
+ if (i == argc) {
+ missing("--repeat-penalty");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.sparams.penalty_repeat = std::stof(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "--presence-penalty")) {
+ if (i == argc) {
+ missing("--presence-penalty");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.sparams.penalty_present = std::stof(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "--frequency-penalty")) {
+ if (i == argc) {
+ missing("--frequency-penalty");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.sparams.penalty_freq = std::stof(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "--dynatemp-range")) {
+ if (i == argc) {
+ missing("--dynatemp-range");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.sparams.dynatemp_range = std::stof(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "--dynatemp-exp")) {
+ if (i == argc) {
+ missing("--dynatemp-exp");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.sparams.dynatemp_exponent = std::stof(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "--mirostat")) {
+ if (i == argc) {
+ missing("--mirostat");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.sparams.mirostat = std::stoi(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "--mirostat-lr")) {
+ if (i == argc) {
+ missing("--mirostat-lr");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.sparams.mirostat_eta = std::stof(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "--mirostat-ent")) {
+ if (i == argc) {
+ missing("--mirostat-ent");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.sparams.mirostat_tau = std::stof(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "-l") || !strcmp(flag, "--logit-bias")) {
+ if (i == argc) {
+ missing("--logit-bias");
+ }
+ char *arg = argv[i++];
+ std::stringstream ss(arg);
+ llama_token key;
+ char sign;
+ std::string value_str;
+ if (ss >> key && ss >> sign && std::getline(ss, value_str) &&
+ (sign == '+' || sign == '-')) {
+ bparams.gparams.sparams.logit_bias[key] =
+ std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
+ } else {
+ invalid("--logit-bias");
+ }
+ continue;
+ }
+
+ if (!strcmp(flag, "--grammar")) {
+ if (i == argc) {
+ missing("--grammar");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.sparams.grammar = std::string(arg);
+ continue;
+ }
+
+ if (!strcmp(flag, "--grammar-file")) {
+ if (i == argc) {
+ missing("--grammar-file");
+ }
+ char *arg = argv[i++];
+ std::ifstream file(arg);
+ if (!file) {
+ invalid("--grammar-file");
+ }
+ std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(),
+ std::back_inserter(bparams.gparams.sparams.grammar));
+ continue;
+ }
+
+ if (!strcmp(flag, "-j") || !strcmp(flag, "--json-schema")) {
+ if (i == argc) {
+ missing("--json-schema");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.sparams.grammar =
+ json_schema_to_grammar(json::parse(std::string(arg)));
+ continue;
+ }
+
+ if (!strcmp(flag, "--rope-scaling")) {
+ if (i == argc) {
+ missing("--rope-scaling");
+ }
+ char *arg = argv[i++];
+ std::string value(arg);
+ if (value == "none") {
+ bparams.gparams.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE;
+ } else if (value == "linear") {
+ bparams.gparams.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR;
+ } else if (value == "yarn") {
+ bparams.gparams.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN;
+ } else {
+ invalid("--rope-scaling");
+ }
+ continue;
+ }
+
+ if (!strcmp(flag, "--rope-scale")) {
+ if (i == argc) {
+ missing("--rope-scale");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.rope_freq_scale = 1.0f / std::stof(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "--rope-freq-base")) {
+ if (i == argc) {
+ missing("--rope-freq-base");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.rope_freq_base = std::stof(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "--rope-freq-scale")) {
+ if (i == argc) {
+ missing("--rope-freq-scale");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.rope_freq_scale = std::stof(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "--yarn-orig-ctx")) {
+ if (i == argc) {
+ missing("--yarn-orig-ctx");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.yarn_orig_ctx = std::stoi(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "--yarn-ext-factor")) {
+ if (i == argc) {
+ missing("--yarn-ext-factor");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.yarn_ext_factor = std::stof(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "--yarn-attn-factor")) {
+ if (i == argc) {
+ missing("--yarn-attn-factor");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.yarn_attn_factor = std::stof(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "--yarn-beta-fast")) {
+ if (i == argc) {
+ missing("--yarn-beta-fast");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.yarn_beta_fast = std::stof(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "--yarn-beta-slow")) {
+ if (i == argc) {
+ missing("--yarn-beta-slow");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.yarn_beta_slow = std::stof(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "-gan") || !strcmp(flag, "--grp-attn-n")) {
+ if (i == argc) {
+ missing("--grp-attn-n");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.grp_attn_n = std::stoi(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "-gaw") || !strcmp(flag, "--grp-attn-w")) {
+ if (i == argc) {
+ missing("--grp-attn-w");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.grp_attn_w = std::stoi(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "-nkvo") || !strcmp(flag, "--no-kv-offload")) {
+ bparams.gparams.no_kv_offload = true;
+ continue;
+ }
+
+ if (!strcmp(flag, "-ctk") || !strcmp(flag, "--cache-type-k")) {
+ if (i == argc) {
+ missing("--cache-type-k");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.cache_type_k = std::string(arg);
+ continue;
+ }
+
+ if (!strcmp(flag, "-ctv") || !strcmp(flag, "--cache-type-v")) {
+ if (i == argc) {
+ missing("--cache-type-v");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.cache_type_v = std::string(arg);
+ continue;
+ }
+
+ if (!strcmp(flag, "-dt") || !strcmp(flag, "--defrag-thold")) {
+ if (i == argc) {
+ missing("--defrag-thold");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.defrag_thold = std::stof(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "-np") || !strcmp(flag, "--parallel")) {
+ if (i == argc) {
+ missing("--parallel");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.n_parallel = std::stoi(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "-cb") || !strcmp(flag, "--cont-batching")) {
+ bparams.gparams.cont_batching = true;
+ continue;
+ }
+
+ if (!strcmp(flag, "--mmproj")) {
+ if (i == argc) {
+ missing("--mmproj");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.mmproj = std::string(arg);
+ continue;
+ }
+
+ if (llama_supports_mlock()) {
+ if (!strcmp(flag, "--mlock")) {
+ bparams.gparams.use_mlock = true;
+ continue;
+ }
+ }
+
+ if (llama_supports_mmap()) {
+ if (!strcmp(flag, "--no-mmap")) {
+ bparams.gparams.use_mmap = false;
+ continue;
+ }
+ }
+
+ if (!strcmp(flag, "--numa")) {
+ if (i == argc) {
+ missing("--numa");
+ }
+ char *arg = argv[i++];
+ std::string value(arg);
+ if (value == "distribute") {
+ bparams.gparams.numa = GGML_NUMA_STRATEGY_DISTRIBUTE;
+ } else if (value == "isolate") {
+ bparams.gparams.numa = GGML_NUMA_STRATEGY_ISOLATE;
+ } else if (value == "numactl") {
+ bparams.gparams.numa = GGML_NUMA_STRATEGY_NUMACTL;
+ } else {
+ invalid("--numa");
+ }
+ }
+
+ if (!strcmp(flag, "--override-kv")) {
+ if (i == argc) {
+ missing("--override-kv");
+ }
+ char *arg = argv[i++];
+ if (!string_parse_kv_override(arg, bparams.gparams.kv_overrides)) {
+ invalid("--override-kv");
+ }
+ continue;
+ }
+
+ if (!strcmp(flag, "--lora")) {
+ if (i == argc) {
+ missing("--lora");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.lora_adapter.emplace_back(std::string(arg), 1.0f);
+ bparams.gparams.use_mmap = false;
+ continue;
+ }
+
+ if (!strcmp(flag, "--lora-scaled")) {
+ if (i == argc) {
+ missing("--lora-scaled");
+ }
+ char *n = argv[i++];
+ if (i == argc) {
+ invalid("--lora-scaled");
+ }
+ char *s = argv[i++];
+ bparams.gparams.lora_adapter.emplace_back(std::string(n),
+ std::stof(std::string(s)));
+ bparams.gparams.use_mmap = false;
+ continue;
+ }
+
+ if (!strcmp(flag, "--lora-base")) {
+ if (i == argc) {
+ missing("--lora-base");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.lora_base = std::string(arg);
+ continue;
+ }
+
+ if (!strcmp(flag, "--control-vector")) {
+ if (i == argc) {
+ missing("--control-vector");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.control_vectors.push_back({1.0f, std::string(arg)});
+ continue;
+ }
+
+ if (!strcmp(flag, "--control-vector-scaled")) {
+ if (i == argc) {
+ missing("--control-vector-scaled");
+ }
+ char *n = argv[i++];
+ if (i == argc) {
+ invalid("--control-vector-scaled");
+ }
+ char *s = argv[i++];
+ bparams.gparams.control_vectors.push_back(
+ {std::stof(std::string(s)), std::string(n)});
+ continue;
+ }
+
+ if (!strcmp(flag, "--control-vector-layer-range")) {
+ if (i == argc) {
+ missing("--control-vector-layer-range");
+ }
+ char *s = argv[i++];
+ if (i == argc) {
+ invalid("--control-vector-layer-range");
+ }
+ char *e = argv[i++];
+ bparams.gparams.control_vector_layer_start = std::stoi(std::string(s));
+ bparams.gparams.control_vector_layer_end = std::stoi(std::string(e));
+ continue;
+ }
+
+ if (llama_supports_gpu_offload()) {
+ if (!strcmp(flag, "-ngl") || !strcmp(flag, "--gpu-layers")) {
+ if (i == argc) {
+ missing("--gpu-layers");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.n_gpu_layers = std::stoi(arg);
+ continue;
+ }
+
+ if (!strcmp(flag, "-sm") || !strcmp(flag, "--split-mode")) {
+ if (i == argc) {
+ missing("--split-mode");
+ }
+ char *arg = argv[i++];
+ if (!strcmp(arg, "none")) {
+ bparams.gparams.split_mode = LLAMA_SPLIT_MODE_NONE;
+ } else if (!strcmp(arg, "layer")) {
+ bparams.gparams.split_mode = LLAMA_SPLIT_MODE_LAYER;
+ } else if (!strcmp(arg, "row")) {
+ bparams.gparams.split_mode = LLAMA_SPLIT_MODE_ROW;
+ } else {
+ invalid("--split-mode");
+ }
+ continue;
+ }
+
+ if (!strcmp(flag, "-ts") || !strcmp(flag, "--tensor-split")) {
+ if (i == argc) {
+ missing("--tensor-split");
+ }
+ char *arg = argv[i++];
+ const std::regex regex{R"([,/]+)"};
+ std::string arg_s{arg};
+ std::sregex_token_iterator it{arg_s.begin(), arg_s.end(), regex, -1};
+ std::vector split_arg{it, {}};
+ if (split_arg.size() >= llama_max_devices()) {
+ invalid("--tensor-split");
+ }
+ for (size_t i = 0; i < llama_max_devices(); ++i) {
+ if (i < split_arg.size()) {
+ bparams.gparams.tensor_split[i] = std::stof(split_arg[i]);
+ } else {
+ bparams.gparams.tensor_split[i] = 0.0f;
+ }
+ }
+ continue;
+ }
+
+ if (!strcmp(flag, "-mg") || !strcmp(flag, "--main-gpu")) {
+ if (i == argc) {
+ missing("--main-gpu");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.main_gpu = std::stoi(std::string(arg));
+ continue;
+ }
+ }
+
+ // server flags
+
+ if (!strcmp(flag, "--host")) {
+ if (i == argc) {
+ missing("--host");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.hostname = std::string(arg);
+ continue;
+ }
+
+ if (!strcmp(flag, "--port")) {
+ if (i == argc) {
+ missing("--port");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.port = std::stoi(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "-to") || !strcmp(flag, "--timeout")) {
+ if (i == argc) {
+ missing("--timeout");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.timeout_read = std::stoi(std::string(arg));
+ bparams.gparams.timeout_write = bparams.gparams.timeout_read;
+ continue;
+ }
+
+ if (!strcmp(flag, "--threads-http")) {
+ if (i == argc) {
+ missing("--threads-http");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.n_threads_http = std::stoi(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "-spf") || !strcmp(flag, "--system-prompt-file")) {
+ if (i == argc) {
+ missing("--system-prompt-file");
+ }
+ char *arg = argv[i++];
+ std::ifstream file(arg);
+ if (!file) {
+ invalid("--system-prompt-file");
+ }
+ std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(),
+ std::back_inserter(bparams.gparams.system_prompt));
+ continue;
+ }
+
+ if (!strcmp(flag, "--metrics")) {
+ bparams.gparams.endpoint_metrics = true;
+ continue;
+ }
+
+ if (!strcmp(flag, "--infill")) {
+ bparams.gparams.infill = true;
+ continue;
+ }
+
+ if (!strcmp(flag, "--embedding") || !strcmp(flag, "--embeddings")) {
+ bparams.gparams.embedding = true;
+ continue;
+ }
+
+ if (!strcmp(flag, "--no-slots")) {
+ bparams.gparams.endpoint_slots = false;
+ continue;
+ }
+
+ if (!strcmp(flag, "--slot-save-path")) {
+ if (i == argc) {
+ missing("--slot-save-path");
+ }
+ char *arg = argv[i++];
+ if (arg[0] == '\0') {
+ invalid("--slot-save-path");
+ }
+ std::string p(arg);
+ if (p[p.size() - 1] != DIRECTORY_SEPARATOR) {
+ p += DIRECTORY_SEPARATOR;
+ }
+ bparams.gparams.slot_save_path = p;
+ continue;
+ }
+
+ if (!strcmp(flag, "--chat-template")) {
+ if (i == argc) {
+ missing("--chat-template");
+ }
+ char *arg = argv[i++];
+ if (arg[0] == '\0') {
+ invalid("--chat-template");
+ }
+ std::string t(arg);
+ if (!llama_chat_verify_template(t)) {
+ invalid("--chat-template");
+ }
+ bparams.gparams.chat_template = t;
+ continue;
+ }
+
+ if (!strcmp(flag, "--chat-template-file")) {
+ if (i == argc) {
+ missing("--chat-template-file");
+ }
+ char *arg = argv[i++];
+ std::ifstream file(arg);
+ if (!file) {
+ invalid("--chat-template-file");
+ }
+ std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(),
+ std::back_inserter(bparams.gparams.chat_template));
+ continue;
+ }
+
+ if (!strcmp(flag, "-sps") || !strcmp(flag, "--slot-prompt-similarity")) {
+ if (i == argc) {
+ missing("--slot-prompt-similarity");
+ }
+ char *arg = argv[i++];
+ bparams.gparams.slot_prompt_similarity = std::stof(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "--conn-idle")) { // extend
+ if (i == argc) {
+ missing("--conn-idle");
+ }
+ char *arg = argv[i++];
+ bparams.conn_idle = std::stoi(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "--conn-keepalive")) { // extend
+ if (i == argc) {
+ missing("--conn-keepalive");
+ }
+ char *arg = argv[i++];
+ bparams.conn_keepalive = std::stoi(std::string(arg));
+ continue;
+ }
+
+ if (!strcmp(flag, "-tps") || !strcmp(flag, "--tokens-per-second")) { // extend
+ if (i == argc) {
+ missing("--tokens-per-second");
+ }
+ char *arg = argv[i++];
+ bparams.n_tps = std::stoi(std::string(arg));
+ continue;
+ }
+
+ // logging flags
+
+ if (!strcmp(flag, "--log-format")) {
+ if (i == argc) {
+ missing("--log-format");
+ }
+ char *arg = argv[i++];
+ if (!strcmp(arg, "json")) {
+ bparams.gparams.log_json = true;
+ } else if (!strcmp(arg, "text")) {
+ bparams.gparams.log_json = false;
+ } else {
+ unknown("--log-format");
+ }
+ continue;
+ }
+
+ unknown(flag);
+ }
+ } catch (const std::invalid_argument &ex) {
+ fprintf(stderr, "%s\n", ex.what());
+ return false;
+ }
+
+ if (!bparams.gparams.kv_overrides.empty()) {
+ bparams.gparams.kv_overrides.emplace_back();
+ bparams.gparams.kv_overrides.back().key[0] = 0;
+ }
+
+ return true;
+}
\ No newline at end of file
diff --git a/llama-box/ratelimiter.hpp b/llama-box/ratelimiter.hpp
new file mode 100644
index 0000000..571f9f0
--- /dev/null
+++ b/llama-box/ratelimiter.hpp
@@ -0,0 +1,39 @@
+#include
+#include
+
+// lockless token bucket rate limiter
+class token_bucket {
+
+private:
+ int capacity;
+ int rate;
+ int tokens;
+ std::chrono::steady_clock::time_point last_time;
+
+ void refill() {
+ auto const now = std::chrono::steady_clock::now();
+ auto const elapsed = std::chrono::duration_cast(now - last_time).count();
+ int new_tokens = elapsed * rate / 1000;
+ if (new_tokens > 0) {
+ tokens = std::min(capacity, tokens + new_tokens);
+ last_time = now;
+ }
+ }
+
+public:
+ token_bucket(int capacity, int rate) : capacity(capacity), rate(rate) {
+ tokens = capacity;
+ last_time = std::chrono::steady_clock::now();
+ }
+
+ bool acquire(int tokens = 1) {
+ if (this->tokens < tokens) {
+ refill();
+ if (this->tokens < tokens) {
+ return false;
+ }
+ }
+ this->tokens -= tokens;
+ return true;
+ }
+};
\ No newline at end of file
diff --git a/llama-box/scripts/version.sh b/llama-box/scripts/version.sh
new file mode 100755
index 0000000..556339d
--- /dev/null
+++ b/llama-box/scripts/version.sh
@@ -0,0 +1,60 @@
+#!/bin/sh
+
+##
+# Inspired by github.com/kubernetes/kubernetes/hack/lib/version.sh
+##
+
+# -----------------------------------------------------------------------------
+# Version management helpers. These functions help to set the
+# following variables:
+#
+# GIT_TREE_STATE - "clean" indicates no changes since the git commit id.
+# "dirty" indicates source code changes after the git commit id.
+# "unknown" indicates cannot find out the git tree.
+# GIT_COMMIT - The git commit id corresponding to this
+# source code.
+# GIT_VERSION - "vX.Y" used to indicate the last release version,
+# it can be specified via "VERSION".
+# BUILD_DATE - The build date of the version.
+
+BUILD_DATE=$(date -u '+%Y-%m-%dT%H:%M:%SZ')
+GIT_TREE_STATE="unknown"
+GIT_COMMIT="unknown"
+GIT_VERSION="unknown"
+
+# return directly if not found git client.
+if [ -z "$(command -v git)" ]; then
+ # respect specified version.
+ GIT_VERSION=${VERSION:-${GIT_VERSION}}
+ return
+fi
+
+# find out git info via git client.
+if GIT_COMMIT=$(git rev-parse "HEAD^{commit}" 2>/dev/null); then
+ # specify as dirty if the tree is not clean.
+ if git_status=$(git status --porcelain 2>/dev/null) && [ -n "${git_status}" ]; then
+ GIT_TREE_STATE="dirty"
+ else
+ GIT_TREE_STATE="clean"
+ fi
+
+ # specify with the tag if the head is tagged.
+ if GIT_VERSION="$(git rev-parse --abbrev-ref HEAD 2>/dev/null)"; then
+ if git_tag=$(git tag -l --contains HEAD 2>/dev/null | head -n 1 2>/dev/null) && [ -n "${git_tag}" ]; then
+ GIT_VERSION="${git_tag}"
+ fi
+ fi
+
+ # specify to dev if the tree is dirty.
+ if [ "${GIT_TREE_STATE:-dirty}" = "dirty" ]; then
+ GIT_VERSION="dev"
+ fi
+
+ # respect specified version
+ GIT_VERSION=${VERSION:-${GIT_VERSION}}
+fi
+
+echo "char const *LLAMA_BOX_BUILD_DATE = \"${BUILD_DATE:-0}\";"
+echo "char const *LLAMA_BOX_GIT_TREE_STATE = \"${GIT_TREE_STATE}\";"
+echo "char const *LLAMA_BOX_GIT_COMMIT = \"${GIT_COMMIT}\";"
+echo "char const *LLAMA_BOX_GIT_VERSION = \"${GIT_VERSION}\";"
diff --git a/llama-box/utils.hpp b/llama-box/utils.hpp
new file mode 100644
index 0000000..9e7f6ee
--- /dev/null
+++ b/llama-box/utils.hpp
@@ -0,0 +1,645 @@
+#pragma once
+
+#include
+#include
+#include
+#include
+
+#include "llama.cpp/common/common.h"
+#define JSON_ASSERT GGML_ASSERT
+#include "llama.cpp/common/json.hpp"
+#include "llama.cpp/llama.h"
+
+#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
+
+using json = nlohmann::json;
+
+// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
+enum error_type {
+ ERROR_TYPE_INVALID_REQUEST,
+ ERROR_TYPE_AUTHENTICATION,
+ ERROR_TYPE_SERVER,
+ ERROR_TYPE_NOT_FOUND,
+ ERROR_TYPE_PERMISSION,
+ ERROR_TYPE_UNAVAILABLE, // custom error
+ ERROR_TYPE_NOT_SUPPORTED, // custom error
+};
+
+#define LOG_ERROR(MSG, ...) server_log("ERR", __func__, __LINE__, MSG, __VA_ARGS__)
+#define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__)
+#define LOG_INFO(MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
+
+static inline void server_log(const char *level, const char *function, int line,
+ const char *message, const json &extra);
+
+template
+static T json_value(const json &body, const std::string &key, const T &default_value) {
+ // Fallback null to default value
+ if (body.contains(key) && !body.at(key).is_null()) {
+ try {
+ return body.at(key);
+ } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &) {
+ std::stringstream ss;
+ ss << "Wrong type supplied for parameter '" << key << "'. Expected '"
+ << json(default_value).type_name() << "', using default value.";
+ LOG_WARNING(ss.str().c_str(), body);
+ return default_value;
+ }
+ } else {
+ return default_value;
+ }
+}
+
+extern bool server_log_json;
+
+static inline void server_log(const char *level, const char *function, int line,
+ const char *message, const json &extra) {
+ std::stringstream ss_tid;
+ ss_tid << std::this_thread::get_id();
+ json log = json{
+ {"tid", ss_tid.str()},
+ {"timestamp", time(nullptr)},
+ };
+
+ if (server_log_json) {
+ log.merge_patch({
+ {"level", level},
+ {"function", function},
+ {"line", line},
+ {"msg", message},
+ });
+
+ if (!extra.empty()) {
+ log.merge_patch(extra);
+ }
+
+ printf("%s\n", log.dump(-1, ' ', false, json::error_handler_t::replace).c_str());
+ fflush(stdout);
+ return;
+ }
+
+ char buf[1024];
+ snprintf(buf, 1024, "%4s [%24s] %s", level, function, message);
+
+ if (!extra.empty()) {
+ log.merge_patch(extra);
+ }
+ std::stringstream ss;
+ ss << buf << " |";
+ for (const auto &el : log.items()) {
+ const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace);
+ ss << " " << el.key() << "=" << value;
+ }
+
+ const std::string str = ss.str();
+ printf("%.*s\n", (int)str.size(), str.data());
+ fflush(stdout);
+}
+
+//
+// chat template utils
+//
+
+// Format given chat. If tmpl is empty, we take the template from model metadata
+inline std::string format_chat(const struct llama_model *model, const std::string &tmpl,
+ const std::vector &messages) {
+ size_t alloc_size = 0;
+ // vector holding all allocated string to be passed to llama_chat_apply_template
+ std::vector str(messages.size() * 2);
+ std::vector chat(messages.size());
+
+ for (size_t i = 0; i < messages.size(); ++i) {
+ const auto &curr_msg = messages[i];
+ str[i * 2 + 0] = json_value(curr_msg, "role", std::string(""));
+ str[i * 2 + 1] = json_value(curr_msg, "content", std::string(""));
+ alloc_size += str[i * 2 + 1].length();
+ chat[i].role = str[i * 2 + 0].c_str();
+ chat[i].content = str[i * 2 + 1].c_str();
+ }
+
+ const char *ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
+ std::vector buf(alloc_size * 2);
+
+ // run the first time to get the total output length
+ int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true,
+ buf.data(), buf.size());
+
+ // if it turns out that our buffer is too small, we resize it
+ if ((size_t)res > buf.size()) {
+ buf.resize(res);
+ res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(),
+ buf.size());
+ }
+
+ const std::string formatted_chat(buf.data(), res);
+ return formatted_chat;
+}
+
+//
+// base64 utils (TODO: move to common in the future)
+//
+
+static const std::string base64_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ "abcdefghijklmnopqrstuvwxyz"
+ "0123456789+/";
+
+static inline bool is_base64(uint8_t c) {
+ return (isalnum(c) || (c == '+') || (c == '/'));
+}
+
+static inline std::vector base64_decode(const std::string &encoded_string) {
+ int i = 0;
+ int j = 0;
+ int in_ = 0;
+
+ int in_len = encoded_string.size();
+
+ uint8_t char_array_4[4];
+ uint8_t char_array_3[3];
+
+ std::vector ret;
+
+ while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) {
+ char_array_4[i++] = encoded_string[in_];
+ in_++;
+ if (i == 4) {
+ for (i = 0; i < 4; i++) {
+ char_array_4[i] = base64_chars.find(char_array_4[i]);
+ }
+
+ char_array_3[0] = ((char_array_4[0]) << 2) + ((char_array_4[1] & 0x30) >> 4);
+ char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
+ char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
+
+ for (i = 0; (i < 3); i++) {
+ ret.push_back(char_array_3[i]);
+ }
+
+ i = 0;
+ }
+ }
+
+ if (i) {
+ for (j = i; j < 4; j++) {
+ char_array_4[j] = 0;
+ }
+
+ for (j = 0; j < 4; j++) {
+ char_array_4[j] = base64_chars.find(char_array_4[j]);
+ }
+
+ char_array_3[0] = ((char_array_4[0]) << 2) + ((char_array_4[1] & 0x30) >> 4);
+ char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
+ char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
+
+ for (j = 0; j < i - 1; j++) {
+ ret.push_back(char_array_3[j]);
+ }
+ }
+
+ return ret;
+}
+
+//
+// random string / id
+//
+
+static std::string random_string() {
+ static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
+
+ std::random_device rd;
+ std::mt19937 generator(rd());
+
+ std::string result(32, ' ');
+
+ for (int i = 0; i < 32; ++i) {
+ result[i] = str[generator() % str.size()];
+ }
+
+ return result;
+}
+
+static std::string gen_chatcmplid() {
+ std::stringstream chatcmplid;
+ chatcmplid << "chatcmpl-" << random_string();
+
+ return chatcmplid.str();
+}
+
+static std::string gen_cmplid() {
+ std::stringstream cmplid;
+ cmplid << "cmpl-" << random_string();
+
+ return cmplid.str();
+}
+
+//
+// other common utils
+//
+
+static size_t common_part(const std::vector &a, const std::vector &b) {
+ size_t i;
+ for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {
+ }
+
+ return i;
+}
+
+static size_t common_part(const std::string &a, const std::string &b) {
+ size_t i;
+ for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {
+ }
+
+ return i;
+}
+
+static bool ends_with(const std::string & str, const std::string & suffix) {
+ return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
+}
+
+static size_t find_partial_stop_string(const std::string &stop, const std::string &text) {
+ if (!text.empty() && !stop.empty()) {
+ const char text_last_char = text.back();
+ for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
+ if (stop[char_index] == text_last_char) {
+ const std::string current_partial = stop.substr(0, char_index + 1);
+ if (ends_with(text, current_partial)) {
+ return text.size() - char_index - 1;
+ }
+ }
+ }
+ }
+
+ return std::string::npos;
+}
+
+// format incomplete utf-8 multibyte character for output
+static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) {
+ std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token);
+
+ // if the size is 1 and first bit is 1, meaning it's a partial character
+ // (size > 1 meaning it's already a known token)
+ if (out.size() == 1 && (out[0] & 0x80) == 0x80) {
+ std::stringstream ss;
+ ss << std::hex << (out[0] & 0xff);
+ std::string res(ss.str());
+ out = "byte: \\x" + res;
+ }
+
+ return out;
+}
+
+struct completion_token_output {
+ llama_token tok;
+ std::string text_to_send;
+
+ struct token_prob {
+ llama_token tok;
+ float prob;
+ };
+
+ std::vector probs;
+};
+
+// convert a vector of completion_token_output to json
+static json probs_vector_to_json(const llama_context *ctx,
+ const std::vector &probs,
+ const bool oaicompat_completion = false,
+ const bool oaicompat_completion_chat = false) {
+ if (oaicompat_completion) {
+ if (oaicompat_completion_chat) {
+ json content = json::array();
+
+ for (const auto &prob : probs) {
+ const std::string token = tokens_to_output_formatted_string(ctx, prob.tok);
+ float token_logprob = 1.0f;
+ std::vector token_bytes(token.begin(), token.end());
+ json token_top_logprobs = json::array();
+ for (const auto &p : prob.probs) {
+ const std::string p_token = tokens_to_output_formatted_string(ctx, p.tok);
+ float p_token_logprob = p.prob;
+ std::vector p_token_bytes(p_token.begin(), p_token.end());
+ token_top_logprobs.push_back(json{
+ {"token", p_token},
+ {"logprob", p_token_logprob},
+ {"bytes", p_token_bytes},
+ });
+ if (p.tok == prob.tok) {
+ token_logprob = p_token_logprob;
+ }
+ }
+
+ content.push_back(json{
+ {"token", token},
+ {"logprob", token_logprob},
+ {"bytes", token_bytes},
+ {"top_logprobs", token_top_logprobs},
+ });
+ }
+
+ return json{{"content", content}};
+ } else {
+ json token_logprobs = json::array();
+ json tokens = json::array();
+ json top_logprobs = json::array();
+
+ for (const auto &prob : probs) {
+ const std::string token = tokens_to_output_formatted_string(ctx, prob.tok);
+ float token_logprob = 1.0f;
+ json token_top_logprobs;
+ for (const auto &p : prob.probs) {
+ const std::string p_token = tokens_to_output_formatted_string(ctx, p.tok);
+ float p_token_logprob = p.prob;
+ token_top_logprobs[p_token] = p_token_logprob;
+ if (p.tok == prob.tok) {
+ token_logprob = p_token_logprob;
+ }
+ }
+
+ tokens.push_back(token);
+ token_logprobs.push_back(token_logprob);
+ top_logprobs.push_back(token_top_logprobs);
+ }
+
+ return json{{"tokens", tokens},
+ {"token_logprobs", token_logprobs},
+ {"top_logprobs", top_logprobs}};
+ }
+ }
+
+ json out = json::array();
+
+ for (const auto &prob : probs) {
+ json probs_for_token = json::array();
+
+ for (const auto &p : prob.probs) {
+ const std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
+ probs_for_token.push_back(json{
+ {"tok_str", tok_str},
+ {"prob", p.prob},
+ });
+ }
+
+ const std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
+ out.push_back(json{
+ {"content", tok_str},
+ {"probs", probs_for_token},
+ });
+ }
+
+ return out;
+}
+
+//
+// OAI utils
+//
+
+static json oaicompat_completion_request(const struct llama_model *model, const json &body,
+ const std::string &chat_template) {
+ bool chat = !chat_template.empty();
+ json llama_params;
+
+ // Annotations for OAI compatibility
+ llama_params["__oaicompat"] = true;
+ llama_params["__oaicompat_completion"] = true;
+ llama_params["__oaicompat_completion_chat"] = chat;
+
+ // Handle default field
+ llama_params["model"] = json_value(body, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
+ llama_params["frequency_penalty"] = json_value(body, "frequency_penalty", 0.0f);
+ llama_params["temperature"] = json_value(body, "temperature", 1.0f);
+ llama_params["top_p"] = json_value(body, "top_p", 1.0f);
+
+ // Handle "max_tokens" field
+ llama_params["n_predict"] = json_value(body, "max_tokens", -1);
+
+ // Apply chat template to the list of messages
+ if (!chat_template.empty()) {
+ llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
+ } else {
+ llama_params["prompt"] = json_value(body, "prompt", std::string());
+ }
+
+ // Handle "stop" field
+ if (body.contains("stop") && body.at("stop").is_string()) {
+ llama_params["stop"] = json::array({body.at("stop").get()});
+ } else {
+ llama_params["stop"] = json_value(body, "stop", json::array());
+ }
+
+ // Handle "response_format" field
+ if (body.contains("response_format")) {
+ json response_format = json_value(body, "response_format", json::object());
+ std::string response_type = json_value(response_format, "type", std::string());
+ if (response_type == "json_object") {
+ llama_params["json_schema"] = json_value(response_format, "schema", json::object());
+ } else if (!response_type.empty() && response_type != "text") {
+ throw std::runtime_error(
+ "\"response_format\" must be one of \"text\" or \"json_object\", but got: " +
+ response_type);
+ }
+ }
+
+ // Handle "n" field
+ int n_choices = json_value(body, "n", 1);
+ if (n_choices != 1) {
+ throw std::runtime_error("Only one completion choice is allowed");
+ }
+
+ // Handle "logprobs" field
+ if (body.contains("logprobs")) {
+ if (chat) {
+ llama_params["n_probs"] = std::min(json_value(body, "top_logprobs", 2), 20);
+ } else {
+ llama_params["n_probs"] = std::min(json_value(body, "logprobs", 2), 5);
+ }
+ } else if (body.contains("top_logprobs")) {
+ throw std::runtime_error("\"top_logprobs\" requires \"logprobs\" to be set");
+ }
+
+ // Params supported by OAI but unsupported by llama.cpp
+ static const std::vector unsupported_params{"tools", "tool_choice"};
+ for (auto ¶m : unsupported_params) {
+ if (body.contains(param)) {
+ throw std::runtime_error("Unsupported param: " + param);
+ }
+ }
+
+ // Copy remaining properties to llama_params
+ // This allows user to use llama.cpp-specific params like "mirostat", "tfs_z",... via OAI
+ // endpoint. See "launch_slot_with_task()" for a complete list of params supported by llama.cpp
+ for (const auto &item : body.items()) {
+ // Exception: if "n_predict" is present, we overwrite the value specified earlier by
+ // "max_tokens"
+ if (!llama_params.contains(item.key()) || item.key() == "n_predict") {
+ llama_params[item.key()] = item.value();
+ }
+ }
+
+ return llama_params;
+}
+
+static json oaicompat_completion_response(const json &request, const json result,
+ const std::string &completion_id, bool streaming = false,
+ bool first = false) {
+ bool stopped_word = json_value(result, "stopped_word", false);
+ bool stopped_eos = json_value(result, "stopped_eos", false);
+ bool stopped_limit = json_value(result, "stopped_limit", false);
+ std::string content = json_value(result, "content", std::string(""));
+
+ std::string finish_reason;
+ if (stopped_word || stopped_eos) {
+ finish_reason = "stop";
+ }
+ if (stopped_limit) {
+ finish_reason = "length";
+ }
+
+ json res = json{
+ {"id", completion_id},
+ {"created", std::time(0)},
+ {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
+ };
+
+ bool chat = json_value(request, "__oaicompat_completion_chat", false);
+ bool finish = !finish_reason.empty();
+ json choice;
+ if (chat) {
+ // chat completion
+ if (streaming) {
+ res["object"] = "chat.completion.chunk";
+ if (!finish && first) {
+ choice = json{{"finish_reason", nullptr},
+ {"index", 0},
+ {"delta", json{{"role", "assistant"}}}};
+ } else if (!finish) {
+ choice = json{{"finish_reason", nullptr},
+ {"index", 0},
+ {"delta", json{{"content", content}}}};
+ } else {
+ // finished
+ choice =
+ json{{"finish_reason", finish_reason}, {"index", 0}, {"delta", json::object()}};
+ }
+ } else {
+ res["object"] = "chat.completion";
+ if (!finish) {
+ choice = json{{"finish_reason", nullptr},
+ {"index", 0},
+ {"message", json{{"content", content}, {"role", "assistant"}}}};
+ } else {
+ choice = json{{"finish_reason", finish_reason},
+ {"index", 0},
+ {"message", json{{"content", content}, {"role", "assistant"}}}};
+ }
+ }
+
+ } else {
+ // completion
+ res["object"] = "text_completion";
+ if (!finish) {
+ choice = json{{"finish_reason", nullptr}, {"index", 0}, {"text", content}};
+ } else {
+ choice = json{{"finish_reason", finish_reason}, {"index", 0}, {"text", content}};
+ }
+ }
+ bool logprobs = result.contains("completion_probabilities");
+ if (!logprobs) {
+ choice["logprobs"] = nullptr;
+ } else {
+ choice["logprobs"] = result.at("completion_probabilities");
+ }
+ res["choices"] = json::array({choice});
+
+ // Add usage information
+ if (!streaming) {
+ int completion_tokens = json_value(result, "tokens_predicted", 0);
+ int prompt_tokens = json_value(result, "tokens_evaluated", 0);
+ json timings = json_value(result, "timings", json::object());
+ int ttft = json_value(timings, "prompt_ms", 0); // time to first token in milliseconds.
+ int tpot = json_value(timings, "predicted_per_token_ms",
+ 0); // time per output token in milliseconds.
+ res["usage"] = json{{"completion_tokens", completion_tokens},
+ {"prompt_tokens", prompt_tokens},
+ {"total_tokens", completion_tokens + prompt_tokens},
+ {"time_to_first_token_ms", ttft},
+ {"time_per_output_token_ms", tpot}};
+ }
+
+ return res;
+}
+
+static json oaicompat_embedding_request(const struct gpt_params ¶ms, const json &body) {
+ json llama_params;
+
+ // Annotations for OAI compatibility
+ llama_params["__oaicompat"] = true;
+ llama_params["__oaicompat_embedding"] = true;
+
+ // Handle "model" field
+ llama_params["model"] = json_value(body, "model", params.model_alias);
+
+ // Handle "input" field
+ llama_params["prompt"] = json_value(body, "input", std::string());
+
+ // Handle "encoding_format" field
+ llama_params["encoding_format"] = json_value(body, "encoding_format", std::string("float"));
+
+ return llama_params;
+}
+
+static json oaicompat_embedding_response(const json &request, const json &result) {
+ json data = json::array();
+ data.push_back(json{{"embedding", json_value(result, "embedding", json::array())},
+ {"index", 0},
+ {"object", "embedding"}});
+
+ int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
+ json res = json{
+ {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
+ {"object", "list"},
+ {"usage", json{{"prompt_tokens", num_prompt_tokens}, {"total_tokens", num_prompt_tokens}}},
+ {"data", data}};
+
+ return res;
+}
+
+static json format_error_response(const std::string &message, const enum error_type type) {
+ std::string type_str;
+ int code = 500;
+ switch (type) {
+ case ERROR_TYPE_INVALID_REQUEST:
+ type_str = "invalid_request_error";
+ code = 400;
+ break;
+ case ERROR_TYPE_AUTHENTICATION:
+ type_str = "authentication_error";
+ code = 401;
+ break;
+ case ERROR_TYPE_NOT_FOUND:
+ type_str = "not_found_error";
+ code = 404;
+ break;
+ case ERROR_TYPE_SERVER:
+ type_str = "server_error";
+ code = 500;
+ break;
+ case ERROR_TYPE_PERMISSION:
+ type_str = "permission_error";
+ code = 403;
+ break;
+ case ERROR_TYPE_NOT_SUPPORTED:
+ type_str = "not_supported_error";
+ code = 501;
+ break;
+ case ERROR_TYPE_UNAVAILABLE:
+ type_str = "unavailable_error";
+ code = 503;
+ break;
+ }
+ return json{
+ {"code", code},
+ {"message", message},
+ {"type", type_str},
+ };
+}
diff --git a/llama.cpp b/llama.cpp
new file mode 160000
index 0000000..d50f889
--- /dev/null
+++ b/llama.cpp
@@ -0,0 +1 @@
+Subproject commit d50f8897a797a5a03f31228d1b5a7b8130ee1bc2