CVAgentArena / app.py
Peiran
typo fixed
c759ce1
import os, uuid, csv, random
from datetime import datetime
from PIL import Image
import gradio as gr
# —— 1. 环境 & 文件准备 ——
os.environ["GRADIO_SSR_MODE"] = "False" # 关掉 SSR
# 确保 data 目录及子目录存在
os.makedirs("data/images/task0/orig_imgs", exist_ok=True)
os.makedirs("data/images/task0/processed_imgs", exist_ok=True)
os.makedirs("data/images/task1/orig_imgs", exist_ok=True)
os.makedirs("data/images/task1/processed_imgs", exist_ok=True)
os.makedirs("data/images/task2/orig_imgs", exist_ok=True)
os.makedirs("data/images/task2/processed_imgs", exist_ok=True)
os.makedirs("data/images/task3/orig_imgs", exist_ok=True)
os.makedirs("data/images/task3/processed_imgs", exist_ok=True)
os.makedirs("data/images/task4/orig_imgs", exist_ok=True)
os.makedirs("data/images/task4/processed_imgs", exist_ok=True)
os.makedirs("data/images/task5/orig_imgs", exist_ok=True)
os.makedirs("data/images/task5/processed_imgs", exist_ok=True)
os.makedirs("data/images/task6/orig_imgs", exist_ok=True)
os.makedirs("data/images/task6/processed_imgs", exist_ok=True)
# 在文件开头添加必要的目录创建
os.makedirs("data/evaluations", exist_ok=True)
os.makedirs("data/metadatas", exist_ok=True)
meta0 = "data/metadatas/meta0.csv"
meta1 = "data/metadatas/meta1.csv"
meta2 = "data/metadatas/meta2.csv"
meta3 = "data/metadatas/meta3.csv"
meta4 = "data/metadatas/meta4.csv"
meta5 = "data/metadatas/meta5.csv"
meta6 = "data/metadatas/meta6.csv"
if not os.path.exists(meta0):
with open(meta0, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=[
"id","original_path","prompt",
"agent1","img1_path","agent2","img2_path"
])
writer.writeheader()
if not os.path.exists(meta1):
with open(meta1, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=[
"id","original_path","prompt",
"agent1","img1_path","agent2","img2_path"
])
writer.writeheader()
if not os.path.exists(meta2):
with open(meta2, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=[
"id","original_path","prompt",
"agent1","img1_path","agent2","img2_path"
])
writer.writeheader()
if not os.path.exists(meta3):
with open(meta3, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=[
"id","original_path","prompt",
"agent1","img1_path","agent2","img2_path"
])
writer.writeheader()
if not os.path.exists(meta4):
with open(meta4, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=[
"id","original_path","prompt",
"agent1","img1_path","agent2","img2_path"
])
writer.writeheader()
if not os.path.exists(meta5):
with open(meta5, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=[
"id","original_path","prompt",
"agent1","img1_path","agent2","img2_path"
])
writer.writeheader()
if not os.path.exists(meta6):
with open(meta6, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=[
"id","original_path","prompt",
"agent1","img1_path","agent2","img2_path"
])
writer.writeheader()
eval0 = "data/evaluations/eval0.csv"
eval1 = "data/evaluations/eval1.csv"
eval2 = "data/evaluations/eval2.csv"
eval3 = "data/evaluations/eval3.csv"
eval4 = "data/evaluations/eval4.csv"
eval5 = "data/evaluations/eval5.csv"
eval6 = "data/evaluations/eval6.csv"
if not os.path.exists(eval0):
with open(eval0, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=[
"timestamp", "record_id",
"a1_follow","a1_creativity","a1_finesse",
"a2_follow","a2_creativity","a2_finesse"
])
writer.writeheader()
if not os.path.exists(eval1):
with open(eval1, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=[
"timestamp", "record_id",
"a1_follow","a1_creativity","a1_finesse",
"a2_follow","a2_creativity","a2_finesse"
])
writer.writeheader()
if not os.path.exists(eval2):
with open(eval2, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=[
"timestamp", "record_id",
"a1_follow","a1_creativity","a1_finesse",
"a2_follow","a2_creativity","a2_finesse"
])
writer.writeheader()
if not os.path.exists(eval3):
with open(eval3, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=[
"timestamp", "record_id",
"a1_follow","a1_creativity","a1_finesse",
"a2_follow","a2_creativity","a2_finesse"
])
writer.writeheader()
if not os.path.exists(eval4):
with open(eval4, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=[
"timestamp", "record_id",
"a1_follow","a1_creativity","a1_finesse",
"a2_follow","a2_creativity","a2_finesse"
])
writer.writeheader()
if not os.path.exists(eval5):
with open(eval5, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=[
"timestamp", "record_id",
"a1_follow","a1_creativity","a1_finesse",
"a2_follow","a2_creativity","a2_finesse"
])
writer.writeheader()
if not os.path.exists(eval6):
with open(eval6, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=[
"timestamp", "record_id",
"a1_follow","a1_creativity","a1_finesse",
"a2_follow","a2_creativity","a2_finesse"
])
writer.writeheader()
def run_agent_on_image(original_img: Image.Image, prompt: str, agent_name: str) -> Image.Image:
if original_img is None:
raise ValueError("Input image cannot be None")
if not prompt or prompt.strip() == "":
raise ValueError("Prompt cannot be empty")
return original_img # TODO: implement actual agent processing
def save_to_library(task_id, orig_img, prompt, a1, a2, img1, img2):
try:
if any(img is None for img in [orig_img, img1, img2]):
raise ValueError("All images must be valid")
if not prompt or prompt.strip() == "":
raise ValueError("Prompt cannot be empty")
orig_id = uuid.uuid4().hex
orig_path = f"data/images/task{task_id}/orig_imgs/{orig_id}.png"
img1_path = f"data/images/task{task_id}/processed_imgs/{orig_id}_a1.png"
img2_path = f"data/images/task{task_id}/processed_imgs/{orig_id}_a2.png"
# 使用 try-except 处理图片保存
try:
orig_img.save(orig_path)
img1.save(img1_path)
img2.save(img2_path)
except Exception as e:
raise IOError(f"Failed to save images: {str(e)}")
# 使用 try-except 处理 CSV 写入
try:
with open(f"data/metadatas/meta{task_id}.csv", "a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=[
"id","original_path", "prompt",
"agent1","img1_path","agent2","img2_path"
])
writer.writerow({
"id": orig_id,
"original_path": orig_path,
"prompt": prompt,
"agent1": a1,
"img1_path": img1_path,
"agent2": a2,
"img2_path": img2_path
})
except Exception as e:
# 如果写入CSV失败,清理已保存的图片
for path in [orig_path, img1_path, img2_path]:
if os.path.exists(path):
os.remove(path)
raise IOError(f"Failed to write metadata: {str(e)}")
except Exception as e:
raise Exception(f"Error in save_to_library: {str(e)}")
def generate_and_store(task_id, orig_img, prompt, a1, a2):
try:
if orig_img is None:
return None, None
if not prompt or prompt.strip() == "":
return None, None
if a1 == a2:
return None, None # 不允许选择相同的Agent
out1 = run_agent_on_image(orig_img, prompt, a1)
out2 = run_agent_on_image(orig_img, prompt, a2)
save_to_library(task_id, orig_img, prompt, a1, a2, out1, out2)
return out1, out2
except Exception as e:
print(f"Error in generate_and_store: {str(e)}")
return None, None
def load_random_record(task_id):
try:
# 检查文件是否存在
meta_file = f"data/metadatas/meta{task_id}.csv"
if not os.path.exists(meta_file):
return "", None, "Metadata file not found", None, None
# 读取所有记录
with open(meta_file, "r", encoding="utf-8") as f:
all_records = list(csv.DictReader(f))
if not all_records:
return "", None, "No records in library", None, None
# 读取最近5分钟内的评测记录
recent_evaluated_ids = set()
current_time = datetime.now()
eval_file = f"data/evaluations/eval{task_id}.csv"
if os.path.exists(eval_file):
try:
with open(eval_file, "r", encoding="utf-8") as f:
eval_records = list(csv.DictReader(f))
for record in eval_records:
try:
eval_time = datetime.fromisoformat(record["timestamp"])
time_diff = (current_time - eval_time).total_seconds() / 60
if time_diff <= 5:
recent_evaluated_ids.add(record["record_id"])
except ValueError:
# 跳过无效的时间戳
continue
except Exception as e:
print(f"Error reading evaluation file: {str(e)}")
available_records = [r for r in all_records if r["id"] not in recent_evaluated_ids]
if not available_records:
return "", None, "All available records have been recently evaluated", None, None
rec = random.choice(available_records)
# 验证图片文件是否存在
for path in [rec["original_path"], rec["img1_path"], rec["img2_path"]]:
if not os.path.exists(path):
return "", None, f"Image file not found: {path}", None, None
return (
rec["id"],
rec["original_path"],
rec["prompt"],
rec["img1_path"],
rec["img2_path"]
)
except Exception as e:
return "", None, f"Error loading record: {str(e)}", None, None
def save_evaluation(task_id, record_id,
a1_follow, a1_creativity, a1_finesse,
a2_follow, a2_creativity, a2_finesse):
try:
# 验证输入
if not record_id:
return "❌ Invalid record ID", *load_random_record(task_id)
# 验证评分
scores = [a1_follow, a1_creativity, a1_finesse,
a2_follow, a2_creativity, a2_finesse]
if any(score is None for score in scores):
return "❌ Please complete all evaluations", *load_random_record(task_id)
with open(f"data/evaluations/eval{task_id}.csv", "a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=[
"timestamp", "record_id",
"a1_follow","a1_creativity","a1_finesse",
"a2_follow","a2_creativity","a2_finesse"
])
writer.writerow({
"timestamp": datetime.now().isoformat(),
"record_id": record_id,
"a1_follow": a1_follow,
"a1_creativity": a1_creativity,
"a1_finesse": a1_finesse,
"a2_follow": a2_follow,
"a2_creativity": a2_creativity,
"a2_finesse": a2_finesse
})
return "✅ Evaluation submitted!", *load_random_record(task_id)
except Exception as e:
return f"❌ Error saving evaluation: {str(e)}", *load_random_record(task_id)
MODEL_CHOICES = ["Model A", "Model B", "Model C"]
TASK_CHOICES = [
"Image Restoration",
"Image Enhancement",
"Domain & Style Transfer",
"Semantic-Aware Editing",
"Image Composition & Expansion",
"Face & Appeal Editing",
"Steganography & Security Handling"
]
with gr.Blocks() as demo:
with gr.Tabs():
# ——— Tab 1: Agent Arena ———
with gr.TabItem("Agent Arena"):
gr.Markdown("## CV Agent Arena 🎨🤖")
with gr.Row():
with gr.Column():
task_dropdown = gr.Dropdown(choices=TASK_CHOICES, label="Task Category", type="index")
original = gr.Image(type="pil", label="Upload Original Image")
prompt = gr.Textbox(lines=2, label="Prompt",
placeholder="e.g. Make it look like a sunny day")
with gr.Column():
agent1 = gr.Dropdown(choices=MODEL_CHOICES, label="Select Agent 1")
agent2 = gr.Dropdown(choices=MODEL_CHOICES, label="Select Agent 2")
run_btn = gr.Button("Run Agents")
with gr.Row():
out1 = gr.Image(type="pil", label="Agent 1 Output")
out2 = gr.Image(type="pil", label="Agent 2 Output")
run_btn.click(
fn=generate_and_store,
inputs=[task_dropdown, original, prompt, agent1, agent2],
outputs=[out1, out2],
show_api=False
)
# ——— Tab 2: Human as Judge ———
with gr.TabItem("Human as Judge"):
record_id_state = gr.State("")
task_dropdown = gr.Dropdown(choices=TASK_CHOICES, label="Task Category", type="index")
# 原图与 Prompt 并排
with gr.Row():
judge_orig = gr.Image(label="Original Image")
judge_prompt = gr.Textbox(label="Prompt", interactive=False)
# 两张结果图并排
with gr.Row():
judge_out1 = gr.Image(label="Agent 1 Result")
judge_out2 = gr.Image(label="Agent 2 Result")
# 当选 Task 时加载随机样本
task_dropdown.change(
fn=load_random_record,
inputs=[task_dropdown],
outputs=[record_id_state, judge_orig, judge_prompt, judge_out1, judge_out2],
show_api=False
)
with gr.Row():
gr.Markdown(
"## Please Evaluate the Processed Images from 3 Aspects",
elem_classes=["center-text"]
)
with gr.Row():
with gr.Column():
a1_follow = gr.Radio([0,1,2,3,4,5], label="Follow Prompt")
a1_creativity = gr.Radio([0,1,2,3,4,5], label="Creativity")
a1_finesse = gr.Radio([0,1,2,3,4,5], label="Finesse/Detail")
with gr.Column():
a2_follow = gr.Radio([0,1,2,3,4,5], label="Follow Prompt")
a2_creativity = gr.Radio([0,1,2,3,4,5], label="Creativity")
a2_finesse = gr.Radio([0,1,2,3,4,5], label="Finesse/Detail")
submit_btn = gr.Button("Submit Evaluation")
submit_status = gr.Textbox(label="Status", interactive=False)
submit_btn.click(
fn=save_evaluation,
inputs=[
task_dropdown, record_id_state,
a1_follow, a1_creativity, a1_finesse,
a2_follow, a2_creativity, a2_finesse
],
outputs=[
submit_status,
record_id_state, judge_orig, judge_prompt, judge_out1, judge_out2
],
show_api=False
)
demo.queue()
demo.launch(
share=False,
show_api=False,
ssr_mode=False
)