File size: 2,134 Bytes
779acf3
 
 
 
11cddad
779acf3
 
 
 
11cddad
 
 
f5e83aa
 
 
779acf3
93fd2ea
11cddad
779acf3
 
 
 
 
 
 
93fd2ea
 
 
 
779acf3
 
 
93fd2ea
779acf3
 
e3c9822
779acf3
 
93fd2ea
 
11cddad
93fd2ea
11cddad
93fd2ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from __future__ import annotations

import PIL.Image
import torch
from diffusers import StableDiffusionAttendAndExcitePipeline, StableDiffusionPipeline


class Model:
    def __init__(self):
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        model_id = "CompVis/stable-diffusion-v1-4"
        self.ax_pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained(model_id)
        self.ax_pipe.to(self.device)
        self.sd_pipe = StableDiffusionPipeline.from_pretrained(model_id)
        self.sd_pipe.to(self.device)

    def get_token_table(self, prompt: str):
        tokens = [self.ax_pipe.tokenizer.decode(t) for t in self.ax_pipe.tokenizer(prompt)["input_ids"]]
        tokens = tokens[1:-1]
        return list(enumerate(tokens, start=1))

    def run(
        self,
        prompt: str,
        indices_to_alter_str: str,
        seed: int = 0,
        apply_attend_and_excite: bool = True,
        num_steps: int = 50,
        guidance_scale: float = 7.5,
        scale_factor: int = 20,
        thresholds: dict[int, float] = {
            10: 0.5,
            20: 0.8,
        },
        max_iter_to_alter: int = 25,
    ) -> PIL.Image.Image:
        generator = torch.Generator(device=self.device).manual_seed(seed)

        if apply_attend_and_excite:
            try:
                token_indices = list(map(int, indices_to_alter_str.split(",")))
            except Exception:
                raise ValueError("Invalid token indices.")
            out = self.ax_pipe(
                prompt=prompt,
                token_indices=token_indices,
                guidance_scale=guidance_scale,
                generator=generator,
                num_inference_steps=num_steps,
                max_iter_to_alter=max_iter_to_alter,
                thresholds=thresholds,
                scale_factor=scale_factor,
            )
        else:
            out = self.sd_pipe(
                prompt=prompt,
                guidance_scale=guidance_scale,
                generator=generator,
                num_inference_steps=num_steps,
            )
        return out.images[0]