Skip to content

Commit

Permalink
add support for dynamic state-tuned models
Browse files Browse the repository at this point in the history
  • Loading branch information
josStorer committed May 12, 2024
1 parent b52873c commit a2bbbab
Show file tree
Hide file tree
Showing 12 changed files with 230 additions and 15 deletions.
4 changes: 3 additions & 1 deletion backend-golang/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ func (a *App) OnStartup(ctx context.Context) {
os.Chmod(a.exDir+"backend-rust/web-rwkv-converter", 0777)
os.Mkdir(a.exDir+"models", os.ModePerm)
os.Mkdir(a.exDir+"lora-models", os.ModePerm)
os.Mkdir(a.exDir+"state-models", os.ModePerm)
os.Mkdir(a.exDir+"finetune/json2binidx_tool/data", os.ModePerm)
trainLogPath := "lora-models/train_log.txt"
if !a.FileExists(trainLogPath) {
Expand All @@ -151,8 +152,9 @@ func (a *App) OnBeforeClose(ctx context.Context) bool {
func (a *App) watchFs() {
watcher, err := fsnotify.NewWatcher()
if err == nil {
watcher.Add(a.exDir + "./lora-models")
watcher.Add(a.exDir + "./models")
watcher.Add(a.exDir + "./lora-models")
watcher.Add(a.exDir + "./state-models")
go func() {
for {
select {
Expand Down
3 changes: 3 additions & 0 deletions backend-python/routes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ def update_config(body: ModelConfigBody):
model_config = ModelConfigBody()
global_var.set(global_var.Model_Config, model_config)
merge_model(model_config, body)
exception = load_rwkv_state(global_var.get(global_var.Model), model_config.state)
if exception is not None:
raise exception
print("Updated Model Config:", model_config)

return "success"
Expand Down
13 changes: 13 additions & 0 deletions backend-python/routes/state_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,19 @@ def reset_state():
return "success"


def force_reset_state():
global trie, dtrie

if trie is None:
return

import cyac

trie = cyac.Trie()
dtrie = {}
gc.collect()


class LongestPrefixStateBody(BaseModel):
prompt: str

Expand Down
92 changes: 86 additions & 6 deletions backend-python/utils/rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
from typing import Dict, Iterable, List, Tuple, Union, Type, Callable
from utils.log import quick_log
from fastapi import HTTPException
from fastapi import HTTPException, status
from pydantic import BaseModel, Field
from routes import state_cache
import global_var
Expand All @@ -27,6 +27,7 @@ def __init__(self, model, pipeline):
self.EOS_ID = 0

self.name = "rwkv"
self.model_path = ""
self.version = 4
self.model = model
self.pipeline = pipeline
Expand All @@ -43,6 +44,8 @@ def __init__(self, model, pipeline):
self.penalty_alpha_frequency = 1
self.penalty_decay = 0.996
self.global_penalty = False
self.state_path = ""
self.state_tuned = None

@abstractmethod
def adjust_occurrence(self, occurrence: Dict, token: int):
Expand Down Expand Up @@ -236,7 +239,10 @@ def generate(
except HTTPException:
pass
if cache is None or cache["prompt"] == "" or cache["state"] is None:
self.model_state = None
if self.state_path:
self.model_state = copy.deepcopy(self.state_tuned)
else:
self.model_state = None
self.model_tokens = []
else:
delta_prompt = prompt[len(cache["prompt"]) :]
Expand Down Expand Up @@ -606,13 +612,13 @@ def get_model_path(model_path: str) -> str:


def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV:
model = get_model_path(model)
model_path = get_model_path(model)

rwkv_beta = global_var.get(global_var.Args).rwkv_beta
rwkv_cpp = getattr(global_var.get(global_var.Args), "rwkv.cpp")
webgpu = global_var.get(global_var.Args).webgpu

if "midi" in model.lower() or "abc" in model.lower():
if "midi" in model_path.lower() or "abc" in model_path.lower():
os.environ["RWKV_RESCALE_LAYER"] = "999"

# dynamic import to make RWKV_CUDA_ON work
Expand All @@ -637,8 +643,8 @@ def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV
)
from rwkv_pip.utils import PIPELINE

filename, _ = os.path.splitext(os.path.basename(model))
model = Model(model, strategy)
filename, _ = os.path.splitext(os.path.basename(model_path))
model = Model(model_path, strategy)
if not tokenizer:
tokenizer = get_tokenizer(len(model.w["emb.weight"]))
pipeline = PIPELINE(model, tokenizer)
Expand Down Expand Up @@ -671,6 +677,7 @@ def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV
else:
rwkv = TextRWKV(model, pipeline)
rwkv.name = filename
rwkv.model_path = model_path
rwkv.version = model.version

return rwkv
Expand All @@ -688,6 +695,7 @@ class ModelConfigBody(BaseModel):
default=None,
description="When generating a response, whether to include the submitted prompt as a penalty factor. By turning this off, you will get the same generated results as official RWKV Gradio. If you find duplicate results in the generated results, turning this on can help avoid generating duplicates.",
)
state: str = Field(default=None, description="state-tuned file path")

model_config = {
"json_schema_extra": {
Expand All @@ -699,11 +707,80 @@ class ModelConfigBody(BaseModel):
"frequency_penalty": 1,
"penalty_decay": 0.996,
"global_penalty": False,
"state": "",
}
}
}


def load_rwkv_state(model: AbstractRWKV, state_path: str) -> HTTPException:
if model:
if state_path:
if model.model_path.endswith(".pth") and state_path.endswith(".pth"):
import torch

state_path = get_model_path(state_path)
if model.state_path == state_path:
return

state_raw = torch.load(state_path, map_location="cpu")
state_raw_shape = next(iter(state_raw.values())).shape

args = model.model.args
if (
len(state_raw) != args.n_layer
or state_raw_shape[0] * state_raw_shape[1] != args.n_embd
):
if model.state_path:
pass
else:
print("state failed to load")
return HTTPException(
status.HTTP_400_BAD_REQUEST, "state shape mismatch"
)

strategy = model.model.strategy
model.state_tuned = [None] * args.n_layer * 3

for i in range(args.n_layer):
dd = strategy[i]
dev = dd.device
atype = dd.atype
model.state_tuned[i * 3 + 0] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous()
model.state_tuned[i * 3 + 1] = (
state_raw[f"blocks.{i}.att.time_state"]
.transpose(1, 2)
.to(dtype=torch.float, device=dev)
.requires_grad_(False)
.contiguous()
)
model.state_tuned[i * 3 + 2] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous()

state_cache.force_reset_state()
model.state_path = state_path
print("state loaded")
else:
if model.state_path:
pass
else:
print("state failed to load")
return HTTPException(
status.HTTP_400_BAD_REQUEST,
"file format of the model or state model not supported",
)
else:
state_cache.force_reset_state()
model.state_path = ""
model.state_tuned = None # TODO cached
print("state unloaded")
else:
print("state not loaded")


def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
if body.max_tokens is not None:
model.max_tokens_per_generation = body.max_tokens
Expand All @@ -724,6 +801,8 @@ def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
model.top_k = body.top_k
if body.global_penalty is not None:
model.global_penalty = body.global_penalty
if body.state is not None:
load_rwkv_state(model, body.state)


def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
Expand All @@ -736,4 +815,5 @@ def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
penalty_decay=model.penalty_decay,
top_k=model.top_k,
global_penalty=model.global_penalty,
state=model.state_path,
)
7 changes: 6 additions & 1 deletion frontend/src/_locales/ja/main.json
Original file line number Diff line number Diff line change
Expand Up @@ -354,5 +354,10 @@
"Inside the model, there is a default prompt to improve the model's handling of common issues, but it may degrade the role-playing effect. You can disable this option to achieve a better role-playing effect.": "モデル内部には、一般的な問題の処理を改善するためのデフォルトのプロンプトがありますが、役割演技の効果を低下させる可能性があります。このオプションを無効にすることで、より良い役割演技効果を得ることができます。",
"Exit without saving": "保存せずに終了",
"Content has been changed, are you sure you want to exit without saving?": "コンテンツが変更されています、保存せずに終了してもよろしいですか?",
"Don't forget to correctly fill in your Ollama API Chat Model Name.": "Ollama APIチャットモデル名を正しく記入するのを忘れないでください。"
"Don't forget to correctly fill in your Ollama API Chat Model Name.": "Ollama APIチャットモデル名を正しく記入するのを忘れないでください。",
"State-tuned Model": "State調整モデル",
"See More": "もっと見る",
"State Model": "Stateモデル",
"State model mismatch": "Stateモデルの不一致",
"File format of the model or state model not supported": "モデルまたはStateモデルのファイル形式がサポートされていません"
}
7 changes: 6 additions & 1 deletion frontend/src/_locales/zh-hans/main.json
Original file line number Diff line number Diff line change
Expand Up @@ -354,5 +354,10 @@
"Inside the model, there is a default prompt to improve the model's handling of common issues, but it may degrade the role-playing effect. You can disable this option to achieve a better role-playing effect.": "模型内部有一个默认提示来改善模型处理常规问题的效果, 但它可能会让角色扮演的效果变差, 你可以关闭此选项来获得更好的角色扮演效果",
"Exit without saving": "退出而不保存",
"Content has been changed, are you sure you want to exit without saving?": "内容已经被修改, 你确定要退出而不保存吗?",
"Don't forget to correctly fill in your Ollama API Chat Model Name.": "不要忘记正确填写你的Ollama API 聊天模型名"
"Don't forget to correctly fill in your Ollama API Chat Model Name.": "不要忘记正确填写你的Ollama API 聊天模型名",
"State-tuned Model": "State微调模型",
"See More": "查看更多",
"State Model": "State模型",
"State model mismatch": "State模型不匹配",
"File format of the model or state model not supported": "模型或state模型的文件格式不支持"
}
13 changes: 12 additions & 1 deletion frontend/src/components/RunButton.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,18 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
presence_penalty: modelConfig.apiParameters.presencePenalty,
frequency_penalty: modelConfig.apiParameters.frequencyPenalty,
penalty_decay: modelConfig.apiParameters.penaltyDecay,
global_penalty: modelConfig.apiParameters.globalPenalty
global_penalty: modelConfig.apiParameters.globalPenalty,
state: modelConfig.apiParameters.stateModel
}).then(async r => {
if (r.status !== 200) {
const error = await r.text();
if (error.includes('state shape mismatch'))
toast(t('State model mismatch'), { type: 'error' });
else if (error.includes('file format of the model or state model not supported'))
toast(t('File format of the model or state model not supported'), { type: 'error' });
else
toast(error, { type: 'error' });
}
});
}

