prithivMLmods commited on
Commit
e7dfb91
Β·
verified Β·
1 Parent(s): ccaf688

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -3
app.py CHANGED
@@ -1,7 +1,9 @@
1
  import os
2
  from collections.abc import Iterator
3
  from threading import Thread
4
-
 
 
5
  import gradio as gr
6
  import spaces
7
  import torch
@@ -27,6 +29,38 @@ model = AutoModelForCausalLM.from_pretrained(
27
  model.config.sliding_window = 4096
28
  model.eval()
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  @spaces.GPU(duration=120)
32
  def generate(
@@ -41,6 +75,16 @@ def generate(
41
  conversation = chat_history.copy()
42
  conversation.append({"role": "user", "content": message})
43
 
 
 
 
 
 
 
 
 
 
 
44
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
45
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
46
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
@@ -67,7 +111,6 @@ def generate(
67
  outputs.append(text)
68
  yield "".join(outputs)
69
 
70
-
71
  demo = gr.ChatInterface(
72
  fn=generate,
73
  additional_inputs=[
@@ -121,6 +164,5 @@ demo = gr.ChatInterface(
121
  fill_height=True,
122
  )
123
 
124
-
125
  if __name__ == "__main__":
126
  demo.queue(max_size=20).launch()
 
1
  import os
2
  from collections.abc import Iterator
3
  from threading import Thread
4
+ import requests
5
+ from bs4 import BeautifulSoup
6
+ from readability import Document
7
  import gradio as gr
8
  import spaces
9
  import torch
 
29
  model.config.sliding_window = 4096
30
  model.eval()
31
 
32
+ def extract_text_from_webpage(html_content):
33
+ doc = Document(html_content)
34
+ return doc.summary()
35
+
36
+ def search(query):
37
+ term = query
38
+ all_results = []
39
+ max_chars_per_page = 8000
40
+ with requests.Session() as session:
41
+ resp = session.get(
42
+ url="https://www.google.com/search",
43
+ headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"},
44
+ params={"q": term, "num": 4, "udm": 14},
45
+ timeout=5,
46
+ verify=None,
47
+ )
48
+ resp.raise_for_status()
49
+ soup = BeautifulSoup(resp.text, "html.parser")
50
+ result_block = soup.find_all("div", attrs={"class": "g"})
51
+ for result in result_block:
52
+ link = result.find("a", href=True)
53
+ link = link["href"]
54
+ try:
55
+ webpage = session.get(link, headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"}, timeout=5, verify=False)
56
+ webpage.raise_for_status()
57
+ visible_text = extract_text_from_webpage(webpage.text)
58
+ if len(visible_text) > max_chars_per_page:
59
+ visible_text = visible_text[:max_chars_per_page]
60
+ all_results.append({"link": link, "text": visible_text})
61
+ except requests.exceptions.RequestException:
62
+ all_results.append({"link": link, "text": None})
63
+ return all_results
64
 
65
  @spaces.GPU(duration=120)
66
  def generate(
 
75
  conversation = chat_history.copy()
76
  conversation.append({"role": "user", "content": message})
77
 
78
+ # Check if the message requires a web search
79
+ if "search" in message.lower() or "find" in message.lower():
80
+ search_query = message
81
+ search_results = search(search_query)
82
+ if search_results:
83
+ search_context = "\n".join([result["text"] for result in search_results if result["text"]])
84
+ conversation.append({"role": "assistant", "content": f"Here are some search results:\n{search_context}"})
85
+ else:
86
+ conversation.append({"role": "assistant", "content": "I couldn't find any relevant information."})
87
+
88
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
89
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
90
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
 
111
  outputs.append(text)
112
  yield "".join(outputs)
113
 
 
114
  demo = gr.ChatInterface(
115
  fn=generate,
116
  additional_inputs=[
 
164
  fill_height=True,
165
  )
166
 
 
167
  if __name__ == "__main__":
168
  demo.queue(max_size=20).launch()