Spaces:
Runtime error
Runtime error
import html | |
import json | |
import os | |
import re | |
from typing import Any | |
from .log import logger | |
def pack_history_conversations(*args: str): | |
roles = ["user", "assistant"] | |
return [ | |
{"role": roles[i % 2], "content": content} for i, content in enumerate(args) | |
] | |
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]: | |
"""Split a string by multiple markers""" | |
if not markers: | |
return [content] | |
results = re.split("|".join(re.escape(marker) for marker in markers), content) | |
return [r.strip() for r in results if r.strip()] | |
# Refer the utils functions of the official GraphRAG implementation: | |
# https://github.com/microsoft/graphrag | |
def clean_str(input: Any) -> str: | |
"""Clean an input string by removing HTML escapes, control characters, and other unwanted characters.""" | |
# If we get non-string input, just give it back | |
if not isinstance(input, str): | |
return input | |
result = html.unescape(input.strip()) | |
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python | |
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result) | |
async def handle_single_entity_extraction( | |
record_attributes: list[str], | |
chunk_key: str, | |
): | |
if len(record_attributes) < 4 or record_attributes[0] != '"entity"': | |
return None | |
# add this record as a node in the G | |
entity_name = clean_str(record_attributes[1].upper()) | |
if not entity_name.strip(): | |
return None | |
entity_type = clean_str(record_attributes[2].upper()) | |
entity_description = clean_str(record_attributes[3]) | |
entity_source_id = chunk_key | |
return { | |
"entity_name": entity_name, | |
"entity_type": entity_type, | |
"description": entity_description, | |
"source_id": entity_source_id, | |
} | |
def is_float_regex(value): | |
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value)) | |
async def handle_single_relationship_extraction( | |
record_attributes: list[str], | |
chunk_key: str, | |
): | |
if len(record_attributes) < 4 or record_attributes[0] != '"relationship"': | |
return None | |
# add this record as edge | |
source = clean_str(record_attributes[1].upper()) | |
target = clean_str(record_attributes[2].upper()) | |
edge_description = clean_str(record_attributes[3]) | |
edge_source_id = chunk_key | |
return { | |
"src_id": source, | |
"tgt_id": target, | |
"description": edge_description, | |
"source_id": edge_source_id, | |
} | |
def load_json(file_name): | |
if not os.path.exists(file_name): | |
return None | |
with open(file_name, encoding="utf-8") as f: | |
return json.load(f) | |
def write_json(json_obj, file_name): | |
if not os.path.exists(os.path.dirname(file_name)): | |
os.makedirs(os.path.dirname(file_name), exist_ok=True) | |
with open(file_name, "w", encoding="utf-8") as f: | |
json.dump(json_obj, f, indent=4, ensure_ascii=False) | |
def format_generation_results( | |
results: dict[str, Any], output_data_format: str | |
) -> list[dict[str, Any]]: | |
if output_data_format == "Alpaca": | |
logger.info("Output data format: Alpaca") | |
results = [ | |
{ | |
"instruction": item["question"], | |
"input": "", | |
"output": item["answer"], | |
} | |
for item in list(results.values()) | |
] | |
elif output_data_format == "Sharegpt": | |
logger.info("Output data format: Sharegpt") | |
results = [ | |
{ | |
"conversations": [ | |
{"from": "human", "value": item["question"]}, | |
{"from": "gpt", "value": item["answer"]}, | |
] | |
} | |
for item in list(results.values()) | |
] | |
elif output_data_format == "ChatML": | |
logger.info("Output data format: ChatML") | |
results = [ | |
{ | |
"messages": [ | |
{"role": "user", "content": item["question"]}, | |
{"role": "assistant", "content": item["answer"]}, | |
] | |
} | |
for item in list(results.values()) | |
] | |
else: | |
raise ValueError(f"Unknown output data format: {output_data_format}") | |
return results | |