# # To Run: | |
# # python -m dsp.modules.hf_server --port 4242 --model "google/flan-t5-base" | |
# # To Query: | |
# # curl -d '{"prompt":".."}' -X POST "http://0.0.0.0:4242" -H 'Content-Type: application/json' | |
# # Or use the HF client. TODO: Add support for kwargs to the server. | |
# from functools import lru_cache | |
# import argparse | |
# import time | |
# import random | |
# import os | |
# import sys | |
# import uvicorn | |
# import warnings | |
# from fastapi import FastAPI | |
# from pydantic import BaseModel | |
# from argparse import ArgumentParser | |
# from starlette.middleware.cors import CORSMiddleware | |
# from dsp.modules.hf import HFModel | |
# class Query(BaseModel): | |
# prompt: str | |
# kwargs: dict = {} | |
# warnings.filterwarnings("ignore") | |
# app = FastAPI() | |
# app.add_middleware( | |
# CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"] | |
# ) | |
# parser = argparse.ArgumentParser("Server for Hugging Face models") | |
# parser.add_argument("--port", type=int, required=True, help="Server port") | |
# parser.add_argument("--model", type=str, required=True, help="Hugging Face model") | |
# args = parser.parse_args() | |
# # TODO: Convert this to a log message | |
# print(f"#> Loading the language model {args.model}") | |
# lm = HFModel(args.model) | |
# @lru_cache(maxsize=None) | |
# def generate(prompt, **kwargs): | |
# global lm | |
# generateStart = time.time() | |
# # TODO: Convert this to a log message | |
# print(f'#> kwargs: "{kwargs}" (type={type(kwargs)})') | |
# response = lm._generate(prompt, **kwargs) | |
# # TODO: Convert this to a log message | |
# print(f'#> Response: "{response}"') | |
# latency = (time.time() - generateStart) * 1000.0 | |
# response["latency"] = latency | |
# print(f'#> Latency:', '{:.3f}'.format(latency / 1000.0), 'seconds') | |
# return response | |
# @app.post("/") | |
# async def generate_post(query: Query): | |
# return generate(query.prompt, **query.kwargs) | |
# if __name__ == "__main__": | |
# uvicorn.run( | |
# app, | |
# host="0.0.0.0", | |
# port=args.port, | |
# reload=False, | |
# log_level="info", | |
# ) # can make reload=True later | |