diff --git a/pyproject.toml b/pyproject.toml index 470109b..5e4b98d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "ragchecker" -version = "0.1.5" +version = "0.1.6" description = "RAGChecker: A Fine-grained Framework For Diagnosing Retrieval-Augmented Generation (RAG) systems." authors = [ "Xiangkun Hu ", diff --git a/ragchecker/evaluator.py b/ragchecker/evaluator.py index 8558d78..f5e04bb 100644 --- a/ragchecker/evaluator.py +++ b/ragchecker/evaluator.py @@ -53,6 +53,7 @@ def __init__( joint_check_num=5, sagemaker_client=None, sagemaker_params=None, + sagemaker_get_response_func=None, **kwargs ): if openai_api_key: @@ -63,6 +64,7 @@ def __init__( self.kwargs = kwargs self.sagemaker_client = sagemaker_client self.sagemaker_params = sagemaker_params + self.sagemaker_get_response_func = sagemaker_get_response_func self.extractor = LLMExtractor( model=extractor_name, @@ -111,6 +113,7 @@ def extract_claims(self, results: List[RAGResult], extract_type="gt_answer"): max_new_tokens=self.extractor_max_new_tokens, sagemaker_client=self.sagemaker_client, sagemaker_params=self.sagemaker_params, + sagemaker_get_response_func=self.sagemaker_get_response_func, **self.kwargs ) claims = [[c.content for c in res.claims] for res in extraction_results] @@ -173,6 +176,7 @@ def check_claims(self, results: RAGResults, check_type="answer2response"): joint_check_num=self.joint_check_num, sagemaker_client=self.sagemaker_client, sagemaker_params=self.sagemaker_params, + sagemaker_get_response_func=self.sagemaker_get_response_func, **self.kwargs ) for i, result in enumerate(results):