File size: 3,324 Bytes
05647e2
8c5a7b2
4b347f0
 
 
 
05647e2
8c5a7b2
 
a73e1ef
8c5a7b2
 
 
 
 
 
 
4b347f0
8c5a7b2
4b347f0
8c5a7b2
4b347f0
 
b99265e
8c5a7b2
 
 
 
 
 
 
 
 
 
 
 
 
4b347f0
8c5a7b2
 
 
 
 
 
 
4b347f0
8c5a7b2
 
ab35c47
8c5a7b2
 
 
 
4b347f0
8c5a7b2
 
4b347f0
8c5a7b2
 
 
 
 
b99265e
8c5a7b2
 
 
 
b99265e
a73e1ef
8c5a7b2
 
 
 
 
 
 
 
 
 
a73e1ef
8c5a7b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b347f0
8c5a7b2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import gradio as gr
from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain_community.llms import HuggingFacePipeline
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

# Конфигурация
DOCS_DIR = "lore"
MODEL_NAME = "IlyaGusev/saiga_mistral_7b"  # Оптимальная модель для русского
EMBEDDINGS_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"

# 1. Загрузка документов
def load_documents():
    docs = []
    for filename in os.listdir(DOCS_DIR):
        if filename.endswith(".txt"):
            loader = TextLoader(os.path.join(DOCS_DIR, filename), encoding="utf-8")
            docs.extend(loader.load())
    return docs

# 2. Подготовка базы знаний
def prepare_knowledge_base():
    documents = load_documents()
    
    # Разбиваем текст на чанки
    text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=50)
    splits = text_splitter.split_documents(documents)
    
    # Создаем векторное хранилище
    embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL)
    vectorstore = FAISS.from_documents(splits, embeddings)
    
    return vectorstore

# 3. Инициализация языковой модели
def load_llm():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        device_map="auto",
        load_in_4bit=True  # Экономия памяти
    )
    
    pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=200,
        temperature=0.3
    )
    
    return HuggingFacePipeline(pipeline=pipe)

# 4. Создание цепочки для вопросов-ответов
def create_qa_chain():
    vectorstore = prepare_knowledge_base()
    llm = load_llm()
    
    return RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=vectorstore.as_retriever(search_kwargs={"k": 2}),
        return_source_documents=True
    )

# 5. Функция для ответов
def get_answer(question):
    qa_chain = create_qa_chain()
    result = qa_chain({"query": question})
    
    # Форматируем ответ
    answer = result["result"]
    sources = list(set(doc.metadata["source"] for doc in result["source_documents"]))
    
    return f"{answer}\n\nИсточники: {', '.join(sources)}"

# 6. Интерфейс Gradio
with gr.Blocks() as demo:
    gr.Markdown("## 🧛 Лор-бот: справочник по сверхъестественному")
    
    with gr.Row():
        question = gr.Textbox(label="Ваш вопрос", placeholder="Какие слабости у вампиров?")
        submit_btn = gr.Button("Спросить")
    
    answer = gr.Textbox(label="Ответ", interactive=False)
    
    submit_btn.click(
        fn=get_answer,
        inputs=question,
        outputs=answer
    )

demo.launch(server_name="0.0.0.0", server_port=7860)