DHEIVER commited on
Commit
15646ab
·
verified ·
1 Parent(s): aace76b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -37
app.py CHANGED
@@ -1,21 +1,72 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import pipeline, AutoProcessor, AutoModelForVision2Seq
4
- from PIL import Image
5
  import numpy as np
6
  import os
7
  from huggingface_hub import snapshot_download
8
  import logging
 
 
 
 
9
 
10
  # Configure logging
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class NutritionalAnalyzer:
15
  def __init__(self):
16
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
  self.models = {}
18
  self.processors = {}
 
19
 
20
  def initialize_model(self, model_name):
21
  """Initialize a specific model with error handling and caching"""
@@ -37,7 +88,7 @@ class NutritionalAnalyzer:
37
 
38
  config = model_configs.get(model_name)
39
  if not config:
40
- raise ValueError(f"Unsupported model: {model_name}")
41
 
42
  # Ensure cache directory exists
43
  os.makedirs(config["local_cache"], exist_ok=True)
@@ -72,19 +123,6 @@ class NutritionalAnalyzer:
72
  logger.error(f"Error initializing {model_name}: {str(e)}")
73
  return False
74
 
75
- def prepare_image(self, image):
76
- """Prepare image for model input"""
77
- if isinstance(image, str):
78
- image = Image.open(image)
79
- elif isinstance(image, np.ndarray):
80
- image = Image.fromarray(image)
81
-
82
- # Ensure image is in RGB mode
83
- if image.mode != "RGB":
84
- image = image.convert("RGB")
85
-
86
- return image
87
-
88
  def generate_nutritional_prompt(self, user_question):
89
  """Generate a comprehensive nutritional analysis prompt"""
90
  return f"""Como nutricionista especializado, analise esta refeição detalhadamente:
@@ -111,6 +149,9 @@ Por favor, forneça uma análise detalhada em português."""
111
  def analyze_image(self, image, question, model_choice):
112
  """Analyze image with nutritional focus"""
113
  try:
 
 
 
114
  # Convert model choice to internal name
115
  model_name = model_choice.lower().replace("-", "")
116
 
@@ -118,34 +159,48 @@ Por favor, forneça uma análise detalhada em português."""
118
  if not self.initialize_model(model_name):
119
  return "Erro: Não foi possível inicializar o modelo. Por favor, tente novamente."
120
 
121
- # Prepare image and prompt
122
- processed_image = self.prepare_image(image)
 
 
 
 
 
 
 
123
  nutritional_prompt = self.generate_nutritional_prompt(question)
124
 
125
  # Process input
126
- inputs = self.processors[model_name](
127
- images=processed_image,
128
- text=nutritional_prompt,
129
- return_tensors="pt"
130
- ).to(self.device)
 
 
 
131
 
132
  # Generate response with enhanced parameters
133
- with torch.no_grad():
134
- outputs = self.models[model_name].generate(
135
- **inputs,
136
- max_new_tokens=300,
137
- num_beams=5,
138
- temperature=0.7,
139
- top_p=0.9,
140
- repetition_penalty=1.2,
141
- length_penalty=1.0
142
- )
143
-
144
- # Decode and format response
145
- response = self.processors[model_name].decode(outputs[0], skip_special_tokens=True)
146
- formatted_response = self.format_response(response)
 
 
 
147
 
148
- return formatted_response
 
149
 
150
  except Exception as e:
151
  logger.error(f"Analysis error: {str(e)}")
@@ -234,6 +289,7 @@ def create_interface():
234
  2. Capture todos os elementos do prato
235
  3. Evite ângulos muito inclinados
236
  4. Seja específico em suas perguntas
 
237
  """)
238
 
