lionelgarnier commited on
Commit
3dc3dff
·
1 Parent(s): 52efc32

Fix text generation pipeline tokenizer and reorder model loading checks

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -22,7 +22,6 @@ _image_gen_pipeline = None
22
  def get_image_gen_pipeline():
23
  global _image_gen_pipeline
24
  if _image_gen_pipeline is None:
25
- print("Loading image generation model on first use...") # Optional debug message
26
  try:
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
  dtype = torch.bfloat16
@@ -53,12 +52,17 @@ def get_text_gen_pipeline():
53
  "mistralai/Mistral-7B-Instruct-v0.3",
54
  use_fast=True
55
  )
 
 
 
 
56
  _text_gen_pipeline = pipeline(
57
  "text-generation",
58
  model="mistralai/Mistral-7B-Instruct-v0.3",
59
- tokenizer=tokenizer,
60
  max_new_tokens=2048,
61
  device=device,
 
62
  )
63
  except Exception as e:
64
  print(f"Error loading text generation model: {e}")
@@ -100,14 +104,14 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_in
100
  try:
101
  progress(0, desc="Starting generation...")
102
 
103
- pipe = get_image_gen_pipeline()
104
- if pipe is None:
105
- return None, "Image generation model is unavailable."
106
-
107
  # Validate that prompt is not empty
108
  if not prompt or prompt.strip() == "":
109
  return None, "Please provide a valid prompt."
110
 
 
 
 
 
111
  # Validate width/height dimensions
112
  is_valid, error_msg = validate_dimensions(width, height)
113
  if not is_valid:
 
22
  def get_image_gen_pipeline():
23
  global _image_gen_pipeline
24
  if _image_gen_pipeline is None:
 
25
  try:
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
  dtype = torch.bfloat16
 
52
  "mistralai/Mistral-7B-Instruct-v0.3",
53
  use_fast=True
54
  )
55
+ # Set pad_token_id to eos_token_id if pad_token is not set
56
+ if tokenizer.pad_token is None:
57
+ tokenizer.pad_token = tokenizer.eos_token
58
+
59
  _text_gen_pipeline = pipeline(
60
  "text-generation",
61
  model="mistralai/Mistral-7B-Instruct-v0.3",
62
+ tokenizer=tokenizer,
63
  max_new_tokens=2048,
64
  device=device,
65
+ pad_token_id=tokenizer.pad_token_id # Explicitly set pad_token_id
66
  )
67
  except Exception as e:
68
  print(f"Error loading text generation model: {e}")
 
104
  try:
105
  progress(0, desc="Starting generation...")
106
 
 
 
 
 
107
  # Validate that prompt is not empty
108
  if not prompt or prompt.strip() == "":
109
  return None, "Please provide a valid prompt."
110
 
111
+ pipe = get_image_gen_pipeline()
112
+ if pipe is None:
113
+ return None, "Image generation model is unavailable."
114
+
115
  # Validate width/height dimensions
116
  is_valid, error_msg = validate_dimensions(width, height)
117
  if not is_valid: