File size: 2,793 Bytes
d228e38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import argparse
from PIL import Image
from util import load_models, classify_image, segment_image, apply_mask

def process_images(input_folder, output_folder, pattern_image_path, head_image_path):
    classification_model, segmentation_model = load_models()
    names = segmentation_model.names

    pattern_image = Image.open(pattern_image_path)
    head_image = Image.open(head_image_path).convert("RGBA")

    for filename in os.listdir(input_folder):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp')):
            image_path = os.path.join(input_folder, filename)
            image = Image.open(image_path)

            # Classify the image
            category = classify_image(image, classification_model)
            category_name = classification_model.names[category[0]]

            # Segment the image
            segmentation_results = segment_image(image, segmentation_model)
            try:
                masks = segmentation_results[0].masks.data.cpu().numpy()
                class_ids = segmentation_results[0].boxes.cls.cpu().numpy().astype(int)
            except AttributeError:
                if category_name in ['porn', 'hentai']:
                    print(f"Warning: {filename} is classified as sensitive content, but no mask was found.")
                masks = []
                class_ids = []

            mask_options = [names[class_id] for class_id in class_ids]
            selected_masks = mask_options  # Automatically select all masks

            if selected_masks:
                image_with_fill = image.copy()
                for i, mask in enumerate(masks):
                    if mask_options[i] in selected_masks:
                        image_with_fill = apply_mask(image_with_fill, mask, pattern_image, head_image)

                # Save the processed image
                output_path = os.path.join(output_folder, f"processed_{filename}")
                image_with_fill.save(output_path, format="PNG")
                print(f"Processed and saved: {output_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Batch process images with YOLO models.")
    parser.add_argument("input_folder", type=str, help="Path to the input folder containing images.")
    parser.add_argument("output_folder", type=str, help="Path to the output folder to save processed images.")
    parser.add_argument("--pattern_image", type=str, default="assets/pattern.png", help="Path to the pattern image.")
    parser.add_argument("--head_image", type=str, default="assets/head.png", help="Path to the head image.")

    args = parser.parse_args()

    os.makedirs(args.output_folder, exist_ok=True)
    process_images(args.input_folder, args.output_folder, args.pattern_image, args.head_image)