Skip to content

Commit

Permalink
Add extra=Forbid to NodeItem
Browse files Browse the repository at this point in the history
Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
  • Loading branch information
cau-git committed Sep 27, 2024
1 parent 5084853 commit adc16f3
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 43 deletions.
2 changes: 2 additions & 0 deletions docling_core/types/experimental/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ class NodeItem(BaseModel):
parent: Optional[RefItem] = None
children: List[RefItem] = []

model_config = ConfigDict(extra="forbid")

def get_ref(self):
"""get_ref."""
return RefItem(cref=self.self_ref)
Expand Down
2 changes: 1 addition & 1 deletion test/data/docling_document/unit/SectionItem.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
children: []
label: text
level: 1
level: 2
orig: whatever
parent: null
prov: []
Expand Down
7 changes: 0 additions & 7 deletions test/data/experimental/dummy_doc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ texts:
- orig: "arXiv:2206.01062v1 [cs.CV] 2 Jun 2022"
text: "arXiv:2206.01062v1 [cs.CV] 2 Jun 2022"
self_ref: "#/texts/0"
hash: 132103230
label: "page_header"
parent:
$ref: "#/furniture"
Expand All @@ -52,7 +51,6 @@ texts:
- orig: "DocLayNet: A Large Human-Annotated Dataset for\nDocument-Layout Analysis"
text: "DocLayNet: A Large Human-Annotated Dataset for Document-Layout Analysis"
self_ref: "#/texts/1"
hash: 2349732 # uint64 hash of self_ref
label: "title"
parent:
$ref: "#/body"
Expand All @@ -68,7 +66,6 @@ texts:
- orig: "OPERATION (cont.)" # nested inside the figure
text: "OPERATION (cont.)"
self_ref: "#/texts/2"
hash: 6978483
label: "section_header"
parent:
$ref: "/pictures/0"
Expand All @@ -85,7 +82,6 @@ texts:
- orig: "Figure 1: Four examples of complex page layouts across dif-\nferent document categories" # nested inside the figure
text: "Figure 1: Four examples of complex page layouts across different document categories"
self_ref: "#/texts/3"
hash: 6978483
label: "caption"
parent:
$ref: "#/body"
Expand All @@ -103,7 +99,6 @@ texts:

tables: # All tables...
- self_ref: "#/table/0"
hash: 98574
label: "table"
parent:
$ref: "#/body"
Expand Down Expand Up @@ -133,7 +128,6 @@ tables: # All tables...

pictures: # All pictures...
- self_ref: "#/pictures/0"
hash: 7782482
label: "picture"
parent:
$ref: "#/body"
Expand Down Expand Up @@ -168,7 +162,6 @@ key_value_items: [ ] # All KV-items
# We should consider this for pages
pages: # Optional, for layout documents
1:
hash: 6203680922337857390
size:
width: 768.23
height: 583.15
Expand Down
103 changes: 68 additions & 35 deletions test/test_docling_doc.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
from collections import deque

import pytest
import yaml

from collections import deque

from docling_core.types.experimental.document import (
BasePictureData,
BaseTableData,
DescriptionItem,
DoclingDocument,
TableCell,
NodeItem,
DocItem,
TextItem,
DoclingDocument,
FloatingItem,
KeyValueItem,
SectionItem,
PictureItem,
SectionItem,
TableCell,
TableItem,
BasePictureData,
BaseTableData
TextItem,
)
from docling_core.types.experimental.labels import DocItemLabel, GroupLabel


def test_docitems():

# Iterative function to find all subclasses
def find_all_subclasses_iterative(base_class):
subclasses = deque([base_class]) # Use a deque for efficient popping from the front
subclasses = deque(
[base_class]
) # Use a deque for efficient popping from the front
all_subclasses = []

while subclasses:
Expand All @@ -40,57 +40,91 @@ def find_all_subclasses_iterative(base_class):
def serialise(obj):
return yaml.safe_dump(obj.model_dump(mode="json", by_alias=True))

def write(name:str, serialisation:str):
def write(name: str, serialisation: str):
with open(f"./test/data/docling_document/unit/{name}.yaml", "w") as fw:
fw.write(serialisation)

def read(name:str):
def read(name: str):
with open(f"./test/data/docling_document/unit/{name}.yaml", "r") as fr:
gold = fr.read()
return gold

def generate(dc, obj):
write(dc.__name__, pred)

def verify(dc, obj):
pred = serialise(obj)
#print(f"\t{dc.__name__}:\n {pred}")
pred = serialise(obj)
# print(f"\t{dc.__name__}:\n {pred}")
gold = read(dc.__name__)

assert pred==gold, f"pred!=gold for {dc.__name__}"
assert pred == gold, f"pred!=gold for {dc.__name__}"

# Iterate over the derived classes of the BaseClass
derived_classes = find_all_subclasses_iterative(DocItem)
for dc in derived_classes:

if dc is TextItem:
obj = dc(text="whatever", orig="whatever", dloc="sdvsd", label=DocItemLabel.TEXT, self_ref="#")
obj = dc(
text="whatever",
orig="whatever",
dloc="sdvsd",
label=DocItemLabel.TEXT,
self_ref="#",
)
verify(dc, obj)

elif dc is FloatingItem:
obj = dc(text="whatever", orig="whatever", dloc="sdvsd", label=DocItemLabel.TEXT, self_ref="#")
obj = dc(
text="whatever",
orig="whatever",
dloc="sdvsd",
label=DocItemLabel.TEXT,
self_ref="#",
)
verify(dc, obj)

elif dc is KeyValueItem:
obj = dc(text="whatever", orig="whatever", dloc="sdvsd", label=DocItemLabel.TEXT, self_ref="#")
obj = dc(
text="whatever",
orig="whatever",
dloc="sdvsd",
label=DocItemLabel.TEXT,
self_ref="#",
)
verify(dc, obj)

elif dc is SectionItem:
obj = dc(text="whatever", orig="whatever", dloc="sdvsd", label=DocItemLabel.TEXT, self_ref="#")
obj = dc(
text="whatever",
orig="whatever",
dloc="sdvsd",
label=DocItemLabel.TEXT,
self_ref="#",
level=2,
)
verify(dc, obj)

elif dc is PictureItem:
obj = dc(text="whatever", orig="whatever", dloc="sdvsd", label=DocItemLabel.TEXT, self_ref="#",
data=BasePictureData())
obj = dc(
text="whatever",
orig="whatever",
dloc="sdvsd",
label=DocItemLabel.TEXT,
self_ref="#",
data=BasePictureData(),
)
verify(dc, obj)

elif dc is TableItem:
obj = dc(text="whatever", orig="whatever", dloc="sdvsd", label=DocItemLabel.TEXT, self_ref="#",
data=BaseTableData(num_rows=3, num_cols=5, cells=[]))
obj = dc(
text="whatever",
orig="whatever",
dloc="sdvsd",
label=DocItemLabel.TEXT,
self_ref="#",
data=BaseTableData(num_rows=3, num_cols=5, cells=[]),
)
verify(dc, obj)

else:
print(f"{dc.__name__} is not known")
print(f"{dc.__name__} is not known")
assert False, "new derived class detected {dc.__name__}: {e}"


Expand Down Expand Up @@ -315,4 +349,3 @@ def _construct_doc() -> DoclingDocument:
fig_item = doc.add_picture(data=BasePictureData(), caption=fig_caption)

return doc

0 comments on commit adc16f3

Please sign in to comment.