Spaces:
Running
Running
The input type now supports text files in addition to images.
Browse filesWhen a text file is used as input, you can perform actions such as "reorganize the article" and "Prepend/Append Additional tags".
app.py
CHANGED
@@ -17,14 +17,16 @@ from datetime import datetime
|
|
17 |
from collections import defaultdict
|
18 |
from classifyTags import classify_tags
|
19 |
|
20 |
-
TITLE = "WaifuDiffusion Tagger multiple images"
|
21 |
DESCRIPTION = """
|
22 |
-
Demo for the WaifuDiffusion tagger models
|
|
|
23 |
Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
|
24 |
|
|
|
25 |
Features of This Modified Version:
|
26 |
-
- Supports batch processing of multiple images
|
27 |
-
- Displays tag results in categorized groups: the generated tags will now be analyzed and categorized into corresponding groups.
|
28 |
"""
|
29 |
|
30 |
# Dataset v3 series of models:
|
@@ -124,33 +126,34 @@ class Timer:
|
|
124 |
|
125 |
def report(self, is_clear_checkpoints = True):
|
126 |
# Determine the max label width for alignment
|
127 |
-
max_label_length = max(len(label) for label, _ in self.checkpoints)
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
134 |
|
135 |
if is_clear_checkpoints:
|
136 |
-
self.checkpoints.
|
137 |
-
self.checkpoint() # Store checkpoints
|
138 |
|
139 |
def report_all(self):
|
140 |
"""Print all recorded checkpoints and total execution time with aligned formatting."""
|
141 |
print("\n> Execution Time Report:")
|
142 |
|
143 |
# Determine the max label width for alignment
|
144 |
-
max_label_length = max(len(label) for label, _ in self.checkpoints) if
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
|
|
154 |
|
155 |
self.checkpoints.clear()
|
156 |
|
@@ -384,12 +387,12 @@ class Predictor:
|
|
384 |
|
385 |
def create_file(self, text: str, directory: str, fileName: str) -> str:
|
386 |
# Write the text to a file
|
387 |
-
|
|
|
388 |
file.write(text)
|
|
|
389 |
|
390 |
-
|
391 |
-
|
392 |
-
def predict(
|
393 |
self,
|
394 |
gallery,
|
395 |
model_repo,
|
@@ -404,34 +407,36 @@ class Predictor:
|
|
404 |
tag_results,
|
405 |
progress=gr.Progress()
|
406 |
):
|
|
|
|
|
|
|
|
|
407 |
gallery_len = len(gallery)
|
408 |
-
print(f"Predict load model: {model_repo}, gallery length: {gallery_len}")
|
409 |
|
410 |
timer = Timer() # Create a timer
|
411 |
progressRatio = 0.5 if llama3_reorganize_model_repo else 1
|
412 |
-
progressTotal = gallery_len + 1
|
413 |
current_progress = 0
|
414 |
|
415 |
self.load_model(model_repo)
|
416 |
-
current_progress +=
|
417 |
progress(current_progress, desc="Initialize wd model finished")
|
418 |
timer.checkpoint(f"Initialize wd model")
|
419 |
-
|
420 |
# Result
|
421 |
txt_infos = []
|
422 |
output_dir = tempfile.mkdtemp()
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
rating = None
|
428 |
-
character_res = None
|
429 |
-
general_res = None
|
430 |
|
|
|
431 |
if llama3_reorganize_model_repo:
|
432 |
print(f"Llama3 reorganize load model {llama3_reorganize_model_repo}")
|
433 |
llama3_reorganize = Llama3Reorganize(llama3_reorganize_model_repo, loadModel=True)
|
434 |
-
current_progress +=
|
435 |
progress(current_progress, desc="Initialize llama3 model finished")
|
436 |
timer.checkpoint(f"Initialize llama3 model")
|
437 |
|
@@ -458,7 +463,7 @@ class Predictor:
|
|
458 |
|
459 |
input_name = self.model.get_inputs()[0].name
|
460 |
label_name = self.model.get_outputs()[0].name
|
461 |
-
print(f"Gallery {idx
|
462 |
preds = self.model.run([label_name], {input_name: image})[0]
|
463 |
|
464 |
labels = list(zip(self.tag_names, preds[0].astype(float)))
|
@@ -473,9 +478,7 @@ class Predictor:
|
|
473 |
if general_mcut_enabled:
|
474 |
general_probs = np.array([x[1] for x in general_names])
|
475 |
general_thresh = mcut_threshold(general_probs)
|
476 |
-
|
477 |
-
general_res = [x for x in general_names if x[1] > general_thresh]
|
478 |
-
general_res = dict(general_res)
|
479 |
|
480 |
# Everything else is characters: pick any where prediction confidence > threshold
|
481 |
character_names = [labels[i] for i in self.character_indexes]
|
@@ -484,16 +487,10 @@ class Predictor:
|
|
484 |
character_probs = np.array([x[1] for x in character_names])
|
485 |
character_thresh = mcut_threshold(character_probs)
|
486 |
character_thresh = max(0.15, character_thresh)
|
487 |
-
|
488 |
-
character_res = [x for x in character_names if x[1] > character_thresh]
|
489 |
-
character_res = dict(character_res)
|
490 |
character_list = list(character_res.keys())
|
491 |
|
492 |
-
sorted_general_list = sorted(
|
493 |
-
general_res.items(),
|
494 |
-
key=lambda x: x[1],
|
495 |
-
reverse=True,
|
496 |
-
)
|
497 |
sorted_general_list = [x[0] for x in sorted_general_list]
|
498 |
#Remove values from character_list that already exist in sorted_general_list
|
499 |
character_list = [item for item in character_list if item not in sorted_general_list]
|
@@ -503,57 +500,181 @@ class Predictor:
|
|
503 |
if append_list:
|
504 |
sorted_general_list = [item for item in sorted_general_list if item not in append_list]
|
505 |
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
|
|
511 |
|
512 |
-
current_progress += progressRatio/progressTotal
|
513 |
-
progress(current_progress, desc=f"
|
514 |
-
timer.checkpoint(f"
|
515 |
|
516 |
-
if
|
517 |
print(f"Starting reorganize with llama3...")
|
518 |
reorganize_strings = llama3_reorganize.reorganize(sorted_general_strings)
|
519 |
-
reorganize_strings
|
520 |
-
|
521 |
-
|
522 |
-
|
|
|
523 |
|
524 |
-
current_progress += progressRatio/progressTotal
|
525 |
-
progress(current_progress, desc=f"
|
526 |
-
timer.checkpoint(f"
|
527 |
|
528 |
txt_file = self.create_file(sorted_general_strings, output_dir, image_name + ".txt")
|
529 |
-
txt_infos.append({"path":txt_file, "name": image_name + ".txt"})
|
530 |
|
531 |
tag_results[image_path] = { "strings": sorted_general_strings, "classified_tags": classified_tags, "rating": rating, "character_res": character_res, "general_res": general_res, "unclassified_tags": unclassified_tags }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
532 |
timer.report()
|
|
|
533 |
except Exception as e:
|
534 |
print(traceback.format_exc())
|
535 |
-
print("Error
|
536 |
-
|
|
|
537 |
# Result
|
538 |
download = []
|
539 |
-
if txt_infos
|
540 |
-
|
|
|
541 |
with zipfile.ZipFile(downloadZipPath, 'w', zipfile.ZIP_DEFLATED) as taggers_zip:
|
542 |
for info in txt_infos:
|
543 |
# Get file name from lookup
|
544 |
taggers_zip.write(info["path"], arcname=info["name"])
|
545 |
download.append(downloadZipPath)
|
546 |
|
547 |
-
if
|
548 |
llama3_reorganize.release_vram()
|
549 |
-
del llama3_reorganize
|
550 |
|
551 |
-
progress(1, desc=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
552 |
timer.report_all() # Print all recorded times
|
553 |
-
print("
|
|
|
|
|
|
|
554 |
|
555 |
-
return download, sorted_general_strings, classified_tags, rating, character_res, general_res, unclassified_tags, tag_results
|
556 |
-
|
557 |
def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData):
|
558 |
if not selected_state:
|
559 |
return selected_state
|
@@ -590,10 +711,15 @@ def remove_image_from_gallery(gallery: list, selected_image: str):
|
|
590 |
if not gallery or not selected_image:
|
591 |
return gallery
|
592 |
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
gallery
|
|
|
|
|
|
|
|
|
|
|
597 |
return gallery
|
598 |
|
599 |
|
@@ -605,7 +731,6 @@ def main():
|
|
605 |
width: 55.5% !important;
|
606 |
}
|
607 |
"""
|
608 |
-
|
609 |
args = parse_args()
|
610 |
|
611 |
predictor = Predictor()
|
@@ -626,34 +751,93 @@ def main():
|
|
626 |
SWINV2_MODEL_IS_DSV1_REPO,
|
627 |
EVA02_LARGE_MODEL_IS_DSV1_REPO,
|
628 |
]
|
629 |
-
|
630 |
llama_list = [
|
631 |
META_LLAMA_3_3B_REPO,
|
632 |
META_LLAMA_3_8B_REPO,
|
633 |
]
|
634 |
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
639 |
gr.Markdown(value=DESCRIPTION)
|
|
|
640 |
with gr.Row():
|
641 |
with gr.Column():
|
642 |
submit = gr.Button(value="Submit", variant="primary", size="lg")
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
650 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
651 |
model_repo = gr.Dropdown(
|
652 |
dropdown_list,
|
653 |
value=EVA02_LARGE_MODEL_DSV3_REPO,
|
654 |
-
label="Model",
|
655 |
)
|
656 |
-
with gr.Row():
|
657 |
general_thresh = gr.Slider(
|
658 |
0,
|
659 |
1,
|
@@ -667,7 +851,7 @@ def main():
|
|
667 |
label="Use MCut threshold",
|
668 |
scale=1,
|
669 |
)
|
670 |
-
with gr.Row():
|
671 |
character_thresh = gr.Slider(
|
672 |
0,
|
673 |
1,
|
@@ -681,18 +865,20 @@ def main():
|
|
681 |
label="Use MCut threshold",
|
682 |
scale=1,
|
683 |
)
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
|
|
|
|
690 |
with gr.Row():
|
691 |
llama3_reorganize_model_repo = gr.Dropdown(
|
692 |
[None] + llama_list,
|
693 |
value=None,
|
694 |
-
label="Llama3
|
695 |
-
info="
|
696 |
)
|
697 |
with gr.Row():
|
698 |
additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
|
@@ -701,6 +887,7 @@ def main():
|
|
701 |
clear = gr.ClearButton(
|
702 |
components=[
|
703 |
gallery,
|
|
|
704 |
model_repo,
|
705 |
general_thresh,
|
706 |
general_mcut_enabled,
|
@@ -714,14 +901,16 @@ def main():
|
|
714 |
variant="secondary",
|
715 |
size="lg",
|
716 |
)
|
|
|
717 |
with gr.Column(variant="panel"):
|
718 |
download_file = gr.File(label="Output (Download)")
|
719 |
-
sorted_general_strings = gr.Textbox(label="Output (string)", show_label=True, show_copy_button=True)
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
|
|
725 |
clear.add(
|
726 |
[
|
727 |
download_file,
|
@@ -733,35 +922,51 @@ def main():
|
|
733 |
unclassified,
|
734 |
]
|
735 |
)
|
736 |
-
|
737 |
tag_results = gr.State({})
|
|
|
|
|
|
|
738 |
# Define the event listener to add the uploaded image to the gallery
|
739 |
image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
|
740 |
# When the upload button is clicked, add the new images to the gallery
|
741 |
upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery)
|
742 |
# Event to update the selected image when an image is clicked in the gallery
|
743 |
-
selected_image = gr.Textbox(label="Selected Image", visible=False)
|
744 |
gallery.select(get_selection_from_gallery, inputs=[gallery, tag_results], outputs=[selected_image, sorted_general_strings, categorized, rating, character_res, general_res, unclassified])
|
745 |
# Event to remove a selected image from the gallery
|
746 |
remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
747 |
|
748 |
-
|
749 |
-
|
750 |
-
|
751 |
-
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
|
|
|
|
|
|
765 |
|
766 |
gr.Examples(
|
767 |
[["power.jpg", SWINV2_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
|
|
|
17 |
from collections import defaultdict
|
18 |
from classifyTags import classify_tags
|
19 |
|
20 |
+
TITLE = "WaifuDiffusion Tagger multiple images/texts"
|
21 |
DESCRIPTION = """
|
22 |
+
Demo for the WaifuDiffusion tagger models and text processing.
|
23 |
+
Select input type below. For images, it will generate tags. For text files, it will process existing tags.
|
24 |
Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
|
25 |
|
26 |
+
This project was duplicated from the Space of [wd-tagger](https://huggingface.co/spaces/SmilingWolf/wd-tagger) by the author SmilingWolf.
|
27 |
Features of This Modified Version:
|
28 |
+
- Supports batch processing of multiple images or text files.
|
29 |
+
- Displays tag results in categorized groups: the generated tags will now be analyzed and categorized into corresponding groups. (for images)
|
30 |
"""
|
31 |
|
32 |
# Dataset v3 series of models:
|
|
|
126 |
|
127 |
def report(self, is_clear_checkpoints = True):
|
128 |
# Determine the max label width for alignment
|
129 |
+
max_label_length = max(len(label) for label, _ in self.checkpoints) if self.checkpoints else 0
|
130 |
+
|
131 |
+
if len(self.checkpoints) > 1:
|
132 |
+
prev_time = self.checkpoints[0][1]
|
133 |
+
for label, curr_time in self.checkpoints[1:]:
|
134 |
+
elapsed = curr_time - prev_time
|
135 |
+
print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
|
136 |
+
prev_time = curr_time
|
137 |
|
138 |
if is_clear_checkpoints:
|
139 |
+
self.checkpoints = [("Start", time.perf_counter())]
|
|
|
140 |
|
141 |
def report_all(self):
|
142 |
"""Print all recorded checkpoints and total execution time with aligned formatting."""
|
143 |
print("\n> Execution Time Report:")
|
144 |
|
145 |
# Determine the max label width for alignment
|
146 |
+
max_label_length = max(len(label) for label, _ in self.checkpoints) if self.checkpoints else 0
|
147 |
+
|
148 |
+
if len(self.checkpoints) > 1:
|
149 |
+
prev_time = self.start_time
|
150 |
+
for label, curr_time in self.checkpoints[1:]:
|
151 |
+
elapsed = curr_time - prev_time
|
152 |
+
print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
|
153 |
+
prev_time = curr_time
|
154 |
+
|
155 |
+
total_time = self.checkpoints[-1][1] - self.start_time
|
156 |
+
print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n")
|
157 |
|
158 |
self.checkpoints.clear()
|
159 |
|
|
|
387 |
|
388 |
def create_file(self, text: str, directory: str, fileName: str) -> str:
|
389 |
# Write the text to a file
|
390 |
+
filepath = os.path.join(directory, fileName)
|
391 |
+
with open(filepath, 'w', encoding="utf-8") as file:
|
392 |
file.write(text)
|
393 |
+
return filepath
|
394 |
|
395 |
+
def predict_from_images(
|
|
|
|
|
396 |
self,
|
397 |
gallery,
|
398 |
model_repo,
|
|
|
407 |
tag_results,
|
408 |
progress=gr.Progress()
|
409 |
):
|
410 |
+
if not gallery:
|
411 |
+
gr.Warning("No images in the gallery to process.")
|
412 |
+
return None, "", "{}", "", "", "", "{}", {}
|
413 |
+
|
414 |
gallery_len = len(gallery)
|
415 |
+
print(f"Predict from images: load model: {model_repo}, gallery length: {gallery_len}")
|
416 |
|
417 |
timer = Timer() # Create a timer
|
418 |
progressRatio = 0.5 if llama3_reorganize_model_repo else 1
|
419 |
+
progressTotal = gallery_len + (1 if llama3_reorganize_model_repo else 0) + 1 # +1 for model load
|
420 |
current_progress = 0
|
421 |
|
422 |
self.load_model(model_repo)
|
423 |
+
current_progress += 1 / progressTotal
|
424 |
progress(current_progress, desc="Initialize wd model finished")
|
425 |
timer.checkpoint(f"Initialize wd model")
|
426 |
+
|
427 |
# Result
|
428 |
txt_infos = []
|
429 |
output_dir = tempfile.mkdtemp()
|
430 |
+
|
431 |
+
last_sorted_general_strings = ""
|
432 |
+
last_classified_tags, last_unclassified_tags = {}, {}
|
433 |
+
last_rating, last_character_res, last_general_res = None, None, None
|
|
|
|
|
|
|
434 |
|
435 |
+
llama3_reorganize = None
|
436 |
if llama3_reorganize_model_repo:
|
437 |
print(f"Llama3 reorganize load model {llama3_reorganize_model_repo}")
|
438 |
llama3_reorganize = Llama3Reorganize(llama3_reorganize_model_repo, loadModel=True)
|
439 |
+
current_progress += 1 / progressTotal
|
440 |
progress(current_progress, desc="Initialize llama3 model finished")
|
441 |
timer.checkpoint(f"Initialize llama3 model")
|
442 |
|
|
|
463 |
|
464 |
input_name = self.model.get_inputs()[0].name
|
465 |
label_name = self.model.get_outputs()[0].name
|
466 |
+
print(f"Gallery {idx+1}/{gallery_len}: Starting run wd model...")
|
467 |
preds = self.model.run([label_name], {input_name: image})[0]
|
468 |
|
469 |
labels = list(zip(self.tag_names, preds[0].astype(float)))
|
|
|
478 |
if general_mcut_enabled:
|
479 |
general_probs = np.array([x[1] for x in general_names])
|
480 |
general_thresh = mcut_threshold(general_probs)
|
481 |
+
general_res = dict([x for x in general_names if x[1] > general_thresh])
|
|
|
|
|
482 |
|
483 |
# Everything else is characters: pick any where prediction confidence > threshold
|
484 |
character_names = [labels[i] for i in self.character_indexes]
|
|
|
487 |
character_probs = np.array([x[1] for x in character_names])
|
488 |
character_thresh = mcut_threshold(character_probs)
|
489 |
character_thresh = max(0.15, character_thresh)
|
490 |
+
character_res = dict([x for x in character_names if x[1] > character_thresh])
|
|
|
|
|
491 |
character_list = list(character_res.keys())
|
492 |
|
493 |
+
sorted_general_list = sorted(general_res.items(), key=lambda x: x[1], reverse=True)
|
|
|
|
|
|
|
|
|
494 |
sorted_general_list = [x[0] for x in sorted_general_list]
|
495 |
#Remove values from character_list that already exist in sorted_general_list
|
496 |
character_list = [item for item in character_list if item not in sorted_general_list]
|
|
|
500 |
if append_list:
|
501 |
sorted_general_list = [item for item in sorted_general_list if item not in append_list]
|
502 |
|
503 |
+
final_tags_list = prepend_list + sorted_general_list + append_list
|
504 |
+
if characters_merge_enabled:
|
505 |
+
final_tags_list = character_list + final_tags_list
|
506 |
+
|
507 |
+
sorted_general_strings = ", ".join(final_tags_list).replace("(", "\(").replace(")", "\)")
|
508 |
+
classified_tags, unclassified_tags = classify_tags(final_tags_list)
|
509 |
|
510 |
+
current_progress += progressRatio / progressTotal
|
511 |
+
progress(current_progress, desc=f"Image {idx+1}/{gallery_len}, predict finished")
|
512 |
+
timer.checkpoint(f"Image {idx+1}/{gallery_len}, predict finished")
|
513 |
|
514 |
+
if llama3_reorganize:
|
515 |
print(f"Starting reorganize with llama3...")
|
516 |
reorganize_strings = llama3_reorganize.reorganize(sorted_general_strings)
|
517 |
+
if reorganize_strings:
|
518 |
+
reorganize_strings = re.sub(r" *Title: *", "", reorganize_strings)
|
519 |
+
reorganize_strings = re.sub(r"\n+", ",", reorganize_strings)
|
520 |
+
reorganize_strings = re.sub(r",,+", ",", reorganize_strings)
|
521 |
+
sorted_general_strings += "," + reorganize_strings
|
522 |
|
523 |
+
current_progress += progressRatio / progressTotal
|
524 |
+
progress(current_progress, desc=f"Image {idx+1}/{gallery_len}, llama3 reorganize finished")
|
525 |
+
timer.checkpoint(f"Image {idx+1}/{gallery_len}, llama3 reorganize finished")
|
526 |
|
527 |
txt_file = self.create_file(sorted_general_strings, output_dir, image_name + ".txt")
|
528 |
+
txt_infos.append({"path": txt_file, "name": image_name + ".txt"})
|
529 |
|
530 |
tag_results[image_path] = { "strings": sorted_general_strings, "classified_tags": classified_tags, "rating": rating, "character_res": character_res, "general_res": general_res, "unclassified_tags": unclassified_tags }
|
531 |
+
|
532 |
+
# Store last result for UI display
|
533 |
+
last_sorted_general_strings = sorted_general_strings
|
534 |
+
last_classified_tags = classified_tags
|
535 |
+
last_rating = rating
|
536 |
+
last_character_res = character_res
|
537 |
+
last_general_res = general_res
|
538 |
+
last_unclassified_tags = unclassified_tags
|
539 |
timer.report()
|
540 |
+
|
541 |
except Exception as e:
|
542 |
print(traceback.format_exc())
|
543 |
+
print("Error predicting image: " + str(e))
|
544 |
+
gr.Warning(f"Failed to process image {os.path.basename(value[0])}. Error: {e}")
|
545 |
+
|
546 |
# Result
|
547 |
download = []
|
548 |
+
if txt_infos:
|
549 |
+
zip_filename = "images-tagger-" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".zip"
|
550 |
+
downloadZipPath = os.path.join(output_dir, zip_filename)
|
551 |
with zipfile.ZipFile(downloadZipPath, 'w', zipfile.ZIP_DEFLATED) as taggers_zip:
|
552 |
for info in txt_infos:
|
553 |
# Get file name from lookup
|
554 |
taggers_zip.write(info["path"], arcname=info["name"])
|
555 |
download.append(downloadZipPath)
|
556 |
|
557 |
+
if llama3_reorganize:
|
558 |
llama3_reorganize.release_vram()
|
|
|
559 |
|
560 |
+
progress(1, desc="Image processing completed")
|
561 |
+
timer.report_all()
|
562 |
+
print("Image prediction is complete.")
|
563 |
+
|
564 |
+
return download, last_sorted_general_strings, last_classified_tags, last_rating, last_character_res, last_general_res, last_unclassified_tags, tag_results
|
565 |
+
|
566 |
+
# NEW: Method to process text files
|
567 |
+
def predict_from_text(
|
568 |
+
self,
|
569 |
+
text_files,
|
570 |
+
llama3_reorganize_model_repo,
|
571 |
+
additional_tags_prepend,
|
572 |
+
additional_tags_append,
|
573 |
+
progress=gr.Progress()
|
574 |
+
):
|
575 |
+
if not text_files:
|
576 |
+
gr.Warning("No text files uploaded to process.")
|
577 |
+
return None, "", "{}", "", "", "", "{}", {}
|
578 |
+
|
579 |
+
files_len = len(text_files)
|
580 |
+
print(f"Predict from text: processing {files_len} files.")
|
581 |
+
|
582 |
+
timer = Timer()
|
583 |
+
progressRatio = 0.5 if llama3_reorganize_model_repo else 1.0
|
584 |
+
progressTotal = files_len + (1 if llama3_reorganize_model_repo else 0)
|
585 |
+
current_progress = 0
|
586 |
+
|
587 |
+
txt_infos = []
|
588 |
+
output_dir = tempfile.mkdtemp()
|
589 |
+
last_processed_string = ""
|
590 |
+
|
591 |
+
llama3_reorganize = None
|
592 |
+
if llama3_reorganize_model_repo:
|
593 |
+
print(f"Llama3 reorganize load model {llama3_reorganize_model_repo}")
|
594 |
+
llama3_reorganize = Llama3Reorganize(llama3_reorganize_model_repo, loadModel=True)
|
595 |
+
current_progress += 1 / progressTotal
|
596 |
+
progress(current_progress, desc="Initialize llama3 model finished")
|
597 |
+
timer.checkpoint(f"Initialize llama3 model")
|
598 |
+
|
599 |
+
timer.report()
|
600 |
+
|
601 |
+
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
|
602 |
+
append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
|
603 |
+
if prepend_list and append_list:
|
604 |
+
append_list = [item for item in append_list if item not in prepend_list]
|
605 |
+
|
606 |
+
name_counters = defaultdict(int)
|
607 |
+
for idx, file_obj in enumerate(text_files):
|
608 |
+
try:
|
609 |
+
file_path = file_obj.name
|
610 |
+
file_name_base = os.path.splitext(os.path.basename(file_path))[0]
|
611 |
+
|
612 |
+
name_counters[file_name_base] += 1
|
613 |
+
if name_counters[file_name_base] > 1:
|
614 |
+
output_file_name = f"{file_name_base}_{name_counters[file_name_base]:02d}.txt"
|
615 |
+
else:
|
616 |
+
output_file_name = f"{file_name_base}.txt"
|
617 |
+
|
618 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
619 |
+
original_content = f.read()
|
620 |
+
|
621 |
+
# Process tags
|
622 |
+
tags_list = [tag.strip() for tag in original_content.split(',') if tag.strip()]
|
623 |
+
|
624 |
+
if prepend_list:
|
625 |
+
tags_list = [item for item in tags_list if item not in prepend_list]
|
626 |
+
if append_list:
|
627 |
+
tags_list = [item for item in tags_list if item not in append_list]
|
628 |
+
|
629 |
+
final_tags_list = prepend_list + tags_list + append_list
|
630 |
+
processed_string = ", ".join(final_tags_list)
|
631 |
+
|
632 |
+
current_progress += progressRatio / progressTotal
|
633 |
+
progress(current_progress, desc=f"File {idx+1}/{files_len}, base processing finished")
|
634 |
+
timer.checkpoint(f"File {idx+1}/{files_len}, base processing finished")
|
635 |
+
|
636 |
+
if llama3_reorganize:
|
637 |
+
print(f"Starting reorganize with llama3...")
|
638 |
+
reorganize_strings = llama3_reorganize.reorganize(processed_string)
|
639 |
+
if reorganize_strings:
|
640 |
+
reorganize_strings = re.sub(r" *Title: *", "", reorganize_strings)
|
641 |
+
reorganize_strings = re.sub(r"\n+", ",", reorganize_strings)
|
642 |
+
reorganize_strings = re.sub(r",,+", ",", reorganize_strings)
|
643 |
+
processed_string += "," + reorganize_strings
|
644 |
+
|
645 |
+
current_progress += progressRatio / progressTotal
|
646 |
+
progress(current_progress, desc=f"File {idx+1}/{files_len}, llama3 reorganize finished")
|
647 |
+
timer.checkpoint(f"File {idx+1}/{files_len}, llama3 reorganize finished")
|
648 |
+
|
649 |
+
txt_file_path = self.create_file(processed_string, output_dir, output_file_name)
|
650 |
+
txt_infos.append({"path": txt_file_path, "name": output_file_name})
|
651 |
+
last_processed_string = processed_string
|
652 |
+
timer.report()
|
653 |
+
|
654 |
+
except Exception as e:
|
655 |
+
print(traceback.format_exc())
|
656 |
+
print("Error processing text file: " + str(e))
|
657 |
+
gr.Warning(f"Failed to process file {os.path.basename(file_obj.name)}. Error: {e}")
|
658 |
+
|
659 |
+
download = []
|
660 |
+
if txt_infos:
|
661 |
+
zip_filename = "texts-processed-" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".zip"
|
662 |
+
downloadZipPath = os.path.join(output_dir, zip_filename)
|
663 |
+
with zipfile.ZipFile(downloadZipPath, 'w', zipfile.ZIP_DEFLATED) as processed_zip:
|
664 |
+
for info in txt_infos:
|
665 |
+
processed_zip.write(info["path"], arcname=info["name"])
|
666 |
+
download.append(downloadZipPath)
|
667 |
+
|
668 |
+
if llama3_reorganize:
|
669 |
+
llama3_reorganize.release_vram()
|
670 |
+
|
671 |
+
progress(1, desc="Text processing completed")
|
672 |
timer.report_all() # Print all recorded times
|
673 |
+
print("Text processing is complete.")
|
674 |
+
|
675 |
+
# Return values in the same structure as the image path, with placeholders for unused outputs
|
676 |
+
return download, last_processed_string, "{}", "", "", "", "{}", {}
|
677 |
|
|
|
|
|
678 |
def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData):
|
679 |
if not selected_state:
|
680 |
return selected_state
|
|
|
711 |
if not gallery or not selected_image:
|
712 |
return gallery
|
713 |
|
714 |
+
try:
|
715 |
+
selected_image = ast.literal_eval(selected_image) #Use ast.literal_eval to parse text into a tuple.
|
716 |
+
# Remove the selected image from the gallery
|
717 |
+
if selected_image in gallery:
|
718 |
+
gallery.remove(selected_image)
|
719 |
+
except (ValueError, SyntaxError):
|
720 |
+
# Handle cases where the string is not a valid literal
|
721 |
+
print(f"Warning: Could not parse selected_image string: {selected_image}")
|
722 |
+
|
723 |
return gallery
|
724 |
|
725 |
|
|
|
731 |
width: 55.5% !important;
|
732 |
}
|
733 |
"""
|
|
|
734 |
args = parse_args()
|
735 |
|
736 |
predictor = Predictor()
|
|
|
751 |
SWINV2_MODEL_IS_DSV1_REPO,
|
752 |
EVA02_LARGE_MODEL_IS_DSV1_REPO,
|
753 |
]
|
754 |
+
|
755 |
llama_list = [
|
756 |
META_LLAMA_3_3B_REPO,
|
757 |
META_LLAMA_3_8B_REPO,
|
758 |
]
|
759 |
|
760 |
+
# NEW: Wrapper function to decide which prediction method to call
|
761 |
+
def run_prediction(
|
762 |
+
input_type, gallery, text_files, model_repo, general_thresh,
|
763 |
+
general_mcut_enabled, character_thresh, character_mcut_enabled,
|
764 |
+
characters_merge_enabled, llama3_reorganize_model_repo,
|
765 |
+
additional_tags_prepend, additional_tags_append, tag_results, progress=gr.Progress()
|
766 |
+
):
|
767 |
+
if input_type == 'Image':
|
768 |
+
return predictor.predict_from_images(
|
769 |
+
gallery, model_repo, general_thresh, general_mcut_enabled,
|
770 |
+
character_thresh, character_mcut_enabled, characters_merge_enabled,
|
771 |
+
llama3_reorganize_model_repo, additional_tags_prepend,
|
772 |
+
additional_tags_append, tag_results, progress
|
773 |
+
)
|
774 |
+
else: # 'Text file (.txt)'
|
775 |
+
# For text files, some parameters are not used, but we must return
|
776 |
+
# a tuple of the same size. `predict_from_text` handles this.
|
777 |
+
return predictor.predict_from_text(
|
778 |
+
text_files, llama3_reorganize_model_repo,
|
779 |
+
additional_tags_prepend, additional_tags_append, progress
|
780 |
+
)
|
781 |
+
|
782 |
+
with gr.Blocks(title=TITLE, css=css) as demo:
|
783 |
+
gr.Markdown(f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
|
784 |
gr.Markdown(value=DESCRIPTION)
|
785 |
+
|
786 |
with gr.Row():
|
787 |
with gr.Column():
|
788 |
submit = gr.Button(value="Submit", variant="primary", size="lg")
|
789 |
+
|
790 |
+
# Input type selector
|
791 |
+
input_type_radio = gr.Radio(
|
792 |
+
choices=['Image', 'Text file (.txt)'],
|
793 |
+
value='Image',
|
794 |
+
label="Input Type"
|
795 |
+
)
|
796 |
+
|
797 |
+
# Group for image inputs, initially visible
|
798 |
+
with gr.Column(visible=True) as image_inputs_group:
|
799 |
+
with gr.Column(variant="panel"):
|
800 |
+
# Create an Image component for uploading images
|
801 |
+
image_input = gr.Image(label="Upload an Image or clicking paste from clipboard button", type="filepath", sources=["upload", "clipboard"], height=150)
|
802 |
+
with gr.Row():
|
803 |
+
upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
|
804 |
+
remove_button = gr.Button("Remove Selected Image", size="sm")
|
805 |
+
gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Gallery that displaying a grid of images")
|
806 |
+
|
807 |
+
# NEW: Group for text file inputs, initially hidden
|
808 |
+
with gr.Column(visible=False) as text_inputs_group:
|
809 |
+
text_files_input = gr.Files(
|
810 |
+
label="Upload .txt files",
|
811 |
+
file_types=[".txt"],
|
812 |
+
file_count="multiple",
|
813 |
+
height=500
|
814 |
+
)
|
815 |
|
816 |
+
# NEW: Logic to show/hide input groups based on radio selection
|
817 |
+
def change_input_type(input_type):
|
818 |
+
is_image = (input_type == 'Image')
|
819 |
+
return {
|
820 |
+
image_inputs_group: gr.update(visible=is_image),
|
821 |
+
text_inputs_group: gr.update(visible=not is_image),
|
822 |
+
# Also update visibility of image-specific settings
|
823 |
+
model_repo: gr.update(visible=is_image),
|
824 |
+
general_thresh_row: gr.update(visible=is_image),
|
825 |
+
character_thresh_row: gr.update(visible=is_image),
|
826 |
+
characters_merge_enabled: gr.update(visible=is_image),
|
827 |
+
categorized: gr.update(visible=is_image),
|
828 |
+
rating: gr.update(visible=is_image),
|
829 |
+
character_res: gr.update(visible=is_image),
|
830 |
+
general_res: gr.update(visible=is_image),
|
831 |
+
unclassified: gr.update(visible=is_image),
|
832 |
+
}
|
833 |
+
|
834 |
+
# Image-specific settings
|
835 |
model_repo = gr.Dropdown(
|
836 |
dropdown_list,
|
837 |
value=EVA02_LARGE_MODEL_DSV3_REPO,
|
838 |
+
label="Model (for Images)",
|
839 |
)
|
840 |
+
with gr.Row(visible=True) as general_thresh_row:
|
841 |
general_thresh = gr.Slider(
|
842 |
0,
|
843 |
1,
|
|
|
851 |
label="Use MCut threshold",
|
852 |
scale=1,
|
853 |
)
|
854 |
+
with gr.Row(visible=True) as character_thresh_row:
|
855 |
character_thresh = gr.Slider(
|
856 |
0,
|
857 |
1,
|
|
|
865 |
label="Use MCut threshold",
|
866 |
scale=1,
|
867 |
)
|
868 |
+
characters_merge_enabled = gr.Checkbox(
|
869 |
+
value=True,
|
870 |
+
label="Merge characters into the string output",
|
871 |
+
scale=1,
|
872 |
+
visible=True,
|
873 |
+
)
|
874 |
+
|
875 |
+
# Common settings
|
876 |
with gr.Row():
|
877 |
llama3_reorganize_model_repo = gr.Dropdown(
|
878 |
[None] + llama_list,
|
879 |
value=None,
|
880 |
+
label="Use the Llama3 model to reorganize the article",
|
881 |
+
info="(Note: very slow)",
|
882 |
)
|
883 |
with gr.Row():
|
884 |
additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
|
|
|
887 |
clear = gr.ClearButton(
|
888 |
components=[
|
889 |
gallery,
|
890 |
+
text_files_input,
|
891 |
model_repo,
|
892 |
general_thresh,
|
893 |
general_mcut_enabled,
|
|
|
901 |
variant="secondary",
|
902 |
size="lg",
|
903 |
)
|
904 |
+
|
905 |
with gr.Column(variant="panel"):
|
906 |
download_file = gr.File(label="Output (Download)")
|
907 |
+
sorted_general_strings = gr.Textbox(label="Output (string)", show_label=True, show_copy_button=True, lines=5)
|
908 |
+
# Image-specific outputs
|
909 |
+
categorized = gr.JSON(label="Categorized (tags)", visible=True)
|
910 |
+
rating = gr.Label(label="Rating", visible=True)
|
911 |
+
character_res = gr.Label(label="Output (characters)", visible=True)
|
912 |
+
general_res = gr.Label(label="Output (tags)", visible=True)
|
913 |
+
unclassified = gr.JSON(label="Unclassified (tags)", visible=True)
|
914 |
clear.add(
|
915 |
[
|
916 |
download_file,
|
|
|
922 |
unclassified,
|
923 |
]
|
924 |
)
|
925 |
+
|
926 |
tag_results = gr.State({})
|
927 |
+
selected_image = gr.Textbox(label="Selected Image", visible=False)
|
928 |
+
|
929 |
+
# Event Listeners
|
930 |
# Define the event listener to add the uploaded image to the gallery
|
931 |
image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
|
932 |
# When the upload button is clicked, add the new images to the gallery
|
933 |
upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery)
|
934 |
# Event to update the selected image when an image is clicked in the gallery
|
|
|
935 |
gallery.select(get_selection_from_gallery, inputs=[gallery, tag_results], outputs=[selected_image, sorted_general_strings, categorized, rating, character_res, general_res, unclassified])
|
936 |
# Event to remove a selected image from the gallery
|
937 |
remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
|
938 |
+
|
939 |
+
# Connect the radio button to the visibility function
|
940 |
+
input_type_radio.change(
|
941 |
+
fn=change_input_type,
|
942 |
+
inputs=input_type_radio,
|
943 |
+
outputs=[
|
944 |
+
image_inputs_group, text_inputs_group, model_repo,
|
945 |
+
general_thresh_row, character_thresh_row, characters_merge_enabled,
|
946 |
+
categorized, rating, character_res, general_res, unclassified
|
947 |
+
]
|
948 |
+
)
|
949 |
|
950 |
+
# submit click now calls the wrapper function
|
951 |
+
submit.click(
|
952 |
+
fn=run_prediction,
|
953 |
+
inputs=[
|
954 |
+
input_type_radio,
|
955 |
+
gallery,
|
956 |
+
text_files_input,
|
957 |
+
model_repo,
|
958 |
+
general_thresh,
|
959 |
+
general_mcut_enabled,
|
960 |
+
character_thresh,
|
961 |
+
character_mcut_enabled,
|
962 |
+
characters_merge_enabled,
|
963 |
+
llama3_reorganize_model_repo,
|
964 |
+
additional_tags_prepend,
|
965 |
+
additional_tags_append,
|
966 |
+
tag_results,
|
967 |
+
],
|
968 |
+
outputs=[download_file, sorted_general_strings, categorized, rating, character_res, general_res, unclassified, tag_results,],
|
969 |
+
)
|
970 |
|
971 |
gr.Examples(
|
972 |
[["power.jpg", SWINV2_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
|
webui.bat
CHANGED
@@ -1,21 +1,37 @@
|
|
1 |
@echo off
|
2 |
|
3 |
-
:: The source of the webui.bat file is stable-diffusion-webui
|
4 |
-
::
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
if not defined PYTHON (set PYTHON=python)
|
7 |
-
if not defined
|
|
|
8 |
|
9 |
mkdir tmp 2>NUL
|
10 |
|
|
|
11 |
%PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt
|
12 |
if %ERRORLEVEL% == 0 goto :check_pip
|
13 |
echo Couldn't launch python
|
14 |
goto :show_stdout_stderr
|
15 |
|
16 |
:check_pip
|
|
|
17 |
%PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt
|
18 |
if %ERRORLEVEL% == 0 goto :start_venv
|
|
|
19 |
if "%PIP_INSTALLER_LOCATION%" == "" goto :show_stdout_stderr
|
20 |
%PYTHON% "%PIP_INSTALLER_LOCATION%" >tmp/stdout.txt 2>tmp/stderr.txt
|
21 |
if %ERRORLEVEL% == 0 goto :start_venv
|
@@ -23,33 +39,106 @@ echo Couldn't install pip
|
|
23 |
goto :show_stdout_stderr
|
24 |
|
25 |
:start_venv
|
26 |
-
if
|
27 |
-
if ["%
|
|
|
|
|
28 |
|
|
|
29 |
dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt
|
30 |
-
if %ERRORLEVEL% == 0 goto :
|
31 |
|
|
|
|
|
32 |
for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i"
|
33 |
echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME%
|
34 |
%PYTHON_FULLNAME% -m venv "%VENV_DIR%" >tmp/stdout.txt 2>tmp/stderr.txt
|
35 |
-
if %ERRORLEVEL%
|
36 |
-
echo Unable to create venv in directory "%VENV_DIR%"
|
37 |
-
goto :show_stdout_stderr
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
set PYTHON="%VENV_DIR%\Scripts\Python.exe"
|
41 |
-
echo venv %PYTHON%
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
-
:
|
|
|
|
|
|
|
|
|
44 |
goto :launch
|
45 |
|
46 |
:launch
|
47 |
-
|
|
|
|
|
|
|
48 |
pause
|
49 |
exit /b
|
50 |
|
51 |
-
:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
|
|
|
|
|
|
53 |
echo.
|
54 |
echo exit code: %errorlevel%
|
55 |
|
@@ -61,13 +150,13 @@ type tmp\stdout.txt
|
|
61 |
|
62 |
:show_stderr
|
63 |
for /f %%i in ("tmp\stderr.txt") do set size=%%~zi
|
64 |
-
if %size% equ 0 goto :
|
65 |
echo.
|
66 |
echo stderr:
|
67 |
type tmp\stderr.txt
|
68 |
|
69 |
:endofscript
|
70 |
-
|
71 |
echo.
|
72 |
echo Launch unsuccessful. Exiting.
|
73 |
pause
|
|
|
|
1 |
@echo off
|
2 |
|
3 |
+
:: The original source of the webui.bat file is stable-diffusion-webui
|
4 |
+
:: Modified and enhanced by Gemini with features for venv management and requirements handling.
|
5 |
|
6 |
+
:: --------- Configuration ---------
|
7 |
+
set COMMANDLINE_ARGS=
|
8 |
+
:: Define the name of the Launch application
|
9 |
+
set APPLICATION_NAME=app.py
|
10 |
+
:: Define the name of the virtual environment directory
|
11 |
+
set VENV_NAME=venv
|
12 |
+
:: Set to 1 to always attempt to update packages from requirements.txt on every launch
|
13 |
+
set ALWAYS_UPDATE_REQS=0
|
14 |
+
:: ---------------------------------
|
15 |
+
|
16 |
+
|
17 |
+
:: Set PYTHON executable if not already defined
|
18 |
if not defined PYTHON (set PYTHON=python)
|
19 |
+
:: Set VENV_DIR using VENV_NAME if not already defined
|
20 |
+
if not defined VENV_DIR (set "VENV_DIR=%~dp0%VENV_NAME%")
|
21 |
|
22 |
mkdir tmp 2>NUL
|
23 |
|
24 |
+
:: Check if Python is callable
|
25 |
%PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt
|
26 |
if %ERRORLEVEL% == 0 goto :check_pip
|
27 |
echo Couldn't launch python
|
28 |
goto :show_stdout_stderr
|
29 |
|
30 |
:check_pip
|
31 |
+
:: Check if pip is available
|
32 |
%PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt
|
33 |
if %ERRORLEVEL% == 0 goto :start_venv
|
34 |
+
:: If pip is not available and PIP_INSTALLER_LOCATION is set, try to install pip
|
35 |
if "%PIP_INSTALLER_LOCATION%" == "" goto :show_stdout_stderr
|
36 |
%PYTHON% "%PIP_INSTALLER_LOCATION%" >tmp/stdout.txt 2>tmp/stderr.txt
|
37 |
if %ERRORLEVEL% == 0 goto :start_venv
|
|
|
39 |
goto :show_stdout_stderr
|
40 |
|
41 |
:start_venv
|
42 |
+
:: Skip venv creation/activation if VENV_DIR is explicitly set to "-"
|
43 |
+
if ["%VENV_DIR%"] == ["-"] goto :skip_venv_entirely
|
44 |
+
:: Skip venv creation/activation if SKIP_VENV is set to "1"
|
45 |
+
if ["%SKIP_VENV%"] == ["1"] goto :skip_venv_entirely
|
46 |
|
47 |
+
:: Check if the venv already exists by looking for Python.exe in its Scripts directory
|
48 |
dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt
|
49 |
+
if %ERRORLEVEL% == 0 goto :activate_venv_and_maybe_update
|
50 |
|
51 |
+
:: Venv does not exist, create it
|
52 |
+
echo Virtual environment not found in "%VENV_DIR%". Creating a new one.
|
53 |
for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i"
|
54 |
echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME%
|
55 |
%PYTHON_FULLNAME% -m venv "%VENV_DIR%" >tmp/stdout.txt 2>tmp/stderr.txt
|
56 |
+
if %ERRORLEVEL% NEQ 0 (
|
57 |
+
echo Unable to create venv in directory "%VENV_DIR%"
|
58 |
+
goto :show_stdout_stderr
|
59 |
+
)
|
60 |
+
echo Venv created.
|
61 |
+
|
62 |
+
:: Install requirements for the first time if venv was just created
|
63 |
+
:: This section handles the initial installation of packages from requirements.txt
|
64 |
+
:: immediately after a new virtual environment is created.
|
65 |
+
echo Checking for requirements.txt for initial setup in %~dp0
|
66 |
+
if exist "%~dp0requirements.txt" (
|
67 |
+
echo Found requirements.txt, attempting to install for initial setup...
|
68 |
+
call "%VENV_DIR%\Scripts\activate.bat"
|
69 |
+
echo Installing packages from requirements.txt ^(initial setup^)...
|
70 |
+
"%VENV_DIR%\Scripts\python.exe" -m pip install -r "%~dp0requirements.txt"
|
71 |
+
if %ERRORLEVEL% NEQ 0 (
|
72 |
+
echo Failed to install requirements during initial setup. Please check the output above.
|
73 |
+
pause
|
74 |
+
goto :show_stdout_stderr_custom_pip_initial
|
75 |
+
)
|
76 |
+
echo Initial requirements installed successfully.
|
77 |
+
call "%VENV_DIR%\Scripts\deactivate.bat"
|
78 |
+
) else (
|
79 |
+
echo No requirements.txt found for initial setup, skipping package installation.
|
80 |
+
)
|
81 |
+
goto :activate_venv_and_maybe_update
|
82 |
+
|
83 |
+
|
84 |
+
:activate_venv_and_maybe_update
|
85 |
+
:: This label is reached if the venv exists or was just created.
|
86 |
+
:: Set PYTHON to point to the venv's Python interpreter.
|
87 |
set PYTHON="%VENV_DIR%\Scripts\Python.exe"
|
88 |
+
echo Activating venv: %PYTHON%
|
89 |
+
|
90 |
+
:: Always update requirements if ALWAYS_UPDATE_REQS is 1
|
91 |
+
:: This section allows for updating packages from requirements.txt on every launch
|
92 |
+
:: if the ALWAYS_UPDATE_REQS variable is set to 1.
|
93 |
+
if defined ALWAYS_UPDATE_REQS (
|
94 |
+
if "%ALWAYS_UPDATE_REQS%"=="1" (
|
95 |
+
echo ALWAYS_UPDATE_REQS is enabled.
|
96 |
+
if exist "%~dp0requirements.txt" (
|
97 |
+
echo Attempting to update packages from requirements.txt...
|
98 |
+
REM No need to call activate.bat here again, PYTHON is already set to the venv's python
|
99 |
+
%PYTHON% -m pip install -r "%~dp0requirements.txt"
|
100 |
+
if %ERRORLEVEL% NEQ 0 (
|
101 |
+
echo Failed to update requirements. Please check the output above.
|
102 |
+
pause
|
103 |
+
goto :endofscript
|
104 |
+
)
|
105 |
+
echo Requirements updated successfully.
|
106 |
+
) else (
|
107 |
+
echo ALWAYS_UPDATE_REQS is enabled, but no requirements.txt found. Skipping update.
|
108 |
+
)
|
109 |
+
) else (
|
110 |
+
echo ALWAYS_UPDATE_REQS is not enabled or not set to 1. Skipping routine update.
|
111 |
+
)
|
112 |
+
)
|
113 |
|
114 |
+
goto :launch
|
115 |
+
|
116 |
+
:skip_venv_entirely
|
117 |
+
:: This label is reached if venv usage is explicitly skipped.
|
118 |
+
echo Skipping venv.
|
119 |
goto :launch
|
120 |
|
121 |
:launch
|
122 |
+
:: Launch the main application
|
123 |
+
echo Launching Web UI with arguments: %COMMANDLINE_ARGS% %*
|
124 |
+
%PYTHON% %APPLICATION_NAME% %COMMANDLINE_ARGS% %*
|
125 |
+
echo Launch finished.
|
126 |
pause
|
127 |
exit /b
|
128 |
|
129 |
+
:show_stdout_stderr_custom_pip_initial
|
130 |
+
:: Custom error handler for failures during the initial pip install process.
|
131 |
+
echo.
|
132 |
+
echo exit code ^(pip initial install^): %errorlevel%
|
133 |
+
echo Errors during initial pip install. See output above.
|
134 |
+
echo.
|
135 |
+
echo Launch unsuccessful. Exiting.
|
136 |
+
pause
|
137 |
+
exit /b
|
138 |
|
139 |
+
|
140 |
+
:show_stdout_stderr
|
141 |
+
:: General error handler: displays stdout and stderr from the tmp directory.
|
142 |
echo.
|
143 |
echo exit code: %errorlevel%
|
144 |
|
|
|
150 |
|
151 |
:show_stderr
|
152 |
for /f %%i in ("tmp\stderr.txt") do set size=%%~zi
|
153 |
+
if %size% equ 0 goto :endofscript
|
154 |
echo.
|
155 |
echo stderr:
|
156 |
type tmp\stderr.txt
|
157 |
|
158 |
:endofscript
|
|
|
159 |
echo.
|
160 |
echo Launch unsuccessful. Exiting.
|
161 |
pause
|
162 |
+
exit /b
|