Spaces:
Running
Running
File size: 5,971 Bytes
2e7f273 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
from flask import Flask, render_template, request, jsonify
from PIL import Image
from io import BytesIO
import torch
from torchvision import models, transforms
from torchvision.models.detection import fasterrcnn_resnet50_fpn
import os
app = Flask(__name__)
# Load ImageNet class index
def load_imagenet_class_index():
class_index_path = 'imagenet_classes.txt'
if not os.path.exists(class_index_path):
raise FileNotFoundError(f"ImageNet class index file not found at {class_index_path}")
with open(class_index_path) as f:
classes = [line.strip() for line in f.readlines()]
return classes
imagenet_classes = load_imagenet_class_index()
# Load pre-trained models
resnet = models.resnet50(pretrained=True)
resnet.eval()
fasterrcnn = fasterrcnn_resnet50_fpn(pretrained=True)
fasterrcnn.eval()
# Image transformation
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# COCO dataset class names
COCO_INSTANCE_CATEGORY_NAMES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
# Function for real image analysis
def real_image_analysis(image):
# Prepare image for classification
img_t = transform(image)
batch_t = torch.unsqueeze(img_t, 0)
# Classification
with torch.no_grad():
output = resnet(batch_t)
# Get top 3 predictions
_, indices = torch.sort(output, descending=True)
percentages = torch.nn.functional.softmax(output, dim=1)[0] * 100
objects = [imagenet_classes[idx.item()] for idx in indices[0][:3]]
# Object detection using Faster R-CNN
img_tensor = transforms.ToTensor()(image).unsqueeze(0)
with torch.no_grad():
prediction = fasterrcnn(img_tensor)
# Get detected objects
detected_objects = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in prediction[0]['labels']]
objects.extend(detected_objects)
objects = list(set(objects)) # Remove duplicates
# Get dominant colors
colors = get_dominant_colors(image)
# Determine scene (indoor/outdoor)
scene = "outdoor" if any(obj in ['sky', 'tree', 'grass', 'mountain'] for obj in objects) else "indoor"
return {
"objects": objects[:5], # Limit to top 5 objects
"colors": colors,
"scene": scene
}
# Function to get dominant colors from an image
def get_dominant_colors(image, num_colors=3):
# Resize image to speed up processing
img = image.copy()
img.thumbnail((100, 100))
# Get colors from the image
paletted = img.convert('P', palette=Image.ADAPTIVE, colors=num_colors)
palette = paletted.getpalette()
color_counts = sorted(paletted.getcolors(), reverse=True)
colors = []
for i in range(num_colors):
palette_index = color_counts[i][1]
dominant_color = palette[palette_index*3:palette_index*3+3]
colors.append(rgb_to_name(dominant_color))
return colors
# Function to convert RGB to color name (simplified)
def rgb_to_name(rgb):
r, g, b = rgb
if r > g and r > b:
return "red"
elif g > r and g > b:
return "green"
elif b > r and b > g:
return "blue"
else:
return "gray"
# Function to simulate the generation of answers from metadata
def generate_answer_from_metadata(metadata, question, complexity):
prompt = f"""
The image contains the following objects: {', '.join(metadata['objects'])}.
The dominant colors are {', '.join(metadata['colors'])}.
It appears to be an {metadata['scene']} scene.
Based on this, provide a {complexity.lower()} response to the following question: {question}
"""
# Since `client` is not defined, we can simulate a response here
# Replace this section with the actual client code if using an API
return f"Simulated answer based on metadata: {metadata}. Question: {question}, Complexity: {complexity}."
# Flask routes
@app.route('/')
def index():
return render_template('index.html')
@app.route('/ask', methods=['POST'])
def ask_question():
image = request.files.get('image')
question = request.form.get('question')
complexity = request.form.get('complexity', 'Default')
if not image or not question:
return jsonify({"error": "Missing image or question"}), 400
# Process the image
image = Image.open(image).convert("RGB")
# Perform real image analysis
metadata = real_image_analysis(image)
# Generate the answer
try:
answer = generate_answer_from_metadata(metadata, question, complexity)
return jsonify({"answer": answer})
except Exception as e:
print(f"Error generating answer: {str(e)}")
return jsonify({"error": "Failed to generate answer"}), 500
if __name__ == '__main__':
app.run(debug=True)
|