avans06 commited on
Commit
f8fbbc6
·
1 Parent(s): c5eece5

Added a "Remove tags" feature to exclude specified tags from the output.

Browse files
Files changed (1) hide show
  1. app.py +97 -88
app.py CHANGED
@@ -134,7 +134,7 @@ class Timer:
134
  elapsed = curr_time - prev_time
135
  print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
136
  prev_time = curr_time
137
-
138
  if is_clear_checkpoints:
139
  self.checkpoints = [("Start", time.perf_counter())]
140
 
@@ -151,7 +151,7 @@ class Timer:
151
  elapsed = curr_time - prev_time
152
  print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
153
  prev_time = curr_time
154
-
155
  total_time = self.checkpoints[-1][1] - self.start_time
156
  print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n")
157
 
@@ -252,7 +252,7 @@ class Llama3Reorganize:
252
  import ctranslate2
253
  import transformers
254
  try:
255
- print('\n\nLoading model: %s\n\n' % self.modelPath)
256
  kwargsTokenizer = {"pretrained_model_name_or_path": self.modelPath}
257
  kwargsModel = {"device": self.device, "model_path": self.modelPath, "compute_type": "auto"}
258
  self.roleSystem = {"role": "system", "content": self.system_prompt}
@@ -270,12 +270,11 @@ class Llama3Reorganize:
270
  try:
271
  import torch
272
  if torch.cuda.is_available():
273
- if getattr(self, "Model", None) is not None and getattr(self.Model, "unload_model", None) is not None:
274
  self.Model.unload_model()
275
-
276
- if getattr(self, "Tokenizer", None) is not None:
277
  del self.Tokenizer
278
- if getattr(self, "Model", None) is not None:
279
  del self.Model
280
  import gc
281
  gc.collect()
@@ -283,14 +282,13 @@ class Llama3Reorganize:
283
  torch.cuda.empty_cache()
284
  except Exception as e:
285
  print(traceback.format_exc())
286
- print("\tcuda empty cache, error: " + str(e))
287
  print("release vram end.")
288
  except Exception as e:
289
  print(traceback.format_exc())
290
- print("Error release vram: " + str(e))
291
 
292
  def reorganize(self, text: str, max_length: int = 400):
293
- output = None
294
  result = None
295
  try:
296
  input_ids = self.Tokenizer.apply_chat_template([self.roleSystem, {"role": "user", "content": text + "\n\nHere's the reorganized English article:"}], tokenize=False, add_generation_prompt=True)
@@ -298,19 +296,18 @@ class Llama3Reorganize:
298
  output = self.Model.generate_batch([source], max_length=max_length, max_batch_size=2, no_repeat_ngram_size=3, beam_size=2, sampling_temperature=0.7, sampling_topp=0.9, include_prompt_in_result=False, end_token=self.terminators)
299
  target = output[0]
300
  result = self.Tokenizer.decode(target.sequences_ids[0])
301
-
302
  if len(result) > 2:
303
- if result[0] == "\"" and result[len(result) - 1] == "\"":
304
  result = result[1:-1]
305
- elif result[0] == "'" and result[len(result) - 1] == "'":
306
  result = result[1:-1]
307
- elif result[0] == "" and result[len(result) - 1] == "":
308
  result = result[1:-1]
309
- elif result[0] == "" and result[len(result) - 1] == "":
310
  result = result[1:-1]
311
  except Exception as e:
312
  print(traceback.format_exc())
313
- print("Error reorganize text: " + str(e))
314
 
315
  return result
316
 
@@ -339,28 +336,19 @@ class Predictor:
339
 
340
  tags_df = pd.read_csv(csv_path)
341
  sep_tags = load_labels(tags_df)
342
-
343
- self.tag_names = sep_tags[0]
344
- self.rating_indexes = sep_tags[1]
345
- self.general_indexes = sep_tags[2]
346
- self.character_indexes = sep_tags[3]
347
-
348
  model = rt.InferenceSession(model_path)
349
- _, height, width, _ = model.get_inputs()[0].shape
350
  self.model_target_size = height
351
-
352
  self.last_loaded_repo = model_repo
353
  self.model = model
354
 
355
  def prepare_image(self, path):
