File size: 3,373 Bytes
097081a
9b56ad1
 
 
 
 
 
 
097081a
9b56ad1
097081a
 
 
 
 
 
 
9b56ad1
 
097081a
9b56ad1
097081a
9b56ad1
 
 
097081a
 
9b56ad1
 
 
097081a
9b56ad1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
097081a
9b56ad1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
097081a
9b56ad1
097081a
 
9b56ad1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import gradio as gr
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import faiss
import numpy as np
from langchain.text_splitter import RecursiveCharacterTextSplitter
import fitz  # PyMuPDF
from huggingface_hub import login

# Authenticate with Hugging Face to access gated models
hf_token = os.environ.get("HUGGINGFACE_TOKEN")
if hf_token is None:
    raise ValueError("Please set the HUGGINGFACE_TOKEN environment variable")
login(token=hf_token)

# Load embedding model
embed_model = SentenceTransformer("BAAI/bge-base-en-v1.5")

# Load LLM model and tokenizer with 4bit quantization
model_id = "mistralai/Mistral-7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    load_in_4bit=True,
    use_auth_token=hf_token
)
llm = pipeline("text-generation", model=model, tokenizer=tokenizer)

# Globals for FAISS index and document texts
index = None
doc_texts = []

# PDF/Text extraction
def extract_text(file):
    if file.name.endswith(".pdf"):
        text = ""
        doc = fitz.open(file.name)
        for page in doc:
            text += page.get_text()
        return text
    elif file.name.endswith(".txt"):
        return file.read().decode("utf-8")
    else:
        return "❌ Invalid file type."

# File processing: chunk text, create embeddings, build FAISS index
def process_file(file):
    global index, doc_texts
    text = extract_text(file)
    if text.startswith("❌"):
        return text

    text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50)
    doc_texts = text_splitter.split_text(text)
    embeddings = embed_model.encode(doc_texts)

    dim = embeddings.shape[1]
    index = faiss.IndexFlatL2(dim)
    index.add(np.array(embeddings))

    return "βœ… File processed successfully. You can now ask questions!"

# Generate answer using retrieved context and LLM
def generate_answer(question):
    global index, doc_texts
    if index is None or len(doc_texts) == 0:
        return "⚠️ Please upload and process a file first."
    
    question_embedding = embed_model.encode([question])
    _, I = index.search(np.array(question_embedding), k=3)
    context = "\n".join([doc_texts[i] for i in I[0]])

    prompt = f"""[System: You are a helpful assistant. Answer strictly based on the context.]

Context:
{context}

Question: {question}
Answer:"""

    result = llm(prompt, max_new_tokens=300, do_sample=True, temperature=0.7)
    return result[0]["generated_text"].split("Answer:")[-1].strip()

# Gradio UI
with gr.Blocks(title="RAG Chatbot") as demo:
    gr.Markdown("## πŸ“š RAG Chatbot - Upload PDF/TXT and Ask Questions")

    with gr.Row():
        file_input = gr.File(label="πŸ“ Upload .pdf or .txt", file_types=[".pdf", ".txt"])
        upload_status = gr.Textbox(label="πŸ“₯ Upload Status", interactive=False)

    with gr.Row():
        question_box = gr.Textbox(label="❓ Ask a Question", placeholder="Type your question here...")
        answer_box = gr.Textbox(label="πŸ’¬ Answer", interactive=False)

    file_input.change(fn=process_file, inputs=file_input, outputs=upload_status)
    question_box.submit(fn=generate_answer, inputs=question_box, outputs=answer_box)

demo.launch()