Spaces:
Sleeping
Sleeping
File size: 2,363 Bytes
432cd4a f43b68f 3f3da62 432cd4a abe2d0f e929713 abe2d0f f43b68f 3c7c10f f43b68f 3c7c10f f43b68f 3c7c10f c6784b6 3c7c10f abe2d0f e929713 372767f 3c7c10f fb27a1f 8ca2a6e fb27a1f 14ddf0d fb27a1f 14ddf0d 8ca2a6e 14ddf0d c6784b6 3c7c10f 14ddf0d ff2cb04 14ddf0d fb27a1f ff2cb04 14ddf0d ff2cb04 3f3da62 432cd4a e929713 b7c5b78 432cd4a 3c7c10f 432cd4a 8ca2a6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import torch
model_id = "thrishala/mental_health_chatbot"
try:
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="cpu",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
max_memory={"cpu": "15GB"},
offload_folder="offload",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.model_max_length = 256 # Set maximum length
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.float16,
num_return_sequences=1,
do_sample=False,
truncation=True,
max_new_tokens=128
)
except Exception as e:
print(f"Error loading model: {e}")
exit()
def respond(message, history, system_message, max_tokens):
prompt = f"{system_message}\n"
# Yield the FULL history FIRST (important!)
full_history = [] # Initialize an empty list for the full history
for user_msg, bot_msg in reversed(history): # Reversed to append messages correctly
full_history.append([user_msg, bot_msg]) # Append the user and bot message to the full history.
prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
yield full_history # Yield the full history first!
# THEN yield the new message/response
prompt += f"User: {message}\nAssistant:"
try:
response = pipe(
prompt,
max_new_tokens=max_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)[0]["generated_text"]
bot_response = response.split("Assistant:")[-1].strip()
yield [message, bot_response] # Yield the new message/response
except Exception as e:
print(f"Error during generation: {e}")
yield [message, "An error occurred during generation."]
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(
value="You are a friendly and helpful mental health chatbot.",
label="System message",
),
gr.Slider(minimum=1, maximum=128, value=128, step=1, label="Max new tokens"),
],
chatbot=gr.Chatbot(type="messages"), # Updated to new format
)
if __name__ == "__main__":
demo.launch()
|