José Eliel Camargo Molina commited on
Commit
bf7477e
·
1 Parent(s): ca2bf21

latex fixed

Browse files
Files changed (1) hide show
  1. app.py +305 -265
app.py CHANGED
@@ -1,429 +1,469 @@
1
- import streamlit as st
2
- from transformers import PreTrainedTokenizerFast, BartForConditionalGeneration
3
- import torch
4
- from pathlib import Path
5
- import urllib.request
6
-
7
- # To latex stuff
8
- ####################################
9
 
10
- import itertools
11
  import re
 
 
 
12
 
 
13
  rep_tex_dict = {
14
- "SU3":{"-3":r"\bar{\textbf{3}}","3":r"\textbf{3}"},
15
- "SU2":{"-2":r"\textbf{2}","2":r"\textbf{2}","-3":r"\textbf{3}","3":r"\textbf{3}"},
16
  }
17
 
18
- def fieldobj_to_tex(obj,lor_index,pos):
19
  su3 = None
20
  su2 = None
21
- u1 = None
22
  hel = None
23
  sp = None
24
 
25
- #print(obj)
26
  obj_mod = obj.copy()
27
  for tok in obj:
28
- if "SU3" in tok:
29
- su3 = tok.split("=")[-1]
30
- obj_mod.remove(tok)
31
- if "SU2" in tok:
32
- su2 = tok.split("=")[-1]
33
- obj_mod.remove(tok)
34
- if "U1" in tok:
35
- u1 = tok.split("=")[-1]
36
- obj_mod.remove(tok)
37
- if "HELICITY" in tok:
38
- hel = tok.split("=")[-1]
39
- if hel == "1" : hel = "+1"
40
- if "SPIN" in tok: sp = tok.split("=")[-1]
41
- #print(obj)
 
42
  assert sp is not None
43
 
44
- outtex= ""
45
- if sp == "0" : outtex += "\phi"
46
- if sp == "1" : outtex += "A"+pos+lor_index
47
- if sp == "1/2" : outtex += "\psi"
 
 
 
48
 
49
  outtex += r"_{("
50
- if su3 is not None:
51
- outtex += rep_tex_dict["SU3"][su3]+" ,"
 
52
  else:
53
- outtex += r"\textbf{1}"+" ,"
54
- if su2 is not None:
55
- outtex += rep_tex_dict["SU2"][su2]+" ,"
 
56
  else:
57
- outtex += r"\textbf{1}"+" ,"
58
- if u1 is not None:
59
- outtex += u1+" ,"
 
60
  else:
61
- outtex += r"\textbf{0}"+" ,"
62
- if hel is not None: outtex += "h:"+ hel + " ,"
63
- if outtex[-1] == ",": outtex = outtex[:-1]+")}"
 
 
 
 
64
  return outtex
65
 
66
- def derobj_to_tex(obj,lor_index,pos):
67
- if pos == "^":
68
- outtex = "D^{"+lor_index+"}_{("
69
  elif pos == "_":
70
- outtex = "D_{"+lor_index+"("
71
  else:
72
- raise ValueError("pos must be ^ or _")
73
- if "SU3" not in obj and "SU2" not in obj and "U1" not in obj:
74
- if pos == "^":
75
- return "\partial^{"+lor_index+"}"
76
- elif pos == "_":
77
- return "\partial_{"+lor_index+"}"
78
-
79
- if "SU3" in obj: outtex += "SU3,"
80
- if "SU2" in obj: outtex += "SU2,"
81
- if "U1" in obj: outtex += "U1,"
82
- if outtex[-1] == ",": outtex = outtex[:-1]+")}"
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  return outtex
85
 
86
- def gamobj_to_tex(obj,lor_index,pos):
87
- outtex = "\sigma"+pos+lor_index
88
- return outtex
89
-
90
- def obj_to_tex(obj,lor_index="\mu",pos="^"):
91
- if isinstance(obj,tuple): obj = list(obj)
92
- if isinstance(obj,str): obj = [i for i in obj.split(" ") if i != ""]
93
- # remove any space char in the first element of the list
94
- if obj[0] == "+" :
95
- return "\quad\quad+"
96
- if obj[0] == "-" :
97
- return "\quad\quad-"
98
- if obj[0] == "i" :
 
 
 
99
  return "i"
100
- if obj[0] == "FIELD" :
101
- return fieldobj_to_tex(obj,lor_index,pos)
 
 
 
102
  if obj[0] == "DERIVATIVE":
103
- return derobj_to_tex(obj,lor_index,pos)
 
104
  if obj[0] == "SIGMA":
105
- return gamobj_to_tex(obj,lor_index,pos)
 
 
 
 
 
 
 
 
 
 
 
 
106
  if obj[0] == "COMMUTATOR_A":
107
- return "[ "+derobj_to_tex(obj,lor_index,pos)
108
  if obj[0] == "COMMUTATOR_B":
