disLodge commited on
Commit
1a1cf30
·
verified ·
1 Parent(s): f9dce8d

Removed ChatHuggingface with custom wrapper to wrap InferenceClient

Browse files
Files changed (1) hide show
  1. app.py +36 -10
app.py CHANGED
@@ -12,11 +12,42 @@ from langchain.text_splitter import CharacterTextSplitter
12
  from huggingface_hub import InferenceClient
13
  import logging
14
 
15
- logging.basicConfig(level=logging.INFO)
16
- logger = logging.getLogger(__name__)
17
 
18
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def extract_pdf_text(url: str) -> str:
21
  response = requests.get(url)
22
  pdf_file = BytesIO(response.content)
@@ -39,11 +70,7 @@ vectorstore = Chroma.from_documents(
39
  )
40
  retriever = vectorstore.as_retriever()
41
 
42
- llm = ChatHuggingFace(
43
- huggingfacehub_api_token=None,
44
- model_id="HuggingFaceH4/zephyr-7b-beta",
45
- interference_client=client,
46
- )
47
 
48
  # Before RAG chain
49
  before_rag_template = "What is {topic}"
@@ -75,9 +102,8 @@ after_rag_chain = (
75
  )
76
 
77
  def process_query(role, system_message, max_tokens, temperature, top_p):
78
- client.max_tokens = max_tokens
79
- client.temperature = temperature
80
- client.top_p = top_p
81
 
82
  # Before RAG
83
  before_rag_result = before_rag_chain.invoke({"topic": "Hugging Face"})
 
12
  from huggingface_hub import InferenceClient
13
  import logging
14
 
15
+ # logging.basicConfig(level=logging.INFO)
16
+ # logger = logging.getLogger(__name__)
17
 
18
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
19
 
20
+ class HuggingFaceInterferenceClientRunnable(Runnable):
21
+ def __init__(self, client, max_tokens=512, temperature=0.7, top_p=0.95):
22
+ self.client = client
23
+ self.max_tokens = max_tokens
24
+ self.temperature = temperature
25
+ self.top_p = top_p
26
+
27
+ def invoke(self, input, config=None):
28
+ prompt = input.to_messages()[0].content
29
+ messages = [{"role": "user", "content": prompt}]
30
+
31
+ response = ""
32
+ for part in self.client.chat_completion(
33
+ messages,
34
+ max_tokens=self.max_tokens,
35
+ stream=True,
36
+ temperature=self.temperature,
37
+ top_p=self.top_p
38
+ ):
39
+ token = part.choices[0].delta.content
40
+ if token:
41
+ response += token
42
+
43
+ return response
44
+
45
+ def update_params(self, max_tokens, temperature, top_p):
46
+ self.max_tokens = max_tokens
47
+ self.temperature=temperature
48
+ self.top_p=top_p
49
+
50
+
51
  def extract_pdf_text(url: str) -> str:
52
  response = requests.get(url)
53
  pdf_file = BytesIO(response.content)
 
70
  )
71
  retriever = vectorstore.as_retriever()
72
 
73
+ llm = HuggingFaceInterferenceClientRunnable(client)
 
 
 
 
74
 
75
  # Before RAG chain
76
  before_rag_template = "What is {topic}"
 
102
  )
103
 
104
  def process_query(role, system_message, max_tokens, temperature, top_p):
105
+
106
+ llm.update_params(max_tokens, temperature, top_p)
 
107
 
108
  # Before RAG
109
  before_rag_result = before_rag_chain.invoke({"topic": "Hugging Face"})