356
- image = Image.open(path)
357
- image = image.convert("RGBA")
358
- target_size = self.model_target_size
359
-
360
  canvas = Image.new("RGBA", image.size, (255, 255, 255))
361
  canvas.alpha_composite(image)
362
  image = canvas.convert("RGB")
363
-
364
  # Pad image to square
365
  image_shape = image.size
366
  max_dim = max(image_shape)
@@ -369,14 +357,14 @@ class Predictor:
369
 
370
  padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
371
  padded_image.paste(image, (pad_left, pad_top))
372
-
373
  # Resize
374
- if max_dim != target_size:
375
  padded_image = padded_image.resize(
376
- (target_size, target_size),
377
  Image.BICUBIC,
378
  )
379
-
380
  # Convert to numpy array
381
  image_array = np.asarray(padded_image, dtype=np.float32)
382
 
@@ -404,6 +392,7 @@ class Predictor:
404
  llama3_reorganize_model_repo,
405
  additional_tags_prepend,
406
  additional_tags_append,
 
407
  tag_results,
408
  progress=gr.Progress()
409
  ):
@@ -413,7 +402,7 @@ class Predictor:
413
 
414
  gallery_len = len(gallery)
415
  print(f"Predict from images: load model: {model_repo}, gallery length: {gallery_len}")
416
-
417
  timer = Timer() # Create a timer
418
  progressRatio = 0.5 if llama3_reorganize_model_repo else 1
419
  progressTotal = gallery_len + (1 if llama3_reorganize_model_repo else 0) + 1 # +1 for model load
@@ -423,7 +412,7 @@ class Predictor:
423
  current_progress += 1 / progressTotal
424
  progress(current_progress, desc="Initialize wd model finished")
425
  timer.checkpoint(f"Initialize wd model")
426
-
427
  # Result
428
  txt_infos = []
429
  output_dir = tempfile.mkdtemp()
@@ -439,14 +428,15 @@ class Predictor:
439
  current_progress += 1 / progressTotal
440
  progress(current_progress, desc="Initialize llama3 model finished")
441
  timer.checkpoint(f"Initialize llama3 model")
442
-
443
  timer.report()
444
 
445
  prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
446
  append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
 
447
  if prepend_list and append_list:
448
  append_list = [item for item in append_list if item not in prepend_list]
449
-
450
  # Dictionary to track counters for each filename
451
  name_counters = defaultdict(int)
452
  for idx, value in enumerate(gallery):
@@ -467,11 +457,11 @@ class Predictor:
467
  preds = self.model.run([label_name], {input_name: image})[0]
468
 
469
  labels = list(zip(self.tag_names, preds[0].astype(float)))
470
-
471
  # First 4 labels are actually ratings: pick one with argmax
472
  ratings_names = [labels[i] for i in self.rating_indexes]
473
  rating = dict(ratings_names)
474
-
475
  # Then we have general tags: pick any where prediction confidence > threshold
476
  general_names = [labels[i] for i in self.general_indexes]
477
 
@@ -479,7 +469,7 @@ class Predictor:
479
  general_probs = np.array([x[1] for x in general_names])
480
  general_thresh = mcut_threshold(general_probs)
481
  general_res = dict([x for x in general_names if x[1] > general_thresh])
482
-
483
  # Everything else is characters: pick any where prediction confidence > threshold
484
  character_names = [labels[i] for i in self.character_indexes]
485
 
@@ -503,7 +493,12 @@ class Predictor:
503
  final_tags_list = prepend_list + sorted_general_list + append_list
504
  if characters_merge_enabled:
505
  final_tags_list = character_list + final_tags_list
506
-
 
 
 
 
 
507
  sorted_general_strings = ", ".join(final_tags_list).replace("(", "\(").replace(")", "\)")
508
  classified_tags, unclassified_tags = classify_tags(final_tags_list)
509
 
@@ -553,23 +548,24 @@ class Predictor:
553
  # Get file name from lookup
554
  taggers_zip.write(info["path"], arcname=info["name"])
555
  download.append(downloadZipPath)
556
-
557
  if llama3_reorganize:
558
  llama3_reorganize.release_vram()
559
-
560
  progress(1, desc="Image processing completed")
561
  timer.report_all()
