tryagain / app.py
random2222's picture
Update app.py
42d3ee2 verified
raw
history blame
3.66 kB
import os
import gradio as gr
import torch
from huggingface_hub import login
from langchain_community.document_loaders import PyMuPDFLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain_community.llms import HuggingFacePipeline
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
# HF Authentication
login(token=os.environ.get('HF_TOKEN'))
# Configuration
DOCS_DIR = "study_materials"
MODEL_NAME = "microsoft/phi-2"
EMBEDDINGS_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
MAX_TOKENS = 300
CHUNK_SIZE = 1000
def load_documents():
documents = []
for filename in os.listdir(DOCS_DIR):
path = os.path.join(DOCS_DIR, filename)
try:
if filename.endswith(".pdf"):
documents.extend(PyMuPDFLoader(path).load())
elif filename.endswith(".txt"):
documents.extend(TextLoader(path).load())
except Exception as e:
print(f"Error loading {filename}: {str(e)}")
return documents
def create_qa_system():
# Load and split documents
documents = load_documents()
if not documents:
raise gr.Error("No documents found in 'study_materials' folder")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=200,
separators=["\n\n", "\n", " "]
)
texts = text_splitter.split_documents(documents)
# Create vector store
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL)
db = FAISS.from_documents(texts, embeddings)
# Load Phi-2 with authentication
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
use_auth_token=True, # Critical change for gated models
torch_dtype=torch.float32,
trust_remote_code=True,
device_map="auto",
low_cpu_mem_usage=True
)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=MAX_TOKENS,
temperature=0.7,
do_sample=True,
top_k=40,
device_map="auto"
)
return RetrievalQA.from_chain_type(
llm=HuggingFacePipeline(pipeline=pipe),
chain_type="stuff",
retriever=db.as_retriever(search_kwargs={"k": 2}),
return_source_documents=True
)
def format_response(response):
answer = response["result"].split("</s>")[0].split("\nOutput:")[-1].strip()
sources = list({os.path.basename(doc.metadata["source"]) for doc in response["source_documents"]})
return f"{answer}\n\nπŸ“š Sources: {', '.join(sources)}"
def process_query(question, history):
try:
qa = create_qa_system()
formatted_q = f"Instruct: {question}\nOutput:"
response = qa.invoke({"query": formatted_q})
return format_response(response)
except Exception as e:
print(f"Error: {str(e)}")
return f"⚠️ Error: {str(e)[:100]}"
with gr.Blocks(title="Phi-2 Study Assistant", theme=gr.themes.Soft()) as app:
gr.Markdown("## πŸ“š Phi-2 Study Assistant\nUpload study materials to 'study_materials' and ask questions!")
chatbot = gr.Chatbot(height=400)
msg = gr.Textbox(label="Your Question")
clear = gr.ClearButton([msg, chatbot])
msg.submit(process_query, [msg, chatbot], [msg, chatbot])
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860)