239
  analyze_btn.click(
 
1
  import gradio as gr
2
  import torch
3
  from transformers import pipeline, AutoProcessor, AutoModelForVision2Seq
4
+ from PIL import Image, ImageOps
5
  import numpy as np
6
  import os
7
  from huggingface_hub import snapshot_download
8
  import logging
9
+ from pathlib import Path
10
+ import tempfile
11
+ import requests
12
+ from io import BytesIO
13
 
14
  # Configure logging
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
+ class ImageHandler:
19
+ """Handle image processing and conversion"""
20
+ @staticmethod
21
+ def convert_to_rgb(image_path):
22
+ """Convert image to RGB format supporting multiple formats"""
23
+ try:
24
+ # If image is a URL, download it first
25
+ if isinstance(image_path, str) and (image_path.startswith('http://') or image_path.startswith('https://')):
26
+ response = requests.get(image_path)
27
+ image_data = BytesIO(response.content)
28
+ image = Image.open(image_data)
29
+ else:
30
+ image = Image.open(image_path)
31
+
32
+ # Convert RGBA to RGB if needed
33
+ if image.mode == 'RGBA':
34
+ background = Image.new('RGB', image.size, (255, 255, 255))
35
+ background.paste(image, mask=image.split()[3])
36
+ image = background
37
+ # Convert any other mode to RGB
38
+ elif image.mode != 'RGB':
39
+ image = image.convert('RGB')
40
+
41
+ return image
42
+
43
+ except Exception as e:
44
+ logger.error(f"Error converting image: {str(e)}")
45
+ raise ValueError(f"Não foi possível processar a imagem. Erro: {str(e)}")
46
+
47
+ @staticmethod
48
+ def process_image(image):
49
+ """Process image from various input types"""
50
+ try:
51
+ if isinstance(image, np.ndarray):
52
+ return Image.fromarray(image)
53
+ elif isinstance(image, Image.Image):
54
+ return image
55
+ elif isinstance(image, (str, Path)):
56
+ return ImageHandler.convert_to_rgb(image)
57
+ else:
58
+ raise ValueError("Formato de imagem não suportado")
59
+
60
+ except Exception as e:
61
+ logger.error(f"Error processing image: {str(e)}")
62
+ raise ValueError(f"Erro no processamento da imagem: {str(e)}")
63
+
64
  class NutritionalAnalyzer:
65
  def __init__(self):
66
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
67
  self.models = {}
68
  self.processors = {}
69
+ self.image_handler = ImageHandler()
70
 
71
  def initialize_model(self, model_name):
72
  """Initialize a specific model with error handling and caching"""
 
88
 
89
  config = model_configs.get(model_name)
90
  if not config:
91
+ raise ValueError(f"Modelo não suportado: {model_name}")
92
 
93
  # Ensure cache directory exists
94
  os.makedirs(config["local_cache"], exist_ok=True)
 
123
  logger.error(f"Error initializing {model_name}: {str(e)}")
124
  return False
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  def generate_nutritional_prompt(self, user_question):
127
  """Generate a comprehensive nutritional analysis prompt"""
128
  return f"""Como nutricionista especializado, analise esta refeição detalhadamente:
 
149
  def analyze_image(self, image, question, model_choice):
150
  """Analyze image with nutritional focus"""
151
  try:
152
+ if image is None:
153
+ return "Por favor, faça upload de uma imagem para análise."
154
+
155
  # Convert model choice to internal name
156
  model_name = model_choice.lower().replace("-", "")
157
 
 
159
  if not self.initialize_model(model_name):
160
  return "Erro: Não foi possível inicializar o modelo. Por favor, tente novamente."
161
 
162
+ # Process image with enhanced error handling
163
+ try:
164
+ processed_image = self.image_handler.process_image(image)
165
+ except ValueError as e:
166
+ return str(e)
167
+ except Exception as e:
168
+ return f"Erro no processamento da imagem: {str(e)}"
169
+
170
+ # Generate and process prompt
171
  nutritional_prompt = self.generate_nutritional_prompt(question)
172
 
173
  # Process input
174
+ try:
175
+ inputs = self.processors[model_name](
176
+ images=processed_image,
177
+ text=nutritional_prompt,
178
+ return_tensors="pt"
179
+ ).to(self.device)
180
+ except Exception as e:
181
+ return f"Erro no processamento do modelo: {str(e)}"
182
 
183
  # Generate response with enhanced parameters
184
+ try:
185
+ with torch.no_grad():
186
+ outputs = self.models[model_name].generate(
187
+ **inputs,
188
+ max_new_tokens=300,
189
+ num_beams=5,
190
+ temperature=0.7,
191
+ top_p=0.9,
192
+ repetition_penalty=1.2,
193
+ length_penalty=1.0
194
+ )
195
+
196
+ # Decode and format response
197
+ response = self.processors[model_name].decode(outputs[0], skip_special_tokens=True)
198
+ formatted_response = self.format_response(response)
199
+
200
+ return formatted_response
201
 
202
+ except Exception as e:
203
+ return f"Erro na geração da análise: {str(e)}"
204
 
205
  except Exception as e:
206
  logger.error(f"Analysis error: {str(e)}")
 
289
  2. Capture todos os elementos do prato
290
  3. Evite ângulos muito inclinados
291
  4. Seja específico em suas perguntas
292
+ 5. Formatos de imagem suportados: JPG, PNG, WEBP
293
  """)
294
 
295
  analyze_btn.click(