VOIDER commited on
Commit
842de2a
·
verified ·
1 Parent(s): 713959a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -173
app.py CHANGED
@@ -4,7 +4,8 @@ import io
4
  import os
5
  import pandas as pd
6
  import torch
7
- from transformers import pipeline as transformers_pipeline , AutoModelForImageClassification, CLIPImageProcessor
 
8
  import open_clip
9
  import re
10
  import matplotlib.pyplot as plt
@@ -12,7 +13,7 @@ import json
12
  from collections import defaultdict
13
  import numpy as np
14
  import logging
15
- import time # Для замера времени
16
 
17
  # --- ONNX Related Imports and Setup ---
18
  try:
@@ -23,7 +24,6 @@ except ImportError:
23
 
24
  from huggingface_hub import hf_hub_download
25
 
26
- # imgutils для rgb_encode
27
  try:
28
  from imgutils.data import rgb_encode
29
  IMGUTILS_AVAILABLE = True
@@ -37,17 +37,12 @@ except ImportError:
37
  img_arr = np.transpose(img_arr, (2, 0, 1))
38
  return img_arr.astype(np.uint8)
39
 
40
- # --- Модель Конфигурация и Загрузка ---
41
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
42
  print(f"INFO: PyTorch Device: {DEVICE}")
43
  ONNX_EXECUTION_PROVIDER = "CUDAExecutionProvider" if DEVICE == "cuda" and onnxruntime and "CUDAExecutionProvider" in onnxruntime.get_available_providers() else "CPUExecutionProvider"
44
- if onnxruntime:
45
- print(f"INFO: ONNX Execution Provider: {ONNX_EXECUTION_PROVIDER}")
46
- else:
47
- print("INFO: ONNX Runtime not available, ONNX models will be skipped.")
48
 
49
-
50
- # --- Helper for ONNX models (deepghs) ---
51
  @torch.no_grad()
52
  def _img_preprocess_for_onnx(image: Image.Image, size: tuple = (384, 384), normalize_mean=0.5, normalize_std=0.5):
53
  image = image.resize(size, Image.Resampling.BILINEAR)
