abdullah63 commited on
Commit
3b55f38
·
verified ·
1 Parent(s): 5a89d46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -39
app.py CHANGED
@@ -1,64 +1,118 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
8
 
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def respond(
11
  message,
12
  history: list[tuple[str, str]],
13
  system_message,
14
  max_tokens,
15
  temperature,
16
- top_p,
17
  ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
  demo = gr.ChatInterface(
47
  respond,
48
  additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
 
 
 
52
  gr.Slider(
53
  minimum=0.1,
54
  maximum=1.0,
55
  value=0.95,
56
  step=0.05,
57
- label="Top-p (nucleus sampling)",
58
  ),
59
  ],
 
 
60
  )
61
 
62
-
63
  if __name__ == "__main__":
64
- demo.launch()
 
1
  import gradio as gr
2
+ import faiss
3
+ import numpy as np
4
+ import pandas as pd
5
+ from sentence_transformers import SentenceTransformer
6
+ import google.generativeai as genai
7
+ import re
8
+ import os
9
 
10
+ # Load data and FAISS index
11
+ def load_data_and_index():
12
+ docs_df = pd.read_pickle("docs_with_embeddings (1).pkl") # Adjust path for HF Spaces
13
+ embeddings = np.array(docs_df['embeddings'].tolist(), dtype=np.float32)
14
+ dimension = embeddings.shape[1]
15
+ index = faiss.IndexFlatL2(dimension)
16
+ index.add(embeddings)
17
+ return docs_df, index
18
 
19
+ docs_df, index = load_data_and_index()
20
 
21
+ # Load SentenceTransformer
22
+ minilm = SentenceTransformer('all-MiniLM-L6-v2')
23
+
24
+ # Configure Gemini API using Hugging Face Secrets
25
+ GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
26
+ if not GEMINI_API_KEY:
27
+ raise ValueError("Gemini API key not found. Please set it in Hugging Face Spaces secrets.")
28
+ genai.configure(api_key=GEMINI_API_KEY)
29
+ model = genai.GenerativeModel('gemini-2.0-flash')
30
+
31
+ # Preprocess text function
32
+ def preprocess_text(text):
33
+ text = text.lower()
34
+ text = text.replace('\n', ' ').replace('\t', ' ')
35
+ text = re.sub(r'[^\w\s.,;:>-]', ' ', text)
36
+ text = ' '.join(text.split()).strip()
37
+ return text
38
+
39
+ # Retrieve documents
40
+ def retrieve_docs(query, k=5):
41
+ query_embedding = minilm.encode([query], show_progress_bar=False)[0].astype(np.float32)
42
+ distances, indices = index.search(np.array([query_embedding]), k)
43
+ retrieved_docs = docs_df.iloc[indices[0]][['label', 'text', 'source']]
44
+ retrieved_docs['distance'] = distances[0]
45
+ return retrieved_docs
46
+
47
+ # RAG pipeline integrated into respond function
48
  def respond(
49
  message,
50
  history: list[tuple[str, str]],
51
  system_message,
52
  max_tokens,
53
  temperature,
54
+ top_p, # Keeping top_p as an input, though Gemini doesn’t use it directly
55
  ):
56
+ # Preprocess the user message
57
+ preprocessed_query = preprocess_text(message)
58
+
59
+ # Retrieve relevant documents
60
+ retrieved_docs = retrieve_docs(preprocessed_query, k=5)
61
+ context = "\n".join(retrieved_docs['text'].tolist())
62
+
63
+ # Construct the prompt with system message, history, and RAG context
64
+ prompt = f"{system_message}\n\n"
65
+ for user_msg, assistant_msg in history:
66
+ if user_msg:
67
+ prompt += f"User: {user_msg}\n"
68
+ if assistant_msg:
69
+ prompt += f"Assistant: {assistant_msg}\n"
70
+ prompt += (
71
+ f"Query: {message}\n"
72
+ f"Relevant Context: {context}\n"
73
+ f"Generate a short, concise, and to-the-point response to the query based only on the provided context."
74
+ )
75
+
76
+ # Generate response with Gemini
77
+ response = model.generate_content(
78
+ prompt,
79
+ generation_config=genai.types.GenerationConfig(
80
+ max_output_tokens=max_tokens,
81
+ temperature=temperature
82
+ )
83
+ )
84
+ answer = response.text.strip()
85
+ if not answer.endswith('.'):
86
+ last_period = answer.rfind('.')
87
+ if last_period != -1:
88
+ answer = answer[:last_period + 1]
89
+ else:
90
+ answer += "."
91
+
92
+ # Yield the full response (no streaming, as Gemini API doesn’t support it here)
93
+ yield answer
94
 
95
+ # Gradio Chat Interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  demo = gr.ChatInterface(
97
  respond,
98
  additional_inputs=[
99
+ gr.Textbox(
100
+ value="You are a medical AI assistant diagnosing patients based on their query, using relevant context from past records of other patients.",
101
+ label="System message"
102
+ ),
103
+ gr.Slider(minimum=1, maximum=2048, value=150, step=1, label="Max new tokens"),
104
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.75, step=0.1, label="Temperature"),
105
  gr.Slider(
106
  minimum=0.1,
107
  maximum=1.0,
108
  value=0.95,
109
  step=0.05,
110
+ label="Top-p (nucleus sampling)", # Included but not used by Gemini
111
  ),
112
  ],
113
+ title="🏥 Medical Chat Assistant",
114
+ description="A chat-based medical assistant that diagnoses patient queries using AI and past records."
115
  )
116
 
 
117
  if __name__ == "__main__":
118
+ demo.launch()