Spaces:
Sleeping
Sleeping
""" | |
Test the MultiModalAgent. | |
""" | |
import os | |
import sys | |
import logging | |
import json | |
# Add the current directory to sys.path to import local modules | |
sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
# Import the MultiModalAgent | |
from agent import MultiModalAgent | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger('test_agent') | |
def main(): | |
"""Test the MultiModalAgent with some sample questions.""" | |
# Initialize the agent | |
resource_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'resource') | |
agent = MultiModalAgent(resource_dir=resource_dir) | |
# Load test questions from metadata.jsonl | |
metadata_path = os.path.join(resource_dir, 'metadata.jsonl') | |
test_questions = [] | |
with open(metadata_path, 'r', encoding='utf-8') as f: | |
for line in f: | |
entry = json.loads(line.strip()) | |
if 'Question' in entry and 'file_name' in entry and entry['file_name']: | |
test_questions.append({ | |
'task_id': entry.get('task_id'), | |
'question': entry['Question'], | |
'file_name': entry['file_name'], | |
'expected_answer': entry.get('Final answer') | |
}) | |
if len(test_questions) >= 5: # Limit to 5 questions | |
break | |
# If no questions with files were found, use some generic questions | |
if not test_questions: | |
test_questions = [ | |
{ | |
'question': "What's the oldest Blu-Ray in the inventory spreadsheet?", | |
'file_name': None, | |
'expected_answer': None | |
}, | |
{ | |
'question': "How many files are in the resource directory?", | |
'file_name': None, | |
'expected_answer': None | |
} | |
] | |
# Test the agent with each question | |
for i, q in enumerate(test_questions): | |
question = q['question'] | |
logger.info(f"Testing question {i+1}: {question}") | |
answer = agent(question) | |
logger.info(f"Answer: {answer}") | |
if q['expected_answer']: | |
logger.info(f"Expected answer: {q['expected_answer']}") | |
if answer.strip() == q['expected_answer'].strip(): | |
logger.info("Correct answer!") | |
else: | |
logger.warning("Incorrect answer.") | |
logger.info("-" * 80) | |
if __name__ == "__main__": | |
main() | |