Skip to content

Commit

Permalink
Change DoclingDocument.iterate_elements and add print tree function
Browse files Browse the repository at this point in the history
Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
  • Loading branch information
cau-git committed Sep 24, 2024
1 parent 4f1c190 commit 0a1e6ce
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 12 deletions.
24 changes: 16 additions & 8 deletions docling_core/types/experimental/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,15 +659,16 @@ def build_page_trees(self):
def iterate_elements(
self,
root: Optional[NodeItem] = None,
omit_groups: bool = True,
with_groups: bool = False,
traverse_figures: bool = True,
) -> typing.Iterable[NodeItem]:
level=0,
) -> typing.Iterable[Tuple[NodeItem, int]]: # tuple of node and level
# Yield the current node
if not root:
root = self.body

if omit_groups and not isinstance(root, GroupItem):
yield root
if not isinstance(root, GroupItem) or with_groups:
yield root, level

# Traverse children
for child_ref in root.children:
Expand All @@ -676,9 +677,16 @@ def iterate_elements(
if isinstance(child, NodeItem):
# If the child is a NodeItem, recursively traverse it
if not isinstance(child, FigureItem) or traverse_figures:
yield from self.iterate_elements(child)
yield from self.iterate_elements(child, level=level + 1)
else: # leaf
yield child
yield child, level

def print_element_tree(self):
for ix, (item, level) in enumerate(self.iterate_elements(with_groups=True)):
if isinstance(item, GroupItem):
print(" " * level, f"{ix}: {item.name}")
elif isinstance(item, DocItem):
print(" " * level, f"{ix}: {item.label}")

def export_to_markdown(
self,
Expand Down Expand Up @@ -718,7 +726,7 @@ def export_to_markdown(

skip_count = 0
if len(self.body.children):
for ix, item in enumerate(self.iterate_elements(self.body)):
for ix, (item, level) in enumerate(self.iterate_elements(self.body)):
if skip_count < from_element:
skip_count += 1
continue # skip as many items as you want
Expand Down Expand Up @@ -837,7 +845,7 @@ def export_to_document_tokens(

skip_count = 0
if len(self.body.children):
for ix, item in enumerate(self.iterate_elements(self.body)):
for ix, (item, level) in enumerate(self.iterate_elements(self.body)):

if skip_count < from_element:
skip_count += 1
Expand Down
9 changes: 5 additions & 4 deletions test/test_docling_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def test_load_serialize_doc():

### Iterate all elements

for item in doc.iterate_elements():
print(f"Item: {item}")
for item, level in doc.iterate_elements():
print(f"Item: {item} at level {level}")


def test_construct_doc():
Expand Down Expand Up @@ -158,7 +158,7 @@ def test_construct_doc():

### Iterate all elements

for item in doc.iterate_elements():
for item, level in doc.iterate_elements():
print(f"Item: {item}")

## Export stuff
Expand All @@ -170,11 +170,12 @@ def test_construct_doc():
table.export_to_html()
table.export_to_dataframe()
table.export_to_document_tokens(doc)
1 == 1

for fig in doc.figures:
fig.export_to_document_tokens(doc)

doc.print_element_tree()

### Serialize and deserialize stuff

yaml_dump = yaml.safe_dump(doc.model_dump(mode="json", by_alias=True))
Expand Down

0 comments on commit 0a1e6ce

Please sign in to comment.