citrus / app.py
tomas.helmfridsson
renamed LLM fixed 356M
7f38362
# ── app.py ───────────────────────────────────────────────────────────
import os, logging, textwrap
import gradio as gr
from transformers import pipeline, AutoTokenizer
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import FAISS
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
import concurrent.futures
# ── KONFIG ───────────────────────────────────────────────────────────
DOCS_DIR = "document"
INDEX_DIR = "faiss_index"
EMB_MODEL = "KBLab/sentence-bert-swedish-cased"
#LLM_MODEL = "tiiuae/falcon-rw-1b" # DΓ₯lig
#LLM_MODEL = "google/flan-t5-base" # DΓ₯lig
#LLM_MODEL = "bigscience/bloom-560m" # DΓ₯lig
#LLM_MODEL = "NbAiLab/nb-gpt-j-6B" #- Restricted
#LLM_MODEL = "datificate/gpt2-small-swedish" # Finns ej pΓ₯ Hugging face
#LLM_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# timpal0l/mdeberta-v3-base-squad2 liten och mΓΆjlig pΓ₯ Svenska
#LLM_MODEL = "AI-Sweden-Models/gpt-sw3-1.3B" # finns olika varianter 126M, 356M, 1.3B, 6.7B, 20B, 40B
LLM_MODEL = "AI-Sweden-Models/gpt-sw3-356M"
# LLM_MODEL = AI-Sweden-Models/Llama-3-8B-instruct # kanske fΓΆr stor
# https://www.ai.se/en/ai-labs/natural-language-understanding/models-resources
CHUNK_SIZE = 400
CHUNK_OVERLAP = 40
CTX_TOK_MAX = 750 # sparar marginal till frΓ₯ga + svar
MAX_NEW_TOKENS = 512
K = 5
DEFAULT_TEMP = 0.8
GEN_TIMEOUT = 180 # Timeout fΓΆr generering i sekunder
# ── LOGGING ──────────────────────────────────────────────────────────
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
log = logging.getLogger(__name__)
# ── 1) Index (bygg eller ladda) ─────────────────────────────────────
emb = HuggingFaceEmbeddings(model_name=EMB_MODEL)
INDEX_PATH = os.path.join(INDEX_DIR, "index.faiss")
if os.path.isfile(INDEX_PATH):
log.info(f"πŸ”„ Laddar index frΓ₯n {INDEX_DIR}")
vs = FAISS.load_local(INDEX_DIR, emb)
else:
splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
docs, pdfs = [], []
for fn in os.listdir(DOCS_DIR):
if fn.lower().endswith(".pdf"):
chunks = splitter.split_documents(PyPDFLoader(os.path.join(DOCS_DIR, fn)).load())
for c in chunks:
c.metadata["source"] = fn
docs.extend(chunks); pdfs.append(fn)
vs = FAISS.from_documents(docs, emb); vs.save_local(INDEX_DIR)
log.info(f"βœ… Byggt index – {len(pdfs)}β€―PDF / {len(docs)}β€―chunkar")
retriever = vs.as_retriever(search_kwargs={"k": K})
# ── 2) LLM‑pipeline & tokenizer ─────────────────────────────────────
log.info("πŸš€ Initierar LLM …")
gen_pipe = pipeline("text-generation", model=LLM_MODEL, device=-1, max_new_tokens=MAX_NEW_TOKENS)
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
log.info("βœ… LLM klar")
# ── 3) HjΓ€lpfunktioner ──────────────────────────────────────────────
def build_prompt(query: str, docs):
"""
Tar sΓ₯ mΓ₯nga chunkar som ryms i CTX_TOK_MAX token
"""
context_parts = []
total_ctx_tok = 0
for d in docs:
tok_len = len(tokenizer.encode(d.page_content))
if total_ctx_tok + tok_len > CTX_TOK_MAX:
break
context_parts.append(d.page_content)
total_ctx_tok += tok_len
context = "\n\n---\n\n".join(context_parts)
return textwrap.dedent(f"""\
Du Γ€r en hjΓ€lpsam assistent som svarar pΓ₯ svenska.
Kontext (hΓ€mtat ur PDF‑dokument):
{context}
FrΓ₯ga: {query}
Svar (svenska):""").strip()
def test_retrieval(q): # snabb‑test utan AI
docs = retriever.invoke(q)
return "\n\n".join([f"{i+1}. ({d.metadata['source']}) {d.page_content[:160]}…" for i, d in enumerate(docs)]) or "🚫 Inga trΓ€ffar"
def chat_fn(q, temp, max_new_tokens, k, ctx_tok_max, history):
history = history or []
history.append({"role": "user", "content": q})
# HΓ€mta chunkar och poΓ€ng
docs_and_scores = vs.similarity_search_with_score(q, k=int(k))
docs = [doc for doc, score in docs_and_scores]
scores = [score for doc, score in docs_and_scores]
if not docs:
history.append({"role": "assistant", "content": "🚫 Hittade inget relevant."})
return history, history
# Visa chunkar och poΓ€ng
chunk_info = "\n\n".join([
f"{i+1}. ({d.metadata['source']}) score={scores[i]:.3f}\n{d.page_content[:160]}…"
for i, d in enumerate(docs)
])
history.append({"role": "system", "content": f"πŸ”Ž Chunkar som anvΓ€nds:\n{chunk_info}"})
def build_prompt_dynamic(query, docs, ctx_tok_max):
context_parts = []
total_ctx_tok = 0
for d in docs:
tok_len = len(tokenizer.encode(d.page_content))
if total_ctx_tok + tok_len > int(ctx_tok_max):
break
context_parts.append(d.page_content)
total_ctx_tok += tok_len
context = "\n\n---\n\n".join(context_parts)
return textwrap.dedent(f"""\
Du Γ€r en hjΓ€lpsam assistent som svarar pΓ₯ svenska.
Kontext (hΓ€mtat ur PDF‑dokument):
{context}
FrΓ₯ga: {query}
Svar (svenska):""").strip()
prompt = build_prompt_dynamic(q, docs, ctx_tok_max)
log.info(f"Prompt tokens={len(tokenizer.encode(prompt))} temp={temp} max_new_tokens={max_new_tokens} k={k} ctx_tok_max={ctx_tok_max}")
def generate():
return gen_pipe(
prompt,
temperature=float(temp),
max_new_tokens=int(max_new_tokens),
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=True,
return_full_text=False
)[0]["generated_text"]
try:
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(generate)
ans = future.result(timeout=GEN_TIMEOUT) # Timeout in seconds
except concurrent.futures.TimeoutError:
ans = f"⏰ Ingen respons frΓ₯n modellen inom {GEN_TIMEOUT} sekunder."
except Exception as e:
log.exception("Genererings‑fel")
ans = f"❌ Fel vid generering: {type(e).__name__}: {e}\n\nPrompt:\n{prompt}"
src_hint = docs[0].metadata["source"] if docs else "Ingen kΓ€lla"
history.append({"role": "assistant", "content": f"**(KΓ€lla: {src_hint})**\n\n{ans}"})
return history, history
# ── 4) Gradio UI ────────────────────────────────────────────────────
with gr.Blocks() as demo:
gr.Markdown("# πŸ“š Svensk RAG‑chat")
gr.Markdown(f"**PDF‑filer:** {', '.join(os.listdir(DOCS_DIR)) or '–'}")
gr.Markdown(f"**LLM-modell som anvΓ€nds:** `{LLM_MODEL}`", elem_id="llm-info")
with gr.Row():
q_test = gr.Textbox(label="πŸ”Ž Test Retrieval")
b_test = gr.Button("Testa")
o_test = gr.Textbox(label="Chunkar")
with gr.Row():
q_in = gr.Textbox(label="FrΓ₯ga", placeholder="Ex: Vad Γ€r fΓΆrvaltningsΓΆverlΓ€mnande?")
temp = gr.Slider(0, 1, value=DEFAULT_TEMP, step=0.05, label="Temperatur")
max_new_tokens = gr.Slider(32, 1024, value=MAX_NEW_TOKENS, step=8, label="Max svarslΓ€ngd (tokens)")
k = gr.Slider(1, 10, value=K, step=1, label="Antal chunkar (K)")
ctx_tok_max = gr.Slider(100, 2000, value=CTX_TOK_MAX, step=50, label="Max kontexttokens")
b_send = gr.Button("Skicka")
b_stop = gr.Button("Stoppa") # LΓ€gg till stoppknapp
chat = gr.Chatbot(type="messages", label="Chat")
chat_hist = gr.State([])
b_test.click(test_retrieval, inputs=[q_test], outputs=[o_test])
send_event = b_send.click(
chat_fn,
inputs=[q_in, temp, max_new_tokens, k, ctx_tok_max, chat_hist],
outputs=[chat, chat_hist]
)
b_stop.click(None, cancels=[send_event])
if __name__ == "__main__":
demo.launch(share=True) # ta bort share=True om du vill hΓ₯lla den privat