File size: 2,087 Bytes
f5776d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
# # To Run:
# # python -m dsp.modules.hf_server --port 4242 --model "google/flan-t5-base"
# # To Query:
# # curl -d '{"prompt":".."}' -X POST "http://0.0.0.0:4242" -H 'Content-Type: application/json'
# # Or use the HF client. TODO: Add support for kwargs to the server.
# from functools import lru_cache
# import argparse
# import time
# import random
# import os
# import sys
# import uvicorn
# import warnings
# from fastapi import FastAPI
# from pydantic import BaseModel
# from argparse import ArgumentParser
# from starlette.middleware.cors import CORSMiddleware
# from dsp.modules.hf import HFModel
# class Query(BaseModel):
# prompt: str
# kwargs: dict = {}
# warnings.filterwarnings("ignore")
# app = FastAPI()
# app.add_middleware(
# CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]
# )
# parser = argparse.ArgumentParser("Server for Hugging Face models")
# parser.add_argument("--port", type=int, required=True, help="Server port")
# parser.add_argument("--model", type=str, required=True, help="Hugging Face model")
# args = parser.parse_args()
# # TODO: Convert this to a log message
# print(f"#> Loading the language model {args.model}")
# lm = HFModel(args.model)
# @lru_cache(maxsize=None)
# def generate(prompt, **kwargs):
# global lm
# generateStart = time.time()
# # TODO: Convert this to a log message
# print(f'#> kwargs: "{kwargs}" (type={type(kwargs)})')
# response = lm._generate(prompt, **kwargs)
# # TODO: Convert this to a log message
# print(f'#> Response: "{response}"')
# latency = (time.time() - generateStart) * 1000.0
# response["latency"] = latency
# print(f'#> Latency:', '{:.3f}'.format(latency / 1000.0), 'seconds')
# return response
# @app.post("/")
# async def generate_post(query: Query):
# return generate(query.prompt, **query.kwargs)
# if __name__ == "__main__":
# uvicorn.run(
# app,
# host="0.0.0.0",
# port=args.port,
# reload=False,
# log_level="info",
# ) # can make reload=True later
|