@@ -61,31 +56,24 @@ def _img_preprocess_for_onnx(image: Image.Image, size: tuple = (384, 384), norma
61
  onnx_sessions_cache = {}
62
  def get_onnx_session_and_meta(repo_id, model_subfolder, current_log_list):
63
  cache_key = f"{repo_id}/{model_subfolder}"
64
- if cache_key in onnx_sessions_cache:
65
- return onnx_sessions_cache[cache_key]
66
-
67
  if not onnxruntime:
68
  msg = f"ERROR: ONNX Runtime not available for get_onnx_session_and_meta ({cache_key}). Skipping."
69
- print(msg)
70
- current_log_list.append(msg)
71
- onnx_sessions_cache[cache_key] = (None, [], None) # Cache error state
72
  return None, [], None
73
-
74
  try:
75
  msg = f"INFO: Loading ONNX model {repo_id}/{model_subfolder}..."
76
  print(msg); current_log_list.append(msg)
77
  model_path = hf_hub_download(repo_id, filename=f"{model_subfolder}/model.onnx")
78
  meta_path = hf_hub_download(repo_id, filename=f"{model_subfolder}/meta.json")
79
-
80
  options = onnxruntime.SessionOptions()
81
  options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
82
  if ONNX_EXECUTION_PROVIDER == "CPUExecutionProvider" and hasattr(os, 'cpu_count'):
83
  options.intra_op_num_threads = os.cpu_count()
84
-
85
  session = onnxruntime.InferenceSession(model_path, options, providers=[ONNX_EXECUTION_PROVIDER])
86
  with open(meta_path, 'r') as f: meta = json.load(f)
87
  labels = meta.get('labels', [])
88
-
89
  msg = f"INFO: ONNX model {cache_key} loaded successfully with provider {ONNX_EXECUTION_PROVIDER}."
90
  print(msg); current_log_list.append(msg)
91
  onnx_sessions_cache[cache_key] = (session, labels, meta)
@@ -96,29 +84,25 @@ def get_onnx_session_and_meta(repo_id, model_subfolder, current_log_list):
96
  onnx_sessions_cache[cache_key] = (None, [], None)
97
  return None, [], None
98
 
99
- # --- Модели PyTorch и Transformers ---
100
- # 1. ImageReward
101
  reward_processor, reward_model = None, None
102
- try:
103
- print("INFO: Loading THUDM/ImageReward model...")
104
- reward_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
105
- reward_model = AutoModelForImageClassification.from_pretrained("THUDM/ImageReward").to(DEVICE)
106
- reward_model.eval()
107
- print("INFO: THUDM/ImageReward loaded successfully.")
108
- except Exception as e:
109
- print(f"ERROR: Failed to load THUDM/ImageReward: {e}")
 
 
110
 
111
- # 2. Anime Aesthetic (deepghs ONNX) - Константы
112
  ANIME_AESTHETIC_REPO = "deepghs/anime_aesthetic"
113
  ANIME_AESTHETIC_SUBFOLDER = "swinv2pv3_v0_448_ls0.2_x"
114
  ANIME_AESTHETIC_IMG_SIZE = (448, 448)
115
  ANIME_AESTHETIC_LABEL_WEIGHTS = {"normal": 0.0, "slight": 1.0, "moderate": 2.0, "strong": 3.0, "extreme": 4.0}
116
-
117
- # 3. MANIQA (Technical Quality) - ВРЕМЕННО ОТКЛЮЧЕНО
118
- # maniqa_pipe = None (уже объявлено в глобальной области видимости неявно)
119
  print("INFO: MANIQA (honklers/maniqa-nr) is currently disabled.")
120
 
121
- # 4. CLIP Score (laion/CLIP-ViT-L-14-laion2B-s32B-b82K) - open_clip
122
  clip_model_instance, clip_preprocess, clip_tokenizer = None, None, None
123
  try:
124
  clip_model_name = 'ViT-L-14'
@@ -133,8 +117,6 @@ try:
133
  except Exception as e:
134
  print(f"ERROR: Failed to load CLIP model {clip_model_name} (laion2b_s32b_b82k): {e}")
135
 
136
- # 5. AI Detectors
137
- # Organika/sdxl-detector - Transformers pipeline
138
  sdxl_detector_pipe = None
139
  try:
140
  print("INFO: Loading Organika/sdxl-detector model...")
@@ -143,24 +125,18 @@ try:
143
  except Exception as e:
144
  print(f"ERROR: Failed to load Organika/sdxl-detector: {e}")
145
 
146
- # deepghs/anime_ai_check - ONNX - Константы
147
  ANIME_AI_CHECK_REPO = "deepghs/anime_ai_check"
148
  ANIME_AI_CHECK_SUBFOLDER = "caformer_s36_plus_sce"
149
  ANIME_AI_CHECK_IMG_SIZE = (384, 384)
150
 
151
-
152
- # --- Функции извлечения метаданных (без изменений в логике, только print) ---
153
  def extract_sd_parameters(image_pil, filename_for_log, current_log_list):
154
- # ... (остальной код extract_sd_parameters без изменений)
155
  if image_pil is None: return "", "N/A", "N/A", "N/A", {}
156
  parameters_str = image_pil.info.get("parameters", "")
157
  if not parameters_str:
158
  current_log_list.append(f"DEBUG [{filename_for_log}]: No metadata found in image.")
159
  return "", "N/A", "N/A", "N/A", {}
160
-
161
- current_log_list.append(f"DEBUG [{filename_for_log}]: Raw metadata: {parameters_str[:100]}...") # Логируем начало
162
  prompt, negative_prompt, model_name, model_hash, other_params_dict = "", "N/A", "N/A", "N/A", {}
163
- # ... (остальной парсинг)
164
  try:
165
  neg_prompt_index = parameters_str.find("Negative prompt:")
166
  steps_meta_index = parameters_str.find("Steps:")
@@ -179,10 +155,8 @@ def extract_sd_parameters(image_pil, filename_for_log, current_log_list):
179
  prompt = parameters_str[:steps_meta_index].strip()
180
  params_part = parameters_str[steps_meta_index:]
181
  else:
182
- prompt = parameters_str.strip() # Весь текст - промпт
183
- params_part = "" # Нет блока параметров
184
-
185
- if params_part: # Если есть блок параметров после Steps:
186
  params_list = [p.strip() for p in params_part.split(",")]
187
  temp_other_params = {}
188
  for param_val_str in params_list:
@@ -194,57 +168,51 @@ def extract_sd_parameters(image_pil, filename_for_log, current_log_list):
194
  elif key.lower() == "model hash": model_hash = value
195
  for k,v in temp_other_params.items():
196
  if k.lower() not in ["model", "model hash"]: other_params_dict[k] = v
197
-
198
  if model_name == "N/A" and model_hash != "N/A": model_name = f"hash_{model_hash}"
199
  if model_name == "N/A" and "Checkpoint" in other_params_dict: model_name = other_params_dict["Checkpoint"]
200
  if model_name == "N/A" and "model" in other_params_dict: model_name = other_params_dict["model"]
201
  current_log_list.append(f"DEBUG [{filename_for_log}]: Parsed Prompt: {prompt[:50]}... | Model: {model_name}")
202
-
203
  except Exception as e:
204
  current_log_list.append(f"ERROR [{filename_for_log}]: Failed to parse metadata: {e}")
205
  return prompt, negative_prompt, model_name, model_hash, other_params_dict
206
 
207
- # --- Функции оценки (добавлено логирование и замер времени) ---
208
  @torch.no_grad()
209
  def get_image_reward(image_pil, filename_for_log, current_log_list):
210
- if not reward_model or not reward_processor:
211
- current_log_list.append(f"INFO [{filename_for_log}]: ImageReward model not loaded, skipping.")
212
- return "N/A"
213
- t_start = time.time()
214
- current_log_list.append(f"DEBUG [{filename_for_log}]: Starting ImageReward score (PyTorch Device: {DEVICE})...")
215
- try:
216
- inputs = reward_processor(images=image_pil, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
217
- outputs = reward_model(**inputs)
218
- score = round(outputs.logits.item(), 4)
219
- t_end = time.time()
220
- current_log_list.append(f"DEBUG [{filename_for_log}]: ImageReward score: {score} (took {t_end - t_start:.2f}s)")
221
- return score
222
- except Exception as e:
223
- current_log_list.append(f"ERROR [{filename_for_log}]: ImageReward scoring failed: {e}")
224
- return "Error"
 
 
225
 
226
  def get_anime_aesthetic_score_deepghs(image_pil, filename_for_log, current_log_list):
227
  session, labels, meta = get_onnx_session_and_meta(ANIME_AESTHETIC_REPO, ANIME_AESTHETIC_SUBFOLDER, current_log_list)
228
  if not session or not labels:
229
  current_log_list.append(f"INFO [{filename_for_log}]: AnimeAesthetic ONNX model not loaded, skipping.")
230
  return "N/A"
231
- t_start = time.time()
232
- current_log_list.append(f"DEBUG [{filename_for_log}]: Starting AnimeAesthetic (ONNX) score...")
233
  try:
234
  input_data = _img_preprocess_for_onnx(image_pil.copy(), size=ANIME_AESTHETIC_IMG_SIZE)
235
- input_name = session.get_inputs()[0].name
236
- output_name = session.get_outputs()[0].name
237
  onnx_output, = session.run([output_name], {input_name: input_data})
238
- scores = onnx_output[0]
239
- exp_scores = np.exp(scores - np.max(scores)); probabilities = exp_scores / np.sum(exp_scores)
240
  weighted_score = sum(probabilities[i] * ANIME_AESTHETIC_LABEL_WEIGHTS.get(label, 0.0) for i, label in enumerate(labels))
241
- score = round(weighted_score, 4)
242
- t_end = time.time()
243
  current_log_list.append(f"DEBUG [{filename_for_log}]: AnimeAesthetic (ONNX) score: {score} (took {t_end - t_start:.2f}s)")
244
  return score
245
  except Exception as e:
246
- current_log_list.append(f"ERROR [{filename_for_log}]: AnimeAesthetic (ONNX) scoring failed: {e}")
247
- return "Error"
248
 
249
  @torch.no_grad()
250
  def get_maniqa_score(image_pil, filename_for_log, current_log_list):
@@ -259,72 +227,56 @@ def calculate_clip_score_value(image_pil, prompt_text, filename_for_log, current
259
  if not prompt_text or prompt_text == "N/A":
260
  current_log_list.append(f"INFO [{filename_for_log}]: Empty prompt, skipping CLIPScore.")
261
  return "N/A (Empty Prompt)"
262
-
263
- t_start = time.time()
264
- current_log_list.append(f"DEBUG [{filename_for_log}]: Starting CLIPScore (PyTorch Device: {DEVICE})...")
265
  try:
266
  image_input = clip_preprocess(image_pil).unsqueeze(0).to(DEVICE)
267
- text_for_tokenizer = str(prompt_text)
268
- text_input = clip_tokenizer([text_for_tokenizer]).to(DEVICE)
269
- image_features = clip_model_instance.encode_image(image_input)
270
- text_features = clip_model_instance.encode_text(text_input)
271
  image_features_norm = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
272
  text_features_norm = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
273
  score_val = (text_features_norm @ image_features_norm.T).squeeze().item() * 100.0
274
- score = round(score_val, 2)
275
- t_end = time.time()
276
  current_log_list.append(f"DEBUG [{filename_for_log}]: CLIPScore: {score} (took {t_end - t_start:.2f}s)")
277
  return score
278
  except Exception as e:
279
- current_log_list.append(f"ERROR [{filename_for_log}]: CLIPScore calculation failed: {e}")
280
- return "Error"
281
 
282
  @torch.no_grad()
283
  def get_sdxl_detection_score(image_pil, filename_for_log, current_log_list):
284
  if not sdxl_detector_pipe:
285
  current_log_list.append(f"INFO [{filename_for_log}]: SDXL_Detector model not loaded, skipping.")
286
  return "N/A"
287
- t_start = time.time()
288
- current_log_list.append(f"DEBUG [{filename_for_log}]: Starting SDXL_Detector score (Device for pipeline: {sdxl_detector_pipe.device})...")
289
  try:
290
- result = sdxl_detector_pipe(image_pil.copy())
291
- ai_score_val = 0.0
292
  for item in result:
293
  if item['label'].lower() == 'artificial': ai_score_val = item['score']; break
294
- score = round(ai_score_val, 4)
295
- t_end = time.time()
296
  current_log_list.append(f"DEBUG [{filename_for_log}]: SDXL_Detector AI Prob: {score} (took {t_end - t_start:.2f}s)")
297
  return score
298
  except Exception as e:
299
- current_log_list.append(f"ERROR [{filename_for_log}]: SDXL_Detector scoring failed: {e}")
300
- return "Error"
301
 
302
  def get_anime_ai_check_score_deepghs(image_pil, filename_for_log, current_log_list):
303
  session, labels, meta = get_onnx_session_and_meta(ANIME_AI_CHECK_REPO, ANIME_AI_CHECK_SUBFOLDER, current_log_list)
304
  if not session or not labels:
305
  current_log_list.append(f"INFO [{filename_for_log}]: AnimeAI_Check ONNX model not loaded, skipping.")
306
  return "N/A"
307
- t_start = time.time()
308
- current_log_list.append(f"DEBUG [{filename_for_log}]: Starting AnimeAI_Check (ONNX) score...")
309
  try:
310
  input_data = _img_preprocess_for_onnx(image_pil.copy(), size=ANIME_AI_CHECK_IMG_SIZE)
311
- input_name = session.get_inputs()[0].name
312
- output_name = session.get_outputs()[0].name
313
  onnx_output, = session.run([output_name], {input_name: input_data})
314
- scores = onnx_output[0]
315
- exp_scores = np.exp(scores - np.max(scores)); probabilities = exp_scores / np.sum(exp_scores)
316
  ai_prob_val = 0.0
317
  for i, label in enumerate(labels):
318
  if label.lower() == 'ai': ai_prob_val = probabilities[i]; break
319
- score = round(ai_prob_val, 4)
320
- t_end = time.time()
321
  current_log_list.append(f"DEBUG [{filename_for_log}]: AnimeAI_Check (ONNX) AI Prob: {score} (took {t_end - t_start:.2f}s)")
322
  return score
323
  except Exception as e:
324
- current_log_list.append(f"ERROR [{filename_for_log}]: AnimeAI_Check (ONNX) scoring failed: {e}")
325
- return "Error"
326
 
327
- # --- Основная функция обработки (стала генератором) ---
328
  def process_images_generator(files, progress=gr.Progress(track_tqdm=True)):
329
  if not files:
330
  yield pd.DataFrame(), None, None, None, None, "Please upload some images.", "No files to process."
@@ -332,52 +284,52 @@ def process_images_generator(files, progress=gr.Progress(track_tqdm=True)):
332
 
333
  all_results = []
334
  log_accumulator = [f"INFO: Starting processing for {len(files)} images..."]
335
- yield pd.DataFrame(), None, None, None, None, "Processing...", "\n".join(log_accumulator)
336
-
 
 
337
 
338
  for i, file_obj in enumerate(files):
339
  filename_for_log = "Unknown File"
340
  current_img_total_time_start = time.time()
341
  try:
342
- filename_for_log = os.path.basename(getattr(file_obj, 'name', f"file_{i}_{time.time()}"))
343
  log_accumulator.append(f"--- Processing image {i+1}/{len(files)}: {filename_for_log} ---")
344
 
345
- # Обновляем UI перед началом обработки файла
346
- progress.update(amount=(i+1)/len(files), desc=f"Img {i+1}/{len(files)}: {filename_for_log}")
347
- yield (pd.DataFrame(all_results), None, None, None, None,
348
- f"Processing image {i+1}/{len(files)}: {filename_for_log}",
 
 
349
  "\n".join(log_accumulator))
350
 
351
  img = Image.open(getattr(file_obj, 'name', str(file_obj)))
352
  if img.mode != "RGB": img = img.convert("RGB")
 
353
 
354
- prompt, neg_prompt, model_n, model_h, other_p = extract_sd_parameters(img, filename_for_log, log_accumulator)
355
 
 
356
  reward = get_image_reward(img, filename_for_log, log_accumulator)
357
  anime_aes_deepghs = get_anime_aesthetic_score_deepghs(img, filename_for_log, log_accumulator)
358
  maniqa = get_maniqa_score(img, filename_for_log, log_accumulator)
359
  clip_val = calculate_clip_score_value(img, prompt, filename_for_log, log_accumulator)
360
  sdxl_detect = get_sdxl_detection_score(img, filename_for_log, log_accumulator)
361
  anime_ai_chk_deepghs = get_anime_ai_check_score_deepghs(img, filename_for_log, log_accumulator)
362
-
363
  current_img_total_time_end = time.time()
364
  log_accumulator.append(f"INFO [{filename_for_log}]: Finished all scores (total for image: {current_img_total_time_end - current_img_total_time_start:.2f}s)")
365
 
366
-
367
  all_results.append({
368
  "Filename": filename_for_log, "Prompt": prompt if prompt else "N/A", "Model Name": model_n, "Model Hash": model_h,
369
  "ImageReward": reward, "AnimeAesthetic_dg": anime_aes_deepghs, "MANIQA_TQ": maniqa,
370
  "CLIPScore": clip_val, "SDXL_Detector_AI_Prob": sdxl_detect, "AnimeAI_Check_dg_Prob": anime_ai_chk_deepghs,
371
  })
372
-
373
- # Обновляем UI после обработки каждого файла с текущими результатами
374
- # Графики и файлы для скачивания будут генерироваться только в конце
375
- # Но можно передавать df для обновления таблицы
376
  df_so_far = pd.DataFrame(all_results)
377
- yield (df_so_far, None, None, None, None, # Пока без графиков и файлов
 
 
378
  f"Processed image {i+1}/{len(files)}: {filename_for_log}",
379
  "\n".join(log_accumulator))
380
-
381
  except Exception as e:
382
  log_accumulator.append(f"CRITICAL ERROR processing {filename_for_log}: {e}")
383
  print(f"CRITICAL ERROR processing {filename_for_log}: {e}")
@@ -387,23 +339,25 @@ def process_images_generator(files, progress=gr.Progress(track_tqdm=True)):
387
  "CLIPScore": "Error", "SDXL_Detector_AI_Prob": "Error", "AnimeAI_Check_dg_Prob": "Error"
388
  })
389
  df_so_far = pd.DataFrame(all_results)
390
- yield (df_so_far, None, None, None, None,
 
391
  f"Error on image {i+1}/{len(files)}: {filename_for_log}",
392
  "\n".join(log_accumulator))
393
 
394
  log_accumulator.append("--- Generating final plots and download files ---")
395
- yield (pd.DataFrame(all_results), None, None, None, None,
 
 
396
  "Generating final plots...",
397
  "\n".join(log_accumulator))
398
 
399
  df = pd.DataFrame(all_results)
400
  plot_model_avg_scores_buffer, plot_prompt_clip_scores_buffer = None, None
401
- csv_buffer_val, json_buffer_val = "", ""
402
 
403
  if not df.empty:
404
- numeric_cols = ["ImageReward", "AnimeAesthetic_dg", "MANIQA_TQ", "CLIPScore"] # MANIQA TQ будет NaN, нормально
405
  for col in numeric_cols: df[col] = pd.to_numeric(df[col], errors='coerce')
406
-
407
  df_model_plot = df[(df["Model Name"] != "N/A") & (df["Model Name"].notna())]
408
  if not df_model_plot.empty and df_model_plot["Model Name"].nunique() > 0:
409
  try:
@@ -415,7 +369,6 @@ def process_images_generator(files, progress=gr.Progress(track_tqdm=True)):
415
  plot_model_avg_scores_buffer = io.BytesIO(); fig1.savefig(plot_model_avg_scores_buffer, format="png"); plot_model_avg_scores_buffer.seek(0); plt.close(fig1)
416
  log_accumulator.append("INFO: Model average scores plot generated.")
417
  except Exception as e: log_accumulator.append(f"ERROR: Failed to generate model average scores plot: {e}")
418
-
419
  df_prompt_plot = df[(df["Prompt"] != "N/A") & (df["Prompt"].notna()) & (df["CLIPScore"].notna())]
420
  if not df_prompt_plot.empty and df_prompt_plot["Prompt"].nunique() > 0 :
421
  try:
@@ -429,80 +382,62 @@ def process_images_generator(files, progress=gr.Progress(track_tqdm=True)):
429
  log_accumulator.append("INFO: Prompt CLIP scores plot generated.")
430
  except Exception as e: log_accumulator.append(f"ERROR: Failed to generate prompt CLIP scores plot: {e}")
431
 
432
- csv_b = io.StringIO(); df.to_csv(csv_b, index=False); csv_buffer_val = csv_b.getvalue()
433
- json_b = io.StringIO(); df.to_json(json_b, orient='records', indent=4); json_buffer_val = json_b.getvalue()
434
- log_accumulator.append("INFO: CSV and JSON data prepared for download.")
 
 
 
 
 
 
 
 
 
435
 
436
- final_status = f"Finished processing {len(all_results)} images. Total time: {sum(entry.get('total_time', 0) for entry in all_results):.2f}s (approx, if times were logged per image)"
437
- # ^Это не совсем точно, т.к. total_time не собирается в entry, но идея понятна
438
  log_accumulator.append(final_status)
439
 
440
  yield (
441
  df,
442
- gr.Image(value=plot_model_avg_scores_buffer, type="pil", visible=plot_model_avg_scores_buffer is not None),
443
- gr.Image(value=plot_prompt_clip_scores_buffer, type="pil", visible=plot_prompt_clip_scores_buffer is not None),
444
- gr.File(value=csv_buffer_val or None, label="Download CSV Results", visible=bool(csv_buffer_val), file_name="evaluation_results.csv"),
445
- gr.File(value=json_buffer_val or None, label="Download JSON Results", visible=bool(json_buffer_val), file_name="evaluation_results.json"),
446
  final_status,
447
  "\n".join(log_accumulator)
448
  )
449
 
 
450
 
451
- # --- Интерфейс Gradio ---
452
  with gr.Blocks(css="footer {display: none !important}") as demo:
453
  gr.Markdown("# AI Image Model Evaluation Tool")
454
  gr.Markdown("Upload PNG images (ideally with Stable Diffusion metadata) to evaluate them...")
455
-
456
- with gr.Row():
457
- image_uploader = gr.Files(
458
- label="Upload Images (PNG)",
459
- file_count="multiple",
460
- file_types=["image"]
461
- )
462
-
463
  process_button = gr.Button("Evaluate Images", variant="primary")
464
-
465
  status_textbox = gr.Textbox(label="Overall Status", interactive=False)
466
-
467
- log_output_textbox = gr.Textbox(label="Detailed Logs", lines=15, interactive=False, autoscroll=True) # Новый логгер
468
-
469
  gr.Markdown("## Evaluation Results Table")
470
  results_table = gr.DataFrame(headers=[
471
  "Filename", "Prompt", "Model Name", "Model Hash", "ImageReward", "AnimeAesthetic_dg",
472
  "MANIQA_TQ", "CLIPScore", "SDXL_Detector_AI_Prob", "AnimeAI_Check_dg_Prob"
473
  ], wrap=True)
474
-
475
  with gr.Row():
476
- download_csv_button = gr.File(label="Download CSV Results", interactive=False)
477
- download_json_button = gr.File(label="Download JSON Results", interactive=False)
478
-
479
  gr.Markdown("## Visualizations")
480
  with gr.Row():
481
  plot_output_model_avg = gr.Image(label="Average Scores per Model", type="pil", interactive=False)
482
  plot_output_prompt_clip = gr.Image(label="Average CLIPScore per Prompt", type="pil", interactive=False)
483
-
484
  process_button.click(
485
- fn=process_images_generator, # Изменено на генератор
486
- inputs=[image_uploader],
487
- outputs=[
488
- results_table,
489
- plot_output_model_avg,
490
- plot_output_prompt_clip,
491
- download_csv_button,
492
- download_json_button,
493
- status_textbox,
494
- log_output_textbox # Добавлен вывод для логов
495
- ]
496
  )
497
-
498
  gr.Markdown("""**Metric Explanations:** ... (без изменений)""")
499
 
500
  if __name__ == "__main__":
501
- # Загрузка моделей при старте (вне функции Gradio)
502
  print("--- Initializing models, please wait... ---")
503
- # Вызов функций загрузки ONNX моделей, чтобы они кэшировались при старте, если возможно
504
- # Это не будет выводить логи в UI, только в консоль сервера при запуске.
505
- # Но поможет понять, загружаются ли они вообще.
506
  initial_dummy_logs = []
507
  if onnxruntime:
508
  get_onnx_session_and_meta(ANIME_AESTHETIC_REPO, ANIME_AESTHETIC_SUBFOLDER, initial_dummy_logs)
@@ -512,5 +447,4 @@ if __name__ == "__main__":
512
  for log_line in initial_dummy_logs: print(log_line)
513
  print("-----------------------------------------")
514
  print("--- Model initialization attempt complete. Launching Gradio. ---")
515
-
516
- demo.queue().launch(debug=True) # queue() важен для генераторов
 
4
  import os
5
  import pandas as pd
6
  import torch
7
+ # from transformers import pipeline as transformers_pipeline , AutoModelForImageClassification, CLIPImageProcessor # ImageReward пока отключен
8
+ from transformers import pipeline as transformers_pipeline , CLIPImageProcessor # Убрали AutoModelForImageClassification
9
  import open_clip
10
  import re
11
  import matplotlib.pyplot as plt
 
13
  from collections import defaultdict
14
  import numpy as np
15
  import logging
16
+ import time
17
 
18
  # --- ONNX Related Imports and Setup ---
19
  try:
 
24
 
25
  from huggingface_hub import hf_hub_download
26
 
 
27
  try:
28
  from imgutils.data import rgb_encode
29
  IMGUTILS_AVAILABLE = True
 
37
  img_arr = np.transpose(img_arr, (2, 0, 1))
38
  return img_arr.astype(np.uint8)
39
 
 
40
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
41
  print(f"INFO: PyTorch Device: {DEVICE}")
42
  ONNX_EXECUTION_PROVIDER = "CUDAExecutionProvider" if DEVICE == "cuda" and onnxruntime and "CUDAExecutionProvider" in onnxruntime.get_available_providers() else "CPUExecutionProvider"
43
+ if onnxruntime: print(f"INFO: ONNX Execution Provider: {ONNX_EXECUTION_PROVIDER}")
44
+ else: print("INFO: ONNX Runtime not available, ONNX models will be skipped.")
 
 
45
 
 
 
46
  @torch.no_grad()
47
  def _img_preprocess_for_onnx(image: Image.Image, size: tuple = (384, 384), normalize_mean=0.5, normalize_std=0.5):
48
  image = image.resize(size, Image.Resampling.BILINEAR)
 
56
  onnx_sessions_cache = {}
57
  def get_onnx_session_and_meta(repo_id, model_subfolder, current_log_list):
58
  cache_key = f"{repo_id}/{model_subfolder}"
59
+ if cache_key in onnx_sessions_cache: return onnx_sessions_cache[cache_key]
 
 
60
  if not onnxruntime:
61
  msg = f"ERROR: ONNX Runtime not available for get_onnx_session_and_meta ({cache_key}). Skipping."
62
+ print(msg); current_log_list.append(msg)
63
+ onnx_sessions_cache[cache_key] = (None, [], None)
 
64
  return None, [], None
 
65
  try:
66
  msg = f"INFO: Loading ONNX model {repo_id}/{model_subfolder}..."
67
  print(msg); current_log_list.append(msg)
68
  model_path = hf_hub_download(repo_id, filename=f"{model_subfolder}/model.onnx")
69
  meta_path = hf_hub_download(repo_id, filename=f"{model_subfolder}/meta.json")
 
70
  options = onnxruntime.SessionOptions()
71
  options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
72
  if ONNX_EXECUTION_PROVIDER == "CPUExecutionProvider" and hasattr(os, 'cpu_count'):
73
  options.intra_op_num_threads = os.cpu_count()
 
74
  session = onnxruntime.InferenceSession(model_path, options, providers=[ONNX_EXECUTION_PROVIDER])
75
  with open(meta_path, 'r') as f: meta = json.load(f)
76
  labels = meta.get('labels', [])
 
77
  msg = f"INFO: ONNX model {cache_key} loaded successfully with provider {ONNX_EXECUTION_PROVIDER}."
78
  print(msg); current_log_list.append(msg)
79
  onnx_sessions_cache[cache_key] = (session, labels, meta)
 
84
  onnx_sessions_cache[cache_key] = (None, [], None)
85
  return None, [], None
86
 
87
+ # 1. ImageReward - ВРЕМЕННО ОТКЛЮЧЕНО
 
88
  reward_processor, reward_model = None, None
89
+ print("INFO: THUDM/ImageReward is temporarily disabled due to loading issues.")
90
+ # try:
91
+ # print("INFO: Loading THUDM/ImageReward model...")
92
+ # # reward_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
93
+ # # reward_model = AutoModelForImageClassification.from_pretrained("THUDM/ImageReward", trust_remote_code=True).to(DEVICE) # Попытка с trust_remote_code
94
+ # # reward_model.eval()
95
+ # # print("INFO: THUDM/ImageReward loaded successfully.")
96
+ # except Exception as e:
97
+ # print(f"ERROR: Failed to load THUDM/ImageReward: {e}")
98
+
99
 
 
100
  ANIME_AESTHETIC_REPO = "deepghs/anime_aesthetic"
101
  ANIME_AESTHETIC_SUBFOLDER = "swinv2pv3_v0_448_ls0.2_x"
102
  ANIME_AESTHETIC_IMG_SIZE = (448, 448)
103
  ANIME_AESTHETIC_LABEL_WEIGHTS = {"normal": 0.0, "slight": 1.0, "moderate": 2.0, "strong": 3.0, "extreme": 4.0}
 
 
 
104
  print("INFO: MANIQA (honklers/maniqa-nr) is currently disabled.")
105
 
 
106
  clip_model_instance, clip_preprocess, clip_tokenizer = None, None, None
107
  try:
108
  clip_model_name = 'ViT-L-14'
 
117
  except Exception as e:
118
  print(f"ERROR: Failed to load CLIP model {clip_model_name} (laion2b_s32b_b82k): {e}")
119
 
 
 
120
  sdxl_detector_pipe = None
121
  try:
122
  print("INFO: Loading Organika/sdxl-detector model...")
 
125
  except Exception as e:
126
  print(f"ERROR: Failed to load Organika/sdxl-detector: {e}")
127
 
 
128
  ANIME_AI_CHECK_REPO = "deepghs/anime_ai_check"
129
  ANIME_AI_CHECK_SUBFOLDER = "caformer_s36_plus_sce"
130
  ANIME_AI_CHECK_IMG_SIZE = (384, 384)
131
 
 
 
132
  def extract_sd_parameters(image_pil, filename_for_log, current_log_list):
 
133
  if image_pil is None: return "", "N/A", "N/A", "N/A", {}
134
  parameters_str = image_pil.info.get("parameters", "")
135
  if not parameters_str:
136
  current_log_list.append(f"DEBUG [{filename_for_log}]: No metadata found in image.")
137
  return "", "N/A", "N/A", "N/A", {}
138
+ current_log_list.append(f"DEBUG [{filename_for_log}]: Raw metadata: {parameters_str[:100]}...")
 
139
  prompt, negative_prompt, model_name, model_hash, other_params_dict = "", "N/A", "N/A", "N/A", {}
 
140
  try:
141
  neg_prompt_index = parameters_str.find("Negative prompt:")
142
  steps_meta_index = parameters_str.find("Steps:")
 
155
  prompt = parameters_str[:steps_meta_index].strip()
156
  params_part = parameters_str[steps_meta_index:]
157
  else:
158
+ prompt = parameters_str.strip(); params_part = ""
159
+ if params_part:
 
 
160
  params_list = [p.strip() for p in params_part.split(",")]
161
  temp_other_params = {}
162
  for param_val_str in params_list:
 
168
  elif key.lower() == "model hash": model_hash = value
169
  for k,v in temp_other_params.items():
170
  if k.lower() not in ["model", "model hash"]: other_params_dict[k] = v
 
171
  if model_name == "N/A" and model_hash != "N/A": model_name = f"hash_{model_hash}"
172
  if model_name == "N/A" and "Checkpoint" in other_params_dict: model_name = other_params_dict["Checkpoint"]
173
  if model_name == "N/A" and "model" in other_params_dict: model_name = other_params_dict["model"]
174
  current_log_list.append(f"DEBUG [{filename_for_log}]: Parsed Prompt: {prompt[:50]}... | Model: {model_name}")
 
175
  except Exception as e:
176
  current_log_list.append(f"ERROR [{filename_for_log}]: Failed to parse metadata: {e}")
177
  return prompt, negative_prompt, model_name, model_hash, other_params_dict
178
 
 
179
  @torch.no_grad()
180
  def get_image_reward(image_pil, filename_for_log, current_log_list):
181
+ # current_log_list.append(f"INFO [{filename_for_log}]: ImageReward model not loaded (disabled), skipping.")
182
+ return "N/A (Disabled)" # Временно отключено
183
+ # if not reward_model or not reward_processor:
184
+ # current_log_list.append(f"INFO [{filename_for_log}]: ImageReward model not loaded, skipping.")
185
+ # return "N/A"
186
+ # t_start = time.time()
187
+ # current_log_list.append(f"DEBUG [{filename_for_log}]: Starting ImageReward score (PyTorch Device: {DEVICE})...")
188
+ # try:
189
+ # inputs = reward_processor(images=image_pil, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
190
+ # outputs = reward_model(**inputs)
191
+ # score = round(outputs.logits.item(), 4)
192
+ # t_end = time.time()
193
+ # current_log_list.append(f"DEBUG [{filename_for_log}]: ImageReward score: {score} (took {t_end - t_start:.2f}s)")
194
+ # return score
195
+ # except Exception as e:
196
+ # current_log_list.append(f"ERROR [{filename_for_log}]: ImageReward scoring failed: {e}")
197
+ # return "Error"
198
 
199
  def get_anime_aesthetic_score_deepghs(image_pil, filename_for_log, current_log_list):
200
  session, labels, meta = get_onnx_session_and_meta(ANIME_AESTHETIC_REPO, ANIME_AESTHETIC_SUBFOLDER, current_log_list)
201
  if not session or not labels:
202
  current_log_list.append(f"INFO [{filename_for_log}]: AnimeAesthetic ONNX model not loaded, skipping.")
203
  return "N/A"
204
+ t_start = time.time(); current_log_list.append(f"DEBUG [{filename_for_log}]: Starting AnimeAesthetic (ONNX) score...")
 
205
  try:
206
  input_data = _img_preprocess_for_onnx(image_pil.copy(), size=ANIME_AESTHETIC_IMG_SIZE)
207
+ input_name = session.get_inputs()[0].name; output_name = session.get_outputs()[0].name
 
208
  onnx_output, = session.run([output_name], {input_name: input_data})
209
+ scores = onnx_output[0]; exp_scores = np.exp(scores - np.max(scores)); probabilities = exp_scores / np.sum(exp_scores)
 
210
  weighted_score = sum(probabilities[i] * ANIME_AESTHETIC_LABEL_WEIGHTS.get(label, 0.0) for i, label in enumerate(labels))
211
+ score = round(weighted_score, 4); t_end = time.time()
 
212
  current_log_list.append(f"DEBUG [{filename_for_log}]: AnimeAesthetic (ONNX) score: {score} (took {t_end - t_start:.2f}s)")
213
  return score
214
  except Exception as e:
215
+ current_log_list.append(f"ERROR [{filename_for_log}]: AnimeAesthetic (ONNX) scoring failed: {e}"); return "Error"
 
216
 
217
  @torch.no_grad()
218
  def get_maniqa_score(image_pil, filename_for_log, current_log_list):
 
227
  if not prompt_text or prompt_text == "N/A":
228
  current_log_list.append(f"INFO [{filename_for_log}]: Empty prompt, skipping CLIPScore.")
229
  return "N/A (Empty Prompt)"
230
+ t_start = time.time(); current_log_list.append(f"DEBUG [{filename_for_log}]: Starting CLIPScore (PyTorch Device: {DEVICE})...")
 
 
231
  try:
232
  image_input = clip_preprocess(image_pil).unsqueeze(0).to(DEVICE)
233
+ text_for_tokenizer = str(prompt_text); text_input = clip_tokenizer([text_for_tokenizer]).to(DEVICE)
234
+ image_features = clip_model_instance.encode_image(image_input); text_features = clip_model_instance.encode_text(text_input)
 
 
235
  image_features_norm = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
236
  text_features_norm = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
237
  score_val = (text_features_norm @ image_features_norm.T).squeeze().item() * 100.0
238
+ score = round(score_val, 2); t_end = time.time()
 
239
  current_log_list.append(f"DEBUG [{filename_for_log}]: CLIPScore: {score} (took {t_end - t_start:.2f}s)")
240
  return score
241
  except Exception as e:
242
+ current_log_list.append(f"ERROR [{filename_for_log}]: CLIPScore calculation failed: {e}"); return "Error"
 
243
 
244
  @torch.no_grad()
245
  def get_sdxl_detection_score(image_pil, filename_for_log, current_log_list):
246
  if not sdxl_detector_pipe:
247
  current_log_list.append(f"INFO [{filename_for_log}]: SDXL_Detector model not loaded, skipping.")
248
  return "N/A"
249
+ t_start = time.time(); current_log_list.append(f"DEBUG [{filename_for_log}]: Starting SDXL_Detector score (Device: {sdxl_detector_pipe.device})...")
 
250
  try:
251
+ result = sdxl_detector_pipe(image_pil.copy()); ai_score_val = 0.0
 
252
  for item in result:
253
  if item['label'].lower() == 'artificial': ai_score_val = item['score']; break
254
+ score = round(ai_score_val, 4); t_end = time.time()
 
255
  current_log_list.append(f"DEBUG [{filename_for_log}]: SDXL_Detector AI Prob: {score} (took {t_end - t_start:.2f}s)")
256
  return score
257
  except Exception as e:
258
+ current_log_list.append(f"ERROR [{filename_for_log}]: SDXL_Detector scoring failed: {e}"); return "Error"
 
259
 
260
  def get_anime_ai_check_score_deepghs(image_pil, filename_for_log, current_log_list):
261
  session, labels, meta = get_onnx_session_and_meta(ANIME_AI_CHECK_REPO, ANIME_AI_CHECK_SUBFOLDER, current_log_list)
262
  if not session or not labels:
263
  current_log_list.append(f"INFO [{filename_for_log}]: AnimeAI_Check ONNX model not loaded, skipping.")
264
  return "N/A"
265
+ t_start = time.time(); current_log_list.append(f"DEBUG [{filename_for_log}]: Starting AnimeAI_Check (ONNX) score...")
 
266
  try:
267
  input_data = _img_preprocess_for_onnx(image_pil.copy(), size=ANIME_AI_CHECK_IMG_SIZE)
268
+ input_name = session.get_inputs()[0].name; output_name = session.get_outputs()[0].name
 
269
  onnx_output, = session.run([output_name], {input_name: input_data})
270
+ scores = onnx_output[0]; exp_scores = np.exp(scores - np.max(scores)); probabilities = exp_scores / np.sum(exp_scores)
 
271
  ai_prob_val = 0.0
272
  for i, label in enumerate(labels):
273
  if label.lower() == 'ai': ai_prob_val = probabilities[i]; break
274
+ score = round(ai_prob_val, 4); t_end = time.time()
 
275
  current_log_list.append(f"DEBUG [{filename_for_log}]: AnimeAI_Check (ONNX) AI Prob: {score} (took {t_end - t_start:.2f}s)")
276
  return score
277
  except Exception as e:
278
+ current_log_list.append(f"ERROR [{filename_for_log}]: AnimeAI_Check (ONNX) scoring failed: {e}"); return "Error"
 
279
 
 
280
  def process_images_generator(files, progress=gr.Progress(track_tqdm=True)):
281
  if not files:
282
  yield pd.DataFrame(), None, None, None, None, "Please upload some images.", "No files to process."
 
284
 
285
  all_results = []
286
  log_accumulator = [f"INFO: Starting processing for {len(files)} images..."]
287
+ # Начальный yield для лога и статуса
288
+ yield (pd.DataFrame(all_results), None, None,
289
+ gr.File(visible=False), gr.File(visible=False), # Скрываем кнопки скачивания вначале
290
+ "Processing...", "\n".join(log_accumulator))
291
 
292
  for i, file_obj in enumerate(files):
293
  filename_for_log = "Unknown File"
294
  current_img_total_time_start = time.time()
295
  try:
296
+ filename_for_log = os.path.basename(getattr(file_obj, 'name', f"file_{i}_{int(time.time())}"))
297
  log_accumulator.append(f"--- Processing image {i+1}/{len(files)}: {filename_for_log} ---")
298
 
299
+ # Используем progress(float, desc=...)
300
+ progress( (i + 0.1) / len(files), desc=f"Img {i+1}/{len(files)}: Loading {filename_for_log}")
301
+ # Немедленно обновляем UI с логом перед тяжелой загрузкой изображения
302
+ yield (pd.DataFrame(all_results), None, None,
303
+ gr.File(visible=False), gr.File(visible=False),
304
+ f"Loading image {i+1}/{len(files)}: {filename_for_log}",
305
  "\n".join(log_accumulator))
306
 
307
  img = Image.open(getattr(file_obj, 'name', str(file_obj)))
308
  if img.mode != "RGB": img = img.convert("RGB")
309
+ progress( (i + 0.3) / len(files), desc=f"Img {i+1}/{len(files)}: Scoring {filename_for_log}")
310
 
 
311
 
312
+ prompt, neg_prompt, model_n, model_h, other_p = extract_sd_parameters(img, filename_for_log, log_accumulator)
313
  reward = get_image_reward(img, filename_for_log, log_accumulator)
314
  anime_aes_deepghs = get_anime_aesthetic_score_deepghs(img, filename_for_log, log_accumulator)
315
  maniqa = get_maniqa_score(img, filename_for_log, log_accumulator)
316
  clip_val = calculate_clip_score_value(img, prompt, filename_for_log, log_accumulator)
317
  sdxl_detect = get_sdxl_detection_score(img, filename_for_log, log_accumulator)
318
  anime_ai_chk_deepghs = get_anime_ai_check_score_deepghs(img, filename_for_log, log_accumulator)
 
319
  current_img_total_time_end = time.time()
320
  log_accumulator.append(f"INFO [{filename_for_log}]: Finished all scores (total for image: {current_img_total_time_end - current_img_total_time_start:.2f}s)")
321
 
 
322
  all_results.append({
323
  "Filename": filename_for_log, "Prompt": prompt if prompt else "N/A", "Model Name": model_n, "Model Hash": model_h,
324
  "ImageReward": reward, "AnimeAesthetic_dg": anime_aes_deepghs, "MANIQA_TQ": maniqa,
325
  "CLIPScore": clip_val, "SDXL_Detector_AI_Prob": sdxl_detect, "AnimeAI_Check_dg_Prob": anime_ai_chk_deepghs,
326
  })
 
 
 
 
327
  df_so_far = pd.DataFrame(all_results)
328
+ progress( (i + 1.0) / len(files), desc=f"Img {i+1}/{len(files)}: Done {filename_for_log}")
329
+ yield (df_so_far, None, None,
330
+ gr.File(visible=False), gr.File(visible=False),
331
  f"Processed image {i+1}/{len(files)}: {filename_for_log}",
332
  "\n".join(log_accumulator))
 
333
  except Exception as e:
334
  log_accumulator.append(f"CRITICAL ERROR processing {filename_for_log}: {e}")
335
  print(f"CRITICAL ERROR processing {filename_for_log}: {e}")
 
339
  "CLIPScore": "Error", "SDXL_Detector_AI_Prob": "Error", "AnimeAI_Check_dg_Prob": "Error"
340
  })
