Update pages/type_text.py
Browse files- pages/type_text.py +28 -15
pages/type_text.py
CHANGED
@@ -111,10 +111,16 @@ st_models = {
|
|
111 |
'original model for general domain, best performance: all-mpnet-base-v2': 'all-mpnet-base-v2',
|
112 |
'fine-tuned model for medical domain: all-mpnet-base-v2': 'all-mpnet-base-v2',
|
113 |
}
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
## Create the select box
|
116 |
-
selected_st_model = st.selectbox('Choose a model:', list(st_models.keys()))
|
117 |
-
st.write("
|
118 |
|
119 |
## Get the selected model
|
120 |
SentTrans_model = st_models[selected_st_model]
|
@@ -126,14 +132,6 @@ def load_model():
|
|
126 |
return model
|
127 |
model = load_model()
|
128 |
|
129 |
-
#model = SentenceTransformer('all-MiniLM-L6-v2') # fastest
|
130 |
-
#model = SentenceTransformer('all-mpnet-base-v2') # best performance
|
131 |
-
#model = SentenceTransformers('all-distilroberta-v1')
|
132 |
-
#model = SentenceTransformer('sentence-transformers/msmarco-bert-base-dot-v5')
|
133 |
-
#model = SentenceTransformer('clips/mfaq')
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
|
138 |
|
139 |
INTdesc_embedding = model.encode(INTdesc_input)
|
@@ -148,13 +146,28 @@ HF_model_results = util.semantic_search(INTdesc_embedding, SBScorpus_embeddings)
|
|
148 |
HF_model_results_sorted = sorted(HF_model_results, key=lambda x: x[1], reverse=True)
|
149 |
HF_model_results_displayed = HF_model_results_sorted[0:numMAPPINGS_input]
|
150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
@st.cache_resource
|
152 |
def load_pipe():
|
153 |
-
pipe = pipeline("text-generation", model=
|
154 |
return pipe
|
155 |
pipe = load_pipe()
|
156 |
|
157 |
-
#pipe = pipeline("text-generation", model="meta-llama/Llama-3.2-1B-Instruct", device_map=device,) # device_map="auto", torch_dtype=torch.bfloat16
|
158 |
|
159 |
dictA = {"Score": [], "SBS Code": [], "SBS Description V2.0": []}
|
160 |
dfALL = pd.DataFrame.from_dict(dictA)
|
@@ -174,13 +187,13 @@ if INTdesc_input is not None and createSBScodes_clicked == True:
|
|
174 |
|
175 |
st.dataframe(data=dfALL, hide_index=True)
|
176 |
|
177 |
-
display_format = "ask REASONING MODEL: Which, if any, of the
|
178 |
#st.write(display_format)
|
179 |
question = "Which one, if any, of the following Saudi Billing System descriptions A, B, C, D, or E corresponds best to " + INTdesc_input +"? "
|
180 |
shortlist = [SBScorpus[result[0]["corpus_id"]], SBScorpus[result[1]["corpus_id"]], SBScorpus[result[2]["corpus_id"]], SBScorpus[result[3]["corpus_id"]], SBScorpus[result[4]["corpus_id"]]]
|
181 |
prompt = question + " " +"A: "+ shortlist[0] + " " +"B: " + shortlist[1] + " " + "C: " + shortlist[2] + " " + "D: " + shortlist[3] + " " + "E: " + shortlist[4]
|
182 |
st.write(prompt)
|
183 |
-
|
184 |
messages = [
|
185 |
{"role": "system", "content": "You are a knowledgable AI assistant who always answers truthfully and precisely!"},
|
186 |
{"role": "user", "content": prompt},
|
|
|
111 |
'original model for general domain, best performance: all-mpnet-base-v2': 'all-mpnet-base-v2',
|
112 |
'fine-tuned model for medical domain: all-mpnet-base-v2': 'all-mpnet-base-v2',
|
113 |
}
|
114 |
+
|
115 |
+
#model = SentenceTransformer('all-MiniLM-L6-v2') # fastest
|
116 |
+
#model = SentenceTransformer('all-mpnet-base-v2') # best performance
|
117 |
+
#model = SentenceTransformers('all-distilroberta-v1')
|
118 |
+
#model = SentenceTransformer('sentence-transformers/msmarco-bert-base-dot-v5')
|
119 |
+
#model = SentenceTransformer('clips/mfaq')
|
120 |
+
|
121 |
## Create the select box
|
122 |
+
selected_st_model = st.selectbox('Choose a Sentence Transformer model:', list(st_models.keys()))
|
123 |
+
st.write("Current selection:", selected_st_model)
|
124 |
|
125 |
## Get the selected model
|
126 |
SentTrans_model = st_models[selected_st_model]
|
|
|
132 |
return model
|
133 |
model = load_model()
|
134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
|
137 |
INTdesc_embedding = model.encode(INTdesc_input)
|
|
|
146 |
HF_model_results_sorted = sorted(HF_model_results, key=lambda x: x[1], reverse=True)
|
147 |
HF_model_results_displayed = HF_model_results_sorted[0:numMAPPINGS_input]
|
148 |
|
149 |
+
## Define the Reasoning models
|
150 |
+
rs_models = {
|
151 |
+
'original model for general domain, faster: meta-llama/Llama-3.2-1B-Instruct': 'meta-llama/Llama-3.2-1B-Instruct',
|
152 |
+
'fine-tuned model for medical domain: meta-llama/Llama-3.2-1B-Instruct': 'meta-llama/Llama-3.2-1B-Instruct',
|
153 |
+
'original model for general domain, slower: Qwen/Qwen2-1.5B-Instruct': 'Qwen/Qwen2-1.5B-Instruct',
|
154 |
+
'fine-tuned model for medical domain: Qwen/Qwen2-1.5B-Instruct': 'Qwen/Qwen2-1.5B-Instruct',
|
155 |
+
}
|
156 |
+
|
157 |
+
## Create the select box
|
158 |
+
selected_rs_model = st.selectbox('Choose a Reasoning model:', list(st_models.keys()))
|
159 |
+
st.write("Current selection:", selected_rs_model)
|
160 |
+
|
161 |
+
## Get the selected model
|
162 |
+
Reasoning_model = rs_models[selected_rs_model]
|
163 |
+
|
164 |
+
## Use the model as pipeline ...
|
165 |
@st.cache_resource
|
166 |
def load_pipe():
|
167 |
+
pipe = pipeline("text-generation", model=Reasoning_model, device_map=device,) # device_map="auto", torch_dtype=torch.bfloat16
|
168 |
return pipe
|
169 |
pipe = load_pipe()
|
170 |
|
|
|
171 |
|
172 |
dictA = {"Score": [], "SBS Code": [], "SBS Description V2.0": []}
|
173 |
dfALL = pd.DataFrame.from_dict(dictA)
|
|
|
187 |
|
188 |
st.dataframe(data=dfALL, hide_index=True)
|
189 |
|
190 |
+
display_format = "ask REASONING MODEL: Which, if any, of the following SBS descriptions corresponds best to " + INTdesc_input +"? "
|
191 |
#st.write(display_format)
|
192 |
question = "Which one, if any, of the following Saudi Billing System descriptions A, B, C, D, or E corresponds best to " + INTdesc_input +"? "
|
193 |
shortlist = [SBScorpus[result[0]["corpus_id"]], SBScorpus[result[1]["corpus_id"]], SBScorpus[result[2]["corpus_id"]], SBScorpus[result[3]["corpus_id"]], SBScorpus[result[4]["corpus_id"]]]
|
194 |
prompt = question + " " +"A: "+ shortlist[0] + " " +"B: " + shortlist[1] + " " + "C: " + shortlist[2] + " " + "D: " + shortlist[3] + " " + "E: " + shortlist[4]
|
195 |
st.write(prompt)
|
196 |
+
|
197 |
messages = [
|
198 |
{"role": "system", "content": "You are a knowledgable AI assistant who always answers truthfully and precisely!"},
|
199 |
{"role": "user", "content": prompt},
|