Spaces:
Build error
Build error
Daniel Marques
commited on
Commit
·
66a4e8f
1
Parent(s):
b606edb
feat: add stream
Browse files- main.py +2 -2
- run_localGPT.py +3 -7
main.py
CHANGED
|
@@ -14,9 +14,9 @@ from langchain.embeddings import HuggingFaceInstructEmbeddings
|
|
| 14 |
from langchain.prompts import PromptTemplate
|
| 15 |
from langchain.memory import ConversationBufferMemory
|
| 16 |
|
|
|
|
| 17 |
# from langchain.embeddings import HuggingFaceEmbeddings
|
| 18 |
from run_localGPT import load_model
|
| 19 |
-
from prompt_template_utils import get_prompt_template
|
| 20 |
|
| 21 |
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
| 22 |
from langchain.vectorstores import Chroma
|
|
@@ -45,7 +45,7 @@ DB = Chroma(
|
|
| 45 |
|
| 46 |
RETRIEVER = DB.as_retriever()
|
| 47 |
|
| 48 |
-
LLM
|
| 49 |
|
| 50 |
template = """you are a helpful, respectful and honest assistant.
|
| 51 |
Your name is Katara llma. You should only use the source documents provided to answer the questions.
|
|
|
|
| 14 |
from langchain.prompts import PromptTemplate
|
| 15 |
from langchain.memory import ConversationBufferMemory
|
| 16 |
|
| 17 |
+
|
| 18 |
# from langchain.embeddings import HuggingFaceEmbeddings
|
| 19 |
from run_localGPT import load_model
|
|
|
|
| 20 |
|
| 21 |
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
| 22 |
from langchain.vectorstores import Chroma
|
|
|
|
| 45 |
|
| 46 |
RETRIEVER = DB.as_retriever()
|
| 47 |
|
| 48 |
+
LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=False)
|
| 49 |
|
| 50 |
template = """you are a helpful, respectful and honest assistant.
|
| 51 |
Your name is Katara llma. You should only use the source documents provided to answer the questions.
|
run_localGPT.py
CHANGED
|
@@ -10,8 +10,6 @@ from langchain.callbacks.manager import CallbackManager
|
|
| 10 |
|
| 11 |
torch.set_grad_enabled(False)
|
| 12 |
|
| 13 |
-
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
| 14 |
-
|
| 15 |
from prompt_template_utils import get_prompt_template
|
| 16 |
|
| 17 |
from langchain.vectorstores import Chroma
|
|
@@ -38,7 +36,7 @@ from constants import (
|
|
| 38 |
|
| 39 |
|
| 40 |
|
| 41 |
-
def load_model(device_type, model_id, model_basename=None, LOGGING=logging):
|
| 42 |
"""
|
| 43 |
Select a model for text generation using the HuggingFace library.
|
| 44 |
If you are running this for the first time, it will download a model for you.
|
|
@@ -91,15 +89,13 @@ def load_model(device_type, model_id, model_basename=None, LOGGING=logging):
|
|
| 91 |
top_k=40,
|
| 92 |
repetition_penalty=1.0,
|
| 93 |
generation_config=generation_config,
|
| 94 |
-
|
| 95 |
-
num_return_sequences=1,
|
| 96 |
-
eos_token_id=tokenizer.eos_token_id
|
| 97 |
)
|
| 98 |
|
| 99 |
local_llm = HuggingFacePipeline(pipeline=pipe)
|
| 100 |
logging.info("Local LLM Loaded")
|
| 101 |
|
| 102 |
-
return
|
| 103 |
|
| 104 |
|
| 105 |
def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
|
|
|
|
| 10 |
|
| 11 |
torch.set_grad_enabled(False)
|
| 12 |
|
|
|
|
|
|
|
| 13 |
from prompt_template_utils import get_prompt_template
|
| 14 |
|
| 15 |
from langchain.vectorstores import Chroma
|
|
|
|
| 36 |
|
| 37 |
|
| 38 |
|
| 39 |
+
def load_model(device_type, model_id, model_basename=None, LOGGING=logging, stream=False):
|
| 40 |
"""
|
| 41 |
Select a model for text generation using the HuggingFace library.
|
| 42 |
If you are running this for the first time, it will download a model for you.
|
|
|
|
| 89 |
top_k=40,
|
| 90 |
repetition_penalty=1.0,
|
| 91 |
generation_config=generation_config,
|
| 92 |
+
callback=[StreamingStdOutCallbackHandler()]
|
|
|
|
|
|
|
| 93 |
)
|
| 94 |
|
| 95 |
local_llm = HuggingFacePipeline(pipeline=pipe)
|
| 96 |
logging.info("Local LLM Loaded")
|
| 97 |
|
| 98 |
+
return local_llm
|
| 99 |
|
| 100 |
|
| 101 |
def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
|