562
  print("Image prediction is complete.")
563
 
564
  return download, last_sorted_general_strings, last_classified_tags, last_rating, last_character_res, last_general_res, last_unclassified_tags, tag_results
565
-
566
- # NEW: Method to process text files
567
  def predict_from_text(
568
  self,
569
  text_files,
570
  llama3_reorganize_model_repo,
571
  additional_tags_prepend,
572
  additional_tags_append,
 
573
  progress=gr.Progress()
574
  ):
575
  if not text_files:
@@ -583,7 +579,7 @@ class Predictor:
583
  progressRatio = 0.5 if llama3_reorganize_model_repo else 1.0
584
  progressTotal = files_len + (1 if llama3_reorganize_model_repo else 0)
585
  current_progress = 0
586
-
587
  txt_infos = []
588
  output_dir = tempfile.mkdtemp()
589
  last_processed_string = ""
@@ -600,6 +596,7 @@ class Predictor:
600
 
601
  prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
602
  append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
 
603
  if prepend_list and append_list:
604
  append_list = [item for item in append_list if item not in prepend_list]
605
 
@@ -608,7 +605,7 @@ class Predictor:
608
  try:
609
  file_path = file_obj.name
610
  file_name_base = os.path.splitext(os.path.basename(file_path))[0]
611
-
612
  name_counters[file_name_base] += 1
613
  if name_counters[file_name_base] > 1:
614
  output_file_name = f"{file_name_base}_{name_counters[file_name_base]:02d}.txt"
@@ -617,16 +614,22 @@ class Predictor:
617
 
618
  with open(file_path, 'r', encoding='utf-8') as f:
619
  original_content = f.read()
620
-
621
  # Process tags
622
  tags_list = [tag.strip() for tag in original_content.split(',') if tag.strip()]
623
-
624
  if prepend_list:
625
  tags_list = [item for item in tags_list if item not in prepend_list]
626
  if append_list:
627
  tags_list = [item for item in tags_list if item not in append_list]
628
 
629
  final_tags_list = prepend_list + tags_list + append_list
 
 
 
 
 
 
630
  processed_string = ", ".join(final_tags_list)
631
 
632
  current_progress += progressRatio / progressTotal
@@ -645,7 +648,7 @@ class Predictor:
645
  current_progress += progressRatio / progressTotal
646
  progress(current_progress, desc=f"File {idx+1}/{files_len}, llama3 reorganize finished")
647
  timer.checkpoint(f"File {idx+1}/{files_len}, llama3 reorganize finished")
648
-
649
  txt_file_path = self.create_file(processed_string, output_dir, output_file_name)
650
  txt_infos.append({"path": txt_file_path, "name": output_file_name})
651
  last_processed_string = processed_string
@@ -671,7 +674,7 @@ class Predictor:
671
  progress(1, desc="Text processing completed")
672
  timer.report_all() # Print all recorded times
673
  print("Text processing is complete.")
674
-
675
  # Return values in the same structure as the image path, with placeholders for unused outputs
676
  return download, last_processed_string, "{}", "", "", "", "{}", {}
677
 
@@ -679,9 +682,8 @@ def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state:
679
  if not selected_state:
680
  return selected_state
681
 
682
- tag_result = { "strings": "", "classified_tags": "{}", "rating": "", "character_res": "", "general_res": "", "unclassified_tags": "{}" }
683
- if selected_state.value["image"]["path"] in tag_results:
684
- tag_result = tag_results[selected_state.value["image"]["path"]]
685
 
686
  return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"]
687
 
@@ -690,7 +692,7 @@ def append_gallery(gallery: list, image: str):
690
  gallery = []
691
  if not image:
692
  return gallery, None
693
-
694
  gallery.append(image)
695
 
696
  return gallery, None
@@ -712,14 +714,14 @@ def remove_image_from_gallery(gallery: list, selected_image: str):
712
  return gallery
713
 
714
  try:
715
- selected_image = ast.literal_eval(selected_image) #Use ast.literal_eval to parse text into a tuple.
716
  # Remove the selected image from the gallery
717
- if selected_image in gallery:
718
- gallery.remove(selected_image)
719
  except (ValueError, SyntaxError):
720
  # Handle cases where the string is not a valid literal
