Spaces:
Sleeping
Sleeping
import gradio as gr | |
from pathlib import Path | |
from mistral_inference.transformer import Transformer | |
from mistral_inference.generate import generate | |
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer | |
from mistral_common.protocol.instruct.messages import UserMessage, AssistantMessage, SystemMessage | |
from mistral_common.protocol.instruct.request import ChatCompletionRequest | |
def setup_mistral(): | |
"""Initialize Mistral model and tokenizer.""" | |
mistral_models_path = Path.home().joinpath('mistral_models', 'Nemo-Instruct') | |
tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json") | |
model = Transformer.from_folder(mistral_models_path) | |
return model, tokenizer | |
def check_custom_responses(message: str) -> str: | |
"""Check for specific patterns and return custom responses.""" | |
message_lower = message.lower() | |
custom_responses = { | |
"what is ur name?": "xylaria", | |
"what is ur Name?": "xylaria", | |
"what is Ur name?": "xylaria", | |
"what is Ur Name?": "xylaria", | |
"What is ur name?": "xylaria", | |
"What is ur Name?": "xylaria", | |
"What is Ur name?": "xylaria", | |
"What is Ur Name?": "xylaria", | |
"what's ur name?": "xylaria", | |
"what's ur Name?": "xylaria", | |
"what's Ur name?": "xylaria", | |
"what's Ur Name?": "xylaria", | |
"whats ur name?": "xylaria", | |
"whats ur Name?": "xylaria", | |
"whats Ur name?": "xylaria", | |
"whats Ur Name?": "xylaria", | |
"what's your name?": "xylaria", | |
"what's your Name?": "xylaria", | |
"what's Your name?": "xylaria", | |
"what's Your Name?": "xylaria", | |
"Whats ur name?": "xylaria", | |
"Whats ur Name?": "xylaria", | |
"Whats Ur name?": "xylaria", | |
"Whats Ur Name?": "xylaria", | |
"What Is Your Name?": "xylaria", | |
"What Is Ur Name?": "xylaria", | |
"What Is Your Name?": "xylaria", | |
"What Is Ur Name?": "xylaria", | |
"what is your name?": "xylaria", | |
"what is your Name?": "xylaria", | |
"what is Your name?": "xylaria", | |
"what is Your Name?": "xylaria", | |
"how many 'r' is in strawberry?": "3", | |
"how many 'R' is in strawberry?": "3", | |
"how many 'r' Is in strawberry?": "3", | |
"how many 'R' Is in strawberry?": "3", | |
"How many 'r' is in strawberry?": "3", | |
"How many 'R' is in strawberry?": "3", | |
"How Many 'r' Is In Strawberry?": "3", | |
"How Many 'R' Is In Strawberry?": "3", | |
"how many r is in strawberry?": "3", | |
"how many R is in strawberry?": "3", | |
"how many r Is in strawberry?": "3", | |
"how many R Is in strawberry?": "3", | |
"How many r is in strawberry?": "3", | |
"How many R is in strawberry?": "3", | |
"How Many R Is In Strawberry?": "3", | |
"how many 'r' in strawberry?": "3", | |
"how many r's are in strawberry?": "3", | |
"how many Rs are in strawberry?": "3", | |
"How Many R's Are In Strawberry?": "3", | |
"How Many Rs Are In Strawberry?": "3", | |
"who is your developer?": "sk md saad amin", | |
"who is your Developer?": "sk md saad amin", | |
"who is Your Developer?": "sk md saad amin", | |
"who is ur developer?": "sk md saad amin", | |
"who is ur Developer?": "sk md saad amin", | |
"who is Your Developer?": "sk md saad amin", | |
"Who is ur developer?": "sk md saad amin", | |
"Who is ur Developer?": "sk md saad amin", | |
"who is ur dev?": "sk md saad amin", | |
"Who is ur dev?": "sk md saad amin", | |
"who is your dev?": "sk md saad amin", | |
"Who is your dev?": "sk md saad amin", | |
"Who's your developer?": "sk md saad amin", | |
"Who's ur developer?": "sk md saad amin", | |
"Who Is Your Developer?": "sk md saad amin", | |
"Who Is Ur Developer?": "sk md saad amin", | |
"Who Is Your Dev?": "sk md saad amin", | |
"Who Is Ur Dev?": "sk md saad amin", | |
"who's your developer?": "sk md saad amin", | |
"who's ur developer?": "sk md saad amin", | |
"who is your devloper?": "sk md saad amin", | |
"who is ur devloper?": "sk md saad amin", | |
"how many r is in strawberry?": "3", | |
"how many R is in strawberry?": "3", | |
"how many r Is in strawberry?": "3", | |
"how many R Is in strawberry?": "3", | |
"How many r is in strawberry?": "3", | |
"How many R is in strawberry?": "3", | |
"How Many R Is In Strawberry?": "3", | |
"how many 'r' is in strawberry?": "3", | |
"how many 'R' is in strawberry?": "3", | |
"how many 'r' Is in strawberry?": "3", | |
"how many 'R' Is in strawberry?": "3", | |
"How many 'r' is in strawberry?": "3", | |
"How many 'R' is in strawberry?": "3", | |
"How Many 'r' Is In Strawberry?": "3", | |
"How Many 'R' Is In Strawberry?": "3", | |
"how many r's are in strawberry?": "3", | |
"how many Rs are in strawberry?": "3", | |
"How Many R's Are In Strawberry?": "3", | |
"How Many Rs Are In Strawberry?": "3", | |
"how many Rs's are in strawberry?": "3", | |
"wat is ur name?": "xylaria", | |
"wat is ur Name?": "xylaria", | |
"wut is ur name?": "xylaria", | |
"wut ur name?": "xylaria", | |
"wats ur name?": "xylaria", | |
"wats ur name": "xylaria", | |
"who's ur dev?": "sk md saad amin", | |
"who's your dev?": "sk md saad amin", | |
"who ur dev?": "sk md saad amin", | |
"who's ur devloper?": "sk md saad amin", | |
"how many r in strawbary?": "3", | |
"how many r in strawbary?": "3", | |
"how many R in strawbary?": "3", | |
"how many 'r' in strawbary?": "3", | |
"how many 'R' in strawbary?": "3", | |
"how many r in strawbry?": "3", | |
"how many R in strawbry?": "3", | |
"how many r is in strawbry?": "3", | |
"how many 'r' is in strawbry?": "3", | |
"how many 'R' is in strawbry?": "3", | |
"who is ur dev": "sk md saad amin", | |
"who is ur devloper": "sk md saad amin", | |
"what is ur dev": "sk md saad amin", | |
"who is ur dev?": "sk md saad amin", | |
"who is ur dev?": "sk md saad amin", | |
"whats ur dev?": "sk md saad amin", | |
} | |
for pattern, response in custom_responses.items(): | |
if pattern in message_lower: | |
return response | |
return None | |
def is_image_request(message: str) -> bool: | |
"""Detect if the message is requesting image generation.""" | |
image_triggers = [ | |
"generate an image", | |
"create an image", | |
"draw", | |
"make a picture", | |
"generate a picture", | |
"create a picture", | |
"generate art", | |
"create art", | |
"make art", | |
"visualize", | |
"show me", | |
] | |
message_lower = message.lower() | |
return any(trigger in message_lower for trigger in image_triggers) | |
def generate_image(prompt: str) -> str: | |
"""Generate an image using DALLE-4K model.""" | |
try: | |
response = image_client.text_to_image( | |
prompt, | |
parameters={ | |
"negative_prompt": "(worst quality, low quality, illustration, 3d, 2d, painting, cartoons, sketch), open mouth", | |
"num_inference_steps": 30, | |
"guidance_scale": 7.5, | |
"sampling_steps": 15, | |
"upscaler": "4x-UltraSharp", | |
"denoising_strength": 0.5, | |
} | |
) | |
return response | |
except Exception as e: | |
print(f"Image generation error: {e}") | |
return None | |
def create_mistral_messages(history, system_message, current_message): | |
"""Convert chat history to Mistral message format.""" | |
messages = [] | |
# Add system message if provided | |
if system_message: | |
messages.append(SystemMessage(content=system_message)) | |
# Add conversation history | |
for user_msg, assistant_msg in history: | |
if user_msg: | |
messages.append(UserMessage(content=user_msg)) | |
if assistant_msg: | |
messages.append(AssistantMessage(content=assistant_msg)) | |
# Add current message | |
messages.append(UserMessage(content=current_message)) | |
return messages | |
def respond(message, history, system_message, max_tokens=16343, temperature=0.7, top_p=0.95): | |
"""Main response function using Mistral model.""" | |
# First check for custom responses | |
custom_response = check_custom_responses(message) | |
if custom_response: | |
yield custom_response | |
return | |
# Check for image requests | |
if is_image_request(message): | |
yield "Sorry, image generation is not supported in this implementation." | |
return | |
try: | |
# Get or initialize Mistral model and tokenizer | |
model, tokenizer = setup_mistral() | |
# Prepare messages for Mistral | |
mistral_messages = create_mistral_messages(history, system_message, message) | |
# Create chat completion request | |
completion_request = ChatCompletionRequest(messages=mistral_messages) | |
# Encode the request | |
tokens = tokenizer.encode_chat_completion(completion_request).tokens | |
# Generate response | |
out_tokens, _ = generate( | |
[tokens], | |
model, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id | |
) | |
# Decode and yield response | |
response = tokenizer.decode(out_tokens[0]) | |
yield response | |
except Exception as e: | |
yield f"An error occurred: {str(e)}" | |
# Custom CSS for the Gradio interface | |
custom_css = """ | |
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap'); | |
body, .gradio-container { | |
font-family: 'Inter', sans-serif; | |
} | |
""" | |
# System message | |
system_message = """Xylaria (v1.2.9) is an AI assistant developed by Sk Md Saad Amin, designed to provide efficient, practical support in various domains with adaptable communication.""" | |
# Create Gradio interface | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox( | |
value=system_message, | |
visible=False, | |
), | |
gr.Slider( | |
minimum=1, | |
maximum=16343, | |
value=16343, | |
step=1, | |
label="Max new tokens" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=4.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)" | |
), | |
], | |
css=custom_css | |
) | |
if __name__ == "__main__": | |
demo.launch() |