seawolf2357 commited on
Commit
96a2022
ยท
verified ยท
1 Parent(s): 89bcb15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -60
app.py CHANGED
@@ -1,41 +1,33 @@
1
  import gradio as gr
2
- import requests
3
- import os
4
- import json
5
  from datasets import load_dataset
6
  from sentence_transformers import SentenceTransformer, util
7
 
8
- # ๋ฌธ์žฅ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋“œ
9
- model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
10
 
11
- # ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
12
- datasets = [
13
- ("all-processed", "all-processed"),
14
- ("chatdoctor-icliniq", "chatdoctor-icliniq"),
15
- ("chatdoctor_healthcaremagic", "chatdoctor_healthcaremagic"),
16
- ]
17
 
18
- all_datasets = {}
19
- for dataset_name, config in datasets:
20
- all_datasets[dataset_name] = load_dataset("lavita/medical-qa-datasets", config)
21
 
22
  def find_most_similar_data(query):
23
  query_embedding = model.encode(query, convert_to_tensor=True)
24
  most_similar = None
25
  highest_similarity = -1
26
-
27
- for dataset_name, dataset in all_datasets.items():
28
- for split in dataset.keys():
29
- for item in dataset[split]:
30
- if 'question' in item and 'answer' in item:
31
- item_text = f"์งˆ๋ฌธ: {item['question']} ๋‹ต๋ณ€: {item['answer']}"
32
- item_embedding = model.encode(item_text, convert_to_tensor=True)
33
- similarity = util.pytorch_cos_sim(query_embedding, item_embedding).item()
34
-
35
- if similarity > highest_similarity:
36
- highest_similarity = similarity
37
- most_similar = item_text
38
-
39
  return most_similar
40
 
41
  def respond_with_prefix(message, history, max_tokens=10000, temperature=0.7, top_p=0.95):
