diff --git a/src/ragas/metrics/base.py b/src/ragas/metrics/base.py index 8d3457e0e2..cca5996c6a 100644 --- a/src/ragas/metrics/base.py +++ b/src/ragas/metrics/base.py @@ -2,12 +2,16 @@ import asyncio import logging +import re +import nltk + import typing as t from abc import ABC, abstractmethod from collections import Counter from dataclasses import dataclass, field from enum import Enum - +from pysbd.cleaner import Cleaner +from pysbd.utils import TextSpan from pysbd import Segmenter from ragas.callbacks import ChainType, new_group @@ -15,7 +19,7 @@ from ragas.executor import is_event_loop_running from ragas.prompt import PromptMixin from ragas.run_config import RunConfig -from ragas.utils import RAGAS_SUPPORTED_LANGUAGE_CODES, deprecated +from ragas.utils import RAGAS_SUPPORTED_LANGUAGE_CODES, RAGAS_SUPPORTED_LANGUAGE_CODES_PYSBD, deprecated if t.TYPE_CHECKING: from langchain_core.callbacks import Callbacks @@ -452,16 +456,21 @@ def get_segmenter( """ Get a sentence segmenter for a given language """ + language = language.lower() if language not in RAGAS_SUPPORTED_LANGUAGE_CODES: raise ValueError( f"Language '{language}' not supported. Supported languages: {RAGAS_SUPPORTED_LANGUAGE_CODES.keys()}" ) - return Segmenter( - language=RAGAS_SUPPORTED_LANGUAGE_CODES[language], - clean=clean, - char_span=char_span, - ) + + if language in RAGAS_SUPPORTED_LANGUAGE_CODES_PYSBD: + return Segmenter( + language=RAGAS_SUPPORTED_LANGUAGE_CODES_PYSBD[language], + clean=clean, + char_span=char_span, + ) + else: + return NLTKSegmenter(language=language, char_span=char_span) def is_reproducable(metric: Metric) -> bool: @@ -472,3 +481,46 @@ def is_reproducable(metric: Metric) -> bool: ensembler = Ensember() + + +class NLTKSegmenter: + def __init__(self, language: str = "english", char_span: bool = False, clean: bool = False): + self.language = language.lower() + self.char_span = char_span + self.clean = clean + + def sentences_with_char_spans(self, sentences): + sent_spans = [] + prior_end_char_idx = 0 + for sent in sentences: + for match in re.finditer('{0}\s*'.format(re.escape(sent)), self.original_text): + match_str = match.group() + match_start_idx, match_end_idx = match.span() + if match_end_idx > prior_end_char_idx: + sent_spans.append( + TextSpan(match_str, match_start_idx, match_end_idx)) + prior_end_char_idx = match_end_idx + break + return sent_spans + + def cleaner(self, text): + return Cleaner(text, self.language_module) + + def segment(self, text): + self.original_text = text + if not text: + return [] + + if self.clean: + text = self.cleaner(text).clean() + + postprocessed_sents = nltk.tokenize.sent_tokenize(text, language=self.language) + sentence_w_char_spans = self.sentences_with_char_spans(postprocessed_sents) + if self.char_span: + return sentence_w_char_spans + elif self.clean: + # clean and destructed sentences + return postprocessed_sents + else: + # nondestructive with whitespaces + return [textspan.sent for textspan in sentence_w_char_spans] \ No newline at end of file diff --git a/src/ragas/prompt/mixin.py b/src/ragas/prompt/mixin.py index 79e551298a..e2caf890bc 100644 --- a/src/ragas/prompt/mixin.py +++ b/src/ragas/prompt/mixin.py @@ -53,7 +53,8 @@ def set_prompts(self, **prompts): setattr(self, key, value) async def adapt_prompts( - self, language: str, llm: BaseRagasLLM, adapt_instruction: bool = False + self, language: str, llm: BaseRagasLLM, adapt_instruction: bool = False, + google_translate: bool = False ) -> t.Dict[str, PydanticPrompt]: """ Adapts the prompts in the class to the given language and using the given LLM. @@ -67,7 +68,7 @@ async def adapt_prompts( prompts = self.get_prompts() adapted_prompts = {} for name, prompt in prompts.items(): - adapted_prompt = await prompt.adapt(language, llm, adapt_instruction) + adapted_prompt = await prompt.adapt(language, llm, adapt_instruction, google_translate) adapted_prompts[name] = adapted_prompt return adapted_prompts diff --git a/src/ragas/prompt/pydantic_prompt.py b/src/ragas/prompt/pydantic_prompt.py index 950252ec8e..0232a859ef 100644 --- a/src/ragas/prompt/pydantic_prompt.py +++ b/src/ragas/prompt/pydantic_prompt.py @@ -4,12 +4,15 @@ import json import logging import os +import hashlib + import typing as t from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers import PydanticOutputParser from langchain_core.prompt_values import StringPromptValue as PromptValue from pydantic import BaseModel +from deep_translator import GoogleTranslator from ragas._version import __version__ from ragas.callbacks import ChainType, new_group @@ -29,21 +32,126 @@ InputModel = t.TypeVar("InputModel", bound=BaseModel) OutputModel = t.TypeVar("OutputModel", bound=BaseModel) +def hash_recursive(obj): + hash_obj = hashlib.sha256() + + def _hash_recursive(obj): + if isinstance(obj, (tuple, list, set)): + for indice, value in enumerate(obj): + hash_obj.update(str(indice).encode()) + hash_recursive(value) + elif isinstance(obj, dict): + for key, value in sorted(obj.items()): + hash_obj.update(str(key).encode()) + hash_recursive(value) + elif hasattr(obj, '__dict__'): + for name in dir(obj): + if not name.startswith('_'): + value = getattr(obj, name) + if not callable(value): + hash_obj.update(name.encode()) + hash_recursive(value) + else: + hash_obj.update(str(obj).encode()) + + _hash_recursive(obj) + + return int.from_bytes(hash_obj.digest(), 'big') + +class PydanticPromptStrings: + output_signature = ( + "Please return the output in a format that complies with the " + "following schema as specified in JSON Schema and OpenAPI specification:" + ) + examples_intro = "These are some examples to show how to perform the above instruction" + instruction_prompt = "Now perform the above instruction with the following input" + only_json = "Respond only with a valid JSON object that complies with the specified schema." + + language = 'english' + + async def adapt(self, target_language: str, llm: BaseRagasLLM | None, translate_with_google: bool) -> PydanticPromptStrings: + if self.language == target_language: + return self + + data = copy.deepcopy(self) + + translator = None + if translate_with_google: + translator = GoogleTranslator(source='en', target=target_language) + + vars = dir(data) + vars.remove('language') + + for name in vars: + value = getattr(data, name) + if not callable(value) and not name.startswith('_'): + if translate_with_google: + translated_str = translator.translate(value) + else: + translated_str = await translate_statements_prompt.generate( + llm=llm, + data=ToTranslate( + target_language=target_language, statements=[value] + ), + ) + setattr(data, name, translated_str) + + data.language = target_language + + return data + + def save(self, file_path: str): + data = {} + + for name in dir(self): + value = getattr(self, name) + if not callable(value) and not name.startswith('_'): + data[name] = value + + data['ragas_version'] = __version__ + if os.path.exists(file_path): + raise FileExistsError(f"The file '{file_path}' already exists.") + with open(file_path, "w", encoding='utf-8') as f: + json.dump(data, f, indent=2, ensure_ascii=False) + print(f"Prompt strings saved to {file_path}") + + @staticmethod + def load(file_path: str) -> PydanticPromptStrings: + with open(file_path, "r") as f: + data = json.load(f) + + prompt_strings = PydanticPromptStrings() + + ragas_version = data.pop("ragas_version") + + if ragas_version != __version__: + logger.warning( + "Prompt strings were saved with Ragas v%s, but you are loading it with Ragas v%s. " + "There might be incompatibilities.", + ragas_version, + __version__, + ) + + for name, value in data.items(): + setattr(prompt_strings, name, value) + + return prompt_strings + class PydanticPrompt(BasePrompt, t.Generic[InputModel, OutputModel]): input_model: t.Type[InputModel] output_model: t.Type[OutputModel] instruction: str examples: t.List[t.Tuple[InputModel, OutputModel]] = [] + strings: PydanticPromptStrings = PydanticPromptStrings() def _generate_instruction(self) -> str: return self.instruction def _generate_output_signature(self, indent: int = 4) -> str: return ( - f"Please return the output in a JSON format that complies with the " - f"following schema as specified in JSON Schema and OpenAPI specification:\n" - f"{self.output_model.model_json_schema()}" + self.strings.output_signature + '\n' + + str(self.output_model.model_json_schema()) ) def _generate_examples(self): @@ -62,7 +170,7 @@ def _generate_examples(self): ) return ( - "These are some examples to show how to perform the above instruction\n" + self.strings.examples_intro + '\n' + "\n\n".join(example_strings) ) # if no examples are provided @@ -76,13 +184,13 @@ def to_string(self, data: t.Optional[InputModel] = None) -> str: + self._generate_output_signature() + "\n" + self._generate_examples() - + "\nNow perform the above instruction with the following input\n" + + "\n" + self.strings.instruction_prompt + '\n' + ( "input: " + data.model_dump_json(indent=4) + "\n" if data is not None else "input: (None)\n" ) - + "Respond only with a valid JSON object that complies with the specified schema.\n" + + self.strings.only_json + "\n" + "output: " ) @@ -224,12 +332,18 @@ def process_output(self, output: OutputModel, input: InputModel) -> OutputModel: return output async def adapt( - self, target_language: str, llm: BaseRagasLLM, adapt_instruction: bool = False + self, target_language: str, llm: BaseRagasLLM | None, adapt_instruction: bool = False, + translate_with_google: bool = False ) -> "PydanticPrompt[InputModel, OutputModel]": """ Adapt the prompt to a new language. """ - + if not translate_with_google and llm is None: + raise ValueError("You must provide an LLM if you are not using Google Translate.") + + if self.language == target_language: + return copy.deepcopy(self) + # throws ValueError if language is not supported _check_if_language_is_supported(target_language) @@ -239,29 +353,47 @@ async def adapt( self.original_hash = hash(self) strings = get_all_strings(self.examples) - translated_strings = await translate_statements_prompt.generate( - llm=llm, - data=ToTranslate(target_language=target_language, statements=strings), - ) + translator = None - translated_examples = update_strings( - obj=self.examples, - old_strings=strings, - new_strings=translated_strings.statements, - ) + if translate_with_google: + translator = GoogleTranslator(source=self.language, target=target_language) + translated_strings = [translator.translate(s) for s in strings] + + translated_examples = update_strings( + obj=self.examples, + old_strings=strings, + new_strings=translated_strings, + ) + else: + translated_strings = await translate_statements_prompt.generate( + llm=llm, + data=ToTranslate(target_language=target_language, statements=strings), + ) + + translated_examples = update_strings( + obj=self.examples, + old_strings=strings, + new_strings=translated_strings.statements, + ) new_prompt = copy.deepcopy(self) new_prompt.examples = translated_examples new_prompt.language = target_language + new_group.strings = await self.strings.adapt(target_language, llm, translate_with_google) if adapt_instruction: - translated_instruction = await translate_statements_prompt.generate( - llm=llm, - data=ToTranslate( - target_language=target_language, statements=[self.instruction] - ), - ) - new_prompt.instruction = translated_instruction.statements[0] + if translate_with_google: + translated_instruction = translator.translate(self.instruction) + else: + translated_instruction = await translate_statements_prompt.generate( + llm=llm, + data=ToTranslate( + target_language=target_language, statements=[self.instruction] + ), + ) + translated_instruction = translated_instruction.statements[0] + + new_prompt.instruction = translated_instruction return new_prompt @@ -284,25 +416,7 @@ def __str__(self): return f"{self.__class__.__name__}({json_str})" def __hash__(self): - # convert examples to json string for hashing - examples = [] - for example in self.examples: - input_model, output_model = example - examples.append( - (input_model.model_dump_json(), output_model.model_dump_json()) - ) - - # not sure if input_model and output_model should be included - return hash( - ( - self.name, - self.input_model, - self.output_model, - self.instruction, - *examples, - self.language, - ) - ) + return hash_recursive(self) def __eq__(self, other): if not isinstance(other, PydanticPrompt): @@ -338,6 +452,12 @@ def save(self, file_path: str): json.dump(data, f, indent=2, ensure_ascii=False) print(f"Prompt saved to {file_path}") + dir, _ = os.path.split(file_path) + path_strings = os.path.join(dir, f'strings_{__version__}.json') + + if not os.path.exists(path_strings): + self.strings.save(path_strings) + @classmethod def load(cls, file_path: str) -> "PydanticPrompt[InputModel, OutputModel]": with open(file_path, "r") as f: @@ -368,6 +488,12 @@ def load(cls, file_path: str) -> "PydanticPrompt[InputModel, OutputModel]": prompt.examples = examples prompt.language = data.get("language", prompt.language) + dir, _ = os.path.split(file_path) + path_strings = os.path.join(dir, f'strings_{__version__}.json') + + if os.path.exists(path_strings): + prompt.strings = PydanticPromptStrings.load(path_strings) + # Optionally, verify the loaded prompt's hash matches the saved hash if original_hash is not None and hash(prompt) != original_hash: logger.warning("Loaded prompt hash does not match the saved hash.") @@ -397,7 +523,7 @@ async def parse_output_string( prompt_value: PromptValue, llm: BaseRagasLLM, callbacks: Callbacks, - retries_left: int = 1, + retries_left: int = 3, ): callbacks = callbacks or [] try: diff --git a/src/ragas/utils.py b/src/ragas/utils.py index 59a835a7d6..160f6c8dce 100644 --- a/src/ragas/utils.py +++ b/src/ragas/utils.py @@ -3,6 +3,7 @@ import logging import os import re +import nltk import typing as t import warnings from functools import lru_cache @@ -10,15 +11,51 @@ import numpy as np from datasets import Dataset from pysbd.languages import LANGUAGE_CODES +from datasets import Dataset +from deep_translator import GoogleTranslator if t.TYPE_CHECKING: from ragas.metrics.base import Metric + DEBUG_ENV_VAR = "RAGAS_DEBUG" -RAGAS_SUPPORTED_LANGUAGE_CODES = { + +nltk.download('punkt_tab') + +path = nltk.data.find('tokenizers/punkt_tab').path + +slovene = os.path.join(path, 'slovene') + +if os.path.exists(slovene): + os.rename(slovene, os.path.join(path, 'slovenian')) + +dirs = os.listdir(path) +supported_languages = [item for item in dirs if os.path.isdir(os.path.join(path, item))] + +supported_languages = [lang.split('.')[0] for lang in supported_languages] + +RAGAS_SUPPORTED_LANGUAGE_CODES_GOOGLE = GoogleTranslator().get_supported_languages(as_dict=True) + +RAGAS_SUPPORTED_LANGUAGE_CODES_NLTK = { + k.lower(): RAGAS_SUPPORTED_LANGUAGE_CODES_GOOGLE[k] for k in supported_languages +} + +RAGAS_SUPPORTED_LANGUAGE_CODES_PYSBD = { v.__name__.lower(): k for k, v in LANGUAGE_CODES.items() } +RAGAS_SUPPORTED_LANGUAGE_CODES_PYSBD['chinese (simplified)'] = RAGAS_SUPPORTED_LANGUAGE_CODES_PYSBD['chinese'] +RAGAS_SUPPORTED_LANGUAGE_CODES_PYSBD['myanmar'] = RAGAS_SUPPORTED_LANGUAGE_CODES_PYSBD['burmese'] +RAGAS_SUPPORTED_LANGUAGE_CODES_PYSBD['german'] = RAGAS_SUPPORTED_LANGUAGE_CODES_PYSBD['deutsch'] + +del RAGAS_SUPPORTED_LANGUAGE_CODES_PYSBD['chinese'] +del RAGAS_SUPPORTED_LANGUAGE_CODES_PYSBD['burmese'] +del RAGAS_SUPPORTED_LANGUAGE_CODES_PYSBD['deutsch'] + +RAGAS_SUPPORTED_LANGUAGE_CODES = { + **RAGAS_SUPPORTED_LANGUAGE_CODES_NLTK, + **{k: RAGAS_SUPPORTED_LANGUAGE_CODES_GOOGLE[k] for k, v in RAGAS_SUPPORTED_LANGUAGE_CODES_PYSBD.items() if k not in RAGAS_SUPPORTED_LANGUAGE_CODES_NLTK} +} @lru_cache(maxsize=1) def get_cache_dir() -> str: