Abe commited on
Commit
7ddc847
·
1 Parent(s): 8247a04

t2i ok, i2i bork

Browse files
Files changed (2) hide show
  1. app.py +31 -9
  2. inference.py +69 -22
app.py CHANGED
@@ -24,31 +24,53 @@ def text_to_image_fn(prompt, model, negative_prompt=None, guidance_scale=7.5, nu
24
  num_inference_steps=num_inference_steps
25
  )
26
 
 
 
 
27
  return image, None
28
  except Exception as e:
29
- return None, str(e)
 
 
30
 
31
  def image_to_image_fn(image, prompt, model, negative_prompt=None, guidance_scale=7.5, num_inference_steps=50):
32
  """
33
  Handle image to image transformation request
34
  """
 
 
 
 
 
 
 
 
 
 
 
 
35
  try:
36
- if not model:
37
- model = config.DEFAULT_IMG2IMG_MODEL
38
-
39
- # Call the inference module
40
  result = inference.image_to_image(
41
  image=image,
42
  prompt=prompt,
43
  model_name=model,
44
- negative_prompt=negative_prompt,
45
- guidance_scale=guidance_scale,
46
- num_inference_steps=num_inference_steps
47
  )
48
 
 
 
 
49
  return result, None
50
  except Exception as e:
51
- return None, str(e)
 
 
 
 
 
52
 
53
  # Create Gradio UI
54
  with gr.Blocks(title="Diffusion Models") as app:
 
24
  num_inference_steps=num_inference_steps
25
  )
26
 
27
+ if image is None:
28
+ return None, "No image was generated. Check the model and parameters."
29
+
30
  return image, None
31
  except Exception as e:
32
+ error_msg = f"Error: {str(e)}"
33
+ print(error_msg)
34
+ return None, error_msg
35
 
36
  def image_to_image_fn(image, prompt, model, negative_prompt=None, guidance_scale=7.5, num_inference_steps=50):
37
  """
38
  Handle image to image transformation request
39
  """
40
+ if image is None:
41
+ return None, "No input image provided."
42
+
43
+ if not prompt:
44
+ prompt = ""
45
+
46
+ if not model:
47
+ model = config.DEFAULT_IMG2IMG_MODEL
48
+
49
+ print(f"Input type: {type(image)}")
50
+ print(f"Processing image-to-image with prompt: '{prompt}', model: {model}")
51
+
52
  try:
53
+ # Call the inference module with explicit parameters
 
 
 
54
  result = inference.image_to_image(
55
  image=image,
56
  prompt=prompt,
57
  model_name=model,
58
+ negative_prompt=negative_prompt if negative_prompt else None,
59
+ guidance_scale=float(guidance_scale),
60
+ num_inference_steps=int(num_inference_steps)
61
  )
62
 
63
+ if result is None:
64
+ return None, "No image was generated. Check the model and parameters."
65
+
66
  return result, None
67
  except Exception as e:
68
+ error_msg = f"Error: {str(e)}"
69
+ print(error_msg)
70
+ print(f"Input image type: {type(image)}")
71
+ print(f"Prompt: {prompt}")
72
+ print(f"Model: {model}")
73
+ return None, error_msg
74
 
75
  # Create Gradio UI
76
  with gr.Blocks(title="Diffusion Models") as app:
inference.py CHANGED
@@ -30,20 +30,29 @@ class DiffusionInference:
30
  """
31
  model = model_name or config.DEFAULT_TEXT2IMG_MODEL
32
 
33
- # Set up parameters dictionary
34
- params = {"prompt": prompt}
 
 
 
35
 
36
- if negative_prompt:
 
37
  params["negative_prompt"] = negative_prompt
38
-
39
- # Add any additional parameters
40
- params.update(kwargs)
 
 
41
 
42
  try:
43
- image = self.client.text_to_image(model=model, **params)
 
44
  return image
45
  except Exception as e:
46
  print(f"Error generating image: {e}")
 
 
47
  raise
48
 
49
  def image_to_image(self, image, prompt=None, model_name=None, negative_prompt=None, **kwargs):
@@ -60,27 +69,65 @@ class DiffusionInference:
60
  Returns:
61
  PIL.Image: The generated image
62
  """
 
 
 
63
  model = model_name or config.DEFAULT_IMG2IMG_MODEL
