Skip to content

Commit

Permalink
GS workspaces can be bucket subfolders (#604)
Browse files Browse the repository at this point in the history
  • Loading branch information
AkshitaB authored Oct 4, 2023
1 parent b955ef7 commit 3857415
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 68 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- The `GSWorkspace()` can now be initialized with google cloud bucket subfolders.

### Fixed

- Removed unnecessary code coverage dev requirements.
Expand Down
109 changes: 72 additions & 37 deletions tango/integrations/gs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union

import google.auth
from google.api_core import exceptions
Expand All @@ -27,33 +27,50 @@
logger = logging.getLogger(__name__)


def empty_bucket(bucket_name: str):
def get_bucket_and_prefix(folder_name: str) -> Tuple[str, str]:
"""
Removes all the tango-related blobs from the specified bucket.
Split bucket name and subfolder name, if present.
"""
split = folder_name.split("/")
return split[0], "/".join(split[1:])


def empty_bucket_folder(folder_name: str):
"""
Removes all the tango-related blobs from the specified bucket folder.
Used for testing.
"""
credentials, project = google.auth.default()
client = storage.Client(project=project, credentials=credentials)
bucket_name, prefix = get_bucket_and_prefix(folder_name)

prefix = prefix + "/tango-" if prefix else "tango-"

bucket = client.bucket(bucket_name)
try:
bucket.delete_blobs(list(bucket.list_blobs(prefix="tango-")))
bucket.delete_blobs(list(bucket.list_blobs(prefix=prefix)))
except exceptions.NotFound:
pass


def empty_datastore(namespace: str):
def empty_datastore(folder_name: str):
"""
Removes all the tango-related entities from the specified namespace in datastore.
Removes all the tango-related entities from the specified namespace subfolder in datastore.
Used for testing.
"""
from google.cloud import datastore

credentials, project = google.auth.default()
namespace, prefix = get_bucket_and_prefix(folder_name)

run_kind = prefix + "/run" if prefix else "run"
stepinfo_kind = prefix + "/stepinfo" if prefix else "stepinfo"

client = datastore.Client(project=project, credentials=credentials, namespace=namespace)
run_query = client.query(kind="run")
run_query = client.query(kind=run_kind)
run_query.keys_only()
keys = [entity.key for entity in run_query.fetch()]
stepinfo_query = client.query(kind="stepinfo")
stepinfo_query = client.query(kind=stepinfo_kind)
stepinfo_query.keys_only()
keys += [entity.key for entity in stepinfo_query.fetch()]
client.delete_multi(keys)
Expand Down Expand Up @@ -108,12 +125,19 @@ class GSArtifactWriteError(TangoError):
pass


def join_path(*args) -> str:
"""
We use this since we cannot use `os.path.join` for cloud storage paths.
"""
return "/".join(args).strip("/")


class GSClient:
"""
A client for interacting with Google Cloud Storage. The authorization works by
providing OAuth2 credentials.
:param bucket_name: The name of the Google Cloud bucket to use.
:param folder_name: The name of the Google Cloud bucket folder to use.
:param credentials: OAuth2 credentials can be provided. If not provided, default
gcloud credentials are inferred.
:param project: Optionally, the project ID can be provided. This is not essential
Expand All @@ -123,7 +147,7 @@ class GSClient:

placeholder_file = ".placeholder"
"""
The placeholder file is used for creation of a folder in the cloud bucket,
The placeholder file is used for creation of a folder in the cloud bucket folder,
as empty folders are not allowed. It is also used as a marker for the creation
time of the folder, hence we use a separate file to mark the artifact as
uncommitted.
Expand All @@ -143,17 +167,20 @@ class GSClient:

def __init__(
self,
bucket_name: str,
folder_name: str,
credentials: Optional[Credentials] = None,
project: Optional[str] = None,
):
if not credentials:
credentials, project = google.auth.default()

self.storage = storage.Client(project=project, credentials=credentials)
self.bucket_name = bucket_name
self.folder_name = folder_name

self.bucket_name, self.prefix = get_bucket_and_prefix(folder_name)
settings_file = self._gs_path(self.settings_file)

blob = self.storage.bucket(bucket_name).blob(self.settings_file) # no HTTP request yet
blob = self.storage.bucket(self.bucket_name).blob(settings_file) # no HTTP request yet
try:
with blob.open("r") as file_ref:
json.load(file_ref)
Expand All @@ -166,13 +193,12 @@ def url(self, artifact: Optional[str] = None):
"""
Returns the remote url of the storage artifact.
"""
path = f"gs://{self.bucket_name}"
path = f"gs://{self.folder_name}"
if artifact is not None:
path = f"{path}/{artifact}"
return path

@classmethod
def _convert_blobs_to_artifact(cls, blobs: List[storage.Blob]) -> GSArtifact:
def _convert_blobs_to_artifact(self, blobs: List[storage.Blob]) -> GSArtifact:
"""
Converts a list of `google.cloud.storage.Blob` to a `GSArtifact`.
"""
Expand All @@ -182,22 +208,24 @@ def _convert_blobs_to_artifact(cls, blobs: List[storage.Blob]) -> GSArtifact:
committed: bool = True

for blob in blobs:
if blob.name.endswith(cls.placeholder_file):
if blob.name.endswith(self.placeholder_file):
created = blob.time_created
name = blob.name.replace("/" + cls.placeholder_file, "")
name = blob.name.replace("/" + self.placeholder_file, "")
if self.prefix:
name = name.replace(self.prefix + "/", "")
artifact_path = name # does not contain bucket info here.
elif blob.name.endswith(cls.uncommitted_file):
elif blob.name.endswith(self.uncommitted_file):
committed = False

assert name is not None, "Folder is not a GSArtifact, should not have happened."
return GSArtifact(name, artifact_path, created, committed)

@classmethod
def from_env(cls, bucket_name: str):
def from_env(cls, folder_name: str):
"""
Constructs the client object from the environment, using default credentials.
"""
return cls(bucket_name)
return cls(folder_name)

def get(self, artifact: Union[str, GSArtifact]) -> GSArtifact:
"""
Expand All @@ -210,42 +238,46 @@ def get(self, artifact: Union[str, GSArtifact]) -> GSArtifact:
# We have an artifact, and we recreate it with refreshed info.
path = artifact.artifact_path

blobs = list(self.storage.bucket(self.bucket_name).list_blobs(prefix=path))
prefix = self._gs_path(path)
blobs = list(self.storage.bucket(self.bucket_name).list_blobs(prefix=prefix))
if len(blobs) > 0:
return self._convert_blobs_to_artifact(blobs)
else:
raise GSArtifactNotFound()

@classmethod
def _gs_path(cls, *args):
def _gs_path(self, *args):
"""
Returns path within google cloud storage bucket. We use this since we cannot
use `os.path.join` for cloud storage paths.
Returns path within google cloud storage bucket.
"""
return "/".join(args)
return join_path(self.prefix, *args)

def create(self, artifact: str):
"""
Creates a new artifact in the remote location. By default, it is uncommitted.
"""
bucket = self.storage.bucket(self.bucket_name)
# gives refreshed information
if bucket.blob(self._gs_path(artifact, self.placeholder_file)).exists():

artifact_path = self._gs_path(artifact, self.placeholder_file)
if bucket.blob(artifact_path).exists():
raise GSArtifactConflict(f"{artifact} already exists!")
else:
# Additional safety check
if bucket.blob(self._gs_path(artifact, self.uncommitted_file)).exists():
raise GSArtifactConflict(f"{artifact} already exists!")
bucket.blob(self._gs_path(artifact, self.placeholder_file)).upload_from_string("")
bucket.blob(self._gs_path(artifact, self.uncommitted_file)).upload_from_string("")
return self._convert_blobs_to_artifact(list(bucket.list_blobs(prefix=artifact)))
return self._convert_blobs_to_artifact(
list(bucket.list_blobs(prefix=self._gs_path(artifact)))
)

def delete(self, artifact: GSArtifact):
"""
Removes the artifact from the remote location.
"""
bucket = self.storage.bucket(self.bucket_name)
blobs = list(bucket.list_blobs(prefix=artifact.artifact_path))
prefix = self._gs_path(artifact.artifact_path)
blobs = list(bucket.list_blobs(prefix=prefix))
bucket.delete_blobs(blobs)

def upload(self, artifact: Union[str, GSArtifact], objects_dir: Path):
Expand All @@ -260,7 +292,7 @@ def upload(self, artifact: Union[str, GSArtifact], objects_dir: Path):
source_path = str(objects_dir)

def _sync_blob(source_file_path: str, target_file_path: str):
blob = self.storage.bucket(self.bucket_name).blob(target_file_path)
blob = self.storage.bucket(self.bucket_name).blob(self._gs_path(target_file_path))
blob.upload_from_filename(source_file_path)

import concurrent.futures
Expand All @@ -277,7 +309,7 @@ def _sync_blob(source_file_path: str, target_file_path: str):
for dirpath, _, filenames in os.walk(source_path):
for filename in filenames:
source_file_path = os.path.join(dirpath, filename)
target_file_path = self._gs_path(
target_file_path = join_path(
folder_path, source_file_path.replace(source_path + "/", "")
)
upload_futures.append(
Expand Down Expand Up @@ -328,14 +360,16 @@ def _fetch_blob(blob: storage.Blob):
import concurrent.futures

bucket = self.storage.bucket(self.bucket_name)
bucket.update()
# We may not need updates that frequently, with list_blobs(prefix).
# bucket.update()

try:
with concurrent.futures.ThreadPoolExecutor(
max_workers=self.NUM_CONCURRENT_WORKERS, thread_name_prefix="GSClient.download()-"
) as executor:
download_futures = []
for blob in bucket.list_blobs(prefix=artifact.artifact_path):
prefix = self._gs_path(artifact.artifact_path)
for blob in bucket.list_blobs(prefix=prefix):
download_futures.append(executor.submit(_fetch_blob, blob))
for future in concurrent.futures.as_completed(download_futures):
future.result()
Expand All @@ -348,6 +382,7 @@ def artifacts(self, prefix: str, uncommitted: bool = True) -> List[GSArtifact]:
`match` and `uncommitted` criteria. These can include steps and runs.
"""
list_of_artifacts = []
prefix = self._gs_path(prefix)
for folder_name in self.storage.list_blobs(
self.bucket_name, prefix=prefix, delimiter="/"
)._get_next_page_response()["prefixes"]:
Expand Down Expand Up @@ -405,15 +440,15 @@ def get_credentials(credentials: Optional[Union[str, Credentials]] = None) -> Cr


def get_client(
bucket_name: str,
folder_name: str,
credentials: Optional[Union[str, Credentials]] = None,
project: Optional[str] = None,
) -> GSClient:
"""
Returns a `GSClient` object for a google cloud bucket.
Returns a `GSClient` object for a google cloud bucket folder.
"""
credentials = get_credentials(credentials)
return GSClient(bucket_name, credentials=credentials, project=project)
return GSClient(folder_name, credentials=credentials, project=project)


class Constants(RemoteConstants):
Expand Down
14 changes: 7 additions & 7 deletions tango/integrations/gs/step_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
GSArtifactNotFound,
GSArtifactWriteError,
GSClient,
get_bucket_and_prefix,
)
from tango.step import Step
from tango.step_cache import StepCache
Expand All @@ -32,24 +33,23 @@ class GSStepCache(RemoteStepCache):
.. tip::
Registered as a :class:`~tango.step_cache.StepCache` under the name "gs".
:param bucket_name: The name of the google cloud bucket to use.
:param folder_name: The name of the google cloud bucket folder to use.
:param client: The google cloud storage client to use.
"""

Constants = Constants

def __init__(self, bucket_name: str, client: Optional[GSClient] = None):
def __init__(self, folder_name: str, client: Optional[GSClient] = None):
if client is not None:
bucket_name, _ = get_bucket_and_prefix(folder_name)
assert (
bucket_name == client.bucket_name
), "Assert that bucket name is same as client bucket until we do better"
self.bucket_name = bucket_name
self.folder_name = folder_name
self._client = client
else:
self._client = GSClient(bucket_name)
super().__init__(
tango_cache_dir() / "gs_cache" / make_safe_filename(self._client.bucket_name)
)
self._client = GSClient(folder_name)
super().__init__(tango_cache_dir() / "gs_cache" / make_safe_filename(folder_name))

@property
def client(self):
Expand Down
Loading

0 comments on commit 3857415

Please sign in to comment.