Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -5,13 +5,13 @@ import gradio as gr
|
|
5 |
# Environment and Model/Client Initialization
|
6 |
# ------------------------------------------------------------------------------
|
7 |
try:
|
8 |
-
# Assume we’re in Google Colab
|
9 |
from google.colab import userdata
|
10 |
HF_TOKEN = userdata.get('HF_TOKEN')
|
11 |
import torch
|
12 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
13 |
|
14 |
-
#
|
15 |
torch.backends.cudnn.benchmark = True
|
16 |
|
17 |
model_name = "HuggingFaceH4/zephyr-7b-beta"
|
@@ -33,7 +33,6 @@ except ImportError:
|
|
33 |
from transformers import AutoTokenizer
|
34 |
|
35 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
36 |
-
# If an HF_TOKEN is provided and valid, it can be passed; otherwise, omit it.
|
37 |
hf_token = os.getenv("HF_TOKEN", None)
|
38 |
if hf_token:
|
39 |
client = InferenceClient(model_name, token=hf_token)
|
@@ -70,8 +69,8 @@ Start the conversation by expressing your current feelings or challenges from th
|
|
70 |
def build_prompt(history: list[tuple[str, str]], system_message: str, message: str, max_response_words: int) -> str:
|
71 |
"""
|
72 |
Build a text prompt (for local inference) that starts with the system message,
|
73 |
-
includes conversation history with "Doctor:" and "Patient:"
|
74 |
-
a new "Doctor:" line prompting the patient.
|
75 |
"""
|
76 |
prompt = system_message.format(max_response_words=max_response_words) + "\n"
|
77 |
for user_msg, assistant_msg in history:
|
@@ -81,7 +80,6 @@ def build_prompt(history: list[tuple[str, str]], system_message: str, message: s
|
|
81 |
prompt += f"Doctor: {message}\nPatient: "
|
82 |
return prompt
|
83 |
|
84 |
-
|
85 |
def build_messages(history: list[tuple[str, str]], system_message: str, message: str, max_response_words: int):
|
86 |
"""
|
87 |
Build a messages list (for InferenceClient) using OpenAI-style formatting.
|
@@ -96,7 +94,6 @@ def build_messages(history: list[tuple[str, str]], system_message: str, message:
|
|
96 |
messages.append({"role": "user", "content": f"Doctor: {message}\nPatient:"})
|
97 |
return messages
|
98 |
|
99 |
-
|
100 |
def truncate_response(text: str, max_words: int) -> str:
|
101 |
"""
|
102 |
Truncate the response text to the specified maximum number of words.
|
@@ -106,7 +103,6 @@ def truncate_response(text: str, max_words: int) -> str:
|
|
106 |
return " ".join(words[:max_words]) + "..."
|
107 |
return text
|
108 |
|
109 |
-
|
110 |
# ------------------------------------------------------------------------------
|
111 |
# Response Function
|
112 |
# ------------------------------------------------------------------------------
|
@@ -120,7 +116,7 @@ def respond(
|
|
120 |
max_response_words: int,
|
121 |
):
|
122 |
"""
|
123 |
-
Generate a response. For local inference, use
|
124 |
For non-local inference, use client.chat_completion() with streaming tokens.
|
125 |
"""
|
126 |
if inference_mode == "local":
|
@@ -142,7 +138,7 @@ def respond(
|
|
142 |
messages = build_messages(history, system_message, message, max_response_words)
|
143 |
response = ""
|
144 |
try:
|
145 |
-
#
|
146 |
for chunk in client.chat_completion(
|
147 |
messages,
|
148 |
max_tokens=max_tokens,
|
@@ -150,7 +146,6 @@ def respond(
|
|
150 |
temperature=temperature,
|
151 |
top_p=top_p,
|
152 |
):
|
153 |
-
# The chunk returns a dictionary; get the token from the delta.
|
154 |
token = chunk.choices[0].delta.get("content", "")
|
155 |
response += token
|
156 |
truncated_response = truncate_response(response, max_response_words)
|
@@ -159,7 +154,6 @@ def respond(
|
|
159 |
print(f"An error occurred: {e}")
|
160 |
return "I'm sorry, I encountered an error. Please try again."
|
161 |
|
162 |
-
|
163 |
# ------------------------------------------------------------------------------
|
164 |
# Optional Initial Message and Gradio Interface
|
165 |
# ------------------------------------------------------------------------------
|
@@ -167,7 +161,7 @@ initial_user_message = (
|
|
167 |
"I’m sorry you’ve been feeling overwhelmed. Could you tell me more about your arguments with your partner and how that’s affecting you?"
|
168 |
)
|
169 |
|
170 |
-
#
|
171 |
demo = gr.ChatInterface(
|
172 |
fn=respond,
|
173 |
additional_inputs=[
|
@@ -179,9 +173,7 @@ demo = gr.ChatInterface(
|
|
179 |
],
|
180 |
title="Patient Interview Practice Chatbot",
|
181 |
description="Simulate a patient interview. You (the user) act as the doctor, and the chatbot replies with the patient's perspective only.",
|
182 |
-
chatbot_kwargs={"type": "messages"},
|
183 |
)
|
184 |
|
185 |
if __name__ == "__main__":
|
186 |
-
# In Spaces, do not set share=True.
|
187 |
demo.launch()
|
|
|
5 |
# Environment and Model/Client Initialization
|
6 |
# ------------------------------------------------------------------------------
|
7 |
try:
|
8 |
+
# Assume we’re in Google Colab or another local environment with PyTorch
|
9 |
from google.colab import userdata
|
10 |
HF_TOKEN = userdata.get('HF_TOKEN')
|
11 |
import torch
|
12 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
13 |
|
14 |
+
# Performance tweak
|
15 |
torch.backends.cudnn.benchmark = True
|
16 |
|
17 |
model_name = "HuggingFaceH4/zephyr-7b-beta"
|
|
|
33 |
from transformers import AutoTokenizer
|
34 |
|
35 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
36 |
hf_token = os.getenv("HF_TOKEN", None)
|
37 |
if hf_token:
|
38 |
client = InferenceClient(model_name, token=hf_token)
|
|
|
69 |
def build_prompt(history: list[tuple[str, str]], system_message: str, message: str, max_response_words: int) -> str:
|
70 |
"""
|
71 |
Build a text prompt (for local inference) that starts with the system message,
|
72 |
+
includes conversation history with "Doctor:" and "Patient:" labels,
|
73 |
+
and ends with a new "Doctor:" line prompting the patient.
|
74 |
"""
|
75 |
prompt = system_message.format(max_response_words=max_response_words) + "\n"
|
76 |
for user_msg, assistant_msg in history:
|
|
|
80 |
prompt += f"Doctor: {message}\nPatient: "
|
81 |
return prompt
|
82 |
|
|
|
83 |
def build_messages(history: list[tuple[str, str]], system_message: str, message: str, max_response_words: int):
|
84 |
"""
|
85 |
Build a messages list (for InferenceClient) using OpenAI-style formatting.
|
|
|
94 |
messages.append({"role": "user", "content": f"Doctor: {message}\nPatient:"})
|
95 |
return messages
|
96 |
|
|
|
97 |
def truncate_response(text: str, max_words: int) -> str:
|
98 |
"""
|
99 |
Truncate the response text to the specified maximum number of words.
|
|
|
103 |
return " ".join(words[:max_words]) + "..."
|
104 |
return text
|
105 |
|
|
|
106 |
# ------------------------------------------------------------------------------
|
107 |
# Response Function
|
108 |
# ------------------------------------------------------------------------------
|
|
|
116 |
max_response_words: int,
|
117 |
):
|
118 |
"""
|
119 |
+
Generate a response. For local inference, use model.generate() on a prompt.
|
120 |
For non-local inference, use client.chat_completion() with streaming tokens.
|
121 |
"""
|
122 |
if inference_mode == "local":
|
|
|
138 |
messages = build_messages(history, system_message, message, max_response_words)
|
139 |
response = ""
|
140 |
try:
|
141 |
+
# Generate response using streaming chat_completion
|
142 |
for chunk in client.chat_completion(
|
143 |
messages,
|
144 |
max_tokens=max_tokens,
|
|
|
146 |
temperature=temperature,
|
147 |
top_p=top_p,
|
148 |
):
|
|
|
149 |
token = chunk.choices[0].delta.get("content", "")
|
150 |
response += token
|
151 |
truncated_response = truncate_response(response, max_response_words)
|
|
|
154 |
print(f"An error occurred: {e}")
|
155 |
return "I'm sorry, I encountered an error. Please try again."
|
156 |
|
|
|
157 |
# ------------------------------------------------------------------------------
|
158 |
# Optional Initial Message and Gradio Interface
|
159 |
# ------------------------------------------------------------------------------
|
|
|
161 |
"I’m sorry you’ve been feeling overwhelmed. Could you tell me more about your arguments with your partner and how that’s affecting you?"
|
162 |
)
|
163 |
|
164 |
+
# Remove chatbot_kwargs (unsupported in the current ChatInterface) to avoid error.
|
165 |
demo = gr.ChatInterface(
|
166 |
fn=respond,
|
167 |
additional_inputs=[
|
|
|
173 |
],
|
174 |
title="Patient Interview Practice Chatbot",
|
175 |
description="Simulate a patient interview. You (the user) act as the doctor, and the chatbot replies with the patient's perspective only.",
|
|
|
176 |
)
|
177 |
|
178 |
if __name__ == "__main__":
|
|
|
179 |
demo.launch()
|