Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Adds porting of network configuration to generated base job templates #392

Merged
merged 1 commit into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions prefect_aws/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@
from prefect.utilities.pydantic import JsonPatch
from pydantic import VERSION as PYDANTIC_VERSION

from prefect_aws.utilities import assemble_document_for_patches

if PYDANTIC_VERSION.startswith("2."):
from pydantic.v1 import Field, root_validator, validator
else:
Expand Down Expand Up @@ -739,6 +741,23 @@ async def generate_work_pool_base_job_template(self) -> dict:
)

if self.task_customizations:
network_config_patches = JsonPatch(
[
patch
for patch in self.task_customizations
if "networkConfiguration" in patch["path"]
]
)
minimal_network_config = assemble_document_for_patches(
network_config_patches
)
if minimal_network_config:
minimal_network_config_with_patches = network_config_patches.apply(
minimal_network_config
)
base_job_template["variables"]["properties"]["network_configuration"][
"default"
] = minimal_network_config_with_patches["networkConfiguration"]
try:
base_job_template["job_configuration"]["task_run_request"] = (
self.task_customizations.apply(
Expand Down
81 changes: 81 additions & 0 deletions prefect_aws/utilities.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Utilities for working with AWS services."""

from typing import Dict, List, Union

from prefect.utilities.collections import visit_collection


Expand Down Expand Up @@ -33,3 +35,82 @@ def make_hashable(item):
collection, visit_fn=make_hashable, return_data=True
)
return hash(hashable_collection)


def ensure_path_exists(doc: Union[Dict, List], path: List[str]):
"""
Ensures the path exists in the document, creating empty dictionaries or lists as
needed.

Args:
doc: The current level of the document or sub-document.
path: The remaining path parts to ensure exist.
"""
if not path:
return
current_path = path.pop(0)
# Check if the next path part exists and is a digit
next_path_is_digit = path and path[0].isdigit()

# Determine if the current path is for an array or an object
if isinstance(doc, list): # Path is for an array index
current_path = int(current_path)
# Ensure the current level of the document is a list and long enough

while len(doc) <= current_path:
doc.append({})
next_level = doc[current_path]
else: # Path is for an object
if current_path not in doc or (
next_path_is_digit and not isinstance(doc.get(current_path), list)
):
doc[current_path] = [] if next_path_is_digit else {}
next_level = doc[current_path]

ensure_path_exists(next_level, path)


def assemble_document_for_patches(patches):
"""
Assembles an initial document that can successfully accept the given JSON Patch
operations.

Args:
patches: A list of JSON Patch operations.

Returns:
An initial document structured to accept the patches.

Example:

```python
patches = [
{"op": "replace", "path": "/name", "value": "Jane"},
{"op": "add", "path": "/contact/address", "value": "123 Main St"},
{"op": "remove", "path": "/age"}
]

initial_document = assemble_document_for_patches(patches)

#output
{
"name": {},
"contact": {},
"age": {}
}
```
"""
document = {}

for patch in patches:
operation = patch["op"]
path = patch["path"].lstrip("/").split("/")

if operation == "add":
# Ensure all but the last element of the path exists
ensure_path_exists(document, path[:-1])
elif operation in ["remove", "replace"]:
# For remove adn replace, the entire path should exist
ensure_path_exists(document, path)

return document
28 changes: 20 additions & 8 deletions tests/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2128,6 +2128,15 @@ def base_job_template_with_defaults(default_base_job_template, aws_credentials):
base_job_template_with_defaults["variables"]["properties"][
"auto_deregister_task_definition"
]["default"] = False
base_job_template_with_defaults["variables"]["properties"]["network_configuration"][
"default"
] = {
"awsvpcConfiguration": {
"subnets": ["subnet-***"],
"assignPublicIp": "DISABLED",
"securityGroups": ["sg-***"],
}
}
return base_job_template_with_defaults


Expand Down Expand Up @@ -2188,10 +2197,20 @@ async def test_generate_work_pool_base_job_template(
cpu=2048,
memory=4096,
task_customizations=[
{
"op": "replace",
"path": "/networkConfiguration/awsvpcConfiguration/assignPublicIp",
"value": "DISABLED",
},
{
"op": "add",
"path": "/networkConfiguration/awsvpcConfiguration/subnets",
"value": ["subnet-***"],
},
{
"op": "add",
"path": "/networkConfiguration/awsvpcConfiguration/securityGroups",
"value": ["sg-d72e9599956a084f5"],
"value": ["sg-***"],
},
],
family="test-family",
Expand Down Expand Up @@ -2229,10 +2248,3 @@ async def test_generate_work_pool_base_job_template(
template = await job.generate_work_pool_base_job_template()

assert template == expected_template

if job_config == "custom":
assert (
"Unable to apply task customizations to the base job template."
"You may need to update the template manually."
in caplog.text
)
59 changes: 58 additions & 1 deletion tests/test_utilities.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import pytest

from prefect_aws.utilities import hash_collection
from prefect_aws.utilities import (
assemble_document_for_patches,
ensure_path_exists,
hash_collection,
)


class TestHashCollection:
Expand Down Expand Up @@ -32,3 +36,56 @@ def test_unhashable_structure(self):
assert hash_collection(typically_unhashable_structure) == hash_collection(
typically_unhashable_structure
), "Unhashable structure hashing failed after transformation"


class TestAssembleDocumentForPatches:
def test_initial_document(self):
patches = [
{"op": "replace", "path": "/name", "value": "Jane"},
{"op": "add", "path": "/contact/address", "value": "123 Main St"},
{"op": "remove", "path": "/age"},
]

initial_document = assemble_document_for_patches(patches)

expected_document = {"name": {}, "contact": {}, "age": {}}

assert initial_document == expected_document, "Initial document assembly failed"


class TestEnsurePathExists:
def test_existing_path(self):
doc = {"key1": {"subkey1": "value1"}}
path = ["key1", "subkey1"]
ensure_path_exists(doc, path)
assert doc == {
"key1": {"subkey1": "value1"}
}, "Existing path modification failed"

def test_new_path_object(self):
doc = {}
path = ["key1", "subkey1"]
ensure_path_exists(doc, path)
assert doc == {"key1": {"subkey1": {}}}, "New path creation for object failed"

def test_new_path_array(self):
doc = {}
path = ["key1", "0"]
ensure_path_exists(doc, path)
assert doc == {"key1": [{}]}, "New path creation for array failed"

def test_existing_path_array(self):
doc = {"key1": [{"subkey1": "value1"}]}
path = ["key1", "0", "subkey1"]
ensure_path_exists(doc, path)
assert doc == {
"key1": [{"subkey1": "value1"}]
}, "Existing path modification for array failed"

def test_existing_path_array_index_out_of_range(self):
doc = {"key1": []}
path = ["key1", "0", "subkey1"]
ensure_path_exists(doc, path)
assert doc == {
"key1": [{"subkey1": {}}]
}, "Existing path modification for array index out of range failed"
Loading