Spaces:
Sleeping
Sleeping
File size: 2,010 Bytes
432cd4a 00c98e3 3f3da62 432cd4a 0b2dc4c abe2d0f e929713 abe2d0f 00c98e3 8c068ee 00c98e3 8c068ee abe2d0f 8c068ee e929713 00c98e3 0b2dc4c 00c98e3 e929713 00c98e3 0b2dc4c 00c98e3 0b2dc4c 00c98e3 3c7c10f b84cd4b 14ddf0d 00c98e3 8ca2a6e 00c98e3 b84cd4b 14ddf0d b84cd4b 3f3da62 432cd4a e929713 00c98e3 432cd4a b84cd4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
model_id = "thrishala/mental_health_chatbot"
try:
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map=device, # Use the determined device
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
max_memory={device: "15GB"}, # Use device-specific memory management
offload_folder="offload",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.model_max_length = 512 # Set maximum length
except Exception as e:
print(f"Error loading model: {e}")
exit()
def generate_text(prompt, max_new_tokens=128):
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=False, # Or True for sampling
eos_token_id=tokenizer.eos_token_id,
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return generated_text
def respond(message, history, system_message, max_tokens):
prompt = f"{system_message}\n"
for user_msg, bot_msg in history:
prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
prompt += f"User: {message}\nAssistant:"
try:
bot_response = generate_text(prompt, max_tokens) # Use the new function
yield bot_response
except Exception as e:
print(f"Error during generation: {e}")
yield "An error occurred."
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(
value="You are a friendly and helpful mental health chatbot.",
label="System message",
),
gr.Slider(minimum=1, maximum=128, value=128, step=10, label="Max new tokens"),
],
)
if __name__ == "__main__":
demo.launch() |