RAG / app.py
amiguel's picture
Update app.py
6c5f444 verified
raw
history blame
4.84 kB
import streamlit as st
import torch
import os
import time
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.schema import Document
# --- HF Token ---
HF_TOKEN = st.secrets["HF_TOKEN"]
# --- Page Config ---
st.set_page_config(page_title="DigiTwin RAG", page_icon="πŸ“‚", layout="centered")
st.title("πŸ“‚ DigiTwin RAG Chat (GM Qwen 1.8B)")
# --- Upload Files Sidebar ---
with st.sidebar:
st.header("πŸ“„ Upload Knowledge Files")
uploaded_files = st.file_uploader("Upload PDFs or .txt files", accept_multiple_files=True, type=["pdf", "txt"])
if uploaded_files:
st.success(f"{len(uploaded_files)} file(s) uploaded")
# --- Model Loading ---
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained("amiguel/GM_Qwen1.8B_Finetune", trust_remote_code=True, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
"amiguel/GM_Qwen1.8B_Finetune",
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
trust_remote_code=True,
token=HF_TOKEN
)
return model, tokenizer
model, tokenizer = load_model()
# --- Prompt Helper ---
SYSTEM_PROMPT = (
"You are DigiTwin, an expert advisor in asset integrity and reliability engineering. "
"Use the provided context from uploaded documents to answer precisely and professionally."
)
def build_prompt(messages, context=""):
prompt = f"<|im_start|>system\n{SYSTEM_PROMPT}\n\nContext:\n{context}<|im_end|>\n"
for msg in messages:
role = msg["role"]
prompt += f"<|im_start|>{role}\n{msg['content']}<|im_end|>\n"
prompt += "<|im_start|>assistant\n"
return prompt
# --- RAG Embedding and Search ---
@st.cache_resource
def embed_uploaded_files(files):
raw_docs = []
for f in files:
file_path = f"/tmp/{f.name}"
with open(file_path, "wb") as out_file:
out_file.write(f.read())
loader = PyPDFLoader(file_path) if f.name.endswith(".pdf") else TextLoader(file_path)
raw_docs.extend(loader.load())
splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=64)
chunks = splitter.split_documents(raw_docs)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
db = FAISS.from_documents(chunks, embedding=embeddings)
return db
retriever = embed_uploaded_files(uploaded_files) if uploaded_files else None
# --- Streaming Response ---
def generate_response(prompt_text):
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
thread = Thread(target=model.generate, kwargs={
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"max_new_tokens": 1024,
"temperature": 0.7,
"top_p": 0.9,
"repetition_penalty": 1.1,
"do_sample": True,
"streamer": streamer
})
thread.start()
return streamer
# --- Avatars & Messages ---
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"
if "messages" not in st.session_state:
st.session_state.messages = []
for msg in st.session_state.messages:
avatar = USER_AVATAR if msg["role"] == "user" else BOT_AVATAR
with st.chat_message(msg["role"], avatar=avatar):
st.markdown(msg["content"])
# --- Chat UI ---
if prompt := st.chat_input("Ask something based on uploaded documents..."):
st.chat_message("user", avatar=USER_AVATAR).markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
context = ""
if retriever:
docs = retriever.similarity_search(prompt, k=3)
context = "\n\n".join([d.page_content for d in docs])
full_prompt = build_prompt(st.session_state.messages, context=context)
with st.chat_message("assistant", avatar=BOT_AVATAR):
start_time = time.time()
streamer = generate_response(full_prompt)
container = st.empty()
answer = ""
for chunk in streamer:
answer += chunk
container.markdown(answer + "β–Œ", unsafe_allow_html=True)
container.markdown(answer)
st.session_state.messages.append({"role": "assistant", "content": answer})