Skip to content

Commit

Permalink
add NPU support for embedding and rerank models
Browse files Browse the repository at this point in the history
  • Loading branch information
openvino-dev-samples committed Mar 25, 2024
1 parent 67bfe36 commit 2b5cfda
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 13 deletions.
10 changes: 8 additions & 2 deletions notebooks/254-llm-chatbot/254-rag-chatbot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1126,11 +1126,14 @@
"source": [
"from ov_embedding_model import OVBgeEmbeddings\n",
"\n",
"embedding_model_max_length = 384\n",
"encode_kwargs = {'normalize_embeddings': embedding_model_configuration[\"do_norm\"]}\n",
"embedding_model_kwargs = {\"device\": embedding_device.value, \"model_max_length\": embedding_model_max_length if embedding_device.value==\"NPU\" else None}\n",
"\n",
"embedding = OVBgeEmbeddings(\n",
" model_dir=embedding_model_id.value,\n",
" model_kwargs=embedding_model_kwargs,\n",
" encode_kwargs=encode_kwargs,\n",
" device=embedding_device.value\n",
")"
]
},
Expand All @@ -1156,10 +1159,13 @@
"\n",
"from ov_rerank_model import OVRanker\n",
"\n",
"rerank_model_max_length = 384\n",
"rerank_top_n = 3\n",
"embedding_model_kwargs = {\"device\": rerank_device.value, \"model_max_length\": rerank_model_max_length if rerank_device.value==\"NPU\" else None}\n",
"\n",
"reranker = OVRanker(\n",
" model_dir=rerank_model_id.value,\n",
" device=rerank_device.value,\n",
" model_kwargs=rerank_model_kwargs,\n",
" top_n=rerank_top_n\n",
")"
]
Expand Down
9 changes: 3 additions & 6 deletions notebooks/254-llm-chatbot/ov_embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@ class OVBgeEmbeddings(BaseModel, Embeddings):
"""Tokenizer for embedding model."""
model_dir: str
"""Path to store models."""
device: str = "CPU"
"""Device for model deployment. """
ov_config: Dict[str, Any] = Field(default_factory=dict)
"""OpenVINO configuration arguments to pass to the model."""
model_kwargs: Dict[str, Any]
"""Keyword arguments passed to the model."""
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass when calling the `encode` method of the model."""
query_instruction: str = DEFAULT_QUERY_BGE_INSTRUCTION_EN
Expand All @@ -43,9 +41,8 @@ class OVBgeEmbeddings(BaseModel, Embeddings):
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
super().__init__(**kwargs)

self.ov_model = OVModelForFeatureExtraction.from_pretrained(
self.model_dir, device=self.device, ov_config=self.ov_config)
self.model_dir, **self.model_kwargs)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)

if "-zh" in self.model_dir:
Expand Down
8 changes: 3 additions & 5 deletions notebooks/254-llm-chatbot/ov_rerank_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,16 @@ class OVRanker(BaseDocumentCompressor):
"""Tokenizer for embedding model."""
model_dir: str
"""Path to store models."""
device: str = "CPU"
"""Device for model deployment. """
ov_config: Dict[str, Any] = Field(default_factory=dict)
"""OpenVINO configuration arguments to pass to the model."""
model_kwargs: Dict[str, Any]
"""Keyword arguments passed to the model."""
top_n: int = 4
"""return Top n texts."""

def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self.tokenizer = self._get_tokenizer()
self.ov_model = OVModelForSequenceClassification.from_pretrained(
self.model_dir, device=self.device, ov_config=self.ov_config)
self.model_dir, **self.model_kwargs)

def _load_vocab(self, vocab_file):

Expand Down

0 comments on commit 2b5cfda

Please sign in to comment.