Skip to content

Commit

Permalink
ENH: Clear cache for embedding and rerank (#1360)
Browse files Browse the repository at this point in the history
  • Loading branch information
codingl2k1 authored Apr 24, 2024
1 parent 9ddff32 commit 2ba72b0
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
13 changes: 13 additions & 0 deletions xinference/model/embedding/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import logging
import os
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union, no_type_check

import numpy as np

from ...device_utils import empty_cache
from ...types import Embedding, EmbeddingData, EmbeddingUsage
from ..core import CacheableModelSpec, ModelDescription
from ..utils import get_cache_dir, is_model_cached
Expand All @@ -28,6 +31,10 @@
# Init when registering all the builtin models.
MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list)
EMBEDDING_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list)
EMBEDDING_EMPTY_CACHE_COUNT = int(
os.getenv("XINFERENCE_EMBEDDING_EMPTY_CACHE_COUNT", "10")
)
assert EMBEDDING_EMPTY_CACHE_COUNT > 0


def get_embedding_model_descriptions():
Expand Down Expand Up @@ -116,6 +123,7 @@ def __init__(self, model_uid: str, model_path: str, device: Optional[str] = None
self._model_path = model_path
self._device = device
self._model = None
self._counter = 0

def load(self):
try:
Expand All @@ -134,6 +142,11 @@ def load(self):
self._model = SentenceTransformer(self._model_path, device=self._device)

def create_embedding(self, sentences: Union[str, List[str]], **kwargs):
self._counter += 1
if self._counter % EMBEDDING_EMPTY_CACHE_COUNT == 0:
logger.debug("Empty embedding cache.")
gc.collect()
empty_cache()
from sentence_transformers import SentenceTransformer

kwargs.setdefault("normalize_embeddings", True)
Expand Down
10 changes: 10 additions & 0 deletions xinference/model/rerank/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import logging
import os
import uuid
Expand All @@ -21,6 +22,7 @@
import numpy as np

from ...constants import XINFERENCE_CACHE_DIR
from ...device_utils import empty_cache
from ...types import Document, DocumentObj, Rerank
from ..core import CacheableModelSpec, ModelDescription
from ..utils import is_model_cached
Expand All @@ -31,6 +33,8 @@
# Init when registering all the builtin models.
MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list)
RERANK_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list)
RERANK_EMPTY_CACHE_COUNT = int(os.getenv("XINFERENCE_RERANK_EMPTY_CACHE_COUNT", "10"))
assert RERANK_EMPTY_CACHE_COUNT > 0


def get_rerank_model_descriptions():
Expand Down Expand Up @@ -113,6 +117,7 @@ def __init__(
self._model_config = model_config or dict()
self._use_fp16 = use_fp16
self._model = None
self._counter = 0

def load(self):
if self._model_spec.type == "normal":
Expand Down Expand Up @@ -160,6 +165,11 @@ def rerank(
return_documents: Optional[bool],
**kwargs,
) -> Rerank:
self._counter += 1
if self._counter % RERANK_EMPTY_CACHE_COUNT == 0:
logger.debug("Empty rerank cache.")
gc.collect()
empty_cache()
assert self._model is not None
if kwargs:
raise ValueError("rerank hasn't support extra parameter.")
Expand Down

0 comments on commit 2ba72b0

Please sign in to comment.