Spaces:
Build error
Build error
Daniel Marques
commited on
Commit
·
e72e226
1
Parent(s):
dc8d635
fix: add callback
Browse files- load_models.py +2 -2
- main.py +0 -3
load_models.py
CHANGED
|
@@ -3,6 +3,7 @@ import logging
|
|
| 3 |
from auto_gptq import AutoGPTQForCausalLM
|
| 4 |
from huggingface_hub import hf_hub_download
|
| 5 |
from langchain.llms import LlamaCpp, HuggingFacePipeline
|
|
|
|
| 6 |
|
| 7 |
from transformers import (
|
| 8 |
AutoModelForCausalLM,
|
|
@@ -204,8 +205,6 @@ def load_model(device_type, model_id, model_basename=None, LOGGING=logging, stre
|
|
| 204 |
|
| 205 |
streamer = TextStreamer(tokenizer, skip_prompt=True)
|
| 206 |
|
| 207 |
-
logging.info(streamer)
|
| 208 |
-
|
| 209 |
pipe = pipeline(
|
| 210 |
"text-generation",
|
| 211 |
model=model,
|
|
@@ -217,6 +216,7 @@ def load_model(device_type, model_id, model_basename=None, LOGGING=logging, stre
|
|
| 217 |
repetition_penalty=1.0,
|
| 218 |
generation_config=generation_config,
|
| 219 |
streamer=streamer
|
|
|
|
| 220 |
)
|
| 221 |
|
| 222 |
local_llm = HuggingFacePipeline(pipeline=pipe)
|
|
|
|
| 3 |
from auto_gptq import AutoGPTQForCausalLM
|
| 4 |
from huggingface_hub import hf_hub_download
|
| 5 |
from langchain.llms import LlamaCpp, HuggingFacePipeline
|
| 6 |
+
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
| 7 |
|
| 8 |
from transformers import (
|
| 9 |
AutoModelForCausalLM,
|
|
|
|
| 205 |
|
| 206 |
streamer = TextStreamer(tokenizer, skip_prompt=True)
|
| 207 |
|
|
|
|
|
|
|
| 208 |
pipe = pipeline(
|
| 209 |
"text-generation",
|
| 210 |
model=model,
|
|
|
|
| 216 |
repetition_penalty=1.0,
|
| 217 |
generation_config=generation_config,
|
| 218 |
streamer=streamer
|
| 219 |
+
callbacks=[StreamingStdOutCallbackHandler()]
|
| 220 |
)
|
| 221 |
|
| 222 |
local_llm = HuggingFacePipeline(pipeline=pipe)
|
main.py
CHANGED
|
@@ -179,9 +179,6 @@ async def predict(data: Predict):
|
|
| 179 |
(os.path.basename(str(document.metadata["source"])), str(document.page_content))
|
| 180 |
)
|
| 181 |
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
return {"response": prompt_response_dict}
|
| 186 |
else:
|
| 187 |
raise HTTPException(status_code=400, detail="Prompt Incorrect")
|
|
|
|
| 179 |
(os.path.basename(str(document.metadata["source"])), str(document.page_content))
|
| 180 |
)
|
| 181 |
|
|
|
|
|
|
|
|
|
|
| 182 |
return {"response": prompt_response_dict}
|
| 183 |
else:
|
| 184 |
raise HTTPException(status_code=400, detail="Prompt Incorrect")
|