341
  df_so_far = pd.DataFrame(all_results)
342
+ yield (df_so_far, None, None,
343
+ gr.File(visible=False), gr.File(visible=False),
344
  f"Error on image {i+1}/{len(files)}: {filename_for_log}",
345
  "\n".join(log_accumulator))
346
 
347
  log_accumulator.append("--- Generating final plots and download files ---")
348
+ progress(1.0, desc="Generating final plots...")
349
+ yield (pd.DataFrame(all_results), None, None,
350
+ gr.File(visible=False), gr.File(visible=False),
351
  "Generating final plots...",
352
  "\n".join(log_accumulator))
353
 
354
  df = pd.DataFrame(all_results)
355
  plot_model_avg_scores_buffer, plot_prompt_clip_scores_buffer = None, None
356
+ csv_file_path_out, json_file_path_out = None, None # Будем возвращать пути к файлам
357
 
358
  if not df.empty:
359
+ numeric_cols = ["ImageReward", "AnimeAesthetic_dg", "MANIQA_TQ", "CLIPScore"]
360
  for col in numeric_cols: df[col] = pd.to_numeric(df[col], errors='coerce')
 
361
  df_model_plot = df[(df["Model Name"] != "N/A") & (df["Model Name"].notna())]
362
  if not df_model_plot.empty and df_model_plot["Model Name"].nunique() > 0:
363
  try:
 
369
  plot_model_avg_scores_buffer = io.BytesIO(); fig1.savefig(plot_model_avg_scores_buffer, format="png"); plot_model_avg_scores_buffer.seek(0); plt.close(fig1)
370
  log_accumulator.append("INFO: Model average scores plot generated.")
371
  except Exception as e: log_accumulator.append(f"ERROR: Failed to generate model average scores plot: {e}")
 
372
  df_prompt_plot = df[(df["Prompt"] != "N/A") & (df["Prompt"].notna()) & (df["CLIPScore"].notna())]
373
  if not df_prompt_plot.empty and df_prompt_plot["Prompt"].nunique() > 0 :
374
  try:
 
382
  log_accumulator.append("INFO: Prompt CLIP scores plot generated.")
383
  except Exception as e: log_accumulator.append(f"ERROR: Failed to generate prompt CLIP scores plot: {e}")
384
 
385
+ # Сохраняем файлы во временные файлы и возвращаем пути
386
+ try:
387
+ with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".csv", encoding='utf-8') as tmp_csv:
388
+ df.to_csv(tmp_csv, index=False)
389
+ csv_file_path_out = tmp_csv.name
390
+ with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json", encoding='utf-8') as tmp_json:
391
+ df.to_json(tmp_json, orient='records', indent=4)
392
+ json_file_path_out = tmp_json.name
393
+ log_accumulator.append("INFO: CSV and JSON data prepared for download.")
394
+ except Exception as e:
395
+ log_accumulator.append(f"ERROR preparing download files: {e}")
396
+
397
 
