From 38574152fcde104e827a2a904b271e958cf31033 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Wed, 4 Oct 2023 14:00:49 -0700 Subject: [PATCH] GS workspaces can be bucket subfolders (#604) --- CHANGELOG.md | 4 + tango/integrations/gs/common.py | 109 +++++++++++++++-------- tango/integrations/gs/step_cache.py | 14 +-- tango/integrations/gs/workspace.py | 39 +++++--- tests/integrations/gs/step_cache_test.py | 13 ++- tests/integrations/gs/workspace_test.py | 24 +++-- 6 files changed, 135 insertions(+), 68 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 075775320..9093a4ce0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added + +- The `GSWorkspace()` can now be initialized with google cloud bucket subfolders. + ### Fixed - Removed unnecessary code coverage dev requirements. diff --git a/tango/integrations/gs/common.py b/tango/integrations/gs/common.py index 2eda3b6b3..875d04b88 100644 --- a/tango/integrations/gs/common.py +++ b/tango/integrations/gs/common.py @@ -9,7 +9,7 @@ import time from dataclasses import dataclass from pathlib import Path -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import google.auth from google.api_core import exceptions @@ -27,33 +27,50 @@ logger = logging.getLogger(__name__) -def empty_bucket(bucket_name: str): +def get_bucket_and_prefix(folder_name: str) -> Tuple[str, str]: """ - Removes all the tango-related blobs from the specified bucket. + Split bucket name and subfolder name, if present. + """ + split = folder_name.split("/") + return split[0], "/".join(split[1:]) + + +def empty_bucket_folder(folder_name: str): + """ + Removes all the tango-related blobs from the specified bucket folder. Used for testing. """ credentials, project = google.auth.default() client = storage.Client(project=project, credentials=credentials) + bucket_name, prefix = get_bucket_and_prefix(folder_name) + + prefix = prefix + "/tango-" if prefix else "tango-" + bucket = client.bucket(bucket_name) try: - bucket.delete_blobs(list(bucket.list_blobs(prefix="tango-"))) + bucket.delete_blobs(list(bucket.list_blobs(prefix=prefix))) except exceptions.NotFound: pass -def empty_datastore(namespace: str): +def empty_datastore(folder_name: str): """ - Removes all the tango-related entities from the specified namespace in datastore. + Removes all the tango-related entities from the specified namespace subfolder in datastore. Used for testing. """ from google.cloud import datastore credentials, project = google.auth.default() + namespace, prefix = get_bucket_and_prefix(folder_name) + + run_kind = prefix + "/run" if prefix else "run" + stepinfo_kind = prefix + "/stepinfo" if prefix else "stepinfo" + client = datastore.Client(project=project, credentials=credentials, namespace=namespace) - run_query = client.query(kind="run") + run_query = client.query(kind=run_kind) run_query.keys_only() keys = [entity.key for entity in run_query.fetch()] - stepinfo_query = client.query(kind="stepinfo") + stepinfo_query = client.query(kind=stepinfo_kind) stepinfo_query.keys_only() keys += [entity.key for entity in stepinfo_query.fetch()] client.delete_multi(keys) @@ -108,12 +125,19 @@ class GSArtifactWriteError(TangoError): pass +def join_path(*args) -> str: + """ + We use this since we cannot use `os.path.join` for cloud storage paths. + """ + return "/".join(args).strip("/") + + class GSClient: """ A client for interacting with Google Cloud Storage. The authorization works by providing OAuth2 credentials. - :param bucket_name: The name of the Google Cloud bucket to use. + :param folder_name: The name of the Google Cloud bucket folder to use. :param credentials: OAuth2 credentials can be provided. If not provided, default gcloud credentials are inferred. :param project: Optionally, the project ID can be provided. This is not essential @@ -123,7 +147,7 @@ class GSClient: placeholder_file = ".placeholder" """ - The placeholder file is used for creation of a folder in the cloud bucket, + The placeholder file is used for creation of a folder in the cloud bucket folder, as empty folders are not allowed. It is also used as a marker for the creation time of the folder, hence we use a separate file to mark the artifact as uncommitted. @@ -143,7 +167,7 @@ class GSClient: def __init__( self, - bucket_name: str, + folder_name: str, credentials: Optional[Credentials] = None, project: Optional[str] = None, ): @@ -151,9 +175,12 @@ def __init__( credentials, project = google.auth.default() self.storage = storage.Client(project=project, credentials=credentials) - self.bucket_name = bucket_name + self.folder_name = folder_name + + self.bucket_name, self.prefix = get_bucket_and_prefix(folder_name) + settings_file = self._gs_path(self.settings_file) - blob = self.storage.bucket(bucket_name).blob(self.settings_file) # no HTTP request yet + blob = self.storage.bucket(self.bucket_name).blob(settings_file) # no HTTP request yet try: with blob.open("r") as file_ref: json.load(file_ref) @@ -166,13 +193,12 @@ def url(self, artifact: Optional[str] = None): """ Returns the remote url of the storage artifact. """ - path = f"gs://{self.bucket_name}" + path = f"gs://{self.folder_name}" if artifact is not None: path = f"{path}/{artifact}" return path - @classmethod - def _convert_blobs_to_artifact(cls, blobs: List[storage.Blob]) -> GSArtifact: + def _convert_blobs_to_artifact(self, blobs: List[storage.Blob]) -> GSArtifact: """ Converts a list of `google.cloud.storage.Blob` to a `GSArtifact`. """ @@ -182,22 +208,24 @@ def _convert_blobs_to_artifact(cls, blobs: List[storage.Blob]) -> GSArtifact: committed: bool = True for blob in blobs: - if blob.name.endswith(cls.placeholder_file): + if blob.name.endswith(self.placeholder_file): created = blob.time_created - name = blob.name.replace("/" + cls.placeholder_file, "") + name = blob.name.replace("/" + self.placeholder_file, "") + if self.prefix: + name = name.replace(self.prefix + "/", "") artifact_path = name # does not contain bucket info here. - elif blob.name.endswith(cls.uncommitted_file): + elif blob.name.endswith(self.uncommitted_file): committed = False assert name is not None, "Folder is not a GSArtifact, should not have happened." return GSArtifact(name, artifact_path, created, committed) @classmethod - def from_env(cls, bucket_name: str): + def from_env(cls, folder_name: str): """ Constructs the client object from the environment, using default credentials. """ - return cls(bucket_name) + return cls(folder_name) def get(self, artifact: Union[str, GSArtifact]) -> GSArtifact: """ @@ -210,19 +238,18 @@ def get(self, artifact: Union[str, GSArtifact]) -> GSArtifact: # We have an artifact, and we recreate it with refreshed info. path = artifact.artifact_path - blobs = list(self.storage.bucket(self.bucket_name).list_blobs(prefix=path)) + prefix = self._gs_path(path) + blobs = list(self.storage.bucket(self.bucket_name).list_blobs(prefix=prefix)) if len(blobs) > 0: return self._convert_blobs_to_artifact(blobs) else: raise GSArtifactNotFound() - @classmethod - def _gs_path(cls, *args): + def _gs_path(self, *args): """ - Returns path within google cloud storage bucket. We use this since we cannot - use `os.path.join` for cloud storage paths. + Returns path within google cloud storage bucket. """ - return "/".join(args) + return join_path(self.prefix, *args) def create(self, artifact: str): """ @@ -230,7 +257,9 @@ def create(self, artifact: str): """ bucket = self.storage.bucket(self.bucket_name) # gives refreshed information - if bucket.blob(self._gs_path(artifact, self.placeholder_file)).exists(): + + artifact_path = self._gs_path(artifact, self.placeholder_file) + if bucket.blob(artifact_path).exists(): raise GSArtifactConflict(f"{artifact} already exists!") else: # Additional safety check @@ -238,14 +267,17 @@ def create(self, artifact: str): raise GSArtifactConflict(f"{artifact} already exists!") bucket.blob(self._gs_path(artifact, self.placeholder_file)).upload_from_string("") bucket.blob(self._gs_path(artifact, self.uncommitted_file)).upload_from_string("") - return self._convert_blobs_to_artifact(list(bucket.list_blobs(prefix=artifact))) + return self._convert_blobs_to_artifact( + list(bucket.list_blobs(prefix=self._gs_path(artifact))) + ) def delete(self, artifact: GSArtifact): """ Removes the artifact from the remote location. """ bucket = self.storage.bucket(self.bucket_name) - blobs = list(bucket.list_blobs(prefix=artifact.artifact_path)) + prefix = self._gs_path(artifact.artifact_path) + blobs = list(bucket.list_blobs(prefix=prefix)) bucket.delete_blobs(blobs) def upload(self, artifact: Union[str, GSArtifact], objects_dir: Path): @@ -260,7 +292,7 @@ def upload(self, artifact: Union[str, GSArtifact], objects_dir: Path): source_path = str(objects_dir) def _sync_blob(source_file_path: str, target_file_path: str): - blob = self.storage.bucket(self.bucket_name).blob(target_file_path) + blob = self.storage.bucket(self.bucket_name).blob(self._gs_path(target_file_path)) blob.upload_from_filename(source_file_path) import concurrent.futures @@ -277,7 +309,7 @@ def _sync_blob(source_file_path: str, target_file_path: str): for dirpath, _, filenames in os.walk(source_path): for filename in filenames: source_file_path = os.path.join(dirpath, filename) - target_file_path = self._gs_path( + target_file_path = join_path( folder_path, source_file_path.replace(source_path + "/", "") ) upload_futures.append( @@ -328,14 +360,16 @@ def _fetch_blob(blob: storage.Blob): import concurrent.futures bucket = self.storage.bucket(self.bucket_name) - bucket.update() + # We may not need updates that frequently, with list_blobs(prefix). + # bucket.update() try: with concurrent.futures.ThreadPoolExecutor( max_workers=self.NUM_CONCURRENT_WORKERS, thread_name_prefix="GSClient.download()-" ) as executor: download_futures = [] - for blob in bucket.list_blobs(prefix=artifact.artifact_path): + prefix = self._gs_path(artifact.artifact_path) + for blob in bucket.list_blobs(prefix=prefix): download_futures.append(executor.submit(_fetch_blob, blob)) for future in concurrent.futures.as_completed(download_futures): future.result() @@ -348,6 +382,7 @@ def artifacts(self, prefix: str, uncommitted: bool = True) -> List[GSArtifact]: `match` and `uncommitted` criteria. These can include steps and runs. """ list_of_artifacts = [] + prefix = self._gs_path(prefix) for folder_name in self.storage.list_blobs( self.bucket_name, prefix=prefix, delimiter="/" )._get_next_page_response()["prefixes"]: @@ -405,15 +440,15 @@ def get_credentials(credentials: Optional[Union[str, Credentials]] = None) -> Cr def get_client( - bucket_name: str, + folder_name: str, credentials: Optional[Union[str, Credentials]] = None, project: Optional[str] = None, ) -> GSClient: """ - Returns a `GSClient` object for a google cloud bucket. + Returns a `GSClient` object for a google cloud bucket folder. """ credentials = get_credentials(credentials) - return GSClient(bucket_name, credentials=credentials, project=project) + return GSClient(folder_name, credentials=credentials, project=project) class Constants(RemoteConstants): diff --git a/tango/integrations/gs/step_cache.py b/tango/integrations/gs/step_cache.py index 392abfd58..9b18053da 100644 --- a/tango/integrations/gs/step_cache.py +++ b/tango/integrations/gs/step_cache.py @@ -11,6 +11,7 @@ GSArtifactNotFound, GSArtifactWriteError, GSClient, + get_bucket_and_prefix, ) from tango.step import Step from tango.step_cache import StepCache @@ -32,24 +33,23 @@ class GSStepCache(RemoteStepCache): .. tip:: Registered as a :class:`~tango.step_cache.StepCache` under the name "gs". - :param bucket_name: The name of the google cloud bucket to use. + :param folder_name: The name of the google cloud bucket folder to use. :param client: The google cloud storage client to use. """ Constants = Constants - def __init__(self, bucket_name: str, client: Optional[GSClient] = None): + def __init__(self, folder_name: str, client: Optional[GSClient] = None): if client is not None: + bucket_name, _ = get_bucket_and_prefix(folder_name) assert ( bucket_name == client.bucket_name ), "Assert that bucket name is same as client bucket until we do better" - self.bucket_name = bucket_name + self.folder_name = folder_name self._client = client else: - self._client = GSClient(bucket_name) - super().__init__( - tango_cache_dir() / "gs_cache" / make_safe_filename(self._client.bucket_name) - ) + self._client = GSClient(folder_name) + super().__init__(tango_cache_dir() / "gs_cache" / make_safe_filename(folder_name)) @property def client(self): diff --git a/tango/integrations/gs/workspace.py b/tango/integrations/gs/workspace.py index 9ad211905..cd4ef31e6 100644 --- a/tango/integrations/gs/workspace.py +++ b/tango/integrations/gs/workspace.py @@ -12,6 +12,7 @@ from tango.integrations.gs.common import ( Constants, GCSStepLock, + get_bucket_and_prefix, get_client, get_credentials, ) @@ -32,7 +33,7 @@ class GSWorkspace(RemoteWorkspace): .. tip:: Registered as a :class:`~tango.workspace.Workspace` under the name "gs". - :param workspace: The name or ID of the Google Cloud bucket to use. + :param workspace: The name or ID of the Google Cloud bucket folder to use. :param project: The Google project ID. This is required for the datastore. If not provided, it will be inferred from the Google cloud credentials. @@ -61,7 +62,7 @@ def __init__( project: Optional[str] = None, credentials: Optional[Union[str, Credentials]] = None, ): - self.client = get_client(bucket_name=workspace, credentials=credentials, project=project) + self.client = get_client(folder_name=workspace, credentials=credentials, project=project) self.client.NUM_CONCURRENT_WORKERS = self.NUM_CONCURRENT_WORKERS self._cache = GSStepCache(workspace, client=self.client) @@ -71,7 +72,11 @@ def __init__( credentials = get_credentials() project = project or credentials.quota_project_id - self._ds = datastore.Client(namespace=workspace, project=project, credentials=credentials) + + self.bucket_name, self.prefix = get_bucket_and_prefix(workspace) + self._ds = datastore.Client( + namespace=self.bucket_name, project=project, credentials=credentials + ) @property def cache(self): @@ -108,19 +113,29 @@ def _remote_lock(self, step: Step) -> GCSStepLock: def _step_location(self, step: Step) -> str: return self.client.url(self.Constants.step_artifact_name(step)) + @property + def _run_key(self): + return self.client._gs_path("run") + + @property + def _stepinfo_key(self): + return self.client._gs_path("stepinfo") + def _save_run( self, steps: Dict[str, StepInfo], run_data: Dict[str, str], name: Optional[str] = None ) -> Run: if name is None: while True: name = petname.generate() + str(random.randint(0, 100)) - if not self._ds.get(self._ds.key("run", name)): + if not self._ds.get(self._ds.key(self._run_key, name)): break else: - if self._ds.get(self._ds.key("run", name)): + if self._ds.get(self._ds.key(self._run_key, name)): raise ValueError(f"Run name '{name}' is already in use") - run_entity = self._ds.entity(key=self._ds.key("run", name), exclude_from_indexes=("steps",)) + run_entity = self._ds.entity( + key=self._ds.key(self._run_key, name), exclude_from_indexes=("steps",) + ) # Even though the run's name is part of the key, we add this as a # field so we can index on it and order asc/desc (indices on the key field don't allow ordering). run_entity["name"] = name @@ -164,7 +179,7 @@ def registered_runs(self) -> Dict[str, Run]: thread_name_prefix="GSWorkspace.registered_runs()-", ) as executor: run_futures = [] - for run_entity in self._ds.query(kind="run").fetch(): + for run_entity in self._ds.query(kind=self._run_key).fetch(): run_futures.append(executor.submit(self._get_run_from_entity, run_entity)) for future in concurrent.futures.as_completed(run_futures): run = future.result() @@ -225,7 +240,7 @@ def _fetch_run_entities( if sort_field is not None and not sort_locally: order = [sort_field if not sort_descending else f"-{sort_field}"] - query = self._ds.query(kind="run", order=order) + query = self._ds.query(kind=self._run_key, order=order) if match: # HACK: Datastore has no direct string matching functionality, # but this comparison is equivalent to checking if 'name' starts with 'match'. @@ -309,7 +324,7 @@ def _fetch_step_info_entities( if sort_field is not None and not sort_locally: order = [sort_field if not sort_descending else f"-{sort_field}"] - query = self._ds.query(kind="stepinfo", order=order) + query = self._ds.query(kind=self._stepinfo_key, order=order) if match is not None: # HACK: Datastore has no direct string matching functionality, @@ -341,7 +356,7 @@ def _fetch_step_info_entities( def registered_run(self, name: str) -> Run: err_msg = f"Run '{name}' not found in workspace" - run_entity = self._ds.get(key=self._ds.key("run", name)) + run_entity = self._ds.get(key=self._ds.key(self._run_key, name)) if not run_entity: raise KeyError(err_msg) @@ -355,7 +370,7 @@ def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo: unique_id = ( step_or_unique_id if isinstance(step_or_unique_id, str) else step_or_unique_id.unique_id ) - step_info_entity = self._ds.get(key=self._ds.key("stepinfo", unique_id)) + step_info_entity = self._ds.get(key=self._ds.key(self._stepinfo_key, unique_id)) if step_info_entity is not None: step_info_bytes = step_info_entity["step_info_dict"] step_info = StepInfo.from_json_dict(json.loads(step_info_bytes)) @@ -369,7 +384,7 @@ def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo: def _update_step_info(self, step_info: StepInfo): step_info_entity = self._ds.entity( - key=self._ds.key("stepinfo", step_info.unique_id), + key=self._ds.key(self._stepinfo_key, step_info.unique_id), exclude_from_indexes=("step_info_dict",), ) diff --git a/tests/integrations/gs/step_cache_test.py b/tests/integrations/gs/step_cache_test.py index ddceeb59c..4dfe100b2 100644 --- a/tests/integrations/gs/step_cache_test.py +++ b/tests/integrations/gs/step_cache_test.py @@ -1,23 +1,28 @@ import os +import pytest + from tango.common.testing import TangoTestCase from tango.common.testing.steps import FloatStep -from tango.integrations.gs.common import empty_bucket +from tango.integrations.gs.common import empty_bucket_folder from tango.integrations.gs.step_cache import GSStepCache GS_BUCKET_NAME = os.environ.get("GS_BUCKET_NAME", "allennlp-tango-bucket") +GS_SUBFOLDER = f"{GS_BUCKET_NAME}/my-workspaces/workspace1" class TestGSStepCache(TangoTestCase): def setup_method(self): super().setup_method() - empty_bucket(GS_BUCKET_NAME) + empty_bucket_folder(GS_BUCKET_NAME) + empty_bucket_folder(GS_SUBFOLDER) def teardown_method(self): super().teardown_method() - def test_step_cache(self): - cache = GSStepCache(bucket_name=GS_BUCKET_NAME) + @pytest.mark.parametrize("gs_path", [GS_BUCKET_NAME, GS_SUBFOLDER]) + def test_step_cache(self, gs_path): + cache = GSStepCache(folder_name=gs_path) step = FloatStep(result=1.0) cache[step] = 1.0 assert step in cache diff --git a/tests/integrations/gs/workspace_test.py b/tests/integrations/gs/workspace_test.py index 28318db8e..d34f82e02 100644 --- a/tests/integrations/gs/workspace_test.py +++ b/tests/integrations/gs/workspace_test.py @@ -1,34 +1,42 @@ import os +import pytest + from tango.common.testing import TangoTestCase from tango.common.testing.steps import FloatStep -from tango.integrations.gs.common import empty_bucket, empty_datastore +from tango.integrations.gs.common import empty_bucket_folder, empty_datastore from tango.integrations.gs.workspace import GSWorkspace from tango.step_info import StepState from tango.workspace import Workspace GS_BUCKET_NAME = os.environ.get("GS_BUCKET_NAME", "allennlp-tango-bucket") +GS_SUBFOLDER = f"{GS_BUCKET_NAME}/my-workspaces/workspace1" class TestGSWorkspace(TangoTestCase): def setup_method(self): super().setup_method() - empty_bucket(GS_BUCKET_NAME) + empty_bucket_folder(GS_BUCKET_NAME) + empty_bucket_folder(GS_SUBFOLDER) empty_datastore(GS_BUCKET_NAME) + empty_datastore(GS_SUBFOLDER) def teardown_method(self): super().teardown_method() - def test_from_url(self): - workspace = Workspace.from_url(f"gs://{GS_BUCKET_NAME}") + @pytest.mark.parametrize("gs_path", [GS_BUCKET_NAME, GS_SUBFOLDER]) + def test_from_url(self, gs_path: str): + workspace = Workspace.from_url(f"gs://{gs_path}") assert isinstance(workspace, GSWorkspace) - def test_from_params(self): - workspace = Workspace.from_params({"type": "gs", "workspace": GS_BUCKET_NAME}) + @pytest.mark.parametrize("gs_path", [GS_BUCKET_NAME, GS_SUBFOLDER]) + def test_from_params(self, gs_path: str): + workspace = Workspace.from_params({"type": "gs", "workspace": gs_path}) assert isinstance(workspace, GSWorkspace) - def test_direct_usage(self): - workspace = GSWorkspace(GS_BUCKET_NAME) + @pytest.mark.parametrize("gs_path", [GS_BUCKET_NAME, GS_SUBFOLDER]) + def test_direct_usage(self, gs_path: str): + workspace = GSWorkspace(gs_path) step = FloatStep(step_name="float", result=1.0) run = workspace.register_run([step])