lepidus / app.py
Loversofdeath's picture
Update app.py
8c5a7b2 verified
raw
history blame
3.32 kB
import os
import gradio as gr
from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain_community.llms import HuggingFacePipeline
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# Конфигурация
DOCS_DIR = "lore"
MODEL_NAME = "IlyaGusev/saiga_mistral_7b" # Оптимальная модель для русского
EMBEDDINGS_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
# 1. Загрузка документов
def load_documents():
docs = []
for filename in os.listdir(DOCS_DIR):
if filename.endswith(".txt"):
loader = TextLoader(os.path.join(DOCS_DIR, filename), encoding="utf-8")
docs.extend(loader.load())
return docs
# 2. Подготовка базы знаний
def prepare_knowledge_base():
documents = load_documents()
# Разбиваем текст на чанки
text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=50)
splits = text_splitter.split_documents(documents)
# Создаем векторное хранилище
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL)
vectorstore = FAISS.from_documents(splits, embeddings)
return vectorstore
# 3. Инициализация языковой модели
def load_llm():
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto",
load_in_4bit=True # Экономия памяти
)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=200,
temperature=0.3
)
return HuggingFacePipeline(pipeline=pipe)
# 4. Создание цепочки для вопросов-ответов
def create_qa_chain():
vectorstore = prepare_knowledge_base()
llm = load_llm()
return RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vectorstore.as_retriever(search_kwargs={"k": 2}),
return_source_documents=True
)
# 5. Функция для ответов
def get_answer(question):
qa_chain = create_qa_chain()
result = qa_chain({"query": question})
# Форматируем ответ
answer = result["result"]
sources = list(set(doc.metadata["source"] for doc in result["source_documents"]))
return f"{answer}\n\nИсточники: {', '.join(sources)}"
# 6. Интерфейс Gradio
with gr.Blocks() as demo:
gr.Markdown("## 🧛 Лор-бот: справочник по сверхъестественному")
with gr.Row():
question = gr.Textbox(label="Ваш вопрос", placeholder="Какие слабости у вампиров?")
submit_btn = gr.Button("Спросить")
answer = gr.Textbox(label="Ответ", interactive=False)
submit_btn.click(
fn=get_answer,
inputs=question,
outputs=answer
)
demo.launch(server_name="0.0.0.0", server_port=7860)