Spaces:
Running
Running
José Eliel Camargo Molina
commited on
Commit
·
bf7477e
1
Parent(s):
ca2bf21
latex fixed
Browse files
app.py
CHANGED
@@ -1,429 +1,469 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
from transformers import PreTrainedTokenizerFast, BartForConditionalGeneration
|
3 |
-
import torch
|
4 |
-
from pathlib import Path
|
5 |
-
import urllib.request
|
6 |
-
|
7 |
-
# To latex stuff
|
8 |
-
####################################
|
9 |
|
10 |
-
import itertools
|
11 |
import re
|
|
|
|
|
|
|
12 |
|
|
|
13 |
rep_tex_dict = {
|
14 |
-
"SU3":{"-3":r"\bar{\textbf{3}}","3":r"\textbf{3}"},
|
15 |
-
"SU2":{"-2":r"\textbf{2}","2":r"\textbf{2}","-3":r"\textbf{3}","3":r"\textbf{3}"},
|
16 |
}
|
17 |
|
18 |
-
def fieldobj_to_tex(obj,lor_index,pos):
|
19 |
su3 = None
|
20 |
su2 = None
|
21 |
-
u1 =
|
22 |
hel = None
|
23 |
sp = None
|
24 |
|
25 |
-
#print(obj)
|
26 |
obj_mod = obj.copy()
|
27 |
for tok in obj:
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
42 |
assert sp is not None
|
43 |
|
44 |
-
outtex= ""
|
45 |
-
if sp == "0"
|
46 |
-
|
47 |
-
if sp == "1
|
|
|
|
|
|
|
48 |
|
49 |
outtex += r"_{("
|
50 |
-
|
51 |
-
|
|
|
52 |
else:
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
56 |
else:
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
60 |
else:
|
61 |
-
|
62 |
-
|
63 |
-
if
|
|
|
|
|
|
|
|
|
64 |
return outtex
|
65 |
|
66 |
-
def derobj_to_tex(obj,lor_index,pos):
|
67 |
-
if pos == "^":
|
68 |
-
|
69 |
elif pos == "_":
|
70 |
-
|
71 |
else:
|
72 |
-
|
73 |
-
if "SU3" not in obj and "SU2" not in obj and "U1" not in obj:
|
74 |
-
if pos == "^":
|
75 |
-
return "\partial^{"+lor_index+"}"
|
76 |
-
elif pos == "_":
|
77 |
-
return "\partial_{"+lor_index+"}"
|
78 |
-
|
79 |
-
if "SU3" in obj: outtex += "SU3,"
|
80 |
-
if "SU2" in obj: outtex += "SU2,"
|
81 |
-
if "U1" in obj: outtex += "U1,"
|
82 |
-
if outtex[-1] == ",": outtex = outtex[:-1]+")}"
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
return outtex
|
85 |
|
86 |
-
def gamobj_to_tex(obj,lor_index,pos):
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
if isinstance(obj,tuple):
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
99 |
return "i"
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
102 |
if obj[0] == "DERIVATIVE":
|
103 |
-
|
|
|
104 |
if obj[0] == "SIGMA":
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
if obj[0] == "COMMUTATOR_A":
|
107 |
-
|
108 |
if obj[0] == "COMMUTATOR_B":
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
if "." in string and ignore_dots == False:
|
123 |
-
#print(string)
|
124 |
-
raise ValueError("Unexpected ending to the generated Lagrangian")
|
125 |
pattern = '(' + '|'.join(map(re.escape, delimiters)) + ')'
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
|
|
130 |
|
131 |
def clean_split(inlist, delimiters):
|
|
|
|
|
|
|
|
|
132 |
i = 0
|
133 |
merged_list = []
|
134 |
while i < len(inlist):
|
135 |
if inlist[i] in delimiters:
|
136 |
if i < len(inlist) - 1:
|
137 |
merged_list.append(inlist[i] + inlist[i+1])
|
138 |
-
i += 1
|
139 |
else:
|
140 |
-
merged_list.append(inlist[i])
|
141 |
else:
|
142 |
merged_list.append(inlist[i])
|
143 |
i += 1
|
144 |
return merged_list
|
145 |
|
146 |
-
|
147 |
def get_obj_dict(inlist):
|
148 |
outdict = {}
|
149 |
for iitem in inlist:
|
150 |
-
idict = {"ID":None,"LATEX":None}
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
outdict[iitem] = idict
|
159 |
return outdict
|
160 |
|
161 |
def get_con_dict(inlist):
|
|
|
|
|
|
|
|
|
|
|
162 |
outdict = {}
|
163 |
for iitem in inlist:
|
164 |
-
|
165 |
-
|
166 |
-
sym = [
|
167 |
assert len(sym) == 1, "More than one symmetry in contraction"
|
168 |
-
ids = [
|
169 |
-
if sym[0] not in outdict
|
170 |
outdict[sym[0]] = [ids]
|
171 |
else:
|
172 |
outdict[sym[0]].append(ids)
|
173 |
return outdict
|
174 |
|
175 |
-
def term_to_tex(term,verbose=
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
if verbose:
|
190 |
-
|
191 |
-
print(i,"\t\t",j)
|
192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
-
#
|
195 |
-
contractions = [
|
196 |
-
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
-
|
200 |
-
|
201 |
-
contractions = clean_split(contractions, [" LORENTZ "," SU2 "," SU3 "])
|
202 |
-
contractions = [i for i in contractions if i != " CONTRACTIONS"]
|
203 |
-
condict = get_con_dict(contractions)
|
204 |
-
if verbose: print(condict)
|
205 |
-
if "LZ" in condict.keys():
|
206 |
firstlz = True
|
207 |
cma = True
|
208 |
-
for con in condict["
|
209 |
-
for kobj
|
210 |
-
if iobj["ID"] is None
|
|
|
211 |
if iobj["ID"] in con:
|
212 |
-
if cma:
|
213 |
-
|
214 |
-
|
|
|
215 |
if firstlz:
|
216 |
-
iobj["LATEX"] = obj_to_tex(kobj,lsymb,"^")
|
217 |
firstlz = False
|
218 |
else:
|
219 |
-
iobj["LATEX"] = obj_to_tex(kobj,lsymb,"_")
|
220 |
cma = False
|
221 |
firstlz = True
|
222 |
|
223 |
-
|
224 |
-
|
225 |
return outstr
|
226 |
-
def display_in_latex(instring,verbose=False):
|
227 |
-
#latex_string = r"$\overgroup{\Large{" + instring + "}}$"
|
228 |
-
latex_string = r"$\Large{" + instring + "}$"
|
229 |
-
if verbose: print(latex_string)
|
230 |
-
display(Latex(latex_string))
|
231 |
-
return instring
|
232 |
-
|
233 |
|
234 |
-
def str_tex(instr,num=0):
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
#outstr = ""
|
239 |
-
#instr = split_with_delimiter_preserved(iinstr,[" + ","+ "," - "])
|
240 |
-
|
241 |
if num != 0:
|
242 |
instr = instr[:num]
|
243 |
-
|
244 |
-
inlist = [term.replace(".","") for term in instr]
|
245 |
outstr = ""
|
246 |
coup = 0
|
247 |
mass = 0
|
248 |
-
outstr = "
|
249 |
-
for i, iterm in enumerate(inlist):
|
250 |
-
if i ==0:
|
251 |
-
outstr += " \mathcal{L}= \quad
|
252 |
-
else:
|
|
|
253 |
nqf = iterm.count("FIELD SPIN = 0")
|
254 |
-
nD
|
255 |
if nqf != 0 and nqf != 2 and nD == 0:
|
256 |
coup += 1
|
257 |
-
outstr += " \lambda_{
|
258 |
if nqf == 2 and nD == 0:
|
259 |
mass += 1
|
260 |
-
outstr += " m^2_{
|
261 |
-
outstr += term_to_tex(iterm,False) + " \quad "
|
262 |
-
if i%4 == 0:
|
|
|
263 |
return outstr
|
264 |
|
265 |
def master_str_tex(iinstr):
|
266 |
-
|
|
|
|
|
|
|
|
|
|
|
267 |
try:
|
268 |
outstr = str_tex(instr)
|
269 |
except Exception as e:
|
270 |
-
|
271 |
-
outstr
|
|
|
272 |
print(e)
|
273 |
-
outstr += "
|
274 |
-
return outstr
|
275 |
-
|
276 |
-
|
277 |
|
|
|
|
|
278 |
device = 'cpu'
|
279 |
model_name = "JoseEliel/BART-Lagrangian"
|
280 |
|
281 |
@st.cache_resource
|
282 |
def load_model():
|
283 |
-
|
284 |
model = BartForConditionalGeneration.from_pretrained(model_name).to(device)
|
285 |
-
|
286 |
return model
|
287 |
|
288 |
-
model = load_model()
|
289 |
-
|
290 |
@st.cache_resource
|
291 |
def load_tokenizer():
|
292 |
return PreTrainedTokenizerFast.from_pretrained(model_name)
|
293 |
|
|
|
294 |
hf_tokenizer = load_tokenizer()
|
295 |
|
|
|
|
|
296 |
def process_input(input_text):
|
|
|
297 |
input_text = input_text.replace("[SOS]", "").replace("[EOS]", "").replace("FIELD", "SPLITFIELD")
|
298 |
-
fields = input_text.split(
|
299 |
-
fields = [x.strip().split(
|
300 |
fields = sorted(fields)
|
301 |
-
fields = "[SOS] " + " ".join([" ".join(x) for x in fields]) + " [EOS]"
|
302 |
return fields
|
303 |
|
304 |
def process_output(output_text):
|
305 |
-
|
|
|
306 |
|
307 |
-
def
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
|
313 |
def generate_lagrangian(input_text):
|
|
|
|
|
|
|
314 |
input_text = process_input(input_text)
|
315 |
inputs = hf_tokenizer([input_text], return_tensors='pt').to(device)
|
316 |
-
with st.spinner(
|
317 |
-
lagrangian_ids = model.generate(inputs['input_ids'], max_length=
|
318 |
lagrangian = hf_tokenizer.decode(lagrangian_ids[0].tolist(), skip_special_tokens=False)
|
319 |
lagrangian = process_output(lagrangian)
|
320 |
return lagrangian
|
321 |
|
322 |
def generate_field(sp, su2, su3, u1):
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
|
|
328 |
components = [f"FIELD SPIN={sp} HEL=1/2"]
|
329 |
-
|
330 |
-
# Conditionally add each component
|
331 |
if su2 != "$1$":
|
332 |
components.append(f"SU2={su2}")
|
333 |
-
if su3 == "$\\bar{3}$":
|
334 |
components.append("SU3=-3")
|
335 |
-
|
336 |
-
components.append(f"SU3={su3}")
|
337 |
if u1 != "0":
|
338 |
components.append(f"U1={u1}")
|
339 |
-
|
340 |
-
# Join components into final string
|
341 |
-
return " ".join(components).replace("$","")
|
342 |
|
|
|
|
|
343 |
def main():
|
344 |
-
# Streamlit UI (Adjusted without 'className')
|
345 |
st.title("$\\mathscr{L}$agrangian Generator")
|
346 |
st.markdown(" ### For a set of chosen fields, this model generates the corresponding Lagrangian which encodes all interactions and dynamics of the fields.")
|
347 |
|
348 |
st.markdown(" #### This is a demo of our [BART](https://arxiv.org/abs/1910.13461)-based model with ca 360M parameters")
|
349 |
|
350 |
-
st.markdown(" ##### :violet[Due to computational resources, we limit the number of fields to 3
|
351 |
st.markdown(" ##### Choose up to three different fields:")
|
352 |
|
|
|
|
|
353 |
su2_options = ["$1$", "$2$", "$3$"]
|
354 |
su3_options = ["$1$", "$3$", "$\\bar{3}$"]
|
355 |
-
u1_options = ["-1","-2/3", "-1/2", "-1/3", "0","1/3"
|
356 |
spin_options = ["0", "1/2"]
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
st.session_state.field_strings = [] # Stores the generated field strings
|
363 |
|
364 |
with st.form("field_selection"):
|
365 |
spin_selection = st.radio("Select spin value:", spin_options)
|
366 |
-
su2_selection = st.radio("Select
|
367 |
-
su3_selection = st.radio("Select
|
368 |
-
u1_selection
|
369 |
submitted = st.form_submit_button("Add field")
|
370 |
if submitted:
|
371 |
if st.session_state.count < 3:
|
372 |
-
|
373 |
-
st.session_state.field_strings.append(
|
374 |
-
st.session_state.count += 1
|
375 |
-
|
376 |
-
st.write("
|
|
|
377 |
clear_fields = st.button("Clear fields")
|
378 |
if clear_fields:
|
379 |
st.session_state.field_strings = []
|
380 |
st.session_state.count = 0
|
381 |
-
# Button to generate field text, allows up to 2 button presses
|
382 |
|
383 |
-
|
|
|
384 |
for i, fs in enumerate(st.session_state.field_strings, 1):
|
385 |
texfield = obj_to_tex(fs)
|
386 |
-
|
387 |
-
st.latex("\\text{" + fieldname + "} \quad" + texfield)
|
388 |
|
|
|
389 |
if st.button("Generate Lagrangian"):
|
390 |
input_fields = " ".join(st.session_state.field_strings)
|
391 |
-
if input_fields == "":
|
392 |
-
st.write("Please add
|
393 |
return
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
print(generated_lagrangian)
|
406 |
-
print("\n")
|
407 |
-
# Save generated lagrangian into same csv file, create if not exist
|
408 |
-
#with open('usesdata.csv', 'a') as f:
|
409 |
-
# f.write(generated_lagrangian + "\n")
|
410 |
-
|
411 |
-
# add = to SU2 X, SU3 X, U1 X, SPIN X only when X is a number and not when its followd by anything not a number
|
412 |
-
#generated_lagrangian = re.sub(r"(SU2)(\s)(\d)", r"\1=\3", generated_lagrangian)
|
413 |
-
#latex_output = master_str_tex(generated_lagrangian[1:])
|
414 |
-
#print(latex_output)
|
415 |
-
#print("\n\n")
|
416 |
-
# save latex output in file
|
417 |
-
#with open('usesdata.csv', 'a') as f:
|
418 |
-
# f.write(latex_output + "\n")
|
419 |
-
#st.text_area("Generated Lagrangian", pretty_output, height=300)
|
420 |
-
st.markdown("### Generated Lagrangian")
|
421 |
-
st.text_area(generated_lagrangian, height=300)
|
422 |
-
|
423 |
-
|
424 |
-
# write my contact info
|
425 |
st.markdown("### Contact")
|
426 |
-
st.markdown("
|
427 |
|
428 |
if __name__ == "__main__":
|
429 |
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
|
|
|
2 |
import re
|
3 |
+
import streamlit as st
|
4 |
+
import torch
|
5 |
+
from transformers import PreTrainedTokenizerFast, BartForConditionalGeneration
|
6 |
|
7 |
+
# Dictionary for SU(3)/SU(2) latex representations
|
8 |
rep_tex_dict = {
|
9 |
+
"SU3": {"-3": r"\bar{\textbf{3}}", "3": r"\textbf{3}"},
|
10 |
+
"SU2": {"-2": r"\textbf{2}", "2": r"\textbf{2}", "-3": r"\textbf{3}", "3": r"\textbf{3}"},
|
11 |
}
|
12 |
|
13 |
+
def fieldobj_to_tex(obj, lor_index, pos):
|
14 |
su3 = None
|
15 |
su2 = None
|
16 |
+
u1 = None
|
17 |
hel = None
|
18 |
sp = None
|
19 |
|
|
|
20 |
obj_mod = obj.copy()
|
21 |
for tok in obj:
|
22 |
+
if "SU3" in tok:
|
23 |
+
su3 = tok.split("=")[-1]
|
24 |
+
obj_mod.remove(tok)
|
25 |
+
if "SU2" in tok:
|
26 |
+
su2 = tok.split("=")[-1]
|
27 |
+
obj_mod.remove(tok)
|
28 |
+
if "U1" in tok:
|
29 |
+
u1 = tok.split("=")[-1]
|
30 |
+
obj_mod.remove(tok)
|
31 |
+
if "HELICITY" in tok:
|
32 |
+
hel = tok.split("=")[-1]
|
33 |
+
if hel == "1":
|
34 |
+
hel = "+1"
|
35 |
+
if "SPIN" in tok:
|
36 |
+
sp = tok.split("=")[-1]
|
37 |
assert sp is not None
|
38 |
|
39 |
+
outtex = ""
|
40 |
+
if sp == "0":
|
41 |
+
outtex += r"\phi"
|
42 |
+
if sp == "1":
|
43 |
+
outtex += "A" + pos + lor_index
|
44 |
+
if sp == "1/2":
|
45 |
+
outtex += r"\psi"
|
46 |
|
47 |
outtex += r"_{("
|
48 |
+
# SU(3)
|
49 |
+
if su3 is not None:
|
50 |
+
outtex += rep_tex_dict["SU3"].get(su3, r"\textbf{1}") + " ,"
|
51 |
else:
|
52 |
+
outtex += r"\textbf{1},"
|
53 |
+
# SU(2)
|
54 |
+
if su2 is not None:
|
55 |
+
outtex += rep_tex_dict["SU2"].get(su2, r"\textbf{1}") + " ,"
|
56 |
else:
|
57 |
+
outtex += r"\textbf{1},"
|
58 |
+
# U(1)
|
59 |
+
if u1 is not None:
|
60 |
+
outtex += u1 + " ,"
|
61 |
else:
|
62 |
+
outtex += r"\textbf{0},"
|
63 |
+
# Helicity
|
64 |
+
if hel is not None:
|
65 |
+
outtex += "h:" + hel + " ,"
|
66 |
+
# Finish out subscript
|
67 |
+
if outtex[-1] == ",":
|
68 |
+
outtex = outtex[:-1] + ")}"
|
69 |
return outtex
|
70 |
|
71 |
+
def derobj_to_tex(obj, lor_index, pos):
|
72 |
+
if pos == "^":
|
73 |
+
outtex = f"D^{{{lor_index}}}_{{("
|
74 |
elif pos == "_":
|
75 |
+
outtex = f"D_{{{lor_index}}}^{{("
|
76 |
else:
|
77 |
+
raise ValueError("pos must be ^ or _")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
+
if "SU3" not in obj and "SU2" not in obj and "U1" not in obj:
|
80 |
+
# Just partial derivative
|
81 |
+
if pos == "^":
|
82 |
+
return f"\\partial^{lor_index}"
|
83 |
+
else:
|
84 |
+
return f"\\partial_{lor_index}"
|
85 |
+
|
86 |
+
if "SU3" in obj:
|
87 |
+
outtex += "SU3,"
|
88 |
+
if "SU2" in obj:
|
89 |
+
outtex += "SU2,"
|
90 |
+
if "U1" in obj:
|
91 |
+
outtex += "U1,"
|
92 |
+
if outtex[-1] == ",":
|
93 |
+
outtex = outtex[:-1] + ")}"
|
94 |
return outtex
|
95 |
|
96 |
+
def gamobj_to_tex(obj, lor_index, pos):
|
97 |
+
return r"\sigma" + pos + lor_index
|
98 |
+
|
99 |
+
def obj_to_tex(obj, lor_index="\mu", pos="^"):
|
100 |
+
# Convert tuple/strings to a list of tokens
|
101 |
+
if isinstance(obj, tuple):
|
102 |
+
obj = list(obj)
|
103 |
+
if isinstance(obj, str):
|
104 |
+
obj = [i for i in obj.split(" ") if i != ""]
|
105 |
+
|
106 |
+
# Basic tokens
|
107 |
+
if obj[0] == "+":
|
108 |
+
return r"\quad\quad+"
|
109 |
+
if obj[0] == "-":
|
110 |
+
return r"\quad\quad-"
|
111 |
+
if obj[0] == "i":
|
112 |
return "i"
|
113 |
+
|
114 |
+
# Field
|
115 |
+
if obj[0] == "FIELD":
|
116 |
+
return fieldobj_to_tex(obj, lor_index, pos)
|
117 |
+
# Derivative
|
118 |
if obj[0] == "DERIVATIVE":
|
119 |
+
return derobj_to_tex(obj, lor_index, pos)
|
120 |
+
# Sigma (gamma matrices)
|
121 |
if obj[0] == "SIGMA":
|
122 |
+
return gamobj_to_tex(obj, lor_index, pos)
|
123 |
+
|
124 |
+
# Combined COMMUTATOR + DERIVATIVE tokens
|
125 |
+
if obj[0] == "COMMUTATOR_ADERIVATIVE":
|
126 |
+
new_obj = obj[:]
|
127 |
+
new_obj[0] = "DERIVATIVE"
|
128 |
+
return "[ " + derobj_to_tex(new_obj, lor_index, pos)
|
129 |
+
if obj[0] == "COMMUTATOR_BDERIVATIVE":
|
130 |
+
new_obj = obj[:]
|
131 |
+
new_obj[0] = "DERIVATIVE"
|
132 |
+
return ", " + derobj_to_tex(new_obj, lor_index, pos) + " ]"
|
133 |
+
|
134 |
+
# Single COMMUTATOR tokens
|
135 |
if obj[0] == "COMMUTATOR_A":
|
136 |
+
return "[ " + derobj_to_tex(obj, lor_index, pos)
|
137 |
if obj[0] == "COMMUTATOR_B":
|
138 |
+
return ", " + derobj_to_tex(obj, lor_index, pos) + " ]"
|
139 |
+
|
140 |
+
# Fallback for unrecognized tokens if you like:
|
141 |
+
# return f"\\text{{Unhandled}}({obj})"
|
142 |
+
return ""
|
143 |
+
|
144 |
+
def split_with_delimiter_preserved(string, delimiters, ignore_dots=False):
|
145 |
+
"""
|
146 |
+
Splits a string using the given delimiters,
|
147 |
+
while preserving them as separate tokens.
|
148 |
+
"""
|
149 |
+
if "." in string and not ignore_dots:
|
150 |
+
raise ValueError("Unexpected ending to the generated Lagrangian")
|
|
|
|
|
|
|
151 |
pattern = '(' + '|'.join(map(re.escape, delimiters)) + ')'
|
152 |
+
parts = re.split(pattern, string)
|
153 |
+
# Turn a lonely "+ " into " + "
|
154 |
+
parts = [" + " if p == "+ " else p for p in parts]
|
155 |
+
# Remove empty entries
|
156 |
+
parts = [p for p in parts if p != ""]
|
157 |
+
return parts
|
158 |
|
159 |
def clean_split(inlist, delimiters):
|
160 |
+
"""
|
161 |
+
Merges an immediate delimiter with its next token
|
162 |
+
so that "FIELD " + "SPIN" -> "FIELD SPIN".
|
163 |
+
"""
|
164 |
i = 0
|
165 |
merged_list = []
|
166 |
while i < len(inlist):
|
167 |
if inlist[i] in delimiters:
|
168 |
if i < len(inlist) - 1:
|
169 |
merged_list.append(inlist[i] + inlist[i+1])
|
170 |
+
i += 1
|
171 |
else:
|
172 |
+
merged_list.append(inlist[i])
|
173 |
else:
|
174 |
merged_list.append(inlist[i])
|
175 |
i += 1
|
176 |
return merged_list
|
177 |
|
|
|
178 |
def get_obj_dict(inlist):
|
179 |
outdict = {}
|
180 |
for iitem in inlist:
|
181 |
+
idict = {"ID": None, "LATEX": None}
|
182 |
+
# Find any ID=... string
|
183 |
+
item_parts = iitem.split()
|
184 |
+
the_ids = [x for x in item_parts if x.startswith("ID")]
|
185 |
+
if the_ids:
|
186 |
+
idict["ID"] = the_ids[0]
|
187 |
+
# Always compute LATEX from obj_to_tex
|
188 |
+
idict["LATEX"] = obj_to_tex(iitem, "\\mu", "^")
|
189 |
outdict[iitem] = idict
|
190 |
return outdict
|
191 |
|
192 |
def get_con_dict(inlist):
|
193 |
+
"""
|
194 |
+
For a list of 'contractions' tokens, produce
|
195 |
+
a dictionary of which IDs are to be contracted
|
196 |
+
under LORENTZ, SU2, or SU3.
|
197 |
+
"""
|
198 |
outdict = {}
|
199 |
for iitem in inlist:
|
200 |
+
tokens = iitem.split()
|
201 |
+
tokens = [t for t in tokens if t != ""]
|
202 |
+
sym = [t for t in tokens if ("SU" in t or "LORENTZ" in t)]
|
203 |
assert len(sym) == 1, "More than one symmetry in contraction"
|
204 |
+
ids = [t for t in tokens if ("SU" not in t and "LZ" not in t)]
|
205 |
+
if sym[0] not in outdict:
|
206 |
outdict[sym[0]] = [ids]
|
207 |
else:
|
208 |
outdict[sym[0]].append(ids)
|
209 |
return outdict
|
210 |
|
211 |
+
def term_to_tex(term, verbose=True):
|
212 |
+
"""
|
213 |
+
Converts one Lagrangian term into its LaTeX representation.
|
214 |
+
"""
|
215 |
+
# Clean up certain strings
|
216 |
+
term = term.replace(".", "").replace(" = ", "=").replace(" =- ", "=-")
|
217 |
+
term = term.replace(" / ", "/")
|
218 |
+
term = term.replace("COMMUTATOR_A DERIVATIVE", "COMMUTATOR_ADERIVATIVE")
|
219 |
+
term = term.replace("COMMUTATOR_B DERIVATIVE", "COMMUTATOR_BDERIVATIVE")
|
220 |
+
|
221 |
+
# Split into sub-tokens
|
222 |
+
term = split_with_delimiter_preserved(
|
223 |
+
term,
|
224 |
+
[" FIELD ", " DERIVATIVE ", " SIGMA ", " COMMUTATOR_ADERIVATIVE ", " COMMUTATOR_BDERIVATIVE ", " CONTRACTIONS "]
|
225 |
+
)
|
226 |
+
term = clean_split(
|
227 |
+
term,
|
228 |
+
[" FIELD ", " DERIVATIVE ", " SIGMA ", " COMMUTATOR_ADERIVATIVE ", " COMMUTATOR_BDERIVATIVE ", " CONTRACTIONS "]
|
229 |
+
)
|
230 |
+
|
231 |
if verbose:
|
232 |
+
print(term)
|
|
|
233 |
|
234 |
+
# If it's just +, -, or i, return that token
|
235 |
+
if term in [[" + "], [" - "], [" i "]]:
|
236 |
+
return term[0]
|
237 |
+
|
238 |
+
# Build dictionary for objects that aren't in "CONTRACTIONS"
|
239 |
+
objdict = get_obj_dict([t for t in term if " CONTRACTIONS " not in t])
|
240 |
+
if verbose:
|
241 |
+
for k, v in objdict.items():
|
242 |
+
print(k, "\t\t", v)
|
243 |
|
244 |
+
# Contractions
|
245 |
+
contractions = [t for t in term if " CONTRACTIONS " in t]
|
246 |
+
if len(contractions) > 1:
|
247 |
+
raise ValueError("More than one contraction in term")
|
248 |
+
|
249 |
+
if len(contractions) == 1 and contractions != [" CONTRACTIONS "]:
|
250 |
+
# e.g. "LORENTZ ID5 ID2", etc.
|
251 |
+
c_str = contractions[0]
|
252 |
+
c_str = split_with_delimiter_preserved(c_str, [" LORENTZ ", " SU2 ", " SU3 "])
|
253 |
+
c_str = clean_split(c_str, [" LORENTZ ", " SU2 ", " SU3 "])
|
254 |
+
c_str = [i for i in c_str if i != " CONTRACTIONS"]
|
255 |
+
condict = get_con_dict(c_str)
|
256 |
+
if verbose:
|
257 |
+
print(condict)
|
258 |
|
259 |
+
# LORENTZ contraction handling
|
260 |
+
if "LORENTZ" in condict:
|
|
|
|
|
|
|
|
|
|
|
261 |
firstlz = True
|
262 |
cma = True
|
263 |
+
for con in condict["LORENTZ"]:
|
264 |
+
for kobj, iobj in objdict.items():
|
265 |
+
if iobj["ID"] is None:
|
266 |
+
continue
|
267 |
if iobj["ID"] in con:
|
268 |
+
if cma:
|
269 |
+
lsymb = r"\mu"
|
270 |
+
else:
|
271 |
+
lsymb = r"\nu"
|
272 |
if firstlz:
|
273 |
+
iobj["LATEX"] = obj_to_tex(kobj, lsymb, "^")
|
274 |
firstlz = False
|
275 |
else:
|
276 |
+
iobj["LATEX"] = obj_to_tex(kobj, lsymb, "_")
|
277 |
cma = False
|
278 |
firstlz = True
|
279 |
|
280 |
+
# Join the final LaTeX strings
|
281 |
+
outstr = " ".join([objdict[t]["LATEX"] for t in term if " CONTRACTIONS " not in t])
|
282 |
return outstr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
|
284 |
+
def str_tex(instr, num=0):
|
285 |
+
"""
|
286 |
+
Convert list of terms into complete LaTeX lines for the Lagrangian.
|
287 |
+
"""
|
|
|
|
|
|
|
288 |
if num != 0:
|
289 |
instr = instr[:num]
|
290 |
+
|
291 |
+
inlist = [term.replace(".", "") for term in instr]
|
292 |
outstr = ""
|
293 |
coup = 0
|
294 |
mass = 0
|
295 |
+
outstr = r"\begin{aligned}"
|
296 |
+
for i, iterm in enumerate(inlist):
|
297 |
+
if i == 0:
|
298 |
+
outstr += r" \mathcal{L}= \quad \\ & "
|
299 |
+
else:
|
300 |
+
# Identify coupling or mass terms by counting spin-0 fields
|
301 |
nqf = iterm.count("FIELD SPIN = 0")
|
302 |
+
nD = iterm.count(" DERIVATIVE ")
|
303 |
if nqf != 0 and nqf != 2 and nD == 0:
|
304 |
coup += 1
|
305 |
+
outstr += rf" \lambda_{{{coup}}} \,"
|
306 |
if nqf == 2 and nD == 0:
|
307 |
mass += 1
|
308 |
+
outstr += rf" m^2_{{{mass}}} \,"
|
309 |
+
outstr += term_to_tex(iterm, False) + r" \quad "
|
310 |
+
if i % 4 == 0:
|
311 |
+
outstr += r" \\ \\ & "
|
312 |
return outstr
|
313 |
|
314 |
def master_str_tex(iinstr):
|
315 |
+
"""
|
316 |
+
Master function that splits the incoming string,
|
317 |
+
tries to render the full Lagrangian,
|
318 |
+
and catches errors if the model text is truncated.
|
319 |
+
"""
|
320 |
+
instr = split_with_delimiter_preserved(iinstr, [" + ", "+ ", " - "])
|
321 |
try:
|
322 |
outstr = str_tex(instr)
|
323 |
except Exception as e:
|
324 |
+
# If an error occurs, try ignoring the last token
|
325 |
+
outstr = str_tex(instr, -1)
|
326 |
+
outstr += " \\cdots"
|
327 |
print(e)
|
328 |
+
outstr += r"\end{aligned}"
|
329 |
+
return outstr
|
|
|
|
|
330 |
|
331 |
+
# ---------------------------------------------------------------------------------
|
332 |
+
# Model loading
|
333 |
device = 'cpu'
|
334 |
model_name = "JoseEliel/BART-Lagrangian"
|
335 |
|
336 |
@st.cache_resource
|
337 |
def load_model():
|
|
|
338 |
model = BartForConditionalGeneration.from_pretrained(model_name).to(device)
|
|
|
339 |
return model
|
340 |
|
|
|
|
|
341 |
@st.cache_resource
|
342 |
def load_tokenizer():
|
343 |
return PreTrainedTokenizerFast.from_pretrained(model_name)
|
344 |
|
345 |
+
model = load_model()
|
346 |
hf_tokenizer = load_tokenizer()
|
347 |
|
348 |
+
# ---------------------------------------------------------------------------------
|
349 |
+
# Text processing wrappers
|
350 |
def process_input(input_text):
|
351 |
+
# Sort fields so generation is consistent
|
352 |
input_text = input_text.replace("[SOS]", "").replace("[EOS]", "").replace("FIELD", "SPLITFIELD")
|
353 |
+
fields = input_text.split("SPLIT")[1:]
|
354 |
+
fields = [x.strip().split(" ") for x in fields]
|
355 |
fields = sorted(fields)
|
356 |
+
fields = "[SOS] " + " ".join([" ".join(x) for x in fields]) + " [EOS]"
|
357 |
return fields
|
358 |
|
359 |
def process_output(output_text):
|
360 |
+
# Remove special tokens from model output
|
361 |
+
return output_text.replace("[SOS]", "").replace("[EOS]", "").replace(".", "")
|
362 |
|
363 |
+
def reformat_expression(s):
|
364 |
+
# e.g. turn SU2= -1 into SU2=-1, remove spaces
|
365 |
+
return re.sub(r"(SU[23]|U1|SPIN|HEL)\s+([+-]?\s*\d+)",
|
366 |
+
lambda m: f"{m.group(1)} = {m.group(2).replace(' ', '')}",
|
367 |
+
s)
|
368 |
|
369 |
def generate_lagrangian(input_text):
|
370 |
+
"""
|
371 |
+
Calls the model to produce a Lagrangian for the user-given fields.
|
372 |
+
"""
|
373 |
input_text = process_input(input_text)
|
374 |
inputs = hf_tokenizer([input_text], return_tensors='pt').to(device)
|
375 |
+
with st.spinner("Generating Lagrangian..."):
|
376 |
+
lagrangian_ids = model.generate(inputs['input_ids'], max_length=2048)
|
377 |
lagrangian = hf_tokenizer.decode(lagrangian_ids[0].tolist(), skip_special_tokens=False)
|
378 |
lagrangian = process_output(lagrangian)
|
379 |
return lagrangian
|
380 |
|
381 |
def generate_field(sp, su2, su3, u1):
|
382 |
+
"""
|
383 |
+
Builds a single field string with the chosen spin and gauge charges.
|
384 |
+
"""
|
385 |
+
components = [f"FIELD SPIN={sp}"]
|
386 |
+
# For spin = 1/2, optionally add helicity
|
387 |
+
if sp == "1/2":
|
388 |
components = [f"FIELD SPIN={sp} HEL=1/2"]
|
389 |
+
|
|
|
390 |
if su2 != "$1$":
|
391 |
components.append(f"SU2={su2}")
|
392 |
+
if su3 == "$\\bar{{3}}$":
|
393 |
components.append("SU3=-3")
|
394 |
+
elif su3 != "$1$":
|
395 |
+
components.append(f"SU3={su3.replace('$','')}")
|
396 |
if u1 != "0":
|
397 |
components.append(f"U1={u1}")
|
398 |
+
return " ".join(components).replace("$", "")
|
|
|
|
|
399 |
|
400 |
+
# ---------------------------------------------------------------------------------
|
401 |
+
# Streamlit GUI
|
402 |
def main():
|
|
|
403 |
st.title("$\\mathscr{L}$agrangian Generator")
|
404 |
st.markdown(" ### For a set of chosen fields, this model generates the corresponding Lagrangian which encodes all interactions and dynamics of the fields.")
|
405 |
|
406 |
st.markdown(" #### This is a demo of our [BART](https://arxiv.org/abs/1910.13461)-based model with ca 360M parameters")
|
407 |
|
408 |
+
st.markdown(" ##### :violet[Due to computational resources, we limit the number of fields to 3.]")
|
409 |
st.markdown(" ##### Choose up to three different fields:")
|
410 |
|
411 |
+
|
412 |
+
st.markdown("Choose up to three different fields:")
|
413 |
su2_options = ["$1$", "$2$", "$3$"]
|
414 |
su3_options = ["$1$", "$3$", "$\\bar{3}$"]
|
415 |
+
u1_options = ["-1", "-2/3", "-1/2", "-1/3", "0", "1/3", "1/2", "2/3", "1"]
|
416 |
spin_options = ["0", "1/2"]
|
417 |
+
|
418 |
+
if "count" not in st.session_state:
|
419 |
+
st.session_state.count = 0
|
420 |
+
if "field_strings" not in st.session_state:
|
421 |
+
st.session_state.field_strings = []
|
|
|
422 |
|
423 |
with st.form("field_selection"):
|
424 |
spin_selection = st.radio("Select spin value:", spin_options)
|
425 |
+
su2_selection = st.radio("Select SU(2) value:", su2_options)
|
426 |
+
su3_selection = st.radio("Select SU(3) value:", su3_options)
|
427 |
+
u1_selection = st.radio("Select U(1) value:", u1_options)
|
428 |
submitted = st.form_submit_button("Add field")
|
429 |
if submitted:
|
430 |
if st.session_state.count < 3:
|
431 |
+
fs = generate_field(spin_selection, su2_selection, su3_selection, u1_selection)
|
432 |
+
st.session_state.field_strings.append(fs)
|
433 |
+
st.session_state.count += 1
|
434 |
+
else:
|
435 |
+
st.write("Maximum of 3 fields for this demo.")
|
436 |
+
|
437 |
clear_fields = st.button("Clear fields")
|
438 |
if clear_fields:
|
439 |
st.session_state.field_strings = []
|
440 |
st.session_state.count = 0
|
|
|
441 |
|
442 |
+
# Display current fields
|
443 |
+
st.write("Input Fields:")
|
444 |
for i, fs in enumerate(st.session_state.field_strings, 1):
|
445 |
texfield = obj_to_tex(fs)
|
446 |
+
st.latex(r"\text{Field " + str(i) + ":} \quad " + texfield)
|
|
|
447 |
|
448 |
+
# Generate Lagrangian button
|
449 |
if st.button("Generate Lagrangian"):
|
450 |
input_fields = " ".join(st.session_state.field_strings)
|
451 |
+
if input_fields.strip() == "":
|
452 |
+
st.write("Please add at least one field before generating the Lagrangian.")
|
453 |
return
|
454 |
+
|
455 |
+
input_fields = input_fields.replace("=", " ")
|
456 |
+
input_fields = "[SOS] " + input_fields + " [EOS]"
|
457 |
+
generated_lagrangian = generate_lagrangian(input_fields)
|
458 |
+
generated_lagrangian = reformat_expression(generated_lagrangian)
|
459 |
+
print(generated_lagrangian)
|
460 |
+
|
461 |
+
# Attempt to render as LaTeX
|
462 |
+
latex_output = master_str_tex(generated_lagrangian[1:])
|
463 |
+
st.latex(latex_output)
|
464 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
465 |
st.markdown("### Contact")
|
466 |
+
st.markdown("For questions/suggestions, email us: [Eliel](mailto:[email protected]) or [Yong Sheng](mailto:[email protected]).")
|
467 |
|
468 |
if __name__ == "__main__":
|
469 |
main()
|