|
|
import os |
|
|
import json |
|
|
import base64 |
|
|
import argparse |
|
|
import time |
|
|
import re |
|
|
import traceback |
|
|
from datetime import datetime |
|
|
from functools import partial |
|
|
import requests |
|
|
from openai import AzureOpenAI, OpenAI |
|
|
from volcenginesdkarkruntime import Ark |
|
|
import concurrent.futures |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
|
|
|
IMAGINE_AGENT_SYSTEM_PROMPT = """ |
|
|
You are an intelligent AI assistant specializing in answering video question-answering problems through reasoning and imagination. |
|
|
Your task is to answer a multiple-choice question based on an initial, limited set of video frames. |
|
|
|
|
|
You will receive a few uniformly sampled frames to get a basic understanding of the video. |
|
|
These frames may not contain all the visual evidence needed to directly answer the question. |
|
|
|
|
|
If the provided frame information is insufficient, you must use the `imagine_frame` tool to generate new, imagined frames to fill in the visual gaps and aid your reasoning. |
|
|
You can call this tool multiple times to construct a sequence of imagined events. |
|
|
|
|
|
Your strategy should be: |
|
|
1. Analyze the initial frames and the user's question. |
|
|
2. Form a hypothesis about the missing content. |
|
|
3. If you need more visual information, call the `imagine_frame` tool. Provide a text `prompt` describing the scene you want to imagine, and select a `reference_image_id` from existing frames. The `reference_image_id` MUST be one of the IDs explicitly provided to you in the conversation history (e.g., "Frame ID: X" or "New Frame ID: Y"). Do not invent or assume frame IDs. |
|
|
4. Analyze the newly generated frame in conjunction with the existing ones. |
|
|
5. Continue this process of reasoning and imagination until you are confident in your answer. Please ensure you have found or created the relevant visual cues before answering the question. |
|
|
6. Each tool call can only generate one frame. |
|
|
|
|
|
IMPORTANT: Your text `prompt` for image generation must be safe and general. Avoid descriptions that could be interpreted as sensitive, harmful, or explicit to prevent generation failures. |
|
|
|
|
|
After your reasoning, provide the final answer in a JSON code block. The JSON object must contain a key "answer" with a value of one of 'A', 'B', 'C', or 'D'. |
|
|
|
|
|
Your output must strictly follow this format: |
|
|
<Your step-by-step reasoning process here, including why you chose to imagine a certain frame> |
|
|
```json |
|
|
{"answer": "X"} |
|
|
``` |
|
|
Do not include any other text after the JSON code block. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
IMAGINE_FRAME_TOOL_SCHEMA = { |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": "imagine_frame", |
|
|
"description": "When visual evidence is insufficient, generates a new image based on a text prompt and a reference image to help answer the question. Use it to imagine what might have happened between the provided frames.", |
|
|
"parameters": { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"reference_image_id": { |
|
|
"type": "integer", |
|
|
"description": "The ID of an existing frame to use as a style and content reference. It can be one of the original frames or a previously generated one.", |
|
|
}, |
|
|
"prompt": { |
|
|
"type": "string", |
|
|
"description": "A detailed text description of the frame you want to imagine and generate.", |
|
|
}, |
|
|
}, |
|
|
"required": ["reference_image_id", "prompt"], |
|
|
}, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def imagine_frame( |
|
|
reference_image_id: int, |
|
|
prompt: str, |
|
|
all_frame_paths: dict, |
|
|
output_dir: str, |
|
|
generation_count: int, |
|
|
): |
|
|
""" |
|
|
Tool implementation: Calls an image generation model to create a new frame. |
|
|
|
|
|
Args: |
|
|
reference_image_id (int): The ID of the reference frame. |
|
|
prompt (str): The text prompt for image generation. |
|
|
all_frame_paths (dict): A dictionary containing IDs and paths of all currently available frames (original + generated). |
|
|
output_dir (str): The directory to save the generated image. |
|
|
generation_count (int): The current generation count, used for naming the file. |
|
|
|
|
|
Returns: |
|
|
str or None: The path of the newly generated image on success, otherwise None. |
|
|
""" |
|
|
print(f"\n[Tool Call] Imagining new frame with prompt: '{prompt}'") |
|
|
ark_api_key = os.environ.get("ARK_API_KEY") |
|
|
if not ark_api_key: |
|
|
raise ValueError("Error: Environment variable ARK_API_KEY is not set.") |
|
|
|
|
|
client = Ark( |
|
|
base_url="https://ark.cn-beijing.volces.com/api/v3", |
|
|
api_key=ark_api_key, |
|
|
) |
|
|
|
|
|
ref_image_path = all_frame_paths.get(reference_image_id) |
|
|
if not ref_image_path or not os.path.exists(ref_image_path): |
|
|
raise FileNotFoundError(f"Reference image ID not found: {reference_image_id}") |
|
|
|
|
|
try: |
|
|
|
|
|
ref_image_b64 = encode_image(ref_image_path) |
|
|
ref_image_data_uri = f"data:image/jpeg;base64,{ref_image_b64}" |
|
|
|
|
|
imagesResponse = client.images.generate( |
|
|
model="doubao-seedream-4-0-250828", |
|
|
prompt=prompt, |
|
|
image=ref_image_data_uri, |
|
|
size="1024x1024", |
|
|
response_format="url", |
|
|
watermark=False, |
|
|
) |
|
|
|
|
|
image_url = imagesResponse.data[0].url |
|
|
|
|
|
|
|
|
response = requests.get(image_url) |
|
|
response.raise_for_status() |
|
|
|
|
|
|
|
|
new_frame_filename = ( |
|
|
f"generated_frame_{generation_count}_ref_{reference_image_id}.jpg" |
|
|
) |
|
|
new_frame_path = os.path.join(output_dir, new_frame_filename) |
|
|
|
|
|
with open(new_frame_path, "wb") as f: |
|
|
f.write(response.content) |
|
|
|
|
|
print(f"[Tool Success] Generated frame saved to: {new_frame_path}") |
|
|
return new_frame_path |
|
|
|
|
|
except Exception as e: |
|
|
print(f"An error occurred during image generation or download: {e}") |
|
|
traceback.print_exc() |
|
|
return None |
|
|
|
|
|
|
|
|
def parse_arguments(): |
|
|
"""Parse command-line arguments""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Video QA Evaluation Framework with Imagine-and-Reason Agent" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--target-model", |
|
|
"-tm", |
|
|
type=str, |
|
|
required=True, |
|
|
help="The model to be evaluated (e.g., gpt-4o)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--frames-path", |
|
|
"-fp", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Absolute path to the root directory containing video frames.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-path", |
|
|
"-op", |
|
|
type=str, |
|
|
default="./generated_outputs", |
|
|
help="Path to store generated images and results.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--data-file", |
|
|
"-df", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Absolute path to the evaluation dataset JSON file.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--initial-frames-num", |
|
|
"-ifn", |
|
|
type=int, |
|
|
default=8, |
|
|
help="Number of initial uniformly sampled frames.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max-retry-times", |
|
|
"-mr", |
|
|
type=int, |
|
|
default=10, |
|
|
help="Maximum number of retries for failed API calls.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--pool-processes", |
|
|
"-pp", |
|
|
type=int, |
|
|
default=10, |
|
|
help="Number of parallel processes.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--base_url", |
|
|
type=str, |
|
|
required=True, |
|
|
help="API Endpoint URL for the target model service.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--api_key", |
|
|
type=str, |
|
|
required=True, |
|
|
help="API Key for the target model service.", |
|
|
) |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def save_json_file(data, output_file): |
|
|
"""Save data to a JSON file""" |
|
|
with open(output_file, "w", encoding="utf-8") as f: |
|
|
json.dump(data, f, indent=4, ensure_ascii=False) |
|
|
|
|
|
|
|
|
def extract_json_from_response(response): |
|
|
"""Extract JSON answer from the model's text response""" |
|
|
if not response: |
|
|
return None |
|
|
match = re.search(r"```json\s*(\{.*?\})\s*```", response, re.DOTALL) |
|
|
if match: |
|
|
try: |
|
|
return json.loads(match.group(1)) |
|
|
except (json.JSONDecodeError, IndexError): |
|
|
return None |
|
|
return None |
|
|
|
|
|
|
|
|
def calculate_metrics(results): |
|
|
"""Calculate various metrics from the evaluation results""" |
|
|
valid_results = [r for r in results if "error" not in r] |
|
|
total_samples = len(valid_results) |
|
|
if total_samples == 0: |
|
|
return { |
|
|
"total_samples": 0, |
|
|
"answered_samples": 0, |
|
|
"correct_answers": 0, |
|
|
"accuracy": 0.0, |
|
|
} |
|
|
answered_samples = sum( |
|
|
1 for x in valid_results if x.get("model_answer") is not None |
|
|
) |
|
|
correct_answers = sum(1 for x in valid_results if x.get("is_correct")) |
|
|
accuracy = correct_answers / answered_samples if answered_samples > 0 else 0.0 |
|
|
return { |
|
|
"total_samples": total_samples, |
|
|
"answered_samples": answered_samples, |
|
|
"correct_answers": correct_answers, |
|
|
"accuracy": accuracy, |
|
|
} |
|
|
|
|
|
|
|
|
def call_single_model(client, messages, model, item_id, max_retry_times, tools=None): |
|
|
"""A single model API call with retry logic""" |
|
|
params = {"model": model, "messages": messages, "max_tokens": 4096} |
|
|
if tools: |
|
|
params["tools"] = tools |
|
|
params["tool_choice"] = "auto" |
|
|
|
|
|
retry_times = 0 |
|
|
while retry_times < max_retry_times: |
|
|
try: |
|
|
completion = client.chat.completions.create(**params) |
|
|
return completion.choices[0].message |
|
|
except Exception as e: |
|
|
retry_times += 1 |
|
|
print( |
|
|
f"API call error (Item {item_id}): {str(e)}. Retrying ({retry_times}/{max_retry_times})..." |
|
|
) |
|
|
if retry_times == max_retry_times: |
|
|
raise e |
|
|
time.sleep(5) |
|
|
|
|
|
|
|
|
def uniformly_sample_frames_and_encode(frames_dir, num_frames): |
|
|
"""Uniformly sample a specified number of frames from a directory and encode them""" |
|
|
if not os.path.isdir(frames_dir): |
|
|
return [], {} |
|
|
|
|
|
frame_files = sorted( |
|
|
[f for f in os.listdir(frames_dir) if f.endswith(".jpg")], |
|
|
key=lambda x: int(re.search(r"frame_(\d+)\.jpg", x).group(1)), |
|
|
) |
|
|
|
|
|
total_frames = len(frame_files) |
|
|
if total_frames == 0: |
|
|
return [], {} |
|
|
|
|
|
if total_frames > num_frames: |
|
|
indices = [int(i * total_frames / num_frames) for i in range(num_frames)] |
|
|
sampled_files = [frame_files[i] for i in indices] |
|
|
else: |
|
|
sampled_files = frame_files |
|
|
|
|
|
frame_path_map = {} |
|
|
encoded_frames = [] |
|
|
for f in sampled_files: |
|
|
path = os.path.join(frames_dir, f) |
|
|
frame_id = int(re.search(r"frame_(\d+)\.jpg", f).group(1)) |
|
|
b64_image = encode_image(path) |
|
|
|
|
|
encoded_frames.append({"type": "text", "text": f"This is Frame ID: {frame_id}"}) |
|
|
encoded_frames.append( |
|
|
{ |
|
|
"type": "image_url", |
|
|
"image_url": {"url": f"data:image/jpeg;base64,{b64_image}"}, |
|
|
} |
|
|
) |
|
|
frame_path_map[frame_id] = path |
|
|
|
|
|
return encoded_frames, frame_path_map |
|
|
|
|
|
|
|
|
def evaluate_single_item_agentic_imagination( |
|
|
data_item, |
|
|
initial_frames, |
|
|
initial_frame_paths, |
|
|
generated_images_dir, |
|
|
target_model, |
|
|
api_key, |
|
|
base_url, |
|
|
max_retry_times, |
|
|
): |
|
|
""" |
|
|
Core logic for evaluating a single data item using the Imagine-and-Reason Agent. |
|
|
""" |
|
|
|
|
|
if "ark" in base_url: |
|
|
client = Ark(base_url=base_url, api_key=api_key) |
|
|
elif "aliyun" in base_url or "127.0.0.1" in base_url: |
|
|
client = OpenAI(api_key=api_key, base_url=base_url) |
|
|
else: |
|
|
client = AzureOpenAI( |
|
|
api_version="2023-05-15", api_key=api_key, azure_endpoint=base_url |
|
|
) |
|
|
|
|
|
tools = [IMAGINE_FRAME_TOOL_SCHEMA] |
|
|
|
|
|
|
|
|
available_frame_paths = initial_frame_paths.copy() |
|
|
|
|
|
initial_prompt_content = [ |
|
|
{ |
|
|
"type": "text", |
|
|
"text": "Here are the initial sampled video frames provided to you:", |
|
|
}, |
|
|
*initial_frames, |
|
|
{ |
|
|
"type": "text", |
|
|
"text": f"Please answer the following question:\n{data_item['question']}", |
|
|
}, |
|
|
] |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": IMAGINE_AGENT_SYSTEM_PROMPT}, |
|
|
{"role": "user", "content": initial_prompt_content}, |
|
|
] |
|
|
|
|
|
response_content = None |
|
|
max_tool_calls = ( |
|
|
5 |
|
|
) |
|
|
generation_count = 0 |
|
|
|
|
|
for i in range(max_tool_calls): |
|
|
response_message = call_single_model( |
|
|
client, |
|
|
messages, |
|
|
target_model, |
|
|
data_item["key"], |
|
|
max_retry_times, |
|
|
tools=tools, |
|
|
) |
|
|
if response_message is None: |
|
|
return None |
|
|
|
|
|
messages.append(response_message.model_dump(exclude_none=True)) |
|
|
|
|
|
if response_message.tool_calls: |
|
|
tool_call = response_message.tool_calls[ |
|
|
0 |
|
|
] |
|
|
function_name = tool_call.function.name |
|
|
|
|
|
if function_name == "imagine_frame": |
|
|
generation_count += 1 |
|
|
function_args = json.loads(tool_call.function.arguments) |
|
|
new_frame_path = imagine_frame( |
|
|
**function_args, |
|
|
all_frame_paths=available_frame_paths, |
|
|
output_dir=generated_images_dir, |
|
|
generation_count=generation_count, |
|
|
) |
|
|
|
|
|
if new_frame_path: |
|
|
|
|
|
new_frame_id = ( |
|
|
max(available_frame_paths.keys()) |
|
|
if available_frame_paths |
|
|
else 0 |
|
|
) + 1 |
|
|
available_frame_paths[new_frame_id] = new_frame_path |
|
|
|
|
|
b64_image = encode_image(new_frame_path) |
|
|
tool_response_content = [ |
|
|
{ |
|
|
"type": "text", |
|
|
"text": f"Here is the frame you requested to imagine (New Frame ID: {new_frame_id}). Please use it to continue your reasoning.", |
|
|
}, |
|
|
{ |
|
|
"type": "image_url", |
|
|
"image_url": {"url": f"data:image/jpeg;base64,{b64_image}"}, |
|
|
}, |
|
|
] |
|
|
|
|
|
messages.append( |
|
|
{ |
|
|
"tool_call_id": tool_call.id, |
|
|
"role": "tool", |
|
|
"name": function_name, |
|
|
"content": json.dumps( |
|
|
{"status": "success", "new_frame_id": new_frame_id} |
|
|
), |
|
|
} |
|
|
) |
|
|
messages.append({"role": "user", "content": tool_response_content}) |
|
|
else: |
|
|
messages.append( |
|
|
{ |
|
|
"tool_call_id": tool_call.id, |
|
|
"role": "tool", |
|
|
"name": function_name, |
|
|
"content": json.dumps( |
|
|
{ |
|
|
"status": "error", |
|
|
"message": "Failed to generate image.", |
|
|
} |
|
|
), |
|
|
} |
|
|
) |
|
|
else: |
|
|
response_content = response_message.content |
|
|
break |
|
|
|
|
|
|
|
|
if response_content is None and response_message: |
|
|
final_prompt = "You have reached the maximum number of tool calls. Please provide a final answer in the specified JSON format based on the information you have gathered so far." |
|
|
messages.append({"role": "user", "content": final_prompt}) |
|
|
final_response_message = call_single_model( |
|
|
client, messages, target_model, data_item["key"], max_retry_times |
|
|
) |
|
|
if final_response_message: |
|
|
messages.append(final_response_message.model_dump(exclude_none=True)) |
|
|
response_content = final_response_message.content |
|
|
|
|
|
is_correct = False |
|
|
model_answer_cleaned = None |
|
|
parsed_json = extract_json_from_response(response_content) |
|
|
if parsed_json and "answer" in parsed_json: |
|
|
model_answer_cleaned = str(parsed_json["answer"]).strip().upper() |
|
|
gold_answer = data_item["answer"].strip().upper() |
|
|
if model_answer_cleaned == gold_answer: |
|
|
is_correct = True |
|
|
|
|
|
return { |
|
|
**data_item, |
|
|
"agent_conversation": messages, |
|
|
"model_reasoning_and_answer": response_content, |
|
|
"model_answer": model_answer_cleaned, |
|
|
"is_correct": is_correct, |
|
|
"generated_images_path": generated_images_dir, |
|
|
} |
|
|
|
|
|
|
|
|
def encode_image(image_path): |
|
|
"""Encode an image file to a Base64 string""" |
|
|
with open(image_path, "rb") as image_file: |
|
|
return base64.b64encode(image_file.read()).decode("utf-8") |
|
|
|
|
|
|
|
|
def process_single_data(data_item, args): |
|
|
"""Worker function to process a single data item in parallel""" |
|
|
item_key = data_item["key"] |
|
|
try: |
|
|
|
|
|
generated_images_dir = os.path.join( |
|
|
args.output_path, "generated_images", item_key |
|
|
) |
|
|
os.makedirs(generated_images_dir, exist_ok=True) |
|
|
|
|
|
specific_frames_path = os.path.join(args.frames_path, item_key) |
|
|
initial_frames, initial_frame_paths = uniformly_sample_frames_and_encode( |
|
|
specific_frames_path, args.initial_frames_num |
|
|
) |
|
|
|
|
|
if not initial_frames: |
|
|
raise FileNotFoundError(f"Initial frames not found for item '{item_key}'") |
|
|
|
|
|
result = evaluate_single_item_agentic_imagination( |
|
|
data_item, |
|
|
initial_frames, |
|
|
initial_frame_paths, |
|
|
generated_images_dir, |
|
|
args.target_model, |
|
|
args.api_key, |
|
|
args.base_url, |
|
|
args.max_retry_times, |
|
|
) |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nA critical error occurred while processing item {item_key}: {str(e)}") |
|
|
traceback.print_exc() |
|
|
return { |
|
|
"key": item_key, |
|
|
"uid": data_item.get("uid"), |
|
|
"error": str(e), |
|
|
"traceback": traceback.format_exc(), |
|
|
} |
|
|
|
|
|
|
|
|
def load_test_data(json_file): |
|
|
"""Load test data from a JSON file""" |
|
|
try: |
|
|
with open(json_file, "r", encoding="utf-8") as f: |
|
|
return json.load(f) |
|
|
except FileNotFoundError: |
|
|
print(f"Error: Data file not found: {json_file}") |
|
|
exit(1) |
|
|
except json.JSONDecodeError: |
|
|
print(f"Error: JSON file is malformed: {json_file}") |
|
|
exit(1) |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function to orchestrate the entire evaluation flow""" |
|
|
args = parse_arguments() |
|
|
|
|
|
print("--- Video QA Imagine-and-Reason Agent Framework ---") |
|
|
print(f"Evaluating Model: {args.target_model}") |
|
|
print(f"Output Path: {args.output_path}") |
|
|
print(f"Dataset: {args.data_file}") |
|
|
print("---------------------------------") |
|
|
|
|
|
|
|
|
os.makedirs(args.output_path, exist_ok=True) |
|
|
|
|
|
model_name_safe = args.target_model.replace("/", "_") |
|
|
data_filename_base = os.path.splitext(os.path.basename(args.data_file))[0] |
|
|
|
|
|
output_prefix = f"{model_name_safe}_{data_filename_base}_imagine_agent" |
|
|
results_output_file = os.path.join( |
|
|
args.output_path, f"{output_prefix}_results.json" |
|
|
) |
|
|
metrics_output_file = os.path.join( |
|
|
args.output_path, f"{output_prefix}_metrics.json" |
|
|
) |
|
|
error_log_file = os.path.join(args.output_path, f"{output_prefix}_errors.log") |
|
|
|
|
|
|
|
|
|
|
|
all_test_data = load_test_data(args.data_file) |
|
|
tasks_to_process = all_test_data |
|
|
|
|
|
all_results = [] |
|
|
|
|
|
with concurrent.futures.ProcessPoolExecutor( |
|
|
max_workers=args.pool_processes |
|
|
) as executor: |
|
|
func = partial(process_single_data, args=args) |
|
|
results_iterator = executor.map(func, tasks_to_process) |
|
|
|
|
|
for result in tqdm( |
|
|
results_iterator, total=len(tasks_to_process), desc="Processing Videos" |
|
|
): |
|
|
if result: |
|
|
if "error" in result: |
|
|
with open(error_log_file, "a", encoding="utf-8") as f: |
|
|
f.write( |
|
|
f"Error on item {result.get('key', 'N/A')}:\n Error: {result['error']}\n---\n" |
|
|
) |
|
|
all_results.append(result) |
|
|
|
|
|
|
|
|
if len(all_results) % 10 == 0: |
|
|
save_json_file(all_results, results_output_file) |
|
|
|
|
|
print("\n\nProcessing complete.") |
|
|
|
|
|
save_json_file(all_results, results_output_file) |
|
|
print(f"Detailed results saved to: {results_output_file}") |
|
|
|
|
|
|
|
|
final_metrics = calculate_metrics(all_results) |
|
|
save_json_file(final_metrics, metrics_output_file) |
|
|
print(f"\nEvaluation metrics saved to: {metrics_output_file}") |
|
|
print(json.dumps(final_metrics, indent=4)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
main() |
|
|
|