triAGI-Coder / app.py
acecalisto3's picture
Update app.py
09f8eba verified
raw
history blame
5.01 kB
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM
import os
import json
import time
import logging
from threading import Lock
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
messages = [
{"role": "user", "content": "Who are you?"},
]
pipe = pipeline("text-generation", model="mistralai/Mixtral-8x7B-Instruct-v0.1")
pipe(messages)
class EnhancedChatbot:
def __init__(self):
self.model = None
self.config = self.load_config()
self.model_lock = Lock()
self.load_model()
def load_config(self):
if os.path.exists(CONFIG_FILE):
with open(CONFIG_FILE, 'r') as f:
return json.load(f)
return {
"model_name": MODEL_NAME,
"max_tokens": 512,
"temperature": 0.7,
"top_p": 0.95,
"system_message": "You are a friendly and helpful AI assistant.",
"gpu_layers": 0
}
def save_config(self):
with open(CONFIG_FILE, 'w') as f:
json.dump(self.config, f, indent=2)
def load_model(self):
try:
self.model = AutoModelForCausalLM.from_pretrained(
self.config["model_name"],
model_type="llama",
gpu_layers=self.config["gpu_layers"],
cache_dir=CACHE_DIR
)
logging.info(f"Model loaded successfully: {self.config['model_name']}")
except Exception as e:
logging.error(f"Error loading model: {str(e)}")
raise
def generate_response(self, message, history, system_message, max_tokens, temperature, top_p):
prompt = f"{system_message}\n\n"
for user_msg, assistant_msg in history:
prompt += f"Human: {user_msg}\nAssistant: {assistant_msg}\n"
prompt += f"Human: {message}\nAssistant: "
start_time = time.time()
with self.model_lock:
generated_text = self.model(
prompt,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
end_time = time.time()
response_time = end_time - start_time
logging.info(f"Response generated in {response_time:.2f} seconds")
return generated_text.strip()
chatbot = EnhancedChatbot()
def respond(message, history, system_message, max_tokens, temperature, top_p):
try:
response = chatbot.generate_response(message, history, system_message, max_tokens, temperature, top_p)
yield response
except Exception as e:
logging.error(f"Error generating response: {str(e)}")
yield "I apologize, but I encountered an error while processing your request. Please try again."
def update_model_config(model_name, gpu_layers):
chatbot.config["model_name"] = model_name
chatbot.config["gpu_layers"] = gpu_layers
chatbot.save_config()
chatbot.load_model()
return f"Model updated to {model_name} with {gpu_layers} GPU layers."
def update_system_message(system_message):
chatbot.config["system_message"] = system_message
chatbot .save_config()
return f"System message updated: {system_message}"
with gr.Blocks() as demo:
gr.Markdown("# Enhanced AI Chatbot")
with gr.Tab("Chat"):
chatbot_interface= gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value=chatbot.config["system_message"], label="System message"),
gr.Slider(minimum=1, maximum=2048, value=chatbot.config["max_tokens"], step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=chatbot.config["temperature"], step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=chatbot.config["top_p"],
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
with gr.Tab("Settings"):
with gr.Group():
gr.Markdown("### Model Settings")
model_name_input = gr.Textbox(value=chatbot.config["model_name"], label="Model name")
gpu_layers_input = gr.Slider(minimum=0, maximum=8, value=chatbot.config["gpu_layers"], step=1, label="GPU layers")
update_model_button = gr.Button("Update model")
update_model_button.click(update_model_config, inputs=[model_name_input, gpu_layers_input], outputs="text")
with gr.Group():
gr.Markdown("### System Message Settings")
system_message_input = gr.Textbox(value=chatbot.config["system_message"], label="System message")
update_system_message_button = gr.Button("Update system message")
update_system_message_button.click(update_system_message, inputs=[system_message_input], outputs="text")
if __name__ == "__main__":
demo.launch()