DrishtiSharma commited on
Commit
0733fab
·
verified ·
1 Parent(s): f4bd53a

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +186 -0
main.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json # to work with JSON
2
+ import threading # to allow streaming response
3
+ import time # to pave the deliver of the message
4
+
5
+ import faiss # to create a search index
6
+ import gradio # for the interface
7
+ import numpy # to work with vectors
8
+ import pandas # to work with pandas
9
+ import sentence_transformers # to load an embedding model
10
+ import spaces # for GPU
11
+ import transformers # to load an LLM
12
+
13
+ # Constants
14
+ GREETING = (
15
+ "Howdy! "
16
+ "I'm an AI agent that uses [retrieval-augmented generation](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) pipeline to answer questions about research by the [Design Research Collective](https://cmudrc.github.io/). "
17
+ "And the best part is that I always try to cite my sources! "
18
+ "I still make some mistakes though. "
19
+ "What can I tell you about today?"
20
+ )
21
+ EXAMPLE_QUERIES = [
22
+ "Tell me about new research at the intersection of additive manufacturing and machine learning.",
23
+ "What is a physics-informed neural network and what can it be used for?",
24
+ "What can agent-based models do about climate change?",
25
+ "What's the difference between a markov chain and a hidden markov model?",
26
+ "What are the latest advancements in reinforcement learning?",
27
+ "What is known about different modes for human-AI teaming?",
28
+ ]
29
+ EMBEDDING_MODEL_NAME = "allenai-specter"
30
+ LLM_MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
31
+ PUBLICATIONS_TO_RETRIEVE = 5
32
+ PARQUET_URL = "hf://datasets/ccm/publications/data/train-00000-of-00001.parquet"
33
+
34
+ # Load the dataset and convert to pandas
35
+ data = pandas.read_parquet(PARQUET_URL)
36
+
37
+ # Filter out any publications without an abstract
38
+ abstract_is_null = [
39
+ '"abstract": null' in json.dumps(bibdict) for bibdict in data["bib_dict"].values
40
+ ]
41
+ data = data[~pandas.Series(abstract_is_null)]
42
+ data.reset_index(inplace=True)
43
+
44
+ # Load the model for later use in embeddings
45
+ model = sentence_transformers.SentenceTransformer(EMBEDDING_MODEL_NAME)
46
+
47
+ # Create an LLM pipeline that we can send queries to
48
+ tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL_NAME, trust_remote_code=True)
49
+ streamer = transformers.TextIteratorStreamer(
50
+ tokenizer, skip_prompt=True, skip_special_tokens=True
51
+ )
52
+ chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
53
+ LLM_MODEL_NAME, device_map="auto", torch_dtype="auto", trust_remote_code=True
54
+ )
55
+
56
+ # Create a FAISS index for fast similarity search
57
+ metric = faiss.METRIC_INNER_PRODUCT
58
+ vectors = numpy.stack(data["embedding"].tolist(), axis=0)
59
+ index = faiss.IndexFlatL2(len(data["embedding"][0]))
60
+ index.metric_type = metric
61
+ faiss.normalize_L2(vectors)
62
+ index.train(vectors)
63
+ index.add(vectors)
64
+
65
+
66
+ def preprocess(query: str, k: int) -> tuple[str, str]:
67
+ """
68
+ Searches the dataset for the top k most relevant papers to the query and returns a prompt and references
69
+ Args:
70
+ query (str): The user's query
71
+ k (int): The number of results to return
72
+ Returns:
73
+ tuple[str, str]: A tuple containing the prompt and references
74
+ """
75
+ encoded_query = numpy.expand_dims(model.encode(query), axis=0)
76
+ faiss.normalize_L2(encoded_query)
77
+ D, I = index.search(encoded_query, k)
78
+ top_five = data.loc[I[0]]
79
+
80
+ prompt = (
81
+ "You are an AI assistant who delights in helping people learn about research from the Design Research Collective, which is a research lab at Carnegie Mellon University led by Professor Chris McComb. "
82
+ "Your main task is to provide a concise ANSWER to the USER_QUERY that includes as many of the RESEARCH_ABSTRACTS as possible. "
83
+ "The RESEARCH_ABSTRACTS are provided in the `.bibtex` format. Your ANSWER should contain citations to the RESEARCH_ABSTRACTS using (AUTHOR, YEAR) format. "
84
+ "DO NOT list references at the end of the answer.\n\n"
85
+ "RESEARCH_ABSTRACTS:\n```bibtex\n{{ABSTRACTS_GO_HERE}}\n```\n\n"
86
+ "USER_GUERY:\n{{QUERY_GOES_HERE}}\n\n"
87
+ "ANSWER:\n"
88
+ )
89
+
90
+ references = []
91
+ research_abstracts = ""
92
+
93
+ for i in range(k):
94
+ year = str(int(top_five["bib_dict"].values[i]["pub_year"]))
95
+ abstract = top_five["bib_dict"].values[i]["abstract"]
96
+ url = "https://scholar.google.com/citations?view_op=view_citation&citation_for_view=" + top_five["author_pub_id"].values[i]
97
+ title = top_five["bib_dict"].values[i]["title"]
98
+ last_names = [
99
+ author.split(" ")[-1]
100
+ for author in top_five["bib_dict"]
101
+ .values[i]["author"]
102
+ .split(" and ")
103
+ ]
104
+ authors = ", ".join(
105
+ last_names
106
+ )
107
+
108
+ first_authors_last_name = last_names[0]
109
+
110
+ research_abstracts += top_five["bibtex"].values[i] + "\n"
111
+ references.append(f"<a href=\"{url}\">{first_authors_last_name} {year}</a>")
112
+
113
+ prompt = prompt.replace("{{ABSTRACTS_GO_HERE}}", research_abstracts)
114
+ prompt = prompt.replace("{{QUERY_GOES_HERE}}", query)
115
+
116
+ print(prompt)
117
+
118
+ return prompt, "; ".join(references)
119
+
120
+
121
+ @spaces.GPU
122
+ def reply(message: str, history: list[str]) -> str:
123
+ """
124
+ This function is responsible for crafting a response
125
+ Args:
126
+ message (str): The user's message
127
+ history (list[str]): The conversation history
128
+ Returns:
129
+ str: The AI's response
130
+ """
131
+
132
+ # Apply preprocessing
133
+ message, bypass = preprocess(message, PUBLICATIONS_TO_RETRIEVE)
134
+
135
+ # This is some handling that is applied to the history variable to put it in a good format
136
+ history_transformer_format = [
137
+ {"role": role, "content": message_pair[idx]}
138
+ for message_pair in history
139
+ for idx, role in enumerate(["user", "assistant"])
140
+ if message_pair[idx] is not None
141
+ ] + [{"role": "user", "content": message}]
142
+
143
+ # Stream a response from pipe
144
+ text = tokenizer.apply_chat_template(
145
+ history_transformer_format, tokenize=False, add_generation_prompt=True
146
+ )
147
+ model_inputs = tokenizer([text], return_tensors="pt").to("cuda:0")
148
+
149
+ generate_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=512)
150
+ t = threading.Thread(target=chatmodel.generate, kwargs=generate_kwargs)
151
+ t.start()
152
+
153
+ partial_message = ""
154
+ for new_token in streamer:
155
+ if new_token != "<":
156
+ partial_message += new_token
157
+ time.sleep(0.01)
158
+ yield partial_message
159
+
160
+ yield partial_message + "\n\n" + bypass
161
+
162
+
163
+
164
+ # Create and run the gradio interface
165
+ gradio.ChatInterface(
166
+ reply,
167
+ examples=EXAMPLE_QUERIES,
168
+ chatbot=gradio.Chatbot(
169
+ show_label=False,
170
+ show_share_button=False,
171
+ show_copy_button=False,
172
+ value=[[None, GREETING]],
173
+ avatar_images=[
174
+ "https://cdn.dribbble.com/users/316121/screenshots/2333676/11-04_scotty-plaid_dribbble.png",
175
+ "https://media.thetab.com/blogs.dir/90/files/2021/06/screenshot-2021-06-10-at-110730-1024x537.png",
176
+ ],
177
+ height="60vh",
178
+ bubble_full_width=False,
179
+ ),
180
+ retry_btn=None,
181
+ undo_btn=None,
182
+ clear_btn=None,
183
+ theme=gradio.themes.Default(
184
+ font=[gradio.themes.GoogleFont("Zilla Slab")]
185
+ )
186
+ ).launch(debug=True)