Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add doc metadata extractor and ID generator classes #34

Merged
merged 1 commit into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docling_core/transforms/id_generator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#

"""Define the ID generator types."""

from docling_core.transforms.id_generator.base import BaseIDGenerator # noqa
from docling_core.transforms.id_generator.doc_hash_id_generator import ( # noqa
DocHashIDGenerator,
)
from docling_core.transforms.id_generator.uuid_generator import UUIDGenerator # noqa
30 changes: 30 additions & 0 deletions docling_core/transforms/id_generator/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#

"""Base document ID generator module."""

from abc import ABC, abstractmethod
from typing import Any

from docling_core.types import Document as DLDocument


class BaseIDGenerator(ABC):
"""Document ID generator base class."""

@abstractmethod
def generate_id(self, doc: DLDocument, *args: Any, **kwargs: Any) -> str:
"""Generate an ID for the given document.

Args:
doc (DLDocument): document to generate ID for

Raises:
NotImplementedError: in this abstract implementation

Returns:
str: the generated ID
"""
raise NotImplementedError()
27 changes: 27 additions & 0 deletions docling_core/transforms/id_generator/doc_hash_id_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#

"""Doc-hash-based ID generator module."""


from typing import Any

from docling_core.transforms.id_generator import BaseIDGenerator
from docling_core.types import Document as DLDocument


class DocHashIDGenerator(BaseIDGenerator):
"""Doc-hash-based ID generator class."""

def generate_id(self, doc: DLDocument, *args: Any, **kwargs: Any) -> str:
"""Generate an ID for the given document.

Args:
doc (DLDocument): document to generate ID for

Returns:
str: the generated ID
"""
return doc.file_info.document_hash
34 changes: 34 additions & 0 deletions docling_core/transforms/id_generator/uuid_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#

"""UUID-based ID generator module."""

from random import Random
from typing import Annotated, Any, Optional
from uuid import UUID

from pydantic import BaseModel, Field

from docling_core.transforms.id_generator import BaseIDGenerator
from docling_core.types import Document as DLDocument


class UUIDGenerator(BaseModel, BaseIDGenerator):
"""UUID-based ID generator class."""

seed: Optional[int] = None
uuid_version: Annotated[int, Field(strict=True, ge=1, le=5)] = 4

def generate_id(self, doc: DLDocument, *args: Any, **kwargs: Any) -> str:
"""Generate an ID for the given document.

Args:
doc (DLDocument): document to generate ID for

Returns:
str: the generated ID
"""
rd = Random(x=self.seed)
return str(UUID(int=rd.getrandbits(128), version=self.uuid_version))
13 changes: 13 additions & 0 deletions docling_core/transforms/metadata_extractor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#

"""Define the metadata extractor types."""

from docling_core.transforms.metadata_extractor.base import ( # noqa
BaseMetadataExtractor,
)
from docling_core.transforms.metadata_extractor.simple_metadata_extractor import ( # noqa
SimpleMetadataExtractor,
)
59 changes: 59 additions & 0 deletions docling_core/transforms/metadata_extractor/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#

"""Base metadata extractor module."""


from abc import ABC, abstractmethod
from typing import Any

from pydantic import BaseModel

from docling_core.types import Document as DLDocument


class BaseMetadataExtractor(BaseModel, ABC):
"""Metadata extractor base class."""

@abstractmethod
def get_metadata(
self, doc: DLDocument, *args: Any, **kwargs: Any
) -> dict[str, Any]:
"""Extract metadata for the given document.

Args:
doc (DLDocument): document to extract metadata for

Raises:
NotImplementedError: in this abstract implementation

Returns:
dict[str, Any]: the extracted metadata
"""
raise NotImplementedError()

@abstractmethod
def get_excluded_embed_metadata_keys(self) -> list[str]:
"""Get metadata keys to exclude from embedding.

Raises:
NotImplementedError: in this abstract implementation

Returns:
list[str]: the metadata to exclude
"""
raise NotImplementedError()

@abstractmethod
def get_excluded_llm_metadata_keys(self) -> list[str]:
"""Get metadata keys to exclude from LLM generation.

Raises:
NotImplementedError: in this abstract implementation

Returns:
list[str]: the metadata to exclude
"""
raise NotImplementedError()
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#

"""Simple metadata extractor module."""


from enum import Enum
from typing import Any

from docling_core.transforms.metadata_extractor import BaseMetadataExtractor
from docling_core.types import Document as DLDocument


class SimpleMetadataExtractor(BaseMetadataExtractor):
"""Simple metadata extractor class."""

class _Keys(str, Enum):
DL_DOC_HASH = "dl_doc_hash"
ORIGIN = "origin"

include_origin: bool = False

def get_metadata(
self, doc: DLDocument, origin: str, *args: Any, **kwargs: Any
) -> dict[str, Any]:
"""Extract metadata for the given document.

Args:
doc (DLDocument): document to extract metadata for
origin (str): the document origin

Returns:
dict[str, Any]: the extracted metadata
"""
meta: dict[str, Any] = {
self._Keys.DL_DOC_HASH: doc.file_info.document_hash,
}
if self.include_origin:
meta[self._Keys.ORIGIN] = origin
return meta

def get_excluded_embed_metadata_keys(self) -> list[str]:
"""Get metadata keys to exclude from embedding.

Returns:
list[str]: the metadata to exclude
"""
excl_keys: list[str] = [self._Keys.DL_DOC_HASH]
if self.include_origin:
excl_keys.append(self._Keys.ORIGIN)
return excl_keys

def get_excluded_llm_metadata_keys(self) -> list[str]:
"""Get metadata keys to exclude from LLM generation.

Returns:
list[str]: the metadata to exclude
"""
return self.get_excluded_embed_metadata_keys()
Loading