Expand Down
50 changes: 47 additions & 3 deletions frontend/src/pages/Configs.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ import {
Dropdown,
Input,
Label,
Link,
Option,
PresenceBadge,
Select,
Switch,
Text
Text,
Tooltip
} from '@fluentui/react-components';
import { AddCircle20Regular, DataUsageSettings20Regular, Delete20Regular, Save20Regular } from '@fluentui/react-icons';
import React, { FC, useCallback, useEffect, useRef } from 'react';
Expand All @@ -27,7 +29,7 @@ import { Page } from '../components/Page';
import { useNavigate } from 'react-router';
import { RunButton } from '../components/RunButton';
import { updateConfig } from '../apis';
import { getStrategy } from '../utils';
import { getStrategy, isDynamicStateSupported } from '../utils';
import { useTranslation } from 'react-i18next';
import strategyImg from '../assets/images/strategy.jpg';
import strategyZhImg from '../assets/images/strategy_zh.jpg';
Expand All @@ -36,6 +38,7 @@ import { useMediaQuery } from 'usehooks-ts';
import { ApiParameters, Device, ModelParameters, Precision } from '../types/configs';
import { convertModel, convertToGGML, convertToSt } from '../utils/convert-model';
import { defaultPenaltyDecay } from './defaultConfigs';
import { BrowserOpenURL } from '../../wailsjs/runtime';

