namberino commited on
Commit
805b803
·
1 Parent(s): fd5723e

Initial commit

Browse files
Files changed (2) hide show
  1. main.py +312 -0
  2. requirements.txt +5 -0
main.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio RAG -> MCQ app for HuggingFace Spaces
3
+ - Upload a PDF
4
+ - Chunk + embed using Together embeddings
5
+ - Store vectors in Chroma (local) and retrieve
6
+ - Call Together chat/completion to generate Vietnamese MCQs in JSON
7
+
8
+ Drop this file into a new HuggingFace Space (Gradio, Python). Add a requirements.txt (see README below) and set the secret TOGETHER_API_KEY in Space settings.
9
+ """
10
+
11
+ import os
12
+ import json
13
+ import uuid
14
+ import tempfile
15
+ import pdfplumber
16
+ from together import Together
17
+ import chromadb
18
+ from chromadb.config import Settings
19
+ import gradio as gr
20
+ from typing import List
21
+
22
+ # ---------- Config - can be overridden from UI ----------
23
+ TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
24
+ DEFAULT_EMBEDDING_MODEL = "togethercomputer/m2-bert-80M-8k-retrieval"
25
+ DEFAULT_LLM_MODEL = "mistralai/Mixtral-8x7B-Instruct-v0.1"
26
+ DEFAULT_CHUNK_SIZE = 1200
27
+ DEFAULT_CHUNK_OVERLAP = 200
28
+ DEFAULT_K_RETRIEVE = 4
29
+ EMBED_BATCH = 64
30
+
31
+ # instantiate Together client (requires TOGETHER_API_KEY in env / HF Secrets)
32
+ if TOGETHER_API_KEY:
33
+ client = Together(api_key=TOGETHER_API_KEY)
34
+ else:
35
+ # allow local testing if user wants to set env var later
36
+ client = None
37
+
38
+ # -------- PDF -> text ----------
39
+ def extract_text_from_pdf(path: str) -> str:
40
+ text_parts = []
41
+ with pdfplumber.open(path) as pdf:
42
+ for page in pdf.pages:
43
+ page_text = page.extract_text()
44
+ if page_text:
45
+ text_parts.append(page_text)
46
+ return "\n\n".join(text_parts)
47
+
48
+ # -------- simple chunker ----------
49
+ def chunk_text(text: str, chunk_size=DEFAULT_CHUNK_SIZE, overlap=DEFAULT_CHUNK_OVERLAP) -> List[str]:
50
+ chunks = []
51
+ start = 0
52
+ L = len(text)
53
+ while start < L:
54
+ end = min(L, start + chunk_size)
55
+ chunk = text[start:end].strip()
56
+ chunks.append(chunk)
57
+ start = end - overlap
58
+ if start < 0:
59
+ start = 0
60
+ if start >= L:
61
+ break
62
+ return chunks
63
+
64
+ # -------- embeddings (Together) with batching ----------
65
+ def embed_texts(texts: List[str], model=DEFAULT_EMBEDDING_MODEL):
66
+ if client is None:
67
+ raise RuntimeError("Together client not initialized. Set TOGETHER_API_KEY in environment or Space secrets.")
68
+ embeddings = []
69
+ for i in range(0, len(texts), EMBED_BATCH):
70
+ batch = texts[i:i+EMBED_BATCH]
71
+ resp = client.embeddings.create(input=batch, model=model)
72
+ # resp.data is list; each item has .embedding
73
+ for item in resp.data:
74
+ embeddings.append(item.embedding)
75
+ return embeddings
76
+
77
+ # -------- chroma vectorstore setup / helpers ----------
78
+ def build_chroma_collection(name="pdf_docs", persist_directory="./chroma_db"):
79
+ # On Spaces, writes to the repo may be limited; chroma will attempt to use the path provided.
80
+ client_chroma = chromadb.Client(Settings(chroma_db_impl="duckdb+parquet", persist_directory=persist_directory))
81
+ # create or get collection
82
+ try:
83
+ collection = client_chroma.get_collection(name)
84
+ except Exception:
85
+ collection = client_chroma.create_collection(name)
86
+ return client_chroma, collection
87
+
88
+ def add_documents_to_vectorstore(collection, chunks: List[str], embeddings: List[List[float]]):
89
+ ids = [f"doc_{i}" for i in range(len(chunks))]
90
+ metadatas = [{"chunk_index": i} for i in range(len(chunks))]
91
+ # If collection already has docs with same ids, Chroma will append; it's common to recreate collection per-upload.
92
+ collection.add(ids=ids, documents=chunks, metadatas=metadatas, embeddings=embeddings)
93
+
94
+ # -------- retrieve top-k using chroma ----------
95
+ def retrieve_relevant_chunks(collection, query: str, k=DEFAULT_K_RETRIEVE, embedding_model=DEFAULT_EMBEDDING_MODEL):
96
+ q_emb = embed_texts([query], model=embedding_model)[0]
97
+ result = collection.query(query_embeddings=[q_emb], n_results=k, include=["documents", "metadatas", "distances"])
98
+ docs = result["documents"][0]
99
+ metas = result["metadatas"][0]
100
+ distances = result["distances"][0]
101
+ return list(zip(docs, metas, distances))
102
+
103
+ # -------- prompt template (Vietnamese) ----------
104
+ MCQ_PROMPT_VI = """
105
+ Bạn là một chuyên gia soạn câu hỏi trắc nghiệm (MCQ). SỬ DỤNG CHỈ các đoạn ngữ cảnh được cung cấp dưới đây (KHÔNG suy diễn/không thêm thông tin ngoài ngữ cảnh).
106
+ Tạo **một** câu hỏi trắc nghiệm có 4 lựa chọn (A, B, C, D), chỉ ra đáp án đúng (A/B/C/D) và viết 1 câu giải thích ngắn (1-2 câu).
107
+ **Bắt buộc:** output PHẢI LÀ **JSON duy nhất** theo schema sau (không có văn bản nào khác ngoài JSON):
108
+
109
+ {{
110
+ "question_id": "<mã duy nhất>",
111
+ "question": "<câu hỏi bằng tiếng Việt>",
112
+ "options": [
113
+ {{ "label": "A", "text": "..." }},
114
+ {{ "label": "B", "text": "..." }},
115
+ {{ "label": "C", "text": "..." }},
116
+ {{ "label": "D", "text": "..." }}
117
+ ],
118
+ "answer": "A",
119
+ "explanation": "<giải thích ngắn bằng tiếng Việt>",
120
+ "source_chunks": [ "<chunk_index hoặc đoạn trích ngắn>", ... ]
121
+ }}
122
+
123
+ Ví dụ đầu ra (một mẫu JSON đúng; chỉ để mô tả định dạng):
124
+ {{
125
+ "question_id": "q_0001",
126
+ "question": "Nguyên tố nào là thành phần chính của vỏ trái đất?",
127
+ "options": [
128
+ {{ "label": "A", "text": "Sắt" }},
129
+ {{ "label": "B", "text": "Oxi" }},
130
+ {{ "label": "C", "text": "Cacbon" }},
131
+ {{ "label": "D", "text": "Nitơ" }}
132
+ ],
133
+ "answer": "B",
134
+ "explanation": "Oxi là nguyên tố phong phú nhất trong vỏ trái đất, chủ yếu trong các oxit và khoáng vật.",
135
+ "source_chunks": [ "chunk_3" ]
136
+ }}
137
+
138
+ Đây là các đoạn ngữ cảnh (chỉ được phép dùng những đoạn này để soạn câu hỏi):
139
+ {context}
140
+
141
+ Hãy viết câu hỏi rõ ràng, không gây mơ hồ. Đảm bảo distractor (đáp án sai) là hợp lý và gây nhầm lẫn cho người học.
142
+ """
143
+
144
+ # -------- call Together chat/completion ----------
145
+ def generate_mcq_with_rag(question_seed: str, retrieved_chunks, llm_model=DEFAULT_LLM_MODEL, temperature=0.0):
146
+ if client is None:
147
+ raise RuntimeError("Together client not initialized. Set TOGETHER_API_KEY in environment or Space secrets.")
148
+
149
+ context = ""
150
+ for i, (doc_text, meta, dist) in enumerate(retrieved_chunks):
151
+ snippet = doc_text.replace("\n", " ").strip()
152
+ context += f"[chunk_{meta.get('chunk_index', i)}] {snippet}\n\n"
153
+
154
+ prompt = MCQ_PROMPT_VI.format(context=context)
155
+ full_user = f"Yêu cầu (chủ đề / seed): {question_seed}\n\n{prompt}"
156
+
157
+ messages = [
158
+ {"role": "system", "content": "Bạn là một chuyên gia soạn câu hỏi trắc nghiệm bằng tiếng Việt. Chỉ trả về JSON, KHÔNG có lời giải thích thêm."},
159
+ {"role": "user", "content": full_user},
160
+ ]
161
+
162
+ resp = client.chat.completions.create(
163
+ model=llm_model,
164
+ messages=messages,
165
+ temperature=temperature,
166
+ )
167
+ out = resp.choices[0].message.content
168
+
169
+ # try to parse JSON, fallback to extracting first {...}
170
+ try:
171
+ parsed = json.loads(out)
172
+ except Exception:
173
+ start = out.find("{")
174
+ end = out.rfind("}")
175
+ if start != -1 and end != -1:
176
+ try:
177
+ parsed = json.loads(out[start:end+1])
178
+ except Exception:
179
+ parsed = None
180
+ else:
181
+ parsed = None
182
+
183
+ # ensure question_id exists
184
+ if parsed and isinstance(parsed, dict):
185
+ if not parsed.get("question_id"):
186
+ parsed["question_id"] = f"q_{uuid.uuid4().hex[:8]}"
187
+
188
+ return parsed, out
189
+
190
+ # -------- high-level runner used by Gradio ----------
191
+ def generate_mcqs_from_pdf(pdf_path: str, seeds: List[str], questions_per_seed=1, chunk_size=DEFAULT_CHUNK_SIZE,
192
+ chunk_overlap=DEFAULT_CHUNK_OVERLAP, k_retrieve=DEFAULT_K_RETRIEVE,
193
+ embedding_model=DEFAULT_EMBEDDING_MODEL, llm_model=DEFAULT_LLM_MODEL,
194
+ temperature=0.0, persist_directory="./chroma_db"):
195
+ text = extract_text_from_pdf(pdf_path)
196
+ chunks = chunk_text(text, chunk_size=chunk_size, overlap=chunk_overlap)
197
+
198
+ # embed
199
+ chunk_embeddings = embed_texts(chunks, model=embedding_model)
200
+
201
+ # build vectorstore (recreate to avoid old data)
202
+ chroma_client, collection = build_chroma_collection(name="pdf_docs", persist_directory=persist_directory)
203
+ try:
204
+ collection.delete()
205
+ collection = chroma_client.create_collection("pdf_docs")
206
+ except Exception:
207
+ # some backends will raise; ignore and continue
208
+ pass
209
+ add_documents_to_vectorstore(collection, chunks, chunk_embeddings)
210
+
211
+ results = []
212
+ for seed in seeds:
213
+ for i in range(questions_per_seed):
214
+ retrieved = retrieve_relevant_chunks(collection, seed, k=k_retrieve, embedding_model=embedding_model)
215
+ parsed, raw = generate_mcq_with_rag(seed, retrieved, llm_model=llm_model, temperature=temperature)
216
+ if parsed is None:
217
+ item = {"seed": seed, "ok": False, "raw": raw}
218
+ else:
219
+ item = {"seed": seed, "ok": True, "mcq": parsed}
220
+ results.append(item)
221
+
222
+ return results
223
+
224
+ # -------- Gradio UI ----------
225
+
226
+ def ui_run(pdf_file, seeds_text, questions_per_seed, k_retrieve, chunk_size, chunk_overlap,
227
+ embedding_model, llm_model, temperature):
228
+ if pdf_file is None:
229
+ return "", None
230
+
231
+ # save uploaded file to temp path
232
+ tmp_dir = tempfile.mkdtemp()
233
+ local_path = os.path.join(tmp_dir, os.path.basename(pdf_file.name))
234
+ with open(local_path, "wb") as f:
235
+ f.write(pdf_file.read())
236
+
237
+ seeds = [s.strip() for s in seeds_text.split(",") if s.strip()]
238
+ if not seeds:
239
+ seeds = ["Lấy câu hỏi tổng quát về tài liệu"]
240
+
241
+ try:
242
+ results = generate_mcqs_from_pdf(
243
+ pdf_path=local_path,
244
+ seeds=seeds,
245
+ questions_per_seed=questions_per_seed,
246
+ chunk_size=chunk_size,
247
+ chunk_overlap=chunk_overlap,
248
+ k_retrieve=k_retrieve,
249
+ embedding_model=embedding_model,
250
+ llm_model=llm_model,
251
+ temperature=temperature,
252
+ persist_directory="./chroma_db"
253
+ )
254
+ except Exception as e:
255
+ return f"Lỗi khi sinh MCQ: {e}", None
256
+
257
+ out_json = json.dumps(results, ensure_ascii=False, indent=2)
258
+ # write output file for download
259
+ out_file = os.path.join(tmp_dir, "mcq_output.json")
260
+ with open(out_file, "w", encoding="utf-8") as f:
261
+ f.write(out_json)
262
+
263
+ return out_json, out_file
264
+
265
+ with gr.Blocks(title="RAG -> MCQ (Tiếng Việt)") as demo:
266
+ gr.Markdown("# RAG -> MCQ Generator (Tiếng Việt)\nUpload PDF, set seeds (phân tách bằng dấu phẩy), và nhấn Generate.\nOutputs: JSON trả về các câu hỏi trắc nghiệm.)")
267
+
268
+ with gr.Row():
269
+ with gr.Column(scale=1):
270
+ pdf_in = gr.File(label="Upload PDF", file_types=['.pdf'])
271
+ seeds_in = gr.Textbox(label="Seeds (chủ đề), phân tách bằng dấu phẩy", value="lập trình hướng đối tượng, kế thừa")
272
+ questions_per_seed = gr.Slider(label="Questions per seed", minimum=1, maximum=5, step=1, value=1)
273
+ k_retrieve = gr.Slider(label="K retrieve (số đoạn liên quan)", minimum=1, maximum=10, step=1, value=DEFAULT_K_RETRIEVE)
274
+ chunk_size = gr.Number(label="Chunk size (chars)", value=DEFAULT_CHUNK_SIZE)
275
+ chunk_overlap = gr.Number(label="Chunk overlap (chars)", value=DEFAULT_CHUNK_OVERLAP)
276
+ embedding_model = gr.Textbox(label="Embedding model", value=DEFAULT_EMBEDDING_MODEL)
277
+ llm_model = gr.Textbox(label="LLM model", value=DEFAULT_LLM_MODEL)
278
+ temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, step=0.05, value=0.0)
279
+ btn = gr.Button("Generate MCQs")
280
+ with gr.Column(scale=1):
281
+ out_text = gr.Textbox(label="Raw JSON output", lines=20)
282
+ out_file = gr.File(label="Download JSON")
283
+
284
+ btn.click(fn=ui_run, inputs=[pdf_in, seeds_in, questions_per_seed, k_retrieve, chunk_size, chunk_overlap,
285
+ embedding_model, llm_model, temperature], outputs=[out_text, out_file])
286
+
287
+ if __name__ == "__main__":
288
+ demo.launch()
289
+
290
+ # -------------------------------
291
+ # README / Deployment tips (keep in repo README.md)
292
+ # -------------------------------
293
+ # requirements.txt (suggested):
294
+ # gradio
295
+ # together
296
+ # chromadb
297
+ # pdfplumber
298
+ # tiktoken
299
+ #
300
+ # Deployment to HuggingFace Spaces:
301
+ # 1) Create a new Space, choose "Gradio" and "Python".
302
+ # 2) Add this file as app.py (or keep the same name), add a requirements.txt with packages above.
303
+ # 3) In the Space settings -> Secrets, add TOGETHER_API_KEY with your Together API key.
304
+ # 4) Commit & run. The Space will install dependencies and start the Gradio app.
305
+ #
306
+ # Notes:
307
+ # - If pip package name for Together differs, try checking Together docs. If `pip install together` fails,
308
+ # try searching for the SDK package name or check their GitHub/docs and update requirements accordingly.
309
+ # - HuggingFace Spaces have limits on disk persistence; if you want long-term persistence of the Chroma DB,
310
+ # consider using an external vector DB (Pinecone, Milvus, etc.) or a hosted setup.
311
+ # - For heavy models you may prefer to let Together host the inference (this app calls Together's API).
312
+ # -------------------------------
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ together
3
+ chromadb
4
+ pdfplumber
5
+ tiktoken