721
  print(f"Warning: Could not parse selected_image string: {selected_image}")
722
-
723
  return gallery
724
 
725
 
@@ -751,32 +753,33 @@ def main():
751
  SWINV2_MODEL_IS_DSV1_REPO,
752
  EVA02_LARGE_MODEL_IS_DSV1_REPO,
753
  ]
754
-
755
  llama_list = [
756
  META_LLAMA_3_3B_REPO,
757
  META_LLAMA_3_8B_REPO,
758
  ]
759
-
760
- # NEW: Wrapper function to decide which prediction method to call
761
  def run_prediction(
762
  input_type, gallery, text_files, model_repo, general_thresh,
763
  general_mcut_enabled, character_thresh, character_mcut_enabled,
764
  characters_merge_enabled, llama3_reorganize_model_repo,
765
- additional_tags_prepend, additional_tags_append, tag_results, progress=gr.Progress()
 
766
  ):
767
  if input_type == 'Image':
768
  return predictor.predict_from_images(
769
  gallery, model_repo, general_thresh, general_mcut_enabled,
770
  character_thresh, character_mcut_enabled, characters_merge_enabled,
771
  llama3_reorganize_model_repo, additional_tags_prepend,
772
- additional_tags_append, tag_results, progress
773
  )
774
  else: # 'Text file (.txt)'
775
  # For text files, some parameters are not used, but we must return
776
  # a tuple of the same size. `predict_from_text` handles this.
777
  return predictor.predict_from_text(
778
  text_files, llama3_reorganize_model_repo,
779
- additional_tags_prepend, additional_tags_append, progress
780
  )
781
 
782
  with gr.Blocks(title=TITLE, css=css) as demo:
@@ -793,7 +796,7 @@ def main():
793
  value='Image',
794
  label="Input Type"
795
  )
796
-
797
  # Group for image inputs, initially visible
798
  with gr.Column(visible=True) as image_inputs_group:
799
  with gr.Column(variant="panel"):
@@ -803,8 +806,8 @@ def main():
803
  upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
804
  remove_button = gr.Button("Remove Selected Image", size="sm")
805
  gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Gallery that displaying a grid of images")
806
-
807
- # NEW: Group for text file inputs, initially hidden
808
  with gr.Column(visible=False) as text_inputs_group:
809
  text_files_input = gr.Files(
810
  label="Upload .txt files",
@@ -813,24 +816,6 @@ def main():
813
  height=500
814
  )
815
 
816
- # NEW: Logic to show/hide input groups based on radio selection
817
- def change_input_type(input_type):
818
- is_image = (input_type == 'Image')
819
- return {
820
- image_inputs_group: gr.update(visible=is_image),
821
- text_inputs_group: gr.update(visible=not is_image),
822
- # Also update visibility of image-specific settings
823
- model_repo: gr.update(visible=is_image),
824
- general_thresh_row: gr.update(visible=is_image),
825
- character_thresh_row: gr.update(visible=is_image),
826
- characters_merge_enabled: gr.update(visible=is_image),
827
- categorized: gr.update(visible=is_image),
828
- rating: gr.update(visible=is_image),
829
- character_res: gr.update(visible=is_image),
830
- general_res: gr.update(visible=is_image),
831
- unclassified: gr.update(visible=is_image),
832
- }
833
-
834
  # Image-specific settings
