diff --git a/xinference/model/embedding/core.py b/xinference/model/embedding/core.py index 35af7d2cd6..331c4a8b9b 100644 --- a/xinference/model/embedding/core.py +++ b/xinference/model/embedding/core.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import logging +import os from collections import defaultdict from typing import Dict, List, Optional, Tuple, Union, no_type_check import numpy as np +from ...device_utils import empty_cache from ...types import Embedding, EmbeddingData, EmbeddingUsage from ..core import CacheableModelSpec, ModelDescription from ..utils import get_cache_dir, is_model_cached @@ -28,6 +31,10 @@ # Init when registering all the builtin models. MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list) EMBEDDING_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list) +EMBEDDING_EMPTY_CACHE_COUNT = int( + os.getenv("XINFERENCE_EMBEDDING_EMPTY_CACHE_COUNT", "10") +) +assert EMBEDDING_EMPTY_CACHE_COUNT > 0 def get_embedding_model_descriptions(): @@ -116,6 +123,7 @@ def __init__(self, model_uid: str, model_path: str, device: Optional[str] = None self._model_path = model_path self._device = device self._model = None + self._counter = 0 def load(self): try: @@ -134,6 +142,11 @@ def load(self): self._model = SentenceTransformer(self._model_path, device=self._device) def create_embedding(self, sentences: Union[str, List[str]], **kwargs): + self._counter += 1 + if self._counter % EMBEDDING_EMPTY_CACHE_COUNT == 0: + logger.debug("Empty embedding cache.") + gc.collect() + empty_cache() from sentence_transformers import SentenceTransformer kwargs.setdefault("normalize_embeddings", True) diff --git a/xinference/model/rerank/core.py b/xinference/model/rerank/core.py index 4b0a3a9fd2..4069918e51 100644 --- a/xinference/model/rerank/core.py +++ b/xinference/model/rerank/core.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import logging import os import uuid @@ -21,6 +22,7 @@ import numpy as np from ...constants import XINFERENCE_CACHE_DIR +from ...device_utils import empty_cache from ...types import Document, DocumentObj, Rerank from ..core import CacheableModelSpec, ModelDescription from ..utils import is_model_cached @@ -31,6 +33,8 @@ # Init when registering all the builtin models. MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list) RERANK_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list) +RERANK_EMPTY_CACHE_COUNT = int(os.getenv("XINFERENCE_RERANK_EMPTY_CACHE_COUNT", "10")) +assert RERANK_EMPTY_CACHE_COUNT > 0 def get_rerank_model_descriptions(): @@ -113,6 +117,7 @@ def __init__( self._model_config = model_config or dict() self._use_fp16 = use_fp16 self._model = None + self._counter = 0 def load(self): if self._model_spec.type == "normal": @@ -160,6 +165,11 @@ def rerank( return_documents: Optional[bool], **kwargs, ) -> Rerank: + self._counter += 1 + if self._counter % RERANK_EMPTY_CACHE_COUNT == 0: + logger.debug("Empty rerank cache.") + gc.collect() + empty_cache() assert self._model is not None if kwargs: raise ValueError("rerank hasn't support extra parameter.")