diff --git a/graphrag/index/graph/extractors/graph/graph_extractor.py b/graphrag/index/graph/extractors/graph/graph_extractor.py index 49ca671a72..5fb4a061fd 100644 --- a/graphrag/index/graph/extractors/graph/graph_extractor.py +++ b/graphrag/index/graph/extractors/graph/graph_extractor.py @@ -137,6 +137,7 @@ async def __call__( all_records, prompt_variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER), prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER), + prompt_variables.get(self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER), ) return GraphExtractionResult( @@ -185,6 +186,7 @@ async def _process_results( results: dict[int, str], tuple_delimiter: str, record_delimiter: str, + complete_delimiter: str, ) -> nx.Graph: """Parse the result string to create an undirected unipartite graph. @@ -192,12 +194,15 @@ async def _process_results( - results - dict of results from the extraction chain - tuple_delimiter - delimiter between tuples in an output record, default is '<|>' - record_delimiter - delimiter between records, default is '##' + - complete_delimiter - complete between records, default is '<|COMPLETE|>' Returns: - output - unipartite graph in graphML format """ graph = nx.Graph() for source_doc_id, extracted_data in results.items(): - records = [r.strip() for r in extracted_data.split(record_delimiter)] + records = [r.strip() for r in extracted_data.split(complete_delimiter)] + records = [r.split(record_delimiter) for r in records] + records = [r for sublist in records for r in sublist] for record in records: record = re.sub(r"^\(|\)$", "", record.strip()) @@ -238,7 +243,7 @@ async def _process_results( source_id=str(source_doc_id), ) - if ( + elif ( record_attributes[0] == '"relationship"' and len(record_attributes) >= 5 ): @@ -290,7 +295,6 @@ async def _process_results( description=edge_description, source_id=edge_source_id, ) - return graph