File size: 5,971 Bytes
2e7f273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from flask import Flask, render_template, request, jsonify
from PIL import Image
from io import BytesIO
import torch
from torchvision import models, transforms
from torchvision.models.detection import fasterrcnn_resnet50_fpn
import os

app = Flask(__name__)

# Load ImageNet class index
def load_imagenet_class_index():
    class_index_path = 'imagenet_classes.txt'
    if not os.path.exists(class_index_path):
        raise FileNotFoundError(f"ImageNet class index file not found at {class_index_path}")
    
    with open(class_index_path) as f:
        classes = [line.strip() for line in f.readlines()]
    return classes

imagenet_classes = load_imagenet_class_index()

# Load pre-trained models
resnet = models.resnet50(pretrained=True)
resnet.eval()
fasterrcnn = fasterrcnn_resnet50_fpn(pretrained=True)
fasterrcnn.eval()

# Image transformation
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# COCO dataset class names
COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

# Function for real image analysis
def real_image_analysis(image):
    # Prepare image for classification
    img_t = transform(image)
    batch_t = torch.unsqueeze(img_t, 0)

    # Classification
    with torch.no_grad():
        output = resnet(batch_t)

    # Get top 3 predictions
    _, indices = torch.sort(output, descending=True)
    percentages = torch.nn.functional.softmax(output, dim=1)[0] * 100
    objects = [imagenet_classes[idx.item()] for idx in indices[0][:3]]

    # Object detection using Faster R-CNN
    img_tensor = transforms.ToTensor()(image).unsqueeze(0)
    with torch.no_grad():
        prediction = fasterrcnn(img_tensor)

    # Get detected objects
    detected_objects = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in prediction[0]['labels']]
    objects.extend(detected_objects)
    objects = list(set(objects))  # Remove duplicates

    # Get dominant colors
    colors = get_dominant_colors(image)

    # Determine scene (indoor/outdoor)
    scene = "outdoor" if any(obj in ['sky', 'tree', 'grass', 'mountain'] for obj in objects) else "indoor"

    return {
        "objects": objects[:5],  # Limit to top 5 objects
        "colors": colors,
        "scene": scene
    }

# Function to get dominant colors from an image
def get_dominant_colors(image, num_colors=3):
    # Resize image to speed up processing
    img = image.copy()
    img.thumbnail((100, 100))
    
    # Get colors from the image
    paletted = img.convert('P', palette=Image.ADAPTIVE, colors=num_colors)
    palette = paletted.getpalette()
    color_counts = sorted(paletted.getcolors(), reverse=True)
    colors = []
    for i in range(num_colors):
        palette_index = color_counts[i][1]
        dominant_color = palette[palette_index*3:palette_index*3+3]
        colors.append(rgb_to_name(dominant_color))
    return colors

# Function to convert RGB to color name (simplified)
def rgb_to_name(rgb):
    r, g, b = rgb
    if r > g and r > b:
        return "red"
    elif g > r and g > b:
        return "green"
    elif b > r and b > g:
        return "blue"
    else:
        return "gray"

# Function to simulate the generation of answers from metadata
def generate_answer_from_metadata(metadata, question, complexity):
    prompt = f"""

    The image contains the following objects: {', '.join(metadata['objects'])}.

    The dominant colors are {', '.join(metadata['colors'])}.

    It appears to be an {metadata['scene']} scene.



    Based on this, provide a {complexity.lower()} response to the following question: {question}

    """
    
    # Since `client` is not defined, we can simulate a response here
    # Replace this section with the actual client code if using an API
    return f"Simulated answer based on metadata: {metadata}. Question: {question}, Complexity: {complexity}."

# Flask routes
@app.route('/')
def index():
    return render_template('index.html')

@app.route('/ask', methods=['POST'])
def ask_question():
    image = request.files.get('image')
    question = request.form.get('question')
    complexity = request.form.get('complexity', 'Default')

    if not image or not question:
        return jsonify({"error": "Missing image or question"}), 400

    # Process the image
    image = Image.open(image).convert("RGB")

    # Perform real image analysis
    metadata = real_image_analysis(image)

    # Generate the answer
    try:
        answer = generate_answer_from_metadata(metadata, question, complexity)
        return jsonify({"answer": answer})
    except Exception as e:
        print(f"Error generating answer: {str(e)}")
        return jsonify({"error": "Failed to generate answer"}), 500

if __name__ == '__main__':
    app.run(debug=True)