const ConfigSelector: FC<{
selectedIndex: number,
Expand Down Expand Up @@ -112,14 +115,27 @@ const Configs: FC = observer(() => {

const onClickSave = () => {
commonStore.setModelConfig(selectedIndex, selectedConfig);
// When clicking RunButton in Configs page, updateConfig will be called twice,
// because there are also RunButton in other pages, and the calls to updateConfig in both places are necessary.
updateConfig({
max_tokens: selectedConfig.apiParameters.maxResponseToken,
temperature: selectedConfig.apiParameters.temperature,
top_p: selectedConfig.apiParameters.topP,
presence_penalty: selectedConfig.apiParameters.presencePenalty,
frequency_penalty: selectedConfig.apiParameters.frequencyPenalty,
penalty_decay: selectedConfig.apiParameters.penaltyDecay,
global_penalty: selectedConfig.apiParameters.globalPenalty
global_penalty: selectedConfig.apiParameters.globalPenalty,
state: selectedConfig.apiParameters.stateModel
}).then(async r => {
if (r.status !== 200) {
const error = await r.text();
if (error.includes('state shape mismatch'))
toast(t('State model mismatch'), { type: 'error' });
else if (error.includes('file format of the model or state model not supported'))
toast(t('File format of the model or state model not supported'), { type: 'error' });
else
toast(error, { type: 'error' });
}
});
toast(t('Config Saved'), { autoClose: 300, type: 'success' });
};
Expand Down Expand Up @@ -200,6 +216,34 @@ const Configs: FC = observer(() => {
});
}} />
} />
{isDynamicStateSupported(selectedConfig) &&
<div className="sm:col-span-2 flex gap-2 items-center min-w-0">
<Tooltip content={<div>
{t('State-tuned Model')}, {t('See More')}: <Link
onClick={() => BrowserOpenURL('https://github.com/BlinkDL/RWKV-LM#state-tuning-tuning-the-initial-state-zero-inference-overhead')}>{'https://github.com/BlinkDL/RWKV-LM#state-tuning-tuning-the-initial-state-zero-inference-overhead'}
</Link>
</div>} showDelay={0} hideDelay={0}
relationship="description">
<div className="shrink-0">
{t('State Model') + ' *'}
</div>
</Tooltip>
<Select style={{ minWidth: 0 }} className="grow"
value={selectedConfig.apiParameters.stateModel}
onChange={(e, data) => {
setSelectedConfigApiParams({
stateModel: data.value
});
}}>
<option key={-1} value={''}>
{t('None')}
</option>
{commonStore.stateModels.map((modelName, index) =>
<option key={index} value={modelName}>{modelName}</option>
)}
</Select>
</div>
}
<Accordion className="sm:col-span-2" collapsible
openItems={!commonStore.apiParamsCollapsed && 'advanced'}
onToggle={(e, data) => {
Expand Down
Loading

0 comments on commit a2bbbab

Please sign in to comment.