pradeep6kumar2024 commited on
Commit
90ec6ed
·
1 Parent(s): ef7a9a3

updated for cpu

Browse files
Files changed (2) hide show
  1. README.md +5 -0
  2. app.py +9 -4
README.md CHANGED
@@ -7,6 +7,11 @@ sdk: gradio
7
  sdk_version: "3.50.2"
8
  app_file: app.py
9
  pinned: false
 
 
 
 
 
10
  ---
11
 
12
  # Style-Guided Image Generation with Purple Enhancement
 
7
  sdk_version: "3.50.2"
8
  app_file: app.py
9
  pinned: false
10
+ hardware: true
11
+ resources:
12
+ cpu: 1
13
+ memory: "16Gi"
14
+ gpu: 1
15
  ---
16
 
17
  # Style-Guided Image Generation with Purple Enhancement
app.py CHANGED
@@ -40,10 +40,14 @@ styles = {
40
 
41
  def load_pipeline():
42
  """Load and prepare the pipeline with all style embeddings"""
 
 
 
 
43
  pipe = StableDiffusionPipeline.from_pretrained(
44
  "CompVis/stable-diffusion-v1-4",
45
- torch_dtype=torch.float16
46
- ).to("cuda")
47
 
48
  # Load all embeddings
49
  for style_info in styles.values():
@@ -71,8 +75,9 @@ def generate_image(prompt, style, seed, apply_guidance, guidance_strength=0.5):
71
  # Get style info
72
  style_info = styles[style]
73
 
74
- # Prepare generator
75
- generator = torch.Generator("cuda").manual_seed(int(seed))
 
76
 
77
  # Create styled prompt
78
  styled_prompt = f"{prompt} {style_info['token']}"
 
40
 
41
  def load_pipeline():
42
  """Load and prepare the pipeline with all style embeddings"""
43
+ # Check if CUDA is available
44
+ device = "cuda" if torch.cuda.is_available() else "cpu"
45
+ dtype = torch.float16 if device == "cuda" else torch.float32
46
+
47
  pipe = StableDiffusionPipeline.from_pretrained(
48
  "CompVis/stable-diffusion-v1-4",
49
+ torch_dtype=dtype
50
+ ).to(device)
51
 
52
  # Load all embeddings
53
  for style_info in styles.values():
 
75
  # Get style info
76
  style_info = styles[style]
77
 
78
+ # Prepare generator with appropriate device
79
+ device = "cuda" if torch.cuda.is_available() else "cpu"
80
+ generator = torch.Generator(device).manual_seed(int(seed))
81
 
82
  # Create styled prompt
83
  styled_prompt = f"{prompt} {style_info['token']}"