From 959b71cd2418cb976771cd4faef67f5847b39de7 Mon Sep 17 00:00:00 2001 From: Michael Sekamanya Date: Fri, 16 Aug 2024 19:51:08 -0700 Subject: [PATCH] Refactor samba document processor --- nesis/api/core/document_loaders/s3.py | 4 +- nesis/api/core/document_loaders/samba.py | 458 ++++++++---------- nesis/api/core/tasks/document_management.py | 11 +- .../tests/core/document_loaders/test_samba.py | 12 +- .../tests/tasks/test_document_management.py | 28 +- 5 files changed, 219 insertions(+), 294 deletions(-) diff --git a/nesis/api/core/document_loaders/s3.py b/nesis/api/core/document_loaders/s3.py index f4cc361..e46ce6a 100644 --- a/nesis/api/core/document_loaders/s3.py +++ b/nesis/api/core/document_loaders/s3.py @@ -211,9 +211,9 @@ def _sync_document( ) _LOG.info(f"Done syncing object {item['Key']} in bucket {bucket_name}") - except Exception as ex: + except: _LOG.warning( - f"Error when getting and ingesting document {item['Key']} - {ex}", + f"Error when getting and ingesting document {item['Key']}", exc_info=True, ) diff --git a/nesis/api/core/document_loaders/samba.py b/nesis/api/core/document_loaders/samba.py index dd368a6..fc6f3d3 100644 --- a/nesis/api/core/document_loaders/samba.py +++ b/nesis/api/core/document_loaders/samba.py @@ -1,316 +1,254 @@ -import uuid +import logging import pathlib -import json -import memcache +import uuid +from concurrent.futures import as_completed from datetime import datetime from typing import Dict, Any +import memcache import smbprotocol from smbclient import scandir, stat, shutil -import logging -from nesis.api.core.models.entities import Document -from nesis.api.core.services import util -from nesis.api.core.services.util import ( - save_document, - get_document, - delete_document, - get_documents, - ValidationException, - ingest_file, -) +from nesis.api.core.document_loaders.loader_helper import DocumentProcessor +from nesis.api.core.models.entities import Datasource from nesis.api.core.util import http, clean_control, isblank +from nesis.api.core.util.concurrency import IOBoundPool from nesis.api.core.util.constants import DEFAULT_DATETIME_FORMAT, DEFAULT_SAMBA_PORT -from nesis.api.core.util.dateutil import strptime _LOG = logging.getLogger(__name__) -def fetch_documents( - connection: Dict[str, str], - rag_endpoint: str, - http_client: http.HttpClient, - cache_client: memcache.Client, - metadata: Dict[str, Any], -) -> None: - try: - _sync_samba_documents( - connection=connection, - rag_endpoint=rag_endpoint, - http_client=http_client, - metadata=metadata, - cache_client=cache_client, - ) - except: - _LOG.exception(f"Error syncing documents") - - try: - _unsync_samba_documents( - connection=connection, rag_endpoint=rag_endpoint, http_client=http_client - ) - except Exception as ex: - _LOG.exception(f"Error unsyncing documents") - - -def validate_connection_info(connection: Dict[str, Any]) -> Dict[str, Any]: - port = connection.get("port") or DEFAULT_SAMBA_PORT - _valid_keys = ["port", "endpoint", "user", "password", "dataobjects"] - if not str(port).isnumeric(): - raise ValueError("Port value cannot be non numeric") +class Processor(DocumentProcessor): + def __init__( + self, + config, + http_client: http.HttpClient, + cache_client: memcache.Client, + datasource: Datasource, + ): + super().__init__(config, http_client, datasource) + self._config = config + self._http_client = http_client + self._cache_client = cache_client + self._datasource = datasource + self._futures = [] + + def run(self, metadata: Dict[str, Any]): + connection = self._datasource.connection + try: + self._sync_samba_documents( + metadata=metadata, + ) + except: + _LOG.exception(f"Error syncing documents") - assert not isblank( - connection.get("endpoint") - ), "A valid share address must be supplied" + try: + self._unsync_samba_documents( + connection=connection, + ) + except: + _LOG.exception(f"Error unsyncing documents") - try: - _connect_samba_server(connection) - except Exception as ex: - _LOG.exception( - f"Failed to connect to samba server at {connection['endpoint']}", - ) - raise ValueError(ex) - connection["port"] = port - return { - key: val - for key, val in connection.items() - if key in _valid_keys and not isblank(connection[key]) - } + for future in as_completed(self._futures): + try: + future.result() + except: + _LOG.warning(future.exception()) + def _sync_samba_documents(self, metadata): -def _connect_samba_server(connection): - username = connection.get("user") - password = connection.get("password") - endpoint = connection.get("endpoint") - port = connection.get("port") - next(scandir(endpoint, username=username, password=password, port=port)) + connection = self._datasource.connection + username = connection["user"] + password = connection["password"] + endpoint = connection["endpoint"] + port = connection["port"] + # These are any folder specified to scope the sync to + dataobjects = connection.get("dataobjects") or "" -def _sync_samba_documents( - connection, rag_endpoint, http_client, metadata, cache_client -): + dataobjects_parts = [do.strip() for do in dataobjects.split(",")] - username = connection["user"] - password = connection["password"] - endpoint = connection["endpoint"] - port = connection["port"] - # These are any folder specified to scope the sync to - dataobjects = connection.get("dataobjects") or "" + try: + file_shares = scandir( + endpoint, username=username, password=password, port=port + ) + except Exception as ex: + _LOG.exception( + f"Error while scanning share on samba server {endpoint} - {ex}" + ) + raise - dataobjects_parts = [do.strip() for do in dataobjects.split(",")] + work_dir = f"/tmp/{uuid.uuid4()}" + pathlib.Path(work_dir).mkdir(parents=True) - try: - file_shares = scandir(endpoint, username=username, password=password, port=port) - except Exception as ex: - _LOG.exception(f"Error while scanning share on samba server {endpoint} - {ex}") - raise - - work_dir = f"/tmp/{uuid.uuid4()}" - pathlib.Path(work_dir).mkdir(parents=True) - - for file_share in file_shares: - if ( - len(dataobjects_parts) > 0 - and file_share.is_dir() - and file_share.name not in dataobjects_parts - ): - continue - try: - self_link = file_share.path - _lock_key = clean_control(f"{__name__}/locks/{self_link}") + for file_share in file_shares: + if ( + len(dataobjects_parts) > 0 + and file_share.is_dir() + and file_share.name not in dataobjects_parts + ): + continue + try: - if cache_client.add(key=_lock_key, val=_lock_key, time=30 * 60): _metadata = { **(metadata or {}), "file_name": file_share.path, - "self_link": self_link, } - try: - _process_file( + + self._futures.append( + IOBoundPool.submit( + self._process_file, connection=connection, file_share=file_share, work_dir=work_dir, - http_client=http_client, - rag_endpoint=rag_endpoint, metadata=_metadata, ) - finally: - cache_client.delete(_lock_key) - else: - _LOG.info(f"Document {self_link} is already processing") - except: - _LOG.warn( - f"Error fetching and updating documents from shared_file share {file_share.path} - ", - exc_info=True, - ) - - _LOG.info( - f"Completed syncing files from samba server {endpoint} " - f"to endpoint {rag_endpoint}" - ) - - -def _process_file( - connection, file_share, work_dir, http_client, rag_endpoint, metadata -): - username = connection["user"] - password = connection["password"] - endpoint = connection["endpoint"] - port = connection["port"] + ) - if file_share.is_dir(): - if not file_share.name.startswith("."): - dir_files = scandir( - file_share.path, username=username, password=password, port=port - ) - for dir_file in dir_files: - _process_file( - connection=connection, - file_share=dir_file, - work_dir=work_dir, - http_client=http_client, - rag_endpoint=rag_endpoint, - metadata=metadata, + except: + _LOG.warning( + f"Error fetching and updating documents from shared_file share {file_share.path} - ", + exc_info=True, ) - return - file_name = file_share.name - file_stats = stat(file_share.path, username=username, password=password, port=port) - last_change_datetime = datetime.fromtimestamp(file_stats.st_chgtime) + def _process_file(self, connection, file_share, work_dir, metadata): + username = connection["user"] + password = connection["password"] + endpoint = connection["endpoint"] + port = connection["port"] - try: - file_path = f"{work_dir}/{file_share.name}" - file_unique_id = f"{uuid.uuid5(uuid.NAMESPACE_DNS, file_share.path)}" + if file_share.is_dir(): + if not file_share.name.startswith("."): + dir_files = scandir( + file_share.path, username=username, password=password, port=port + ) + for dir_file in dir_files: + self._process_file( + connection=connection, + file_share=dir_file, + work_dir=work_dir, + metadata=metadata, + ) + return - _LOG.info( - f"Starting syncing shared_file {file_name} in shared directory share {file_share.path}" + file_name = file_share.name + file_stats = stat( + file_share.path, username=username, password=password, port=port ) + last_change_datetime = datetime.fromtimestamp(file_stats.st_chgtime) try: - shutil.copyfile( - file_share.path, - file_path, - username=username, - password=password, - port=port, - ) - except Exception as ex: - _LOG.warn( - f"Failed to copy contents of shared_file {file_name} from shared location {file_share.path}", - exc_info=True, + file_path = f"{work_dir}/{file_share.name}" + file_unique_id = f"{uuid.uuid5(uuid.NAMESPACE_DNS, file_share.path)}" + + _LOG.info( + f"Starting syncing shared_file {file_name} in shared directory share {file_share.path}" ) - return - """ - Here we check if this file has been updated. - If the file has been updated, we delete it from the vector store and re-ingest the new updated file - """ - document: Document = get_document(document_id=file_unique_id) - if document and document.base_uri == endpoint: - store_metadata = document.store_metadata - if store_metadata and store_metadata.get("last_modified"): - if not strptime(date_string=store_metadata["last_modified"]).replace( - tzinfo=None - ) < last_change_datetime.replace(tzinfo=None).replace(microsecond=0): - _LOG.debug(f"Skipping shared_file {file_name} already up to date") - return - rag_metadata: dict = document.rag_metadata - if rag_metadata is None: - return - for document_data in rag_metadata.get("data") or []: + try: + shutil.copyfile( + file_share.path, + file_path, + username=username, + password=password, + port=port, + ) + self_link = file_share.path + _lock_key = clean_control(f"{__name__}/locks/{self_link}") + + metadata["self_link"] = self_link + + if self._cache_client.add(key=_lock_key, val=_lock_key, time=30 * 60): try: - util.un_ingest_file( - http_client=http_client, - endpoint=rag_endpoint, - doc_id=document_data["doc_id"], - ) - except: - _LOG.warn( - f"Failed to delete document {document_data['doc_id']}", - exc_info=True, + self.sync( + endpoint, + file_path, + last_modified=last_change_datetime, + metadata=metadata, + store_metadata={ + "shared_folder": file_share.name, + "file_path": file_share.path, + "filename": file_share.path, + "file_id": file_unique_id, + "size": file_stats.st_size, + "name": file_name, + "last_modified": last_change_datetime.strftime( + DEFAULT_DATETIME_FORMAT + ), + }, ) - try: - delete_document(document_id=file_unique_id) - except: - _LOG.warn( - f"Failed to delete shared_file {file_name}'s record. Continuing anyway...", - exc_info=True, - ) - - file_metadata = { - "shared_folder": file_share.name, - "file_path": file_share.path, - "file_id": file_unique_id, - "size": file_stats.st_size, - "name": file_name, - "last_modified": last_change_datetime.strftime(DEFAULT_DATETIME_FORMAT), - } + finally: + self._cache_client.delete(_lock_key) + else: + _LOG.info(f"Document {self_link} is already processing") + + except: + _LOG.warning( + f"Failed to copy contents of shared_file {file_name} from shared location {file_share.path}", + exc_info=True, + ) + return - try: - response = ingest_file( - http_client=http_client, - endpoint=rag_endpoint, - metadata=metadata, - file_path=file_path, + _LOG.info( + f"Done syncing shared_file {file_name} in location {file_share.path}" + ) + except Exception as ex: + _LOG.warn( + f"Error when getting and ingesting shared_file {file_name} - {ex}", + exc_info=True, ) - except UserWarning: - _LOG.debug(f"File {file_path} is already processing") - return - response_json = json.loads(response) - - save_document( - document_id=file_unique_id, - filename=file_share.path, - base_uri=endpoint, - rag_metadata=response_json, - store_metadata=file_metadata, - ) - - _LOG.info(f"Done syncing shared_file {file_name} in location {file_share.path}") - except Exception as ex: - _LOG.warn( - f"Error when getting and ingesting shared_file {file_name} - {ex}", - exc_info=True, - ) - _LOG.info( - f"Completed syncing files from shared_file share {file_share.path} to endpoint {rag_endpoint}" - ) - -def _unsync_samba_documents(connection, rag_endpoint, http_client): - try: + def _unsync_samba_documents(self, connection): username = connection["user"] password = connection["password"] - endpoint = connection["endpoint"] port = connection["port"] - work_dir = f"/tmp/{uuid.uuid4()}" - pathlib.Path(work_dir).mkdir(parents=True) - - documents = get_documents(base_uri=endpoint) - for document in documents: - store_metadata = document.store_metadata - rag_metadata = document.rag_metadata - + def clean(**kwargs): + store_metadata = kwargs["store_metadata"] file_path = store_metadata["file_path"] try: stat(file_path, username=username, password=password, port=port) + return False except smbprotocol.exceptions.SMBOSError as error: - if "No such file" not in str(error): + if "No such file" in str(error): + return True + else: raise - try: - http_client.deletes( - [ - f"{rag_endpoint}/v1/ingest/documents/{document_data['doc_id']}" - for document_data in rag_metadata.get("data") or [] - ] - ) - _LOG.info(f"Deleting document {document.filename}") - delete_document(document_id=document.id) - except: - _LOG.warning( - f"Failed to delete document {document.filename}", - exc_info=True, - ) - _LOG.info(f"Completed unsyncing files from endpoint {rag_endpoint}") - except: - _LOG.warn("Error fetching and updating documents", exc_info=True) + + try: + self.unsync(clean=clean) + except: + _LOG.warning("Error fetching and updating documents", exc_info=True) + + +def _connect_samba_server(connection): + username = connection.get("user") + password = connection.get("password") + endpoint = connection.get("endpoint") + port = connection.get("port") + next(scandir(endpoint, username=username, password=password, port=port)) + + +def validate_connection_info(connection: Dict[str, Any]) -> Dict[str, Any]: + port = connection.get("port") or DEFAULT_SAMBA_PORT + _valid_keys = ["port", "endpoint", "user", "password", "dataobjects"] + if not str(port).isnumeric(): + raise ValueError("Port value cannot be non numeric") + + assert not isblank( + connection.get("endpoint") + ), "A valid share address must be supplied" + + try: + _connect_samba_server(connection) + except Exception as ex: + _LOG.exception( + f"Failed to connect to samba server at {connection['endpoint']}", + ) + raise ValueError(ex) + connection["port"] = port + return { + key: val + for key, val in connection.items() + if key in _valid_keys and not isblank(connection[key]) + } diff --git a/nesis/api/core/tasks/document_management.py b/nesis/api/core/tasks/document_management.py index 80fbbd6..552688d 100644 --- a/nesis/api/core/tasks/document_management.py +++ b/nesis/api/core/tasks/document_management.py @@ -59,13 +59,16 @@ def ingest_datasource(**kwargs) -> None: metadata={"datasource": datasource.name}, ) case DatasourceType.WINDOWS_SHARE: - samba.fetch_documents( - connection=datasource.connection, - rag_endpoint=rag_endpoint, + + ingestor = samba.Processor( + config=config, http_client=http_client, - metadata={"datasource": datasource.name}, cache_client=cache_client, + datasource=datasource, ) + + ingestor.run(metadata=metadata) + case DatasourceType.S3: minio_ingestor = s3.Processor( config=config, diff --git a/nesis/api/tests/core/document_loaders/test_samba.py b/nesis/api/tests/core/document_loaders/test_samba.py index ca65db8..3391b2d 100644 --- a/nesis/api/tests/core/document_loaders/test_samba.py +++ b/nesis/api/tests/core/document_loaders/test_samba.py @@ -53,7 +53,7 @@ def test_fetch_documents( ) -> None: data = { "name": "s3 documents", - "engine": "s3", + "engine": "samba", "connection": { "endpoint": "https://s3.endpoint", "user": "user", @@ -87,12 +87,14 @@ def test_fetch_documents( http_client = mock.MagicMock() http_client.upload.return_value = json.dumps({}) - samba.fetch_documents( - connection=data["connection"], + ingestor = samba.Processor( + config=tests.config, http_client=http_client, - metadata={"datasource": "documents"}, - rag_endpoint="http://localhost:8080", cache_client=cache, + datasource=datasource, + ) + ingestor.run( + metadata={"datasource": "documents"}, ) _, upload_kwargs = http_client.upload.call_args_list[0] diff --git a/nesis/api/tests/tasks/test_document_management.py b/nesis/api/tests/tasks/test_document_management.py index 98bae52..9fa1c27 100644 --- a/nesis/api/tests/tasks/test_document_management.py +++ b/nesis/api/tests/tasks/test_document_management.py @@ -97,15 +97,9 @@ def test_ingest_datasource_minio( @mock.patch("nesis.api.core.document_loaders.samba.scandir") -@mock.patch("nesis.api.core.tasks.document_management.samba._unsync_samba_documents") -@mock.patch("nesis.api.core.tasks.document_management.samba._sync_samba_documents") +@mock.patch("nesis.api.core.tasks.document_management.samba.Processor") def test_ingest_datasource_samba( - _sync_samba_documents: mock.MagicMock, - _unsync_samba_documents: mock.MagicMock, - scandir, - tc, - cache_client, - http_client, + ingestor: mock.MagicMock, scandir, tc, cache_client, http_client ): """ Test the ingestion happy path @@ -117,7 +111,7 @@ def test_ingest_datasource_samba( password=tests.admin_password, ) datasource: Datasource = create_datasource( - token=admin_user.token, datasource_type="windows_share" + token=admin_user.token, datasource_type="WINDOWS_SHARE" ) ingest_datasource( @@ -127,21 +121,9 @@ def test_ingest_datasource_samba( params={"datasource": {"id": datasource.uuid}}, ) - _, kwargs_sync_samba_documents = _sync_samba_documents.call_args_list[0] - assert ( - kwargs_sync_samba_documents["rag_endpoint"] == tests.config["rag"]["endpoint"] - ) - tc.assertDictEqual(kwargs_sync_samba_documents["connection"], datasource.connection) - tc.assertDictEqual( - kwargs_sync_samba_documents["metadata"], {"datasource": datasource.name} - ) - - _, kwargs_unsync_samba_documents = _unsync_samba_documents.call_args_list[0] - assert ( - kwargs_unsync_samba_documents["rag_endpoint"] == tests.config["rag"]["endpoint"] - ) + _, kwargs_fetch_documents = ingestor.return_value.run.call_args_list[0] tc.assertDictEqual( - kwargs_unsync_samba_documents["connection"], datasource.connection + kwargs_fetch_documents["metadata"], {"datasource": datasource.name} )