835
  model_repo = gr.Dropdown(
836
  dropdown_list,
@@ -883,6 +868,10 @@ def main():
883
  with gr.Row():
884
  additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
885
  additional_tags_append = gr.Text(label="Append Additional tags (comma split)")
 
 
 
 
886
  with gr.Row():
887
  clear = gr.ClearButton(
888
  components=[
@@ -897,6 +886,7 @@ def main():
897
  llama3_reorganize_model_repo,
898
  additional_tags_prepend,
899
  additional_tags_append,
 
900
  ],
901
  variant="secondary",
902
  size="lg",
@@ -935,7 +925,25 @@ def main():
935
  gallery.select(get_selection_from_gallery, inputs=[gallery, tag_results], outputs=[selected_image, sorted_general_strings, categorized, rating, character_res, general_res, unclassified])
936
  # Event to remove a selected image from the gallery
937
  remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
938
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
939
  # Connect the radio button to the visibility function
940
  input_type_radio.change(
941
  fn=change_input_type,
@@ -946,7 +954,7 @@ def main():
946
  categorized, rating, character_res, general_res, unclassified
947
  ]
948
  )
949
-
950
  # submit click now calls the wrapper function
951
  submit.click(
952
  fn=run_prediction,
@@ -963,6 +971,7 @@ def main():
963
  llama3_reorganize_model_repo,
964
  additional_tags_prepend,
965
  additional_tags_append,
 
966
  tag_results,
967
  ],
968
  outputs=[download_file, sorted_general_strings, categorized, rating, character_res, general_res, unclassified, tag_results,],
 
134
  elapsed = curr_time - prev_time
135
  print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
136
  prev_time = curr_time
137
+
138
  if is_clear_checkpoints:
139
  self.checkpoints = [("Start", time.perf_counter())]
140
 
 
151
  elapsed = curr_time - prev_time
152
  print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
153
  prev_time = curr_time
154
+
155
  total_time = self.checkpoints[-1][1] - self.start_time
156
  print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n")
157
 
 
252
  import ctranslate2
253
  import transformers
254
  try:
255
+ print(f'\n\nLoading model: {self.modelPath}\n\n')
256
  kwargsTokenizer = {"pretrained_model_name_or_path": self.modelPath}
257
  kwargsModel = {"device": self.device, "model_path": self.modelPath, "compute_type": "auto"}
258
  self.roleSystem = {"role": "system", "content": self.system_prompt}
 
270
  try:
271
  import torch
272
  if torch.cuda.is_available():
273
+ if hasattr(self, "Model") and hasattr(self.Model, "unload_model"):
274
  self.Model.unload_model()
275
+ if hasattr(self, "Tokenizer"):
 
276
  del self.Tokenizer
277
+ if hasattr(self, "Model"):
278
  del self.Model
279
  import gc
280
  gc.collect()
 
282
  torch.cuda.empty_cache()
283
  except Exception as e:
284
  print(traceback.format_exc())
285
+ print(f"\tcuda empty cache, error: {e}")
286
  print("release vram end.")
287
  except Exception as e:
288
  print(traceback.format_exc())
289
+ print(f"Error release vram: {e}")
290
 
291
  def reorganize(self, text: str, max_length: int = 400):
 
292
  result = None
293
  try:
294
  input_ids = self.Tokenizer.apply_chat_template([self.roleSystem, {"role": "user", "content": text + "\n\nHere's the reorganized English article:"}], tokenize=False, add_generation_prompt=True)
 
296
  output = self.Model.generate_batch([source], max_length=max_length, max_batch_size=2, no_repeat_ngram_size=3, beam_size=2, sampling_temperature=0.7, sampling_topp=0.9, include_prompt_in_result=False, end_token=self.terminators)
297
  target = output[0]
298
  result = self.Tokenizer.decode(target.sequences_ids[0])
 
299
  if len(result) > 2:
300
+ if result[0] == '"' and result[-1] == '"':
301
  result = result[1:-1]
302
+ elif result[0] == "'" and result[-1] == "'":
303
  result = result[1:-1]
304
+ elif result[0] == '' and result[-1] == '':
305
  result = result[1:-1]
306
+ elif result[0] == '' and result[-1] == '':
307
  result = result[1:-1]
308
  except Exception as e:
309
  print(traceback.format_exc())
310
+ print(f"Error reorganize text: {e}")
311
 
312
  return result
313
 
 
336
 
337
  tags_df = pd.read_csv(csv_path)
338
  sep_tags = load_labels(tags_df)
339
+ self.tag_names, self.rating_indexes, self.general_indexes, self.character_indexes = sep_tags
 
 
 
 
 
340
  model = rt.InferenceSession(model_path)
341
+ _, height, _, _ = model.get_inputs()[0].shape
342
  self.model_target_size = height
 
343
  self.last_loaded_repo = model_repo
344
  self.model = model
345
 
346
  def prepare_image(self, path):
347
+ image = Image.open(path).convert("RGBA")
 
 
 
348
  canvas = Image.new("RGBA", image.size, (255, 255, 255))
349
  canvas.alpha_composite(image)
350
  image = canvas.convert("RGB")
351
+
352
  # Pad image to square
353
  image_shape = image.size
354
  max_dim = max(image_shape)
 
357
 
358
  padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
359
  padded_image.paste(image, (pad_left, pad_top))
360
+
361
  # Resize
362
+ if max_dim != self.model_target_size:
363
  padded_image = padded_image.resize(
364
+ (self.model_target_size, self.model_target_size),
365
  Image.BICUBIC,
366
  )
367
+
368
  # Convert to numpy array
369
  image_array = np.asarray(padded_image, dtype=np.float32)
370
 
 
392
  llama3_reorganize_model_repo,
393
  additional_tags_prepend,
394
  additional_tags_append,
395
+ tags_to_remove,
396
  tag_results,
397
  progress=gr.Progress()
398
  ):
 