398
+ final_status = f"Finished processing {len(all_results)} images."
 
399
  log_accumulator.append(final_status)
400
 
401
  yield (
402
  df,
403
+ gr.Image(value=plot_model_avg_scores_buffer, visible=plot_model_avg_scores_buffer is not None),
404
+ gr.Image(value=plot_prompt_clip_scores_buffer, visible=plot_prompt_clip_scores_buffer is not None),
405
+ gr.File(value=csv_file_path_out, visible=csv_file_path_out is not None), # Убрали file_name
406
+ gr.File(value=json_file_path_out, visible=json_file_path_out is not None), # Убрали file_name
407
  final_status,
408
  "\n".join(log_accumulator)
409
  )
410
 
411
+ import tempfile # Для gr.File
412
 
 
413
  with gr.Blocks(css="footer {display: none !important}") as demo:
414
  gr.Markdown("# AI Image Model Evaluation Tool")
415
  gr.Markdown("Upload PNG images (ideally with Stable Diffusion metadata) to evaluate them...")
416
+ with gr.Row(): image_uploader = gr.Files(label="Upload Images (PNG)", file_count="multiple", file_types=["image"])
 
 
 
 
 
 
 
417
  process_button = gr.Button("Evaluate Images", variant="primary")
 
418
  status_textbox = gr.Textbox(label="Overall Status", interactive=False)
