File size: 1,783 Bytes
6dae430
02ce4ee
9afc14b
6dae430
9afc14b
 
 
6dae430
9afc14b
02ce4ee
6dae430
 
 
9afc14b
6dae430
 
 
 
9afc14b
02ce4ee
9afc14b
 
 
6dae430
 
9afc14b
 
6dae430
9afc14b
 
 
 
 
6dae430
 
9afc14b
 
6dae430
9afc14b
 
6dae430
9afc14b
6dae430
9afc14b
6dae430
 
 
9afc14b
02ce4ee
6dae430
 
 
 
 
 
 
 
 
 
 
9afc14b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gradio as gr
import os
from transformers import pipeline, AutoTokenizer

# Load the tokenizer and model using the pipeline
pipe = pipeline("text-generation", model="explorewithai/Loxa-4B", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("explorewithai/Loxa-4B")

# Get the system prompt from environment variables
meo_system = os.environ.get("MEO")

def respond(
    message,
    history,
    max_tokens,
    temperature,
    top_p,
):
    # Format the messages for the pipeline
    messages = [{"role": "system", "content": meo_system}]
    for user_msg, bot_msg in history:
        messages.append({"role": "user", "content": user_msg})
        messages.append({"role": "assistant", "content": bot_msg})
    messages.append({"role": "user", "content": message})

    # Generate the prompt using the tokenizer's chat template
    prompt = tokenizer.apply_chat_template(messages, tokenize=False)

    # Generate the response using the pipeline
    outputs = pipe(
        prompt,
        max_new_tokens=max_tokens,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        return_full_text=False  # We only want the generated part
    )

    # Extract the generated text
    response = outputs[0]['generated_text']

    return response

# Create the Gradio interface
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)

if __name__ == "__main__":
    demo.launch()