Loversofdeath commited on
Commit
8c5a7b2
·
verified ·
1 Parent(s): 4385c2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -68
app.py CHANGED
@@ -1,90 +1,97 @@
1
  import os
2
- import re
3
- import torch # Добавлен импорт torch
4
  from langchain_community.document_loaders import TextLoader
5
  from langchain.text_splitter import CharacterTextSplitter
6
  from langchain_community.vectorstores import FAISS
7
  from langchain_community.embeddings import HuggingFaceEmbeddings
8
- from langchain_core.prompts import PromptTemplate
9
  from langchain.chains import RetrievalQA
10
- from transformers import pipeline
11
- import gradio as gr
12
 
13
- # 1. Загрузка всех файлов из папки lore/
14
- def load_all_lore_files():
 
 
 
 
 
15
  docs = []
16
- for filename in os.listdir("lore"):
17
  if filename.endswith(".txt"):
18
- loader = TextLoader(os.path.join("lore", filename), encoding="utf-8")
19
  docs.extend(loader.load())
20
  return docs
21
 
22
- # 2. Очистка от спецсимволов
23
- def clean_text(text):
24
- return re.sub(r"\[=.*?\/?]", "", text)
 
 
 
 
 
 
 
 
 
 
25
 
26
- # 3. Настройка эмбеддингов
27
- def create_embeddings():
28
- return HuggingFaceEmbeddings(
29
- model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
 
 
 
30
  )
31
-
32
- # 4. Создание векторной базы
33
- def create_vectorstore(docs, embeddings):
34
- text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
35
- split_docs = text_splitter.split_documents(docs)
36
- for doc in split_docs:
37
- doc.page_content = clean_text(doc.page_content)
38
- return FAISS.from_documents(split_docs, embeddings)
39
-
40
- # 5. Загрузка модели ответа (с проверкой доступности GPU)
41
- def create_llm_pipeline():
42
- return pipeline(
43
  "text-generation",
44
- model="IlyaGusev/saiga2_7b_lora",
45
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
46
- device="cuda:0" if torch.cuda.is_available() else "cpu"
47
- )
48
-
49
- # 6. Объединение в цепочку
50
- def build_chain():
51
- docs = load_all_lore_files()
52
- embeddings = create_embeddings()
53
- vectorstore = create_vectorstore(docs, embeddings)
54
-
55
- retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
56
-
57
- prompt = PromptTemplate(
58
- template="""
59
- Ты — помощник, который отвечает на вопросы по вымышленному лору. Отвечай кратко, точно и на русском языке.
60
- Если в лоре нет нужной информации, честно скажи, что не знаешь.
61
-
62
- Контекст:
63
- {context}
64
-
65
- Вопрос:
66
- {question}
67
-
68
- Ответ:
69
- """,
70
- input_variables=["context", "question"]
71
  )
 
 
72
 
 
 
 
 
 
73
  return RetrievalQA.from_chain_type(
74
- llm=create_llm_pipeline(),
75
- retriever=retriever,
76
- chain_type_kwargs={"prompt": prompt}
 
77
  )
78
 
79
- # 7. Интерфейс
80
- qa_chain = build_chain()
 
 
 
 
 
 
 
 
81
 
82
- def ask_question(question):
83
- return qa_chain.run(question)
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- gr.Interface(
86
- fn=ask_question,
87
- inputs=gr.Textbox(label="Спроси что-нибудь по лору"),
88
- outputs=gr.Textbox(label="Ответ"),
89
- title="Лор-бот"
90
- ).launch()
 
1
  import os
2
+ import gradio as gr
 
3
  from langchain_community.document_loaders import TextLoader
4
  from langchain.text_splitter import CharacterTextSplitter
5
  from langchain_community.vectorstores import FAISS
6
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
7
  from langchain.chains import RetrievalQA
8
+ from langchain_community.llms import HuggingFacePipeline
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
10
 
11
+ # Конфигурация
12
+ DOCS_DIR = "lore"
13
+ MODEL_NAME = "IlyaGusev/saiga_mistral_7b" # Оптимальная модель для русского
14
+ EMBEDDINGS_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
15
+
16
+ # 1. Загрузка документов
17
+ def load_documents():
18
  docs = []
19
+ for filename in os.listdir(DOCS_DIR):
20
  if filename.endswith(".txt"):
21
+ loader = TextLoader(os.path.join(DOCS_DIR, filename), encoding="utf-8")
22
  docs.extend(loader.load())
23
  return docs
24
 
25
+ # 2. Подготовка базы знаний
26
+ def prepare_knowledge_base():
27
+ documents = load_documents()
28
+
29
+ # Разбиваем текст на чанки
30
+ text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=50)
31
+ splits = text_splitter.split_documents(documents)
32
+
33
+ # Создаем векторное хранилище
34
+ embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL)
35
+ vectorstore = FAISS.from_documents(splits, embeddings)
36
+
37
+ return vectorstore
38
 
39
+ # 3. Инициализация языковой модели
40
+ def load_llm():
41
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
42
+ model = AutoModelForCausalLM.from_pretrained(
43
+ MODEL_NAME,
44
+ device_map="auto",
45
+ load_in_4bit=True # Экономия памяти
46
  )
47
+
48
+ pipe = pipeline(
 
 
 
 
 
 
 
 
 
 
49
  "text-generation",
50
+ model=model,
51
+ tokenizer=tokenizer,
52
+ max_new_tokens=200,
53
+ temperature=0.3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  )
55
+
56
+ return HuggingFacePipeline(pipeline=pipe)
57
 
58
+ # 4. Создание цепочки для вопросов-ответов
59
+ def create_qa_chain():
60
+ vectorstore = prepare_knowledge_base()
61
+ llm = load_llm()
62
+
63
  return RetrievalQA.from_chain_type(
64
+ llm=llm,
65
+ chain_type="stuff",
66
+ retriever=vectorstore.as_retriever(search_kwargs={"k": 2}),
67
+ return_source_documents=True
68
  )
69
 
70
+ # 5. Функция для ответов
71
+ def get_answer(question):
72
+ qa_chain = create_qa_chain()
73
+ result = qa_chain({"query": question})
74
+
75
+ # Форматируем ответ
76
+ answer = result["result"]
77
+ sources = list(set(doc.metadata["source"] for doc in result["source_documents"]))
78
+
79
+ return f"{answer}\n\nИсточники: {', '.join(sources)}"
80
 
81
+ # 6. Интерфейс Gradio
82
+ with gr.Blocks() as demo:
83
+ gr.Markdown("## 🧛 Лор-бот: справочник по сверхъестественному")
84
+
85
+ with gr.Row():
86
+ question = gr.Textbox(label="Ваш вопрос", placeholder="Какие слабости у вампиров?")
87
+ submit_btn = gr.Button("Спросить")
88
+
89
+ answer = gr.Textbox(label="Ответ", interactive=False)
90
+
91
+ submit_btn.click(
92
+ fn=get_answer,
93
+ inputs=question,
94
+ outputs=answer
95
+ )
96
 
97
+ demo.launch(server_name="0.0.0.0", server_port=7860)