Ntdeseb commited on
Commit
9fe9983
·
1 Parent(s): 8d8ad99

Corregir API de gráficos vectoriales - usar requests en lugar de InferenceClient

Browse files
Files changed (2) hide show
  1. app.py +56 -23
  2. test_vector_graphics.py +42 -20
app.py CHANGED
@@ -761,10 +761,12 @@ def load_vector_model(model_name):
761
 
762
  # Para SVGDreamer, no necesitamos cargar el modelo localmente
763
  # Solo verificamos que podemos acceder a la API
764
- from huggingface_hub import InferenceClient
765
 
766
  try:
767
- client = InferenceClient("jree423/svgdreamer")
 
 
768
  # Hacer una prueba simple para verificar acceso
769
  test_payload = {
770
  "inputs": "test",
@@ -784,7 +786,7 @@ def load_vector_model(model_name):
784
 
785
  model_cache[model_name] = {
786
  "type": "svgdreamer",
787
- "client": "huggingface_hub.InferenceClient"
788
  }
789
 
790
  else:
@@ -793,10 +795,11 @@ def load_vector_model(model_name):
793
  print("✅ Modelo de vector configurado (usa API de Hugging Face)")
794
 
795
  # Verificar acceso al modelo
796
- from huggingface_hub import InferenceClient
797
 
798
  try:
799
- client = InferenceClient(model_name)
 
800
  print(f"✅ Acceso a {model_name} verificado")
801
  except Exception as e:
802
  print(f"⚠️ Advertencia: No se pudo verificar acceso a {model_name}: {e}")
@@ -804,7 +807,7 @@ def load_vector_model(model_name):
804
 
805
  model_cache[model_name] = {
806
  "type": "vector_generic",
807
- "client": "huggingface_hub.InferenceClient"
808
  }
809
 
810
  load_time = time.time() - start_time
@@ -819,7 +822,7 @@ def load_vector_model(model_name):
819
  print("🔄 Fallback a SVGDreamer...")
820
  model_cache[model_name] = {
821
  "type": "svgdreamer",
822
- "client": "huggingface_hub.InferenceClient"
823
  }
824
  print("✅ Fallback exitoso con SVGDreamer")
825
  else:
@@ -1159,20 +1162,45 @@ def generate_vector(prompt, model_name, style="iconography", n_particle=4, num_i
1159
  print(f"🚀 Enviando request a SVGDreamer...")
1160
  print(f"📦 Payload: {payload}")
1161
 
1162
- # Realizar request
1163
- result = client.post(json=payload)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1164
 
1165
  print(f"✅ Respuesta recibida de SVGDreamer")
1166
  print(f"📊 Tipo de respuesta: {type(result)}")
1167
 
1168
  # Procesar respuesta
1169
- if hasattr(result, 'content'):
1170
- # Si es una respuesta HTTP
1171
- svg_content = result.content
1172
- if isinstance(svg_content, bytes):
1173
- svg_content = svg_content.decode('utf-8')
1174
  else:
1175
- # Si es una respuesta directa
1176
  svg_content = result
1177
 
1178
  # Si la respuesta es una lista de partículas
@@ -1229,10 +1257,11 @@ def generate_vector(prompt, model_name, style="iconography", n_particle=4, num_i
1229
  # Para otros modelos de vector (usar API genérica)
1230
  print(f"🎨 Usando modelo genérico para vector: {model_name}")
1231
 
1232
- # Usar la API de Hugging Face genérica
1233
- from huggingface_hub import InferenceClient
1234
 
1235
- client = InferenceClient(model_name)
 
1236
 
1237
  # Preparar payload genérico
1238
  payload = {
@@ -1250,15 +1279,19 @@ def generate_vector(prompt, model_name, style="iconography", n_particle=4, num_i
1250
  print(f"🚀 Enviando request a {model_name}...")
1251
 
1252
  # Realizar request
1253
- result = client.post(json=payload)
 
 
 
1254
 
 
1255
  print(f"✅ Respuesta recibida de {model_name}")
1256
 
1257
  # Procesar respuesta genérica
1258
- if hasattr(result, 'content'):
1259
- svg_content = result.content
1260
- if isinstance(svg_content, bytes):
1261
- svg_content = svg_content.decode('utf-8')
1262
  else:
1263
  svg_content = result
1264
 
 
761
 
762
  # Para SVGDreamer, no necesitamos cargar el modelo localmente
763
  # Solo verificamos que podemos acceder a la API
764
+ import requests
765
 
766
  try:
767
+ API_URL = "https://api-inference.huggingface.co/models/jree423/svgdreamer"
768
+ headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
769
+
770
  # Hacer una prueba simple para verificar acceso
771
  test_payload = {
772
  "inputs": "test",
 
786
 
787
  model_cache[model_name] = {
788
  "type": "svgdreamer",
789
+ "client": "requests"
790
  }
791
 
792
  else:
 
795
  print("✅ Modelo de vector configurado (usa API de Hugging Face)")
796
 
797
  # Verificar acceso al modelo
798
+ import requests
799
 
800
  try:
801
+ API_URL = f"https://api-inference.huggingface.co/models/{model_name}"
802
+ headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
803
  print(f"✅ Acceso a {model_name} verificado")
804
  except Exception as e:
805
  print(f"⚠️ Advertencia: No se pudo verificar acceso a {model_name}: {e}")
 
807
 
808
  model_cache[model_name] = {
809
  "type": "vector_generic",
810
+ "client": "requests"
811
  }
812
 
813
  load_time = time.time() - start_time
 
822
  print("🔄 Fallback a SVGDreamer...")
823
  model_cache[model_name] = {
824
  "type": "svgdreamer",
825
+ "client": "requests"
826
  }
827
  print("✅ Fallback exitoso con SVGDreamer")
828
  else:
 
1162
  print(f"🚀 Enviando request a SVGDreamer...")
1163
  print(f"📦 Payload: {payload}")
1164
 
1165
+ # Realizar request usando el método correcto
1166
+ try:
1167
+ # Intentar con el método text_generation para SVGDreamer
1168
+ result = client.text_generation(
1169
+ prompt,
1170
+ model="jree423/svgdreamer",
1171
+ parameters={
1172
+ "n_particle": n_particle,
1173
+ "num_iter": num_iter,
1174
+ "guidance_scale": guidance_scale,
1175
+ "style": style,
1176
+ "width": width,
1177
+ "height": height,
1178
+ "seed": seed
1179
+ }
1180
+ )
1181
+ except Exception as e:
1182
+ print(f"⚠️ Error con text_generation: {e}")
1183
+ # Fallback: usar requests directamente
1184
+ import requests
1185
+
1186
+ API_URL = "https://api-inference.huggingface.co/models/jree423/svgdreamer"
1187
+ headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
1188
+
1189
+ response = requests.post(API_URL, headers=headers, json=payload)
1190
+ if response.status_code == 200:
1191
+ result = response.json()
1192
+ else:
1193
+ raise Exception(f"Error en API request: {response.status_code} - {response.text}")
1194
 
1195
  print(f"✅ Respuesta recibida de SVGDreamer")
1196
  print(f"📊 Tipo de respuesta: {type(result)}")
1197
 
1198
  # Procesar respuesta
1199
+ if isinstance(result, dict) and 'generated_text' in result:
1200
+ svg_content = result['generated_text']
1201
+ elif isinstance(result, list):
1202
+ svg_content = result
 
1203
  else:
 
1204
  svg_content = result
1205
 
1206
  # Si la respuesta es una lista de partículas
 
1257
  # Para otros modelos de vector (usar API genérica)
1258
  print(f"🎨 Usando modelo genérico para vector: {model_name}")
1259
 
1260
+ # Usar requests directamente para modelos genéricos
1261
+ import requests
1262
 
1263
+ API_URL = f"https://api-inference.huggingface.co/models/{model_name}"
1264
+ headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
1265
 
1266
  # Preparar payload genérico
1267
  payload = {
 
1279
  print(f"🚀 Enviando request a {model_name}...")
1280
 
1281
  # Realizar request
1282
+ response = requests.post(API_URL, headers=headers, json=payload)
1283
+
1284
+ if response.status_code != 200:
1285
+ raise Exception(f"Error en API request: {response.status_code} - {response.text}")
1286
 
1287
+ result = response.json()
1288
  print(f"✅ Respuesta recibida de {model_name}")
1289
 
1290
  # Procesar respuesta genérica
1291
+ if isinstance(result, dict) and 'generated_text' in result:
1292
+ svg_content = result['generated_text']
1293
+ elif isinstance(result, list):
1294
+ svg_content = result[0] if len(result) > 0 else str(result)
1295
  else:
1296
  svg_content = result
1297
 
test_vector_graphics.py CHANGED
@@ -8,15 +8,17 @@ import os
8
  import sys
9
  import time
10
  import tempfile
11
- from huggingface_hub import InferenceClient
12
 
13
  def test_svgdreamer_basic():
14
  """Prueba básica de SVGDreamer"""
15
  print("🎨 Probando SVGDreamer - Generación básica...")
16
 
17
  try:
18
- # Configurar cliente
19
- client = InferenceClient("jree423/svgdreamer")
 
 
20
 
21
  # Payload básico
22
  payload = {
@@ -36,17 +38,21 @@ def test_svgdreamer_basic():
36
 
37
  # Realizar request
38
  start_time = time.time()
39
- result = client.post(json=payload)
40
  generation_time = time.time() - start_time
41
 
 
 
 
 
42
  print(f"✅ Respuesta recibida en {generation_time:.2f}s")
43
  print(f"📊 Tipo de respuesta: {type(result)}")
44
 
45
  # Procesar respuesta
46
- if hasattr(result, 'content'):
47
- svg_content = result.content
48
- if isinstance(svg_content, bytes):
49
- svg_content = svg_content.decode('utf-8')
50
  else:
51
  svg_content = result
52
 
@@ -84,7 +90,9 @@ def test_svgdreamer_multiple_styles():
84
  prompt = "a friendly robot character"
85
 
86
  try:
87
- client = InferenceClient("jree423/svgdreamer")
 
 
88
 
89
  for style in styles:
90
  print(f"🎯 Probando estilo: {style}")
@@ -103,10 +111,13 @@ def test_svgdreamer_multiple_styles():
103
  }
104
 
105
  start_time = time.time()
106
- result = client.post(json=payload)
107
  generation_time = time.time() - start_time
108
 
109
- print(f"✅ Estilo {style} completado en {generation_time:.2f}s")
 
 
 
110
 
111
  # Pausa entre requests para no sobrecargar
112
  time.sleep(1)
@@ -123,7 +134,9 @@ def test_svgdreamer_multiple_particles():
123
  print("\n🎨 Probando SVGDreamer - Múltiples partículas...")
124
 
125
  try:
126
- client = InferenceClient("jree423/svgdreamer")
 
 
127
 
128
  payload = {
129
  "inputs": "geometric patterns in bright colors",
@@ -141,16 +154,20 @@ def test_svgdreamer_multiple_particles():
141
  print(f"📦 Enviando payload con 3 partículas...")
142
 
143
  start_time = time.time()
144
- result = client.post(json=payload)
145
  generation_time = time.time() - start_time
146
 
 
 
 
 
147
  print(f"✅ Respuesta recibida en {generation_time:.2f}s")
148
 
149
  # Procesar respuesta
150
- if hasattr(result, 'content'):
151
- svg_content = result.content
152
- if isinstance(svg_content, bytes):
153
- svg_content = svg_content.decode('utf-8')
154
  else:
155
  svg_content = result
156
 
@@ -176,7 +193,9 @@ def test_error_handling():
176
  print("\n🎨 Probando manejo de errores...")
177
 
178
  try:
179
- client = InferenceClient("jree423/svgdreamer")
 
 
180
 
181
  # Payload con parámetros inválidos
182
  payload = {
@@ -195,8 +214,11 @@ def test_error_handling():
195
  print("🧪 Probando parámetros extremos...")
196
 
197
  try:
198
- result = client.post(json=payload)
199
- print("⚠️ Request completado (esperaba error)")
 
 
 
200
  except Exception as e:
201
  print(f"✅ Error capturado correctamente: {type(e).__name__}")
202
 
 
8
  import sys
9
  import time
10
  import tempfile
11
+ import requests
12
 
13
  def test_svgdreamer_basic():
14
  """Prueba básica de SVGDreamer"""
15
  print("🎨 Probando SVGDreamer - Generación básica...")
16
 
17
  try:
18
+ # Configurar API
19
+ API_URL = "https://api-inference.huggingface.co/models/jree423/svgdreamer"
20
+ hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
21
+ headers = {"Authorization": f"Bearer {hf_token}"} if hf_token else {}
22
 
23
  # Payload básico
24
  payload = {
 
38
 
39
  # Realizar request
40
  start_time = time.time()
41
+ response = requests.post(API_URL, headers=headers, json=payload)
42
  generation_time = time.time() - start_time
43
 
44
+ if response.status_code != 200:
45
+ raise Exception(f"Error en API: {response.status_code} - {response.text}")
46
+
47
+ result = response.json()
48
  print(f"✅ Respuesta recibida en {generation_time:.2f}s")
49
  print(f"📊 Tipo de respuesta: {type(result)}")
50
 
51
  # Procesar respuesta
52
+ if isinstance(result, dict) and 'generated_text' in result:
53
+ svg_content = result['generated_text']
54
+ elif isinstance(result, list):
55
+ svg_content = result
56
  else:
57
  svg_content = result
58
 
 
90
  prompt = "a friendly robot character"
91
 
92
  try:
93
+ API_URL = "https://api-inference.huggingface.co/models/jree423/svgdreamer"
94
+ hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
95
+ headers = {"Authorization": f"Bearer {hf_token}"} if hf_token else {}
96
 
97
  for style in styles:
98
  print(f"🎯 Probando estilo: {style}")
 
111
  }
112
 
113
  start_time = time.time()
114
+ response = requests.post(API_URL, headers=headers, json=payload)
115
  generation_time = time.time() - start_time
116
 
117
+ if response.status_code == 200:
118
+ print(f"✅ Estilo {style} completado en {generation_time:.2f}s")
119
+ else:
120
+ print(f"❌ Error con estilo {style}: {response.status_code}")
121
 
122
  # Pausa entre requests para no sobrecargar
123
  time.sleep(1)
 
134
  print("\n🎨 Probando SVGDreamer - Múltiples partículas...")
135
 
136
  try:
137
+ API_URL = "https://api-inference.huggingface.co/models/jree423/svgdreamer"
138
+ hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
139
+ headers = {"Authorization": f"Bearer {hf_token}"} if hf_token else {}
140
 
141
  payload = {
142
  "inputs": "geometric patterns in bright colors",
 
154
  print(f"📦 Enviando payload con 3 partículas...")
155
 
156
  start_time = time.time()
157
+ response = requests.post(API_URL, headers=headers, json=payload)
158
  generation_time = time.time() - start_time
159
 
160
+ if response.status_code != 200:
161
+ raise Exception(f"Error en API: {response.status_code} - {response.text}")
162
+
163
+ result = response.json()
164
  print(f"✅ Respuesta recibida en {generation_time:.2f}s")
165
 
166
  # Procesar respuesta
167
+ if isinstance(result, dict) and 'generated_text' in result:
168
+ svg_content = result['generated_text']
169
+ elif isinstance(result, list):
170
+ svg_content = result
171
  else:
172
  svg_content = result
173
 
 
193
  print("\n🎨 Probando manejo de errores...")
194
 
195
  try:
196
+ API_URL = "https://api-inference.huggingface.co/models/jree423/svgdreamer"
197
+ hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
198
+ headers = {"Authorization": f"Bearer {hf_token}"} if hf_token else {}
199
 
200
  # Payload con parámetros inválidos
201
  payload = {
 
214
  print("🧪 Probando parámetros extremos...")
215
 
216
  try:
217
+ response = requests.post(API_URL, headers=headers, json=payload)
218
+ if response.status_code == 200:
219
+ print("⚠️ Request completado (esperaba error)")
220
+ else:
221
+ print(f"✅ Error capturado correctamente: {response.status_code}")
222
  except Exception as e:
223
  print(f"✅ Error capturado correctamente: {type(e).__name__}")
224