Spaces:
Sleeping
Sleeping
File size: 2,970 Bytes
432cd4a 00c98e3 3f3da62 432cd4a 0b2dc4c abe2d0f e929713 abe2d0f 803024c 8c068ee 803024c 8c068ee abe2d0f 8c068ee 803024c 8c068ee c2d3107 e929713 00c98e3 0b2dc4c 00c98e3 e929713 00c98e3 0b2dc4c 803024c 00c98e3 803024c 0b2dc4c 803024c 0b2dc4c 00c98e3 81ab351 803024c 81ab351 803024c 81ab351 803024c 81ab351 803024c 81ab351 803024c 81ab351 00c98e3 3c7c10f b84cd4b 14ddf0d 00c98e3 8ca2a6e 803024c 14ddf0d b84cd4b 3f3da62 432cd4a e929713 803024c 432cd4a c2d3107 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
model_id = "thrishala/mental_health_chatbot"
try:
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map=device,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
max_memory={device: "15GB"},
offload_folder="offload",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.model_max_length = 512
dummy_input = tokenizer("This is a test.", return_tensors="pt").to(model.device)
model.generate(input_ids=dummy_input.input_ids, return_dict=True) # Dummy call
except Exception as e:
print(f"Error loading model: {e}")
exit()
def generate_text(prompt, max_new_tokens=128):
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=False,
eos_token_id=tokenizer.eos_token_id,
return_dict=True, # Explicitly set return_dict=True
)
generated_text = tokenizer.decode(output.sequences[0], skip_special_tokens=True) # Decode from sequences
return generated_text
def generate_text_streaming(prompt, max_new_tokens=128):
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
for i in range(max_new_tokens):
output = model.generate(
input_ids=input_ids,
max_new_tokens=1,
do_sample=False,
eos_token_id=tokenizer.eos_token_id,
return_dict=True,
output_scores=True,
)
generated_token = tokenizer.decode(output.logits[0][-1].argmax(), skip_special_tokens=True)
yield generated_token
input_ids = torch.cat([input_ids, output.sequences[:, -1:]], dim=-1)
if output.sequences[0][-1] == tokenizer.eos_token_id:
break
def respond(message, history, system_message, max_tokens):
prompt = f"{system_message}\n"
for user_msg, bot_msg in history:
prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
prompt += f"User: {message}\nAssistant:"
try:
for token in generate_text_streaming(prompt, max_tokens):
yield token # Yield each token individually
except Exception as e:
print(f"Error during generation: {e}")
yield "An error occurred."
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(
value="You are a friendly and helpful mental health chatbot.",
label="System message",
),
gr.Slider(minimum=1, maximum=128, value=32, step=1, label="Max new tokens"),
],
)
if __name__ == "__main__":
demo.launch()
|