From 9e2ae866c4f44074b55f837a35650d117d5ab61b Mon Sep 17 00:00:00 2001 From: Unai Garay Maestre Date: Wed, 29 Nov 2023 10:58:53 +0100 Subject: [PATCH] langchain[patch]: Adds progress bar to GooglePalmEmbeddings (#13812) - **Description:** Adds a tqdm progress bar to GooglePalmEmbeddings when embedding a list. - **Issue:** #13637 - **Dependencies:** TQDM as a main dependency (instead of extra) Signed-off-by: ugm2 --------- Signed-off-by: ugm2 Co-authored-by: Harrison Chase --- .../langchain/embeddings/google_palm.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/libs/langchain/langchain/embeddings/google_palm.py b/libs/langchain/langchain/embeddings/google_palm.py index a61dbd3b2997f..db6314d47f466 100644 --- a/libs/langchain/langchain/embeddings/google_palm.py +++ b/libs/langchain/langchain/embeddings/google_palm.py @@ -60,6 +60,8 @@ class GooglePalmEmbeddings(BaseModel, Embeddings): google_api_key: Optional[str] model_name: str = "models/embedding-gecko-001" """Model name to use.""" + show_progress_bar: bool = False + """Whether to show a tqdm progress bar. Must have `tqdm` installed.""" @root_validator() def validate_environment(cls, values: Dict) -> Dict: @@ -79,7 +81,20 @@ def validate_environment(cls, values: Dict) -> Dict: return values def embed_documents(self, texts: List[str]) -> List[List[float]]: - return [self.embed_query(text) for text in texts] + if self.show_progress_bar: + try: + from tqdm import tqdm + + iter_ = tqdm(texts, desc="GooglePalmEmbeddings") + except ImportError: + logger.warning( + "Unable to show progress bar because tqdm could not be imported. " + "Please install with `pip install tqdm`." + ) + iter_ = texts + else: + iter_ = texts + return [self.embed_query(text) for text in iter_] def embed_query(self, text: str) -> List[float]: """Embed query text."""