lionelgarnier commited on
Commit
b6b421e
·
1 Parent(s): b0c8c02

bugfix cursor gpu

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -11,24 +11,23 @@ from huggingface_hub import login
11
  hf_token = os.getenv("hf_token")
12
  login(token=hf_token)
13
 
14
-
15
- dtype = torch.bfloat16
16
- device = "cuda" if torch.cuda.is_available() else "cpu"
17
-
18
  MAX_SEED = np.iinfo(np.int32).max
19
  MAX_IMAGE_SIZE = 2048
20
 
21
  _text_gen_pipeline = None
 
22
  def get_text_gen_pipeline():
23
  global _text_gen_pipeline
24
  if _text_gen_pipeline is None:
25
  try:
 
26
  _text_gen_pipeline = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.3", max_new_tokens=2048, device=device)
27
  except Exception as e:
28
  print(f"Error loading text generation model: {e}")
29
  return None
30
  return _text_gen_pipeline
31
 
 
32
  def refine_prompt(prompt):
33
  text_gen = get_text_gen_pipeline()
34
  if text_gen is None:
@@ -53,6 +52,10 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_in
53
  try:
54
  progress(0, desc="Starting generation...")
55
 
 
 
 
 
56
  # Validate that prompt is not empty
57
  if not prompt or prompt.strip() == "":
58
  return None, "Please provide a valid prompt."
 
11
  hf_token = os.getenv("hf_token")
12
  login(token=hf_token)
13
 
 
 
 
 
14
  MAX_SEED = np.iinfo(np.int32).max
15
  MAX_IMAGE_SIZE = 2048
16
 
17
  _text_gen_pipeline = None
18
+ @spaces.GPU()
19
  def get_text_gen_pipeline():
20
  global _text_gen_pipeline
21
  if _text_gen_pipeline is None:
22
  try:
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
  _text_gen_pipeline = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.3", max_new_tokens=2048, device=device)
25
  except Exception as e:
26
  print(f"Error loading text generation model: {e}")
27
  return None
28
  return _text_gen_pipeline
29
 
30
+ @spaces.GPU()
31
  def refine_prompt(prompt):
32
  text_gen = get_text_gen_pipeline()
33
  if text_gen is None:
 
52
  try:
53
  progress(0, desc="Starting generation...")
54
 
55
+ device = "cuda" if torch.cuda.is_available() else "cpu"
56
+ dtype = torch.bfloat16
57
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
58
+
59
  # Validate that prompt is not empty
60
  if not prompt or prompt.strip() == "":
61
  return None, "Please provide a valid prompt."