Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from langchain_community.document_loaders import TextLoader | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain.chains import RetrievalQA | |
from langchain_community.llms import HuggingFacePipeline | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
# Конфигурация | |
DOCS_DIR = "lore" | |
MODEL_NAME = "IlyaGusev/saiga_mistral_7b" # Оптимальная модель для русского | |
EMBEDDINGS_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" | |
# 1. Загрузка документов | |
def load_documents(): | |
docs = [] | |
for filename in os.listdir(DOCS_DIR): | |
if filename.endswith(".txt"): | |
loader = TextLoader(os.path.join(DOCS_DIR, filename), encoding="utf-8") | |
docs.extend(loader.load()) | |
return docs | |
# 2. Подготовка базы знаний | |
def prepare_knowledge_base(): | |
documents = load_documents() | |
# Разбиваем текст на чанки | |
text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
splits = text_splitter.split_documents(documents) | |
# Создаем векторное хранилище | |
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL) | |
vectorstore = FAISS.from_documents(splits, embeddings) | |
return vectorstore | |
# 3. Инициализация языковой модели | |
def load_llm(): | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
device_map="auto", | |
load_in_4bit=True # Экономия памяти | |
) | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=200, | |
temperature=0.3 | |
) | |
return HuggingFacePipeline(pipeline=pipe) | |
# 4. Создание цепочки для вопросов-ответов | |
def create_qa_chain(): | |
vectorstore = prepare_knowledge_base() | |
llm = load_llm() | |
return RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=vectorstore.as_retriever(search_kwargs={"k": 2}), | |
return_source_documents=True | |
) | |
# 5. Функция для ответов | |
def get_answer(question): | |
qa_chain = create_qa_chain() | |
result = qa_chain({"query": question}) | |
# Форматируем ответ | |
answer = result["result"] | |
sources = list(set(doc.metadata["source"] for doc in result["source_documents"])) | |
return f"{answer}\n\nИсточники: {', '.join(sources)}" | |
# 6. Интерфейс Gradio | |
with gr.Blocks() as demo: | |
gr.Markdown("## 🧛 Лор-бот: справочник по сверхъестественному") | |
with gr.Row(): | |
question = gr.Textbox(label="Ваш вопрос", placeholder="Какие слабости у вампиров?") | |
submit_btn = gr.Button("Спросить") | |
answer = gr.Textbox(label="Ответ", interactive=False) | |
submit_btn.click( | |
fn=get_answer, | |
inputs=question, | |
outputs=answer | |
) | |
demo.launch(server_name="0.0.0.0", server_port=7860) |