Spaces:
Sleeping
Sleeping
File size: 3,096 Bytes
432cd4a 00c98e3 3f3da62 432cd4a 0b2dc4c abe2d0f e929713 abe2d0f 00c98e3 8c068ee 00c98e3 8c068ee abe2d0f 8c068ee e929713 00c98e3 0b2dc4c 00c98e3 e929713 00c98e3 0b2dc4c 00c98e3 0b2dc4c 00c98e3 81ab351 00c98e3 3c7c10f b84cd4b 14ddf0d 00c98e3 8ca2a6e 81ab351 14ddf0d b84cd4b 3f3da62 432cd4a e929713 81ab351 432cd4a b84cd4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
model_id = "thrishala/mental_health_chatbot"
try:
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map=device, # Use the determined device
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
max_memory={device: "15GB"}, # Use device-specific memory management
offload_folder="offload",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.model_max_length = 512 # Set maximum length
except Exception as e:
print(f"Error loading model: {e}")
exit()
def generate_text(prompt, max_new_tokens=128):
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=False, # Or True for sampling
eos_token_id=tokenizer.eos_token_id,
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return generated_text
def generate_text_streaming(prompt, max_new_tokens=128):
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
for i in range(max_new_tokens):
output = model.generate(
input_ids=input_ids,
max_new_tokens=1, # Generate only 1 new token at a time
do_sample=False, # Or True for sampling
eos_token_id=tokenizer.eos_token_id,
return_dict=True, #Return a dictionary
output_scores=True #Return the scores
)
generated_token = tokenizer.decode(output.logits[0][-1].argmax(), skip_special_tokens=True) #Decode the last token only
yield generated_token #Yield the last token
input_ids = torch.cat([input_ids, output.sequences[:, -1:]], dim=-1) #Append the new token to the input
if output.sequences[0][-1] == tokenizer.eos_token_id: #Check if the end of sequence token was generated
break #Break the loop
def respond(message, history, system_message, max_tokens):
prompt = f"{system_message}\n"
for user_msg, bot_msg in history:
prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
prompt += f"User: {message}\nAssistant:"
try:
for token in generate_text_streaming(prompt, max_tokens): #Iterate over the generator
yield token #Yield each token individually
except Exception as e:
print(f"Error during generation: {e}")
yield "An error occurred."
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(
value="You are a friendly and helpful mental health chatbot.",
label="System message",
),
gr.Slider(minimum=1, maximum=128, value=32, step=10, label="Max new tokens"),
],
)
if __name__ == "__main__":
demo.launch() |