sunbal7 commited on
Commit
b32efb7
Β·
verified Β·
1 Parent(s): ab46633

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -43
app.py CHANGED
@@ -4,7 +4,7 @@ st.set_page_config(page_title="RAG Book Analyzer", layout="wide") # Must be the
4
  import torch
5
  import numpy as np
6
  import faiss
7
- from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from sentence_transformers import SentenceTransformer
9
  import fitz # PyMuPDF for PDF extraction
10
  import docx2txt # For DOCX extraction
@@ -13,7 +13,7 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter
13
  # ------------------------
14
  # Configuration
15
  # ------------------------
16
- MODEL_NAME = "ibm-granite/granite-3.1-1b-a400m-instruct"
17
  EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2"
18
  CHUNK_SIZE = 512
19
  CHUNK_OVERLAP = 64
@@ -25,19 +25,13 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
  @st.cache_resource
26
  def load_models():
27
  try:
28
- tokenizer = AutoTokenizer.from_pretrained(
29
- MODEL_NAME,
30
- trust_remote_code=True,
31
- revision="main"
32
- )
33
  model = AutoModelForCausalLM.from_pretrained(
34
  MODEL_NAME,
35
- trust_remote_code=True,
36
- revision="main",
37
  device_map="auto" if DEVICE == "cuda" else None,
38
  torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
39
  low_cpu_mem_usage=True
40
- ).eval()
41
  embedder = SentenceTransformer(EMBED_MODEL, device=DEVICE)
42
  return tokenizer, model, embedder
43
  except Exception as e:
@@ -79,10 +73,9 @@ def extract_text(file):
79
  return ""
80
 
81
  def build_index(chunks):
82
- embeddings = embedder.encode(chunks, show_progress_bar=True)
83
  dimension = embeddings.shape[1]
84
- index = faiss.IndexFlatIP(dimension)
85
- faiss.normalize_L2(embeddings)
86
  index.add(embeddings)
87
  return index
88
 
@@ -90,36 +83,38 @@ def build_index(chunks):
90
  # Summarization and Q&A Functions
91
  # ------------------------
92
  def generate_summary(text):
93
- # Limit input text to avoid long sequences
94
- prompt = f"<|user|>\nSummarize the following book in a concise and informative paragraph:\n\n{text[:4000]}\n<|assistant|>\n"
95
  inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
96
- outputs = model.generate(**inputs, max_new_tokens=300, temperature=0.5)
 
 
 
 
 
 
97
  summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
98
- # Remove any markers and extra lines; return the first non-empty paragraph.
99
- summary = summary.replace("<|assistant|>", "").strip()
100
- paragraphs = [p.strip() for p in summary.split("\n") if p.strip()]
101
- return paragraphs[0] if paragraphs else summary
102
 
103
  def generate_answer(query, context):
104
- prompt = f"<|user|>\nUsing the context below, answer the following question precisely. If unsure, say 'I don't know'.\n\nContext: {context}\n\nQuestion: {query}\n<|assistant|>\n"
105
- inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True).to(DEVICE)
 
106
  outputs = model.generate(
107
  **inputs,
108
  max_new_tokens=300,
109
- temperature=0.4,
110
  top_p=0.9,
111
  repetition_penalty=1.2,
112
  do_sample=True
113
  )
114
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
115
- answer = answer.replace("<|assistant|>", "").strip()
116
- paragraphs = [p.strip() for p in answer.split("\n") if p.strip()]
117
- return paragraphs[0] if paragraphs else answer
118
 
119
  # ------------------------
120
  # Streamlit UI
121
  # ------------------------
122
- st.title("RAG-Based Book Analyzer")
123
  st.write("Upload a book (PDF, TXT, DOCX) to get a summary and ask questions about its content.")
124
 
125
  uploaded_file = st.file_uploader("Upload File", type=["pdf", "txt", "docx"])
@@ -127,11 +122,12 @@ uploaded_file = st.file_uploader("Upload File", type=["pdf", "txt", "docx"])
127
  if uploaded_file:
128
  text = extract_text(uploaded_file)
129
  if text:
130
- st.success("File successfully processed!")
131
- st.write("Generating summary...")
132
- summary = generate_summary(text)
133
- st.markdown("### Book Summary")
134
- st.write(summary)
 
135
 
136
  # Process text into chunks and build FAISS index
137
  chunks = split_text(text)
@@ -139,15 +135,19 @@ if uploaded_file:
139
  st.session_state.chunks = chunks
140
  st.session_state.index = index
141
 
142
- st.markdown("### Ask a Question about the Book:")
143
- query = st.text_input("Your Question:")
144
  if query:
145
- # Retrieve top 3 relevant chunks as context
146
- query_embedding = embedder.encode([query])
147
- faiss.normalize_L2(query_embedding)
148
- distances, indices = st.session_state.index.search(query_embedding, k=3)
149
- retrieved_chunks = [chunks[i] for i in indices[0] if i < len(chunks)]
150
- context = "\n".join(retrieved_chunks)
151
- answer = generate_answer(query, context)
152
- st.markdown("### Answer")
153
- st.write(answer)
 
 
 
 
 
4
  import torch
5
  import numpy as np
6
  import faiss
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
8
  from sentence_transformers import SentenceTransformer
9
  import fitz # PyMuPDF for PDF extraction
10
  import docx2txt # For DOCX extraction
 
13
  # ------------------------
14
  # Configuration
15
  # ------------------------
16
+ MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"
17
  EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2"
18
  CHUNK_SIZE = 512
19
  CHUNK_OVERLAP = 64
 
25
  @st.cache_resource
26
  def load_models():
27
  try:
28
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
 
 
 
29
  model = AutoModelForCausalLM.from_pretrained(
30
  MODEL_NAME,
 
 
31
  device_map="auto" if DEVICE == "cuda" else None,
32
  torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
33
  low_cpu_mem_usage=True
34
+ )
35
  embedder = SentenceTransformer(EMBED_MODEL, device=DEVICE)
36
  return tokenizer, model, embedder
37
  except Exception as e:
 
73
  return ""
74
 
75
  def build_index(chunks):
76
+ embeddings = embedder.encode(chunks, show_progress_bar=False)
77
  dimension = embeddings.shape[1]
78
+ index = faiss.IndexFlatL2(dimension)
 
79
  index.add(embeddings)
80
  return index
81
 
 
83
  # Summarization and Q&A Functions
84
  # ------------------------
85
  def generate_summary(text):
86
+ # Create prompt with Mistral format
87
+ prompt = f"<s>[INST] Summarize this book in a concise paragraph: {text[:3000]} [/INST]"
88
  inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
89
+ outputs = model.generate(
90
+ **inputs,
91
+ max_new_tokens=300,
92
+ temperature=0.7,
93
+ top_p=0.9,
94
+ do_sample=True
95
+ )
96
  summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
97
+ return summary.split("[/INST]")[-1].strip()
 
 
 
98
 
99
  def generate_answer(query, context):
100
+ # Create prompt with Mistral format
101
+ prompt = f"<s>[INST] Answer this question based on the context. If unsure, say 'I don't know'.\n\nQuestion: {query}\nContext: {context} [/INST]"
102
+ inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
103
  outputs = model.generate(
104
  **inputs,
105
  max_new_tokens=300,
106
+ temperature=0.5,
107
  top_p=0.9,
108
  repetition_penalty=1.2,
109
  do_sample=True
110
  )
111
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
112
+ return answer.split("[/INST]")[-1].strip()
 
 
113
 
114
  # ------------------------
115
  # Streamlit UI
116
  # ------------------------
117
+ st.title("πŸ“š RAG-Based Book Analyzer")
118
  st.write("Upload a book (PDF, TXT, DOCX) to get a summary and ask questions about its content.")
119
 
120
  uploaded_file = st.file_uploader("Upload File", type=["pdf", "txt", "docx"])
 
122
  if uploaded_file:
123
  text = extract_text(uploaded_file)
124
  if text:
125
+ st.success("βœ… File successfully processed!")
126
+
127
+ with st.spinner("Generating summary..."):
128
+ summary = generate_summary(text)
129
+ st.markdown("### Book Summary")
130
+ st.info(summary)
131
 
132
  # Process text into chunks and build FAISS index
133
  chunks = split_text(text)
 
135
  st.session_state.chunks = chunks
136
  st.session_state.index = index
137
 
138
+ st.markdown("### ❓ Ask a Question about the Book")
139
+ query = st.text_input("Enter your question:")
140
  if query:
141
+ with st.spinner("Searching for answers..."):
142
+ # Retrieve top 3 relevant chunks as context
143
+ query_embedding = embedder.encode([query])
144
+ distances, indices = st.session_state.index.search(query_embedding, k=3)
145
+ retrieved_chunks = [st.session_state.chunks[i] for i in indices[0] if i < len(st.session_state.chunks)]
146
+ context = "\n\n".join(retrieved_chunks)
147
+
148
+ answer = generate_answer(query, context)
149
+ st.markdown("### πŸ’¬ Answer")
150
+ st.success(answer)
151
+
152
+ with st.expander("See context used"):
153
+ st.write(context)