Loversofdeath commited on
Commit
95f2e49
·
verified ·
1 Parent(s): eecd39b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -70
app.py CHANGED
@@ -1,97 +1,91 @@
1
  import os
2
  import gradio as gr
3
- from langchain_community.document_loaders import TextLoader
4
- from langchain.text_splitter import CharacterTextSplitter
5
- from langchain_community.vectorstores import FAISS
6
- from langchain_community.embeddings import HuggingFaceEmbeddings
7
  from langchain.chains import RetrievalQA
8
- from langchain_community.llms import HuggingFacePipeline
9
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
10
 
11
  # Конфигурация
12
  DOCS_DIR = "lore"
13
- MODEL_NAME = "IlyaGusev/saiga_mistral_7b" # Оптимальная модель для русского
14
  EMBEDDINGS_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
 
 
15
 
16
- # 1. Загрузка документов
17
  def load_documents():
18
  docs = []
19
  for filename in os.listdir(DOCS_DIR):
20
  if filename.endswith(".txt"):
21
- loader = TextLoader(os.path.join(DOCS_DIR, filename), encoding="utf-8")
22
- docs.extend(loader.load())
 
 
 
 
 
 
23
  return docs
24
 
25
- # 2. Подготовка базы знаний
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def prepare_knowledge_base():
27
  documents = load_documents()
28
-
29
- # Разбиваем текст на чанки
30
- text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=50)
31
- splits = text_splitter.split_documents(documents)
32
-
33
- # Создаем векторное хранилище
34
- embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL)
35
- vectorstore = FAISS.from_documents(splits, embeddings)
36
-
37
- return vectorstore
38
-
39
- # 3. Инициализация языковой модели
40
- def load_llm():
41
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
42
- model = AutoModelForCausalLM.from_pretrained(
43
- MODEL_NAME,
44
- device_map="auto",
45
- load_in_4bit=True # Экономия памяти
46
- )
47
-
48
- pipe = pipeline(
49
- "text-generation",
50
- model=model,
51
- tokenizer=tokenizer,
52
- max_new_tokens=200,
53
- temperature=0.3
54
  )
55
-
56
- return HuggingFacePipeline(pipeline=pipe)
 
57
 
58
- # 4. Создание цепочки для вопросов-ответов
59
  def create_qa_chain():
60
- vectorstore = prepare_knowledge_base()
61
- llm = load_llm()
62
-
 
 
 
 
 
63
  return RetrievalQA.from_chain_type(
64
  llm=llm,
65
  chain_type="stuff",
66
- retriever=vectorstore.as_retriever(search_kwargs={"k": 2}),
67
- return_source_documents=True
 
68
  )
69
 
70
- # 5. Функция для ответов
71
  def get_answer(question):
72
- qa_chain = create_qa_chain()
73
- result = qa_chain({"query": question})
74
-
75
- # Форматируем ответ
76
- answer = result["result"]
77
- sources = list(set(doc.metadata["source"] for doc in result["source_documents"]))
78
-
79
- return f"{answer}\n\nИсточники: {', '.join(sources)}"
80
 
81
- # 6. Интерфейс Gradio
82
- with gr.Blocks() as demo:
83
- gr.Markdown("## 🧛 Лор-бот: справочник по сверхъестественному")
84
-
85
- with gr.Row():
86
- question = gr.Textbox(label="Ваш вопрос", placeholder="Какие слабости у вампиров?")
87
- submit_btn = gr.Button("Спросить")
88
-
89
- answer = gr.Textbox(label="Ответ", interactive=False)
90
-
91
- submit_btn.click(
92
- fn=get_answer,
93
- inputs=question,
94
- outputs=answer
95
- )
96
 
97
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import os
2
  import gradio as gr
3
+ from langchain.document_loaders import TextLoader
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain.embeddings import HuggingFaceEmbeddings
6
+ from langchain.vectorstores import FAISS
7
  from langchain.chains import RetrievalQA
8
+ from langchain.llms import HuggingFaceHub
 
9
 
10
  # Конфигурация
11
  DOCS_DIR = "lore"
 
12
  EMBEDDINGS_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
13
+ LLM_REPO = "IlyaGusev/saiga_mistral_7b"
14
+ HF_TOKEN = os.getenv("HF_TOKEN") # Добавьте в Secrets Space
15
 
16
+ # 1. Загрузка документов с обработкой ошибок
17
  def load_documents():
18
  docs = []
19
  for filename in os.listdir(DOCS_DIR):
20
  if filename.endswith(".txt"):
21
+ try:
22
+ loader = TextLoader(
23
+ os.path.join(DOCS_DIR, filename),
24
+ encoding="utf-8"
25
+ )
26
+ docs.extend(loader.load())
27
+ except Exception as e:
28
+ print(f"Ошибка загрузки {filename}: {str(e)}")
29
  return docs
30
 
31
+ # 2. Инициализация эмбеддингов с проверкой
32
+ def get_embeddings():
33
+ try:
34
+ return HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL)
35
+ except ImportError:
36
+ raise ImportError(
37
+ "Требуемые пакеты не установлены. "
38
+ "Добавьте в requirements.txt:\n"
39
+ "sentence-transformers\n"
40
+ "torch\n"
41
+ "transformers"
42
+ )
43
+
44
+ # 3. Подготовка базы знаний
45
  def prepare_knowledge_base():
46
  documents = load_documents()
47
+ text_splitter = RecursiveCharacterTextSplitter(
48
+ chunk_size=500,
49
+ chunk_overlap=50,
50
+ separators=["\n\n", "\n", " ", ""]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  )
52
+ splits = text_splitter.split_documents(documents)
53
+ embeddings = get_embeddings()
54
+ return FAISS.from_documents(splits, embeddings)
55
 
56
+ # 4. Создание цепочки QA
57
  def create_qa_chain():
58
+ llm = HuggingFaceHub(
59
+ repo_id=LLM_REPO,
60
+ huggingfacehub_api_token=HF_TOKEN,
61
+ model_kwargs={
62
+ "temperature": 0.3,
63
+ "max_new_tokens": 200
64
+ }
65
+ )
66
  return RetrievalQA.from_chain_type(
67
  llm=llm,
68
  chain_type="stuff",
69
+ retriever=prepare_knowledge_base().as_retriever(
70
+ search_kwargs={"k": 2}
71
+ )
72
  )
73
 
74
+ # 5. Интерфейс с обработкой ошибок
75
  def get_answer(question):
76
+ try:
77
+ qa = create_qa_chain()
78
+ result = qa.run(question)
79
+ return result[:500] # Обрезаем слишком длинные ответы
80
+ except Exception as e:
81
+ return f"⚠️ Ошибка: {str(e)}"
 
 
82
 
83
+ # Запуск приложения
84
+ with gr.Blocks(title="📚 Лор-бот") as app:
85
+ gr.Markdown("## 🧛 Вопрос-ответ по лору")
86
+ question = gr.Textbox(label="Ваш вопрос", placeholder="Какие слабости у вампиров?")
87
+ output = gr.Textbox(label="Ответ", interactive=False)
88
+ btn = gr.Button("Спросить")
89
+ btn.click(get_answer, inputs=question, outputs=output)
 
 
 
 
 
 
 
 
90
 
91
+ app.launch(server_name="0.0.0.0", server_port=7860)