b2bomber commited on
Commit
2d37af9
·
verified ·
1 Parent(s): 6c698ce

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import StableDiffusionPipeline, DDIMScheduler
4
+ from PIL import Image
5
+
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
7
+
8
+ # Load SD model (use SD1.5 or SDXL-based)
9
+ model_id = "stabilityai/stable-diffusion-2-1"
10
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
11
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
12
+ pipe = pipe.to(device)
13
+
14
+ # Preset styles
15
+ styles = {
16
+ "Pixar": "pixar style portrait of",
17
+ "Anime": "anime style portrait of",
18
+ "Cyberpunk": "cyberpunk futuristic avatar of",
19
+ "Disney": "disney movie character of",
20
+ "Sketch": "pencil sketch portrait of",
21
+ "Astronaut": "realistic astronaut with helmet, portrait of"
22
+ }
23
+
24
+ def generate_avatar(image, style):
25
+ if image is None:
26
+ return None
27
+
28
+ # Preprocess image (convert to prompt-only for simplicity)
29
+ base_prompt = styles[style]
30
+ prompt = f"{base_prompt} a person"
31
+
32
+ image = pipe(prompt=prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
33
+ return image
34
+
35
+ with gr.Blocks() as demo:
36
+ gr.Markdown("## 🎨 Stable Diffusion Avatar Generator with Preset Styles")
37
+
38
+ with gr.Row():
39
+ with gr.Column():
40
+ image_input = gr.Image(label="Upload your photo", type="pil", sources=["upload", "webcam"])
41
+ style_selector = gr.Radio(choices=list(styles.keys()), label="Choose a style", value="Anime")
42
+ generate_btn = gr.Button("Generate Avatar")
43
+ with gr.Column():
44
+ output_image = gr.Image(label="Generated Avatar")
45
+
46
+ generate_btn.click(fn=generate_avatar, inputs=[image_input, style_selector], outputs=output_image)
47
+
48
+ demo.launch()