Skip to content

Commit

Permalink
Add support for node tags
Browse files Browse the repository at this point in the history
  • Loading branch information
jschlyter committed Jan 10, 2025
1 parent 4b930cb commit 83c9bcd
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 9 deletions.
8 changes: 7 additions & 1 deletion nodeman/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,13 @@ def command_create(args: argparse.Namespace) -> NodeBootstrapInformation:

client = get_admin_client(args)

payload = {
**({"name": args.name} if args.name else {}),
**({"tags": args.tags.split(",")} if hasattr(args, "tags") and args.tags else {}),
}

try:
response = client.post(urljoin(args.server, "/api/v1/node"))
response = client.post(urljoin(args.server, "/api/v1/node"), json=payload)
response.raise_for_status()
except httpx.HTTPError as exc:
logging.error("Failed to create node: %s", str(exc))
Expand Down Expand Up @@ -314,6 +319,7 @@ def main() -> None:
admin_create_parser.set_defaults(func=command_create)
add_admin_arguments(admin_create_parser)
admin_create_parser.add_argument("--name", metavar="name", help="Node name")
admin_create_parser.add_argument("--tags", metavar="tags", help="Node tags")

admin_get_parser = subparsers.add_parser("get", help="Get node")
admin_get_parser.set_defaults(func=command_get)
Expand Down
4 changes: 3 additions & 1 deletion nodeman/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from cryptography import x509
from cryptography.hazmat.primitives import serialization
from cryptography.x509.oid import ExtensionOID
from mongoengine import DateTimeField, DictField, Document, StringField, ValidationError
from mongoengine import DateTimeField, DictField, Document, ListField, StringField, ValidationError
from mongoengine.errors import NotUniqueError

from .names import get_deterministic_name, get_random_name
Expand All @@ -24,6 +24,8 @@ class TapirNode(Document):
activated = DateTimeField()
deleted = DateTimeField()

tags = ListField(StringField())

@classmethod
def create_random_node(cls, domain: str) -> Self:
name = get_random_name() + "." + domain
Expand Down
7 changes: 7 additions & 0 deletions nodeman/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ def from_accept(cls, accept: str | None) -> Self:
raise ValueError(f"Unsupported format. Acceptable formats: {[f.value for f in cls]} or */*")


class NodeCreateRequest(BaseModel):
name: str | None = Field(title="Node name", default=None)
tags: list[str] | None = Field(title="Node tags", default=None)


class NodeRequest(BaseModel):
timestamp: AwareDatetime = Field(title="Timestamp")
x509_csr: str = Field(title="X.509 Client Certificate Bundle")
Expand All @@ -49,13 +54,15 @@ class NodeInformation(BaseModel):
name: str = Field(title="Node name")
public_key: PublicJwk | None = Field(title="Public key")
activated: AwareDatetime | None = Field(title="Activated")
tags: list[str] | None = Field(title="Node tags")

@classmethod
def from_db_model(cls, node: TapirNode):
return cls(
name=node.name,
public_key=public_key_factory(node.public_key) if node.public_key else None,
activated=node.activated,
tags=node.tags,
)


Expand Down
20 changes: 17 additions & 3 deletions nodeman/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
NodeCertificate,
NodeCollection,
NodeConfiguration,
NodeCreateRequest,
NodeEnrollmentResult,
NodeInformation,
PublicKeyFormat,
Expand Down Expand Up @@ -128,22 +129,35 @@ def process_csr_request(request: Request, csr: x509.CertificateSigningRequest, n
response_model_exclude_none=True,
)
async def create_node(
request: Request, username: Annotated[str, Depends(get_current_username)], name: str | None = None
username: Annotated[str, Depends(get_current_username)],
request: Request,
create_request: NodeCreateRequest | None = None,
) -> NodeBootstrapInformation:
"""Create node"""

node_enrollment_key = JWK.generate(kty="oct", size=256, alg="HS256")
if create_request:
name = create_request.name
tags = create_request.tags
else:
name = None
tags = None

domain = request.app.settings.nodes.domain

node_enrollment_key = JWK.generate(kty="oct", size=256, alg="HS256")

if name is None:
node = TapirNode.create_next_node(domain=request.app.settings.nodes.domain)
node = TapirNode.create_next_node(domain=domain)
elif name.endswith(f".{domain}"):
logging.debug("Explicit node name %s requested", name, extra={"nodename": name})
node = TapirNode(name=name, domain=domain).save()
else:
logging.warning("Explicit node name %s not acceptable", name, extra={"nodename": name})
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail="Invalid node name")

node.tags = tags
node.save()

TapirNodeEnrollment(
name=node.name,
key=node_enrollment_key.export(as_dict=True, private_key=node_enrollment_key.kty == "oct"),
Expand Down
10 changes: 6 additions & 4 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,13 @@ def _test_enroll(data_key: JWK, x509_key: PrivateKey, requested_name: str | None
logging.basicConfig(level=logging.DEBUG)
logging.debug("Testing enrollment")

tags = ["test", str(uuid.uuid4())]

#############
# Create node

response = admin_client.post(
urljoin(server, "/api/v1/node"), params={"name": requested_name} if requested_name else None
)
node_create_request = {**({"name": requested_name} if requested_name else {}), "tags": tags}
response = admin_client.post(urljoin(server, "/api/v1/node"), json=node_create_request)
if response.status_code != status.HTTP_201_CREATED:
raise FailedToCreateNode
assert response.status_code == status.HTTP_201_CREATED
Expand All @@ -93,6 +94,7 @@ def _test_enroll(data_key: JWK, x509_key: PrivateKey, requested_name: str | None
node_information = response.json()
assert node_information["name"] == name
assert node_information["activated"] is None
assert "test" in node_information["tags"]

#####################
# Enroll created node
Expand Down Expand Up @@ -333,7 +335,7 @@ def test_enroll_bad_data_signature() -> None:

logging.basicConfig(level=logging.DEBUG)

response = admin_client.post(urljoin(server, "/api/v1/node"))
response = admin_client.post(urljoin(server, "/api/v1/node"), json={})
assert response.status_code == status.HTTP_201_CREATED
create_response = response.json()
name = create_response["name"]
Expand Down

0 comments on commit 83c9bcd

Please sign in to comment.