402
 
403
  gallery_len = len(gallery)
404
  print(f"Predict from images: load model: {model_repo}, gallery length: {gallery_len}")
405
+
406
  timer = Timer() # Create a timer
407
  progressRatio = 0.5 if llama3_reorganize_model_repo else 1
408
  progressTotal = gallery_len + (1 if llama3_reorganize_model_repo else 0) + 1 # +1 for model load
 
412
  current_progress += 1 / progressTotal
413
  progress(current_progress, desc="Initialize wd model finished")
414
  timer.checkpoint(f"Initialize wd model")
415
+
416
  # Result
417
  txt_infos = []
418
  output_dir = tempfile.mkdtemp()
 
428
  current_progress += 1 / progressTotal
429
  progress(current_progress, desc="Initialize llama3 model finished")
430
  timer.checkpoint(f"Initialize llama3 model")
431
+
432
  timer.report()
433
 
434
  prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
435
  append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
436
+ remove_list = [tag.strip() for tag in tags_to_remove.split(",") if tag.strip()] # Parse remove tags
437
  if prepend_list and append_list:
438
  append_list = [item for item in append_list if item not in prepend_list]
439
+
440
  # Dictionary to track counters for each filename
441
  name_counters = defaultdict(int)
442
  for idx, value in enumerate(gallery):
 
457
  preds = self.model.run([label_name], {input_name: image})[0]
458
 
459
  labels = list(zip(self.tag_names, preds[0].astype(float)))
460
+
461
  # First 4 labels are actually ratings: pick one with argmax
462
  ratings_names = [labels[i] for i in self.rating_indexes]
463
  rating = dict(ratings_names)
464
+
465
  # Then we have general tags: pick any where prediction confidence > threshold
466
  general_names = [labels[i] for i in self.general_indexes]
467
 
 
469
  general_probs = np.array([x[1] for x in general_names])
470
  general_thresh = mcut_threshold(general_probs)
471
  general_res = dict([x for x in general_names if x[1] > general_thresh])
472
+
473
  # Everything else is characters: pick any where prediction confidence > threshold
474
  character_names = [labels[i] for i in self.character_indexes]
475
 
 
493
  final_tags_list = prepend_list + sorted_general_list + append_list
494
  if characters_merge_enabled:
495
  final_tags_list = character_list + final_tags_list
496
+
497
+ # Apply removal logic
498
+ if remove_list:
499
+ remove_set = set(remove_list)
500
+ final_tags_list = [tag for tag in final_tags_list if tag not in remove_set]
501
+
502
  sorted_general_strings = ", ".join(final_tags_list).replace("(", "\(").replace(")", "\)")
503
  classified_tags, unclassified_tags = classify_tags(final_tags_list)
504
 
 
548
  # Get file name from lookup
549
  taggers_zip.write(info["path"], arcname=info["name"])
550
  download.append(downloadZipPath)
551
+
552
  if llama3_reorganize:
553
  llama3_reorganize.release_vram()
554
+
555
  progress(1, desc="Image processing completed")
556
  timer.report_all()
557
  print("Image prediction is complete.")
558
 
559
  return download, last_sorted_general_strings, last_classified_tags, last_rating, last_character_res, last_general_res, last_unclassified_tags, tag_results
