diff --git a/tests/test_chatbot.py b/tests/test_chatbot.py index bdde5d4..2ecf449 100644 --- a/tests/test_chatbot.py +++ b/tests/test_chatbot.py @@ -40,7 +40,7 @@ def test_estimate_fee(self): {'role': 'user', 'content': 'Hello'}, ] fee = bot.estimate_fee(messages) - assert isclose(fee, 6e-06) + self.assertTrue(isclose(fee, 6e-06)) def test_gpt_update_fee(self): bot = self.gpt_bot @@ -55,8 +55,7 @@ def test_gpt_update_fee(self): bot.api_fees += [0] response3 = OpenAIResponse(usage=OpenAIUsage(prompt_tokens=300, completion_tokens=600, total_tokens=900)) bot.update_fee(response3) - - assert bot.api_fees == [0.00035, 0.0007, 0.00105] + self.assertListEqual(bot.api_fees, [0.00035, 0.0007, 0.00105]) def test_claude_update_fee(self): bot = self.claude_bot @@ -72,7 +71,7 @@ def test_claude_update_fee(self): response3 = OpenAIResponse(usage=AnthropicUsage(input_tokens=300, output_tokens=600)) bot.update_fee(response3) - assert bot.api_fees == [0.0033, 0.0066, 0.0099] + self.assertListEqual(bot.api_fees, [0.0033, 0.0066, 0.0099]) def test_gpt_message_async(self): bot = self.gpt_bot @@ -85,7 +84,8 @@ def test_gpt_message_async(self): ], ] results = bot.message(messages_list) - assert all(['hello' in bot.get_content(r).lower() for r in results]) + + self.assertTrue(all(['hello' in bot.get_content(r).lower() for r in results])) def test_claude_message_async(self): bot = self.claude_bot @@ -98,7 +98,8 @@ def test_claude_message_async(self): ], ] results = bot.message(messages_list) - assert all(['hello' in bot.get_content(r).lower() for r in results]) + + self.assertTrue(all(['hello' in bot.get_content(r).lower() for r in results])) def test_gpt_message_seq(self): bot = self.gpt_bot @@ -108,7 +109,8 @@ def test_gpt_message_seq(self): ] ] results = bot.message(messages_list) - assert 'hello' in bot.get_content(results[0]).lower() + + self.assertIn('hello', bot.get_content(results[0]).lower()) def test_claude_message_seq(self): bot = self.claude_bot @@ -119,3 +121,5 @@ def test_claude_message_seq(self): ] results = bot.message(messages_list) assert 'hello' in bot.get_content(results[0]).lower() + + self.assertIn('hello', bot.get_content(results[0]).lower()) diff --git a/tests/test_context.py b/tests/test_context.py index c63932d..13cd993 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -1,4 +1,4 @@ -# Copyright (C) 2023. Hao Zheng +# Copyright (C) 2024. Hao Zheng # All rights reserved. import tempfile @@ -15,10 +15,10 @@ def setUp(self) -> None: def test_init(self): context = self.context - assert context.background == 'test background' - assert context.audio_type == 'test audio type' - assert context.description_map == {'test audio name': 'description'} - assert context.config_path is None + self.assertEqual(context.background, 'test background') + self.assertEqual(context.audio_type, 'test audio type') + self.assertEqual(context.description_map, {'test audio name': 'description'}) + self.assertIsNone(context.config_path) def test_init_with_config_file(self): with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: @@ -28,10 +28,10 @@ def test_init_with_config_file(self): context = Context(config_path=config_path) - assert context.background == 'config background' - assert context.audio_type == 'config audio type' - assert context.description_map == {'config': 'config description'} - assert context.config_path == config_path + self.assertEqual(context.background, 'config background') + self.assertEqual(context.audio_type, 'config audio type') + self.assertEqual(context.description_map, {'config': 'config description'}) + self.assertEqual(context.config_path, config_path) config_path.unlink() @@ -48,10 +48,10 @@ def test_load_config(self): context.load_config(config_path) - assert context.background == 'config background' - assert context.audio_type == 'config audio type' - assert context.description_map == {'config': 'config description'} - assert context.config_path == config_path + self.assertEqual(context.background, 'config background') + self.assertEqual(context.audio_type, 'config audio type') + self.assertEqual(context.description_map, {'config': 'config description'}) + self.assertEqual(context.config_path, config_path) config_path.unlink() @@ -66,17 +66,17 @@ def test_save_config(self): with open(config_path, 'r') as file: config = file.read() - assert 'background: test background' in config - assert 'audio_type: test audio type' in config - assert 'description_map:\n test audio name: description' in config + self.assertIn('background: test background', config) + self.assertIn('audio_type: test audio type', config) + self.assertIn('description_map:\n test audio name: description', config) config_path.unlink() def test_get_description(self): context = self.context - assert context.get_description('test audio name') == 'description' - assert context.get_description('audio name without description') == '' + self.assertEqual(context.get_description('test audio name'), 'description') + self.assertEqual(context.get_description('audio name without description'), '') def test_str(self): - assert str(self.context) == \ - 'Context(background=test background, audio_type=test audio type, description_map={\'test audio name\': \'description\'})' + self.assertEqual(str(self.context), + 'Context(background=test background, audio_type=test audio type, description_map={\'test audio name\': \'description\'})') diff --git a/tests/test_opt.py b/tests/test_opt.py index c00e91f..1c1af9f 100644 --- a/tests/test_opt.py +++ b/tests/test_opt.py @@ -18,44 +18,44 @@ def test_merge_same(self): original_len = len(subtitle) optimizer = SubtitleOptimizer(subtitle) optimizer.merge_same() - assert len(optimizer.subtitle.segments) == original_len - 1 + self.assertEqual(len(optimizer.subtitle.segments), original_len - 1) def test_merge_short(self): subtitle = self.subtitle original_len = len(subtitle) optimizer = SubtitleOptimizer(subtitle) optimizer.merge_short() - assert len(optimizer.subtitle.segments) == original_len - 1 + self.assertEqual(len(optimizer.subtitle.segments), original_len - 1) def test_merge_repeat(self): subtitle = self.subtitle optimizer = SubtitleOptimizer(subtitle) optimizer.merge_repeat() - assert optimizer.subtitle.segments[2].text == '好好...' + self.assertEqual(optimizer.subtitle.segments[2].text, '好好...') def test_cut_long(self): subtitle = self.subtitle optimizer = SubtitleOptimizer(subtitle) optimizer.cut_long(keep=2) - assert optimizer.subtitle.segments[4].text == '这太' + self.assertEqual(optimizer.subtitle.segments[4].text, '这太') def test_traditional2mandarin(self): subtitle = self.subtitle optimizer = SubtitleOptimizer(subtitle) optimizer.traditional2mandarin() - assert optimizer.subtitle.segments[5].text == '繁体的字' + self.assertEqual(optimizer.subtitle.segments[5].text, '繁体的字') def test_punctuation_optimization(self): subtitle = self.subtitle optimizer = SubtitleOptimizer(subtitle) optimizer.punctuation_optimization() - assert optimizer.subtitle.segments[0].text == '你好,你好...你好!你好。' + self.assertEqual(optimizer.subtitle.segments[0].text, '你好,你好...你好!你好。') def test_remove_unk(self): subtitle = self.subtitle optimizer = SubtitleOptimizer(subtitle) optimizer.remove_unk() - assert optimizer.subtitle.segments[6].text == 'unk' + self.assertEqual(optimizer.subtitle.segments[6].text, 'unk') def test_remove_empty(self): subtitle = self.subtitle @@ -63,7 +63,7 @@ def test_remove_empty(self): original_len = len(subtitle) optimizer = SubtitleOptimizer(subtitle) optimizer.remove_empty() - assert len(optimizer.subtitle.segments) == original_len - 1 + self.assertEqual(len(optimizer.subtitle.segments), original_len - 1) def test_save(self): subtitle = self.subtitle @@ -74,7 +74,7 @@ def test_save(self): with open('data/test_subtitle_optimized.json', 'r', encoding='utf-8') as f: optimized_subtitle = json.load(f) - assert optimized_subtitle['language'] == 'zh' - assert len(optimized_subtitle['segments']) == 7 + self.assertEqual(optimized_subtitle['language'], 'zh') + self.assertEqual(len(optimized_subtitle['segments']), 7) os.remove('data/test_subtitle_optimized.json') diff --git a/tests/test_prompter.py b/tests/test_prompter.py index 18a0282..94f4a4f 100644 --- a/tests/test_prompter.py +++ b/tests/test_prompter.py @@ -1,4 +1,4 @@ -# Copyright (C) 2023. Hao Zheng +# Copyright (C) 2024. Hao Zheng # All rights reserved. import unittest @@ -46,14 +46,16 @@ def test_user_prompt(self): Original> 生き残る秘訣は、進化し続けることです。 Translation>''' - assert self.prompter.user(1, user_input, ['test chunk1 summary', 'test chunk2 summary'], - 'test scene content') == self.formatted_user_input + self.assertEqual( + self.prompter.user(1, user_input, ['test chunk1 summary', 'test chunk2 summary'], 'test scene content'), + self.formatted_user_input + ) def test_format_texts(self): texts = [(1, '変わりゆく時代において、'), (2, '生き残る秘訣は、進化し続けることです。')] expected_output = '#1\nOriginal>\n変わりゆく時代において、\nTranslation>\n\n#2\nOriginal>\n' \ '生き残る秘訣は、進化し続けることです。\nTranslation>\n' - assert BaseTranslatePrompter.format_texts(texts) == expected_output + self.assertEqual(BaseTranslatePrompter.format_texts(texts), expected_output) def test_check_format(self): messages = [{'role': 'system', 'content': 'system content'}, @@ -82,4 +84,4 @@ def test_check_format(self): Summary Scene ''' - assert self.prompter.check_format(messages, content) is True + self.assertTrue(self.prompter.check_format(messages, content)) diff --git a/tests/test_translate.py b/tests/test_translate.py index 49122da..0c55720 100644 --- a/tests/test_translate.py +++ b/tests/test_translate.py @@ -25,17 +25,15 @@ def test_single_chunk_translation(self): translator = LLMTranslator(chatbot_model) translation = translator.translate(text, 'en', 'es')[0] - assert get_similarity(translation, 'Hola, ¿cómo estás?') > 0.618 - self.tearDown() + self.assertGreater(get_similarity(translation, 'Hola, ¿cómo estás?'), 0.618) def test_multiple_chunk_translation(self): for chatbot_model in test_models: texts = ['Hello, how are you?', 'I am fine, thank you.'] translator = LLMTranslator(chatbot_model) translations = translator.translate(texts, 'en', 'es') - assert get_similarity(translations[0], 'Hola, ¿cómo estás?') > 0.618 - assert get_similarity(translations[1], 'Estoy bien, gracias.') > 0.618 - self.tearDown() + self.assertGreater(get_similarity(translations[0], 'Hola, ¿cómo estás?'), 0.618) + self.assertGreater(get_similarity(translations[1], 'Estoy bien, gracias.'), 0.618) def test_different_language_translation(self): for chatbot_model in test_models: @@ -43,28 +41,27 @@ def test_different_language_translation(self): translator = LLMTranslator(chatbot_model) try: translation = translator.translate(text, 'en', 'ja')[0] - assert (get_similarity(translation, 'こんにちは、お元気ですか?') > 0.618 or - get_similarity(translation, 'こんにちは、調子はどうですか?') > 0.618) + self.assertTrue( + get_similarity(translation, 'こんにちは、お元気ですか?') > 0.618 or + get_similarity(translation, 'こんにちは、調子はどうですか?') > 0.618 + ) except (openai.OpenAIError, anthropic.APIError): pass - self.tearDown() def test_empty_text_list_translation(self): for chatbot_model in test_models: texts = [] translator = LLMTranslator(chatbot_model) translations = translator.translate(texts, 'en', 'es') - assert translations == [] - self.tearDown() + self.assertEqual(translations, []) def test_atomic_translate(self): for chatbot_model in test_models: texts = ['Hello, how are you?', 'I am fine, thank you.'] translator = LLMTranslator(chatbot_model) translations = translator.atomic_translate(texts, 'en', 'zh') - assert get_similarity(translations[0], '你好,你好吗?') > 0.618 - assert get_similarity(translations[1], '我很好,谢谢。') > 0.618 - self.tearDown() + self.assertGreater(get_similarity(translations[0], '你好,你好吗?'), 0.618) + self.assertGreater(get_similarity(translations[1], '我很好,谢谢。'), 0.618) # Not integrated by the openlrc main function because of performance # diff --git a/tests/test_utils.py b/tests/test_utils.py index a8f85f3..b5aca4f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -20,73 +20,60 @@ def tearDown(self) -> None: self.video_file.with_suffix('.wav').unlink(missing_ok=True) def test_extract_audio(self): - # Test extracting audio from a video file extracted_audio_file = extract_audio(self.video_file) - assert extracted_audio_file == self.video_file.with_suffix('.wav') + self.assertEqual(extracted_audio_file, self.video_file.with_suffix('.wav')) - # Test extracting audio from an audio file extracted_audio_file = extract_audio(self.audio_file) - assert extracted_audio_file == self.audio_file + self.assertEqual(extracted_audio_file, self.audio_file) - # Test extracting audio from an unsupported file type with self.assertRaises(RuntimeError): extract_audio(self.unsupported) def test_get_file_type(self): - # Test getting the file type of video file - file_type = get_file_type(self.video_file) - assert file_type == 'video' + self.assertEqual(get_file_type(self.video_file), 'video') + self.assertEqual(get_file_type(self.audio_file), 'audio') - # Test getting the file type of audio file - file_type = get_file_type(self.audio_file) - assert file_type == 'audio' - - # Test getting the file type of unsupported file type with self.assertRaises(RuntimeError): get_file_type(self.unsupported) def test_lrc_format(self): - assert format_timestamp(1.2345, 'lrc') == '00:01.23' - assert format_timestamp(61.2345, 'lrc') == '01:01.23' - assert format_timestamp(3661.2345, 'lrc') == '01:01.23' + self.assertEqual(format_timestamp(1.2345, 'lrc'), '00:01.23') + self.assertEqual(format_timestamp(61.2345, 'lrc'), '01:01.23') + self.assertEqual(format_timestamp(3661.2345, 'lrc'), '01:01.23') - assert parse_timestamp('1:23.45', 'lrc') == 83.45 - assert parse_timestamp('0:00.01', 'lrc') == 0.01 - assert parse_timestamp('10:00.00', 'lrc') == 600.0 + self.assertEqual(parse_timestamp('1:23.45', 'lrc'), 83.45) + self.assertEqual(parse_timestamp('0:00.01', 'lrc'), 0.01) + self.assertEqual(parse_timestamp('10:00.00', 'lrc'), 600.0) def test_srt_format(self): - assert format_timestamp(1.2345, 'srt') == '00:00:01,234' - assert format_timestamp(61.2345, 'srt') == '00:01:01,234' - assert format_timestamp(3661.2345, 'srt') == '01:01:01,234' + self.assertEqual(format_timestamp(1.2345, 'srt'), '00:00:01,234') + self.assertEqual(format_timestamp(61.2345, 'srt'), '00:01:01,234') + self.assertEqual(format_timestamp(3661.2345, 'srt'), '01:01:01,234') - assert parse_timestamp('01:23:45,678', 'srt') == 5025.678 - assert parse_timestamp('00:00:01,000', 'srt') == 1.0 - assert parse_timestamp('01:00:00,000', 'srt') == 3600.0 + self.assertEqual(parse_timestamp('01:23:45,678', 'srt'), 5025.678) + self.assertEqual(parse_timestamp('00:00:01,000', 'srt'), 1.0) + self.assertEqual(parse_timestamp('01:00:00,000', 'srt'), 3600.0) def test_negative_timestamp(self): with self.assertRaises(AssertionError): format_timestamp(-1.2345, 'lrc') - with self.assertRaises(ValueError): parse_timestamp('-1:23.45', 'lrc') - with self.assertRaises(AssertionError): format_timestamp(-1.2345, 'srt') - with self.assertRaises(ValueError): parse_timestamp('-01:23:45,678', 'srt') def test_invalid_format(self): with self.assertRaises(ValueError): format_timestamp(1.2345, 'invalid') - with self.assertRaises(ValueError): parse_timestamp('1:23.45', 'invalid') def test_get_text_token_number(self): - assert get_text_token_number('Hello, world!') == 4 - assert get_text_token_number('This is a longer sentence.') == 6 - assert get_text_token_number('') == 0 + self.assertEqual(get_text_token_number('Hello, world!'), 4) + self.assertEqual(get_text_token_number('This is a longer sentence.'), 6) + self.assertEqual(get_text_token_number(''), 0) def test_get_messages_token_number(self): messages = [ @@ -94,7 +81,7 @@ def test_get_messages_token_number(self): {'content': 'This is a longer sentence.'}, {'content': ''} ] - assert get_messages_token_number(messages) == 10 + self.assertEqual(get_messages_token_number(messages), 10) messages = [ {'content': 'Hello, world!'}, @@ -102,37 +89,36 @@ def test_get_messages_token_number(self): {'content': ''}, {'content': 'Another message.'} ] - assert get_messages_token_number(messages) == 13 + self.assertEqual(get_messages_token_number(messages), 13) def test_extend_filename(self): - assert extend_filename(Path('file.txt'), '_new') == Path('file_new.txt') - assert extend_filename(Path('file.txt'), '') == Path('file.txt') + self.assertEqual(extend_filename(Path('file.txt'), '_new'), Path('file_new.txt')) + self.assertEqual(extend_filename(Path('file.txt'), ''), Path('file.txt')) def test_release_memory(self): model = torch.nn.Module() if torch.cuda.is_available(): model.cuda() release_memory(model) - assert torch.cuda.memory_allocated() == 0 + self.assertEqual(torch.cuda.memory_allocated(), 0) def test_normalize(self): - # Full width to half width & lower case alphabet_fw = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' alphabet_hw = 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz' - assert normalize(alphabet_fw) == alphabet_hw + self.assertEqual(normalize(alphabet_fw), alphabet_hw) number_fw = '0123456789' number_hw = '0123456789' - assert normalize(number_fw) == number_hw + self.assertEqual(normalize(number_fw), number_hw) sign_fw = '!#$%&()*+,-./:;<=>?@[]^_`{|}”’¥~' sign_hw = '!#$%&()*+,-./:;<=>?@[]^_`{|}"\'¥~' - assert normalize(sign_fw) == sign_hw + self.assertEqual(normalize(sign_fw), sign_hw) kana_fw = 'アイウエオカキクケコサシスセソタチツテトナニヌネノハヒフヘホマミムメモヤユヨラリルレロワヲンガギグゲゴザジズゼゾダヂヅデドバビブベボパピプペポヴァィゥェォッャュョ・ー、。・「」' kana_hw = 'アイウエオカキクケコサシスセソタチツテトナニヌネノハヒフヘホマミムメモヤユヨラリルレロワヲンガギグゲゴザジズゼゾダヂヅデドバビブベボパピプペポヴァィゥェォッャュョ・ー、。・「」' - assert normalize(kana_fw) == kana_hw + self.assertEqual(normalize(kana_fw), kana_hw) space_fw = ' ' # Full-width space (U+3000) space_hw = ' ' # Half-width space (U+0020) - assert normalize(space_fw) == space_hw + self.assertEqual(normalize(space_fw), space_hw)