Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
# Model configuration | |
MODEL_NAME = "DarwinAnim8or/TinyRP" | |
# Global variables for model | |
tokenizer = None | |
model = None | |
def load_model(): | |
"""Load model and tokenizer""" | |
global tokenizer, model | |
try: | |
print("Loading model for CPU inference...") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
torch_dtype=torch.float32, | |
device_map="cpu", | |
trust_remote_code=True | |
) | |
print(f"β Model loaded successfully: {MODEL_NAME}") | |
return True | |
except Exception as e: | |
print(f"β Error loading model: {e}") | |
return False | |
# Sample character presets | |
CHARACTERS = { | |
"Custom Character": "", | |
"Adventurous Knight": "You are Sir Gareth, a brave and noble knight on a quest to save the kingdom. You speak with honor and courage, always ready to help those in need.", | |
"Mysterious Wizard": "You are Eldara, an ancient and wise wizard who speaks in riddles and knows secrets of the mystical arts. You are helpful but often cryptic.", | |
"Friendly Tavern Keeper": "You are Bram, a cheerful tavern keeper who loves telling stories and meeting new travelers. Your tavern is a warm, welcoming place.", | |
"Curious Scientist": "You are Dr. Maya Chen, a brilliant scientist fascinated by discovery. You explain complex concepts simply and love new experiments.", | |
"Space Explorer": "You are Captain Nova, a fearless space explorer who has traveled to distant galaxies. You're brave, curious, and ready for adventure." | |
} | |
def chat_respond(message, history, character_desc, max_tokens, temperature, top_p, rep_penalty): | |
"""Main chat response function""" | |
if not message.strip(): | |
return history | |
if model is None: | |
response = "β Model not loaded. Please check the model path." | |
history.append([message, response]) | |
return history | |
try: | |
# Build ChatML conversation | |
conversation = "" | |
# Add character as system message | |
if character_desc.strip(): | |
conversation += f"<|im_start|>system\n{character_desc}<|im_end|>\n" | |
# Add history | |
for user_msg, bot_msg in history: | |
conversation += f"<|im_start|>user\n{user_msg}<|im_end|>\n" | |
conversation += f"<|im_start|>assistant\n{bot_msg}<|im_end|>\n" | |
# Add current message | |
conversation += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n" | |
# Tokenize | |
inputs = tokenizer.encode(conversation, return_tensors="pt", max_length=900, truncation=True) | |
# Generate | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
repetition_penalty=rep_penalty, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
# Decode response | |
full_text = tokenizer.decode(outputs[0], skip_special_tokens=False) | |
# Extract assistant response | |
if "<|im_start|>assistant\n" in full_text: | |
response = full_text.split("<|im_start|>assistant\n")[-1] | |
response = response.replace("<|im_end|>", "").strip() | |
else: | |
response = "Sorry, couldn't generate a response." | |
# Clean up response | |
response = response.replace("<|im_start|>", "").replace("<|im_end|>", "") | |
response = response.strip() | |
if not response: | |
response = "No response generated." | |
except Exception as e: | |
response = f"Error: {str(e)}" | |
# Add to history | |
history.append([message, response]) | |
return history | |
def load_character(character_name): | |
"""Load character preset""" | |
return CHARACTERS.get(character_name, "") | |
def clear_chat(): | |
"""Clear chat history""" | |
return [] | |
# Load model on startup | |
model_loaded = load_model() | |
# Create interface | |
with gr.Blocks(title="TinyRP Chat") as demo: | |
gr.Markdown("# π TinyRP Character Chat") | |
gr.Markdown("Chat with AI characters using local CPU inference!") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
chatbot = gr.Chatbot(height=500, label="Conversation") | |
msg_box = gr.Textbox(label="Message", placeholder="Type here...") | |
with gr.Column(scale=1): | |
gr.Markdown("### Character") | |
char_dropdown = gr.Dropdown( | |
choices=list(CHARACTERS.keys()), | |
value="Custom Character", | |
label="Preset" | |
) | |
char_text = gr.Textbox( | |
label="Description", | |
lines=4, | |
placeholder="Character description..." | |
) | |
load_btn = gr.Button("Load Character") | |
gr.Markdown("### Settings") | |
max_tokens = gr.Slider(16, 256, 80, label="Max tokens") | |
temperature = gr.Slider(0.1, 2.0, 0.9, label="Temperature") | |
top_p = gr.Slider(0.1, 1.0, 0.85, label="Top-p") | |
rep_penalty = gr.Slider(1.0, 1.5, 1.1, label="Rep penalty") | |
clear_btn = gr.Button("Clear Chat") | |
# Character samples | |
gr.Markdown("### Sample Characters") | |
with gr.Row(): | |
for name in ["Adventurous Knight", "Mysterious Wizard", "Space Explorer"]: | |
gr.Markdown(f"**{name}**: {CHARACTERS[name][:80]}...") | |
# Event handlers - simplified | |
msg_box.submit( | |
fn=chat_respond, | |
inputs=[msg_box, chatbot, char_text, max_tokens, temperature, top_p, rep_penalty], | |
outputs=[chatbot] | |
).then( | |
fn=lambda: "", | |
outputs=[msg_box] | |
) | |
load_btn.click( | |
fn=load_character, | |
inputs=[char_dropdown], | |
outputs=[char_text] | |
) | |
clear_btn.click( | |
fn=clear_chat, | |
outputs=[chatbot] | |
) | |
if __name__ == "__main__": | |
demo.launch() |