109
- return ", "+derobj_to_tex(obj,lor_index,pos)+' ]'
110
-
111
- def split_with_delimiter_preserved(string, delimiters,ignore_dots=False):
112
- if "." in string and ignore_dots == False:
113
- #print(string)
114
- raise ValueError("Unexpected ending to the generated Lagrangian")
115
- pattern = '(' + '|'.join(map(re.escape, delimiters)) + ')'
116
- pattern = re.split(pattern, string)
117
- pattern = [" + " if i == "+ " else i for i in pattern ]
118
- pattern = [i for i in pattern if i != ""]
119
- return pattern
120
-
121
- def split_with_delimiter_preserved(string, delimiters,ignore_dots=False):
122
- if "." in string and ignore_dots == False:
123
- #print(string)
124
- raise ValueError("Unexpected ending to the generated Lagrangian")
125
  pattern = '(' + '|'.join(map(re.escape, delimiters)) + ')'
126
- pattern = re.split(pattern, string)
127
- pattern = [" + " if i == "+ " else i for i in pattern ]
128
- pattern = [i for i in pattern if i != ""]
129
- return pattern
 
 
130
 
131
  def clean_split(inlist, delimiters):
 
 
 
 
132
  i = 0
133
  merged_list = []
134
  while i < len(inlist):
135
  if inlist[i] in delimiters:
136
  if i < len(inlist) - 1:
137
  merged_list.append(inlist[i] + inlist[i+1])
138
- i += 1 # Skip the next element as it has been merged
139
  else:
140
- merged_list.append(inlist[i]) # If it's the last element, append it without merging
141
  else:
142
  merged_list.append(inlist[i])
143
  i += 1
144
  return merged_list
145
 
146
-
147
  def get_obj_dict(inlist):
148
  outdict = {}
149
  for iitem in inlist:
150
- idict = {"ID":None,"LATEX":None}
151
- id = [i for i in iitem.split() if "ID" in i]
152
- if len(id) == 1:
153
- idict["ID"] = id[0]
154
- if "FIELD" in iitem:
155
- idict["LATEX"] = obj_to_tex(iitem,"\\mu","^")
156
- if iitem == "+" or iitem == "-" or iitem == "i":
157
- idict["LATEX"] = obj_to_tex(iitem )
158
  outdict[iitem] = idict
159
  return outdict
160
 
161
  def get_con_dict(inlist):
 
 
 
 
 
162
  outdict = {}
163
  for iitem in inlist:
164
- iitem = iitem.split()
165
- iitem = [i for i in iitem if i != ""]
166
- sym = [i for i in iitem if ("SU" in i or "LORENTZ" in i)]
167
  assert len(sym) == 1, "More than one symmetry in contraction"
168
- ids = [i for i in iitem if ("SU" not in i and "LZ" not in i)]
169
- if sym[0] not in outdict.keys():
170
  outdict[sym[0]] = [ids]
171
  else:
172
  outdict[sym[0]].append(ids)
173
  return outdict
174
 
175
- def term_to_tex(term,verbose=False):
176
- # Clean term
177
- term = term.replace(".","").replace(" = ", "=").replace(" =- ", "=-").replace(" / ", "/").replace("COMMUTATOR_A DERIVATIVE", "COMMUTATOR_ADERIVATIVE").replace("COMMUTATOR_B DERIVATIVE", "COMMUTATOR_BDERIVATIVE")
178
- term = split_with_delimiter_preserved(term,[" FIELD "," DERIVATIVE "," SIGMA "," COMMUTATOR_A "," COMMUTATOR_B "," CONTRACTIONS "])
179
- term = clean_split(term, [" FIELD "," DERIVATIVE "," SIGMA "," COMMUTATOR_ADERIVATIVE "," COMMUTATOR_BDERIVATIVE "," CONTRACTIONS "])
180
-
181
- if verbose: print(term)
182
-
183
- if term == [" + "] or term == [" - "] or term == [" i "]:
184
- return term[0]
185
-
186
- # Get Dictionary of objects
187
- objdict = get_obj_dict([i for i in term if " CONTRACTIONS " not in i])
188
-
 
 
 
 
 
 
189
  if verbose:
190
- for i,j in objdict.items():
191
- print(i,"\t\t",j)
192
 
 
 
 
 
 
 
 
 
 
193
 
194
- # Do contractions
195
- contractions = [i for i in term if " CONTRACTIONS " in i]
196
- assert len(contractions) < 2, "More than one contraction in term"
197
- if (len(contractions) == 1) and contractions != [" CONTRACTIONS "]:
 
 
 
 
 
 
 
 
 
 
198
 
199
- contractions = contractions[0]
200
- contractions = split_with_delimiter_preserved(contractions,[" LORENTZ "," SU2 "," SU3 "])
201
- contractions = clean_split(contractions, [" LORENTZ "," SU2 "," SU3 "])
202
- contractions = [i for i in contractions if i != " CONTRACTIONS"]
203
- condict = get_con_dict(contractions)
204
- if verbose: print(condict)
205
- if "LZ" in condict.keys():
206
  firstlz = True
207
  cma = True
208
- for con in condict["LZ"]:
209
- for kobj , iobj in objdict.items():
210
- if iobj["ID"] is None : continue
 
211
  if iobj["ID"] in con:
212
- if cma: lsymb = "\\mu"
213
- else: lsymb = "\\nu"
214
-
 
