File size: 3,655 Bytes
2d8c319
42d3ee2
8550dc5
42d3ee2
 
 
91b268b
42d3ee2
 
 
 
 
 
 
0e5b4a4
2e62dd1
42d3ee2
2d8c319
42d3ee2
 
 
0e5b4a4
42d3ee2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8550dc5
 
42d3ee2
 
 
8550dc5
42d3ee2
8550dc5
42d3ee2
 
 
8550dc5
42d3ee2
4d6816c
8550dc5
 
42d3ee2
 
8550dc5
2d88065
 
8550dc5
42d3ee2
 
 
 
 
 
 
 
 
 
 
8550dc5
42d3ee2
 
 
 
 
 
0e5b4a4
42d3ee2
 
 
 
0e5b4a4
42d3ee2
f8c1ecf
42d3ee2
 
 
 
f8c1ecf
42d3ee2
 
4d6816c
42d3ee2
 
 
 
2d88065
 
42d3ee2
0e5b4a4
42d3ee2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import gradio as gr
import torch
from huggingface_hub import login
from langchain_community.document_loaders import PyMuPDFLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain_community.llms import HuggingFacePipeline
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM

# HF Authentication
login(token=os.environ.get('HF_TOKEN'))

# Configuration
DOCS_DIR = "study_materials"
MODEL_NAME = "microsoft/phi-2"
EMBEDDINGS_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
MAX_TOKENS = 300
CHUNK_SIZE = 1000

def load_documents():
    documents = []
    for filename in os.listdir(DOCS_DIR):
        path = os.path.join(DOCS_DIR, filename)
        try:
            if filename.endswith(".pdf"):
                documents.extend(PyMuPDFLoader(path).load())
            elif filename.endswith(".txt"):
                documents.extend(TextLoader(path).load())
        except Exception as e:
            print(f"Error loading {filename}: {str(e)}")
    return documents

def create_qa_system():
    # Load and split documents
    documents = load_documents()
    if not documents:
        raise gr.Error("No documents found in 'study_materials' folder")
    
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=CHUNK_SIZE,
        chunk_overlap=200,
        separators=["\n\n", "\n", " "]
    )
    texts = text_splitter.split_documents(documents)

    # Create vector store
    embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL)
    db = FAISS.from_documents(texts, embeddings)

    # Load Phi-2 with authentication
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        use_auth_token=True,  # Critical change for gated models
        torch_dtype=torch.float32,
        trust_remote_code=True,
        device_map="auto",
        low_cpu_mem_usage=True
    )
    
    pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=MAX_TOKENS,
        temperature=0.7,
        do_sample=True,
        top_k=40,
        device_map="auto"
    )

    return RetrievalQA.from_chain_type(
        llm=HuggingFacePipeline(pipeline=pipe),
        chain_type="stuff",
        retriever=db.as_retriever(search_kwargs={"k": 2}),
        return_source_documents=True
    )

def format_response(response):
    answer = response["result"].split("</s>")[0].split("\nOutput:")[-1].strip()
    sources = list({os.path.basename(doc.metadata["source"]) for doc in response["source_documents"]})
    return f"{answer}\n\n📚 Sources: {', '.join(sources)}"

def process_query(question, history):
    try:
        qa = create_qa_system()
        formatted_q = f"Instruct: {question}\nOutput:"
        response = qa.invoke({"query": formatted_q})
        return format_response(response)
    except Exception as e:
        print(f"Error: {str(e)}")
        return f"⚠️ Error: {str(e)[:100]}"

with gr.Blocks(title="Phi-2 Study Assistant", theme=gr.themes.Soft()) as app:
    gr.Markdown("## 📚 Phi-2 Study Assistant\nUpload study materials to 'study_materials' and ask questions!")
    chatbot = gr.Chatbot(height=400)
    msg = gr.Textbox(label="Your Question")
    clear = gr.ClearButton([msg, chatbot])

    msg.submit(process_query, [msg, chatbot], [msg, chatbot])

if __name__ == "__main__":
    app.launch(server_name="0.0.0.0", server_port=7860)