File size: 2,627 Bytes
86e3b11
6d40614
 
 
 
e687dc8
6d40614
 
 
 
 
 
 
 
 
71f62cd
9a63c43
d5a5044
6d40614
 
 
 
 
 
 
 
 
 
 
 
 
f9ce403
6d40614
 
 
 
 
c389077
86e3b11
71f62cd
623da4d
8a1b0f9
d5a5044
 
6d40614
2524cd0
d5a5044
623da4d
d5a5044
 
 
 
2524cd0
6d40614
 
d5a5044
2524cd0
 
 
 
 
 
 
 
 
 
6d40614
f9ce403
9a63c43
2524cd0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch

# Load model and tokenizer
model_name = "arshiaafshani/Arsh-llm" 
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)

# Create pipeline
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device=0 if torch.cuda.is_available() else -1
)

def respond(message, chat_history, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty):
    # Prepare prompt
    prompt = f"{system_message}\n\nUser: {message}\nAssistant:"
    
    # Generate response
    output = pipe(
        prompt,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repeat_penalty,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )
    
    response = output[0]['generated_text'].split("Assistant:")[-1].strip()
    
    # Update chat history
    chat_history.append((message, response))
    return chat_history

with gr.Blocks() as demo:
    gr.Markdown("# Arsh-LLM Demo")

    with gr.Row():
        with gr.Column():
            system_msg = gr.Textbox("You are Arsh, a helpful assistant by Arshia Afshani. You should answer the user carefully.", 
                                  label="System Message")
            max_tokens = gr.Slider(1, 4096, value=2048, step=1, label="Max Tokens")
            temperature = gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature")
            top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
            top_k = gr.Slider(0, 100, value=40, step=1, label="Top-k")
            repeat_penalty = gr.Slider(0.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty")

    chatbot = gr.Chatbot(height=500)
    msg = gr.Textbox(label="Your Message")
    clear = gr.Button("Clear")

    def submit_message(message, chat_history, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty):
        chat_history = chat_history or []
        response = respond(message, chat_history, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty)
        return "", response

    msg.submit(
        submit_message,
        [msg, chatbot, system_msg, max_tokens, temperature, top_p, top_k, repeat_penalty],
        [msg, chatbot]
    )
    clear.click(lambda: None, None, chatbot, queue=False)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, share=True)