Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements on ContextReviewer agent & Minor tweaks. #46

Merged
merged 5 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ jobs:

runs-on: ${{ matrix.os }}
steps:
- name: Delete huge unnecessary tools folder
run: rm -rf /opt/hostedtoolcache

- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ if __name__ == '__main__':
lrcer.run('./data/test.mp3', target_lang='zh-cn')

# Clear temp folder after processing done
lrcer.run('./data/test.mp3', target_lang='zh-cn', clear_temp_folder=True)
lrcer.run('./data/test.mp3', target_lang='zh-cn', clear_temp=True)

# Change base_url
lrcer = LRCer(base_url_config={'openai': 'https://api.g4f.icu/v1',
Expand Down
60 changes: 53 additions & 7 deletions openlrc/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from openlrc.chatbot import route_chatbot, GPTBot, ClaudeBot
from openlrc.context import TranslationContext, TranslateInfo
from openlrc.logger import logger
from openlrc.prompter import BaseTranslatePrompter, ContextReviewPrompter, POTENTIAL_PREFIX_COMBOS, \
ProofreaderPrompter, PROOFREAD_PREFIX
from openlrc.prompter import ChunkedTranslatePrompter, ContextReviewPrompter, ProofreaderPrompter, PROOFREAD_PREFIX, \
ContextReviewerValidatePrompter
from openlrc.validators import POTENTIAL_PREFIX_COMBOS


class Agent(abc.ABC):
Expand All @@ -32,13 +33,13 @@ class ChunkedTranslatorAgent(Agent):
TEMPERATURE = 1.0

def __init__(self, src_lang, target_lang, info: TranslateInfo = TranslateInfo(),
chatbot_model: str = 'gpt-3.5-turbo', fee_limit: float = 0.25, proxy: str = None,
chatbot_model: str = 'gpt-3.5-turbo', fee_limit: float = 0.3, proxy: str = None,
base_url_config: Optional[dict] = None):
super().__init__()
self.chatbot_model = chatbot_model
self.info = info
self.chatbot = self._initialize_chatbot(chatbot_model, fee_limit, proxy, base_url_config)
self.prompter = BaseTranslatePrompter(src_lang, target_lang, info)
self.prompter = ChunkedTranslatePrompter(src_lang, target_lang, info)
self.cost = 0

def __str__(self):
Expand Down Expand Up @@ -106,30 +107,75 @@ class ContextReviewerAgent(Agent):
TODO: Add chunking support.
"""

TEMPERATURE = 0.8
TEMPERATURE = 0.6

def __init__(self, src_lang, target_lang, info: TranslateInfo = TranslateInfo(),
chatbot_model: str = 'gpt-3.5-turbo', fee_limit: float = 0.25, proxy: str = None,
chatbot_model: str = 'gpt-3.5-turbo', retry_model=None,
fee_limit: float = 0.3, proxy: str = None,
base_url_config: Optional[dict] = None):
super().__init__()
self.src_lang = src_lang
self.target_lang = target_lang
self.info = info
self.chatbot_model = chatbot_model
self.validate_prompter = ContextReviewerValidatePrompter()
self.prompter = ContextReviewPrompter(src_lang, target_lang)
self.chatbot = self._initialize_chatbot(chatbot_model, fee_limit, proxy, base_url_config)
self.retry_chatbot = self._initialize_chatbot(
retry_model, fee_limit, proxy, base_url_config
) if retry_model else None

def __str__(self):
return f'Context Reviewer Agent ({self.chatbot_model})'

def _validate_context(self, context: str) -> bool:
messages_list = [
{'role': 'system', 'content': self.validate_prompter.system()},
{'role': 'user', 'content': self.validate_prompter.user(context)},
]
resp = self.chatbot.message(messages_list, output_checker=self.validate_prompter.check_format)[0]
return 'true' in self.chatbot.get_content(resp).lower()

def build_context(self, texts, title='', glossary: Optional[dict] = None) -> str:
text_content = '\n'.join(texts)

messages_list = [
{'role': 'system', 'content': self.prompter.system()},
{'role': 'user', 'content': self.prompter.user(text_content, title=title, given_glossary=glossary)},
]
resp = self.chatbot.message(messages_list, output_checker=self.prompter.check_format)[0]
context = self.chatbot.get_content(resp)

context_pool = [context]
# Validate
if not self._validate_context(context):
validated = False
if self.retry_chatbot:
logger.info(f'Failed to validate the context using {self.chatbot}, retrying with {self.retry_chatbot}')
resp = self.retry_chatbot.message(messages_list, output_checker=self.validate_prompter.check_format)[0]
context = self.retry_chatbot.get_content(resp)
context_pool.append(context)
if self._validate_context(context):
validated = True
else:
logger.warning(f'Failed to validate the context using {self.retry_chatbot}: {context}')

if not validated:
for i in range(2, 4):
logger.warning(f'Retry to generate the context using {self.chatbot} at {i} reties.')
resp = self.chatbot.message(messages_list, output_checker=self.validate_prompter.check_format)[0]
context = self.chatbot.get_content(resp)
context_pool.append(context)
if self._validate_context(context):
validated = True
break

if not validated:
logger.warning(
f'Finally failed to validate the context: {context}, you may check the context manually.')
context = max(context_pool, key=len)
logger.info(f'Now using the longest context: {context}')

return context


Expand All @@ -140,7 +186,7 @@ class ProofreaderAgent(Agent):
TEMPERATURE = 0.8

def __init__(self, src_lang, target_lang, info: TranslateInfo = TranslateInfo(),
chatbot_model: str = 'gpt-3.5-turbo', fee_limit: float = 0.25, proxy: str = None,
chatbot_model: str = 'gpt-3.5-turbo', fee_limit: float = 0.3, proxy: str = None,
base_url_config: Optional[dict] = None):
super().__init__()
self.src_lang = src_lang
Expand Down
7 changes: 5 additions & 2 deletions openlrc/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def route_chatbot(model):
class ChatBot:
pricing = None

def __init__(self, pricing, temperature=1, top_p=1, retry=8, max_async=16, fee_limit=0.25):
def __init__(self, pricing, temperature=1, top_p=1, retry=8, max_async=16, fee_limit=0.3):
self.pricing = pricing
self._model = None

Expand Down Expand Up @@ -155,6 +155,9 @@ def message(self, messages_list: Union[List[Dict], List[List[Dict]]],

return results

def __str__(self):
return f'ChatBot ({self.model})'


@_register_chatbot
class GPTBot(ChatBot):
Expand Down Expand Up @@ -251,7 +254,7 @@ class ClaudeBot(ChatBot):
'claude-3-5-sonnet-20240620': (3, 15),
}

def __init__(self, model='claude-3-sonnet-20240229', temperature=1, top_p=1, retry=8, max_async=16, fee_limit=0.25,
def __init__(self, model='claude-3-sonnet-20240229', temperature=1, top_p=1, retry=8, max_async=16, fee_limit=0.3,
proxy=None, base_url_config=None):

# clamp temperature to 0-1
Expand Down
34 changes: 22 additions & 12 deletions openlrc/openlrc.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ class LRCer:
retry_model: The model to use when retrying the translation. Default: None
"""

def __init__(self, whisper_model='large-v3', compute_type='float16', device='cuda', chatbot_model: str = 'gpt-3.5-turbo',
fee_limit=0.25, consumer_thread=4, asr_options=None, vad_options=None, preprocess_options=None,
proxy=None, base_url_config=None, glossary: Union[dict, str, Path] = None, retry_model=None):
def __init__(self, whisper_model='large-v3', compute_type='float16', device='cuda',
chatbot_model: str = 'gpt-3.5-turbo', fee_limit=0.3, consumer_thread=4, asr_options=None,
vad_options=None, preprocess_options=None, proxy=None, base_url_config=None,
glossary: Union[dict, str, Path] = None, retry_model=None):
self.chatbot_model = chatbot_model
self.fee_limit = fee_limit
self.api_fee = 0 # Can be updated in different thread, operation should be thread-safe
Expand Down Expand Up @@ -181,7 +182,9 @@ def consumer_worker(self, transcription_queue, target_lang, skip_trans, bilingua
return

# Copy preprocessed/xxx_preprocessed.lrc or preprocessed/xxx_preprocessed.srt to xxx.lrc or xxx.srt
subtitle_format = 'srt' if audio_name in self.from_video else 'lrc'
original_name_wo_suffix = transcribed_path.parents[
1] / f"{transcribed_path.name.replace('_preprocessed_transcribed.json', '')}"
subtitle_format = 'srt' if original_name_wo_suffix in self.from_video else 'lrc'
subtitle_path = getattr(final_subtitle, f'to_{subtitle_format}')()
result_path = subtitle_path.parents[1] / subtitle_path.name.replace(f'_preprocessed.{subtitle_format}',
f'.{subtitle_format}')
Expand All @@ -191,6 +194,8 @@ def consumer_worker(self, transcription_queue, target_lang, skip_trans, bilingua
bilingual_subtitle = BilingualSubtitle.from_preprocessed(
transcribed_path.parent, audio_name.replace('_preprocessed', '')
)
bilingual_optimizer = SubtitleOptimizer(bilingual_subtitle)
bilingual_optimizer.extend_time()
# TODO: consider the edge case (audio file name contains _preprocessed)
getattr(bilingual_subtitle, f'to_{subtitle_format}')()
bilingual_lrc_path = bilingual_subtitle.filename.with_suffix(bilingual_subtitle.suffix)
Expand Down Expand Up @@ -250,7 +255,7 @@ def _translate(self, audio_name, target_lang, transcribed_opt_sub, translated_pa
return final_subtitle

def run(self, paths: Union[str, Path, List[Union[str, Path]]], src_lang: Optional[str] = None, target_lang='zh-cn',
skip_trans=False, noise_suppress=False, bilingual_sub=False, clear_temp_folder=False) -> List[str]:
skip_trans=False, noise_suppress=False, bilingual_sub=False, clear_temp=False) -> List[str]:
"""
Split the translation into 2 phases: transcription and translation. They're running in parallel.
Firstly, transcribe the audios one-by-one. At the same time, translation threads are created and waiting for
Expand All @@ -264,7 +269,7 @@ def run(self, paths: Union[str, Path, List[Union[str, Path]]], src_lang: Optiona
skip_trans (bool): Whether to skip the translation process. (Default to False)
noise_suppress (bool): Whether to suppress the noise in the audio. (Default to False)
bilingual_sub (bool): Whether to generate bilingual subtitles. (Default to False)
clear_temp_folder (bool): Whether to clear the temporary folder.
clear_temp (bool): Whether to clear all the temporary files, including the generated .wav from video.
Note, set this back to False to see more intermediate results if error encountered. (Default to False)

Returns:
Expand Down Expand Up @@ -305,14 +310,13 @@ def run(self, paths: Union[str, Path, List[Union[str, Path]]], src_lang: Optiona

logger.info(f'Totally used API fee: {self.api_fee:.4f} USD')

if clear_temp_folder:
if clear_temp:
logger.info('Clearing temporary folder...')
self.clear_temp_files(audio_paths)

return self.transcribed_paths

@staticmethod
def clear_temp_files(paths):
def clear_temp_files(self, paths):
"""
Clear the temporary files generated during the transcription and translation process.
"""
Expand All @@ -323,6 +327,12 @@ def clear_temp_files(paths):
shutil.rmtree(folder)
logger.debug(f'Removed {folder}')

for input_video_path in self.from_video:
generated_wave = input_video_path.with_suffix('.wav')
if generated_wave.exists():
generated_wave.unlink()
logger.debug(f'Removed generated wav (from video): {generated_wave}')

@staticmethod
def to_json(segments: List[Segment], name, lang):
result = {
Expand Down Expand Up @@ -352,10 +362,10 @@ def pre_process(self, paths, noise_suppress=False):
if not path.exists() or not path.is_file():
raise FileNotFoundError(f'File not found: {path}')

paths[i] = extract_audio(path)

if get_file_type(path) == 'video':
self.from_video.add(path.stem + '_preprocessed')
self.from_video.add(path.with_suffix(''))

paths[i] = extract_audio(path)

# Audio-based process
preprocessor = Preprocessor(paths, options=self.preprocess_options)
Expand Down
12 changes: 10 additions & 2 deletions openlrc/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import zhconv

from openlrc.logger import logger
from openlrc.subtitle import Subtitle
from openlrc.subtitle import Subtitle, BilingualSubtitle
from openlrc.utils import extend_filename, format_timestamp

# Thresholds for different languages
Expand Down Expand Up @@ -42,7 +42,7 @@ class SubtitleOptimizer:
SubtitleOptimizer class is used to optimize subtitles by performing various operations.
"""

def __init__(self, subtitle: Union[Path, Subtitle]):
def __init__(self, subtitle: Union[Path, Subtitle, BilingualSubtitle]):
if isinstance(subtitle, Path):
subtitle = Subtitle.from_json(subtitle)

Expand Down Expand Up @@ -139,6 +139,10 @@ def cut_long(self, max_length=20):
"""
Cut long texts based on language-specific thresholds.
"""
if isinstance(self.subtitle, BilingualSubtitle):
logger.warning('Bilingual subtitle is not supported for cut_long operation.')
return

threshold = CUT_LONG_THRESHOLD.get(self.lang.lower(), 150)

for element in self.subtitle.segments:
Expand All @@ -157,6 +161,10 @@ def punctuation_optimization(self):
"""
Replace English punctuation with Chinese punctuation.
"""
if isinstance(self.subtitle, BilingualSubtitle):
logger.warning('Bilingual subtitle is not supported for punctuation_optimization operation.')
return

for element in self.subtitle.segments:
element.text = self._replace_punctuation_with_chinese(element.text)

Expand Down
Loading
Loading