- Python 3.6+
- PyTorch 1.13.0+
- Transformers 4.25.0+
pip install neurocache
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from neurocache import (
NeurocacheModelForCausalLM,
OnDeviceCacheConfig,
)
model_name = "facebook/opt-350m"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
cache_layer_idx = model.config.num_hidden_layers - 5
config = OnDeviceCacheConfig(
cache_layers=[cache_layer_idx, cache_layer_idx + 3],
attention_layers=list(range(cache_layer_idx, model.config.num_hidden_layers)),
compression_factor=8,
topk=8,
)
model = NeurocacheModelForCausalLM(model, config)
input_text = ["Hello, my dog is cute", " is cute"]
tokenized_input = tokenizer(input_text, return_tensors="pt")
tokenized_input["start_of_sequence"] = torch.tensor([1, 0]).bool()
outputs = model(**tokenized_input)
from neurocache.utils import NEUROCACHE_SUPPORTED_MODELS
print(NEUROCACHE_SUPPORTED_MODELS)
[
"opt",
"llama",
"mistral",
"gptj",
]
- Benchmark the implementation and identify bottlenecks.
- Add support for more models and for grouped query attention (for Mistral and Larger LLaMA models).
- Add chunked storage function for generation (enables faster processing for long prompts).
- Add support for masking padding tokens in the cache (required for global cache only).