Platma commited on
Commit
7794767
·
verified ·
1 Parent(s): 89ca032

Update src/main.py

Browse files
Files changed (1) hide show
  1. src/main.py +38 -7
src/main.py CHANGED
@@ -90,23 +90,49 @@ def notify_success(project_id: str):
90
  description=message,
91
  token=HF_ACCESS_TOKEN,
92
  )
93
-
 
 
 
 
 
 
 
 
 
 
 
 
94
  def deploy_model(id: str):
 
95
  url = "https://api.endpoints.huggingface.cloud/v2/endpoint/Platma"
96
  data = {"compute": {"accelerator": "gpu", "instanceSize": "x1", "instanceType": "nvidia-l4",
97
- "scaling": {"maxReplica": 1, "minReplica": 1, "scaleToZeroTimeout":15}},
98
  "model": {"framework": "pytorch", "image": {
99
  "custom": {"health_route": "/health",
100
  "url": "ghcr.io/huggingface/text-generation-inference:sha-f852190",
101
- "env": {"MAX_BATCH_PREFILL_TOKENS": "2048", "MAX_INPUT_LENGTH": "1024",
102
- "MAX_TOTAL_TOKENS": "1512",
103
  "MODEL_ID": "/repository"}}},
104
  "repository": f"Platma/{id}",
105
  "secrets": {},
106
  "task": "text-generation"},
107
- "name": "1726061674-dip", "provider": {"region": "us-east-1", "vendor": "aws"}, "type": "protected"}
108
- headers = {"Authorization": f"Bearer {HF_ACCESS_TOKEN}"}
109
- r = requests.post(url, data=data, headers=headers)
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  print(r)
111
 
112
  NOTIFICATION_TEMPLATE = """\
@@ -117,5 +143,10 @@ Following an update of [{input_dataset}](https://huggingface.co/datasets/{input_
117
  (This is an automated message)
118
  """
119
 
 
 
 
 
 
120
  if __name__ == "__main__":
121
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
90
  description=message,
91
  token=HF_ACCESS_TOKEN,
92
  )
93
+
94
+ def notify_url(url: str):
95
+ message = URL_TEMPLATE.format(
96
+ url=url,
97
+ )
98
+ return HfApi(token=HF_ACCESS_TOKEN).create_discussion(
99
+ repo_id=config.input_dataset,
100
+ repo_type="dataset",
101
+ title="✨ Endpoint is ready!",
102
+ description=message,
103
+ token=HF_ACCESS_TOKEN,
104
+ )
105
+
106
  def deploy_model(id: str):
107
+ api = HfApi(token=HF_ACCESS_TOKEN)
108
  url = "https://api.endpoints.huggingface.cloud/v2/endpoint/Platma"
109
  data = {"compute": {"accelerator": "gpu", "instanceSize": "x1", "instanceType": "nvidia-l4",
110
+ "scaling": {"maxReplica": 1, "minReplica": 1, "scaleToZeroTimeout": 15}},
111
  "model": {"framework": "pytorch", "image": {
112
  "custom": {"health_route": "/health",
113
  "url": "ghcr.io/huggingface/text-generation-inference:sha-f852190",
114
+ "env": {"MAX_BATCH_PREFILL_TOKENS": "2048", "MAX_INPUT_LENGTH": "2048",
115
+ "MAX_TOTAL_TOKENS": "2512",
116
  "MODEL_ID": "/repository"}}},
117
  "repository": f"Platma/{id}",
118
  "secrets": {},
119
  "task": "text-generation"},
120
+ "name": f"platma-{id}", "provider": {"region": "us-east-1", "vendor": "aws"}, "type": "protected"}
121
+ headers = {"Authorization": f"Bearer {HF_ACCESS_TOKEN}", "Content-Type": "application/json"}
122
+ r = requests.post(url, json=data, headers=headers)
123
+ print(r)
124
+ r = api.get_inference_endpoint(name=f"platma-{id}")
125
+ while True:
126
+ print("Fetching url")
127
+ if r.status == 'running':
128
+ print(r)
129
+ notify_url(r.url)
130
+ break
131
+ else:
132
+ if r.status == 'error':
133
+ break
134
+ time.sleep(10)
135
+ r = api.get_inference_endpoint(name=f"platma-{id}")
136
  print(r)
137
 
138
  NOTIFICATION_TEMPLATE = """\
 
143
  (This is an automated message)
144
  """
145
 
146
+ URL_TEMPLATE = """\
147
+ Here is your endpoint: {url}
148
+ (This is an automated message)
149
+ """
150
+
151
  if __name__ == "__main__":
152
  uvicorn.run(app, host="0.0.0.0", port=8000)