amiguel commited on
Commit
e6d99a3
Β·
verified Β·
1 Parent(s): 64615ed

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # πŸš€ Streamlit v2 for GM_Qwen1.8B_Finetune
2
+ import streamlit as st
3
+ import torch
4
+ import os
5
+ import time
6
+ from threading import Thread
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
8
+ from huggingface_hub import login
9
+
10
+ # --- Hugging Face Token ---
11
+ HF_TOKEN = os.environ.get("HF_TOKEN") # or hardcode "hf_xxxxxxx"
12
+ login(token=HF_TOKEN)
13
+
14
+ # --- Streamlit page config ---
15
+ st.set_page_config(
16
+ page_title="Fine-tune DigiTwin - ValLabs πŸš€",
17
+ page_icon="πŸš€",
18
+ layout="centered"
19
+ )
20
+
21
+ st.title("πŸš€ Fine-tune DigiTwin - ValLabs πŸš€")
22
+
23
+ # Avatars
24
+ USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
25
+ BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"
26
+
27
+ # --- Load model and tokenizer ---
28
+ @st.cache_resource
29
+ def load_model():
30
+ tokenizer = AutoTokenizer.from_pretrained(
31
+ "amiguel/GM_Qwen1.8B_Finetune",
32
+ trust_remote_code=True,
33
+ token=HF_TOKEN
34
+ )
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ "amiguel/GM_Qwen1.8B_Finetune",
37
+ device_map="auto",
38
+ torch_dtype=torch.bfloat16,
39
+ trust_remote_code=True,
40
+ token=HF_TOKEN
41
+ )
42
+ return model, tokenizer
43
+
44
+ model, tokenizer = load_model()
45
+
46
+ # --- Session state for chat history ---
47
+ if "messages" not in st.session_state:
48
+ st.session_state.messages = []
49
+
50
+ # --- Streamer function ---
51
+ def generate_response(prompt, model, tokenizer):
52
+ streamer = TextIteratorStreamer(
53
+ tokenizer,
54
+ skip_prompt=True,
55
+ skip_special_tokens=True
56
+ )
57
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
58
+
59
+ generation_kwargs = {
60
+ "input_ids": inputs["input_ids"],
61
+ "attention_mask": inputs["attention_mask"],
62
+ "max_new_tokens": 1024,
63
+ "temperature": 0.7,
64
+ "top_p": 0.9,
65
+ "repetition_penalty": 1.1,
66
+ "do_sample": True,
67
+ "streamer": streamer
68
+ }
69
+
70
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
71
+ thread.start()
72
+ return streamer
73
+
74
+ # --- Display previous chat history ---
75
+ for message in st.session_state.messages:
76
+ avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
77
+ with st.chat_message(message["role"], avatar=avatar):
78
+ st.markdown(message["content"])
79
+
80
+ # --- User input ---
81
+ if prompt := st.chat_input("Ask me anything about your inspection knowledge..."):
82
+
83
+ # Display user prompt
84
+ with st.chat_message("user", avatar=USER_AVATAR):
85
+ st.markdown(prompt)
86
+ st.session_state.messages.append({"role": "user", "content": prompt})
87
+
88
+ # Generate assistant response
89
+ if model and tokenizer:
90
+ try:
91
+ with st.chat_message("assistant", avatar=BOT_AVATAR):
92
+ start_time = time.time()
93
+ streamer = generate_response(prompt, model, tokenizer)
94
+
95
+ response_container = st.empty()
96
+ full_response = ""
97
+
98
+ for chunk in streamer:
99
+ full_response += chunk
100
+ response_container.markdown(full_response + "β–Œ", unsafe_allow_html=True)
101
+
102
+ end_time = time.time()
103
+ input_tokens = len(tokenizer(prompt)["input_ids"])
104
+ output_tokens = len(tokenizer(full_response)["input_ids"])
105
+ speed = output_tokens / (end_time - start_time)
106
+
107
+ # (Optional) token-based cost estimation if running commercial APIs
108
+ st.caption(
109
+ f"πŸ”‘ Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | "
110
+ f"πŸ•’ Speed: {speed:.1f} tokens/sec"
111
+ )
112
+
113
+ response_container.markdown(full_response)
114
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
115
+
116
+ except Exception as e:
117
+ st.error(f"⚑ Generation error: {str(e)}")
118
+ else:
119
+ st.error("πŸ€– Model not loaded!")