File size: 3,096 Bytes
432cd4a
00c98e3
3f3da62
432cd4a
0b2dc4c
 
 
 
 
abe2d0f
 
e929713
abe2d0f
 
00c98e3
8c068ee
 
00c98e3
8c068ee
abe2d0f
8c068ee
 
 
e929713
 
 
00c98e3
0b2dc4c
00c98e3
e929713
00c98e3
0b2dc4c
 
 
 
00c98e3
0b2dc4c
 
 
 
00c98e3
81ab351
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00c98e3
3c7c10f
b84cd4b
14ddf0d
 
00c98e3
8ca2a6e
81ab351
 
 
14ddf0d
 
b84cd4b
3f3da62
432cd4a
 
 
e929713
 
 
 
81ab351
432cd4a
 
 
 
b84cd4b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

model_id = "thrishala/mental_health_chatbot"

try:
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map=device,  # Use the determined device
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        max_memory={device: "15GB"},  # Use device-specific memory management
        offload_folder="offload",
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.model_max_length = 512  # Set maximum length

except Exception as e:
    print(f"Error loading model: {e}")
    exit()

def generate_text(prompt, max_new_tokens=128):
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        output = model.generate(
            input_ids=input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=False,  # Or True for sampling
            eos_token_id=tokenizer.eos_token_id,
        )

    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    return generated_text

def generate_text_streaming(prompt, max_new_tokens=128):
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        for i in range(max_new_tokens):
            output = model.generate(
                input_ids=input_ids,
                max_new_tokens=1,  # Generate only 1 new token at a time
                do_sample=False,  # Or True for sampling
                eos_token_id=tokenizer.eos_token_id,
                return_dict=True, #Return a dictionary
                output_scores=True #Return the scores
            )

            generated_token = tokenizer.decode(output.logits[0][-1].argmax(), skip_special_tokens=True) #Decode the last token only
            yield generated_token #Yield the last token

            input_ids = torch.cat([input_ids, output.sequences[:, -1:]], dim=-1) #Append the new token to the input

            if output.sequences[0][-1] == tokenizer.eos_token_id: #Check if the end of sequence token was generated
                break #Break the loop

def respond(message, history, system_message, max_tokens):
    prompt = f"{system_message}\n"
    for user_msg, bot_msg in history:
        prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
    prompt += f"User: {message}\nAssistant:"

    try:
        for token in generate_text_streaming(prompt, max_tokens): #Iterate over the generator
            yield token #Yield each token individually

    except Exception as e:
        print(f"Error during generation: {e}")
        yield "An error occurred."

demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(
            value="You are a friendly and helpful mental health chatbot.",
            label="System message",
        ),
        gr.Slider(minimum=1, maximum=128, value=32, step=10, label="Max new tokens"),
    ],
)

if __name__ == "__main__":
    demo.launch()