215
  if firstlz:
216
- iobj["LATEX"] = obj_to_tex(kobj,lsymb,"^")
217
  firstlz = False
218
  else:
219
- iobj["LATEX"] = obj_to_tex(kobj,lsymb,"_")
220
  cma = False
221
  firstlz = True
222
 
223
- outstr = " ".join([objdict[i]["LATEX"] for i in term if " CONTRACTIONS " not in i])
224
-
225
  return outstr
226
- def display_in_latex(instring,verbose=False):
227
- #latex_string = r"$\overgroup{\Large{" + instring + "}}$"
228
- latex_string = r"$\Large{" + instring + "}$"
229
- if verbose: print(latex_string)
230
- display(Latex(latex_string))
231
- return instring
232
-
233
 
234
- def str_tex(instr,num=0):
235
-
236
- #print("INPUT:",iinstr)
237
- #print("TERM:")
238
- #outstr = ""
239
- #instr = split_with_delimiter_preserved(iinstr,[" + ","+ "," - "])
240
-
241
  if num != 0:
242
  instr = instr[:num]
243
-
244
- inlist = [term.replace(".","") for term in instr]
245
  outstr = ""
246
  coup = 0
247
  mass = 0
248
- outstr = "\\begin{aligned}"
249
- for i, iterm in enumerate(inlist):
250
- if i ==0:
251
- outstr += " \mathcal{L}= \quad \\\\ & "
252
- else:
 
253
  nqf = iterm.count("FIELD SPIN = 0")
254
- nD = iterm.count(" DERIVATIVE ")
255
  if nqf != 0 and nqf != 2 and nD == 0:
256
  coup += 1
257
- outstr += " \lambda_{"+str(coup)+"} \,"
258
  if nqf == 2 and nD == 0:
259
  mass += 1
260
- outstr += " m^2_{"+str(mass)+"} \,"
261
- outstr += term_to_tex(iterm,False) + " \quad "
262
- if i%4 == 0: outstr += " \\\\ \\\\ & "
 
263
  return outstr
264
 
265
  def master_str_tex(iinstr):
266
- instr = split_with_delimiter_preserved(iinstr,[" + ","+ "," - "])
 
 
 
 
 
267
  try:
268
  outstr = str_tex(instr)
269
  except Exception as e:
270
- outstr = str_tex(instr,-1)
271
- outstr += " \cdots"
 
272
  print(e)
273
- outstr += "\\end{aligned}"
274
- return outstr#########
275
-
276
-
277
 
 
 
278
  device = 'cpu'
279
  model_name = "JoseEliel/BART-Lagrangian"
280
 
281
  @st.cache_resource
282
  def load_model():
283
-
284
  model = BartForConditionalGeneration.from_pretrained(model_name).to(device)
285
-
286
  return model
287
 
288
- model = load_model()
289
-
290
  @st.cache_resource
291
  def load_tokenizer():
292
  return PreTrainedTokenizerFast.from_pretrained(model_name)
293
 
 
294
  hf_tokenizer = load_tokenizer()
295
 
 
 
296
  def process_input(input_text):
 
297
  input_text = input_text.replace("[SOS]", "").replace("[EOS]", "").replace("FIELD", "SPLITFIELD")
298
- fields = input_text.split('SPLIT')[1:]
299
- fields = [x.strip().split(' ') for x in fields]
300
  fields = sorted(fields)
301
- fields = "[SOS] " + " ".join([" ".join(x) for x in fields]) + " [EOS]"
302
  return fields
303
 
304
  def process_output(output_text):
305
- return output_text.replace("[SOS]", "").replace("[EOS]", "").replace(".","")
 
306
 
307
- def process_output_pretty_print(output_text):
308
- pretty_output = output_text.replace(" / ", "/")
309
- pretty_output = pretty_output.replace("=- ", "= -")
310
- pretty_output = pretty_output.replace("+", "\n+")
311
- return pretty_output
312
 
313
  def generate_lagrangian(input_text):
 
 
 
314
  input_text = process_input(input_text)
315
  inputs = hf_tokenizer([input_text], return_tensors='pt').to(device)
316
- with st.spinner(text="Generating Lagrangian..."):
317
- lagrangian_ids = model.generate(inputs['input_ids'], max_length=512)
318
  lagrangian = hf_tokenizer.decode(lagrangian_ids[0].tolist(), skip_special_tokens=False)
319
  lagrangian = process_output(lagrangian)
320
  return lagrangian
321
 
322
  def generate_field(sp, su2, su3, u1):
323
- # Initialize components list
324
-
325
- if sp == "0":
326
- components = [f"FIELD SPIN={sp}"]
327
- else:
 
328
  components = [f"FIELD SPIN={sp} HEL=1/2"]
329
-
330
- # Conditionally add each component
331
  if su2 != "$1$":
332
  components.append(f"SU2={su2}")
333
- if su3 == "$\\bar{3}$":
334
  components.append("SU3=-3")
335
- if su3 != "$1$" and su3 != "$\\bar{3}$":
336
- components.append(f"SU3={su3}")
337
  if u1 != "0":
338
  components.append(f"U1={u1}")
