diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml
new file mode 100644
index 0000000..ab7d218
--- /dev/null
+++ b/.github/workflows/test.yaml
@@ -0,0 +1,35 @@
+name: test
+
+on:
+ push:
+ branches:
+ - main
+ pull_request:
+ branches:
+ - main
+
+jobs:
+ test:
+ strategy:
+ matrix:
+ os: [ubuntu-latest, macos-latest, windows-latest]
+ python-version: ['3.10', 3.11, 3.12]
+
+ runs-on: ${{ matrix.os }}
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v4
+
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Install dependencies
+ run: |
+ pip install -r requirements.txt
+
+ - name: Run tests
+ run: |
+ python -m unittest discover
\ No newline at end of file
diff --git a/README.md b/README.md
index 0ddad0d..e16d412 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,8 @@
Xiwen 析文
-
+
+
@@ -69,7 +70,7 @@ The table below lists the number of simplified hanzi per grade, and the number o
[![](https://img.shields.io/badge/GitHub-xiwen-181717.svg?flat&logo=GitHub&logoColor=white)](https://github.com/essteer/xiwen)
-Clone the `xiwen` repo from GitHub for the full source code. The repo includes the CSV and text files used to generate the character lists and a test suite.
+Clone the `xiwen` repo from GitHub for the full code, files used to generate the character lists and a test suite.
```console
$ git clone git@github.com:essteer/xiwen
diff --git a/pyproject.toml b/pyproject.toml
index 05f0fa7..0412c70 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -8,7 +8,6 @@ dependencies = [
"masquer>=1.1.1",
"polars==0.20.31",
"requests>=2.32.3",
- "tqdm==4.66.2",
]
requires-python = ">=3.9"
license = { file = "LICENSE" }
@@ -21,7 +20,6 @@ classifiers = [
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
@@ -31,6 +29,7 @@ classifiers = [
dev = [
"pre-commit==3.7.0",
"ruff>=0.4.5",
+ "tqdm==4.66.2",
]
[project.urls]
diff --git a/requirements.txt b/requirements.txt
index dad8444..d2f17b4 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -16,7 +16,5 @@ requests==2.32.3
# via xiwen (pyproject.toml)
soupsieve==2.5
# via beautifulsoup4
-tqdm==4.66.3
- # via xiwen (pyproject.toml)
urllib3==2.2.1
# via requests
diff --git a/src/resources/main.py b/src/resources/main.py
index 2be3baf..7d62fde 100644
--- a/src/resources/main.py
+++ b/src/resources/main.py
@@ -1,4 +1,4 @@
-import pandas as pd
+import polars as pl
from tqdm import tqdm
from xiwen.utils.config import (
ASSETS_DIR,
@@ -24,7 +24,7 @@
##########################################################################
# Read hsk30-chars-ext.csv
-df = pd.read_csv(HSK_PATH)
+df = pl.read_csv(HSK_PATH)
# Extract character columns and HSK grades
df = df[["Hanzi", "Traditional", "Level"]]
# Rename columns
@@ -33,8 +33,8 @@
)
# Get pinyin based on traditional characters
trad_hanzi = df["Traditional"].tolist()
-# pinyin_df = pd.DataFrame("Pinyin": get_pinyin(trad_hanzi, pinyin_map))
-df["Pinyin"] = pd.DataFrame({"Pinyin": get_pinyin(trad_hanzi, pinyin_map)})
+# pinyin_df = pl.DataFrame("Pinyin": get_pinyin(trad_hanzi, pinyin_map))
+df["Pinyin"] = pl.DataFrame({"Pinyin": get_pinyin(trad_hanzi, pinyin_map)})
##########################################################################
# Add unicode for simplified and traditional hanzi
@@ -67,7 +67,7 @@
# DataFrame of Jun Da character frequencies
cols = ["Simplified", "JD Rank", "JD Frequency", "JD Percentile"]
-junda_df = pd.DataFrame(junda_freqs, columns=cols)
+junda_df = pl.DataFrame(junda_freqs, columns=cols)
##########################################################################
# Map frequencies to HSK set
diff --git a/src/xiwen/utils/analysis.py b/src/xiwen/utils/analysis.py
index 25c6db1..fcde6ed 100644
--- a/src/xiwen/utils/analysis.py
+++ b/src/xiwen/utils/analysis.py
@@ -1,4 +1,5 @@
import polars as pl
+import sys
from .config import HSK_GRADES, STATS_COLUMNS
from .counters import cumulative_counts, get_counts, granular_counts
@@ -20,6 +21,8 @@ def identify_variant(hsk_simp: list, hsk_trad: list) -> str:
str
text character variant
"""
+ # Use epsilon to mitigate float rounding errors
+ epsilon = sys.float_info.epsilon
# Threshold beyond which to decide that text belongs to one variant
threshold = 0.90
simp_set = set(hsk_simp) - set(hsk_trad)
@@ -29,9 +32,11 @@ def identify_variant(hsk_simp: list, hsk_trad: list) -> str:
return "Unknown"
ratio = len(simp_set) / (len(simp_set) + len(trad_set))
- if ratio >= threshold:
+ if ratio >= threshold - epsilon:
return "Simplified"
- return "Traditional"
+ elif ratio <= 1 - threshold + epsilon:
+ return "Traditional"
+ return "Unknown"
def compute_stats(
@@ -143,7 +148,6 @@ def analyse(
"Unknown": trad_list,
}
# Get counts of each hanzi
- # hanzi_df = get_counts(variants[variant], variant, HSK_HANZI)
hanzi_df = get_counts(variants[variant], variant)
# Get counts of hanzi by grade
grade_counts = granular_counts(hanzi_df, hanzi_list, variant)
diff --git a/src/xiwen/utils/counters.py b/src/xiwen/utils/counters.py
index 2d52f2b..e5d9cd7 100644
--- a/src/xiwen/utils/counters.py
+++ b/src/xiwen/utils/counters.py
@@ -1,6 +1,6 @@
import polars as pl
from .config import HSK_GRADES
-from .hanzi import get_hanzi_processor_instance
+from .hanzi import get_HSKHanzi_instance
def unit_counts(hanzi: list) -> dict:
@@ -100,15 +100,18 @@ def get_counts(hanzi_subset: list, variant: str) -> pl.DataFrame:
merged_df : pl.DataFrame
DataFrame of HSK_HANZI with counts applied
"""
+ # Get DataFrame of full HSK character liss
+ hsk_hanzi = get_HSKHanzi_instance().HSK_HANZI
# Count occurrences of each character in hanzi_subset
counts = unit_counts(hanzi_subset)
+
+ # Merge on variant column
+ if variant == "Unknown":
+ variant = "Traditional"
# Create DataFrame from counts dictionary
counts_df = pl.DataFrame(
list(counts.items()), schema={variant: pl.String, "Count": pl.Int32}
)
- # Get DataFrame of full HSK character liss
- hsk_hanzi = get_hanzi_processor_instance().HSK_HANZI
- # Merge on variant column
merged_df = hsk_hanzi.join(counts_df, on=variant, coalesce=True, how="left")
# Fill null values and convert counts to integers
merged_df = merged_df.fill_null(0).with_columns(pl.col("Count").cast(pl.Int32))
diff --git a/src/xiwen/utils/hanzi.py b/src/xiwen/utils/hanzi.py
index 0f2ad64..cb712c8 100644
--- a/src/xiwen/utils/hanzi.py
+++ b/src/xiwen/utils/hanzi.py
@@ -3,7 +3,7 @@
from .config import ASSETS_DIR, HSK30_HANZI_SCHEMA
-class HanziProcessor:
+class HSKHanzi:
"""
Loads and retains HSK character lists
Singleton pattern -> only one instance exists
@@ -25,7 +25,7 @@ class HanziProcessor:
def __new__(cls):
if cls._instance is None:
- cls._instance = super(HanziProcessor, cls).__new__(cls)
+ cls._instance = super(HSKHanzi, cls).__new__(cls)
cls._instance._initialize()
return cls._instance
@@ -38,8 +38,8 @@ def _initialize(self):
self.HSK_TRAD = self.HSK_HANZI.select("Traditional").to_series().to_list()
-def get_hanzi_processor_instance():
+def get_HSKHanzi_instance():
"""
- Gets and returns the HanziProcessor class
+ Gets and returns the HSKHanzi class
"""
- return HanziProcessor()
+ return HSKHanzi()
diff --git a/src/xiwen/utils/transform.py b/src/xiwen/utils/transform.py
index b52b2a1..adf6886 100644
--- a/src/xiwen/utils/transform.py
+++ b/src/xiwen/utils/transform.py
@@ -1,4 +1,4 @@
-from .hanzi import get_hanzi_processor_instance
+from .hanzi import get_HSKHanzi_instance
def partition_hanzi(hanzi_list: list) -> tuple[list]:
@@ -23,8 +23,8 @@ def partition_hanzi(hanzi_list: list) -> tuple[list]:
outliers : list
characters not in above lists
"""
- hsk_simp = get_hanzi_processor_instance().HSK_SIMP
- hsk_trad = get_hanzi_processor_instance().HSK_TRAD
+ hsk_simp = get_HSKHanzi_instance().HSK_SIMP
+ hsk_trad = get_HSKHanzi_instance().HSK_TRAD
simp = [zi for zi in hanzi_list if zi in hsk_simp]
trad = [zi for zi in hanzi_list if zi in hsk_trad]
diff --git a/tests/test_analysis.py b/tests/test_analysis.py
index ea97a25..4d66f1e 100644
--- a/tests/test_analysis.py
+++ b/tests/test_analysis.py
@@ -1,5 +1,5 @@
import os
-import pandas as pd
+import polars as pl
import unittest
from src.xiwen.utils.analysis import identify_variant
from src.xiwen.utils.config import ENCODING
@@ -9,9 +9,9 @@
TEST_ASSETS = os.path.abspath(os.path.join("tests", "assets"))
# Combine script directory with relative path to the file
-filepath = os.path.join("src", "xiwen", "assets", "hsk30_hanzi.csv")
+filepath = os.path.join("src", "xiwen", "assets", "hsk30_hanzi.parquet")
# Load HSK Hanzi database (unigrams only)
-HSK_HANZI = pd.read_csv(filepath)
+HSK_HANZI = pl.read_parquet(filepath)
HSK_SIMP = list(HSK_HANZI["Simplified"])
HSK_TRAD = list(HSK_HANZI["Traditional"])
@@ -354,3 +354,7 @@ def test_known_figures(self):
simp, trad, outliers = partition_hanzi(hanzi)
# Check identified character variant
self.assertEqual(identify_variant(simp, trad), TEST_CASES[test_case][0])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_counters.py b/tests/test_counters.py
index 7d88faa..fa7dd9b 100644
--- a/tests/test_counters.py
+++ b/tests/test_counters.py
@@ -1,30 +1,298 @@
import os
+import polars as pl
+import sys
import unittest
-from src.xiwen.utils.counters import unit_counts
+from polars.testing import assert_frame_equal
+from src.xiwen.utils.config import HSK_GRADES
+from src.xiwen.utils.counters import (
+ cumulative_counts,
+ get_counts,
+ granular_counts,
+ unit_counts,
+)
+from src.xiwen.utils.extract import filter_text
+from src.xiwen.utils.hanzi import get_HSKHanzi_instance
+from src.xiwen.utils.transform import partition_hanzi
TEST_ASSETS = os.path.abspath(os.path.join("tests", "assets"))
-# Combine script directory with relative path to the file
+
+TEST_CASES = {
+ # Simplified only
+ "bjzd.txt": {
+ "Simplified": {
+ 0: (1751, 18896),
+ 1: [281, 11145],
+ 2: [271, 3269],
+ 3: [259, 1809],
+ 4: [217, 981],
+ 5: [185, 472],
+ 6: [133, 325],
+ 7: [314, 646],
+ },
+ "Traditional": {
+ 0: (1751, 18896),
+ 1: [191, 8432],
+ 2: [179, 2263],
+ 3: [163, 1255],
+ 4: [139, 625],
+ 5: [113, 285],
+ 6: [88, 225],
+ 7: [196, 392],
+ },
+ "Unknown": {
+ 0: (1751, 18896),
+ 1: [191, 8432],
+ 2: [179, 2263],
+ 3: [163, 1255],
+ 4: [139, 625],
+ 5: [113, 285],
+ 6: [88, 225],
+ 7: [196, 392],
+ },
+ },
+ # Traditional only
+ "ttc.txt": {
+ "Simplified": {
+ 0: (810, 5686),
+ 1: [96, 1765],
+ 2: [82, 1075],
+ 3: [75, 532],
+ 4: [51, 438],
+ 5: [34, 124],
+ 6: [40, 148],
+ 7: [106, 308],
+ },
+ "Traditional": {
+ 0: (810, 5686),
+ 1: [134, 1926],
+ 2: [118, 1357],
+ 3: [112, 688],
+ 4: [88, 712],
+ 5: [52, 189],
+ 6: [63, 221],
+ 7: [146, 373],
+ },
+ "Unknown": {
+ 0: (810, 5686),
+ 1: [134, 1926],
+ 2: [118, 1357],
+ 3: [111, 687],
+ 4: [88, 712],
+ 5: [52, 189],
+ 6: [63, 221],
+ 7: [146, 373],
+ },
+ },
+ # Latin alphabet (no hanzi)
+ "iliad.txt": {
+ "Simplified": {
+ 0: (0, 0),
+ 1: [0, 0],
+ 2: [0, 0],
+ 3: [0, 0],
+ 4: [0, 0],
+ 5: [0, 0],
+ 6: [0, 0],
+ 7: [0, 0],
+ },
+ "Traditional": {
+ 0: (0, 0),
+ 1: [0, 0],
+ 2: [0, 0],
+ 3: [0, 0],
+ 4: [0, 0],
+ 5: [0, 0],
+ 6: [0, 0],
+ 7: [0, 0],
+ },
+ "Unknown": {
+ 0: (0, 0),
+ 1: [0, 0],
+ 2: [0, 0],
+ 3: [0, 0],
+ 4: [0, 0],
+ 5: [0, 0],
+ 6: [0, 0],
+ 7: [0, 0],
+ },
+ },
+}
class TestUnitCounts(unittest.TestCase):
def test_counts(self):
"""Test counts match across character variants"""
+ hanzi = []
+ test = dict()
+ self.assertEqual(unit_counts(hanzi), test)
hanzi = ["爱", "气", "爱", "气", "车", "爱", "气", "车", "愛", "氣", "車"]
test = {"爱": 3, "气": 3, "车": 2, "愛": 1, "氣": 1, "車": 1}
self.assertEqual(unit_counts(hanzi), test)
-# class TestGetCounts(unittest.TestCase):
-# def test_get_counts(self):
-# """Test counts DataFrame"""
-# all = ["爱", "八", "爸", "杯", "子", "愛", "八", "爸", "杯", "子"]
-# simp = ["爱", "八", "爸", "杯", "子"]
-# trad = ["愛", "八", "爸", "杯", "子"]
-# df_data = {
-# "Simplified": ["爱", "八", "爸", "杯", "子"],
-# "Traditional": ["愛", "八", "爸", "杯", "子"],
-# }
-# df = pd.DataFrame(df_data)
-# results = _get_counts(df, all, (simp, trad), "Unknown")
-# print(results)
+class TestCumulativeCounts(unittest.TestCase):
+ @unittest.skipIf(
+ sys.platform.startswith("win"), "Skip on Windows: test case decode issue"
+ )
+ def test_simplified_set(self):
+ """Test counts match for simplified character set"""
+ variant = "Simplified"
+ for test_case in TEST_CASES.keys():
+ with open(os.path.join(TEST_ASSETS, test_case), "r") as f:
+ text = f.read()
+ # Extract hanzi from text (with duplicates)
+ hanzi_list = filter_text(text)
+ simp, _, _ = partition_hanzi(hanzi_list)
+ # Get counts of each hanzi
+ hanzi_df = get_counts(simp, variant)
+ # Get counts by grade (test case)
+ counts = granular_counts(hanzi_df, hanzi_list, variant)
+
+ cumulative_num_unique = 0
+ cumulative_num_grade = 0
+ for i in range(1, HSK_GRADES + 1):
+ cumulative_num_unique += counts[i][0]
+ cumulative_num_grade += counts[i][1]
+ self.assertEqual(cumulative_counts(counts)[i][0], cumulative_num_unique)
+ self.assertEqual(cumulative_counts(counts)[i][1], cumulative_num_grade)
+
+ @unittest.skipIf(
+ sys.platform.startswith("win"), "Skip on Windows: test case decode issue"
+ )
+ def test_traditional_set(self):
+ """Test counts match for traditional character set"""
+ variant = "Traditional"
+ for test_case in TEST_CASES.keys():
+ with open(os.path.join(TEST_ASSETS, test_case), "r") as f:
+ text = f.read()
+ # Extract hanzi from text (with duplicates)
+ hanzi_list = filter_text(text)
+ simp, _, _ = partition_hanzi(hanzi_list)
+ # Get counts of each hanzi
+ hanzi_df = get_counts(simp, variant)
+ # Get counts by grade (test case)
+ counts = granular_counts(hanzi_df, hanzi_list, variant)
+
+ cumulative_num_unique = 0
+ cumulative_num_grade = 0
+ for i in range(1, HSK_GRADES + 1):
+ cumulative_num_unique += counts[i][0]
+ cumulative_num_grade += counts[i][1]
+ self.assertEqual(cumulative_counts(counts)[i][0], cumulative_num_unique)
+ self.assertEqual(cumulative_counts(counts)[i][1], cumulative_num_grade)
+
+
+class TestGetCounts(unittest.TestCase):
+ @unittest.skipIf(
+ sys.platform.startswith("win"), "Skip on Windows: test case decode issue"
+ )
+ def test_simplified_set(self):
+ """Test counts correct for simplified characters"""
+ variant = "Simplified"
+ # Get DataFrame of full HSK character liss
+ hsk_hanzi = get_HSKHanzi_instance().HSK_HANZI
+ for test_case in TEST_CASES.keys():
+ with open(os.path.join(TEST_ASSETS, test_case), "r") as f:
+ text = f.read()
+ # Extract hanzi from text (with duplicates)
+ hanzi_list = filter_text(text)
+ simp, _, _ = partition_hanzi(hanzi_list)
+ counts = unit_counts(simp)
+ # Create DataFrame from counts dictionary
+ counts_df = pl.DataFrame(
+ list(counts.items()), schema={variant: pl.String, "Count": pl.Int32}
+ )
+ merged_df = hsk_hanzi.join(counts_df, on=variant, coalesce=True, how="left")
+ # Fill null values and convert counts to integers
+ merged_df = merged_df.fill_null(0).with_columns(
+ pl.col("Count").cast(pl.Int32)
+ )
+ self.assertIsNone(assert_frame_equal(get_counts(simp, variant), merged_df))
+
+ @unittest.skipIf(
+ sys.platform.startswith("win"), "Skip on Windows: test case decode issue"
+ )
+ def test_traditional_set(self):
+ """Test counts correct for traditional characters"""
+ variant = "Traditional"
+ # Get DataFrame of full HSK character liss
+ hsk_hanzi = get_HSKHanzi_instance().HSK_HANZI
+ for test_case in TEST_CASES.keys():
+ with open(os.path.join(TEST_ASSETS, test_case), "r") as f:
+ text = f.read()
+ # Extract hanzi from text (with duplicates)
+ hanzi_list = filter_text(text)
+ simp, _, _ = partition_hanzi(hanzi_list)
+ counts = unit_counts(simp)
+ # Create DataFrame from counts dictionary
+ counts_df = pl.DataFrame(
+ list(counts.items()), schema={variant: pl.String, "Count": pl.Int32}
+ )
+ merged_df = hsk_hanzi.join(counts_df, on=variant, coalesce=True, how="left")
+ # Fill null values and convert counts to integers
+ merged_df = merged_df.fill_null(0).with_columns(
+ pl.col("Count").cast(pl.Int32)
+ )
+ self.assertIsNone(assert_frame_equal(get_counts(simp, variant), merged_df))
+
+
+class TestGranularCounts(unittest.TestCase):
+ @unittest.skipIf(
+ sys.platform.startswith("win"), "Skip on Windows: test case decode issue"
+ )
+ def test_simplified_set(self):
+ """Test correct breakdown for simplified character set"""
+ variant = "Simplified"
+ for test_case in TEST_CASES.keys():
+ with open(os.path.join(TEST_ASSETS, test_case), "r") as f:
+ text = f.read()
+ # Extract hanzi from text (with duplicates)
+ hanzi_list = filter_text(text)
+ simp, _, _ = partition_hanzi(hanzi_list)
+ # Get counts of each hanzi
+ hanzi_df = get_counts(simp, variant)
+ # Get counts by grade (test case)
+ counts = granular_counts(hanzi_df, hanzi_list, variant)
+ self.assertEqual(TEST_CASES[test_case][variant], counts)
+
+ @unittest.skipIf(
+ sys.platform.startswith("win"), "Skip on Windows: test case decode issue"
+ )
+ def test_traditional_set(self):
+ """Test correct breakdown for traditional character set"""
+ variant = "Traditional"
+ for test_case in TEST_CASES.keys():
+ with open(os.path.join(TEST_ASSETS, test_case), "r") as f:
+ text = f.read()
+ # Extract hanzi from text (with duplicates)
+ hanzi_list = filter_text(text)
+ _, trad, _ = partition_hanzi(hanzi_list)
+ # Get counts of each hanzi
+ hanzi_df = get_counts(trad, variant)
+ # Get counts by grade (test case)
+ counts = granular_counts(hanzi_df, hanzi_list, variant)
+ self.assertEqual(TEST_CASES[test_case][variant], counts)
+
+ @unittest.skipIf(
+ sys.platform.startswith("win"), "Skip on Windows: test case decode issue"
+ )
+ def test_unknown_set(self):
+ """Test correct breakdown for unknown character set"""
+ variant = "Unknown"
+ for test_case in TEST_CASES.keys():
+ with open(os.path.join(TEST_ASSETS, test_case), "r") as f:
+ text = f.read()
+ # Extract hanzi from text (with duplicates)
+ hanzi_list = filter_text(text)
+ _, trad, _ = partition_hanzi(hanzi_list)
+ # Get counts of each hanzi
+ hanzi_df = get_counts(trad, variant)
+ # Get counts by grade (test case)
+ counts = granular_counts(hanzi_df, hanzi_list, variant)
+ # Figures should match traditional counts
+ self.assertEqual(TEST_CASES[test_case][variant], counts)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_hanzi.py b/tests/test_hanzi.py
new file mode 100644
index 0000000..f48c176
--- /dev/null
+++ b/tests/test_hanzi.py
@@ -0,0 +1,59 @@
+import os
+import polars as pl
+import unittest
+from polars.testing import assert_frame_equal
+from src.xiwen.utils.config import ASSETS_DIR, HSK30_HANZI_SCHEMA
+from src.xiwen.utils.hanzi import get_HSKHanzi_instance, HSKHanzi
+
+
+class TestHSKHanzi(unittest.TestCase):
+ def test_references_to_HSKHanzi(self):
+ """Test separate references are equal"""
+ A = HSKHanzi()
+ B = HSKHanzi()
+ self.assertEqual(A, B)
+
+ def test_one_HSKHanzi_exists(self):
+ """Test just one instance exists despite multiple calls"""
+ A = HSKHanzi()
+ B = HSKHanzi()
+ self.assertIs(A, B._instance)
+
+ def test_HSKHanzi_attributes_exist(self):
+ """Test HSKHanzi has expected attributes"""
+ A = HSKHanzi()
+ self.assertTrue(hasattr(A, "HSK_HANZI"))
+ self.assertTrue(hasattr(A, "HSK_SIMP"))
+ self.assertTrue(hasattr(A, "HSK_TRAD"))
+
+ def test_HSKHanzi_attributes_correct(self):
+ """Test class attributes match expected dataframes and lists"""
+ HSK_HANZI = pl.read_parquet(
+ os.path.join(ASSETS_DIR, "hsk30_hanzi.parquet"),
+ hive_schema=HSK30_HANZI_SCHEMA,
+ )
+ A = HSKHanzi().HSK_HANZI
+ self.assertIsNone(assert_frame_equal(HSK_HANZI, A))
+ HSK_SIMP = HSK_HANZI.select("Simplified").to_series().to_list()
+ B = HSKHanzi().HSK_SIMP
+ self.assertEqual(HSK_SIMP, B)
+ HSK_TRAD = HSK_HANZI.select("Traditional").to_series().to_list()
+ C = HSKHanzi().HSK_TRAD
+ self.assertEqual(HSK_TRAD, C)
+
+
+class TestGetHSKHanziInstance(unittest.TestCase):
+ def test_instance_returned(self):
+ """Test function returns an HSKHanzi instance"""
+ A = get_HSKHanzi_instance()
+ self.assertTrue(isinstance(A, HSKHanzi))
+
+ def test_returns_same_instance(self):
+ """Test multiple calls return same HSKHanzi instance"""
+ A = get_HSKHanzi_instance()
+ B = get_HSKHanzi_instance()
+ self.assertIs(A, B._instance)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_pinyin.py b/tests/test_pinyin.py
index a726af0..972c645 100644
--- a/tests/test_pinyin.py
+++ b/tests/test_pinyin.py
@@ -3,7 +3,7 @@
from src.xiwen.utils.pinyin import get_pinyin, map_pinyin
PINYIN_PATH = os.path.join(
- os.getcwd(), "src", "file_prep", "assets", "hanzi_pinyin_characters.tsv.txt"
+ os.getcwd(), "src", "resources", "assets", "hanzi_pinyin_characters.tsv.txt"
)
SIMP_HANZI_TO_PINYIN = {
diff --git a/tests/test_transform.py b/tests/test_transform.py
index 6021af6..c35a3e8 100644
--- a/tests/test_transform.py
+++ b/tests/test_transform.py
@@ -28,8 +28,8 @@
class TestPartitionHanzi(unittest.TestCase):
def test_partition(self):
"""Test characters are separated appropriately"""
- simp = ["爱", "气", "车", "电", "话", "点", "脑", "视", "东"]
- trad = ["愛", "氣", "車", "電", "話", "點", "腦", "視", "東"]
+ simp = ["爱", "气", "车", "电", "话", "点", "脑", "视", "东", "不", "了"]
+ trad = ["愛", "氣", "車", "電", "話", "點", "腦", "視", "東", "不", "了"]
test = [
"爱",
"气",
@@ -49,10 +49,12 @@ def test_partition(self):
"腦",
"視",
"東",
+ "不",
+ "了",
]
self.assertEqual(partition_hanzi(test), (simp, trad, []))
- self.assertEqual(partition_hanzi(simp), (simp, [], []))
- self.assertEqual(partition_hanzi(trad), ([], trad, []))
+ self.assertEqual(partition_hanzi(simp), (simp, ["不", "了"], []))
+ self.assertEqual(partition_hanzi(trad), (["不", "了"], trad, []))
test = [
"爱",
"气",
@@ -72,6 +74,8 @@ def test_partition(self):
"腦",
"視",
"東",
+ "不",
+ "了",
"朕",
]
self.assertEqual(partition_hanzi(test), (simp, trad, ["朕"]))