Vestiq / test_enhanced_analysis.py
Hashii1729's picture
Add enhanced fashion analysis features with refined prompt
bb71884
#!/usr/bin/env python3
"""
Test script for the enhanced fashion analysis with refined prompt
"""
import requests
import json
import sys
from pathlib import Path
def test_enhanced_analysis(image_path, server_url="http://localhost:7861"):
"""Test the enhanced analysis endpoint"""
if not Path(image_path).exists():
print(f"Error: Image file {image_path} not found")
return
# Test the enhanced analysis endpoint
print("πŸ” Testing Enhanced Fashion Analysis")
print("=" * 50)
try:
with open(image_path, 'rb') as f:
files = {'file': f}
response = requests.post(f"{server_url}/analyze-enhanced", files=files)
if response.status_code == 200:
print("βœ… Enhanced Analysis Results:")
print("-" * 30)
print(response.text)
print("-" * 30)
else:
print(f"❌ Error: {response.status_code} - {response.text}")
except requests.exceptions.ConnectionError:
print(f"❌ Error: Could not connect to server at {server_url}")
print("Make sure the FastAPI server is running with: python fast.py")
except Exception as e:
print(f"❌ Error: {str(e)}")
def compare_analysis_methods(image_path, server_url="http://localhost:7861"):
"""Compare different analysis methods"""
if not Path(image_path).exists():
print(f"Error: Image file {image_path} not found")
return
print("πŸ” Comparing Analysis Methods")
print("=" * 50)
endpoints = [
("/analyze-enhanced", "Enhanced Prompt Analysis"),
("/analyze-structured", "Structured Analysis"),
("/analyze", "Basic Analysis")
]
for endpoint, name in endpoints:
print(f"\nπŸ“‹ {name}")
print("-" * 30)
try:
with open(image_path, 'rb') as f:
files = {'file': f}
response = requests.post(f"{server_url}{endpoint}", files=files)
if response.status_code == 200:
if endpoint == "/analyze-structured":
# This returns JSON
try:
data = response.json()
print(json.dumps(data, indent=2))
except:
print(response.text)
else:
print(response.text)
else:
print(f"❌ Error: {response.status_code} - {response.text}")
except requests.exceptions.ConnectionError:
print(f"❌ Error: Could not connect to server at {server_url}")
break
except Exception as e:
print(f"❌ Error: {str(e)}")
def test_refined_prompt(server_url="http://localhost:7861"):
"""Test the refined prompt endpoint"""
print("πŸ“ Testing Refined Prompt")
print("=" * 50)
try:
response = requests.get(f"{server_url}/refined-prompt")
if response.status_code == 200:
print("βœ… Refined Prompt:")
print("-" * 30)
print(response.text)
print("-" * 30)
else:
print(f"❌ Error: {response.status_code} - {response.text}")
except requests.exceptions.ConnectionError:
print(f"❌ Error: Could not connect to server at {server_url}")
except Exception as e:
print(f"❌ Error: {str(e)}")
def main():
"""Main function"""
if len(sys.argv) < 2:
print("Usage: python test_enhanced_analysis.py <image_path> [server_url]")
print("Example: python test_enhanced_analysis.py test_image.jpg")
print("Example: python test_enhanced_analysis.py test_image.jpg http://localhost:7861")
return
image_path = sys.argv[1]
server_url = sys.argv[2] if len(sys.argv) > 2 else "http://localhost:7861"
print("🎽 Enhanced Fashion Analysis Test")
print("=" * 50)
print(f"Image: {image_path}")
print(f"Server: {server_url}")
print("=" * 50)
# Test the refined prompt first
test_refined_prompt(server_url)
print("\n")
# Test enhanced analysis
test_enhanced_analysis(image_path, server_url)
print("\n")
# Compare all methods
compare_analysis_methods(image_path, server_url)
if __name__ == "__main__":
main()