339
-
340
- # Join components into final string
341
- return " ".join(components).replace("$","")
342
 
 
 
343
  def main():
344
- # Streamlit UI (Adjusted without 'className')
345
  st.title("$\\mathscr{L}$agrangian Generator")
346
  st.markdown(" ### For a set of chosen fields, this model generates the corresponding Lagrangian which encodes all interactions and dynamics of the fields.")
347
 
348
  st.markdown(" #### This is a demo of our [BART](https://arxiv.org/abs/1910.13461)-based model with ca 360M parameters")
349
 
350
- st.markdown(" ##### :violet[Due to computational resources, we limit the number of fields to 3 and the maximum length of the generated Lagrangian to 512 tokens.]")
351
  st.markdown(" ##### Choose up to three different fields:")
352
 
 
 
353
  su2_options = ["$1$", "$2$", "$3$"]
354
  su3_options = ["$1$", "$3$", "$\\bar{3}$"]
355
- u1_options = ["-1","-2/3", "-1/2", "-1/3", "0","1/3" ,"1/2", "2/3", "1"]
356
  spin_options = ["0", "1/2"]
357
-
358
- # Initialize or update session state variables
359
- if 'count' not in st.session_state:
360
- st.session_state.count = 0 # Keeps track of button presses
361
- if 'field_strings' not in st.session_state:
362
- st.session_state.field_strings = [] # Stores the generated field strings
363
 
364
  with st.form("field_selection"):
365
  spin_selection = st.radio("Select spin value:", spin_options)
366
- su2_selection = st.radio("Select $\\mathrm{SU}(2)$ value:", su2_options)
367
- su3_selection = st.radio("Select $\\mathrm{SU}(3)$ value:", su3_options)
368
- u1_selection = st.radio("Select $\\mathrm{U}(1)$ value:", u1_options)
369
  submitted = st.form_submit_button("Add field")
370
  if submitted:
371
  if st.session_state.count < 3:
372
- field_string = generate_field(spin_selection, su2_selection, su3_selection, u1_selection)
373
- st.session_state.field_strings.append(field_string) # Save generated field string
374
- st.session_state.count += 1 # Increment button press count
375
- elif st.session_state.count >= 3:
376
- st.write("You have reached the maximum number of fields we allow in this demo.")
 
377
  clear_fields = st.button("Clear fields")
378
  if clear_fields:
379
  st.session_state.field_strings = []
380
  st.session_state.count = 0
381
- # Button to generate field text, allows up to 2 button presses
382
 
383
- st.write(f"Input Fields:")
 
384
  for i, fs in enumerate(st.session_state.field_strings, 1):
385
  texfield = obj_to_tex(fs)
386
- fieldname = f"Field {i}:"
387
- st.latex("\\text{" + fieldname + "} \quad" + texfield)
388
 
 
389
  if st.button("Generate Lagrangian"):
390
  input_fields = " ".join(st.session_state.field_strings)
391
- if input_fields == "":
392
- st.write("Please add fields before generating the Lagrangian.")
393
  return
394
- else:
395
- print("\n")
396
- # append input fields into csv file, create if not exist
397
- #with open('usesdata.csv', 'a') as f:
398
- # f.write(input_fields + "\n")
399
- # replace = with space
400
- input_fields = input_fields.replace("=", " ")
401
- # append and prepend input fields with SOS and EOS tokens
402
- input_fields = "[SOS] " + input_fields + " [EOS]"
403
- print(input_fields)
404
- generated_lagrangian = generate_lagrangian(input_fields)
405
- print(generated_lagrangian)
406
- print("\n")
407
- # Save generated lagrangian into same csv file, create if not exist
408
- #with open('usesdata.csv', 'a') as f:
409
- # f.write(generated_lagrangian + "\n")
410
-
411
- # add = to SU2 X, SU3 X, U1 X, SPIN X only when X is a number and not when its followd by anything not a number
412
- #generated_lagrangian = re.sub(r"(SU2)(\s)(\d)", r"\1=\3", generated_lagrangian)
413
- #latex_output = master_str_tex(generated_lagrangian[1:])
414
- #print(latex_output)
415
- #print("\n\n")
416
- # save latex output in file
417
- #with open('usesdata.csv', 'a') as f:
418
- # f.write(latex_output + "\n")
419
- #st.text_area("Generated Lagrangian", pretty_output, height=300)
420
- st.markdown("### Generated Lagrangian")
421
- st.text_area(generated_lagrangian, height=300)
422
-
423
-
424
- # write my contact info
425
  st.markdown("### Contact")
426
- st.markdown("If you have any questions or suggestions, please feel free to Email us. [Eliel](mailto:[email protected]) or [Yong Sheng](mailto:[email protected]).")
427
 
428
  if __name__ == "__main__":
429
  main()
 
 
 
 
 
 
 
 
 
1
 
 
2
  import re
3
+ import streamlit as st
4
+ import torch
5
+ from transformers import PreTrainedTokenizerFast, BartForConditionalGeneration
6
 
