GraphGen / webui /app.py
github-actions[bot]
Auto-sync from demo at Wed Sep 10 11:42:39 UTC 2025
88a9a7b
raw
history blame
19.6 kB
import json
import os
import sys
import tempfile
from importlib.resources import files
import gradio as gr
import pandas as pd
from dotenv import load_dotenv
from graphgen.graphgen import GraphGen
from graphgen.models import OpenAIModel, Tokenizer
from graphgen.models.llm.limitter import RPM, TPM
from graphgen.utils import set_logger
from webui.base import GraphGenParams
from webui.cache_utils import cleanup_workspace, setup_workspace
from webui.count_tokens import count_tokens
from webui.i18n import Translate
from webui.i18n import gettext as _
from webui.test_api import test_api_connection
root_dir = files("webui").parent
sys.path.append(root_dir)
load_dotenv()
css = """
.center-row {
display: flex;
justify-content: center;
align-items: center;
}
"""
def init_graph_gen(config: dict, env: dict) -> GraphGen:
# Set up working directory
log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
set_logger(log_file, if_stream=True)
graph_gen = GraphGen(working_dir=working_dir, config=config)
# Set up LLM clients
graph_gen.synthesizer_llm_client = OpenAIModel(
model_name=env.get("SYNTHESIZER_MODEL", ""),
base_url=env.get("SYNTHESIZER_BASE_URL", ""),
api_key=env.get("SYNTHESIZER_API_KEY", ""),
request_limit=True,
rpm=RPM(env.get("RPM", 1000)),
tpm=TPM(env.get("TPM", 50000)),
)
graph_gen.trainee_llm_client = OpenAIModel(
model_name=env.get("TRAINEE_MODEL", ""),
base_url=env.get("TRAINEE_BASE_URL", ""),
api_key=env.get("TRAINEE_API_KEY", ""),
request_limit=True,
rpm=RPM(env.get("RPM", 1000)),
tpm=TPM(env.get("TPM", 50000)),
)
graph_gen.tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
return graph_gen
# pylint: disable=too-many-statements
def run_graphgen(params, progress=gr.Progress()):
def sum_tokens(client):
return sum(u["total_tokens"] for u in client.token_usage)
config = {
"if_trainee_model": params.if_trainee_model,
"input_file": params.input_file,
"output_data_type": params.output_data_type,
"output_data_format": params.output_data_format,
"tokenizer": params.tokenizer,
"search": {"enabled": False},
"quiz_and_judge_strategy": {
"enabled": params.if_trainee_model,
"quiz_samples": params.quiz_samples,
},
"traverse_strategy": {
"bidirectional": params.bidirectional,
"expand_method": params.expand_method,
"max_extra_edges": params.max_extra_edges,
"max_tokens": params.max_tokens,
"max_depth": params.max_depth,
"edge_sampling": params.edge_sampling,
"isolated_node_strategy": params.isolated_node_strategy,
"loss_strategy": params.loss_strategy,
},
"chunk_size": params.chunk_size,
}
env = {
"SYNTHESIZER_BASE_URL": params.synthesizer_url,
"SYNTHESIZER_MODEL": params.synthesizer_model,
"TRAINEE_BASE_URL": params.trainee_url,
"TRAINEE_MODEL": params.trainee_model,
"SYNTHESIZER_API_KEY": params.api_key,
"TRAINEE_API_KEY": params.trainee_api_key,
"RPM": params.rpm,
"TPM": params.tpm,
}
# Test API connection
test_api_connection(
env["SYNTHESIZER_BASE_URL"],
env["SYNTHESIZER_API_KEY"],
env["SYNTHESIZER_MODEL"],
)
if config["if_trainee_model"]:
test_api_connection(
env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"]
)
# Load input data
file = config["input_file"]
if isinstance(file, list):
file = file[0]
data = []
if file.endswith(".jsonl"):
config["input_data_type"] = "raw"
with open(file, "r", encoding="utf-8") as f:
data.extend(json.loads(line) for line in f)
elif file.endswith(".json"):
config["input_data_type"] = "chunked"
with open(file, "r", encoding="utf-8") as f:
data.extend(json.load(f))
elif file.endswith(".txt"):
# 读取文件后根据chunk_size转成raw格式的数据
config["input_data_type"] = "raw"
content = ""
with open(file, "r", encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
content += line.strip() + " "
size = int(config.get("chunk_size", 512))
chunks = [content[i : i + size] for i in range(0, len(content), size)]
data.extend([{"content": chunk} for chunk in chunks])
else:
raise ValueError(f"Unsupported file type: {file}")
# Initialize GraphGen
graph_gen = init_graph_gen(config, env)
graph_gen.clear()
graph_gen.progress_bar = progress
try:
# Process the data
graph_gen.insert()
if config["if_trainee_model"]:
# Generate quiz
graph_gen.quiz()
# Judge statements
graph_gen.judge()
else:
graph_gen.traverse_strategy.edge_sampling = "random"
# Traverse graph
graph_gen.traverse()
# Save output
output_data = graph_gen.qa_storage.data
with tempfile.NamedTemporaryFile(
mode="w", suffix=".jsonl", delete=False, encoding="utf-8"
) as tmpfile:
json.dump(output_data, tmpfile, ensure_ascii=False)
output_file = tmpfile.name
synthesizer_tokens = sum_tokens(graph_gen.synthesizer_llm_client)
trainee_tokens = (
sum_tokens(graph_gen.trainee_llm_client)
if config["if_trainee_model"]
else 0
)
total_tokens = synthesizer_tokens + trainee_tokens
data_frame = params.token_counter
try:
_update_data = [
[data_frame.iloc[0, 0], data_frame.iloc[0, 1], str(total_tokens)]
]
new_df = pd.DataFrame(_update_data, columns=data_frame.columns)
data_frame = new_df
except Exception as e:
raise gr.Error(f"DataFrame operation error: {str(e)}")
return output_file, gr.DataFrame(
label="Token Stats",
headers=["Source Text Token Count", "Expected Token Usage", "Token Used"],
datatype="str",
interactive=False,
value=data_frame,
visible=True,
wrap=True,
)
except Exception as e: # pylint: disable=broad-except
raise gr.Error(f"Error occurred: {str(e)}")
finally:
# Clean up workspace
cleanup_workspace(graph_gen.working_dir)
with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
# Header
gr.Image(
value=os.path.join(root_dir, "resources", "images", "logo.png"),
label="GraphGen Banner",
elem_id="banner",
interactive=False,
container=False,
show_download_button=False,
show_fullscreen_button=False,
)
lang_btn = gr.Radio(
choices=[
("English", "en"),
("简体中文", "zh"),
],
value="en",
# label=_("Language"),
render=False,
container=False,
elem_classes=["center-row"],
)
gr.HTML(
"""
<div style="display: flex; gap: 8px; margin-left: auto; align-items: center; justify-content: center;">
<a href="https://github.com/open-sciencelab/GraphGen/releases">
<img src="https://img.shields.io/badge/Version-v0.1.0-blue" alt="Version">
</a>
<a href="https://graphgen-docs.example.com">
<img src="https://img.shields.io/badge/Docs-Latest-brightgreen" alt="Documentation">
</a>
<a href="https://github.com/open-sciencelab/GraphGen/issues/10">
<img src="https://img.shields.io/github/stars/open-sciencelab/GraphGen?style=social" alt="GitHub Stars">
</a>
<a href="https://arxiv.org/abs/2505.20416">
<img src="https://img.shields.io/badge/arXiv-pdf-yellow" alt="arXiv">
</a>
</div>
"""
)
with Translate(
os.path.join(root_dir, "webui", "translation.json"),
lang_btn,
placeholder_langs=["en", "zh"],
persistant=False, # True to save the language setting in the browser. Requires gradio >= 5.6.0
):
lang_btn.render()
gr.Markdown(
value="# "
+ _("Title")
+ "\n\n"
+ "### [GraphGen](https://github.com/open-sciencelab/GraphGen) "
+ _("Intro")
)
if_trainee_model = gr.Checkbox(
label=_("Use Trainee Model"), value=False, interactive=True
)
with gr.Accordion(label=_("Model Config"), open=False):
synthesizer_url = gr.Textbox(
label="Synthesizer URL",
value="https://api.siliconflow.cn/v1",
info=_("Synthesizer URL Info"),
interactive=True,
)
synthesizer_model = gr.Textbox(
label="Synthesizer Model",
value="Qwen/Qwen2.5-7B-Instruct",
info=_("Synthesizer Model Info"),
interactive=True,
)
trainee_url = gr.Textbox(
label="Trainee URL",
value="https://api.siliconflow.cn/v1",
info=_("Trainee URL Info"),
interactive=True,
visible=if_trainee_model.value is True,
)
trainee_model = gr.Textbox(
label="Trainee Model",
value="Qwen/Qwen2.5-7B-Instruct",
info=_("Trainee Model Info"),
interactive=True,
visible=if_trainee_model.value is True,
)
trainee_api_key = gr.Textbox(
label=_("SiliconFlow Token for Trainee Model"),
type="password",
value="",
info="https://cloud.siliconflow.cn/account/ak",
visible=if_trainee_model.value is True,
)
with gr.Accordion(label=_("Generation Config"), open=False):
chunk_size = gr.Slider(
label="Chunk Size",
minimum=256,
maximum=4096,
value=512,
step=256,
interactive=True,
)
tokenizer = gr.Textbox(
label="Tokenizer", value="cl100k_base", interactive=True
)
output_data_type = gr.Radio(
choices=["atomic", "multi_hop", "aggregated"],
label="Output Data Type",
value="aggregated",
interactive=True,
)
output_data_format = gr.Radio(
choices=["Alpaca", "Sharegpt", "ChatML"],
label="Output Data Format",
value="Alpaca",
interactive=True,
)
quiz_samples = gr.Number(
label="Quiz Samples",
value=2,
minimum=1,
interactive=True,
visible=if_trainee_model.value is True,
)
bidirectional = gr.Checkbox(
label="Bidirectional", value=True, interactive=True
)
expand_method = gr.Radio(
choices=["max_width", "max_tokens"],
label="Expand Method",
value="max_tokens",
interactive=True,
)
max_extra_edges = gr.Slider(
minimum=1,
maximum=10,
value=5,
label="Max Extra Edges",
step=1,
interactive=True,
visible=expand_method.value == "max_width",
)
max_tokens = gr.Slider(
minimum=64,
maximum=1024,
value=256,
label="Max Tokens",
step=64,
interactive=True,
visible=(expand_method.value != "max_width"),
)
max_depth = gr.Slider(
minimum=1,
maximum=5,
value=2,
label="Max Depth",
step=1,
interactive=True,
)
edge_sampling = gr.Radio(
choices=["max_loss", "min_loss", "random"],
label="Edge Sampling",
value="max_loss",
interactive=True,
visible=if_trainee_model.value is True,
)
isolated_node_strategy = gr.Radio(
choices=["add", "ignore"],
label="Isolated Node Strategy",
value="ignore",
interactive=True,
)
loss_strategy = gr.Radio(
choices=["only_edge", "both"],
label="Loss Strategy",
value="only_edge",
interactive=True,
)
with gr.Row(equal_height=True):
with gr.Column(scale=3):
api_key = gr.Textbox(
label=_("SiliconFlow Token"),
type="password",
value="",
info="https://cloud.siliconflow.cn/account/ak",
)
with gr.Column(scale=1):
test_connection_btn = gr.Button(_("Test Connection"))
with gr.Blocks():
with gr.Row(equal_height=True):
with gr.Column():
rpm = gr.Slider(
label="RPM",
minimum=10,
maximum=10000,
value=1000,
step=100,
interactive=True,
visible=True,
)
with gr.Column():
tpm = gr.Slider(
label="TPM",
minimum=5000,
maximum=5000000,
value=50000,
step=1000,
interactive=True,
visible=True,
)
with gr.Blocks():
with gr.Row(equal_height=True):
with gr.Column(scale=1):
upload_file = gr.File(
label=_("Upload File"),
file_count="single",
file_types=[".txt", ".json", ".jsonl"],
interactive=True,
)
examples_dir = os.path.join(root_dir, "webui", "examples")
gr.Examples(
examples=[
[os.path.join(examples_dir, "txt_demo.txt")],
[os.path.join(examples_dir, "raw_demo.jsonl")],
[os.path.join(examples_dir, "chunked_demo.json")],
],
inputs=upload_file,
label=_("Example Files"),
examples_per_page=3,
)
with gr.Column(scale=1):
output = gr.File(
label="Output(See Github FAQ)",
file_count="single",
interactive=False,
)
with gr.Blocks():
token_counter = gr.DataFrame(
label="Token Stats",
headers=[
"Source Text Token Count",
"Estimated Token Usage",
"Token Used",
],
datatype="str",
interactive=False,
visible=False,
wrap=True,
)
submit_btn = gr.Button(_("Run GraphGen"))
# Test Connection
test_connection_btn.click(
test_api_connection,
inputs=[synthesizer_url, api_key, synthesizer_model],
outputs=[],
)
if if_trainee_model.value:
test_connection_btn.click(
test_api_connection,
inputs=[trainee_url, api_key, trainee_model],
outputs=[],
)
expand_method.change(
lambda method: (
gr.update(visible=method == "max_width"),
gr.update(visible=method != "max_width"),
),
inputs=expand_method,
outputs=[max_extra_edges, max_tokens],
)
if_trainee_model.change(
lambda use_trainee: [gr.update(visible=use_trainee)] * 5,
inputs=if_trainee_model,
outputs=[
trainee_url,
trainee_model,
quiz_samples,
edge_sampling,
trainee_api_key,
],
)
upload_file.change(
lambda x: (gr.update(visible=True)),
inputs=[upload_file],
outputs=[token_counter],
).then(
count_tokens,
inputs=[upload_file, tokenizer, token_counter],
outputs=[token_counter],
)
# run GraphGen
submit_btn.click(
lambda x: (gr.update(visible=False)),
inputs=[token_counter],
outputs=[token_counter],
)
submit_btn.click(
lambda *args: run_graphgen(
GraphGenParams(
if_trainee_model=args[0],
input_file=args[1],
tokenizer=args[2],
output_data_type=args[3],
output_data_format=args[4],
bidirectional=args[5],
expand_method=args[6],
max_extra_edges=args[7],
max_tokens=args[8],
max_depth=args[9],
edge_sampling=args[10],
isolated_node_strategy=args[11],
loss_strategy=args[12],
synthesizer_url=args[13],
synthesizer_model=args[14],
trainee_model=args[15],
api_key=args[16],
chunk_size=args[17],
rpm=args[18],
tpm=args[19],
quiz_samples=args[20],
trainee_url=args[21],
trainee_api_key=args[22],
token_counter=args[23],
)
),
inputs=[
if_trainee_model,
upload_file,
tokenizer,
output_data_type,
output_data_format,
bidirectional,
expand_method,
max_extra_edges,
max_tokens,
max_depth,
edge_sampling,
isolated_node_strategy,
loss_strategy,
synthesizer_url,
synthesizer_model,
trainee_model,
api_key,
chunk_size,
rpm,
tpm,
quiz_samples,
trainee_url,
trainee_api_key,
token_counter,
],
outputs=[output, token_counter],
)
if __name__ == "__main__":
demo.queue(api_open=False, default_concurrency_limit=2)
demo.launch(server_name="0.0.0.0")