Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
import torch | |
# Load model and tokenizer | |
model_name = "arshiaafshani/Arsh-llm" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32) | |
# Create pipeline | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
def respond(message, chat_history, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty): | |
# Prepare prompt | |
prompt = f"{system_message}\n\nUser: {message}\nAssistant:" | |
# Generate response | |
output = pipe( | |
prompt, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repeat_penalty, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
response = output[0]['generated_text'].split("Assistant:")[-1].strip() | |
# Update chat history | |
chat_history.append((message, response)) | |
return chat_history | |
with gr.Blocks() as demo: | |
gr.Markdown("# Arsh-LLM Demo") | |
with gr.Row(): | |
with gr.Column(): | |
system_msg = gr.Textbox("You are Arsh, a helpful assistant by Arshia Afshani. You should answer the user carefully.", | |
label="System Message") | |
max_tokens = gr.Slider(1, 4096, value=2048, step=1, label="Max Tokens") | |
temperature = gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature") | |
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p") | |
top_k = gr.Slider(0, 100, value=40, step=1, label="Top-k") | |
repeat_penalty = gr.Slider(0.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty") | |
chatbot = gr.Chatbot(height=500) | |
msg = gr.Textbox(label="Your Message") | |
clear = gr.Button("Clear") | |
def submit_message(message, chat_history, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty): | |
chat_history = chat_history or [] | |
response = respond(message, chat_history, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty) | |
return "", response | |
msg.submit( | |
submit_message, | |
[msg, chatbot, system_msg, max_tokens, temperature, top_p, top_k, repeat_penalty], | |
[msg, chatbot] | |
) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |