File size: 2,091 Bytes
432cd4a
f43b68f
3f3da62
432cd4a
abe2d0f
 
e929713
abe2d0f
 
8c068ee
 
 
 
 
abe2d0f
8c068ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e929713
 
 
 
b84cd4b
 
 
 
 
 
 
3c7c10f
b84cd4b
14ddf0d
 
8ca2a6e
 
14ddf0d
 
 
8c068ee
b84cd4b
14ddf0d
fb27a1f
b84cd4b
 
 
14ddf0d
 
b84cd4b
3f3da62
432cd4a
 
 
e929713
 
 
 
b84cd4b
432cd4a
 
 
 
b84cd4b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import torch

model_id = "thrishala/mental_health_chatbot"

try:
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="cpu",
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        max_memory={"cpu": "15GB"},
        offload_folder="offload",
    )
    
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.model_max_length = 512  # Set maximum length
    
    pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        torch_dtype=torch.float16,
        num_return_sequences=1,
        do_sample=False,
        truncation=True,
        max_new_tokens=128
    )

except Exception as e:
    print(f"Error loading model: {e}")
    exit()

def respond(
    message,
    history,
    system_message,
    max_tokens,
):
    # Construct the prompt with clear separation
    prompt = f"{system_message}\n"
    for user_msg, bot_msg in history:
        prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
    prompt += f"User: {message}\nAssistant:"
    
    try:
        response = pipe(
            prompt,
            max_new_tokens=max_tokens,
            do_sample=False,
            eos_token_id=tokenizer.eos_token_id,  # Use EOS token to stop generation
        )[0]["generated_text"]
        
        # Extract only the new assistant response after the last Assistant: in the prompt
        bot_response = response[len(prompt):].split("User:")[0].strip()  # Take text after prompt and before next User
        yield bot_response
    except Exception as e:
        print(f"Error during generation: {e}")
        yield "An error occurred."

demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(
            value="You are a friendly and helpful mental health chatbot.",
            label="System message",
        ),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
    ],
)

if __name__ == "__main__":
    demo.launch()