Spaces:
Sleeping
Sleeping
""" | |
Gradio RAG -> MCQ app for HuggingFace Spaces | |
- Upload a PDF | |
- Chunk + embed using Together embeddings | |
- Store vectors in Chroma (local) and retrieve | |
- Call Together chat/completion to generate Vietnamese MCQs in JSON | |
Drop this file into a new HuggingFace Space (Gradio, Python). Add a requirements.txt (see README below) and set the secret TOGETHER_API_KEY in Space settings. | |
""" | |
import os | |
import json | |
import uuid | |
import tempfile | |
import pdfplumber | |
from together import Together | |
import chromadb | |
from chromadb.config import Settings | |
import gradio as gr | |
from typing import List | |
import shutil | |
tmp_dir = "./tmp" | |
# ---------- Config - can be overridden from UI ---------- | |
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY") | |
DEFAULT_EMBEDDING_MODEL = "togethercomputer/m2-bert-80M-8k-retrieval" | |
DEFAULT_LLM_MODEL = "mistralai/Mixtral-8x7B-Instruct-v0.1" | |
DEFAULT_CHUNK_SIZE = 1200 | |
DEFAULT_CHUNK_OVERLAP = 200 | |
DEFAULT_K_RETRIEVE = 4 | |
EMBED_BATCH = 64 | |
# instantiate Together client (requires TOGETHER_API_KEY in env / HF Secrets) | |
if TOGETHER_API_KEY: | |
client = Together(api_key=TOGETHER_API_KEY) | |
else: | |
# allow local testing if user wants to set env var later | |
client = None | |
# -------- PDF -> text ---------- | |
def extract_text_from_pdf(path: str) -> str: | |
text_parts = [] | |
with pdfplumber.open(path) as pdf: | |
for page in pdf.pages: | |
page_text = page.extract_text() | |
if page_text: | |
text_parts.append(page_text) | |
return "\n\n".join(text_parts) | |
# -------- simple chunker ---------- | |
def chunk_text(text: str, chunk_size=DEFAULT_CHUNK_SIZE, overlap=DEFAULT_CHUNK_OVERLAP) -> List[str]: | |
chunks = [] | |
start = 0 | |
L = len(text) | |
while start < L: | |
end = min(L, start + chunk_size) | |
chunk = text[start:end].strip() | |
chunks.append(chunk) | |
start = end - overlap | |
if start < 0: | |
start = 0 | |
if start >= L: | |
break | |
return chunks | |
# -------- embeddings (Together) with batching ---------- | |
def embed_texts(texts: List[str], model=DEFAULT_EMBEDDING_MODEL): | |
if client is None: | |
raise RuntimeError("Together client not initialized. Set TOGETHER_API_KEY in environment or Space secrets.") | |
embeddings = [] | |
for i in range(0, len(texts), EMBED_BATCH): | |
batch = texts[i:i+EMBED_BATCH] | |
resp = client.embeddings.create(input=batch, model=model) | |
# resp.data is list; each item has .embedding | |
for item in resp.data: | |
embeddings.append(item.embedding) | |
return embeddings | |
# -------- chroma vectorstore setup / helpers ---------- | |
def build_chroma_collection(name="pdf_docs", persist_directory="./chroma_db"): | |
# On Spaces, writes to the repo may be limited; chroma will attempt to use the path provided. | |
client_chroma = chromadb.Client(Settings(chroma_db_impl="duckdb+parquet", persist_directory=persist_directory)) | |
# create or get collection | |
try: | |
collection = client_chroma.get_collection(name) | |
except Exception: | |
collection = client_chroma.create_collection(name) | |
return client_chroma, collection | |
def add_documents_to_vectorstore(collection, chunks: List[str], embeddings: List[List[float]]): | |
ids = [f"doc_{i}" for i in range(len(chunks))] | |
metadatas = [{"chunk_index": i} for i in range(len(chunks))] | |
# If collection already has docs with same ids, Chroma will append; it's common to recreate collection per-upload. | |
collection.add(ids=ids, documents=chunks, metadatas=metadatas, embeddings=embeddings) | |
# -------- retrieve top-k using chroma ---------- | |
def retrieve_relevant_chunks(collection, query: str, k=DEFAULT_K_RETRIEVE, embedding_model=DEFAULT_EMBEDDING_MODEL): | |
q_emb = embed_texts([query], model=embedding_model)[0] | |
result = collection.query(query_embeddings=[q_emb], n_results=k, include=["documents", "metadatas", "distances"]) | |
docs = result["documents"][0] | |
metas = result["metadatas"][0] | |
distances = result["distances"][0] | |
return list(zip(docs, metas, distances)) | |
# -------- prompt template (Vietnamese) ---------- | |
MCQ_PROMPT_VI = """ | |
Bạn là một chuyên gia soạn câu hỏi trắc nghiệm (MCQ). SỬ DỤNG CHỈ các đoạn ngữ cảnh được cung cấp dưới đây (KHÔNG suy diễn/không thêm thông tin ngoài ngữ cảnh). | |
Tạo **một** câu hỏi trắc nghiệm có 4 lựa chọn (A, B, C, D), chỉ ra đáp án đúng (A/B/C/D) và viết 1 câu giải thích ngắn (1-2 câu). | |
**Bắt buộc:** output PHẢI LÀ **JSON duy nhất** theo schema sau (không có văn bản nào khác ngoài JSON): | |
{{ | |
"question_id": "<mã duy nhất>", | |
"question": "<câu hỏi bằng tiếng Việt>", | |
"options": [ | |
{{ "label": "A", "text": "..." }}, | |
{{ "label": "B", "text": "..." }}, | |
{{ "label": "C", "text": "..." }}, | |
{{ "label": "D", "text": "..." }} | |
], | |
"answer": "A", | |
"explanation": "<giải thích ngắn bằng tiếng Việt>", | |
"source_chunks": [ "<chunk_index hoặc đoạn trích ngắn>", ... ] | |
}} | |
Ví dụ đầu ra (một mẫu JSON đúng; chỉ để mô tả định dạng): | |
{{ | |
"question_id": "q_0001", | |
"question": "Nguyên tố nào là thành phần chính của vỏ trái đất?", | |
"options": [ | |
{{ "label": "A", "text": "Sắt" }}, | |
{{ "label": "B", "text": "Oxi" }}, | |
{{ "label": "C", "text": "Cacbon" }}, | |
{{ "label": "D", "text": "Nitơ" }} | |
], | |
"answer": "B", | |
"explanation": "Oxi là nguyên tố phong phú nhất trong vỏ trái đất, chủ yếu trong các oxit và khoáng vật.", | |
"source_chunks": [ "chunk_3" ] | |
}} | |
Đây là các đoạn ngữ cảnh (chỉ được phép dùng những đoạn này để soạn câu hỏi): | |
{context} | |
Hãy viết câu hỏi rõ ràng, không gây mơ hồ. Đảm bảo distractor (đáp án sai) là hợp lý và gây nhầm lẫn cho người học. | |
""" | |
# -------- call Together chat/completion ---------- | |
def generate_mcq_with_rag(question_seed: str, retrieved_chunks, llm_model=DEFAULT_LLM_MODEL, temperature=0.0): | |
if client is None: | |
raise RuntimeError("Together client not initialized. Set TOGETHER_API_KEY in environment or Space secrets.") | |
context = "" | |
for i, (doc_text, meta, dist) in enumerate(retrieved_chunks): | |
snippet = doc_text.replace("\n", " ").strip() | |
context += f"[chunk_{meta.get('chunk_index', i)}] {snippet}\n\n" | |
prompt = MCQ_PROMPT_VI.format(context=context) | |
full_user = f"Yêu cầu (chủ đề / seed): {question_seed}\n\n{prompt}" | |
messages = [ | |
{"role": "system", "content": "Bạn là một chuyên gia soạn câu hỏi trắc nghiệm bằng tiếng Việt. Chỉ trả về JSON, KHÔNG có lời giải thích thêm."}, | |
{"role": "user", "content": full_user}, | |
] | |
resp = client.chat.completions.create( | |
model=llm_model, | |
messages=messages, | |
temperature=temperature, | |
) | |
out = resp.choices[0].message.content | |
# try to parse JSON, fallback to extracting first {...} | |
try: | |
parsed = json.loads(out) | |
except Exception: | |
start = out.find("{") | |
end = out.rfind("}") | |
if start != -1 and end != -1: | |
try: | |
parsed = json.loads(out[start:end+1]) | |
except Exception: | |
parsed = None | |
else: | |
parsed = None | |
# ensure question_id exists | |
if parsed and isinstance(parsed, dict): | |
if not parsed.get("question_id"): | |
parsed["question_id"] = f"q_{uuid.uuid4().hex[:8]}" | |
return parsed, out | |
# -------- high-level runner used by Gradio ---------- | |
def generate_mcqs_from_pdf(pdf_path: str, seeds: List[str], questions_per_seed=1, chunk_size=DEFAULT_CHUNK_SIZE, | |
chunk_overlap=DEFAULT_CHUNK_OVERLAP, k_retrieve=DEFAULT_K_RETRIEVE, | |
embedding_model=DEFAULT_EMBEDDING_MODEL, llm_model=DEFAULT_LLM_MODEL, | |
temperature=0.0, persist_directory="./chroma_db"): | |
text = extract_text_from_pdf(pdf_path) | |
chunks = chunk_text(text, chunk_size=chunk_size, overlap=chunk_overlap) | |
# embed | |
chunk_embeddings = embed_texts(chunks, model=embedding_model) | |
# build vectorstore (recreate to avoid old data) | |
chroma_client, collection = build_chroma_collection(name="pdf_docs", persist_directory=persist_directory) | |
try: | |
collection.delete() | |
collection = chroma_client.create_collection("pdf_docs") | |
except Exception: | |
# some backends will raise; ignore and continue | |
pass | |
add_documents_to_vectorstore(collection, chunks, chunk_embeddings) | |
results = [] | |
for seed in seeds: | |
for i in range(questions_per_seed): | |
retrieved = retrieve_relevant_chunks(collection, seed, k=k_retrieve, embedding_model=embedding_model) | |
parsed, raw = generate_mcq_with_rag(seed, retrieved, llm_model=llm_model, temperature=temperature) | |
if parsed is None: | |
item = {"seed": seed, "ok": False, "raw": raw} | |
else: | |
item = {"seed": seed, "ok": True, "mcq": parsed} | |
results.append(item) | |
return results | |
# -------- Gradio UI ---------- | |
def save_uploaded_file(uploaded) -> str: | |
""" | |
uploaded may be: | |
- Path string (when running locally in some setups) | |
- File-like object with .name | |
- tuple/list returned by gradio in some versions | |
Returns saved file path. | |
""" | |
if uploaded is None: | |
raise ValueError("No file uploaded.") | |
# normalize to path | |
if isinstance(uploaded, str) and os.path.exists(uploaded): | |
src = uploaded | |
elif hasattr(uploaded, "name") and os.path.exists(uploaded.name): | |
src = uploaded.name | |
elif isinstance(uploaded, (tuple, list)) and len(uploaded) > 0: | |
# sometimes gradio returns (tempfile_path, original_name) | |
cand = uploaded[0] | |
if isinstance(cand, str) and os.path.exists(cand): | |
src = cand | |
else: | |
# fallback: try bytes | |
src = None | |
else: | |
src = None | |
dest_path = os.path.join(tmp_dir, os.path.basename(src) if src else "uploaded_doc") | |
if src: | |
shutil.copy(src, dest_path) | |
return dest_path | |
# last-resort: if 'uploaded' is bytes-like | |
try: | |
data = uploaded.read() | |
except Exception: | |
# try treat as bytes | |
data = uploaded if isinstance(uploaded, (bytes, bytearray)) else None | |
if data is None: | |
raise ValueError("Could not handle uploaded file type.") | |
with open(dest_path, "wb") as f: | |
f.write(data) | |
return dest_path | |
def ui_run(pdf_file, seeds_text, questions_per_seed, k_retrieve, chunk_size, chunk_overlap, | |
embedding_model, llm_model, temperature): | |
if pdf_file is None: | |
return "", None | |
# save uploaded file to temp path | |
try: | |
# Clear tmp folder and recreate | |
if os.path.exists(tmp_dir): | |
shutil.rmtree(tmp_dir) | |
os.makedirs(tmp_dir) | |
except Exception: | |
pass | |
print(f"created {tmp_dir}") | |
try: | |
local_path = save_uploaded_file(pdf_file) | |
except Exception as e: | |
return {"error": f"Failed saving uploaded file: {e}"} | |
print(f"uploaded file {local_path}") | |
seeds = [s.strip() for s in seeds_text.split(",") if s.strip()] | |
if not seeds: | |
seeds = ["Lấy câu hỏi tổng quát về tài liệu"] | |
print(f"seeds: {seeds}") | |
print("generating mcqs") | |
try: | |
results = generate_mcqs_from_pdf( | |
pdf_path=local_path, | |
seeds=seeds, | |
questions_per_seed=questions_per_seed, | |
chunk_size=chunk_size, | |
chunk_overlap=chunk_overlap, | |
k_retrieve=k_retrieve, | |
embedding_model=embedding_model, | |
llm_model=llm_model, | |
temperature=temperature, | |
persist_directory="./chroma_db" | |
) | |
except Exception as e: | |
return f"Lỗi khi sinh MCQ: {e}", None | |
print("mcqs generated") | |
out_json = json.dumps(results, ensure_ascii=False, indent=2) | |
# write output file for download | |
out_file = os.path.join(tmp_dir, "mcq_output.json") | |
with open(out_file, "w", encoding="utf-8") as f: | |
f.write(out_json) | |
print("json output dumped") | |
return out_json, out_file | |
with gr.Blocks(title="RAG -> MCQ (Tiếng Việt)") as demo: | |
gr.Markdown("# RAG -> MCQ Generator (Tiếng Việt)\nUpload PDF, set seeds (phân tách bằng dấu phẩy), và nhấn Generate.\nOutputs: JSON trả về các câu hỏi trắc nghiệm.)") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
pdf_in = gr.File(label="Upload PDF") | |
seeds_in = gr.Textbox(label="Seeds (chủ đề), phân tách bằng dấu phẩy", value="lập trình hướng đối tượng, kế thừa") | |
questions_per_seed = gr.Slider(label="Questions per seed", minimum=1, maximum=5, step=1, value=1) | |
k_retrieve = gr.Slider(label="K retrieve (số đoạn liên quan)", minimum=1, maximum=10, step=1, value=DEFAULT_K_RETRIEVE) | |
chunk_size = gr.Number(label="Chunk size (chars)", value=DEFAULT_CHUNK_SIZE) | |
chunk_overlap = gr.Number(label="Chunk overlap (chars)", value=DEFAULT_CHUNK_OVERLAP) | |
embedding_model = gr.Textbox(label="Embedding model", value=DEFAULT_EMBEDDING_MODEL) | |
llm_model = gr.Textbox(label="LLM model", value=DEFAULT_LLM_MODEL) | |
temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, step=0.05, value=0.0) | |
btn = gr.Button("Generate MCQs") | |
with gr.Column(scale=1): | |
out_text = gr.Textbox(label="Raw JSON output", lines=20) | |
out_file = gr.File(label="Download JSON") | |
btn.click(fn=ui_run, inputs=[pdf_in, seeds_in, questions_per_seed, k_retrieve, chunk_size, chunk_overlap, | |
embedding_model, llm_model, temperature], outputs=[out_text, out_file]) | |
if __name__ == "__main__": | |
demo.launch() | |