Skip to content

Commit

Permalink
feat: add doc metadata extractor and ID generator classes
Browse files Browse the repository at this point in the history
Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com>
  • Loading branch information
vagenas committed Sep 26, 2024
1 parent b49e93e commit ef88fe1
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 0 deletions.
12 changes: 12 additions & 0 deletions docling_core/transforms/id_generator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#

"""Define the ID generator types."""

from docling_core.transforms.id_generator.base import BaseIDGenerator # noqa
from docling_core.transforms.id_generator.doc_hash_id_generator import ( # noqa
DocHashIDGenerator,
)
from docling_core.transforms.id_generator.uuid_generator import UUIDGenerator # noqa
30 changes: 30 additions & 0 deletions docling_core/transforms/id_generator/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#

"""Base document ID generator module."""

from abc import ABC, abstractmethod
from typing import Any

from docling_core.types import Document as DLDocument


class BaseIDGenerator(ABC):
"""Document ID generator base class."""

@abstractmethod
def generate_id(self, doc: DLDocument, *args: Any, **kwargs: Any) -> str:
"""Generate an ID for the given document.
Args:
doc (DLDocument): document to generate ID for
Raises:
NotImplementedError: in this abstract implementation
Returns:
str: the generated ID
"""
raise NotImplementedError()
27 changes: 27 additions & 0 deletions docling_core/transforms/id_generator/doc_hash_id_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#

"""Doc-hash-based ID generator module."""


from typing import Any

from docling_core.transforms.id_generator import BaseIDGenerator
from docling_core.types import Document as DLDocument


class DocHashIDGenerator(BaseIDGenerator):
"""Doc-hash-based ID generator class."""

def generate_id(self, doc: DLDocument, *args: Any, **kwargs: Any) -> str:
"""Generate an ID for the given document.
Args:
doc (DLDocument): document to generate ID for
Returns:
str: the generated ID
"""
return doc.file_info.document_hash
34 changes: 34 additions & 0 deletions docling_core/transforms/id_generator/uuid_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#

"""UUID-based ID generator module."""

from random import Random
from typing import Annotated, Any, Optional
from uuid import UUID

from pydantic import BaseModel, Field

from docling_core.transforms.id_generator import BaseIDGenerator
from docling_core.types import Document as DLDocument


class UUIDGenerator(BaseModel, BaseIDGenerator):
"""UUID-based ID generator class."""

seed: Optional[int] = None
uuid_version: Annotated[int, Field(strict=True, ge=1, le=5)] = 4

def generate_id(self, doc: DLDocument, *args: Any, **kwargs: Any) -> str:
"""Generate an ID for the given document.
Args:
doc (DLDocument): document to generate ID for
Returns:
str: the generated ID
"""
rd = Random(x=self.seed)
return str(UUID(int=rd.getrandbits(128), version=self.uuid_version))
13 changes: 13 additions & 0 deletions docling_core/transforms/metadata_extractor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#

"""Define the metadata extractor types."""

from docling_core.transforms.metadata_extractor.base import ( # noqa
BaseMetadataExtractor,
)
from docling_core.transforms.metadata_extractor.simple_metadata_extractor import ( # noqa
SimpleMetadataExtractor,
)
59 changes: 59 additions & 0 deletions docling_core/transforms/metadata_extractor/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#

"""Base metadata extractor module."""


from abc import ABC, abstractmethod
from typing import Any

from pydantic import BaseModel

from docling_core.types import Document as DLDocument


class BaseMetadataExtractor(BaseModel, ABC):
"""Metadata extractor base class."""

@abstractmethod
def get_metadata(
self, doc: DLDocument, *args: Any, **kwargs: Any
) -> dict[str, Any]:
"""Extract metadata for the given document.
Args:
doc (DLDocument): document to extract metadata for
Raises:
NotImplementedError: in this abstract implementation
Returns:
dict[str, Any]: the extracted metadata
"""
raise NotImplementedError()

@abstractmethod
def get_excluded_embed_metadata_keys(self) -> list[str]:
"""Get metadata keys to exclude from embedding.
Raises:
NotImplementedError: in this abstract implementation
Returns:
list[str]: the metadata to exclude
"""
raise NotImplementedError()

@abstractmethod
def get_excluded_llm_metadata_keys(self) -> list[str]:
"""Get metadata keys to exclude from LLM generation.
Raises:
NotImplementedError: in this abstract implementation
Returns:
list[str]: the metadata to exclude
"""
raise NotImplementedError()
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#

"""Simple metadata extractor module."""


from enum import Enum
from typing import Any

from docling_core.transforms.metadata_extractor import BaseMetadataExtractor
from docling_core.types import Document as DLDocument


class SimpleMetadataExtractor(BaseMetadataExtractor):
"""Simple metadata extractor class."""

class _Keys(str, Enum):
DL_DOC_HASH = "dl_doc_hash"
ORIGIN = "origin"

include_origin: bool = False

def get_metadata(
self, doc: DLDocument, origin: str, *args: Any, **kwargs: Any
) -> dict[str, Any]:
"""Extract metadata for the given document.
Args:
doc (DLDocument): document to extract metadata for
origin (str): the document origin
Returns:
dict[str, Any]: the extracted metadata
"""
meta: dict[str, Any] = {
self._Keys.DL_DOC_HASH: doc.file_info.document_hash,
}
if self.include_origin:
meta[self._Keys.ORIGIN] = origin
return meta

def get_excluded_embed_metadata_keys(self) -> list[str]:
"""Get metadata keys to exclude from embedding.
Returns:
list[str]: the metadata to exclude
"""
excl_keys: list[str] = [self._Keys.DL_DOC_HASH]
if self.include_origin:
excl_keys.append(self._Keys.ORIGIN)
return excl_keys

def get_excluded_llm_metadata_keys(self) -> list[str]:
"""Get metadata keys to exclude from LLM generation.
Returns:
list[str]: the metadata to exclude
"""
return self.get_excluded_embed_metadata_keys()

0 comments on commit ef88fe1

Please sign in to comment.