TEST / app.py
Reality123b's picture
Update app.py
b55e187 verified
raw
history blame
11.9 kB
import gradio as gr
from pathlib import Path
import os
from huggingface_hub import snapshot_download
from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage, AssistantMessage, SystemMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
def download_mistral_model():
"""Download Mistral model if not already present."""
print("Checking for Mistral model...")
mistral_models_path = Path.home().joinpath('mistral_models', 'Nemo-Instruct')
# Check if model files already exist
required_files = ["params.json", "consolidated.safetensors", "tekken.json"]
files_exist = all(
mistral_models_path.joinpath(file).exists()
for file in required_files
)
if not files_exist:
print("Downloading Mistral model (this may take a while)...")
mistral_models_path.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id="mistralai/Mistral-Nemo-Instruct-2407",
allow_patterns=required_files,
local_dir=mistral_models_path
)
print("Model downloaded successfully!")
else:
print("Mistral model already downloaded.")
return mistral_models_path
def setup_mistral():
"""Initialize Mistral model and tokenizer."""
mistral_models_path = download_mistral_model()
print("Initializing Mistral model and tokenizer...")
tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json")
model = Transformer.from_folder(mistral_models_path)
return model, tokenizer
# Global variables for model and tokenizer
global_model = None
global_tokenizer = None
def initialize_globals():
"""Initialize global model and tokenizer if not already done."""
global global_model, global_tokenizer
if global_model is None or global_tokenizer is None:
global_model, global_tokenizer = setup_mistral()
def check_custom_responses(message: str) -> str:
"""Check for specific patterns and return custom responses."""
message_lower = message.lower()
custom_responses = {
"what is ur name?": "xylaria",
"what is ur Name?": "xylaria",
"what is Ur name?": "xylaria",
"what is Ur Name?": "xylaria",
"What is ur name?": "xylaria",
"What is ur Name?": "xylaria",
"What is Ur name?": "xylaria",
"What is Ur Name?": "xylaria",
"what's ur name?": "xylaria",
"what's ur Name?": "xylaria",
"what's Ur name?": "xylaria",
"what's Ur Name?": "xylaria",
"whats ur name?": "xylaria",
"whats ur Name?": "xylaria",
"whats Ur name?": "xylaria",
"whats Ur Name?": "xylaria",
"what's your name?": "xylaria",
"what's your Name?": "xylaria",
"what's Your name?": "xylaria",
"what's Your Name?": "xylaria",
"Whats ur name?": "xylaria",
"Whats ur Name?": "xylaria",
"Whats Ur name?": "xylaria",
"Whats Ur Name?": "xylaria",
"What Is Your Name?": "xylaria",
"What Is Ur Name?": "xylaria",
"What Is Your Name?": "xylaria",
"What Is Ur Name?": "xylaria",
"what is your name?": "xylaria",
"what is your Name?": "xylaria",
"what is Your name?": "xylaria",
"what is Your Name?": "xylaria",
"how many 'r' is in strawberry?": "3",
"how many 'R' is in strawberry?": "3",
"how many 'r' Is in strawberry?": "3",
"how many 'R' Is in strawberry?": "3",
"How many 'r' is in strawberry?": "3",
"How many 'R' is in strawberry?": "3",
"How Many 'r' Is In Strawberry?": "3",
"How Many 'R' Is In Strawberry?": "3",
"how many r is in strawberry?": "3",
"how many R is in strawberry?": "3",
"how many r Is in strawberry?": "3",
"how many R Is in strawberry?": "3",
"How many r is in strawberry?": "3",
"How many R is in strawberry?": "3",
"How Many R Is In Strawberry?": "3",
"how many 'r' in strawberry?": "3",
"how many r's are in strawberry?": "3",
"how many Rs are in strawberry?": "3",
"How Many R's Are In Strawberry?": "3",
"How Many Rs Are In Strawberry?": "3",
"who is your developer?": "sk md saad amin",
"who is your Developer?": "sk md saad amin",
"who is Your Developer?": "sk md saad amin",
"who is ur developer?": "sk md saad amin",
"who is ur Developer?": "sk md saad amin",
"who is Your Developer?": "sk md saad amin",
"Who is ur developer?": "sk md saad amin",
"Who is ur Developer?": "sk md saad amin",
"who is ur dev?": "sk md saad amin",
"Who is ur dev?": "sk md saad amin",
"who is your dev?": "sk md saad amin",
"Who is your dev?": "sk md saad amin",
"Who's your developer?": "sk md saad amin",
"Who's ur developer?": "sk md saad amin",
"Who Is Your Developer?": "sk md saad amin",
"Who Is Ur Developer?": "sk md saad amin",
"Who Is Your Dev?": "sk md saad amin",
"Who Is Ur Dev?": "sk md saad amin",
"who's your developer?": "sk md saad amin",
"who's ur developer?": "sk md saad amin",
"who is your devloper?": "sk md saad amin",
"who is ur devloper?": "sk md saad amin",
"how many r is in strawberry?": "3",
"how many R is in strawberry?": "3",
"how many r Is in strawberry?": "3",
"how many R Is in strawberry?": "3",
"How many r is in strawberry?": "3",
"How many R is in strawberry?": "3",
"How Many R Is In Strawberry?": "3",
"how many 'r' is in strawberry?": "3",
"how many 'R' is in strawberry?": "3",
"how many 'r' Is in strawberry?": "3",
"how many 'R' Is in strawberry?": "3",
"How many 'r' is in strawberry?": "3",
"How many 'R' is in strawberry?": "3",
"How Many 'r' Is In Strawberry?": "3",
"How Many 'R' Is In Strawberry?": "3",
"how many r's are in strawberry?": "3",
"how many Rs are in strawberry?": "3",
"How Many R's Are In Strawberry?": "3",
"How Many Rs Are In Strawberry?": "3",
"how many Rs's are in strawberry?": "3",
"wat is ur name?": "xylaria",
"wat is ur Name?": "xylaria",
"wut is ur name?": "xylaria",
"wut ur name?": "xylaria",
"wats ur name?": "xylaria",
"wats ur name": "xylaria",
"who's ur dev?": "sk md saad amin",
"who's your dev?": "sk md saad amin",
"who ur dev?": "sk md saad amin",
"who's ur devloper?": "sk md saad amin",
"how many r in strawbary?": "3",
"how many r in strawbary?": "3",
"how many R in strawbary?": "3",
"how many 'r' in strawbary?": "3",
"how many 'R' in strawbary?": "3",
"how many r in strawbry?": "3",
"how many R in strawbry?": "3",
"how many r is in strawbry?": "3",
"how many 'r' is in strawbry?": "3",
"how many 'R' is in strawbry?": "3",
"who is ur dev": "sk md saad amin",
"who is ur devloper": "sk md saad amin",
"what is ur dev": "sk md saad amin",
"who is ur dev?": "sk md saad amin",
"who is ur dev?": "sk md saad amin",
"whats ur dev?": "sk md saad amin",
}
for pattern, response in custom_responses.items():
if pattern in message_lower:
return response
return None
def is_image_request(message: str) -> bool:
"""Detect if the message is requesting image generation."""
image_triggers = [
"generate an image",
"create an image",
"draw",
"make a picture",
"generate a picture",
"create a picture",
"generate art",
"create art",
"make art",
"visualize",
"show me",
]
message_lower = message.lower()
return any(trigger in message_lower for trigger in image_triggers)
def create_mistral_messages(history, system_message, current_message):
"""Convert chat history to Mistral message format."""
messages = []
# Add system message if provided
if system_message:
messages.append(SystemMessage(content=system_message))
# Add conversation history
for user_msg, assistant_msg in history:
if user_msg:
messages.append(UserMessage(content=user_msg))
if assistant_msg:
messages.append(AssistantMessage(content=assistant_msg))
# Add current message
messages.append(UserMessage(content=current_message))
return messages
def respond(message, history, system_message, max_tokens=16343, temperature=0.7, top_p=0.95):
"""Main response function using Mistral model."""
# First check for custom responses
custom_response = check_custom_responses(message)
if custom_response:
yield custom_response
return
# Check for image requests
if is_image_request(message):
yield "Sorry, image generation is not supported in this implementation."
return
try:
# Initialize global model and tokenizer if needed
initialize_globals()
# Prepare messages for Mistral
mistral_messages = create_mistral_messages(history, system_message, message)
# Create chat completion request
completion_request = ChatCompletionRequest(messages=mistral_messages)
# Encode the request
tokens = global_tokenizer.encode_chat_completion(completion_request).tokens
# Generate response
out_tokens, _ = generate(
[tokens],
global_model,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
eos_id=global_tokenizer.instruct_tokenizer.tokenizer.eos_id
)
# Decode and yield response
response = global_tokenizer.decode(out_tokens[0])
yield response
except Exception as e:
yield f"An error occurred: {str(e)}"
# Custom CSS for the Gradio interface
custom_css = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap');
body, .gradio-container {
font-family: 'Inter', sans-serif;
}
"""
# System message
system_message = """Xylaria (v1.2.9) is an AI assistant developed by Sk Md Saad Amin, designed to provide efficient, practical support in various domains with adaptable communication."""
def main():
print("Starting Mistral Chat Interface...")
print("Initializing model (this may take a few minutes on first run)...")
# Initialize model and tokenizer at startup
initialize_globals()
# Create Gradio interface
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(
value=system_message,
visible=False,
),
gr.Slider(
minimum=1,
maximum=16343,
value=16343,
step=1,
label="Max new tokens"
),
gr.Slider(
minimum=0.1,
maximum=4.0,
value=0.7,
step=0.1,
label="Temperature"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)"
),
],
css=custom_css
)
print("Launch successful! Interface is ready to use.")
demo.launch()
if __name__ == "__main__":
main()