Skip to content

Commit

Permalink
Refactor: NJDFeature 型の定義を types.py に移動し、if TYPE_CHECKING: せずとも NJDFe…
Browse files Browse the repository at this point in the history
…ature を型ヒントに使えるよう改善

従来は openjtalk.pyx の型定義ファイルである openjtalk.pyi に記述されていたため型としては実態がなく、呼び出すには from __future__ import annotations に加え if TYPE_CHECKING: の中で import する必要があり非常に面倒だった
  • Loading branch information
tsukumijima committed Aug 1, 2024
1 parent 3302a30 commit 2a7faf5
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 42 deletions.
8 changes: 2 additions & 6 deletions pyopenjtalk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from __future__ import annotations

import atexit
import os
import sys
from contextlib import ExitStack
from os.path import exists
from pathlib import Path
from typing import TYPE_CHECKING, Any, List, Tuple, Union
from typing import Any, List, Tuple, Union

import numpy as np
import numpy.typing as npt
Expand All @@ -25,11 +23,9 @@
from .openjtalk import OpenJTalk
from .openjtalk import build_mecab_dictionary as _build_mecab_dictionary
from .openjtalk import mecab_dict_index as _mecab_dict_index
from .types import NJDFeature
from .utils import merge_njd_marine_features, modify_kanji_yomi, modify_masu_acc, retreat_acc_nuc

if TYPE_CHECKING:
from .openjtalk import NJDFeature

_file_manager = ExitStack()
atexit.register(_file_manager.close)

Expand Down
27 changes: 3 additions & 24 deletions pyopenjtalk/openjtalk.pyi
Original file line number Diff line number Diff line change
@@ -1,29 +1,8 @@
# flake8: noqa

import sys
from typing import Dict, Iterable, List

if sys.version_info >= (3, 8):
from typing import TypedDict

class NJDFeature(TypedDict):
string: str
pos: str
pos_group1: str
pos_group2: str
pos_group3: str
ctype: str
cform: str
orig: str
read: str
pron: str
acc: int
mora_size: int
chain_rule: str
chain_flag: int

else:
NJDFeature = Dict[str, str | int]
from typing import Iterable, List

from .types import NJDFeature

class OpenJTalk:
def __init__(self, dn_mecab: bytes = b"/usr/local/dic", userdic: bytes = b"") -> None:
Expand Down
18 changes: 18 additions & 0 deletions pyopenjtalk/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import TypedDict


class NJDFeature(TypedDict):
string: str
pos: str
pos_group1: str
pos_group2: str
pos_group3: str
ctype: str
cform: str
orig: str
read: str
pron: str
acc: int
mora_size: int
chain_rule: str
chain_flag: int
8 changes: 2 additions & 6 deletions pyopenjtalk/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, List
from typing import Any, Dict, List

from sudachipy import dictionary, tokenizer

from .types import NJDFeature
from .yomi_model.nani_predict import predict

if TYPE_CHECKING:
from .openjtalk import NJDFeature


def merge_njd_marine_features(
njd_features: List[NJDFeature], marine_results: Dict[str, Any]
Expand Down
8 changes: 2 additions & 6 deletions pyopenjtalk/yomi_model/nani_predict.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
from __future__ import annotations

import os

# import pickle
from typing import TYPE_CHECKING, List, Union
from typing import List, Union

# import pandas as pd

if TYPE_CHECKING:
from ..openjtalk import NJDFeature
from ..types import NJDFeature

X_COLS = ["pos", "pos_group1", "pos_group2", "pron", "ctype", "cform"]
model_dir = os.path.dirname(__file__)
Expand Down

0 comments on commit 2a7faf5

Please sign in to comment.