gimmy256 commited on
Commit
8845357
·
verified ·
1 Parent(s): e5ce63e

Upload app.py

Browse files
Files changed (1) hide show
  1. src/app.py +256 -0
src/app.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ st.title("Medical RAG and Reasoning App")
4
+ st.write("This app demonstrates Retrieval-Augmented Generation (RAG) for medical question answering.")
5
+
6
+ #!/usr/bin/env python
7
+ # coding: utf-8
8
+
9
+ # # HuatuoGPT-o1 Medical RAG and Reasoning
10
+ #
11
+ # _Authored by: [Alan Ponnachan](https://huggingface.co/AlanPonnachan)_
12
+ #
13
+ # This notebook demonstrates an end-to-end example of using HuatuoGPT-o1 for medical question answering with Retrieval-Augmented Generation (RAG) and reasoning. We'll leverage the HuatuoGPT-o1 model, a medical Large Language Model (LLM) designed for advanced medical reasoning, to provide detailed and well-structured answers to medical queries.
14
+ #
15
+ # ## Introduction
16
+ #
17
+ # HuatuoGPT-o1 is a medical LLM that excels at identifying mistakes, exploring alternative strategies, and refining its answers. It utilizes verifiable medical problems and a specialized medical verifier to enhance its reasoning capabilities. This notebook showcases how to use HuatuoGPT-o1 in a RAG setting, where we retrieve relevant information from a medical knowledge base and then use the model to generate a reasoned response.
18
+
19
+ # ## Notebook Setup
20
+ #
21
+ #
22
+ # **Important:** Before running the code, ensure you are using a GPU runtime for faster performance. Go to **"Runtime" -> "Change runtime type"** and select **"GPU"** under "Hardware accelerator."
23
+ #
24
+ # Let's start by installing the necessary libraries.
25
+
26
+ # In[1]:
27
+
28
+
29
+ get_ipython().system('pip install transformers datasets sentence-transformers scikit-learn --upgrade -q')
30
+
31
+
32
+ # ## Load the Dataset
33
+ #
34
+ # We'll use the **"ChatDoctor-HealthCareMagic-100k"** dataset from the Hugging Face Datasets library. This dataset contains 100,000 real-world patient-doctor interactions, providing a rich knowledge base for our RAG system.
35
+
36
+ # In[2]:
37
+
38
+
39
+ from datasets import load_dataset
40
+
41
+ dataset = load_dataset("lavita/ChatDoctor-HealthCareMagic-100k")
42
+
43
+
44
+ # ## Step 3: Initialize the Models
45
+ #
46
+ # We need to initialize two models:
47
+ #
48
+ # 1. **HuatuoGPT-o1**: The medical LLM for generating responses.
49
+ # 2. **Sentence Transformer**: An embedding model for creating vector representations of text, which we'll use for retrieval.
50
+
51
+ # In[3]:
52
+
53
+
54
+ import torch
55
+ from transformers import AutoModelForCausalLM, AutoTokenizer
56
+ from sentence_transformers import SentenceTransformer
57
+
58
+ # Initialize HuatuoGPT-o1
59
+ model_name = "FreedomIntelligence/HuatuoGPT-o1-7B"
60
+ model = AutoModelForCausalLM.from_pretrained(
61
+ model_name, torch_dtype="auto", device_map="auto"
62
+ )
63
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
64
+
65
+ # Initialize Sentence Transformer
66
+ embed_model = SentenceTransformer("all-MiniLM-L6-v2")
67
+
68
+
69
+ # ## Prepare the Knowledge Base
70
+ #
71
+ # We'll create a knowledge base by generating embeddings for the combined question-answer pairs from the dataset.
72
+
73
+ # In[4]:
74
+
75
+
76
+ import pandas as pd
77
+ import numpy as np
78
+
79
+ # Convert dataset to DataFrame
80
+ df = pd.DataFrame(dataset["train"])
81
+
82
+ # Combine question and answer for context
83
+ df["combined"] = df["input"] + " " + df["output"]
84
+
85
+ # Generate embeddings
86
+ st.write("Generating embeddings for the knowledge base...")
87
+ embeddings = embed_model.encode(
88
+ df["combined"].tolist(), show_progress_bar=True, batch_size=128
89
+ )
90
+ st.write("Embeddings generated!")
91
+
92
+
93
+ # ## Implement Retrieval
94
+ #
95
+ # This function retrieves the `k` most relevant contexts to a given query using cosine similarity.
96
+
97
+ # In[5]:
98
+
99
+
100
+ from sklearn.metrics.pairwise import cosine_similarity
101
+
102
+ def retrieve_relevant_contexts(query: str, k: int = 3) -> list:
103
+ """
104
+ Retrieves the k most relevant contexts to a given query.
105
+
106
+ Args:
107
+ query (str): The user's medical query.
108
+ k (int): The number of relevant contexts to retrieve.
109
+
110
+ Returns:
111
+ list: A list of dictionaries, each containing a relevant context.
112
+ """
113
+ # Generate query embedding
114
+ query_embedding = embed_model.encode([query])[0]
115
+
116
+ # Calculate similarities
117
+ similarities = cosine_similarity([query_embedding], embeddings)[0]
118
+
119
+ # Get top k similar contexts
120
+ top_k_indices = np.argsort(similarities)[-k:][::-1]
121
+
122
+ contexts = []
123
+ for idx in top_k_indices:
124
+ contexts.append(
125
+ {
126
+ "question": df.iloc[idx]["input"],
127
+ "answer": df.iloc[idx]["output"],
128
+ "similarity": similarities[idx],
129
+ }
130
+ )
131
+
132
+ return contexts
133
+
134
+
135
+ # ## Implement Response Generation
136
+ #
137
+ # This function generates a detailed response using the retrieved contexts.
138
+
139
+ # In[6]:
140
+
141
+
142
+ def generate_structured_response(query: str, contexts: list) -> str:
143
+ """
144
+ Generates a detailed response using the retrieved contexts.
145
+
146
+ Args:
147
+ query (str): The user's medical query.
148
+ contexts (list): A list of relevant contexts.
149
+
150
+ Returns:
151
+ str: The generated response.
152
+ """
153
+ # Prepare prompt with retrieved contexts
154
+ context_prompt = "\n".join(
155
+ [
156
+ f"Reference {i+1}:"
157
+ f"\nQuestion: {ctx['question']}"
158
+ f"\nAnswer: {ctx['answer']}"
159
+ for i, ctx in enumerate(contexts)
160
+ ]
161
+ )
162
+
163
+ prompt = f"""Based on the following references and your medical knowledge, provide a detailed response:
164
+
165
+ References:
166
+ {context_prompt}
167
+
168
+ Question: {query}
169
+
170
+ By considering:
171
+ 1. The key medical concepts in the question.
172
+ 2. How the reference cases relate to this question.
173
+ 3. What medical principles should be applied.
174
+ 4. Any potential complications or considerations.
175
+
176
+ Give the final response:
177
+ """
178
+
179
+ # Generate response
180
+ messages = [{"role": "user", "content": prompt}]
181
+ inputs = tokenizer(
182
+ tokenizer.apply_chat_template(
183
+ messages, tokenize=False, add_generation_prompt=True
184
+ ),
185
+ return_tensors="pt",
186
+ ).to(model.device)
187
+
188
+ outputs = model.generate(
189
+ **inputs,
190
+ max_new_tokens=1024,
191
+ temperature=0.7,
192
+ num_beams=1,
193
+ do_sample=True,
194
+ )
195
+
196
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
197
+
198
+ # Extract the final response portion
199
+ final_response = response.split("Give the final response:\n")[-1]
200
+
201
+ return final_response
202
+
203
+
204
+ # ## Putting It All Together
205
+ #
206
+ # Let's define a function to process a query end-to-end and then use it with an example.
207
+
208
+ # In[7]:
209
+
210
+
211
+ def process_query(query: str, k: int = 3) -> tuple:
212
+ """
213
+ Processes a medical query end-to-end.
214
+
215
+ Args:
216
+ query (str): The user's medical query.
217
+ k (int): The number of relevant contexts to retrieve.
218
+
219
+ Returns:
220
+ tuple: The generated response and the retrieved contexts.
221
+ """
222
+ contexts = retrieve_relevant_contexts(query, k)
223
+ response = generate_structured_response(query, contexts)
224
+ return response, contexts
225
+
226
+ # Example query
227
+ query = "I've been experiencing persistent headaches and dizziness for the past week. What could be the cause?"
228
+
229
+ # Process query
230
+ response, contexts = process_query(query)
231
+
232
+ # Print results
233
+ st.write("\nQuery:", query)
234
+ st.write("\nRelevant Contexts:")
235
+ for i, ctx in enumerate(contexts, 1):
236
+ st.write(f"\nReference {i} (Similarity: {ctx['similarity']:.3f}):")
237
+ st.write(f"Q: {ctx['question']}")
238
+ st.write(f"A: {ctx['answer']}")
239
+
240
+ st.write("\nGenerated Response:")
241
+ st.write(response)
242
+
243
+
244
+ # ## Conclusion
245
+ #
246
+ # This notebook demonstrates a practical application of HuatuoGPT-o1 for medical question answering using RAG and reasoning. By combining retrieval from a relevant knowledge base with the advanced reasoning capabilities of HuatuoGPT-o1, we can build a system that provides detailed and well-structured answers to complex medical queries.
247
+ #
248
+ # You can further enhance this system by:
249
+ #
250
+ # * Experimenting with different values of `k` (number of retrieved contexts).
251
+ # * Fine-tuning HuatuoGPT-o1 on a specific medical domain.
252
+ # * Evaluating the system's performance using medical benchmarks.
253
+ # * Adding a user interface for easier interaction.
254
+ # * Improving upon existing code by handling edge cases.
255
+ #
256
+ # Feel free to adapt and expand upon this example to create even more powerful and helpful medical AI applications!