Skip to content

Commit

Permalink
Merge pull request #356 from backend-developers-ltd/s3-diag
Browse files Browse the repository at this point in the history
Add s3 diagnostic management command
  • Loading branch information
emnoor-reef authored Jan 8, 2025
2 parents a989c24 + 8344d41 commit baf555f
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 8 deletions.
32 changes: 31 additions & 1 deletion validator/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ export AWS_DEFAULT_REGION=BUCKETS_REGION

At the end of the script, it will show the values for `S3_BUCKET_NAME_PROMPTS`, `S3_BUCKET_NAME_ANSWERS`.
If you use the `--create-user` flag, it will also show the values for `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY`.
opy these variables in your validator `.env` file and restart your validator.
Copy these variables in your validator `.env` file and restart your validator.

> [!WARNING]
> Even if you did not use `--create-user`, you still need to provide `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` in your validator `.env` file.
Expand All @@ -114,6 +114,36 @@ opy these variables in your validator `.env` file and restart your validator.
> [!NOTE]
> We have tested the AWS S3. The buckets allow quick and concurrent upload and download of multiple (but tiny) text files.
## Verifying S3 bucket configurations

After setting up your `.env` file by following the above section,
you can verify your S3 setup with the following command:

```sh
docker compose exec validator-runner docker compose exec app python manage.py check_s3_setup
```

If your some reason your S3 setup does not work,
you can try checking different values of for S3 configurations with this command's flags:

```
--aws-access-key-id AWS_ACCESS_KEY_ID
Override the value of AWS_ACCESS_KEY_ID from .env
--aws-secret-access-key AWS_SECRET_ACCESS_KEY
Override the value of AWS_SECRET_ACCESS_KEY from .env
--aws-region-name AWS_REGION_NAME
Override the value of AWS region (by default read from environment variable AWS_DEFAULT_REGION)
--aws-endpoint-url AWS_ENDPOINT_URL
Override the value of AWS_ENDPOINT_URL from .env
--s3-bucket-name-prompts S3_BUCKET_NAME_PROMPTS
Override the value of S3_BUCKET_NAME_PROMPTS from .env
--s3-bucket-name-answers S3_BUCKET_NAME_ANSWERS
Override the value of S3_BUCKET_NAME_ANSWERS from .env
```

After finding the configuration values for which the script works,
copy the values in your validator `.env` file and restart your validator.

## Updated validator .env

Add or update these variables in the validator `.env` file:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import logging
import os
import re
import secrets
import uuid

import requests
from botocore.exceptions import NoCredentialsError, PartialCredentialsError
from django.conf import settings
from django.core.management.base import BaseCommand
from requests import HTTPError

from compute_horde_validator.validator import s3

logger = logging.getLogger(__name__)
REGION_ERROR = re.compile(
"the region '([a-z]+-[a-z]+-[0-9]+)' is wrong; expecting '([a-z]+-[a-z]+-[0-9]+)'"
)


def _print_region_env_warnings(writer):
msg = ""
if aws_region := os.getenv("AWS_REGION"):
msg += f"You have environment variable AWS_REGION={aws_region}\n"
if aws_default_region := os.getenv("AWS_DEFAULT_REGION"):
msg += f"You have environment variable AWS_DEFAULT_REGION={aws_default_region}\n"

if msg:
msg = "Possible issues with environment variables:\n" + msg

writer.write(msg)


class Command(BaseCommand):
def add_arguments(self, parser):
parser.add_argument(
"--aws-access-key-id",
help="Override the value of AWS_ACCESS_KEY_ID from .env",
)
parser.add_argument(
"--aws-secret-access-key",
help="Override the value of AWS_SECRET_ACCESS_KEY from .env",
)
parser.add_argument(
"--aws-region-name",
help="Override the value of AWS region (by default read from environment variable AWS_DEFAULT_REGION)",
)
parser.add_argument(
"--aws-endpoint-url",
help="Override the value of AWS_ENDPOINT_URL from .env",
)
parser.add_argument(
"--s3-bucket-name-prompts",
help="Override the value of S3_BUCKET_NAME_PROMPTS from .env",
)
parser.add_argument(
"--s3-bucket-name-answers",
help="Override the value of S3_BUCKET_NAME_ANSWERS from .env",
)

def handle(self, *args, **options):
logging.basicConfig(level="ERROR")

s3_client = s3.get_s3_client(
aws_access_key_id=options["aws_access_key_id"],
aws_secret_access_key=options["aws_secret_access_key"],
region_name=options["aws_region_name"],
endpoint_url=options["aws_endpoint_url"],
)

