Leri777's picture
Update app.py
b04bbf8 verified
raw
history blame
3.1 kB
import os
import gradio as gr
import torch
import logging
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from dotenv import load_dotenv
# Настройка логирования
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Загрузка переменных окружения
load_dotenv()
MODEL_NAME = os.getenv("MODEL_NAME", "mistralai/Mixtral-8x7B-Instruct-v0.1")
HF_TOKEN = os.getenv("HF_TOKEN")
# Проверка доступности токена
if not HF_TOKEN:
logger.error("HF_TOKEN не задан. Пожалуйста, укажите токен доступа Hugging Face в файле .env.")
raise EnvironmentError("Отсутствует токен доступа Hugging Face.")
try:
# Инициализация пайплайна для работы с моделью
logger.info(f"Попытка загрузить модель: {MODEL_NAME}")
pipe = pipeline(
"text-generation",
model=MODEL_NAME,
use_auth_token=HF_TOKEN,
device=0 if torch.cuda.is_available() else -1
)
logger.info("Модель успешно загружена.")
except Exception as e:
logger.error(f"Ошибка при загрузке модели: {e}")
raise
def generate_response(prompt):
"""
Функция для генерации ответа с использованием модели.
Форматирует запрос в соответствии с требованиями модели.
"""
try:
# Форматирование инструкции согласно требованиям модели
formatted_prompt = f"<s>[INST] {prompt} [/INST]</s>"
logger.debug(f"Сформированный запрос: {formatted_prompt}")
response = pipe(formatted_prompt, max_length=150, num_return_sequences=1)
logger.debug(f"Полученный ответ: {response}")
return response[0]['generated_text'].replace(formatted_prompt, "").strip()
except Exception as e:
logger.error(f"Ошибка при генерации ответа: {e}")
return "Произошла ошибка при генерации ответа. Пожалуйста, попробуйте еще раз."
# Интерфейс Gradio для взаимодействия с моделью
def main():
with gr.Blocks() as demo:
gr.Markdown("# Mixtral-8x7B Chat Interface")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Введите ваш запрос", placeholder="Введите текст сюда...")
submit_btn = gr.Button("Сгенерировать ответ")
with gr.Column():
response = gr.Textbox(label="Ответ модели")
submit_btn.click(fn=generate_response, inputs=prompt, outputs=response)
demo.queue().launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
main()