github-actions[bot]
Auto-sync from demo at Thu Aug 28 10:06:44 UTC 2025
ba8b592
raw
history blame
4.21 kB
import html
import json
import os
import re
from typing import Any
from .log import logger
def pack_history_conversations(*args: str):
roles = ["user", "assistant"]
return [
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
]
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
"""Split a string by multiple markers"""
if not markers:
return [content]
results = re.split("|".join(re.escape(marker) for marker in markers), content)
return [r.strip() for r in results if r.strip()]
# Refer the utils functions of the official GraphRAG implementation:
# https://github.com/microsoft/graphrag
def clean_str(input: Any) -> str:
"""Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
# If we get non-string input, just give it back
if not isinstance(input, str):
return input
result = html.unescape(input.strip())
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
async def handle_single_entity_extraction(
record_attributes: list[str],
chunk_key: str,
):
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
return None
# add this record as a node in the G
entity_name = clean_str(record_attributes[1].upper())
if not entity_name.strip():
return None
entity_type = clean_str(record_attributes[2].upper())
entity_description = clean_str(record_attributes[3])
entity_source_id = chunk_key
return {
"entity_name": entity_name,
"entity_type": entity_type,
"description": entity_description,
"source_id": entity_source_id,
}
def is_float_regex(value):
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
async def handle_single_relationship_extraction(
record_attributes: list[str],
chunk_key: str,
):
if len(record_attributes) < 4 or record_attributes[0] != '"relationship"':
return None
# add this record as edge
source = clean_str(record_attributes[1].upper())
target = clean_str(record_attributes[2].upper())
edge_description = clean_str(record_attributes[3])
edge_source_id = chunk_key
return {
"src_id": source,
"tgt_id": target,
"description": edge_description,
"source_id": edge_source_id,
}
def load_json(file_name):
if not os.path.exists(file_name):
return None
with open(file_name, encoding="utf-8") as f:
return json.load(f)
def write_json(json_obj, file_name):
if not os.path.exists(os.path.dirname(file_name)):
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, "w", encoding="utf-8") as f:
json.dump(json_obj, f, indent=4, ensure_ascii=False)
def format_generation_results(
results: dict[str, Any], output_data_format: str
) -> list[dict[str, Any]]:
if output_data_format == "Alpaca":
logger.info("Output data format: Alpaca")
results = [
{
"instruction": item["question"],
"input": "",
"output": item["answer"],
}
for item in list(results.values())
]
elif output_data_format == "Sharegpt":
logger.info("Output data format: Sharegpt")
results = [
{
"conversations": [
{"from": "human", "value": item["question"]},
{"from": "gpt", "value": item["answer"]},
]
}
for item in list(results.values())
]
elif output_data_format == "ChatML":
logger.info("Output data format: ChatML")
results = [
{
"messages": [
{"role": "user", "content": item["question"]},
{"role": "assistant", "content": item["answer"]},
]
}
for item in list(results.values())
]
else:
raise ValueError(f"Unknown output data format: {output_data_format}")
return results