7
+ # Dictionary for SU(3)/SU(2) latex representations
8
  rep_tex_dict = {
9
+ "SU3": {"-3": r"\bar{\textbf{3}}", "3": r"\textbf{3}"},
10
+ "SU2": {"-2": r"\textbf{2}", "2": r"\textbf{2}", "-3": r"\textbf{3}", "3": r"\textbf{3}"},
11
  }
12
 
13
+ def fieldobj_to_tex(obj, lor_index, pos):
14
  su3 = None
15
  su2 = None
16
+ u1 = None
17
  hel = None
18
  sp = None
19
 
 
20
  obj_mod = obj.copy()
21
  for tok in obj:
22
+ if "SU3" in tok:
23
+ su3 = tok.split("=")[-1]
24
+ obj_mod.remove(tok)
25
+ if "SU2" in tok:
26
+ su2 = tok.split("=")[-1]
27
+ obj_mod.remove(tok)
28
+ if "U1" in tok:
29
+ u1 = tok.split("=")[-1]
30
+ obj_mod.remove(tok)
31
+ if "HELICITY" in tok:
32
+ hel = tok.split("=")[-1]
33
+ if hel == "1":
34
+ hel = "+1"
35
+ if "SPIN" in tok:
36
+ sp = tok.split("=")[-1]
37
  assert sp is not None
38
 
39
+ outtex = ""
40
+ if sp == "0":
41
+ outtex += r"\phi"
42
+ if sp == "1":
43
+ outtex += "A" + pos + lor_index
44
+ if sp == "1/2":
45
+ outtex += r"\psi"
46
 
47
  outtex += r"_{("
48
+ # SU(3)
49
+ if su3 is not None:
50
+ outtex += rep_tex_dict["SU3"].get(su3, r"\textbf{1}") + " ,"
51
  else:
52
+ outtex += r"\textbf{1},"
53
+ # SU(2)
54
+ if su2 is not None:
55
+ outtex += rep_tex_dict["SU2"].get(su2, r"\textbf{1}") + " ,"
56
  else:
57
+ outtex += r"\textbf{1},"
58
+ # U(1)
59
+ if u1 is not None:
60
+ outtex += u1 + " ,"
61
  else:
62
+ outtex += r"\textbf{0},"
63
+ # Helicity
64
+ if hel is not None:
65
+ outtex += "h:" + hel + " ,"
66
+ # Finish out subscript
67
+ if outtex[-1] == ",":
68
+ outtex = outtex[:-1] + ")}"
69
  return outtex
70
 
71
+ def derobj_to_tex(obj, lor_index, pos):
72
+ if pos == "^":
73
+ outtex = f"D^{{{lor_index}}}_{{("
74
  elif pos == "_":
75
+ outtex = f"D_{{{lor_index}}}^{{("
76
  else:
77
+ raise ValueError("pos must be ^ or _")
 
 
 
 
 
 
 
 
 
 
78
 
79
+ if "SU3" not in obj and "SU2" not in obj and "U1" not in obj:
80
+ # Just partial derivative
81
+ if pos == "^":
82
+ return f"\\partial^{lor_index}"
83
+ else:
84
+ return f"\\partial_{lor_index}"
85
+
86
+ if "SU3" in obj:
87
+ outtex += "SU3,"
88
+ if "SU2" in obj:
89
+ outtex += "SU2,"
90
+ if "U1" in obj:
91
+ outtex += "U1,"
92
+ if outtex[-1] == ",":
93
+ outtex = outtex[:-1] + ")}"
94
  return outtex
95
 
96
+ def gamobj_to_tex(obj, lor_index, pos):
97
+ return r"\sigma" + pos + lor_index
98
+
99
+ def obj_to_tex(obj, lor_index="\mu", pos="^"):
100
+ # Convert tuple/strings to a list of tokens
101
+ if isinstance(obj, tuple):
102
+ obj = list(obj)
103
+ if isinstance(obj, str):
104
+ obj = [i for i in obj.split(" ") if i != ""]
105
+
106
+ # Basic tokens
107
+ if obj[0] == "+":
108
+ return r"\quad\quad+"
109
+ if obj[0] == "-":
110
+ return r"\quad\quad-"
111
+ if obj[0] == "i":
112
  return "i"
113
+
114
+ # Field
115
+ if obj[0] == "FIELD":
116
+ return fieldobj_to_tex(obj, lor_index, pos)
117
+ # Derivative
118
  if obj[0] == "DERIVATIVE":
119
+ return derobj_to_tex(obj, lor_index, pos)
120
+ # Sigma (gamma matrices)
121
  if obj[0] == "SIGMA":
122
+ return gamobj_to_tex(obj, lor_index, pos)
123
+
124
+ # Combined COMMUTATOR + DERIVATIVE tokens
125
+ if obj[0] == "COMMUTATOR_ADERIVATIVE":
126
+ new_obj = obj[:]
127
+ new_obj[0] = "DERIVATIVE"
128
+ return "[ " + derobj_to_tex(new_obj, lor_index, pos)
129
+ if obj[0] == "COMMUTATOR_BDERIVATIVE":
130
+ new_obj = obj[:]
131
+ new_obj[0] = "DERIVATIVE"
132
+ return ", " + derobj_to_tex(new_obj, lor_index, pos) + " ]"
133
+
134
+ # Single COMMUTATOR tokens
135
  if obj[0] == "COMMUTATOR_A":