prompts_bucket = options["s3_bucket_name_prompts"] or settings.S3_BUCKET_NAME_PROMPTS
answers_bucket = options["s3_bucket_name_answers"] or settings.S3_BUCKET_NAME_ANSWERS

for bucket_name in (prompts_bucket, answers_bucket):
test_file = f"diagnostic-{uuid.uuid4()}"
file_contents = secrets.token_hex()

# try generating a pre-signed url
try:
upload_url = s3.generate_upload_url(
test_file, bucket_name=bucket_name, s3_client=s3_client
)
except (NoCredentialsError, PartialCredentialsError):
self.stderr.write(
"You did not provide credentials.\n"
"Please provide AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY in .env file.\n"
"For testing, you can also provide your credentials to this command with --aws-access-key-id and --aws-secret-access-key flags.\n",
)
return

# try uploading to the pre-signed url
try:
s3.upload_prompts_to_s3_url(upload_url, file_contents)
except HTTPError as exc:
if exc.response.status_code == 403:
self.stderr.write(
f"Your configured credentials does not have permissions to write to this bucket: {bucket_name}\n"
)
elif match := REGION_ERROR.search(exc.response.text):
wrong, expecting = match.groups()
self.stderr.write(
"Your bucket requests are being routed to a invalid AWS region.\n"
)
self.stderr.write(
f"Your bucket {bucket_name!r} is in region {expecting!r}, but it is configured as {wrong!r}\n"
)
_print_region_env_warnings(self.stderr)
else:
self.stderr.write(
f"Failed to write to the bucket {bucket_name} with the following error:\n"
+ str(exc)
+ "\n"
+ exc.response.text
+ "Please check if the buckets you have configured exist and are accessible with the configured credentials.\n"
)
return

# try downloading with public url
download_url = s3.get_public_url(
test_file, bucket_name=bucket_name, s3_client=s3_client
)
response = requests.get(download_url)
response.raise_for_status() # TODO: handle status >= 400, but when would that happen?
if response.status_code == 301:
bucket_region = response.headers.get("x-amz-bucket-region")
self.stderr.write(
"Your bucket requests are being routed to a invalid AWS region.\n"
)
if bucket_region:
self.stderr.write(
f"Your bucket {bucket_name!r} is in region {bucket_region!r}\n"
)
_print_region_env_warnings(self.stderr)
return

assert file_contents == response.text.strip()
self.stdout.write("\n\n🎉 Your S3 configuration works! 🎉")
32 changes: 25 additions & 7 deletions validator/app/src/compute_horde_validator/validator/s3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import logging
from typing import Any

import boto3
import httpx
Expand All @@ -10,12 +11,25 @@
logger = logging.getLogger(__name__)


def get_s3_client():
def get_s3_client(
aws_access_key_id=None,
aws_secret_access_key=None,
region_name=None,
endpoint_url=None,
):
if aws_access_key_id is None:
aws_access_key_id = settings.AWS_ACCESS_KEY_ID
if aws_secret_access_key is None:
aws_secret_access_key = settings.AWS_SECRET_ACCESS_KEY
if endpoint_url is None:
endpoint_url = settings.AWS_ENDPOINT_URL

return boto3.client(
"s3",
aws_access_key_id=settings.AWS_ACCESS_KEY_ID,
aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY,
endpoint_url=settings.AWS_ENDPOINT_URL,
region_name=region_name,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
endpoint_url=endpoint_url,
)


Expand All @@ -26,8 +40,10 @@ def _generate_presigned_url(
bucket_name: str,
prefix: str = "",
expiration: int = 3600,
s3_client: Any = None,
) -> str:
s3_client = get_s3_client()
if s3_client is None:
s3_client = get_s3_client()

return s3_client.generate_presigned_url( # type: ignore
method,
Expand All @@ -40,8 +56,10 @@ def _generate_presigned_url(
generate_download_url = functools.partial(_generate_presigned_url, "get_object")


def get_public_url(key: str, *, bucket_name: str, prefix: str = "") -> str:
endpoint_url = settings.AWS_ENDPOINT_URL or "https://s3.amazonaws.com"
def get_public_url(key: str, *, bucket_name: str, prefix: str = "", s3_client: Any = None) -> str:
if s3_client is None:
s3_client = get_s3_client()
endpoint_url = s3_client.meta.endpoint_url
return f"{endpoint_url}/{bucket_name}/{prefix}{key}"


Expand Down

0 comments on commit baf555f

Please sign in to comment.