Skip to content

Commit

Permalink
Replace manually assert with unittest assertion.
Browse files Browse the repository at this point in the history
  • Loading branch information
zh-plus committed May 11, 2024
1 parent 29d5417 commit 3b470d3
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 98 deletions.
18 changes: 11 additions & 7 deletions tests/test_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_estimate_fee(self):
{'role': 'user', 'content': 'Hello'},
]
fee = bot.estimate_fee(messages)
assert isclose(fee, 6e-06)
self.assertTrue(isclose(fee, 6e-06))

def test_gpt_update_fee(self):
bot = self.gpt_bot
Expand All @@ -55,8 +55,7 @@ def test_gpt_update_fee(self):
bot.api_fees += [0]
response3 = OpenAIResponse(usage=OpenAIUsage(prompt_tokens=300, completion_tokens=600, total_tokens=900))
bot.update_fee(response3)

assert bot.api_fees == [0.00035, 0.0007, 0.00105]
self.assertListEqual(bot.api_fees, [0.00035, 0.0007, 0.00105])

def test_claude_update_fee(self):
bot = self.claude_bot
Expand All @@ -72,7 +71,7 @@ def test_claude_update_fee(self):
response3 = OpenAIResponse(usage=AnthropicUsage(input_tokens=300, output_tokens=600))
bot.update_fee(response3)

assert bot.api_fees == [0.0033, 0.0066, 0.0099]
self.assertListEqual(bot.api_fees, [0.0033, 0.0066, 0.0099])

def test_gpt_message_async(self):
bot = self.gpt_bot
Expand All @@ -85,7 +84,8 @@ def test_gpt_message_async(self):
],
]
results = bot.message(messages_list)
assert all(['hello' in bot.get_content(r).lower() for r in results])

self.assertTrue(all(['hello' in bot.get_content(r).lower() for r in results]))

def test_claude_message_async(self):
bot = self.claude_bot
Expand All @@ -98,7 +98,8 @@ def test_claude_message_async(self):
],
]
results = bot.message(messages_list)
assert all(['hello' in bot.get_content(r).lower() for r in results])

self.assertTrue(all(['hello' in bot.get_content(r).lower() for r in results]))

def test_gpt_message_seq(self):
bot = self.gpt_bot
Expand All @@ -108,7 +109,8 @@ def test_gpt_message_seq(self):
]
]
results = bot.message(messages_list)
assert 'hello' in bot.get_content(results[0]).lower()

self.assertIn('hello', bot.get_content(results[0]).lower())

def test_claude_message_seq(self):
bot = self.claude_bot
Expand All @@ -119,3 +121,5 @@ def test_claude_message_seq(self):
]
results = bot.message(messages_list)
assert 'hello' in bot.get_content(results[0]).lower()

self.assertIn('hello', bot.get_content(results[0]).lower())
40 changes: 20 additions & 20 deletions tests/test_context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2023. Hao Zheng
# Copyright (C) 2024. Hao Zheng
# All rights reserved.

import tempfile
Expand All @@ -15,10 +15,10 @@ def setUp(self) -> None:

def test_init(self):
context = self.context
assert context.background == 'test background'
assert context.audio_type == 'test audio type'
assert context.description_map == {'test audio name': 'description'}
assert context.config_path is None
self.assertEqual(context.background, 'test background')
self.assertEqual(context.audio_type, 'test audio type')
self.assertEqual(context.description_map, {'test audio name': 'description'})
self.assertIsNone(context.config_path)

def test_init_with_config_file(self):
with tempfile.NamedTemporaryFile(mode='w', delete=False) as f:
Expand All @@ -28,10 +28,10 @@ def test_init_with_config_file(self):

context = Context(config_path=config_path)

assert context.background == 'config background'
assert context.audio_type == 'config audio type'
assert context.description_map == {'config': 'config description'}
assert context.config_path == config_path
self.assertEqual(context.background, 'config background')
self.assertEqual(context.audio_type, 'config audio type')
self.assertEqual(context.description_map, {'config': 'config description'})
self.assertEqual(context.config_path, config_path)

config_path.unlink()

Expand All @@ -48,10 +48,10 @@ def test_load_config(self):

context.load_config(config_path)

