File size: 2,363 Bytes
432cd4a
f43b68f
3f3da62
432cd4a
abe2d0f
 
e929713
abe2d0f
 
f43b68f
3c7c10f
 
 
 
f43b68f
 
 
3c7c10f
f43b68f
 
 
 
 
 
3c7c10f
c6784b6
3c7c10f
 
abe2d0f
e929713
 
 
 
 
372767f
3c7c10f
fb27a1f
 
 
8ca2a6e
fb27a1f
14ddf0d
fb27a1f
 
 
 
14ddf0d
8ca2a6e
 
14ddf0d
 
 
c6784b6
3c7c10f
14ddf0d
ff2cb04
14ddf0d
fb27a1f
 
ff2cb04
14ddf0d
 
ff2cb04
3f3da62
432cd4a
 
 
e929713
 
 
 
b7c5b78
432cd4a
3c7c10f
432cd4a
 
 
8ca2a6e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import torch

model_id = "thrishala/mental_health_chatbot"

try:
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="cpu",
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        max_memory={"cpu": "15GB"},
        offload_folder="offload",
    )
    
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.model_max_length = 256  # Set maximum length
    
    pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        torch_dtype=torch.float16,
        num_return_sequences=1,
        do_sample=False,
        truncation=True,
        max_new_tokens=128
    )

except Exception as e:
    print(f"Error loading model: {e}")
    exit()

def respond(message, history, system_message, max_tokens):
    prompt = f"{system_message}\n"

    # Yield the FULL history FIRST (important!)
    full_history = []  # Initialize an empty list for the full history
    for user_msg, bot_msg in reversed(history):  # Reversed to append messages correctly
        full_history.append([user_msg, bot_msg])  # Append the user and bot message to the full history.
        prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
    
    yield full_history  # Yield the full history first!

    # THEN yield the new message/response
    prompt += f"User: {message}\nAssistant:"
    
    try:
        response = pipe(
            prompt,
            max_new_tokens=max_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )[0]["generated_text"]

        bot_response = response.split("Assistant:")[-1].strip()
        
        yield [message, bot_response]  # Yield the new message/response

    except Exception as e:
        print(f"Error during generation: {e}")
        yield [message, "An error occurred during generation."]

demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(
            value="You are a friendly and helpful mental health chatbot.",
            label="System message",
        ),
        gr.Slider(minimum=1, maximum=128, value=128, step=1, label="Max new tokens"),
    ],
    chatbot=gr.Chatbot(type="messages"),  # Updated to new format
)

if __name__ == "__main__":
    demo.launch()