auto-hajimi-mosaic / batch_process.py
Frinkleko's picture
init commit
d228e38 verified
import os
import argparse
from PIL import Image
from util import load_models, classify_image, segment_image, apply_mask
def process_images(input_folder, output_folder, pattern_image_path, head_image_path):
classification_model, segmentation_model = load_models()
names = segmentation_model.names
pattern_image = Image.open(pattern_image_path)
head_image = Image.open(head_image_path).convert("RGBA")
for filename in os.listdir(input_folder):
if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp')):
image_path = os.path.join(input_folder, filename)
image = Image.open(image_path)
# Classify the image
category = classify_image(image, classification_model)
category_name = classification_model.names[category[0]]
# Segment the image
segmentation_results = segment_image(image, segmentation_model)
try:
masks = segmentation_results[0].masks.data.cpu().numpy()
class_ids = segmentation_results[0].boxes.cls.cpu().numpy().astype(int)
except AttributeError:
if category_name in ['porn', 'hentai']:
print(f"Warning: {filename} is classified as sensitive content, but no mask was found.")
masks = []
class_ids = []
mask_options = [names[class_id] for class_id in class_ids]
selected_masks = mask_options # Automatically select all masks
if selected_masks:
image_with_fill = image.copy()
for i, mask in enumerate(masks):
if mask_options[i] in selected_masks:
image_with_fill = apply_mask(image_with_fill, mask, pattern_image, head_image)
# Save the processed image
output_path = os.path.join(output_folder, f"processed_{filename}")
image_with_fill.save(output_path, format="PNG")
print(f"Processed and saved: {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Batch process images with YOLO models.")
parser.add_argument("input_folder", type=str, help="Path to the input folder containing images.")
parser.add_argument("output_folder", type=str, help="Path to the output folder to save processed images.")
parser.add_argument("--pattern_image", type=str, default="assets/pattern.png", help="Path to the pattern image.")
parser.add_argument("--head_image", type=str, default="assets/head.png", help="Path to the head image.")
args = parser.parse_args()
os.makedirs(args.output_folder, exist_ok=True)
process_images(args.input_folder, args.output_folder, args.pattern_image, args.head_image)