419
+ log_output_textbox = gr.Textbox(label="Detailed Logs", lines=15, interactive=False, autoscroll=True)
 
 
420
  gr.Markdown("## Evaluation Results Table")
421
  results_table = gr.DataFrame(headers=[
422
  "Filename", "Prompt", "Model Name", "Model Hash", "ImageReward", "AnimeAesthetic_dg",
423
  "MANIQA_TQ", "CLIPScore", "SDXL_Detector_AI_Prob", "AnimeAI_Check_dg_Prob"
424
  ], wrap=True)
 
425
  with gr.Row():
426
+ download_csv_button = gr.File(label="Download CSV Results", interactive=False) # Будет обновляться из yield
427
+ download_json_button = gr.File(label="Download JSON Results", interactive=False) # Будет обновляться из yield
 
428
  gr.Markdown("## Visualizations")
429
  with gr.Row():
430
  plot_output_model_avg = gr.Image(label="Average Scores per Model", type="pil", interactive=False)
431
  plot_output_prompt_clip = gr.Image(label="Average CLIPScore per Prompt", type="pil", interactive=False)
 
432
  process_button.click(
433
+ fn=process_images_generator, inputs=[image_uploader],
434
+ outputs=[results_table, plot_output_model_avg, plot_output_prompt_clip,
435
+ download_csv_button, download_json_button, status_textbox, log_output_textbox]
 
 
 
 
 
 
 
 
436
  )
 
437
  gr.Markdown("""**Metric Explanations:** ... (без изменений)""")
438
 
439
  if __name__ == "__main__":
 
440
  print("--- Initializing models, please wait... ---")
 
 
 
441
  initial_dummy_logs = []
442
  if onnxruntime:
443
  get_onnx_session_and_meta(ANIME_AESTHETIC_REPO, ANIME_AESTHETIC_SUBFOLDER, initial_dummy_logs)
 
447
  for log_line in initial_dummy_logs: print(log_line)
448
  print("-----------------------------------------")
449
  print("--- Model initialization attempt complete. Launching Gradio. ---")
450
+ demo.queue().launch(debug=True)