Arsh-llm-demo / app.py
arshiaafshani's picture
Update app.py
2524cd0 verified
raw
history blame
2.63 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
# Load model and tokenizer
model_name = "arshiaafshani/Arsh-llm"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
# Create pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1
)
def respond(message, chat_history, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty):
# Prepare prompt
prompt = f"{system_message}\n\nUser: {message}\nAssistant:"
# Generate response
output = pipe(
prompt,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repeat_penalty,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
response = output[0]['generated_text'].split("Assistant:")[-1].strip()
# Update chat history
chat_history.append((message, response))
return chat_history
with gr.Blocks() as demo:
gr.Markdown("# Arsh-LLM Demo")
with gr.Row():
with gr.Column():
system_msg = gr.Textbox("You are Arsh, a helpful assistant by Arshia Afshani. You should answer the user carefully.",
label="System Message")
max_tokens = gr.Slider(1, 4096, value=2048, step=1, label="Max Tokens")
temperature = gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
top_k = gr.Slider(0, 100, value=40, step=1, label="Top-k")
repeat_penalty = gr.Slider(0.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty")
chatbot = gr.Chatbot(height=500)
msg = gr.Textbox(label="Your Message")
clear = gr.Button("Clear")
def submit_message(message, chat_history, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty):
chat_history = chat_history or []
response = respond(message, chat_history, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty)
return "", response
msg.submit(
submit_message,
[msg, chatbot, system_msg, max_tokens, temperature, top_p, top_k, repeat_penalty],
[msg, chatbot]
)
clear.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)