Spaces:
Running
Running
Added a "Remove tags" feature to exclude specified tags from the output.
Browse files
app.py
CHANGED
@@ -134,7 +134,7 @@ class Timer:
|
|
134 |
elapsed = curr_time - prev_time
|
135 |
print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
|
136 |
prev_time = curr_time
|
137 |
-
|
138 |
if is_clear_checkpoints:
|
139 |
self.checkpoints = [("Start", time.perf_counter())]
|
140 |
|
@@ -151,7 +151,7 @@ class Timer:
|
|
151 |
elapsed = curr_time - prev_time
|
152 |
print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
|
153 |
prev_time = curr_time
|
154 |
-
|
155 |
total_time = self.checkpoints[-1][1] - self.start_time
|
156 |
print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n")
|
157 |
|
@@ -252,7 +252,7 @@ class Llama3Reorganize:
|
|
252 |
import ctranslate2
|
253 |
import transformers
|
254 |
try:
|
255 |
-
print('\n\nLoading model:
|
256 |
kwargsTokenizer = {"pretrained_model_name_or_path": self.modelPath}
|
257 |
kwargsModel = {"device": self.device, "model_path": self.modelPath, "compute_type": "auto"}
|
258 |
self.roleSystem = {"role": "system", "content": self.system_prompt}
|
@@ -270,12 +270,11 @@ class Llama3Reorganize:
|
|
270 |
try:
|
271 |
import torch
|
272 |
if torch.cuda.is_available():
|
273 |
-
if
|
274 |
self.Model.unload_model()
|
275 |
-
|
276 |
-
if getattr(self, "Tokenizer", None) is not None:
|
277 |
del self.Tokenizer
|
278 |
-
if
|
279 |
del self.Model
|
280 |
import gc
|
281 |
gc.collect()
|
@@ -283,14 +282,13 @@ class Llama3Reorganize:
|
|
283 |
torch.cuda.empty_cache()
|
284 |
except Exception as e:
|
285 |
print(traceback.format_exc())
|
286 |
-
print("\tcuda empty cache, error: "
|
287 |
print("release vram end.")
|
288 |
except Exception as e:
|
289 |
print(traceback.format_exc())
|
290 |
-
print("Error release vram: "
|
291 |
|
292 |
def reorganize(self, text: str, max_length: int = 400):
|
293 |
-
output = None
|
294 |
result = None
|
295 |
try:
|
296 |
input_ids = self.Tokenizer.apply_chat_template([self.roleSystem, {"role": "user", "content": text + "\n\nHere's the reorganized English article:"}], tokenize=False, add_generation_prompt=True)
|
@@ -298,19 +296,18 @@ class Llama3Reorganize:
|
|
298 |
output = self.Model.generate_batch([source], max_length=max_length, max_batch_size=2, no_repeat_ngram_size=3, beam_size=2, sampling_temperature=0.7, sampling_topp=0.9, include_prompt_in_result=False, end_token=self.terminators)
|
299 |
target = output[0]
|
300 |
result = self.Tokenizer.decode(target.sequences_ids[0])
|
301 |
-
|
302 |
if len(result) > 2:
|
303 |
-
if result[0] == "
|
304 |
result = result[1:-1]
|
305 |
-
elif result[0] == "'" and result[
|
306 |
result = result[1:-1]
|
307 |
-
elif result[0] ==
|
308 |
result = result[1:-1]
|
309 |
-
elif result[0] ==
|
310 |
result = result[1:-1]
|
311 |
except Exception as e:
|
312 |
print(traceback.format_exc())
|
313 |
-
print("Error reorganize text: "
|
314 |
|
315 |
return result
|
316 |
|
@@ -339,28 +336,19 @@ class Predictor:
|
|
339 |
|
340 |
tags_df = pd.read_csv(csv_path)
|
341 |
sep_tags = load_labels(tags_df)
|
342 |
-
|
343 |
-
self.tag_names = sep_tags[0]
|
344 |
-
self.rating_indexes = sep_tags[1]
|
345 |
-
self.general_indexes = sep_tags[2]
|
346 |
-
self.character_indexes = sep_tags[3]
|
347 |
-
|
348 |
model = rt.InferenceSession(model_path)
|
349 |
-
_, height,
|
350 |
self.model_target_size = height
|
351 |
-
|
352 |
self.last_loaded_repo = model_repo
|
353 |
self.model = model
|
354 |
|
355 |
def prepare_image(self, path):
|
356 |
-
image = Image.open(path)
|
357 |
-
image = image.convert("RGBA")
|
358 |
-
target_size = self.model_target_size
|
359 |
-
|
360 |
canvas = Image.new("RGBA", image.size, (255, 255, 255))
|
361 |
canvas.alpha_composite(image)
|
362 |
image = canvas.convert("RGB")
|
363 |
-
|
364 |
# Pad image to square
|
365 |
image_shape = image.size
|
366 |
max_dim = max(image_shape)
|
@@ -369,14 +357,14 @@ class Predictor:
|
|
369 |
|
370 |
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
|
371 |
padded_image.paste(image, (pad_left, pad_top))
|
372 |
-
|
373 |
# Resize
|
374 |
-
if max_dim !=
|
375 |
padded_image = padded_image.resize(
|
376 |
-
(
|
377 |
Image.BICUBIC,
|
378 |
)
|
379 |
-
|
380 |
# Convert to numpy array
|
381 |
image_array = np.asarray(padded_image, dtype=np.float32)
|
382 |
|
@@ -404,6 +392,7 @@ class Predictor:
|
|
404 |
llama3_reorganize_model_repo,
|
405 |
additional_tags_prepend,
|
406 |
additional_tags_append,
|
|
|
407 |
tag_results,
|
408 |
progress=gr.Progress()
|
409 |
):
|
@@ -413,7 +402,7 @@ class Predictor:
|
|
413 |
|
414 |
gallery_len = len(gallery)
|
415 |
print(f"Predict from images: load model: {model_repo}, gallery length: {gallery_len}")
|
416 |
-
|
417 |
timer = Timer() # Create a timer
|
418 |
progressRatio = 0.5 if llama3_reorganize_model_repo else 1
|
419 |
progressTotal = gallery_len + (1 if llama3_reorganize_model_repo else 0) + 1 # +1 for model load
|
@@ -423,7 +412,7 @@ class Predictor:
|
|
423 |
current_progress += 1 / progressTotal
|
424 |
progress(current_progress, desc="Initialize wd model finished")
|
425 |
timer.checkpoint(f"Initialize wd model")
|
426 |
-
|
427 |
# Result
|
428 |
txt_infos = []
|
429 |
output_dir = tempfile.mkdtemp()
|
@@ -439,14 +428,15 @@ class Predictor:
|
|
439 |
current_progress += 1 / progressTotal
|
440 |
progress(current_progress, desc="Initialize llama3 model finished")
|
441 |
timer.checkpoint(f"Initialize llama3 model")
|
442 |
-
|
443 |
timer.report()
|
444 |
|
445 |
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
|
446 |
append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
|
|
|
447 |
if prepend_list and append_list:
|
448 |
append_list = [item for item in append_list if item not in prepend_list]
|
449 |
-
|
450 |
# Dictionary to track counters for each filename
|
451 |
name_counters = defaultdict(int)
|
452 |
for idx, value in enumerate(gallery):
|
@@ -467,11 +457,11 @@ class Predictor:
|
|
467 |
preds = self.model.run([label_name], {input_name: image})[0]
|
468 |
|
469 |
labels = list(zip(self.tag_names, preds[0].astype(float)))
|
470 |
-
|
471 |
# First 4 labels are actually ratings: pick one with argmax
|
472 |
ratings_names = [labels[i] for i in self.rating_indexes]
|
473 |
rating = dict(ratings_names)
|
474 |
-
|
475 |
# Then we have general tags: pick any where prediction confidence > threshold
|
476 |
general_names = [labels[i] for i in self.general_indexes]
|
477 |
|
@@ -479,7 +469,7 @@ class Predictor:
|
|
479 |
general_probs = np.array([x[1] for x in general_names])
|
480 |
general_thresh = mcut_threshold(general_probs)
|
481 |
general_res = dict([x for x in general_names if x[1] > general_thresh])
|
482 |
-
|
483 |
# Everything else is characters: pick any where prediction confidence > threshold
|
484 |
character_names = [labels[i] for i in self.character_indexes]
|
485 |
|
@@ -503,7 +493,12 @@ class Predictor:
|
|
503 |
final_tags_list = prepend_list + sorted_general_list + append_list
|
504 |
if characters_merge_enabled:
|
505 |
final_tags_list = character_list + final_tags_list
|
506 |
-
|
|
|
|
|
|
|
|
|
|
|
507 |
sorted_general_strings = ", ".join(final_tags_list).replace("(", "\(").replace(")", "\)")
|
508 |
classified_tags, unclassified_tags = classify_tags(final_tags_list)
|
509 |
|
@@ -553,23 +548,24 @@ class Predictor:
|
|
553 |
# Get file name from lookup
|
554 |
taggers_zip.write(info["path"], arcname=info["name"])
|
555 |
download.append(downloadZipPath)
|
556 |
-
|
557 |
if llama3_reorganize:
|
558 |
llama3_reorganize.release_vram()
|
559 |
-
|
560 |
progress(1, desc="Image processing completed")
|
561 |
timer.report_all()
|
562 |
print("Image prediction is complete.")
|
563 |
|
564 |
return download, last_sorted_general_strings, last_classified_tags, last_rating, last_character_res, last_general_res, last_unclassified_tags, tag_results
|
565 |
-
|
566 |
-
#
|
567 |
def predict_from_text(
|
568 |
self,
|
569 |
text_files,
|
570 |
llama3_reorganize_model_repo,
|
571 |
additional_tags_prepend,
|
572 |
additional_tags_append,
|
|
|
573 |
progress=gr.Progress()
|
574 |
):
|
575 |
if not text_files:
|
@@ -583,7 +579,7 @@ class Predictor:
|
|
583 |
progressRatio = 0.5 if llama3_reorganize_model_repo else 1.0
|
584 |
progressTotal = files_len + (1 if llama3_reorganize_model_repo else 0)
|
585 |
current_progress = 0
|
586 |
-
|
587 |
txt_infos = []
|
588 |
output_dir = tempfile.mkdtemp()
|
589 |
last_processed_string = ""
|
@@ -600,6 +596,7 @@ class Predictor:
|
|
600 |
|
601 |
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
|
602 |
append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
|
|
|
603 |
if prepend_list and append_list:
|
604 |
append_list = [item for item in append_list if item not in prepend_list]
|
605 |
|
@@ -608,7 +605,7 @@ class Predictor:
|
|
608 |
try:
|
609 |
file_path = file_obj.name
|
610 |
file_name_base = os.path.splitext(os.path.basename(file_path))[0]
|
611 |
-
|
612 |
name_counters[file_name_base] += 1
|
613 |
if name_counters[file_name_base] > 1:
|
614 |
output_file_name = f"{file_name_base}_{name_counters[file_name_base]:02d}.txt"
|
@@ -617,16 +614,22 @@ class Predictor:
|
|
617 |
|
618 |
with open(file_path, 'r', encoding='utf-8') as f:
|
619 |
original_content = f.read()
|
620 |
-
|
621 |
# Process tags
|
622 |
tags_list = [tag.strip() for tag in original_content.split(',') if tag.strip()]
|
623 |
-
|
624 |
if prepend_list:
|
625 |
tags_list = [item for item in tags_list if item not in prepend_list]
|
626 |
if append_list:
|
627 |
tags_list = [item for item in tags_list if item not in append_list]
|
628 |
|
629 |
final_tags_list = prepend_list + tags_list + append_list
|
|
|
|
|
|
|
|
|
|
|
|
|
630 |
processed_string = ", ".join(final_tags_list)
|
631 |
|
632 |
current_progress += progressRatio / progressTotal
|
@@ -645,7 +648,7 @@ class Predictor:
|
|
645 |
current_progress += progressRatio / progressTotal
|
646 |
progress(current_progress, desc=f"File {idx+1}/{files_len}, llama3 reorganize finished")
|
647 |
timer.checkpoint(f"File {idx+1}/{files_len}, llama3 reorganize finished")
|
648 |
-
|
649 |
txt_file_path = self.create_file(processed_string, output_dir, output_file_name)
|
650 |
txt_infos.append({"path": txt_file_path, "name": output_file_name})
|
651 |
last_processed_string = processed_string
|
@@ -671,7 +674,7 @@ class Predictor:
|
|
671 |
progress(1, desc="Text processing completed")
|
672 |
timer.report_all() # Print all recorded times
|
673 |
print("Text processing is complete.")
|
674 |
-
|
675 |
# Return values in the same structure as the image path, with placeholders for unused outputs
|
676 |
return download, last_processed_string, "{}", "", "", "", "{}", {}
|
677 |
|
@@ -679,9 +682,8 @@ def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state:
|
|
679 |
if not selected_state:
|
680 |
return selected_state
|
681 |
|
682 |
-
tag_result =
|
683 |
-
|
684 |
-
tag_result = tag_results[selected_state.value["image"]["path"]]
|
685 |
|
686 |
return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"]
|
687 |
|
@@ -690,7 +692,7 @@ def append_gallery(gallery: list, image: str):
|
|
690 |
gallery = []
|
691 |
if not image:
|
692 |
return gallery, None
|
693 |
-
|
694 |
gallery.append(image)
|
695 |
|
696 |
return gallery, None
|
@@ -712,14 +714,14 @@ def remove_image_from_gallery(gallery: list, selected_image: str):
|
|
712 |
return gallery
|
713 |
|
714 |
try:
|
715 |
-
|
716 |
# Remove the selected image from the gallery
|
717 |
-
if
|
718 |
-
gallery.remove(
|
719 |
except (ValueError, SyntaxError):
|
720 |
# Handle cases where the string is not a valid literal
|
721 |
print(f"Warning: Could not parse selected_image string: {selected_image}")
|
722 |
-
|
723 |
return gallery
|
724 |
|
725 |
|
@@ -751,32 +753,33 @@ def main():
|
|
751 |
SWINV2_MODEL_IS_DSV1_REPO,
|
752 |
EVA02_LARGE_MODEL_IS_DSV1_REPO,
|
753 |
]
|
754 |
-
|
755 |
llama_list = [
|
756 |
META_LLAMA_3_3B_REPO,
|
757 |
META_LLAMA_3_8B_REPO,
|
758 |
]
|
759 |
-
|
760 |
-
#
|
761 |
def run_prediction(
|
762 |
input_type, gallery, text_files, model_repo, general_thresh,
|
763 |
general_mcut_enabled, character_thresh, character_mcut_enabled,
|
764 |
characters_merge_enabled, llama3_reorganize_model_repo,
|
765 |
-
additional_tags_prepend, additional_tags_append,
|
|
|
766 |
):
|
767 |
if input_type == 'Image':
|
768 |
return predictor.predict_from_images(
|
769 |
gallery, model_repo, general_thresh, general_mcut_enabled,
|
770 |
character_thresh, character_mcut_enabled, characters_merge_enabled,
|
771 |
llama3_reorganize_model_repo, additional_tags_prepend,
|
772 |
-
additional_tags_append, tag_results, progress
|
773 |
)
|
774 |
else: # 'Text file (.txt)'
|
775 |
# For text files, some parameters are not used, but we must return
|
776 |
# a tuple of the same size. `predict_from_text` handles this.
|
777 |
return predictor.predict_from_text(
|
778 |
text_files, llama3_reorganize_model_repo,
|
779 |
-
additional_tags_prepend, additional_tags_append, progress
|
780 |
)
|
781 |
|
782 |
with gr.Blocks(title=TITLE, css=css) as demo:
|
@@ -793,7 +796,7 @@ def main():
|
|
793 |
value='Image',
|
794 |
label="Input Type"
|
795 |
)
|
796 |
-
|
797 |
# Group for image inputs, initially visible
|
798 |
with gr.Column(visible=True) as image_inputs_group:
|
799 |
with gr.Column(variant="panel"):
|
@@ -803,8 +806,8 @@ def main():
|
|
803 |
upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
|
804 |
remove_button = gr.Button("Remove Selected Image", size="sm")
|
805 |
gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Gallery that displaying a grid of images")
|
806 |
-
|
807 |
-
#
|
808 |
with gr.Column(visible=False) as text_inputs_group:
|
809 |
text_files_input = gr.Files(
|
810 |
label="Upload .txt files",
|
@@ -813,24 +816,6 @@ def main():
|
|
813 |
height=500
|
814 |
)
|
815 |
|
816 |
-
# NEW: Logic to show/hide input groups based on radio selection
|
817 |
-
def change_input_type(input_type):
|
818 |
-
is_image = (input_type == 'Image')
|
819 |
-
return {
|
820 |
-
image_inputs_group: gr.update(visible=is_image),
|
821 |
-
text_inputs_group: gr.update(visible=not is_image),
|
822 |
-
# Also update visibility of image-specific settings
|
823 |
-
model_repo: gr.update(visible=is_image),
|
824 |
-
general_thresh_row: gr.update(visible=is_image),
|
825 |
-
character_thresh_row: gr.update(visible=is_image),
|
826 |
-
characters_merge_enabled: gr.update(visible=is_image),
|
827 |
-
categorized: gr.update(visible=is_image),
|
828 |
-
rating: gr.update(visible=is_image),
|
829 |
-
character_res: gr.update(visible=is_image),
|
830 |
-
general_res: gr.update(visible=is_image),
|
831 |
-
unclassified: gr.update(visible=is_image),
|
832 |
-
}
|
833 |
-
|
834 |
# Image-specific settings
|
835 |
model_repo = gr.Dropdown(
|
836 |
dropdown_list,
|
@@ -883,6 +868,10 @@ def main():
|
|
883 |
with gr.Row():
|
884 |
additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
|
885 |
additional_tags_append = gr.Text(label="Append Additional tags (comma split)")
|
|
|
|
|
|
|
|
|
886 |
with gr.Row():
|
887 |
clear = gr.ClearButton(
|
888 |
components=[
|
@@ -897,6 +886,7 @@ def main():
|
|
897 |
llama3_reorganize_model_repo,
|
898 |
additional_tags_prepend,
|
899 |
additional_tags_append,
|
|
|
900 |
],
|
901 |
variant="secondary",
|
902 |
size="lg",
|
@@ -935,7 +925,25 @@ def main():
|
|
935 |
gallery.select(get_selection_from_gallery, inputs=[gallery, tag_results], outputs=[selected_image, sorted_general_strings, categorized, rating, character_res, general_res, unclassified])
|
936 |
# Event to remove a selected image from the gallery
|
937 |
remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
|
938 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
939 |
# Connect the radio button to the visibility function
|
940 |
input_type_radio.change(
|
941 |
fn=change_input_type,
|
@@ -946,7 +954,7 @@ def main():
|
|
946 |
categorized, rating, character_res, general_res, unclassified
|
947 |
]
|
948 |
)
|
949 |
-
|
950 |
# submit click now calls the wrapper function
|
951 |
submit.click(
|
952 |
fn=run_prediction,
|
@@ -963,6 +971,7 @@ def main():
|
|
963 |
llama3_reorganize_model_repo,
|
964 |
additional_tags_prepend,
|
965 |
additional_tags_append,
|
|
|
966 |
tag_results,
|
967 |
],
|
968 |
outputs=[download_file, sorted_general_strings, categorized, rating, character_res, general_res, unclassified, tag_results,],
|
|
|
134 |
elapsed = curr_time - prev_time
|
135 |
print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
|
136 |
prev_time = curr_time
|
137 |
+
|
138 |
if is_clear_checkpoints:
|
139 |
self.checkpoints = [("Start", time.perf_counter())]
|
140 |
|
|
|
151 |
elapsed = curr_time - prev_time
|
152 |
print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
|
153 |
prev_time = curr_time
|
154 |
+
|
155 |
total_time = self.checkpoints[-1][1] - self.start_time
|
156 |
print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n")
|
157 |
|
|
|
252 |
import ctranslate2
|
253 |
import transformers
|
254 |
try:
|
255 |
+
print(f'\n\nLoading model: {self.modelPath}\n\n')
|
256 |
kwargsTokenizer = {"pretrained_model_name_or_path": self.modelPath}
|
257 |
kwargsModel = {"device": self.device, "model_path": self.modelPath, "compute_type": "auto"}
|
258 |
self.roleSystem = {"role": "system", "content": self.system_prompt}
|
|
|
270 |
try:
|
271 |
import torch
|
272 |
if torch.cuda.is_available():
|
273 |
+
if hasattr(self, "Model") and hasattr(self.Model, "unload_model"):
|
274 |
self.Model.unload_model()
|
275 |
+
if hasattr(self, "Tokenizer"):
|
|
|
276 |
del self.Tokenizer
|
277 |
+
if hasattr(self, "Model"):
|
278 |
del self.Model
|
279 |
import gc
|
280 |
gc.collect()
|
|
|
282 |
torch.cuda.empty_cache()
|
283 |
except Exception as e:
|
284 |
print(traceback.format_exc())
|
285 |
+
print(f"\tcuda empty cache, error: {e}")
|
286 |
print("release vram end.")
|
287 |
except Exception as e:
|
288 |
print(traceback.format_exc())
|
289 |
+
print(f"Error release vram: {e}")
|
290 |
|
291 |
def reorganize(self, text: str, max_length: int = 400):
|
|
|
292 |
result = None
|
293 |
try:
|
294 |
input_ids = self.Tokenizer.apply_chat_template([self.roleSystem, {"role": "user", "content": text + "\n\nHere's the reorganized English article:"}], tokenize=False, add_generation_prompt=True)
|
|
|
296 |
output = self.Model.generate_batch([source], max_length=max_length, max_batch_size=2, no_repeat_ngram_size=3, beam_size=2, sampling_temperature=0.7, sampling_topp=0.9, include_prompt_in_result=False, end_token=self.terminators)
|
297 |
target = output[0]
|
298 |
result = self.Tokenizer.decode(target.sequences_ids[0])
|
|
|
299 |
if len(result) > 2:
|
300 |
+
if result[0] == '"' and result[-1] == '"':
|
301 |
result = result[1:-1]
|
302 |
+
elif result[0] == "'" and result[-1] == "'":
|
303 |
result = result[1:-1]
|
304 |
+
elif result[0] == '「' and result[-1] == '」':
|
305 |
result = result[1:-1]
|
306 |
+
elif result[0] == '『' and result[-1] == '』':
|
307 |
result = result[1:-1]
|
308 |
except Exception as e:
|
309 |
print(traceback.format_exc())
|
310 |
+
print(f"Error reorganize text: {e}")
|
311 |
|
312 |
return result
|
313 |
|
|
|
336 |
|
337 |
tags_df = pd.read_csv(csv_path)
|
338 |
sep_tags = load_labels(tags_df)
|
339 |
+
self.tag_names, self.rating_indexes, self.general_indexes, self.character_indexes = sep_tags
|
|
|
|
|
|
|
|
|
|
|
340 |
model = rt.InferenceSession(model_path)
|
341 |
+
_, height, _, _ = model.get_inputs()[0].shape
|
342 |
self.model_target_size = height
|
|
|
343 |
self.last_loaded_repo = model_repo
|
344 |
self.model = model
|
345 |
|
346 |
def prepare_image(self, path):
|
347 |
+
image = Image.open(path).convert("RGBA")
|
|
|
|
|
|
|
348 |
canvas = Image.new("RGBA", image.size, (255, 255, 255))
|
349 |
canvas.alpha_composite(image)
|
350 |
image = canvas.convert("RGB")
|
351 |
+
|
352 |
# Pad image to square
|
353 |
image_shape = image.size
|
354 |
max_dim = max(image_shape)
|
|
|
357 |
|
358 |
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
|
359 |
padded_image.paste(image, (pad_left, pad_top))
|
360 |
+
|
361 |
# Resize
|
362 |
+
if max_dim != self.model_target_size:
|
363 |
padded_image = padded_image.resize(
|
364 |
+
(self.model_target_size, self.model_target_size),
|
365 |
Image.BICUBIC,
|
366 |
)
|
367 |
+
|
368 |
# Convert to numpy array
|
369 |
image_array = np.asarray(padded_image, dtype=np.float32)
|
370 |
|
|
|
392 |
llama3_reorganize_model_repo,
|
393 |
additional_tags_prepend,
|
394 |
additional_tags_append,
|
395 |
+
tags_to_remove,
|
396 |
tag_results,
|
397 |
progress=gr.Progress()
|
398 |
):
|
|
|
402 |
|
403 |
gallery_len = len(gallery)
|
404 |
print(f"Predict from images: load model: {model_repo}, gallery length: {gallery_len}")
|
405 |
+
|
406 |
timer = Timer() # Create a timer
|
407 |
progressRatio = 0.5 if llama3_reorganize_model_repo else 1
|
408 |
progressTotal = gallery_len + (1 if llama3_reorganize_model_repo else 0) + 1 # +1 for model load
|
|
|
412 |
current_progress += 1 / progressTotal
|
413 |
progress(current_progress, desc="Initialize wd model finished")
|
414 |
timer.checkpoint(f"Initialize wd model")
|
415 |
+
|
416 |
# Result
|
417 |
txt_infos = []
|
418 |
output_dir = tempfile.mkdtemp()
|
|
|
428 |
current_progress += 1 / progressTotal
|
429 |
progress(current_progress, desc="Initialize llama3 model finished")
|
430 |
timer.checkpoint(f"Initialize llama3 model")
|
431 |
+
|
432 |
timer.report()
|
433 |
|
434 |
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
|
435 |
append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
|
436 |
+
remove_list = [tag.strip() for tag in tags_to_remove.split(",") if tag.strip()] # Parse remove tags
|
437 |
if prepend_list and append_list:
|
438 |
append_list = [item for item in append_list if item not in prepend_list]
|
439 |
+
|
440 |
# Dictionary to track counters for each filename
|
441 |
name_counters = defaultdict(int)
|
442 |
for idx, value in enumerate(gallery):
|
|
|
457 |
preds = self.model.run([label_name], {input_name: image})[0]
|
458 |
|
459 |
labels = list(zip(self.tag_names, preds[0].astype(float)))
|
460 |
+
|
461 |
# First 4 labels are actually ratings: pick one with argmax
|
462 |
ratings_names = [labels[i] for i in self.rating_indexes]
|
463 |
rating = dict(ratings_names)
|
464 |
+
|
465 |
# Then we have general tags: pick any where prediction confidence > threshold
|
466 |
general_names = [labels[i] for i in self.general_indexes]
|
467 |
|
|
|
469 |
general_probs = np.array([x[1] for x in general_names])
|
470 |
general_thresh = mcut_threshold(general_probs)
|
471 |
general_res = dict([x for x in general_names if x[1] > general_thresh])
|
472 |
+
|
473 |
# Everything else is characters: pick any where prediction confidence > threshold
|
474 |
character_names = [labels[i] for i in self.character_indexes]
|
475 |
|
|
|
493 |
final_tags_list = prepend_list + sorted_general_list + append_list
|
494 |
if characters_merge_enabled:
|
495 |
final_tags_list = character_list + final_tags_list
|
496 |
+
|
497 |
+
# Apply removal logic
|
498 |
+
if remove_list:
|
499 |
+
remove_set = set(remove_list)
|
500 |
+
final_tags_list = [tag for tag in final_tags_list if tag not in remove_set]
|
501 |
+
|
502 |
sorted_general_strings = ", ".join(final_tags_list).replace("(", "\(").replace(")", "\)")
|
503 |
classified_tags, unclassified_tags = classify_tags(final_tags_list)
|
504 |
|
|
|
548 |
# Get file name from lookup
|
549 |
taggers_zip.write(info["path"], arcname=info["name"])
|
550 |
download.append(downloadZipPath)
|
551 |
+
|
552 |
if llama3_reorganize:
|
553 |
llama3_reorganize.release_vram()
|
554 |
+
|
555 |
progress(1, desc="Image processing completed")
|
556 |
timer.report_all()
|
557 |
print("Image prediction is complete.")
|
558 |
|
559 |
return download, last_sorted_general_strings, last_classified_tags, last_rating, last_character_res, last_general_res, last_unclassified_tags, tag_results
|
560 |
+
|
561 |
+
# Method to process text files
|
562 |
def predict_from_text(
|
563 |
self,
|
564 |
text_files,
|
565 |
llama3_reorganize_model_repo,
|
566 |
additional_tags_prepend,
|
567 |
additional_tags_append,
|
568 |
+
tags_to_remove,
|
569 |
progress=gr.Progress()
|
570 |
):
|
571 |
if not text_files:
|
|
|
579 |
progressRatio = 0.5 if llama3_reorganize_model_repo else 1.0
|
580 |
progressTotal = files_len + (1 if llama3_reorganize_model_repo else 0)
|
581 |
current_progress = 0
|
582 |
+
|
583 |
txt_infos = []
|
584 |
output_dir = tempfile.mkdtemp()
|
585 |
last_processed_string = ""
|
|
|
596 |
|
597 |
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
|
598 |
append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
|
599 |
+
remove_list = [tag.strip() for tag in tags_to_remove.split(",") if tag.strip()] # Parse remove tags
|
600 |
if prepend_list and append_list:
|
601 |
append_list = [item for item in append_list if item not in prepend_list]
|
602 |
|
|
|
605 |
try:
|
606 |
file_path = file_obj.name
|
607 |
file_name_base = os.path.splitext(os.path.basename(file_path))[0]
|
608 |
+
|
609 |
name_counters[file_name_base] += 1
|
610 |
if name_counters[file_name_base] > 1:
|
611 |
output_file_name = f"{file_name_base}_{name_counters[file_name_base]:02d}.txt"
|
|
|
614 |
|
615 |
with open(file_path, 'r', encoding='utf-8') as f:
|
616 |
original_content = f.read()
|
617 |
+
|
618 |
# Process tags
|
619 |
tags_list = [tag.strip() for tag in original_content.split(',') if tag.strip()]
|
620 |
+
|
621 |
if prepend_list:
|
622 |
tags_list = [item for item in tags_list if item not in prepend_list]
|
623 |
if append_list:
|
624 |
tags_list = [item for item in tags_list if item not in append_list]
|
625 |
|
626 |
final_tags_list = prepend_list + tags_list + append_list
|
627 |
+
|
628 |
+
# Apply removal logic
|
629 |
+
if remove_list:
|
630 |
+
remove_set = set(remove_list)
|
631 |
+
final_tags_list = [tag for tag in final_tags_list if tag not in remove_set]
|
632 |
+
|
633 |
processed_string = ", ".join(final_tags_list)
|
634 |
|
635 |
current_progress += progressRatio / progressTotal
|
|
|
648 |
current_progress += progressRatio / progressTotal
|
649 |
progress(current_progress, desc=f"File {idx+1}/{files_len}, llama3 reorganize finished")
|
650 |
timer.checkpoint(f"File {idx+1}/{files_len}, llama3 reorganize finished")
|
651 |
+
|
652 |
txt_file_path = self.create_file(processed_string, output_dir, output_file_name)
|
653 |
txt_infos.append({"path": txt_file_path, "name": output_file_name})
|
654 |
last_processed_string = processed_string
|
|
|
674 |
progress(1, desc="Text processing completed")
|
675 |
timer.report_all() # Print all recorded times
|
676 |
print("Text processing is complete.")
|
677 |
+
|
678 |
# Return values in the same structure as the image path, with placeholders for unused outputs
|
679 |
return download, last_processed_string, "{}", "", "", "", "{}", {}
|
680 |
|
|
|
682 |
if not selected_state:
|
683 |
return selected_state
|
684 |
|
685 |
+
tag_result = tag_results.get(selected_state.value["image"]["path"],
|
686 |
+
{"strings": "", "classified_tags": "{}", "rating": "", "character_res": "", "general_res": "", "unclassified_tags": "{}"})
|
|
|
687 |
|
688 |
return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"]
|
689 |
|
|
|
692 |
gallery = []
|
693 |
if not image:
|
694 |
return gallery, None
|
695 |
+
|
696 |
gallery.append(image)
|
697 |
|
698 |
return gallery, None
|
|
|
714 |
return gallery
|
715 |
|
716 |
try:
|
717 |
+
selected_image_tuple = ast.literal_eval(selected_image) #Use ast.literal_eval to parse text into a tuple.
|
718 |
# Remove the selected image from the gallery
|
719 |
+
if selected_image_tuple in gallery:
|
720 |
+
gallery.remove(selected_image_tuple)
|
721 |
except (ValueError, SyntaxError):
|
722 |
# Handle cases where the string is not a valid literal
|
723 |
print(f"Warning: Could not parse selected_image string: {selected_image}")
|
724 |
+
|
725 |
return gallery
|
726 |
|
727 |
|
|
|
753 |
SWINV2_MODEL_IS_DSV1_REPO,
|
754 |
EVA02_LARGE_MODEL_IS_DSV1_REPO,
|
755 |
]
|
756 |
+
|
757 |
llama_list = [
|
758 |
META_LLAMA_3_3B_REPO,
|
759 |
META_LLAMA_3_8B_REPO,
|
760 |
]
|
761 |
+
|
762 |
+
# Wrapper function to decide which prediction method to call
|
763 |
def run_prediction(
|
764 |
input_type, gallery, text_files, model_repo, general_thresh,
|
765 |
general_mcut_enabled, character_thresh, character_mcut_enabled,
|
766 |
characters_merge_enabled, llama3_reorganize_model_repo,
|
767 |
+
additional_tags_prepend, additional_tags_append, tags_to_remove,
|
768 |
+
tag_results, progress=gr.Progress()
|
769 |
):
|
770 |
if input_type == 'Image':
|
771 |
return predictor.predict_from_images(
|
772 |
gallery, model_repo, general_thresh, general_mcut_enabled,
|
773 |
character_thresh, character_mcut_enabled, characters_merge_enabled,
|
774 |
llama3_reorganize_model_repo, additional_tags_prepend,
|
775 |
+
additional_tags_append, tags_to_remove, tag_results, progress
|
776 |
)
|
777 |
else: # 'Text file (.txt)'
|
778 |
# For text files, some parameters are not used, but we must return
|
779 |
# a tuple of the same size. `predict_from_text` handles this.
|
780 |
return predictor.predict_from_text(
|
781 |
text_files, llama3_reorganize_model_repo,
|
782 |
+
additional_tags_prepend, additional_tags_append, tags_to_remove, progress
|
783 |
)
|
784 |
|
785 |
with gr.Blocks(title=TITLE, css=css) as demo:
|
|
|
796 |
value='Image',
|
797 |
label="Input Type"
|
798 |
)
|
799 |
+
|
800 |
# Group for image inputs, initially visible
|
801 |
with gr.Column(visible=True) as image_inputs_group:
|
802 |
with gr.Column(variant="panel"):
|
|
|
806 |
upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
|
807 |
remove_button = gr.Button("Remove Selected Image", size="sm")
|
808 |
gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Gallery that displaying a grid of images")
|
809 |
+
|
810 |
+
# Group for text file inputs, initially hidden
|
811 |
with gr.Column(visible=False) as text_inputs_group:
|
812 |
text_files_input = gr.Files(
|
813 |
label="Upload .txt files",
|
|
|
816 |
height=500
|
817 |
)
|
818 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
819 |
# Image-specific settings
|
820 |
model_repo = gr.Dropdown(
|
821 |
dropdown_list,
|
|
|
868 |
with gr.Row():
|
869 |
additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
|
870 |
additional_tags_append = gr.Text(label="Append Additional tags (comma split)")
|
871 |
+
|
872 |
+
# NEW: Add the remove tags input box
|
873 |
+
tags_to_remove = gr.Text(label="Remove tags (comma split)")
|
874 |
+
|
875 |
with gr.Row():
|
876 |
clear = gr.ClearButton(
|
877 |
components=[
|
|
|
886 |
llama3_reorganize_model_repo,
|
887 |
additional_tags_prepend,
|
888 |
additional_tags_append,
|
889 |
+
tags_to_remove,
|
890 |
],
|
891 |
variant="secondary",
|
892 |
size="lg",
|
|
|
925 |
gallery.select(get_selection_from_gallery, inputs=[gallery, tag_results], outputs=[selected_image, sorted_general_strings, categorized, rating, character_res, general_res, unclassified])
|
926 |
# Event to remove a selected image from the gallery
|
927 |
remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
|
928 |
+
|
929 |
+
# Logic to show/hide input groups based on radio selection
|
930 |
+
def change_input_type(input_type):
|
931 |
+
is_image = (input_type == 'Image')
|
932 |
+
return {
|
933 |
+
image_inputs_group: gr.update(visible=is_image),
|
934 |
+
text_inputs_group: gr.update(visible=not is_image),
|
935 |
+
# Also update visibility of image-specific settings
|
936 |
+
model_repo: gr.update(visible=is_image),
|
937 |
+
general_thresh_row: gr.update(visible=is_image),
|
938 |
+
character_thresh_row: gr.update(visible=is_image),
|
939 |
+
characters_merge_enabled: gr.update(visible=is_image),
|
940 |
+
categorized: gr.update(visible=is_image),
|
941 |
+
rating: gr.update(visible=is_image),
|
942 |
+
character_res: gr.update(visible=is_image),
|
943 |
+
general_res: gr.update(visible=is_image),
|
944 |
+
unclassified: gr.update(visible=is_image),
|
945 |
+
}
|
946 |
+
|
947 |
# Connect the radio button to the visibility function
|
948 |
input_type_radio.change(
|
949 |
fn=change_input_type,
|
|
|
954 |
categorized, rating, character_res, general_res, unclassified
|
955 |
]
|
956 |
)
|
957 |
+
|
958 |
# submit click now calls the wrapper function
|
959 |
submit.click(
|
960 |
fn=run_prediction,
|
|
|
971 |
llama3_reorganize_model_repo,
|
972 |
additional_tags_prepend,
|
973 |
additional_tags_append,
|
974 |
+
tags_to_remove,
|
975 |
tag_results,
|
976 |
],
|
977 |
outputs=[download_file, sorted_general_strings, categorized, rating, character_res, general_res, unclassified, tag_results,],
|