Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generically handle all entities in DocGen entity expansion. #130

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions aws_doc_sdk_examples_tools/doc_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json

from collections import defaultdict
from dataclasses import dataclass, field, is_dataclass, asdict
from dataclasses import dataclass, field, fields, is_dataclass, asdict
from functools import reduce
from pathlib import Path
from typing import Dict, Iterable, Optional, Set, Tuple, List, Any
Expand Down Expand Up @@ -90,6 +90,24 @@ def languages(self) -> Set[str]:
def expand_entities(self, text: str) -> Tuple[str, EntityErrors]:
return expand_all_entities(text, self.entities)

def expand_entity_fields(self, obj: object):
if isinstance(obj, list):
for o in obj:
self.expand_entity_fields(o)
if isinstance(obj, dict):
for val in obj.values():
self.expand_entity_fields(val)
if is_dataclass(obj) and not isinstance(obj, type):
for f in fields(obj):
val = getattr(obj, f.name)
if isinstance(val, str):
[expanded, errs] = self.expand_entities(val)
if errs:
self.errors.extend(errs)
else:
setattr(obj, f.name, expanded)
self.expand_entity_fields(val)

def merge(self, other: "DocGen") -> MetadataErrors:
"""Merge fields from other into self, prioritizing self fields."""
warnings = MetadataErrors()
Expand Down Expand Up @@ -332,7 +350,7 @@ def count_genai(d: Dict[str, int], e: Example):
# and arguably not useful either.
class DocGenEncoder(json.JSONEncoder):
def default(self, obj):
if is_dataclass(obj):
if is_dataclass(obj) and not isinstance(obj, type):
return asdict(obj)

if isinstance(obj, Path):
Expand Down
38 changes: 4 additions & 34 deletions aws_doc_sdk_examples_tools/doc_gen_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,43 +48,13 @@ def main():
unmerged_doc_gen = DocGen.from_root(Path(root))
merged_doc_gen.merge(unmerged_doc_gen)

if args.strict and merged_doc_gen.errors:
logging.error("Errors found in metadata: %s", merged_doc_gen.errors)
exit(1)

if not args.skip_entity_expansion:
# Replace entities
for example in merged_doc_gen.examples.values():
errors = EntityErrors()
title, title_errors = merged_doc_gen.expand_entities(example.title)
errors.extend(title_errors)

title_abbrev, title_abbrev_errors = merged_doc_gen.expand_entities(
example.title_abbrev
)
errors.extend(title_abbrev_errors)

synopsis, synopsis_errors = merged_doc_gen.expand_entities(example.synopsis)
errors.extend(synopsis_errors)
merged_doc_gen.expand_entity_fields(merged_doc_gen)

synopsis_list = []
for synopsis in example.synopsis_list:
expanded_synopsis, synopsis_errors = merged_doc_gen.expand_entities(
synopsis
)
synopsis_list.append(expanded_synopsis)
errors.extend(synopsis_errors)

if args.strict and errors:
logging.error(
f"Errors expanding entities for example: {example}. {errors}"
)
exit(1)

example.title = title
example.title_abbrev = title_abbrev
example.synopsis = synopsis
example.synopsis_list = synopsis_list
if args.strict and merged_doc_gen.errors:
logging.error("Errors found in metadata: %s", merged_doc_gen.errors)
exit(1)

serialized = json.dumps(merged_doc_gen, cls=DocGenEncoder)

Expand Down
22 changes: 17 additions & 5 deletions aws_doc_sdk_examples_tools/doc_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,16 @@ def sample_doc_gen() -> DocGen:
root=Path("/test/root"),
errors=metadata_errors,
prefix="test_prefix",
entities={"&S3long;": "Amazon Simple Storage Service", "&S3;": "Amazon S3"},
entities={
"&S3long;": "Amazon Simple Storage Service",
"&S3;": "Amazon S3",
"&PYLong;": "Python SDK v1",
"&PYShort;": "Python V1",
},
sdks={
"python": Sdk(
name="python",
versions=[
SdkVersion(version=1, long="Python SDK v1", short="Python v1")
],
versions=[SdkVersion(version=1, long="&PYLong;", short="&PYShort;")],
guide="Python Guide",
property="python",
)
Expand Down Expand Up @@ -123,6 +126,15 @@ def test_expand_entities(sample_doc_gen: DocGen):
assert not errors


def test_expand_entity_fields(sample_doc_gen: DocGen):
error_count = len(sample_doc_gen.errors)
sample_doc_gen.expand_entity_fields(sample_doc_gen)
assert sample_doc_gen.services["s3"].long == "Amazon Simple Storage Service"
assert sample_doc_gen.sdks["python"].versions[0].long == "Python SDK v1"
# The fixture has an error, so make sure we don't have _more_ errors.
assert error_count == len(sample_doc_gen.errors)


def test_doc_gen_encoder(sample_doc_gen: DocGen):
encoded = json.dumps(sample_doc_gen, cls=DocGenEncoder)
decoded = json.loads(encoded)
Expand All @@ -137,7 +149,7 @@ def test_doc_gen_encoder(sample_doc_gen: DocGen):
assert decoded["sdks"]["python"]["name"] == "python"
assert decoded["sdks"]["python"]["guide"] == "Python Guide"
assert decoded["sdks"]["python"]["versions"][0]["version"] == 1
assert decoded["sdks"]["python"]["versions"][0]["long"] == "Python SDK v1"
assert decoded["sdks"]["python"]["versions"][0]["long"] == "&PYLong;"

# Verify service information
assert "services" in decoded
Expand Down
8 changes: 4 additions & 4 deletions aws_doc_sdk_examples_tools/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
from dataclasses import dataclass
import re

from .metadata_errors import ErrorsList
from .metadata_errors import ErrorsList, MetadataError


@dataclass
class EntityError(Exception):
class EntityError(MetadataError):
"""
Base error. Do not use directly.
"""

entity: Optional[str]
entity: Optional[str] = None

def message(self) -> str:
return ""
Expand Down Expand Up @@ -54,4 +54,4 @@ def expand_entity(
if expanded is not None:
return entity.replace(entity, expanded), None
else:
return entity, MissingEntityError(entity)
return entity, MissingEntityError(entity=entity)
2 changes: 1 addition & 1 deletion aws_doc_sdk_examples_tools/entities_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

def test_entity_errors_append():
errors = EntityErrors()
errors.append(MissingEntityError("entity1"))
errors.append(MissingEntityError(entity="entity1"))
assert len(errors._errors) == 1
assert errors._errors[0].entity == "entity1"

Expand Down
Loading