File size: 7,103 Bytes
99cc3fd
 
 
be6c945
99cc3fd
 
 
 
 
 
 
be6c945
1adf5f1
99cc3fd
 
 
 
85bf964
99cc3fd
7f4899e
99cc3fd
 
 
7f4899e
f39631c
 
ba95cd5
1adf5f1
99cc3fd
7f4899e
 
 
2b00c83
80f7f4c
2b00c83
 
7f4899e
 
 
 
f39631c
7f4899e
f39631c
 
 
 
be6c945
7f4899e
f39631c
7f4899e
f39631c
1adf5f1
f39631c
 
 
 
 
 
 
 
1adf5f1
e59e5c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f39631c
be6c945
e59e5c8
1adf5f1
f39631c
 
 
 
1adf5f1
 
f39631c
1adf5f1
f39631c
 
 
 
 
 
 
 
 
 
1adf5f1
f39631c
99cc3fd
f39631c
 
 
 
 
 
 
 
 
 
 
99cc3fd
f39631c
 
1adf5f1
f39631c
 
 
1adf5f1
f39631c
 
be6c945
1adf5f1
be6c945
1adf5f1
be6c945
 
1adf5f1
be6c945
 
 
 
 
1adf5f1
be6c945
f39631c
1adf5f1
be6c945
1adf5f1
 
e59e5c8
be6c945
 
1adf5f1
be6c945
 
1adf5f1
 
be6c945
7f4899e
 
 
 
 
 
 
1adf5f1
 
7f4899e
1adf5f1
 
7f4899e
1adf5f1
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import streamlit as st
import torch
import os
import time
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.schema import Document

# --- Hugging Face Token ---
HF_TOKEN = st.secrets["HF_TOKEN"]

# --- Page Config ---
st.set_page_config(page_title="DigiTwin RAG", page_icon="πŸ“‚", layout="centered")
st.title("πŸ“‚ DigiTs the Twin")

# --- Sidebar ---
with st.sidebar:
    st.header("πŸ“„ Upload Knowledge Files")
    uploaded_files = st.file_uploader("Upload PDFs or .txt files", accept_multiple_files=True, type=["pdf", "txt"])
    model_choice = st.selectbox("🧠 Choose Model", ["Qwen", "Mistral"])
    if uploaded_files:
        st.success(f"{len(uploaded_files)} file(s) uploaded")

# --- Load Model & Tokenizer ---
@st.cache_resource
def load_model(selected_model):
    if selected_model == "Qwen":
        model_id = "amiguel/GM_Qwen1.8B_Finetune"

    elif selected_model == "Llama":
          model_id = "amiguel/Llama3_8B_Instruct_FP16"
    
    else:
        model_id = "amiguel/GM_Mistral7B_Finetune"

    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, token=HF_TOKEN)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
        trust_remote_code=True,
        token=HF_TOKEN
    )
    return model, tokenizer, model_id

model, tokenizer, model_id = load_model(model_choice)

# --- System Prompt ---
SYSTEM_PROMPT = (
    "You are DigiTwin, a digital expert and senior topside engineer specializing in inspection and maintenance "
    "of offshore piping systems, structural elements, mechanical equipment, floating production units, pressure vessels "
    "(with emphasis on Visual Internal Inspection - VII), and pressure safety devices (PSDs). Rely on uploaded documents "
    "and context to provide practical, standards-driven, and technically accurate responses. Your guidance reflects deep "
    "field experience, industry regulations, and proven methodologies in asset integrity and reliability engineering."
)

# --- Prompt Builder ---
def build_prompt(messages, context="", model_name="Qwen"):
    if "Mistral" in model_name:
        # Alpaca-style prompt
        prompt = f"You are DigiTwin, an expert in offshore inspection, maintenance, and asset integrity.\n"
        if context:
            prompt += f"Here is relevant context:\n{context}\n\n"
        for msg in messages:
            if msg["role"] == "user":
                prompt += f"### Instruction:\n{msg['content'].strip()}\n"
            elif msg["role"] == "assistant":
                prompt += f"### Response:\n{msg['content'].strip()}\n"
        prompt += "### Response:\n"
    else:
        # Qwen-style
        prompt = f"<|im_start|>system\n{SYSTEM_PROMPT}\n\nContext:\n{context}<|im_end|>\n"
        for msg in messages:
            role = msg["role"]
            prompt += f"<|im_start|>{role}\n{msg['content']}<|im_end|>\n"
        prompt += "<|im_start|>assistant\n"
    return prompt


# --- Embed Uploaded Documents ---
@st.cache_resource
def embed_uploaded_files(files):
    raw_docs = []
    for f in files:
        path = f"/tmp/{f.name}"
        with open(path, "wb") as out_file:
            out_file.write(f.read())
        loader = PyPDFLoader(path) if f.name.endswith(".pdf") else TextLoader(path)
        raw_docs.extend(loader.load())

    splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=64)
    chunks = splitter.split_documents(raw_docs)
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    db = FAISS.from_documents(chunks, embedding=embeddings)
    return db

retriever = embed_uploaded_files(uploaded_files) if uploaded_files else None

# --- Streaming Generator ---
def generate_response(prompt_text):
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
    thread = Thread(target=model.generate, kwargs={
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "max_new_tokens": 1024,
        "temperature": 0.7,
        "top_p": 0.9,
        "repetition_penalty": 1.1,
        "do_sample": True,
        "streamer": streamer
    })
    thread.start()
    return streamer

# --- Avatars ---
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"

# --- Initialize Chat Memory ---
if "messages" not in st.session_state:
    st.session_state.messages = []

# --- Display Message History ---
for msg in st.session_state.messages:
    with st.chat_message(msg["role"], avatar=USER_AVATAR if msg["role"] == "user" else BOT_AVATAR):
        st.markdown(msg["content"])

# --- Chat Interface ---
if prompt := st.chat_input("Ask something based on uploaded documents..."):
    st.chat_message("user", avatar=USER_AVATAR).markdown(prompt)
    st.session_state.messages.append({"role": "user", "content": prompt})

    context = ""
    docs = []
    if retriever:
        docs = retriever.similarity_search(prompt, k=3)
        context = "\n\n".join([doc.page_content for doc in docs])

    # Limit to last 6 messages for memory
    recent_messages = st.session_state.messages[-6:]
    full_prompt = build_prompt(recent_messages, context, model_name=model_id)

    with st.chat_message("assistant", avatar=BOT_AVATAR):
        start = time.time()
        container = st.empty()
        answer = ""

        for chunk in generate_response(full_prompt):
            answer += chunk
            cleaned = answer

            # πŸ”§ Strip <|im_start|>, <|im_end|> if using Mistral (Qwen needs them)
            if "Mistral" in model_id:
                cleaned = cleaned.replace("<|im_start|>", "").replace("<|im_end|>", "").strip()

            container.markdown(cleaned + "β–Œ", unsafe_allow_html=True)

        end = time.time()
        st.session_state.messages.append({"role": "assistant", "content": cleaned})

        input_tokens = len(tokenizer(full_prompt)["input_ids"])
        output_tokens = len(tokenizer(cleaned)["input_ids"])
        speed = output_tokens / (end - start)

        with st.expander("πŸ“Š Debug Info"):
            st.caption(
                f"πŸ”‘ Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | "
                f"πŸ•’ Speed: {speed:.1f} tokens/sec"
            )
            for i, doc in enumerate(docs):
                st.markdown(f"**Chunk #{i+1}**")
                st.code(doc.page_content.strip()[:500])