File size: 4,839 Bytes
99cc3fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c5f444
99cc3fd
 
6c5f444
 
99cc3fd
 
6c5f444
99cc3fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import streamlit as st
import torch
import os
import time
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.schema import Document

# --- HF Token ---
HF_TOKEN = st.secrets["HF_TOKEN"]

# --- Page Config ---
st.set_page_config(page_title="DigiTwin RAG", page_icon="πŸ“‚", layout="centered")
st.title("πŸ“‚ DigiTwin RAG Chat (GM Qwen 1.8B)")

# --- Upload Files Sidebar ---
with st.sidebar:
    st.header("πŸ“„ Upload Knowledge Files")
    uploaded_files = st.file_uploader("Upload PDFs or .txt files", accept_multiple_files=True, type=["pdf", "txt"])
    if uploaded_files:
        st.success(f"{len(uploaded_files)} file(s) uploaded")

# --- Model Loading ---
@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained("amiguel/GM_Qwen1.8B_Finetune", trust_remote_code=True, token=HF_TOKEN)
    model = AutoModelForCausalLM.from_pretrained(
        "amiguel/GM_Qwen1.8B_Finetune",
        device_map="auto",
        torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
        trust_remote_code=True,
        token=HF_TOKEN
    )
    return model, tokenizer

model, tokenizer = load_model()

# --- Prompt Helper ---
SYSTEM_PROMPT = (
    "You are DigiTwin, an expert advisor in asset integrity and reliability engineering. "
    "Use the provided context from uploaded documents to answer precisely and professionally."
)

def build_prompt(messages, context=""):
    prompt = f"<|im_start|>system\n{SYSTEM_PROMPT}\n\nContext:\n{context}<|im_end|>\n"
    for msg in messages:
        role = msg["role"]
        prompt += f"<|im_start|>{role}\n{msg['content']}<|im_end|>\n"
    prompt += "<|im_start|>assistant\n"
    return prompt


# --- RAG Embedding and Search ---
@st.cache_resource
def embed_uploaded_files(files):
    raw_docs = []
    for f in files:
        file_path = f"/tmp/{f.name}"
        with open(file_path, "wb") as out_file:
            out_file.write(f.read())

        loader = PyPDFLoader(file_path) if f.name.endswith(".pdf") else TextLoader(file_path)
        raw_docs.extend(loader.load())

    splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=64)
    chunks = splitter.split_documents(raw_docs)
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    db = FAISS.from_documents(chunks, embedding=embeddings)
    return db

retriever = embed_uploaded_files(uploaded_files) if uploaded_files else None

# --- Streaming Response ---
def generate_response(prompt_text):
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
    thread = Thread(target=model.generate, kwargs={
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "max_new_tokens": 1024,
        "temperature": 0.7,
        "top_p": 0.9,
        "repetition_penalty": 1.1,
        "do_sample": True,
        "streamer": streamer
    })
    thread.start()
    return streamer

# --- Avatars & Messages ---
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"

if "messages" not in st.session_state:
    st.session_state.messages = []

for msg in st.session_state.messages:
    avatar = USER_AVATAR if msg["role"] == "user" else BOT_AVATAR
    with st.chat_message(msg["role"], avatar=avatar):
        st.markdown(msg["content"])

# --- Chat UI ---
if prompt := st.chat_input("Ask something based on uploaded documents..."):
    st.chat_message("user", avatar=USER_AVATAR).markdown(prompt)
    st.session_state.messages.append({"role": "user", "content": prompt})

    context = ""
    if retriever:
        docs = retriever.similarity_search(prompt, k=3)
        context = "\n\n".join([d.page_content for d in docs])

    full_prompt = build_prompt(st.session_state.messages, context=context)

    with st.chat_message("assistant", avatar=BOT_AVATAR):
        start_time = time.time()
        streamer = generate_response(full_prompt)
        container = st.empty()
        answer = ""
        for chunk in streamer:
            answer += chunk
            container.markdown(answer + "β–Œ", unsafe_allow_html=True)
        container.markdown(answer)
        st.session_state.messages.append({"role": "assistant", "content": answer})