ovedrive commited on
Commit
3bca564
·
1 Parent(s): 76b0872

make cpu compatible

Browse files
Files changed (2) hide show
  1. inference.py +5 -4
  2. requirements.txt +0 -1
inference.py CHANGED
@@ -18,6 +18,7 @@ class DiffusionInference:
18
  provider="hf-inference",
19
  api_key=self.api_key,
20
  )
 
21
 
22
  def text_to_image(self, prompt, model_name=None, negative_prompt=None, seed=None, **kwargs):
23
  """
@@ -154,17 +155,17 @@ class DiffusionInference:
154
  if seed is not None:
155
  try:
156
  # Convert to integer and add to params
157
- generator = torch.Generator(device="cuda").manual_seed(seed)
158
  except (ValueError, TypeError):
159
  # Use random seed if conversion fails
160
  random_seed = random.randint(0, 3999999999) # Max 32-bit integer
161
- generator = torch.Generator(device="cuda").manual_seed(random_seed)
162
  print(f"Warning: Invalid seed value: {seed}, using random seed {random_seed} instead")
163
  else:
164
  # Generate random seed when none is provided
165
  random_seed = random.randint(0, 3999999999) # Max 32-bit integer
166
- generator = torch.Generator(device="cuda").manual_seed(random_seed)
167
  print(f"Using random seed: {random_seed}")
168
- pipeline = AutoPipelineForText2Image.from_pretrained(model_name, generator=generator, torch_dtype=torch.float16).to("cuda")
169
  image = pipeline(**kwargs).images[0]
170
  return image
 
18
  provider="hf-inference",
19
  api_key=self.api_key,
20
  )
21
+ self.device = torch.device("cuda" if torch.cuda else "cpu")
22
 
23
  def text_to_image(self, prompt, model_name=None, negative_prompt=None, seed=None, **kwargs):
24
  """
 
155
  if seed is not None:
156
  try:
157
  # Convert to integer and add to params
158
+ generator = torch.Generator(device=self.device).manual_seed(seed)
159
  except (ValueError, TypeError):
160
  # Use random seed if conversion fails
161
  random_seed = random.randint(0, 3999999999) # Max 32-bit integer
162
+ generator = torch.Generator(device=self.device).manual_seed(random_seed)
163
  print(f"Warning: Invalid seed value: {seed}, using random seed {random_seed} instead")
164
  else:
165
  # Generate random seed when none is provided
166
  random_seed = random.randint(0, 3999999999) # Max 32-bit integer
167
+ generator = torch.Generator(device=self.device).manual_seed(random_seed)
168
  print(f"Using random seed: {random_seed}")
169
+ pipeline = AutoPipelineForText2Image.from_pretrained(model_name, generator=generator, torch_dtype=torch.float16).to(self.device)
170
  image = pipeline(**kwargs).images[0]
171
  return image
requirements.txt CHANGED
@@ -8,7 +8,6 @@ torch
8
  transformers
9
  diffusers
10
  spaces>=0.14.0
11
- xformers
12
  numpy
13
  accelerate
14
  sentencepiece
 
8
  transformers
9
  diffusers
10
  spaces>=0.14.0
 
11
  numpy
12
  accelerate
13
  sentencepiece