Leri777's picture
Update app.py
cfdd958 verified
raw
history blame
2.75 kB
import os
import logging
from threading import Thread
from logging.handlers import RotatingFileHandler
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
# Logging setup
log_file = '/tmp/app_debug.log'
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
file_handler = RotatingFileHandler(log_file, maxBytes=10*1024*1024, backupCount=5)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(file_handler)
logger.debug("Application started")
# Define model parameters
MODEL_ID = "Qwen/Qwen2.5-Coder-7B-Instruct"
CONTEXT_LENGTH = 16000
# Configuration for 4-bit quantization
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16
)
# Load tokenizer and model
if torch.cuda.is_available():
logger.debug("GPU is available. Proceeding with GPU setup.")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
quantization_config=quantization_config,
trust_remote_code=True,
)
else:
logger.warning("GPU is not available. Proceeding with CPU setup.")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True,
)
# Create Hugging Face pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_length=CONTEXT_LENGTH,
temperature=0.7,
top_k=50,
top_p=0.9,
repetition_penalty=1.2,
)
# Prediction function using the model directly
def predict(
message,
temperature,
max_new_tokens,
top_k,
repetition_penalty,
top_p,
):
try:
result = pipe(message, temperature=temperature, max_length=max_new_tokens, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty)
return result[0]['generated_text']
except Exception as e:
logger.exception(f"Error during prediction: {e}")
return "An error occurred."
# Gradio UI
interface = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label="User input"),
gr.Slider(0, 1, 0.7, label="Temperature"),
gr.Slider(128, 2048, 1024, label="Max new tokens"),
gr.Slider(1, 80, 40, label="Top K sampling"),
gr.Slider(0, 2, 1.1, label="Repetition penalty"),
gr.Slider(0, 1, 0.95, label="Top P sampling"),
],
outputs="text",
live=True,
)
interface.launch()
logger.debug("Chat interface initialized and launched")