army / app.py
M17idd's picture
Update app.py
21dbe18
raw
history blame
6.64 kB
import os
import time
import streamlit as st
from langchain.chat_models import ChatOpenAI
from transformers import AutoTokenizer, AutoModel
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document as LangchainDocument
from langchain.chains import RetrievalQA
from langchain.llms import OpenAI
import torch
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document
from typing import List
from pydantic import Field
from groq import Groq
# ----------------- تنظیمات صفحه -----------------
st.set_page_config(page_title="چت‌بات ارتش - فقط از PDF", page_icon="🪖", layout="wide")
# ----------------- بارگذاری مدل FarsiBERT -----------------
model_name = "HooshvareLab/bert-fa-zwnj-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
# ----------------- لود PDF و ساخت ایندکس -----------------
import os
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document as LangchainDocument
from sentence_transformers import SentenceTransformer
import numpy as np
@st.cache_resource
def build_pdf_index():
with st.spinner('📄 در حال پردازش فایل PDF...'):
# بارگذاری فایل
loader = PyPDFLoader("test1.pdf")
pages = loader.load()
# تکه‌تکه کردن متن
splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50
)
texts = []
for page in pages:
texts.extend(splitter.split_text(page.page_content))
documents = [LangchainDocument(page_content=t) for t in texts]
# مدل‌های Embedding
tokenizer = AutoTokenizer.from_pretrained("HooshvareLab/bert-fa-zwnj-base")
bert_model = AutoModel.from_pretrained("HooshvareLab/bert-fa-zwnj-base")
sentence_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
embeddings = []
batch_size = 16
for i in range(0, len(documents), batch_size):
batch_docs = documents[i:i+batch_size]
batch_texts = [doc.page_content for doc in batch_docs]
# اول تلاش با مدل SentenceTransformer (خیلی سریعتره)
try:
batch_embeddings = sentence_model.encode(batch_texts, batch_size=batch_size, convert_to_numpy=True)
except Exception as e:
st.error(f"❌ خطا در SentenceTransformer: {e}")
batch_embeddings = []
# اگر موفق نبود، استفاده از BERT
if batch_embeddings == []:
inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = bert_model(**inputs)
batch_embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
embeddings.extend(batch_embeddings)
# اطمینان که خروجی NumpyArray باشه
embeddings = np.array(embeddings)
return documents, embeddings
# ----------------- تعریف LLM از Groq -----------------
groq_api_key = "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"
from langchain.llms import HuggingFaceEndpoint
groq_api_key = os.environ.get("GROQ_API_KEY")
# به جای OpenAI اینو بذار:
llm = ChatOpenAI(
base_url="https://api.together.xyz/v1",
api_key='0291f33aee03412a47fa5d8e562e515182dcc5d9aac5a7fb5eefdd1759005979',
model="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free"
)
# ----------------- تعریف SimpleRetriever -----------------
class SimpleRetriever(BaseRetriever):
documents: List[Document] = Field(...)
embeddings: List = Field(...)
def _get_relevant_documents(self, query: str) -> List[Document]:
inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
query_embedding = outputs.last_hidden_state.mean(dim=1).numpy()
similarities = []
for doc_embedding in self.embeddings:
similarity = (query_embedding * doc_embedding).sum()
similarities.append(similarity)
ranked_docs = sorted(zip(similarities, self.documents), reverse=True)
return [doc for _, doc in ranked_docs[:5]]
# ----------------- ساخت Index -----------------
documents, embeddings = build_pdf_index()
retriever = SimpleRetriever(documents=documents, embeddings=embeddings)
# ----------------- ساخت Chain -----------------
chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=retriever,
chain_type="stuff",
input_key="question"
)
# ----------------- استیت برای چت -----------------
if 'messages' not in st.session_state:
st.session_state.messages = []
if 'pending_prompt' not in st.session_state:
st.session_state.pending_prompt = None
# ----------------- نمایش پیام‌های قبلی -----------------
for msg in st.session_state.messages:
with st.chat_message(msg['role']):
st.markdown(f"🗨️ {msg['content']}", unsafe_allow_html=True)
# ----------------- ورودی چت -----------------
prompt = st.chat_input("سوالی در مورد فایل بپرس...")
if prompt:
st.session_state.messages.append({'role': 'user', 'content': prompt})
st.session_state.pending_prompt = prompt
st.rerun()
# ----------------- پاسخ مدل -----------------
if st.session_state.pending_prompt:
with st.chat_message('ai'):
# اضافه کردن پروگرس بار
progress_bar = st.progress(0, text="در حال پردازش...")
try:
response = chain.run(f"سوال: {st.session_state.pending_prompt}")
answer = response.strip()
# شبیه سازی پردازش برای به روز کردن پروگرس بار
for i in range(0, 101, 20):
progress_bar.progress(i)
time.sleep(0.1) # شبیه سازی سرعت پردازش
except Exception as e:
answer = f"خطا در پاسخ‌دهی: {str(e)}"
progress_bar.progress(100) # کامل شدن پروگرس بار
st.session_state.messages.append({'role': 'ai', 'content': answer})