oliver-aizip commited on
Commit
0276240
·
1 Parent(s): f9d275c

maybe fixed qwen3

Browse files
Files changed (1) hide show
  1. utils/models.py +6 -7
utils/models.py CHANGED
@@ -3,7 +3,7 @@ os.environ['MKL_THREADING_LAYER'] = 'GNU'
3
  import spaces
4
 
5
  import torch
6
- from transformers import pipeline, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
7
  from .prompts import format_rag_prompt
8
  from .shared import generation_interrupt
9
 
@@ -142,7 +142,7 @@ def run_inference(model_name, context, question):
142
  "text-generation",
143
  model=model_name,
144
  tokenizer=tokenizer,
145
- device_map='auto',
146
  trust_remote_code=True,
147
  torch_dtype=torch.bfloat16,
148
  )
@@ -153,16 +153,15 @@ def run_inference(model_name, context, question):
153
  tokenize=False,
154
  **tokenizer_kwargs,
155
  )
156
-
157
-
158
-
159
  # Check interrupt before generation
160
  if generation_interrupt.is_set():
161
  return ""
162
 
163
- outputs = pipe(formatted, skip_special_tokens=True, **generation_kwargs, )
164
  #print(outputs[0]['generated_text'])
165
- result = outputs[0]['generated_text']
166
 
167
  except Exception as e:
168
  print(f"Error in inference for {model_name}: {e}")
 
3
  import spaces
4
 
5
  import torch
6
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
7
  from .prompts import format_rag_prompt
8
  from .shared import generation_interrupt
9
 
 
142
  "text-generation",
143
  model=model_name,
144
  tokenizer=tokenizer,
145
+ device_map='cuda',
146
  trust_remote_code=True,
147
  torch_dtype=torch.bfloat16,
148
  )
 
153
  tokenize=False,
154
  **tokenizer_kwargs,
155
  )
156
+
157
+ input_length = len(formatted)
 
158
  # Check interrupt before generation
159
  if generation_interrupt.is_set():
160
  return ""
161
 
162
+ outputs = pipe(formatted, **generation_kwargs)
163
  #print(outputs[0]['generated_text'])
164
+ result = outputs[0]['generated_text'][input_length:]
165
 
166
  except Exception as e:
167
  print(f"Error in inference for {model_name}: {e}")