Manireddy1508 commited on
Commit
37abfc2
Β·
verified Β·
1 Parent(s): 2e17b63

Update app.py

Browse files

using base model for all products

Files changed (1) hide show
  1. app.py +54 -60
app.py CHANGED
@@ -2,97 +2,91 @@
2
 
3
  import gradio as gr
4
  from PIL import Image
5
- import base64
6
- import requests
7
  import os
8
- from io import BytesIO
9
-
10
  from utils.planner import extract_scene_plan # 🧠 Brain Layer
11
 
12
- # πŸ” Hugging Face keys
13
- HF_API_KEY = os.getenv("HF_API_KEY")
14
- SDXL_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0" # βœ… Correct model for image-to-image
15
- SDXL_API_URL = f"https://api-inference.huggingface.co/models/{SDXL_MODEL_ID}"
16
- HEADERS = {"Authorization": f"Bearer {HF_API_KEY}"}
17
-
18
- # πŸš€ Image generation (img2img)
19
- def process_image(prompt, image, num_variations):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  try:
21
  print("🧠 Prompt received:", prompt)
22
-
23
- # Step 1: Brain Layer
24
  reasoning_json = extract_scene_plan(prompt)
25
  print("🧠 Scene plan extracted:", reasoning_json)
26
 
27
- # Step 2: Encode input image
28
- buffered = BytesIO()
29
- image.save(buffered, format="JPEG")
30
- img_bytes = buffered.getvalue()
31
- encoded_image = base64.b64encode(img_bytes).decode("utf-8")
32
 
33
- # Step 3: Send image + prompt to HF API
34
- outputs = []
35
  for i in range(num_variations):
36
- payload = {
37
- "image": encoded_image,
38
- "prompt": prompt,
39
- "negative_prompt": "blurry, deformed, cropped",
40
- "strength": 25,
41
- "guidance_scale": 7.5
42
- }
43
-
44
- print(f"πŸ“€ Sending request to HF (variation {i+1})")
45
- response = requests.post(SDXL_API_URL, headers=HEADERS, json=payload)
46
-
47
- if response.status_code == 200:
48
- try:
49
- result_json = response.json()
50
- if "images" in result_json:
51
- base64_img = result_json["images"][0]
52
- result_image = Image.open(BytesIO(base64.b64decode(base64_img)))
53
- outputs.append(result_image)
54
- print(f"βœ… Decoded image variation {i+1} successfully")
55
- else:
56
- print(f"⚠️ No 'images' key found in response")
57
- outputs.append("❌ No image in response.")
58
- except Exception as decode_err:
59
- print("❌ Image decode error:", decode_err)
60
- outputs.append("❌ Failed to decode image.")
61
- else:
62
- print(f"❌ HF API error: {response.status_code} - {response.text}")
63
- outputs.append(f"Error {response.status_code}: {response.text}")
64
-
65
- return outputs, reasoning_json
66
 
67
  except Exception as e:
68
- print("❌ General Exception in process_image:", e)
69
- return ["Processing error occurred"], {"error": str(e)}
70
 
71
- # 🎨 Gradio UI
 
 
72
  with gr.Blocks() as demo:
73
- gr.Markdown("# 🧠 NewCrux AI Demo: Image-to-Image using Fast SDXL + Brain Layer")
74
 
75
  with gr.Row():
76
  with gr.Column():
77
- prompt_input = gr.Textbox(label="Enter Prompt")
78
  image_input = gr.Image(type="pil", label="Upload Product Image")
 
79
  variation_slider = gr.Slider(1, 4, step=1, value=1, label="Number of Variations")
80
  generate_btn = gr.Button("Generate")
81
 
82
  with gr.Column():
83
  output_gallery = gr.Gallery(
84
- label="Generated Image Variations",
85
  columns=2,
86
  rows=2,
87
  height="auto"
88
  )
89
- json_output = gr.JSON(label="Brain Layer Reasoning (Scene Plan)")
90
 
91
  generate_btn.click(
92
  fn=process_image,
93
- inputs=[prompt_input, image_input, variation_slider],
94
  outputs=[output_gallery, json_output]
95
  )
96
 
97
- demo.launch(share=True)
98
 
 
2
 
3
  import gradio as gr
4
  from PIL import Image
 
 
5
  import os
6
+ import torch
7
+ from diffusers import StableDiffusionXLInpaintPipeline
8
  from utils.planner import extract_scene_plan # 🧠 Brain Layer
9
 
10
+ # ----------------------------
11
+ # πŸ”§ Device Setup
12
+ # ----------------------------
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ dtype = torch.float16 if device == "cuda" else torch.float32
15
+
16
+ # ----------------------------
17
+ # πŸ“¦ Load Inpainting Model
18
+ # ----------------------------
19
+ pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
20
+ "diffusers/stable-diffusion-xl-1.0-inpainting",
21
+ torch_dtype=dtype,
22
+ variant="fp16" if device == "cuda" else None
23
+ ).to(device)
24
+
25
+ pipe.enable_attention_slicing()
26
+ pipe.enable_model_cpu_offload()
27
+
28
+ # ----------------------------
29
+ # 🎨 Image Generation Function
30
+ # ----------------------------
31
+ def process_image(prompt, image, mask, num_variations):
32
  try:
33
  print("🧠 Prompt received:", prompt)
34
+
35
+ # 🧠 Step 1: Brain Layer
36
  reasoning_json = extract_scene_plan(prompt)
37
  print("🧠 Scene plan extracted:", reasoning_json)
38
 
39
+ # Resize inputs to 1024x1024 (required for SDXL)
40
+ image = image.resize((1024, 1024)).convert("RGB")
41
+ mask = mask.resize((1024, 1024)).convert("L")
 
 
42
 
43
+ results = []
 
44
  for i in range(num_variations):
45
+ print(f"🎨 Generating variation {i + 1}...")
46
+ output = pipe(
47
+ prompt=prompt,
48
+ image=image,
49
+ mask_image=mask,
50
+ strength=0.98,
51
+ guidance_scale=7.5,
52
+ num_inference_steps=40
53
+ ).images[0]
54
+ results.append(output)
55
+
56
+ return results, reasoning_json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  except Exception as e:
59
+ print("❌ Error during generation:", e)
60
+ return ["❌ Generation failed"], {"error": str(e)}
61
 
62
+ # ----------------------------
63
+ # πŸ–ΌοΈ Gradio UI
64
+ # ----------------------------
65
  with gr.Blocks() as demo:
66
+ gr.Markdown("## 🧠 NewCrux Inpainting Demo (SDXL)\nUpload a product image, a mask, and a prompt to generate realistic content.")
67
 
68
  with gr.Row():
69
  with gr.Column():
70
+ prompt_input = gr.Textbox(label="Prompt")
71
  image_input = gr.Image(type="pil", label="Upload Product Image")
72
+ mask_input = gr.Image(type="pil", label="Upload Mask (white = keep, black = replace)")
73
  variation_slider = gr.Slider(1, 4, step=1, value=1, label="Number of Variations")
74
  generate_btn = gr.Button("Generate")
75
 
76
  with gr.Column():
77
  output_gallery = gr.Gallery(
78
+ label="Generated Variations",
79
  columns=2,
80
  rows=2,
81
  height="auto"
82
  )
83
+ json_output = gr.JSON(label="🧠 Brain Layer Reasoning")
84
 
85
  generate_btn.click(
86
  fn=process_image,
87
+ inputs=[prompt_input, image_input, mask_input, variation_slider],
88
  outputs=[output_gallery, json_output]
89
  )
90
 
91
+ demo.launch()
92