64
 
65
- # Convert image path to PIL Image if needed
66
- if isinstance(image, str):
67
- image = Image.open(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- # Set up parameters dictionary
70
- params = {"image": image}
71
-
72
- if prompt:
73
- params["prompt"] = prompt
74
 
75
- if negative_prompt:
76
- params["negative_prompt"] = negative_prompt
 
 
 
 
 
 
 
 
 
 
77
 
78
- # Add any additional parameters
79
- params.update(kwargs)
80
-
81
- try:
82
- result = self.client.image_to_image(model=model, **params)
83
  return result
 
84
  except Exception as e:
85
  print(f"Error transforming image: {e}")
 
 
86
  raise
 
 
 
 
 
 
 
 
 
30
  """
31
  model = model_name or config.DEFAULT_TEXT2IMG_MODEL
32
 
33
+ # Create parameters dictionary for all keyword arguments
34
+ params = {
35
+ "prompt": prompt,
36
+ "model": model
37
+ }
38
 
39
+ # Add negative prompt if provided
40
+ if negative_prompt is not None:
41
  params["negative_prompt"] = negative_prompt
42
+
43
+ # Add any other parameters
44
+ for k, v in kwargs.items():
45
+ if k not in ["prompt", "model", "negative_prompt"]:
46
+ params[k] = v
47
 
48
  try:
49
+ # Call the API with all parameters as kwargs
50
+ image = self.client.text_to_image(**params)
51
  return image
52
  except Exception as e:
53
  print(f"Error generating image: {e}")
54
+ print(f"Model: {model}")
55
+ print(f"Prompt: {prompt}")
56
  raise
57
 
58
  def image_to_image(self, image, prompt=None, model_name=None, negative_prompt=None, **kwargs):
 
69
  Returns:
70
  PIL.Image: The generated image
71
  """
72
+ import tempfile
73
+ import os
74
+
75
  model = model_name or config.DEFAULT_IMG2IMG_MODEL
76
 
77
+ # Create a temporary file for the image if it's a PIL Image
78
+ temp_file = None
79
+ try:
80
+ # Handle different image input types
81
+ if isinstance(image, str):
82
+ # If it's already a file path, use it directly
83
+ image_path = image
84
+ elif isinstance(image, Image.Image):
85
+ # If it's a PIL Image, save it to a temporary file
86
+ temp_dir = tempfile.gettempdir()
87
+ temp_file = os.path.join(temp_dir, "temp_image.png")
88
+ image.save(temp_file, format="PNG")
89
+ image_path = temp_file
90
+ else:
91
+ # If it's something else, try to convert it to a PIL Image first
92
+ try:
93
+ pil_image = Image.fromarray(image)
94
+ temp_dir = tempfile.gettempdir()
95
+ temp_file = os.path.join(temp_dir, "temp_image.png")
96
+ pil_image.save(temp_file, format="PNG")
97
+ image_path = temp_file
98
+ except Exception as e:
99
+ raise ValueError(f"Unsupported image type: {type(image)}. Error: {e}")
100
 
101
+ # Create a parameters dictionary including all the required keyword args
102
+ params = {"model": model}
 
 
 
103
 
104
+ # Add prompt if provided (MUST be as a keyword arg, not positional)
105
+ if prompt is not None:
106
+ params["prompt"] = prompt
107
+
108
+ # Add negative_prompt if provided
109
+ if negative_prompt is not None:
110
+ params["negative_prompt"] = negative_prompt
111
+
112
+ # Add additional parameters
113
+ for k, v in kwargs.items():
114
+ if k not in ["prompt", "model", "negative_prompt", "image"]:
115
+ params[k] = v
116
 
117
+ # Make the API call with image as the only positional arg, all others as kwargs
118
+ result = self.client.image_to_image(image_path, **params)
 
 
 
119
  return result
120
+
121
  except Exception as e:
122
  print(f"Error transforming image: {e}")
123
+ print(f"Model: {model}")
124
+ print(f"Prompt: {prompt}")
125
  raise
126
+
127
+ finally:
128
+ # Clean up the temporary file if it was created
129
+ if temp_file and os.path.exists(temp_file):
130
+ try:
131
+ os.remove(temp_file)
132
+ except Exception as e:
133
+ print(f"Warning: Could not delete temporary file {temp_file}: {e}")