amiguel commited on
Commit
99cc3fd
Β·
verified Β·
1 Parent(s): cf96989

Upload digitwin_rag_qwen_app.py

Browse files
Files changed (1) hide show
  1. digitwin_rag_qwen_app.py +131 -0
digitwin_rag_qwen_app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import os
4
+ import time
5
+ from threading import Thread
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
7
+ from langchain_community.document_loaders import PyPDFLoader, TextLoader
8
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
9
+ from langchain_community.embeddings import HuggingFaceEmbeddings
10
+ from langchain.vectorstores import FAISS
11
+ from langchain.schema import Document
12
+
13
+ # --- HF Token ---
14
+ HF_TOKEN = st.secrets["HF_TOKEN"]
15
+
16
+ # --- Page Config ---
17
+ st.set_page_config(page_title="DigiTwin RAG", page_icon="πŸ“‚", layout="centered")
18
+ st.title("πŸ“‚ DigiTwin RAG Chat (GM Qwen 1.8B)")
19
+
20
+ # --- Upload Files Sidebar ---
21
+ with st.sidebar:
22
+ st.header("πŸ“„ Upload Knowledge Files")
23
+ uploaded_files = st.file_uploader("Upload PDFs or .txt files", accept_multiple_files=True, type=["pdf", "txt"])
24
+ if uploaded_files:
25
+ st.success(f"{len(uploaded_files)} file(s) uploaded")
26
+
27
+ # --- Model Loading ---
28
+ @st.cache_resource
29
+ def load_model():
30
+ tokenizer = AutoTokenizer.from_pretrained("amiguel/GM_Qwen1.8B_Finetune", trust_remote_code=True, token=HF_TOKEN)
31
+ model = AutoModelForCausalLM.from_pretrained(
32
+ "amiguel/GM_Qwen1.8B_Finetune",
33
+ device_map="auto",
34
+ torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
35
+ trust_remote_code=True,
36
+ token=HF_TOKEN
37
+ )
38
+ return model, tokenizer
39
+
40
+ model, tokenizer = load_model()
41
+
42
+ # --- Prompt Helper ---
43
+ SYSTEM_PROMPT = (
44
+ "You are DigiTwin, an expert advisor in asset integrity and reliability engineering. "
45
+ "Use the provided context from uploaded documents to answer precisely and professionally."
46
+ )
47
+
48
+ def build_prompt(messages, context=""):
49
+ prompt = f"<|im_start|>system\n{SYSTEM_PROMPT}\n
50
+ Context:
51
+ {context}<|im_end|>
52
+ "
53
+ for msg in messages:
54
+ role = msg["role"]
55
+ prompt += f"<|im_start|>{role}\n{msg['content']}<|im_end|>
56
+ "
57
+ prompt += "<|im_start|>assistant
58
+ "
59
+ return prompt
60
+
61
+ # --- RAG Embedding and Search ---
62
+ @st.cache_resource
63
+ def embed_uploaded_files(files):
64
+ raw_docs = []
65
+ for f in files:
66
+ file_path = f"/tmp/{f.name}"
67
+ with open(file_path, "wb") as out_file:
68
+ out_file.write(f.read())
69
+
70
+ loader = PyPDFLoader(file_path) if f.name.endswith(".pdf") else TextLoader(file_path)
71
+ raw_docs.extend(loader.load())
72
+
73
+ splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=64)
74
+ chunks = splitter.split_documents(raw_docs)
75
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
76
+ db = FAISS.from_documents(chunks, embedding=embeddings)
77
+ return db
78
+
79
+ retriever = embed_uploaded_files(uploaded_files) if uploaded_files else None
80
+
81
+ # --- Streaming Response ---
82
+ def generate_response(prompt_text):
83
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
84
+ inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
85
+ thread = Thread(target=model.generate, kwargs={
86
+ "input_ids": inputs["input_ids"],
87
+ "attention_mask": inputs["attention_mask"],
88
+ "max_new_tokens": 1024,
89
+ "temperature": 0.7,
90
+ "top_p": 0.9,
91
+ "repetition_penalty": 1.1,
92
+ "do_sample": True,
93
+ "streamer": streamer
94
+ })
95
+ thread.start()
96
+ return streamer
97
+
98
+ # --- Avatars & Messages ---
99
+ USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
100
+ BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"
101
+
102
+ if "messages" not in st.session_state:
103
+ st.session_state.messages = []
104
+
105
+ for msg in st.session_state.messages:
106
+ avatar = USER_AVATAR if msg["role"] == "user" else BOT_AVATAR
107
+ with st.chat_message(msg["role"], avatar=avatar):
108
+ st.markdown(msg["content"])
109
+
110
+ # --- Chat UI ---
111
+ if prompt := st.chat_input("Ask something based on uploaded documents..."):
112
+ st.chat_message("user", avatar=USER_AVATAR).markdown(prompt)
113
+ st.session_state.messages.append({"role": "user", "content": prompt})
114
+
115
+ context = ""
116
+ if retriever:
117
+ docs = retriever.similarity_search(prompt, k=3)
118
+ context = "\n\n".join([d.page_content for d in docs])
119
+
120
+ full_prompt = build_prompt(st.session_state.messages, context=context)
121
+
122
+ with st.chat_message("assistant", avatar=BOT_AVATAR):
123
+ start_time = time.time()
124
+ streamer = generate_response(full_prompt)
125
+ container = st.empty()
126
+ answer = ""
127
+ for chunk in streamer:
128
+ answer += chunk
129
+ container.markdown(answer + "β–Œ", unsafe_allow_html=True)
130
+ container.markdown(answer)
131
+ st.session_state.messages.append({"role": "assistant", "content": answer})