Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
0276240
1
Parent(s):
f9d275c
maybe fixed qwen3
Browse files- utils/models.py +6 -7
utils/models.py
CHANGED
@@ -3,7 +3,7 @@ os.environ['MKL_THREADING_LAYER'] = 'GNU'
|
|
3 |
import spaces
|
4 |
|
5 |
import torch
|
6 |
-
from transformers import pipeline, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
|
7 |
from .prompts import format_rag_prompt
|
8 |
from .shared import generation_interrupt
|
9 |
|
@@ -142,7 +142,7 @@ def run_inference(model_name, context, question):
|
|
142 |
"text-generation",
|
143 |
model=model_name,
|
144 |
tokenizer=tokenizer,
|
145 |
-
device_map='
|
146 |
trust_remote_code=True,
|
147 |
torch_dtype=torch.bfloat16,
|
148 |
)
|
@@ -153,16 +153,15 @@ def run_inference(model_name, context, question):
|
|
153 |
tokenize=False,
|
154 |
**tokenizer_kwargs,
|
155 |
)
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
# Check interrupt before generation
|
160 |
if generation_interrupt.is_set():
|
161 |
return ""
|
162 |
|
163 |
-
outputs = pipe(formatted,
|
164 |
#print(outputs[0]['generated_text'])
|
165 |
-
result = outputs[0]['generated_text']
|
166 |
|
167 |
except Exception as e:
|
168 |
print(f"Error in inference for {model_name}: {e}")
|
|
|
3 |
import spaces
|
4 |
|
5 |
import torch
|
6 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
|
7 |
from .prompts import format_rag_prompt
|
8 |
from .shared import generation_interrupt
|
9 |
|
|
|
142 |
"text-generation",
|
143 |
model=model_name,
|
144 |
tokenizer=tokenizer,
|
145 |
+
device_map='cuda',
|
146 |
trust_remote_code=True,
|
147 |
torch_dtype=torch.bfloat16,
|
148 |
)
|
|
|
153 |
tokenize=False,
|
154 |
**tokenizer_kwargs,
|
155 |
)
|
156 |
+
|
157 |
+
input_length = len(formatted)
|
|
|
158 |
# Check interrupt before generation
|
159 |
if generation_interrupt.is_set():
|
160 |
return ""
|
161 |
|
162 |
+
outputs = pipe(formatted, **generation_kwargs)
|
163 |
#print(outputs[0]['generated_text'])
|
164 |
+
result = outputs[0]['generated_text'][input_length:]
|
165 |
|
166 |
except Exception as e:
|
167 |
print(f"Error in inference for {model_name}: {e}")
|