136
+ return "[ " + derobj_to_tex(obj, lor_index, pos)
137
  if obj[0] == "COMMUTATOR_B":
138
+ return ", " + derobj_to_tex(obj, lor_index, pos) + " ]"
139
+
140
+ # Fallback for unrecognized tokens if you like:
141
+ # return f"\\text{{Unhandled}}({obj})"
142
+ return ""
143
+
144
+ def split_with_delimiter_preserved(string, delimiters, ignore_dots=False):
145
+ """
146
+ Splits a string using the given delimiters,
147
+ while preserving them as separate tokens.
148
+ """
149
+ if "." in string and not ignore_dots:
150
+ raise ValueError("Unexpected ending to the generated Lagrangian")
 
 
 
151
  pattern = '(' + '|'.join(map(re.escape, delimiters)) + ')'
152
+ parts = re.split(pattern, string)
153
+ # Turn a lonely "+ " into " + "
154
+ parts = [" + " if p == "+ " else p for p in parts]
155
+ # Remove empty entries
156
+ parts = [p for p in parts if p != ""]
157
+ return parts
158
 
159
  def clean_split(inlist, delimiters):
160
+ """
161
+ Merges an immediate delimiter with its next token
162
+ so that "FIELD " + "SPIN" -> "FIELD SPIN".
163
+ """
164
  i = 0
165
  merged_list = []
166
  while i < len(inlist):
167
  if inlist[i] in delimiters:
168
  if i < len(inlist) - 1:
169
  merged_list.append(inlist[i] + inlist[i+1])
170
+ i += 1
171
  else:
172
+ merged_list.append(inlist[i])
173
  else:
174
  merged_list.append(inlist[i])
175
  i += 1
176
  return merged_list
177
 
 
178
  def get_obj_dict(inlist):
179
  outdict = {}
180
  for iitem in inlist:
181
+ idict = {"ID": None, "LATEX": None}
182
+ # Find any ID=... string
183
+ item_parts = iitem.split()
184
+ the_ids = [x for x in item_parts if x.startswith("ID")]
185
+ if the_ids:
186
+ idict["ID"] = the_ids[0]
187
+ # Always compute LATEX from obj_to_tex
188
+ idict["LATEX"] = obj_to_tex(iitem, "\\mu", "^")
189
  outdict[iitem] = idict
190
  return outdict
191
 
192
  def get_con_dict(inlist):
193
+ """
194
+ For a list of 'contractions' tokens, produce
195
+ a dictionary of which IDs are to be contracted
196
+ under LORENTZ, SU2, or SU3.
197
+ """
198
  outdict = {}
199
  for iitem in inlist:
200
+ tokens = iitem.split()
201
+ tokens = [t for t in tokens if t != ""]
202
+ sym = [t for t in tokens if ("SU" in t or "LORENTZ" in t)]
203
  assert len(sym) == 1, "More than one symmetry in contraction"
204
+ ids = [t for t in tokens if ("SU" not in t and "LZ" not in t)]
205
+ if sym[0] not in outdict:
206
  outdict[sym[0]] = [ids]
207
  else:
208
  outdict[sym[0]].append(ids)
209
  return outdict
210
 
211
+ def term_to_tex(term, verbose=True):
212
+ """
213
+ Converts one Lagrangian term into its LaTeX representation.
214
+ """
215
+ # Clean up certain strings
216
+ term = term.replace(".", "").replace(" = ", "=").replace(" =- ", "=-")
217
+ term = term.replace(" / ", "/")
218
+ term = term.replace("COMMUTATOR_A DERIVATIVE", "COMMUTATOR_ADERIVATIVE")
219
+ term = term.replace("COMMUTATOR_B DERIVATIVE", "COMMUTATOR_BDERIVATIVE")
220
+
221
+ # Split into sub-tokens
222
+ term = split_with_delimiter_preserved(
223
+ term,
224
+ [" FIELD ", " DERIVATIVE ", " SIGMA ", " COMMUTATOR_ADERIVATIVE ", " COMMUTATOR_BDERIVATIVE ", " CONTRACTIONS "]
225
+ )
226
+ term = clean_split(
227
+ term,
228
+ [" FIELD ", " DERIVATIVE ", " SIGMA ", " COMMUTATOR_ADERIVATIVE ", " COMMUTATOR_BDERIVATIVE ", " CONTRACTIONS "]
229
+ )
230
+
231
  if verbose:
232
+ print(term)
 
233
 
234
+ # If it's just +, -, or i, return that token
235
+ if term in [[" + "], [" - "], [" i "]]:
236
+ return term[0]
237
+
238
+ # Build dictionary for objects that aren't in "CONTRACTIONS"
239
+ objdict = get_obj_dict([t for t in term if " CONTRACTIONS " not in t])
240
+ if verbose:
241
+ for k, v in objdict.items():
242
+ print(k, "\t\t", v)
243
 
