File size: 2,010 Bytes
432cd4a
00c98e3
3f3da62
432cd4a
0b2dc4c
 
 
 
 
abe2d0f
 
e929713
abe2d0f
 
00c98e3
8c068ee
 
00c98e3
8c068ee
abe2d0f
8c068ee
 
 
e929713
 
 
00c98e3
0b2dc4c
00c98e3
e929713
00c98e3
0b2dc4c
 
 
 
00c98e3
0b2dc4c
 
 
 
00c98e3
 
3c7c10f
b84cd4b
14ddf0d
 
00c98e3
8ca2a6e
00c98e3
b84cd4b
14ddf0d
 
b84cd4b
3f3da62
432cd4a
 
 
e929713
 
 
 
00c98e3
432cd4a
 
 
 
b84cd4b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

model_id = "thrishala/mental_health_chatbot"

try:
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map=device,  # Use the determined device
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        max_memory={device: "15GB"},  # Use device-specific memory management
        offload_folder="offload",
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.model_max_length = 512  # Set maximum length

except Exception as e:
    print(f"Error loading model: {e}")
    exit()

def generate_text(prompt, max_new_tokens=128):
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        output = model.generate(
            input_ids=input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=False,  # Or True for sampling
            eos_token_id=tokenizer.eos_token_id,
        )

    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    return generated_text

def respond(message, history, system_message, max_tokens):
    prompt = f"{system_message}\n"
    for user_msg, bot_msg in history:
        prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
    prompt += f"User: {message}\nAssistant:"

    try:
        bot_response = generate_text(prompt, max_tokens)  # Use the new function
        yield bot_response
    except Exception as e:
        print(f"Error during generation: {e}")
        yield "An error occurred."

demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(
            value="You are a friendly and helpful mental health chatbot.",
            label="System message",
        ),
        gr.Slider(minimum=1, maximum=128, value=128, step=10, label="Max new tokens"),
    ],
)

if __name__ == "__main__":
    demo.launch()