assert context.background == 'config background'
assert context.audio_type == 'config audio type'
assert context.description_map == {'config': 'config description'}
assert context.config_path == config_path
self.assertEqual(context.background, 'config background')
self.assertEqual(context.audio_type, 'config audio type')
self.assertEqual(context.description_map, {'config': 'config description'})
self.assertEqual(context.config_path, config_path)

config_path.unlink()

Expand All @@ -66,17 +66,17 @@ def test_save_config(self):
with open(config_path, 'r') as file:
config = file.read()

assert 'background: test background' in config
assert 'audio_type: test audio type' in config
assert 'description_map:\n test audio name: description' in config
self.assertIn('background: test background', config)
self.assertIn('audio_type: test audio type', config)
self.assertIn('description_map:\n test audio name: description', config)

config_path.unlink()

def test_get_description(self):
context = self.context
assert context.get_description('test audio name') == 'description'
assert context.get_description('audio name without description') == ''
self.assertEqual(context.get_description('test audio name'), 'description')
self.assertEqual(context.get_description('audio name without description'), '')

def test_str(self):
assert str(self.context) == \
'Context(background=test background, audio_type=test audio type, description_map={\'test audio name\': \'description\'})'
self.assertEqual(str(self.context),
'Context(background=test background, audio_type=test audio type, description_map={\'test audio name\': \'description\'})')
20 changes: 10 additions & 10 deletions tests/test_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,52 +18,52 @@ def test_merge_same(self):
original_len = len(subtitle)
optimizer = SubtitleOptimizer(subtitle)
optimizer.merge_same()
assert len(optimizer.subtitle.segments) == original_len - 1
self.assertEqual(len(optimizer.subtitle.segments), original_len - 1)

def test_merge_short(self):
subtitle = self.subtitle
original_len = len(subtitle)
optimizer = SubtitleOptimizer(subtitle)
optimizer.merge_short()
assert len(optimizer.subtitle.segments) == original_len - 1
self.assertEqual(len(optimizer.subtitle.segments), original_len - 1)

def test_merge_repeat(self):
subtitle = self.subtitle
optimizer = SubtitleOptimizer(subtitle)
optimizer.merge_repeat()
assert optimizer.subtitle.segments[2].text == '好好...'
self.assertEqual(optimizer.subtitle.segments[2].text, '好好...')

def test_cut_long(self):
subtitle = self.subtitle
optimizer = SubtitleOptimizer(subtitle)
optimizer.cut_long(keep=2)
assert optimizer.subtitle.segments[4].text == '这太'
self.assertEqual(optimizer.subtitle.segments[4].text, '这太')

def test_traditional2mandarin(self):
subtitle = self.subtitle
optimizer = SubtitleOptimizer(subtitle)
optimizer.traditional2mandarin()
assert optimizer.subtitle.segments[5].text == '繁体的字'
self.assertEqual(optimizer.subtitle.segments[5].text, '繁体的字')

def test_punctuation_optimization(self):
subtitle = self.subtitle
optimizer = SubtitleOptimizer(subtitle)
optimizer.punctuation_optimization()
assert optimizer.subtitle.segments[0].text == '你好,你好...你好!你好。'
self.assertEqual(optimizer.subtitle.segments[0].text, '你好,你好...你好!你好。')

def test_remove_unk(self):
subtitle = self.subtitle
optimizer = SubtitleOptimizer(subtitle)
optimizer.remove_unk()
assert optimizer.subtitle.segments[6].text == 'unk'
self.assertEqual(optimizer.subtitle.segments[6].text, 'unk')

def test_remove_empty(self):
subtitle = self.subtitle
subtitle.segments[0].text = ''
original_len = len(subtitle)
optimizer = SubtitleOptimizer(subtitle)
optimizer.remove_empty()
assert len(optimizer.subtitle.segments) == original_len - 1
self.assertEqual(len(optimizer.subtitle.segments), original_len - 1)

def test_save(self):
subtitle = self.subtitle
Expand All @@ -74,7 +74,7 @@ def test_save(self):
with open('data/test_subtitle_optimized.json', 'r', encoding='utf-8') as f:
optimized_subtitle = json.load(f)

assert optimized_subtitle['language'] == 'zh'
assert len(optimized_subtitle['segments']) == 7
self.assertEqual(optimized_subtitle['language'], 'zh')
self.assertEqual(len(optimized_subtitle['segments']), 7)

