Files changed (1) hide show
  1. app.py +43 -21
app.py CHANGED
@@ -4,34 +4,56 @@ import torch
4
  from PIL import Image
5
  import io
6
  import os
 
7
 
8
- # Force CPU usage
 
 
 
 
9
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
10
 
11
  @st.cache_resource
12
  def load_model():
 
 
13
  pipe = AutoPipelineForText2Image.from_pretrained(
14
- "stabilityai/sd-turbo",
15
- torch_dtype=torch.float32 # CPU-compatible
16
  )
17
  pipe.to("cpu")
18
  return pipe
19
 
20
- st.title("⚑ Fast AI Image Generator (under 1 minute)")
21
-
22
- prompt = st.text_input("Enter your prompt:",
23
- "A glowing alien city with floating islands and neon rivers, concept art, 8K")
24
-
25
- guidance = st.slider("Guidance scale (higher = more faithful to prompt)", 1.0, 10.0, 3.0)
26
-
27
- if st.button("Generate Image"):
28
- with st.spinner("Generating image (approx. 20–40 seconds on CPU)..."):
29
- pipe = load_model()
30
- result = pipe(prompt, guidance_scale=guidance, num_inference_steps=20)
31
- image = result.images[0]
32
-
33
- st.image(image, caption="Generated Image", use_column_width=True)
34
-
35
- buf = io.BytesIO()
36
- image.save(buf, format="PNG")
37
- st.download_button("Download Image", buf.getvalue(), "generated.png", "image/png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from PIL import Image
5
  import io
6
  import os
7
+ import requests
8
 
9
+ # --- CONFIG ---
10
+ USE_GROQ = False # Set to True when Groq image API is available
11
+ GROQ_API_URL = "https://your-groq-image-api.com/generate" # Placeholder
12
+
13
+ # Force CPU
14
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
15
 
16
  @st.cache_resource
17
  def load_model():
18
+ if USE_GROQ:
19
+ return None # Skip local model
20
  pipe = AutoPipelineForText2Image.from_pretrained(
21
+ "stabilityai/sd-turbo",
22
+ torch_dtype=torch.float32
23
  )
24
  pipe.to("cpu")
25
  return pipe
26
 
27
+ def generate_image_local(prompt, guidance_scale):
28
+ pipe = load_model()
29
+ result = pipe(prompt, guidance_scale=guidance_scale, num_inference_steps=20)
30
+ return result.images[0]
31
+
32
+ def generate_image_from_groq(prompt):
33
+ response = requests.post(GROQ_API_URL, json={"prompt": prompt})
34
+ if response.status_code == 200:
35
+ image_bytes = io.BytesIO(response.content)
36
+ return Image.open(image_bytes)
37
+ else:
38
+ raise Exception(f"GROQ API failed: {response.text}")
39
+
40
+ # UI
41
+ st.title("🧠 AI Image Generator (Fast with API / Groq-ready)")
42
+
43
+ prompt = st.text_input("Prompt:", "A glowing alien forest with floating orbs, concept art, 8K")
44
+ guidance = st.slider("Guidance scale", 1.0, 10.0, 3.0)
45
+
46
+ if st.button("Generate"):
47
+ with st.spinner("Generating..."):
48
+ try:
49
+ if USE_GROQ:
50
+ image = generate_image_from_groq(prompt)
51
+ else:
52
+ image = generate_image_local(prompt, guidance)
53
+
54
+ st.image(image, caption="Generated Image", use_column_width=True)
55
+ buf = io.BytesIO()
56
+ image.save(buf, format="PNG")
57
+ st.download_button("Download Image", buf.getvalue(), "generated.png", "image/png")
58
+ except Exception as e:
59
+ st.error(f"Error: {e}")