johnpaulbin commited on
Commit
9ecb07b
·
1 Parent(s): 49e7ae7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +371 -0
app.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ '''A native client simulating the plugin to use for testing the server'''
3
+ import asyncio
4
+ import itertools
5
+ import struct
6
+ import json
7
+ import time
8
+ import sys
9
+ import csv
10
+ from pathlib import Path
11
+ from pprint import pprint
12
+ from tqdm import tqdm
13
+
14
+
15
+ class Timer:
16
+ """Little helper class top measure runtime of async function calls and dump
17
+ all of those to a CSV.
18
+ """
19
+ def __init__(self):
20
+ self.measurements = []
21
+
22
+ async def measure(self, coro, *details):
23
+ start = time.perf_counter()
24
+ result = await coro
25
+ end = time.perf_counter()
26
+ self.measurements.append([end - start, *details])
27
+ return result
28
+
29
+ def dump(self, fh):
30
+ # TODO stats? For now I just export to Excel or something
31
+ writer = csv.writer(fh)
32
+ writer.writerows(self.measurements)
33
+
34
+
35
+ class Client:
36
+ """asyncio based native messaging client. Main interface is just calling
37
+ `request()` with the right parameters and awaiting the future it returns.
38
+ """
39
+ def __init__(self, *args):
40
+ self.serial = itertools.count(1)
41
+ self.futures = {}
42
+ self.args = args
43
+
44
+ async def __aenter__(self):
45
+ self.proc = await asyncio.create_subprocess_exec(*self.args, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE)
46
+ self.read_task = asyncio.create_task(self.reader())
47
+ return self
48
+
49
+ async def __aexit__(self, *args):
50
+ self.proc.stdin.close()
51
+ await self.proc.wait()
52
+
53
+ def request(self, command, data, *, update=lambda data: None):
54
+ message_id = next(self.serial)
55
+ message = json.dumps({"command": command, "id": message_id, "data": data}).encode()
56
+ # print(f"Sending: {message}", file=sys.stderr)
57
+ future = asyncio.get_running_loop().create_future()
58
+ self.futures[message_id] = future, update
59
+ self.proc.stdin.write(struct.pack("@I", len(message)))
60
+ self.proc.stdin.write(message)
61
+ return future
62
+
63
+ async def reader(self):
64
+ while True:
65
+ try:
66
+ raw_length = await self.proc.stdout.readexactly(4)
67
+ length = struct.unpack("@I", raw_length)[0]
68
+ raw_message = await self.proc.stdout.readexactly(length)
69
+
70
+ # print(f"Receiving: {raw_message.decode()}", file=sys.stderr)
71
+ message = json.loads(raw_message)
72
+
73
+ # Not cool if there is no response message "id" here
74
+ if not "id" in message:
75
+ continue
76
+
77
+ # print(f"Receiving response to {message['id']}", file=sys.stderr)
78
+ future, update = self.futures[message["id"]]
79
+
80
+ if "success" in message:
81
+ del self.futures[message["id"]]
82
+ if message["success"]:
83
+ future.set_result(message["data"])
84
+ else:
85
+ future.set_exception(Exception(message["error"]))
86
+ elif "update" in message:
87
+ update(message["data"])
88
+ except asyncio.IncompleteReadError:
89
+ break # Stop read loop if EOF is reached
90
+ except asyncio.CancelledError:
91
+ break # Also stop reading if we're cancelled
92
+
93
+
94
+ class TranslateLocally(Client):
95
+ """TranslateLocally wrapper around Client that translates
96
+ our defined messages into functions with arguments.
97
+ """
98
+ async def list_models(self, *, include_remote=False):
99
+ return await self.request("ListModels", {"includeRemote": bool(include_remote)})
100
+
101
+ async def translate(self, text, src=None, trg=None, *, model=None, pivot=None, html=False):
102
+ if src and trg:
103
+ if model or pivot:
104
+ raise InvalidArgumentException("Cannot combine src + trg and model + pivot arguments")
105
+ spec = {"src": str(src), "trg": str(trg)}
106
+ elif model:
107
+ if pivot:
108
+ spec = {"model": str(model), "pivot": str(pivot)}
109
+ else:
110
+ spec = {"model": str(model)}
111
+ else:
112
+ raise InvalidArgumentException("Missing src + trg or model argument")
113
+
114
+ result = await self.request("Translate", {**spec, "text": str(text), "html": bool(html)})
115
+ return result["target"]["text"]
116
+
117
+ async def download_model(self, model_id, *, update=lambda data: None):
118
+ return await self.request("DownloadModel", {"modelID": str(model_id)}, update=update)
119
+
120
+
121
+ def first(iterable, *default):
122
+ """Returns the first value of anything iterable, or throws StopIteration
123
+ if it is empty. Or, if you specify a default argument, it will return that.
124
+ """
125
+ return next(iter(iterable), *default) # passing as rest argument so it can be nothing and trigger StopIteration exception
126
+
127
+
128
+ def get_build():
129
+ """Instantiate an asyncio TranslateLocally client that connects to
130
+ tranlateLocally in your local build directory.
131
+ """
132
+ paths = [
133
+ Path("./translateLocally"),
134
+ Path(__file__).resolve().parent / Path("../build/translateLocally")
135
+ ];
136
+
137
+ for path in paths:
138
+ if path.exists():
139
+ return TranslateLocally(path.resolve(), "-p", "--debug")
140
+ raise RuntimeError("Could not find translateLocally binary")
141
+
142
+
143
+
144
+ async def download_with_progress(tl, model, position):
145
+ """tl.download but with a tqdm powered progress bar."""
146
+ with tqdm(position=position, desc=model["modelName"], unit="b", unit_scale=True, leave=False) as bar:
147
+ def update(data):
148
+ assert data["read"] <= data["size"]
149
+ bar.total = data["size"]
150
+ diff = data["read"] - bar.n
151
+ bar.update(diff)
152
+ return await tl.download_model(model["id"], update=update)
153
+
154
+
155
+ async def test():
156
+ """Test TranslateLocally functionality."""
157
+ async with get_build() as tl:
158
+ models = await tl.list_models(include_remote=True)
159
+ pprint(models)
160
+
161
+ # Models necessary for tests, both direct & pivot
162
+ necessary_models = {("en", "de"), ("en", "es"), ("es", "en")}
163
+
164
+ # From all models available, pick one for every necessary language pair
165
+ # (preferring tiny ones) so we can make sure these are downloaded.
166
+ selected_models = {
167
+ (src,trg): first(sorted(
168
+ (
169
+ model
170
+ for model in models
171
+ if src in model["srcTags"] and trg == model["trgTag"]
172
+ ),
173
+ key=lambda model: 0 if model["type"] == "tiny" else 1
174
+ ))
175
+ for src, trg in necessary_models
176
+ }
177
+
178
+ pprint(selected_models)
179
+
180
+ # Download them. Even if they're already model['local'] == True, to test
181
+ # that in that case this is a no-op.
182
+ await asyncio.gather(*(
183
+ download_with_progress(tl, model, position)
184
+ for position, model in enumerate(selected_models.values())
185
+ ))
186
+ print() # tqdm messes a lot with the print position, this makes it less bad
187
+
188
+ # Test whether the model list has been updated to reflect that the
189
+ # downloaded models are now local.
190
+ models = await tl.list_models(include_remote=True)
191
+ assert all(
192
+ model["local"]
193
+ for selected_model in selected_models.values()
194
+ for model in models
195
+ if model["id"] == selected_model["id"]
196
+ )
197
+
198
+ # Perform some translations, switching between the models
199
+ translations = await asyncio.gather(
200
+ tl.translate("Hello world!", "en", "de"),
201
+ tl.translate("Let's translate another sentence to German.", "en", "de"),
202
+ tl.translate("Sticks and stones may break my bones but words WILL NEVER HURT ME!", "en", "es"),
203
+ tl.translate("I <i>like</i> to drive my car. But I don't have one.", "en", "de", html=True),
204
+ tl.translate("¿Por qué no funciona bien?", "es", "de"),
205
+ tl.translate("This will be the last sentence of the day.", "en", "de"),
206
+ )
207
+
208
+ pprint(translations)
209
+
210
+ assert translations == [
211
+ "Hallo Welt!",
212
+ "Übersetzen wir einen weiteren Satz mit Deutsch.",
213
+ "Palos y piedras pueden romper mis huesos, pero las palabras NUNCA HURT ME.",
214
+ "Ich <i>fahre gerne</i> mein Auto. Aber ich habe keine.", #<i>fahre</i>???
215
+ "Warum funktioniert es nicht gut?",
216
+ "Dies wird der letzte Satz des Tages sein.",
217
+ ]
218
+
219
+ # Test bad input
220
+ try:
221
+ await tl.translate("This is impossible to translate", "en", "xx")
222
+ assert False, "How are we able to translate to 'xx'???"
223
+ except Exception as e:
224
+ assert "Could not find the necessary translation models" in str(e)
225
+
226
+ print("Fin")
227
+
228
+
229
+ async def test_third_party():
230
+ """Test whether TranslateLocally can switch between different types of
231
+ models. This test assumes you have the OPUS repository in your list:
232
+ https://object.pouta.csc.fi/OPUS-MT-models/app/models.json
233
+ """
234
+ async with get_build() as tl:
235
+ models_to_try = [
236
+ 'en-de-tiny',
237
+ 'en-de-base',
238
+ 'eng-fin-tiny', # model has broken model_info.json so won't work anyway :(
239
+ 'eng-ukr-tiny',
240
+ ]
241
+
242
+ models = await tl.list_models(include_remote=True)
243
+
244
+ # Select a model from the model list for each of models_to_try, but
245
+ # leave it out if there is no model available.
246
+ selected_models = {
247
+ shortname: model
248
+ for shortname in models_to_try
249
+ if (model := first((model for model in models if model["shortname"] == shortname), None))
250
+ }
251
+
252
+ await asyncio.gather(*(
253
+ download_with_progress(tl, model, position)
254
+ for position, model in enumerate(selected_models.values())
255
+ ))
256
+
257
+ # TODO: Temporary filter to figure out 'failed' downloads. eng-fin-tiny
258
+ # has a broken JSON file so it will download correctly, but still not
259
+ # be available or show up in this list. We should probably make the
260
+ # download fail in that scenario.
261
+ models = await tl.list_models(include_remote=False)
262
+ for shortname in list(selected_models.keys()):
263
+ if not any(True for model in models if model["shortname"] == shortname):
264
+ print(f"Skipping {shortname} because it didn't show up in model list after downloading", file=sys.stderr)
265
+ del selected_models[shortname]
266
+
267
+ translations = await asyncio.gather(*[
268
+ tl.translate("This is a very simple test sentence", model=model["id"])
269
+ for model in selected_models.values()
270
+ ])
271
+
272
+ pprint(list(zip(selected_models.keys(), translations)))
273
+
274
+
275
+ async def test_latency():
276
+ timer = Timer()
277
+
278
+ # Our line generator: just read Crime & Punishment from stdin :D
279
+ lines = (line.strip() for line in sys.stdin)
280
+
281
+ async with get_build() as tl:
282
+ for epoch in range(100):
283
+ print(f"Epoch {epoch}...", file=sys.stderr)
284
+ for batch_size in [1, 5, 10, 20, 50, 100]:
285
+ await asyncio.gather(*(
286
+ timer.measure(
287
+ tl.translate(line, "en", "de"),
288
+ epoch,
289
+ batch_size,
290
+ len(line.split(' ')))
291
+ for n, line in zip(range(batch_size), lines)
292
+ ))
293
+
294
+ timer.dump(sys.stdout)
295
+
296
+
297
+ async def test_concurrency():
298
+ async with get_build() as tl:
299
+ fetch_one = tl.list_models(include_remote=True)
300
+ fetch_two = tl.list_models(include_remote=False)
301
+ fetch_three = tl.list_models(include_remote=True)
302
+ await asyncio.gather(fetch_one, fetch_two, fetch_three)
303
+
304
+
305
+ async def test_shutdown():
306
+ tasks = []
307
+ async with get_build() as tl:
308
+ for n in range(10):
309
+ print(f"Requesting translation {n}")
310
+ tasks.append(tl.request("Translate", {
311
+ "src": "en",
312
+ "trg": "de",
313
+ "text": f"This is simple sentence number {n}!",
314
+ "html": False
315
+ }))
316
+ print("Shutting down")
317
+ print("Shutdown complete")
318
+ for translation in asyncio.as_completed(tasks):
319
+ print(await translation)
320
+ print("Fin.")
321
+
322
+
323
+ async def test_concurrent_download():
324
+ """Test parallel downloads."""
325
+ async with get_build() as tl:
326
+ models = await tl.list_models(include_remote=True)
327
+ remote = [model for model in models if not model["local"]]
328
+ downloads = [
329
+ tl.download_model(model["id"])
330
+ for model, _ in zip(remote, range(3))
331
+ ]
332
+ await asyncio.gather(*downloads)
333
+
334
+
335
+ @app.route('/list_models', methods=['GET'])
336
+ def list_models():
337
+ include_remote = request.args.get('include_remote', type=bool, default=False)
338
+ loop = asyncio.new_event_loop()
339
+ asyncio.set_event_loop(loop)
340
+ result = loop.run_until_complete(get_build().list_models(include_remote=include_remote))
341
+ return jsonify(result)
342
+
343
+ @app.route('/translate', methods=['POST'])
344
+ def translate():
345
+ data = request.get_json()
346
+ text = data.get('text')
347
+ src = data.get('src')
348
+ trg = data.get('trg')
349
+ model = data.get('model')
350
+ pivot = data.get('pivot')
351
+ html = data.get('html', False)
352
+
353
+ loop = asyncio.new_event_loop()
354
+ asyncio.set_event_loop(loop)
355
+ result = loop.run_until_complete(get_build().translate(text, src, trg, model=model, pivot=pivot, html=html))
356
+ return jsonify(result)
357
+
358
+ @app.route('/download_model', methods=['POST'])
359
+ def download_model():
360
+ data = request.get_json()
361
+ model_id = data.get('model_id')
362
+
363
+ loop = asyncio.new_event_loop()
364
+ asyncio.set_event_loop(loop)
365
+ result = loop.run_until_complete(get_build().download_model(model_id))
366
+ return jsonify(result)
367
+
368
+ # More endpoints for other methods
369
+
370
+ if __name__ == '__main__':
371
+ app.run(host='0.0.0.0', port=7860)