Spaces:
Runtime error
Runtime error
Update main.py
Browse files
main.py
CHANGED
|
@@ -36,6 +36,18 @@ abstract_is_null = [
|
|
| 36 |
data = data[~pandas.Series(abstract_is_null)]
|
| 37 |
data.reset_index(inplace=True)
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
# Create a FAISS index for fast similarity search
|
| 40 |
metric = faiss.METRIC_INNER_PRODUCT
|
| 41 |
vectors = numpy.stack(data["embedding"].tolist(), axis=0)
|
|
@@ -45,34 +57,35 @@ faiss.normalize_L2(vectors)
|
|
| 45 |
index.train(vectors)
|
| 46 |
index.add(vectors)
|
| 47 |
|
| 48 |
-
# Load the model for later use in embeddings
|
| 49 |
-
model = sentence_transformers.SentenceTransformer(EMBEDDING_MODEL_NAME)
|
| 50 |
|
| 51 |
-
|
| 52 |
-
def search(query: str, k: int) -> tuple[str, str]:
|
| 53 |
"""
|
| 54 |
-
Searches the dataset for the top k most relevant papers to the query
|
| 55 |
Args:
|
| 56 |
query (str): The user's query
|
| 57 |
k (int): The number of results to return
|
| 58 |
Returns:
|
| 59 |
-
tuple[str, str]: A tuple containing the
|
| 60 |
"""
|
| 61 |
-
|
| 62 |
-
faiss.normalize_L2(
|
| 63 |
-
D, I = index.search(
|
| 64 |
top_five = data.loc[I[0]]
|
| 65 |
|
| 66 |
-
|
| 67 |
"You are an AI assistant who delights in helping people learn about research from the Design "
|
| 68 |
-
"Research Collective. Your main task is to provide an ANSWER to the USER_QUERY based on the
|
| 69 |
-
"RESEARCH_ABSTRACTS
|
|
|
|
|
|
|
|
|
|
| 70 |
)
|
| 71 |
|
| 72 |
references = "\n\n## References\n\n"
|
|
|
|
| 73 |
|
| 74 |
for i in range(k):
|
| 75 |
-
|
| 76 |
references += (
|
| 77 |
str(i + 1)
|
| 78 |
+ ". "
|
|
@@ -93,36 +106,12 @@ def search(query: str, k: int) -> tuple[str, str]:
|
|
| 93 |
+ top_five["author_pub_id"].values[i]
|
| 94 |
+ ").\n"
|
| 95 |
)
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
)
|
| 100 |
-
|
| 101 |
-
return search_results, references
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
# Create an LLM pipeline that we can send queries to
|
| 105 |
-
tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
|
| 106 |
-
streamer = transformers.TextIteratorStreamer(
|
| 107 |
-
tokenizer, skip_prompt=True, skip_special_tokens=True
|
| 108 |
-
)
|
| 109 |
-
chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
|
| 110 |
-
LLM_MODEL_NAME, torch_dtype="auto", device_map="auto"
|
| 111 |
-
)
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
def preprocess(message: str) -> tuple[str, str]:
|
| 115 |
-
"""
|
| 116 |
-
Applies a preprocessing step to the user's message before the LLM receives it
|
| 117 |
-
Args:
|
| 118 |
-
message (str): The user's message
|
| 119 |
-
Returns:
|
| 120 |
-
tuple[str, str]: A tuple containing the preprocessed message and a bypass variable
|
| 121 |
-
"""
|
| 122 |
-
block_search_results, formatted_search_results = search(message, 5)
|
| 123 |
-
return block_search_results + message + "\nANSWER: ", formatted_search_results
|
| 124 |
-
|
| 125 |
-
|
| 126 |
def postprocess(response: str, bypass_from_preprocessing: str) -> str:
|
| 127 |
"""
|
| 128 |
Applies a postprocessing step to the LLM's response before the user receives it
|
|
@@ -147,7 +136,7 @@ def reply(message: str, history: list[str]) -> str:
|
|
| 147 |
"""
|
| 148 |
|
| 149 |
# Apply preprocessing
|
| 150 |
-
message, bypass = preprocess(message)
|
| 151 |
|
| 152 |
# This is some handling that is applied to the history variable to put it in a good format
|
| 153 |
history_transformer_format = [
|
|
|
|
| 36 |
data = data[~pandas.Series(abstract_is_null)]
|
| 37 |
data.reset_index(inplace=True)
|
| 38 |
|
| 39 |
+
# Load the model for later use in embeddings
|
| 40 |
+
model = sentence_transformers.SentenceTransformer(EMBEDDING_MODEL_NAME)
|
| 41 |
+
|
| 42 |
+
# Create an LLM pipeline that we can send queries to
|
| 43 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
|
| 44 |
+
streamer = transformers.TextIteratorStreamer(
|
| 45 |
+
tokenizer, skip_prompt=True, skip_special_tokens=True
|
| 46 |
+
)
|
| 47 |
+
chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
|
| 48 |
+
LLM_MODEL_NAME, torch_dtype="auto", device_map="auto"
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
# Create a FAISS index for fast similarity search
|
| 52 |
metric = faiss.METRIC_INNER_PRODUCT
|
| 53 |
vectors = numpy.stack(data["embedding"].tolist(), axis=0)
|
|
|
|
| 57 |
index.train(vectors)
|
| 58 |
index.add(vectors)
|
| 59 |
|
|
|
|
|
|
|
| 60 |
|
| 61 |
+
def preprocess(query: str, k: int) -> tuple[str, str]:
|
|
|
|
| 62 |
"""
|
| 63 |
+
Searches the dataset for the top k most relevant papers to the query and returns a prompt and references
|
| 64 |
Args:
|
| 65 |
query (str): The user's query
|
| 66 |
k (int): The number of results to return
|
| 67 |
Returns:
|
| 68 |
+
tuple[str, str]: A tuple containing the prompt and references
|
| 69 |
"""
|
| 70 |
+
encoded_query = numpy.expand_dims(model.encode(query), axis=0)
|
| 71 |
+
faiss.normalize_L2(encoded_query)
|
| 72 |
+
D, I = index.search(encoded_query, k)
|
| 73 |
top_five = data.loc[I[0]]
|
| 74 |
|
| 75 |
+
prompt = (
|
| 76 |
"You are an AI assistant who delights in helping people learn about research from the Design "
|
| 77 |
+
"Research Collective. Your main task is to provide an ANSWER to the USER_QUERY based on the "
|
| 78 |
+
"RESEARCH_ABSTRACTS.\n\n"
|
| 79 |
+
"RESEARCH_ABSTRACTS:\n{{ABSTRACTS_GO_HERE}}\n\n"
|
| 80 |
+
"USER_GUERY:\n{{QUERY_GOES_HERE}}\n\n"
|
| 81 |
+
"ANSWER:\n"
|
| 82 |
)
|
| 83 |
|
| 84 |
references = "\n\n## References\n\n"
|
| 85 |
+
research_abstracts = ""
|
| 86 |
|
| 87 |
for i in range(k):
|
| 88 |
+
research_abstracts += top_five["bib_dict"].values[i]["abstract"] + "\n"
|
| 89 |
references += (
|
| 90 |
str(i + 1)
|
| 91 |
+ ". "
|
|
|
|
| 106 |
+ top_five["author_pub_id"].values[i]
|
| 107 |
+ ").\n"
|
| 108 |
)
|
| 109 |
+
|
| 110 |
+
prompt = prompt.replace("{{ABSTRACTS_GO_HERE}}", research_abstracts)
|
| 111 |
+
prompt = prompt.replace("{{QUERY_GOES_HERE}}", query)
|
| 112 |
|
| 113 |
+
return prompt, references
|
| 114 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
def postprocess(response: str, bypass_from_preprocessing: str) -> str:
|
| 116 |
"""
|
| 117 |
Applies a postprocessing step to the LLM's response before the user receives it
|
|
|
|
| 136 |
"""
|
| 137 |
|
| 138 |
# Apply preprocessing
|
| 139 |
+
message, bypass = preprocess(message, 5)
|
| 140 |
|
| 141 |
# This is some handling that is applied to the history variable to put it in a good format
|
| 142 |
history_transformer_format = [
|