DeepMount00's picture
Update app.py
43c73ce verified
raw
history blame
6.6 kB
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import subprocess
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
DESCRIPTION = """\
# Lexora 3B ITA ๐Ÿ’ฌ ๐Ÿ‡ฎ๐Ÿ‡น
"""
# Updated CSS to ensure full height and proper markdown rendering
CUSTOM_CSS = """
.gradio-container {
height: 100vh !important;
max-height: 100vh !important;
padding: 0 !important;
background-color: #0f1117;
}
.contain {
height: 100vh !important;
max-height: 100vh !important;
display: flex;
flex-direction: column;
}
.main-container {
flex-grow: 1;
height: calc(100vh - 100px) !important;
overflow: hidden !important;
}
.chat-container {
height: 100% !important;
overflow: hidden !important;
display: flex;
flex-direction: column;
}
.chat-messages {
flex-grow: 1;
overflow-y: auto !important;
padding: 1rem;
}
.message-wrap {
height: auto !important;
max-height: none !important;
}
.message {
padding: 1rem !important;
margin: 0.5rem 0 !important;
border-radius: 0.5rem !important;
}
.user-message {
background-color: #2b2d31 !important;
}
.bot-message {
background-color: #1e1f23 !important;
}
.examples-container {
margin-top: auto;
}
/* Markdown styling */
.bot-message p {
margin-bottom: 0.5rem;
}
.bot-message h1, .bot-message h2, .bot-message h3,
.bot-message h4, .bot-message h5, .bot-message h6 {
margin-top: 1rem;
margin-bottom: 0.5rem;
}
.bot-message code {
background-color: #2d2d2d;
padding: 0.2rem 0.4rem;
border-radius: 0.2rem;
font-family: monospace;
}
.bot-message pre {
background-color: #2d2d2d;
padding: 1rem;
border-radius: 0.5rem;
overflow-x: auto;
margin: 1rem 0;
}
.bot-message pre code {
background-color: transparent;
padding: 0;
border-radius: 0;
}
.bot-message ul, .bot-message ol {
padding-left: 1.5rem;
margin-bottom: 0.5rem;
}
.bot-message blockquote {
border-left: 3px solid #4a4a4a;
padding-left: 1rem;
margin: 0.5rem 0;
color: #a0a0a0;
}
.bot-message table {
border-collapse: collapse;
width: 100%;
margin: 1rem 0;
}
.bot-message th, .bot-message td {
border: 1px solid #4a4a4a;
padding: 0.5rem;
text-align: left;
}
.bot-message th {
background-color: #2a2a2a;
}
.bot-message img {
max-width: 100%;
height: auto;
margin: 1rem 0;
}
.bot-message a {
color: #3291ff;
text-decoration: none;
}
.bot-message a:hover {
text-decoration: underline;
}
"""
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_id = "DeepMount00/Lexora-Lite-3B"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True,)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=True,
)
model.config.sliding_window = 4096
model.eval()
@spaces.GPU(duration=90)
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_message: str = "",
max_new_tokens: int = 1024,
temperature: float = 0.001,
top_p: float = 1.0,
top_k: int = 50,
repetition_penalty: float = 1.0,
) -> Iterator[str]:
conversation = [{"role": "system", "content": system_message}]
for user, assistant in chat_history:
conversation.extend(
[
{"role": "user", "content": user},
{"role": "assistant", "content": assistant},
]
)
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(
value="",
label="System message",
render=False,
),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0,
maximum=4.0,
step=0.1,
value=0.001,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=1.0,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.0,
),
],
stop_btn=None,
examples=[
["Ciao! Come stai?"],
["Puoi scrivere una lista markdown?"],
["Scrivi un esempio di codice Python"],
],
cache_examples=False,
render_markdown=True, # Enable Markdown rendering
)
with gr.Blocks(css=CUSTOM_CSS, fill_height=True, theme=gr.themes.Base()) as demo:
with gr.Column(elem_classes="contain"):
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
with gr.Column(elem_classes="main-container"):
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()