File size: 4,054 Bytes
2c50826
 
 
 
 
 
 
 
 
 
 
 
 
 
199a7d9
2c50826
 
199a7d9
2c50826
 
199a7d9
2c50826
199a7d9
2c50826
 
 
199a7d9
2c50826
 
 
 
 
 
 
199a7d9
 
 
 
2c50826
 
 
 
199a7d9
2c50826
 
199a7d9
2c50826
 
199a7d9
2c50826
 
 
199a7d9
2c50826
199a7d9
2c50826
199a7d9
2c50826
 
 
 
 
 
 
199a7d9
2c50826
 
 
 
 
 
 
199a7d9
 
 
 
2c50826
199a7d9
2c50826
 
199a7d9
2c50826
 
199a7d9
2c50826
 
 
199a7d9
2c50826
199a7d9
2c50826
199a7d9
2c50826
 
 
 
199a7d9
2c50826
 
 
 
 
 
199a7d9
 
 
2c50826
 
199a7d9
2c50826
199a7d9
2c50826
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import argparse
import json
from pathlib import Path
from typing import List

from tqdm import tqdm

from api import create_api
from benchmark import create_benchmark


def generate_images(api_type: str, benchmarks: List[str]):
    images_dir = Path("images")
    api = create_api(api_type)

    api_dir = images_dir / api_type
    api_dir.mkdir(parents=True, exist_ok=True)

    for benchmark_type in tqdm(benchmarks, desc="Processing benchmarks"):
        print(f"\nProcessing benchmark: {benchmark_type}")

        benchmark = create_benchmark(benchmark_type)

        if benchmark_type == "geneval":
            benchmark_dir = api_dir / benchmark_type
            benchmark_dir.mkdir(parents=True, exist_ok=True)

            metadata_file = benchmark_dir / "metadata.jsonl"
            existing_metadata = {}
            if metadata_file.exists():
                with open(metadata_file, "r") as f:
                    for line in f:
                        entry = json.loads(line)
                        existing_metadata[entry["filepath"]] = entry

            for metadata, folder_name in tqdm(
                benchmark, desc=f"Generating images for {benchmark_type}", leave=False
            ):
                sample_path = benchmark_dir / folder_name
                samples_path = sample_path / "samples"
                samples_path.mkdir(parents=True, exist_ok=True)
                image_path = samples_path / "0000.png"

                if image_path.exists():
                    continue

                try:
                    inference_time = api.generate_image(metadata["prompt"], image_path)

                    metadata_entry = {
                        "filepath": str(image_path),
                        "prompt": metadata["prompt"],
                        "inference_time": inference_time,
                    }

                    existing_metadata[str(image_path)] = metadata_entry

                except Exception as e:
                    print(f"\nError generating image for prompt: {metadata['prompt']}")
                    print(f"Error: {str(e)}")
                    continue
        else:
            benchmark_dir = api_dir / benchmark_type
            benchmark_dir.mkdir(parents=True, exist_ok=True)

            metadata_file = benchmark_dir / "metadata.jsonl"
            existing_metadata = {}
            if metadata_file.exists():
                with open(metadata_file, "r") as f:
                    for line in f:
                        entry = json.loads(line)
                        existing_metadata[entry["filepath"]] = entry

            for prompt, image_path in tqdm(
                benchmark, desc=f"Generating images for {benchmark_type}", leave=False
            ):
                full_image_path = benchmark_dir / image_path

                if full_image_path.exists():
                    continue

                try:
                    inference_time = api.generate_image(prompt, full_image_path)

                    metadata_entry = {
                        "filepath": str(image_path),
                        "prompt": prompt,
                        "inference_time": inference_time,
                    }

                    existing_metadata[str(image_path)] = metadata_entry

                except Exception as e:
                    print(f"\nError generating image for prompt: {prompt}")
                    print(f"Error: {str(e)}")
                    continue

            with open(metadata_file, "w") as f:
                for entry in existing_metadata.values():
                    f.write(json.dumps(entry) + "\n")


def main():
    parser = argparse.ArgumentParser(
        description="Generate images for specified benchmarks using a given API"
    )
    parser.add_argument("api_type", help="Type of API to use for image generation")
    parser.add_argument("benchmarks", nargs="+", help="List of benchmark types to run")

    args = parser.parse_args()

    generate_images(args.api_type, args.benchmarks)


if __name__ == "__main__":
    main()