os.remove('data/test_subtitle_optimized.json')
12 changes: 7 additions & 5 deletions tests/test_prompter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2023. Hao Zheng
# Copyright (C) 2024. Hao Zheng
# All rights reserved.

import unittest
Expand Down Expand Up @@ -46,14 +46,16 @@ def test_user_prompt(self):
Original>
生き残る秘訣は、進化し続けることです。
Translation>'''
assert self.prompter.user(1, user_input, ['test chunk1 summary', 'test chunk2 summary'],
'test scene content') == self.formatted_user_input
self.assertEqual(
self.prompter.user(1, user_input, ['test chunk1 summary', 'test chunk2 summary'], 'test scene content'),
self.formatted_user_input
)

def test_format_texts(self):
texts = [(1, '変わりゆく時代において、'), (2, '生き残る秘訣は、進化し続けることです。')]
expected_output = '#1\nOriginal>\n変わりゆく時代において、\nTranslation>\n\n#2\nOriginal>\n' \
'生き残る秘訣は、進化し続けることです。\nTranslation>\n'
assert BaseTranslatePrompter.format_texts(texts) == expected_output
self.assertEqual(BaseTranslatePrompter.format_texts(texts), expected_output)

def test_check_format(self):
messages = [{'role': 'system', 'content': 'system content'},
Expand Down Expand Up @@ -82,4 +84,4 @@ def test_check_format(self):
<summary>Summary</summary>
<scene>Scene</scene>
'''
assert self.prompter.check_format(messages, content) is True
self.assertTrue(self.prompter.check_format(messages, content))
23 changes: 10 additions & 13 deletions tests/test_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,46 +25,43 @@ def test_single_chunk_translation(self):
translator = LLMTranslator(chatbot_model)
translation = translator.translate(text, 'en', 'es')[0]

assert get_similarity(translation, 'Hola, ¿cómo estás?') > 0.618
self.tearDown()
self.assertGreater(get_similarity(translation, 'Hola, ¿cómo estás?'), 0.618)

def test_multiple_chunk_translation(self):
for chatbot_model in test_models:
texts = ['Hello, how are you?', 'I am fine, thank you.']
translator = LLMTranslator(chatbot_model)
translations = translator.translate(texts, 'en', 'es')
assert get_similarity(translations[0], 'Hola, ¿cómo estás?') > 0.618
assert get_similarity(translations[1], 'Estoy bien, gracias.') > 0.618
self.tearDown()
self.assertGreater(get_similarity(translations[0], 'Hola, ¿cómo estás?'), 0.618)
self.assertGreater(get_similarity(translations[1], 'Estoy bien, gracias.'), 0.618)

def test_different_language_translation(self):
for chatbot_model in test_models:
text = 'Hello, how are you?'
translator = LLMTranslator(chatbot_model)
try:
translation = translator.translate(text, 'en', 'ja')[0]
assert (get_similarity(translation, 'こんにちは、お元気ですか?') > 0.618 or
get_similarity(translation, 'こんにちは、調子はどうですか?') > 0.618)
self.assertTrue(
get_similarity(translation, 'こんにちは、お元気ですか?') > 0.618 or
get_similarity(translation, 'こんにちは、調子はどうですか?') > 0.618
)
except (openai.OpenAIError, anthropic.APIError):
pass
self.tearDown()

def test_empty_text_list_translation(self):
for chatbot_model in test_models:
texts = []
translator = LLMTranslator(chatbot_model)
translations = translator.translate(texts, 'en', 'es')
assert translations == []
self.tearDown()
self.assertEqual(translations, [])

def test_atomic_translate(self):
for chatbot_model in test_models:
texts = ['Hello, how are you?', 'I am fine, thank you.']
translator = LLMTranslator(chatbot_model)
translations = translator.atomic_translate(texts, 'en', 'zh')
assert get_similarity(translations[0], '你好,你好吗?') > 0.618
assert get_similarity(translations[1], '我很好,谢谢。') > 0.618
self.tearDown()
self.assertGreater(get_similarity(translations[0], '你好,你好吗?'), 0.618)
self.assertGreater(get_similarity(translations[1], '我很好,谢谢。'), 0.618)

# Not integrated by the openlrc main function because of performance
#
Expand Down
Loading

0 comments on commit 3b470d3

Please sign in to comment.