@@ -86,51 +78,49 @@ def respond_with_prefix(message, history, max_tokens=10000, temperature=0.7, top
86
  7. ๊ธ€์˜ ์ „์ฒด๊ฐ€ ์•„๋‹ˆ๋ผ ์ฑ•ํ„ฐ ๋งˆ๋‹ค ์ตœ์†Œ 1,000์ž ์ด์ƒ์œผ๋กœ ์„ธ ์ฑ•ํ„ฐ๋ฅผ ํฌํ•จํ•˜๋ฉด 3,000์ž ์ด์ƒ ์ž‘์„ฑํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
87
  8. "#ํƒœ๊ทธ"๋ฅผ 10๊ฐœ ์ž‘์„ฑํ•ด์ฃผ์„ธ์š”.
88
  """
89
-
90
- modified_message = system_prefix + message # ์‚ฌ์šฉ์ž ๋ฉ”์‹œ์ง€์— ํ”„๋ฆฌํ”ฝ์Šค ์ ์šฉ
91
 
92
- # ๊ฐ€์žฅ ์œ ์‚ฌํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ๋ฐ์ดํ„ฐ์…‹์—์„œ ์ฐพ๊ธฐ
93
  similar_data = find_most_similar_data(message)
 
94
  if similar_data:
95
- modified_message += "\n\n" + similar_data # ์œ ์‚ฌํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ๋ฉ”์‹œ์ง€์— ์ถ”๊ฐ€
96
-
97
- data = {
98
- "model": "jinjavis:latest",
99
- "prompt": modified_message,
100
- "max_tokens": max_tokens,
101
- "temperature": temperature,
102
- "top_p": top_p
103
- }
104
-
105
- # API ์š”์ฒญ
106
- response = requests.post("http://hugpu.ai:7877/api/generate", json=data, stream=True)
107
-
108
- partial_message = ""
109
- for line in response.iter_lines():
110
- if line:
111
- try:
112
- result = json.loads(line)
113
- if result.get("done", False):
114
- break
115
- new_text = result.get('response', '')
116
- partial_message += new_text
117
- yield partial_message
118
- except json.JSONDecodeError as e:
119
- print(f"Failed to decode JSON: {e}")
120
- yield "An error occurred while processing your request."
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
 
 
 
 
 
123
 
124
- demo = gr.ChatInterface(
 
125
 
 
126
  fn=respond_with_prefix,
127
  additional_inputs=[
128
- gr.Slider(minimum=1, maximum=120000, value=4000, label="Max Tokens"),
129
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, label="Temperature"),
130
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top-P") # Corrected comma placement
131
  ],
132
  theme="Nymbo/Nymbo_Theme"
133
  )
134
 
135
  if __name__ == "__main__":
136
  demo.queue(max_size=4).launch()
 
 
1
  import gradio as gr
2
+ from openai import OpenAI
 
 
3
  from datasets import load_dataset
4
  from sentence_transformers import SentenceTransformer, util
5
 
6
+ # OpenAI ํด๋ผ์ด์–ธํŠธ ์ดˆ๊ธฐํ™”
7
+ client = OpenAI(api_key=os.getenv("OPENAI")) # ์‹ค์ œ API ํ‚ค๋กœ ๊ต์ฒด ํ•„์š”
8
 
9
+ # Load sentence embedding model
10
+ model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
 
 
 
 
11
 
12
+ # Load the PharmKG dataset
13
+ pharmkg_dataset = load_dataset("vinven7/PharmKG")
 
14
 
15
  def find_most_similar_data(query):
16
  query_embedding = model.encode(query, convert_to_tensor=True)
17
  most_similar = None
18
  highest_similarity = -1
19
+
20
+ for split in pharmkg_dataset.keys():
21
+ for item in pharmkg_dataset[split]:
22
+ if 'Input' in item and 'Output' in item:
23
+ item_text = f"Input: {item['Input']} Output: {item['Output']}"
24
+ item_embedding = model.encode(item_text, convert_to_tensor=True)
25
+ similarity = util.pytorch_cos_sim(query_embedding, item_embedding).item()
26
+
27
+ if similarity > highest_similarity:
28
+ highest_similarity = similarity
29
+ most_similar = item_text
30
+
 
31
  return most_similar
32
 
33
  def respond_with_prefix(message, history, max_tokens=10000, temperature=0.7, top_p=0.95):
 
78
  7. ๊ธ€์˜ ์ „์ฒด๊ฐ€ ์•„๋‹ˆ๋ผ ์ฑ•ํ„ฐ ๋งˆ๋‹ค ์ตœ์†Œ 1,000์ž ์ด์ƒ์œผ๋กœ ์„ธ ์ฑ•ํ„ฐ๋ฅผ ํฌํ•จํ•˜๋ฉด 3,000์ž ์ด์ƒ ์ž‘์„ฑํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
79
  8. "#ํƒœ๊ทธ"๋ฅผ 10๊ฐœ ์ž‘์„ฑํ•ด์ฃผ์„ธ์š”.
80
  """
81
+
 
82
 
83
+ # Find the most similar data from PharmKG dataset
84
  similar_data = find_most_similar_data(message)
85
+ context = f"{system_prefix}\n\n{message}"
86
  if similar_data:
87
+ context += f"\n\nRelated Information: {similar_data}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ try:
90
+ response = client.chat.completions.create(
91
+ model="gpt-4o-mini",
92
+ messages=[
93
+ {"role": "system", "content": system_prefix},
94
+ {"role": "user", "content": message}
95
+ ],
96
+ response_format={"type": "text"},
97
+ temperature=temperature,
98
+ max_tokens=max_tokens,
99
+ top_p=top_p,
100
+ frequency_penalty=0,
101
+ presence_penalty=0,
102
+ stream=True
103
+ )
104
 
105
+ partial_message = ""
106
+ for chunk in response:
107
+ if chunk.choices[0].delta.content:
108
+ partial_message += chunk.choices[0].delta.content
109
+ yield partial_message
110
 
111
+ except Exception as e:
112
+ yield f"An error occurred: {str(e)}"
113
 
114
+ demo = gr.ChatInterface(
115
  fn=respond_with_prefix,
116
  additional_inputs=[
117
+ gr.Slider(minimum=1, maximum=4096, value=2048, label="Max Tokens"),
118
+ gr.Slider(minimum=0.1, maximum=2.0, value=1.0, label="Temperature"),
119
+ gr.Slider(minimum=0.1, maximum=1.0, value=1.0, label="Top-P")
120
  ],
121
  theme="Nymbo/Nymbo_Theme"
122
  )
123
 
124
  if __name__ == "__main__":
125
  demo.queue(max_size=4).launch()
126
+