Other
English
minecraft
action prediction
Kqte commited on
Commit
f7aaeb6
·
verified ·
1 Parent(s): df37595

Upload output_formatter.py

Browse files
Files changed (1) hide show
  1. evaluation/output_formatter.py +93 -0
evaluation/output_formatter.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ formats output text file from llampia
3
+ into markdown tables
4
+ """
5
+
6
+ import os
7
+
8
+ def format_edu(turn_str):
9
+ no = turn_str.split('<')[0].strip()
10
+ text = turn_str.split('>')[1].strip()
11
+ speaker = 'BUIL'
12
+ if 'Arch' in turn_str.split('>')[0]:
13
+ speaker = 'ARCH'
14
+ edu_str = ' ' + '**' + no + ' ' + speaker + '**' + ' ' + text
15
+ return edu_str
16
+
17
+ current_folder=os.getcwd()
18
+
19
+ output_path = '/path/to/test-output-file.txt'
20
+ table_path = current_folder + '/test_predictions.md'
21
+
22
+ with open(output_path, 'r') as txt:
23
+ text = txt.read().split('\n')
24
+
25
+ group = 0
26
+ pairs = []
27
+ pair = []
28
+ for t in text:
29
+ if t.startswith('New Turn'):
30
+ group = 1
31
+ pair.append(t)
32
+ elif t.startswith(' ### DS:'):
33
+ pair.append(t)
34
+ pairs.append(pair)
35
+ pair = []
36
+ group = 0
37
+ elif group == 1:
38
+ pair.append(t)
39
+
40
+ tables = []
41
+ rows = []
42
+ dial_str = ''
43
+ struct_str = ''
44
+ for pair in pairs:
45
+ #print(pair)
46
+ if pair[0].startswith('New Turn: 1 <'):
47
+ if len(rows) > 0:
48
+ tables.append(rows)
49
+ rows = []
50
+ dial_str = '**0 BUIL** Mission has started.'
51
+ dial_str += format_edu(pair[0].lstrip('New Turn: '))
52
+ for p in pair[1:]:
53
+ if p.startswith(' ###'):
54
+ struct_str = p.lstrip(' ### DS:')
55
+ rows.append([dial_str, struct_str])
56
+ # dial_str = ''
57
+ # struct_str = ''
58
+ else:
59
+ dial_str += format_edu(p)
60
+ elif pair[0].startswith('New Turn: '):
61
+ dial_str = format_edu(pair[0].lstrip('New Turn: '))
62
+ #print(dial_str)
63
+ for p in pair[1:]:
64
+ if p.startswith(' ###'):
65
+ struct_str = p.lstrip(' ### DS:')
66
+ # print(struct_str)
67
+ # print(dial_str)
68
+ rows.append([dial_str, struct_str])
69
+ # dial_str = ''
70
+ # struct_str = ''
71
+ else:
72
+ dial_str += format_edu(p)
73
+
74
+ all_md_tables = []
75
+
76
+ for table in tables:
77
+ table_rows = ['| Dialogue | Structure |', '| ----- | ----- |']
78
+ for tr in table:
79
+ st = '| ' + tr[0] + ' | ' + tr[1] + ' |'
80
+ table_rows.append(st)
81
+ all_md_tables.extend(table_rows)
82
+ all_md_tables.extend(' ')
83
+
84
+ md_tables_string = '\n'.join(all_md_tables)
85
+
86
+
87
+ f = open(table_path, 'w')
88
+ for r in all_md_tables:
89
+ print(r)
90
+ print(r, file=f)
91
+ print("markdown printed")
92
+
93
+