isayahc commited on
Commit
2b65fe3
·
1 Parent(s): 103bc92

attempt to stream output

Browse files
Files changed (1) hide show
  1. app.py +49 -1
app.py CHANGED
@@ -7,6 +7,9 @@ from langchain.document_loaders import WebBaseLoader
7
 
8
  from huggingface_hub import AsyncInferenceClient
9
 
 
 
 
10
  from langchain.text_splitter import RecursiveCharacterTextSplitter
11
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=350, chunk_overlap=10)
12
 
@@ -38,6 +41,49 @@ retriever = db.as_retriever()
38
  global qa
39
  qa = RetrievalQA.from_chain_type(llm=model_id, chain_type="stuff", retriever=retriever, return_source_documents=True)
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  def add_text(history, text):
43
  history = history + [(text, None)]
@@ -77,4 +123,6 @@ with gr.Blocks(css=css) as demo:
77
  bot, chatbot, chatbot
78
  )
79
 
80
- demo.launch()
 
 
 
7
 
8
  from huggingface_hub import AsyncInferenceClient
9
 
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
+
12
+
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
14
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=350, chunk_overlap=10)
15
 
 
41
  global qa
42
  qa = RetrievalQA.from_chain_type(llm=model_id, chain_type="stuff", retriever=retriever, return_source_documents=True)
43
 
44
+ def generate(
45
+ message: str,
46
+ chat_history: list[tuple[str, str]],
47
+ system_prompt: str,
48
+ max_new_tokens: int = 1024,
49
+ temperature: float = 0.6,
50
+ top_p: float = 0.9,
51
+ top_k: int = 50,
52
+ repetition_penalty: float = 1.2,
53
+ ) -> Iterator[str]:
54
+ conversation = []
55
+ if system_prompt:
56
+ conversation.append({"role": "system", "content": system_prompt})
57
+ for user, assistant in chat_history:
58
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
59
+ conversation.append({"role": "user", "content": message})
60
+
61
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
62
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
63
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
64
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
65
+ input_ids = input_ids.to(model.device)
66
+
67
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
68
+ generate_kwargs = dict(
69
+ {"input_ids": input_ids},
70
+ streamer=streamer,
71
+ max_new_tokens=max_new_tokens,
72
+ do_sample=True,
73
+ top_p=top_p,
74
+ top_k=top_k,
75
+ temperature=temperature,
76
+ num_beams=1,
77
+ repetition_penalty=repetition_penalty,
78
+ )
79
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
80
+ t.start()
81
+
82
+ outputs = []
83
+ for text in streamer:
84
+ outputs.append(text)
85
+ yield "".join(outputs)
86
+
87
 
88
  def add_text(history, text):
89
  history = history + [(text, None)]
 
123
  bot, chatbot, chatbot
124
  )
125
 
126
+ if __name__ == "__main__":
127
+
128
+ demo.launch()