Spaces:
Build error
Build error
Commit
·
231b875
1
Parent(s):
b81532f
Add vizdoom
Browse files
app.py
CHANGED
|
@@ -39,6 +39,23 @@ def get_user_models(hf_username, env_tag, lib_tag):
|
|
| 39 |
return user_model_ids
|
| 40 |
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
def get_metadata(model_id):
|
| 43 |
"""
|
| 44 |
Get model metadata (contains evaluation data)
|
|
@@ -225,18 +242,22 @@ def certification(hf_username, first_name, last_name):
|
|
| 225 |
},
|
| 226 |
{
|
| 227 |
"unit": "Unit 8 PII",
|
| 228 |
-
"env": "
|
| 229 |
-
"library": "
|
| 230 |
-
"min_result":
|
| 231 |
"best_result": 0,
|
| 232 |
"best_model_id": "",
|
| 233 |
"passed_": False
|
| 234 |
},
|
| 235 |
]
|
| 236 |
for unit in results_certification:
|
|
|
|
| 237 |
# Get user model
|
| 238 |
-
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
| 240 |
# Calculate the best result and get the best_model_id
|
| 241 |
best_result, best_model_id = calculate_best_result(user_models)
|
| 242 |
|
|
|
|
| 39 |
return user_model_ids
|
| 40 |
|
| 41 |
|
| 42 |
+
def get_user_sf_models(hf_username, env_tag, lib_tag):
|
| 43 |
+
models_sf = []
|
| 44 |
+
models = api.list_models(author=hf_username, filter=["reinforcement-learning", lib_tag])
|
| 45 |
+
|
| 46 |
+
user_model_ids = [x.modelId for x in models]
|
| 47 |
+
|
| 48 |
+
for model in user_model_ids:
|
| 49 |
+
meta = get_metadata(model)
|
| 50 |
+
if meta is None:
|
| 51 |
+
continue
|
| 52 |
+
result = meta["model-index"][0]["results"][0]["dataset"]["name"]
|
| 53 |
+
if result == env_tag:
|
| 54 |
+
models_sf.append(model)
|
| 55 |
+
|
| 56 |
+
return models_sf
|
| 57 |
+
|
| 58 |
+
|
| 59 |
def get_metadata(model_id):
|
| 60 |
"""
|
| 61 |
Get model metadata (contains evaluation data)
|
|
|
|
| 242 |
},
|
| 243 |
{
|
| 244 |
"unit": "Unit 8 PII",
|
| 245 |
+
"env": "doom_health_gathering_supreme",
|
| 246 |
+
"library": "sample-factory",
|
| 247 |
+
"min_result": 5,
|
| 248 |
"best_result": 0,
|
| 249 |
"best_model_id": "",
|
| 250 |
"passed_": False
|
| 251 |
},
|
| 252 |
]
|
| 253 |
for unit in results_certification:
|
| 254 |
+
if unit["unit"] != "Unit 8 PII":
|
| 255 |
# Get user model
|
| 256 |
+
user_models = get_user_models(hf_username, unit['env'], unit['library'])
|
| 257 |
+
# For sample factory vizdoom we don't have env tag for now
|
| 258 |
+
else:
|
| 259 |
+
user_models = get_user_sf_models(hf_username, unit['env'], unit['library'])
|
| 260 |
+
|
| 261 |
# Calculate the best result and get the best_model_id
|
| 262 |
best_result, best_model_id = calculate_best_result(user_models)
|
| 263 |
|