Spaces:
Paused
Paused
Update src/main.py
Browse files- src/main.py +38 -7
src/main.py
CHANGED
@@ -90,23 +90,49 @@ def notify_success(project_id: str):
|
|
90 |
description=message,
|
91 |
token=HF_ACCESS_TOKEN,
|
92 |
)
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
def deploy_model(id: str):
|
|
|
95 |
url = "https://api.endpoints.huggingface.cloud/v2/endpoint/Platma"
|
96 |
data = {"compute": {"accelerator": "gpu", "instanceSize": "x1", "instanceType": "nvidia-l4",
|
97 |
-
"scaling": {"maxReplica": 1, "minReplica": 1, "scaleToZeroTimeout":15}},
|
98 |
"model": {"framework": "pytorch", "image": {
|
99 |
"custom": {"health_route": "/health",
|
100 |
"url": "ghcr.io/huggingface/text-generation-inference:sha-f852190",
|
101 |
-
"env": {"MAX_BATCH_PREFILL_TOKENS": "2048", "MAX_INPUT_LENGTH": "
|
102 |
-
"MAX_TOTAL_TOKENS": "
|
103 |
"MODEL_ID": "/repository"}}},
|
104 |
"repository": f"Platma/{id}",
|
105 |
"secrets": {},
|
106 |
"task": "text-generation"},
|
107 |
-
"name": "
|
108 |
-
headers = {"Authorization": f"Bearer {HF_ACCESS_TOKEN}"}
|
109 |
-
r = requests.post(url,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
print(r)
|
111 |
|
112 |
NOTIFICATION_TEMPLATE = """\
|
@@ -117,5 +143,10 @@ Following an update of [{input_dataset}](https://huggingface.co/datasets/{input_
|
|
117 |
(This is an automated message)
|
118 |
"""
|
119 |
|
|
|
|
|
|
|
|
|
|
|
120 |
if __name__ == "__main__":
|
121 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
90 |
description=message,
|
91 |
token=HF_ACCESS_TOKEN,
|
92 |
)
|
93 |
+
|
94 |
+
def notify_url(url: str):
|
95 |
+
message = URL_TEMPLATE.format(
|
96 |
+
url=url,
|
97 |
+
)
|
98 |
+
return HfApi(token=HF_ACCESS_TOKEN).create_discussion(
|
99 |
+
repo_id=config.input_dataset,
|
100 |
+
repo_type="dataset",
|
101 |
+
title="✨ Endpoint is ready!",
|
102 |
+
description=message,
|
103 |
+
token=HF_ACCESS_TOKEN,
|
104 |
+
)
|
105 |
+
|
106 |
def deploy_model(id: str):
|
107 |
+
api = HfApi(token=HF_ACCESS_TOKEN)
|
108 |
url = "https://api.endpoints.huggingface.cloud/v2/endpoint/Platma"
|
109 |
data = {"compute": {"accelerator": "gpu", "instanceSize": "x1", "instanceType": "nvidia-l4",
|
110 |
+
"scaling": {"maxReplica": 1, "minReplica": 1, "scaleToZeroTimeout": 15}},
|
111 |
"model": {"framework": "pytorch", "image": {
|
112 |
"custom": {"health_route": "/health",
|
113 |
"url": "ghcr.io/huggingface/text-generation-inference:sha-f852190",
|
114 |
+
"env": {"MAX_BATCH_PREFILL_TOKENS": "2048", "MAX_INPUT_LENGTH": "2048",
|
115 |
+
"MAX_TOTAL_TOKENS": "2512",
|
116 |
"MODEL_ID": "/repository"}}},
|
117 |
"repository": f"Platma/{id}",
|
118 |
"secrets": {},
|
119 |
"task": "text-generation"},
|
120 |
+
"name": f"platma-{id}", "provider": {"region": "us-east-1", "vendor": "aws"}, "type": "protected"}
|
121 |
+
headers = {"Authorization": f"Bearer {HF_ACCESS_TOKEN}", "Content-Type": "application/json"}
|
122 |
+
r = requests.post(url, json=data, headers=headers)
|
123 |
+
print(r)
|
124 |
+
r = api.get_inference_endpoint(name=f"platma-{id}")
|
125 |
+
while True:
|
126 |
+
print("Fetching url")
|
127 |
+
if r.status == 'running':
|
128 |
+
print(r)
|
129 |
+
notify_url(r.url)
|
130 |
+
break
|
131 |
+
else:
|
132 |
+
if r.status == 'error':
|
133 |
+
break
|
134 |
+
time.sleep(10)
|
135 |
+
r = api.get_inference_endpoint(name=f"platma-{id}")
|
136 |
print(r)
|
137 |
|
138 |
NOTIFICATION_TEMPLATE = """\
|
|
|
143 |
(This is an automated message)
|
144 |
"""
|
145 |
|
146 |
+
URL_TEMPLATE = """\
|
147 |
+
Here is your endpoint: {url}
|
148 |
+
(This is an automated message)
|
149 |
+
"""
|
150 |
+
|
151 |
if __name__ == "__main__":
|
152 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|