RamaManna commited on
Commit
7411c73
·
verified ·
1 Parent(s): 179df07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -28
app.py CHANGED
@@ -1,55 +1,90 @@
1
  import gradio as gr
2
  import torch
 
3
  from PIL import Image
4
 
5
- # We'll use the smallest available pipeline
6
- from diffusers import DiffusionPipeline
7
 
8
- # Load model (cached after first run)
9
  @gr.cache()
10
  def load_model():
11
- return DiffusionPipeline.from_pretrained(
12
- "OFA-Sys/small-stable-diffusion-v0",
13
  torch_dtype=torch.float16,
14
  safety_checker=None,
15
  use_safetensors=True
16
- ).to("cpu")
 
 
 
 
 
17
 
18
- def generate_character(description, seed=42):
19
  try:
20
  pipe = load_model()
 
21
 
22
- torch.manual_seed(seed)
23
  with torch.inference_mode():
24
  image = pipe(
25
- prompt=f"pixel art character, {description}",
26
- num_inference_steps=15,
27
- guidance_scale=7.0,
28
- width=256,
29
- height=256
 
 
30
  ).images[0]
31
 
32
  return image
33
  except Exception as e:
34
- return f"Error: {str(e)}\nTry a simpler description."
35
 
36
- with gr.Blocks(title="Lightweight Character Generator") as demo:
37
- gr.Markdown("# 🎨 Tiny Character Creator")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  with gr.Row():
40
- desc = gr.Textbox(
41
- label="Describe your character",
42
- placeholder="e.g., 'green alien with one eye'",
43
- max_lines=2
44
  )
45
 
46
- generate_btn = gr.Button("Generate Character")
47
- output = gr.Image(label="Your Character", shape=(256, 256))
 
48
 
49
- generate_btn.click(
50
- generate_character,
51
- inputs=desc,
52
- outputs=output
53
- )
 
54
 
55
- demo.launch(debug=False)
 
1
  import gradio as gr
2
  import torch
3
+ from diffusers import StableDiffusionPipeline, DPMSolverSinglestepScheduler
4
  from PIL import Image
5
 
6
+ # Load a memory-efficient SD variant (under 12GB)
7
+ model_id = "runwayml/stable-diffusion-v1-5"
8
 
 
9
  @gr.cache()
10
  def load_model():
11
+ pipe = StableDiffusionPipeline.from_pretrained(
12
+ model_id,
13
  torch_dtype=torch.float16,
14
  safety_checker=None,
15
  use_safetensors=True
16
+ )
17
+ pipe.scheduler = DPMSolverSinglestepScheduler.from_config(pipe.scheduler.config)
18
+ pipe = pipe.to("cpu")
19
+ pipe.enable_attention_slicing() # Reduces memory by 30%
20
+ pipe.enable_model_cpu_offload() # Only loads needed components
21
+ return pipe
22
 
23
+ def generate_character(prompt, seed=42):
24
  try:
25
  pipe = load_model()
26
+ generator = torch.Generator(device="cpu").manual_seed(seed)
27
 
 
28
  with torch.inference_mode():
29
  image = pipe(
30
+ prompt=f"cartoon character {prompt}, vibrant colors, clean lines",
31
+ negative_prompt="blurry, deformed, ugly",
32
+ num_inference_steps=20,
33
+ guidance_scale=7.5,
34
+ width=512,
35
+ height=512,
36
+ generator=generator
37
  ).images[0]
38
 
39
  return image
40
  except Exception as e:
41
+ return f"Error: {str(e)}\nTry simplifying your prompt."
42
 
43
+ # Animation through img2img
44
+ def generate_animation(prompt, frames=3):
45
+ base_image = generate_character(prompt)
46
+ if isinstance(base_image, str): # If error
47
+ return base_image
48
+
49
+ images = [base_image]
50
+ pipe = load_model()
51
+
52
+ for i in range(1, frames):
53
+ result = pipe(
54
+ prompt=prompt,
55
+ image=images[-1],
56
+ strength=0.3, # Small changes per frame
57
+ generator=torch.Generator().manual_seed(i)
58
+ )
59
+ images.append(result.images[0])
60
+
61
+ images[0].save(
62
+ "animation.gif",
63
+ save_all=True,
64
+ append_images=images[1:],
65
+ duration=500,
66
+ loop=0
67
+ )
68
+ return "animation.gif"
69
+
70
+ with gr.Blocks(theme=gr.themes.Base()) as demo:
71
+ gr.Markdown("# 🎬 Character Animator (12GB Optimized)")
72
 
73
  with gr.Row():
74
+ prompt = gr.Textbox(
75
+ label="Character Description",
76
+ placeholder="e.g. 'cyberpunk fox wearing sunglasses'"
 
77
  )
78
 
79
+ with gr.Tab("Single Image"):
80
+ img_out = gr.Image(label="Generated Character", type="pil")
81
+ gen_btn = gr.Button("Generate")
82
 
83
+ with gr.Tab("Animation"):
84
+ anim_out = gr.Image(label="Animation", format="gif")
85
+ anim_btn = gr.Button("Create Animation (3 frames)")
86
+
87
+ gen_btn.click(generate_character, inputs=prompt, outputs=img_out)
88
+ anim_btn.click(generate_animation, inputs=prompt, outputs=anim_out)
89
 
90
+ demo.launch()