File size: 2,087 Bytes
f5776d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# # To Run:
# # python -m dsp.modules.hf_server --port 4242 --model "google/flan-t5-base"

# # To Query:
# # curl -d '{"prompt":".."}' -X POST "http://0.0.0.0:4242" -H 'Content-Type: application/json'
# # Or use the HF client. TODO: Add support for kwargs to the server.


# from functools import lru_cache
# import argparse
# import time
# import random
# import os
# import sys
# import uvicorn
# import warnings

# from fastapi import FastAPI
# from pydantic import BaseModel
# from argparse import ArgumentParser
# from starlette.middleware.cors import CORSMiddleware

# from dsp.modules.hf import HFModel


# class Query(BaseModel):
#     prompt: str
#     kwargs: dict = {}


# warnings.filterwarnings("ignore")

# app = FastAPI()
# app.add_middleware(
#     CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]
# )

# parser = argparse.ArgumentParser("Server for Hugging Face models")
# parser.add_argument("--port", type=int, required=True, help="Server port")
# parser.add_argument("--model", type=str, required=True, help="Hugging Face model")
# args = parser.parse_args()
# # TODO: Convert this to a log message
# print(f"#> Loading the language model {args.model}")
# lm = HFModel(args.model)


# @lru_cache(maxsize=None)
# def generate(prompt, **kwargs):
#     global lm
#     generateStart = time.time()
#     # TODO: Convert this to a log message
#     print(f'#> kwargs: "{kwargs}" (type={type(kwargs)})')
#     response = lm._generate(prompt, **kwargs)
#     # TODO: Convert this to a log message
#     print(f'#> Response: "{response}"')
#     latency = (time.time() - generateStart) * 1000.0
#     response["latency"] = latency
#     print(f'#> Latency:', '{:.3f}'.format(latency / 1000.0), 'seconds')
#     return response


# @app.post("/")
# async def generate_post(query: Query):
#     return generate(query.prompt, **query.kwargs)


# if __name__ == "__main__":
#     uvicorn.run(
#         app,
#         host="0.0.0.0",
#         port=args.port,
#         reload=False,
#         log_level="info",
#     )  # can make reload=True later