Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import base64 | |
import json | |
import os | |
from cosmos_predict1.utils import log | |
from cosmos_predict1.utils.validator import JsonDict, Validator | |
""" | |
Base class for parsing environment variables using validators. | |
Class will go through its list of validators and retrieve values from same named environment variables. | |
Validators provide: | |
- default value | |
- typed parsing | |
- enforments of mandatory values | |
Additionally the environment variables can be passed as single base64 encoded string. | |
we cannot enforce that a component isn't directly using the environment variables. | |
so evaluation of params should throw error to make sure actual env var is correct. | |
""" | |
class EnvParser: | |
def __init__(self, b64_str=None): | |
if b64_str: | |
log.critical(f"b64_str recieved: {b64_str}") | |
self.from_b64(b64_str) | |
else: | |
self.from_env() | |
def from_env(self): | |
validators = self.get_val_dict() | |
for key in validators.keys(): | |
val = os.getenv(key.upper()) | |
log.debug(f"getting env var {key.upper()}: {val}") | |
if val: | |
setattr(self, key, val) | |
self.check_mandatory_values() | |
def from_json(self, file_name): | |
with open(file_name, "r") as f: | |
log.info(f"Reading env params from {file_name}") | |
dict = json.load(f) | |
for key, value in dict.items(): | |
setattr(self, key, value) | |
self.check_mandatory_values() | |
def to_b64(self): | |
json_str = self.to_json() | |
# create bytes-like object for b64 encoder | |
json_str_bytes = json_str.encode() | |
b64_str = base64.b64encode(json_str_bytes).decode() | |
print(b64_str) | |
return b64_str | |
def from_b64(self, b64_str): | |
json_str = base64.b64decode(b64_str).decode() | |
dict = json.loads(json_str) | |
for key, value in dict.items(): | |
setattr(self, key, value) | |
self.check_mandatory_values() | |
def check_mandatory_values(self): | |
for key, validator in self.get_val_dict().items(): | |
if getattr(self, key) is None and validator.default is None: | |
raise ValueError(f"Missing mandatory env var: {key}") | |
def get_val_dict(cls): | |
log.debug(f"getting val dict of {cls.__name__}") | |
val_dict = {} | |
val_dict.update({key: value for key, value in cls.__dict__.items() if isinstance(value, Validator)}) | |
return val_dict | |
def dump_validators(self): | |
validators = self.get_val_dict() | |
for key, value in validators.items(): | |
log.debug(f"{key}: {value.__get__(self)}") | |
def to_json(self, file_name=None): | |
dict = { | |
key.upper(): value.__get__(self) | |
for key, value in EnvParser.__dict__.items() | |
if isinstance(value, Validator) | |
} | |
json_str = json.dumps(dict, indent=4) | |
print(json_str) | |
if file_name: | |
with open(file_name, "w") as f: | |
log.info(f"Writing env params to {file_name}") | |
f.write(json_str) | |
return json_str | |
def to_string_dict(self): | |
result = {} | |
for key, validator in self.get_val_dict().items(): | |
value = getattr(self, key) | |
if value is None: | |
value = validator.default | |
if isinstance(validator, JsonDict): | |
value = json.dumps(value) | |
else: | |
value = str(value) | |
result[key] = value | |
return result | |
def __str__(self): | |
return ", ".join(f"{key}={value}" for key, value in self.__dict__.items()) | |