File size: 2,970 Bytes
432cd4a
00c98e3
3f3da62
432cd4a
0b2dc4c
 
 
 
 
abe2d0f
 
e929713
abe2d0f
 
803024c
8c068ee
 
803024c
8c068ee
abe2d0f
8c068ee
803024c
8c068ee
c2d3107
 
 
e929713
 
 
00c98e3
0b2dc4c
00c98e3
e929713
00c98e3
0b2dc4c
 
 
803024c
00c98e3
803024c
0b2dc4c
 
803024c
0b2dc4c
00c98e3
81ab351
 
 
 
 
 
 
803024c
 
81ab351
803024c
 
81ab351
 
803024c
 
81ab351
803024c
81ab351
803024c
 
81ab351
00c98e3
3c7c10f
b84cd4b
14ddf0d
 
00c98e3
8ca2a6e
803024c
 
14ddf0d
 
b84cd4b
3f3da62
432cd4a
 
 
e929713
 
 
 
803024c
432cd4a
 
 
 
c2d3107
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

model_id = "thrishala/mental_health_chatbot"

try:
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map=device,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        max_memory={device: "15GB"},
        offload_folder="offload",
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.model_max_length = 512

    dummy_input = tokenizer("This is a test.", return_tensors="pt").to(model.device)
    model.generate(input_ids=dummy_input.input_ids, return_dict=True)  # Dummy call

except Exception as e:
    print(f"Error loading model: {e}")
    exit()

def generate_text(prompt, max_new_tokens=128):
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        output = model.generate(
            input_ids=input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            eos_token_id=tokenizer.eos_token_id,
            return_dict=True,  # Explicitly set return_dict=True
        )

    generated_text = tokenizer.decode(output.sequences[0], skip_special_tokens=True)  # Decode from sequences
    return generated_text

def generate_text_streaming(prompt, max_new_tokens=128):
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        for i in range(max_new_tokens):
            output = model.generate(
                input_ids=input_ids,
                max_new_tokens=1,
                do_sample=False,
                eos_token_id=tokenizer.eos_token_id,
                return_dict=True,
                output_scores=True,
            )

            generated_token = tokenizer.decode(output.logits[0][-1].argmax(), skip_special_tokens=True)
            yield generated_token

            input_ids = torch.cat([input_ids, output.sequences[:, -1:]], dim=-1)

            if output.sequences[0][-1] == tokenizer.eos_token_id:
                break

def respond(message, history, system_message, max_tokens):
    prompt = f"{system_message}\n"
    for user_msg, bot_msg in history:
        prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
    prompt += f"User: {message}\nAssistant:"

    try:
        for token in generate_text_streaming(prompt, max_tokens):
            yield token  # Yield each token individually
    except Exception as e:
        print(f"Error during generation: {e}")
        yield "An error occurred."

demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(
            value="You are a friendly and helpful mental health chatbot.",
            label="System message",
        ),
        gr.Slider(minimum=1, maximum=128, value=32, step=1, label="Max new tokens"),
    ],
)

if __name__ == "__main__":
    demo.launch()