560
+
561
+ # Method to process text files
562
  def predict_from_text(
563
  self,
564
  text_files,
565
  llama3_reorganize_model_repo,
566
  additional_tags_prepend,
567
  additional_tags_append,
568
+ tags_to_remove,
569
  progress=gr.Progress()
570
  ):
571
  if not text_files:
 
579
  progressRatio = 0.5 if llama3_reorganize_model_repo else 1.0
580
  progressTotal = files_len + (1 if llama3_reorganize_model_repo else 0)
581
  current_progress = 0
582
+
583
  txt_infos = []
584
  output_dir = tempfile.mkdtemp()
585
  last_processed_string = ""
 
596
 
597
  prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
598
  append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
599
+ remove_list = [tag.strip() for tag in tags_to_remove.split(",") if tag.strip()] # Parse remove tags
600
  if prepend_list and append_list:
601
  append_list = [item for item in append_list if item not in prepend_list]
602
 
 
605
  try:
606
  file_path = file_obj.name
607
  file_name_base = os.path.splitext(os.path.basename(file_path))[0]
608
+
609
  name_counters[file_name_base] += 1
610
  if name_counters[file_name_base] > 1:
611
  output_file_name = f"{file_name_base}_{name_counters[file_name_base]:02d}.txt"
 
614
 
615
  with open(file_path, 'r', encoding='utf-8') as f:
616
  original_content = f.read()
617
+
618
  # Process tags
619
  tags_list = [tag.strip() for tag in original_content.split(',') if tag.strip()]
620
+
621
  if prepend_list:
622
  tags_list = [item for item in tags_list if item not in prepend_list]
623
  if append_list:
624
  tags_list = [item for item in tags_list if item not in append_list]
625
 
626
  final_tags_list = prepend_list + tags_list + append_list
627
+
628
+ # Apply removal logic
629
+ if remove_list:
630
+ remove_set = set(remove_list)
631
+ final_tags_list = [tag for tag in final_tags_list if tag not in remove_set]
632
+
633
  processed_string = ", ".join(final_tags_list)
634
 
635
  current_progress += progressRatio / progressTotal
 
648
  current_progress += progressRatio / progressTotal
649
  progress(current_progress, desc=f"File {idx+1}/{files_len}, llama3 reorganize finished")
650
  timer.checkpoint(f"File {idx+1}/{files_len}, llama3 reorganize finished")
651
+
652
  txt_file_path = self.create_file(processed_string, output_dir, output_file_name)
653
  txt_infos.append({"path": txt_file_path, "name": output_file_name})
654
  last_processed_string = processed_string
 
674
  progress(1, desc="Text processing completed")
675
  timer.report_all() # Print all recorded times
676
  print("Text processing is complete.")
677
+
678
  # Return values in the same structure as the image path, with placeholders for unused outputs
679
  return download, last_processed_string, "{}", "", "", "", "{}", {}
680
 
 
682
  if not selected_state:
683
  return selected_state
684
 
685
+ tag_result = tag_results.get(selected_state.value["image"]["path"],
686
+ {"strings": "", "classified_tags": "{}", "rating": "", "character_res": "", "general_res": "", "unclassified_tags": "{}"})
 
687
 
688
  return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"]
689
 
 
692
  gallery = []
693
  if not image:
694
  return gallery, None
695
+
696
  gallery.append(image)
697
 
698
  return gallery, None
 
714
  return gallery
715
 
716
  try:
717
+ selected_image_tuple = ast.literal_eval(selected_image) #Use ast.literal_eval to parse text into a tuple.
718
  # Remove the selected image from the gallery
719
+ if selected_image_tuple in gallery:
720
+ gallery.remove(selected_image_tuple)
721
  except (ValueError, SyntaxError):
722
  # Handle cases where the string is not a valid literal
723
  print(f"Warning: Could not parse selected_image string: {selected_image}")
724
+
725
  return gallery
726
 
727
 
 
753
  SWINV2_MODEL_IS_DSV1_REPO,
754
  EVA02_LARGE_MODEL_IS_DSV1_REPO,
755
  ]
756
+
757
  llama_list = [
758
  META_LLAMA_3_3B_REPO,
759
  META_LLAMA_3_8B_REPO,
760
  ]