244
+ # Contractions
245
+ contractions = [t for t in term if " CONTRACTIONS " in t]
246
+ if len(contractions) > 1:
247
+ raise ValueError("More than one contraction in term")
248
+
249
+ if len(contractions) == 1 and contractions != [" CONTRACTIONS "]:
250
+ # e.g. "LORENTZ ID5 ID2", etc.
251
+ c_str = contractions[0]
252
+ c_str = split_with_delimiter_preserved(c_str, [" LORENTZ ", " SU2 ", " SU3 "])
253
+ c_str = clean_split(c_str, [" LORENTZ ", " SU2 ", " SU3 "])
254
+ c_str = [i for i in c_str if i != " CONTRACTIONS"]
255
+ condict = get_con_dict(c_str)
256
+ if verbose:
257
+ print(condict)
258
 
259
+ # LORENTZ contraction handling
260
+ if "LORENTZ" in condict:
 
 
 
 
 
261
  firstlz = True
262
  cma = True
263
+ for con in condict["LORENTZ"]:
264
+ for kobj, iobj in objdict.items():
265
+ if iobj["ID"] is None:
266
+ continue
267
  if iobj["ID"] in con:
268
+ if cma:
269
+ lsymb = r"\mu"
270
+ else:
271
+ lsymb = r"\nu"
272
  if firstlz:
273
+ iobj["LATEX"] = obj_to_tex(kobj, lsymb, "^")
274
  firstlz = False
275
  else:
276
+ iobj["LATEX"] = obj_to_tex(kobj, lsymb, "_")
277
  cma = False
278
  firstlz = True
279
 
280
+ # Join the final LaTeX strings
281
+ outstr = " ".join([objdict[t]["LATEX"] for t in term if " CONTRACTIONS " not in t])
282
  return outstr
 
 
 
 
 
 
 
283
 
284
+ def str_tex(instr, num=0):
285
+ """
286
+ Convert list of terms into complete LaTeX lines for the Lagrangian.
287
+ """
 
 
 
288
  if num != 0:
289
  instr = instr[:num]
290
+
291
+ inlist = [term.replace(".", "") for term in instr]
292
  outstr = ""
293
  coup = 0
294
  mass = 0
295
+ outstr = r"\begin{aligned}"
296
+ for i, iterm in enumerate(inlist):
297
+ if i == 0:
298
+ outstr += r" \mathcal{L}= \quad \\ & "
299
+ else:
300
+ # Identify coupling or mass terms by counting spin-0 fields
301
  nqf = iterm.count("FIELD SPIN = 0")
302
+ nD = iterm.count(" DERIVATIVE ")
303
  if nqf != 0 and nqf != 2 and nD == 0:
304
  coup += 1
305
+ outstr += rf" \lambda_{{{coup}}} \,"
306
  if nqf == 2 and nD == 0:
307
  mass += 1
308
+ outstr += rf" m^2_{{{mass}}} \,"
309
+ outstr += term_to_tex(iterm, False) + r" \quad "
310
+ if i % 4 == 0:
311
+ outstr += r" \\ \\ & "
312
  return outstr
313
 
314
  def master_str_tex(iinstr):
315
+ """
316
+ Master function that splits the incoming string,
317
+ tries to render the full Lagrangian,
318
+ and catches errors if the model text is truncated.
319
+ """
320
+ instr = split_with_delimiter_preserved(iinstr, [" + ", "+ ", " - "])
321
  try:
322
  outstr = str_tex(instr)
323
  except Exception as e:
324
+ # If an error occurs, try ignoring the last token
325
+ outstr = str_tex(instr, -1)
326
+ outstr += " \\cdots"
327
  print(e)
328
+ outstr += r"\end{aligned}"
329
+ return outstr
 
 
330
 
331
+ # ---------------------------------------------------------------------------------
332
+ # Model loading
333
  device = 'cpu'
334
  model_name = "JoseEliel/BART-Lagrangian"
335
 
336
  @st.cache_resource
337
  def load_model():
 
338
  model = BartForConditionalGeneration.from_pretrained(model_name).to(device)
 
339
  return model
340
 
 
 
341
  @st.cache_resource
342
  def load_tokenizer():
343
  return PreTrainedTokenizerFast.from_pretrained(model_name)
344
 
345
+ model = load_model()
346
  hf_tokenizer = load_tokenizer()
347
 
348
+ # ---------------------------------------------------------------------------------
349
+ # Text processing wrappers
350
  def process_input(input_text):
351
+ # Sort fields so generation is consistent
352
  input_text = input_text.replace("[SOS]", "").replace("[EOS]", "").replace("FIELD", "SPLITFIELD")
353
+ fields = input_text.split("SPLIT")[1:]
354
+ fields = [x.strip().split(" ") for x in fields]
355
  fields = sorted(fields)
356
+ fields = "[SOS] " + " ".join([" ".join(x) for x in fields]) + " [EOS]"
357
  return fields
358
 
359
  def process_output(output_text):
360
+ # Remove special tokens from model output
361
+ return output_text.replace("[SOS]", "").replace("[EOS]", "").replace(".", "")
362
 
363
+ def reformat_expression(s):
364
+ # e.g. turn SU2= -1 into SU2=-1, remove spaces
365
+ return re.sub(r"(SU[23]|U1|SPIN|HEL)\s+([+-]?\s*\d+)",
366
+ lambda m: f"{m.group(1)} = {m.group(2).replace(' ', '')}",
367
+ s)
368
 
369
  def generate_lagrangian(input_text):
370
+ """
371
+ Calls the model to produce a Lagrangian for the user-given fields.
372
+ """
373
  input_text = process_input(input_text)
374
  inputs = hf_tokenizer([input_text], return_tensors='pt').to(device)
375
+ with st.spinner("Generating Lagrangian..."):
376
+ lagrangian_ids = model.generate(inputs['input_ids'], max_length=2048)
377
  lagrangian = hf_tokenizer.decode(lagrangian_ids[0].tolist(), skip_special_tokens=False)
378
  lagrangian = process_output(lagrangian)
379
  return lagrangian
380
 
381
  def generate_field(sp, su2, su3, u1):
382
+ """
383
+ Builds a single field string with the chosen spin and gauge charges.
384
+ """
385
+ components = [f"FIELD SPIN={sp}"]
386
+ # For spin = 1/2, optionally add helicity
387
+ if sp == "1/2":
388
  components = [f"FIELD SPIN={sp} HEL=1/2"]
389
+
 
390
  if su2 != "$1$":
391
  components.append(f"SU2={su2}")
392
+ if su3 == "$\\bar{{3}}$":
393
  components.append("SU3=-3")
394
+ elif su3 != "$1$":
395
+ components.append(f"SU3={su3.replace('$','')}")
396
  if u1 != "0":
397
  components.append(f"U1={u1}")
398
+ return " ".join(components).replace("$", "")
 
 
399
 
400
+ # ---------------------------------------------------------------------------------
401
+ # Streamlit GUI
402
  def main():
 
403
  st.title("$\\mathscr{L}$agrangian Generator")
404
  st.markdown(" ### For a set of chosen fields, this model generates the corresponding Lagrangian which encodes all interactions and dynamics of the fields.")
405
 
406
  st.markdown(" #### This is a demo of our [BART](https://arxiv.org/abs/1910.13461)-based model with ca 360M parameters")
407
 
408
+ st.markdown(" ##### :violet[Due to computational resources, we limit the number of fields to 3.]")
409
  st.markdown(" ##### Choose up to three different fields:")
410
 
411
+
412
+ st.markdown("Choose up to three different fields:")
413
  su2_options = ["$1$", "$2$", "$3$"]
414
  su3_options = ["$1$", "$3$", "$\\bar{3}$"]
415
+ u1_options = ["-1", "-2/3", "-1/2", "-1/3", "0", "1/3", "1/2", "2/3", "1"]
416
  spin_options = ["0", "1/2"]
417
+
418
+ if "count" not in st.session_state:
419
+ st.session_state.count = 0
420
+ if "field_strings" not in st.session_state:
421
+ st.session_state.field_strings = []
 
422
 
423
  with st.form("field_selection"):
424
  spin_selection = st.radio("Select spin value:", spin_options)
425
+ su2_selection = st.radio("Select SU(2) value:", su2_options)
426
+ su3_selection = st.radio("Select SU(3) value:", su3_options)
427
+ u1_selection = st.radio("Select U(1) value:", u1_options)
428
  submitted = st.form_submit_button("Add field")
429
  if submitted:
430
  if st.session_state.count < 3:
431
+ fs = generate_field(spin_selection, su2_selection, su3_selection, u1_selection)
432
+ st.session_state.field_strings.append(fs)
433
+ st.session_state.count += 1
434
+ else:
435
+ st.write("Maximum of 3 fields for this demo.")
436
+
437
  clear_fields = st.button("Clear fields")
438
  if clear_fields:
439
  st.session_state.field_strings = []
440
  st.session_state.count = 0
 
441
 
442
+ # Display current fields
443
+ st.write("Input Fields:")
444
  for i, fs in enumerate(st.session_state.field_strings, 1):
445
  texfield = obj_to_tex(fs)
446
+ st.latex(r"\text{Field " + str(i) + ":} \quad " + texfield)
 
447
 
448
+ # Generate Lagrangian button
449
  if st.button("Generate Lagrangian"):
450
  input_fields = " ".join(st.session_state.field_strings)
451
+ if input_fields.strip() == "":
452
+ st.write("Please add at least one field before generating the Lagrangian.")
453
  return
454
+
455
+ input_fields = input_fields.replace("=", " ")
456
+ input_fields = "[SOS] " + input_fields + " [EOS]"
457
+ generated_lagrangian = generate_lagrangian(input_fields)
458
+ generated_lagrangian = reformat_expression(generated_lagrangian)
459
+ print(generated_lagrangian)
460
+
461
+ # Attempt to render as LaTeX
462
+ latex_output = master_str_tex(generated_lagrangian[1:])
463
+ st.latex(latex_output)
464
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
  st.markdown("### Contact")
466
+ st.markdown("For questions/suggestions, email us: [Eliel](mailto:[email protected]) or [Yong Sheng](mailto:[email protected]).")
467
 
468
  if __name__ == "__main__":
469
  main()