761
+
762
+ # Wrapper function to decide which prediction method to call
763
  def run_prediction(
764
  input_type, gallery, text_files, model_repo, general_thresh,
765
  general_mcut_enabled, character_thresh, character_mcut_enabled,
766
  characters_merge_enabled, llama3_reorganize_model_repo,
767
+ additional_tags_prepend, additional_tags_append, tags_to_remove,
768
+ tag_results, progress=gr.Progress()
769
  ):
770
  if input_type == 'Image':
771
  return predictor.predict_from_images(
772
  gallery, model_repo, general_thresh, general_mcut_enabled,
773
  character_thresh, character_mcut_enabled, characters_merge_enabled,
774
  llama3_reorganize_model_repo, additional_tags_prepend,
775
+ additional_tags_append, tags_to_remove, tag_results, progress
776
  )
777
  else: # 'Text file (.txt)'
778
  # For text files, some parameters are not used, but we must return
779
  # a tuple of the same size. `predict_from_text` handles this.
780
  return predictor.predict_from_text(
781
  text_files, llama3_reorganize_model_repo,
782
+ additional_tags_prepend, additional_tags_append, tags_to_remove, progress
783
  )
784
 
785
  with gr.Blocks(title=TITLE, css=css) as demo:
 
796
  value='Image',
797
  label="Input Type"
798
  )
799
+
800
  # Group for image inputs, initially visible
801
  with gr.Column(visible=True) as image_inputs_group:
802
  with gr.Column(variant="panel"):
 
806
  upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
807
  remove_button = gr.Button("Remove Selected Image", size="sm")
808
  gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Gallery that displaying a grid of images")
809
+
810
+ # Group for text file inputs, initially hidden
811
  with gr.Column(visible=False) as text_inputs_group:
812
  text_files_input = gr.Files(
813
  label="Upload .txt files",
 
816
  height=500
817
  )
818
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
819
  # Image-specific settings
820
  model_repo = gr.Dropdown(
821
  dropdown_list,
 
868
  with gr.Row():
869
  additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
870
  additional_tags_append = gr.Text(label="Append Additional tags (comma split)")
871
+
872
+ # NEW: Add the remove tags input box
873
+ tags_to_remove = gr.Text(label="Remove tags (comma split)")
874
+
875
  with gr.Row():
876
  clear = gr.ClearButton(
877
  components=[
 
886
  llama3_reorganize_model_repo,
887
  additional_tags_prepend,
888
  additional_tags_append,
889
+ tags_to_remove,
890
  ],
891
  variant="secondary",
892
  size="lg",
 
925
  gallery.select(get_selection_from_gallery, inputs=[gallery, tag_results], outputs=[selected_image, sorted_general_strings, categorized, rating, character_res, general_res, unclassified])
926
  # Event to remove a selected image from the gallery
927
  remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
928
+
929
+ # Logic to show/hide input groups based on radio selection
930
+ def change_input_type(input_type):
931
+ is_image = (input_type == 'Image')
932
+ return {
933
+ image_inputs_group: gr.update(visible=is_image),
934
+ text_inputs_group: gr.update(visible=not is_image),
935
+ # Also update visibility of image-specific settings
936
+ model_repo: gr.update(visible=is_image),
937
+ general_thresh_row: gr.update(visible=is_image),
938
+ character_thresh_row: gr.update(visible=is_image),
939
+ characters_merge_enabled: gr.update(visible=is_image),
940
+ categorized: gr.update(visible=is_image),
941
+ rating: gr.update(visible=is_image),
942
+ character_res: gr.update(visible=is_image),
943
+ general_res: gr.update(visible=is_image),
944
+ unclassified: gr.update(visible=is_image),
945
+ }
946
+
947
  # Connect the radio button to the visibility function
948
  input_type_radio.change(
949
  fn=change_input_type,
 
954
  categorized, rating, character_res, general_res, unclassified
955
  ]
956
  )
957
+
958
  # submit click now calls the wrapper function
959
  submit.click(
960
  fn=run_prediction,
 
971
  llama3_reorganize_model_repo,
972
  additional_tags_prepend,
973
  additional_tags_append,
974
+ tags_to_remove,
975
  tag_results,
976
  ],
977
  outputs=[download_file, sorted_general_strings, categorized, rating, character_res, general_res, unclassified, tag_results,],