Spaces:
Configuration error
Configuration error
import json | |
import os | |
import yaml | |
import requests | |
import pathlib | |
from aiohttp import web | |
from server import PromptServer | |
from .image import tensor2pil, pil2tensor, image2base64, pil2byte | |
from .log import log_node_error | |
root_path = pathlib.Path(__file__).parent.parent.parent | |
config_path = os.path.join(root_path,'config.yaml') | |
default_key = [{'name':'Default', 'key':''}] | |
class StabilityAPI: | |
def __init__(self): | |
self.api_url = "https://api.stability.ai" | |
self.api_keys = None | |
self.api_current = 0 | |
self.user_info = {} | |
self.getAPIKeys() | |
def getErrors(self, code): | |
errors = { | |
400: "Bad Request", | |
403: "ApiKey Forbidden", | |
413: "Your request was larger than 10MiB.", | |
429: "You have made more than 150 requests in 10 seconds.", | |
500: "Internal Server Error", | |
} | |
return errors.get(code, "Unknown Error") | |
def getAPIKeys(self): | |
if os.path.isfile(config_path): | |
with open(config_path, 'r') as f: | |
data = yaml.load(f, Loader=yaml.FullLoader) | |
if not data: | |
data = {'STABILITY_API_KEY': default_key, 'STABILITY_API_DEFAULT':0} | |
with open(config_path, 'w') as f: | |
yaml.dump(data, f) | |
if 'STABILITY_API_KEY' not in data: | |
data['STABILITY_API_KEY'] = default_key | |
data['STABILITY_API_DEFAULT'] = 0 | |
with open(config_path, 'w') as f: | |
yaml.dump(data, f) | |
api_keys = data['STABILITY_API_KEY'] | |
self.api_current = data['STABILITY_API_DEFAULT'] | |
self.api_keys = api_keys | |
return api_keys | |
else: | |
# create a yaml file | |
with open(config_path, 'w') as f: | |
data = {'STABILITY_API_KEY': default_key, 'STABILITY_API_DEFAULT':0} | |
yaml.dump(data, f) | |
return data['STABILITY_API_KEY'] | |
pass | |
def setAPIKeys(self, api_keys): | |
if len(api_keys) > 0: | |
self.api_keys = api_keys | |
# load and save the yaml file | |
with open(config_path, 'r') as f: | |
data = yaml.load(f, Loader=yaml.FullLoader) | |
data['STABILITY_API_KEY'] = api_keys | |
with open(config_path, 'w') as f: | |
yaml.dump(data, f) | |
return True | |
def setAPIDefault(self, current): | |
if current is not None: | |
self.api_current = current | |
# load and save the yaml file | |
with open(config_path, 'r') as f: | |
data = yaml.load(f, Loader=yaml.FullLoader) | |
data['STABILITY_API_DEFAULT'] = current | |
with open(config_path, 'w') as f: | |
yaml.dump(data, f) | |
return True | |
def generate_sd3_image(self, prompt, negative_prompt, aspect_ratio, model, seed, mode='text-to-image', image=None, strength=1, output_format='png', node_name='easy stableDiffusion3API'): | |
url = f"{self.api_url}/v2beta/stable-image/generate/sd3" | |
api_key = self.api_keys[self.api_current]['key'] | |
files = None | |
data = { | |
"prompt": prompt, | |
"mode": mode, | |
"model": model, | |
"seed": seed, | |
"output_format": output_format, | |
} | |
if model == 'sd3': | |
data['negative_prompt'] = negative_prompt | |
if mode == 'text-to-image': | |
files = {"none": ''} | |
data['aspect_ratio'] = aspect_ratio | |
elif mode == 'image-to-image': | |
pil_image = tensor2pil(image) | |
image_byte = pil2byte(pil_image) | |
files = {"image": ("output.png", image_byte, 'image/png')} | |
data['strength'] = strength | |
response = requests.post(url, | |
headers={"authorization": f"{api_key}", "accept": "application/json"}, | |
files=files, | |
data=data, | |
) | |
if response.status_code == 200: | |
PromptServer.instance.send_sync('stable-diffusion-api-generate-succeed',{"model":model}) | |
json_data = response.json() | |
image_base64 = json_data['image'] | |
image_data = image2base64(image_base64) | |
output_t = pil2tensor(image_data) | |
return output_t | |
else: | |
if 'application/json' in response.headers['Content-Type']: | |
error_info = response.json() | |
log_node_error(node_name, error_info.get('name', 'No name provided')) | |
log_node_error(node_name, error_info.get('errors', ['No details provided'])) | |
error_status_text = self.getErrors(response.status_code) | |
PromptServer.instance.send_sync('easyuse-toast',{"type": "error", "content": error_status_text}) | |
raise Exception(f"Failed to generate image: {error_status_text}") | |
# get user account | |
async def getUserAccount(self, cache=True): | |
url = f"{self.api_url}/v1/user/account" | |
api_key = self.api_keys[self.api_current]['key'] | |
name = self.api_keys[self.api_current]['name'] | |
if cache and name in self.user_info: | |
return self.user_info[name] | |
else: | |
response = requests.get(url, headers={"Authorization": f"Bearer {api_key}"}) | |
if response.status_code == 200: | |
user_info = response.json() | |
self.user_info[name] = user_info | |
return user_info | |
else: | |
PromptServer.instance.send_sync('easyuse-toast',{'type': 'error', 'content': self.getErrors(response.status_code)}) | |
return None | |
# get user balance | |
async def getUserBalance(self): | |
url = f"{self.api_url}/v1/user/balance" | |
api_key = self.api_keys[self.api_current]['key'] | |
response = requests.get(url, headers={ | |
"Authorization": f"Bearer {api_key}" | |
}) | |
if response.status_code == 200: | |
return response.json() | |
else: | |
PromptServer.instance.send_sync('easyuse-toast', {'type': 'error', 'content': self.getErrors(response.status_code)}) | |
return None | |
stableAPI = StabilityAPI() | |
async def get_stability_api_keys(request): | |
stableAPI.getAPIKeys() | |
return web.json_response({"keys": stableAPI.api_keys, "current": stableAPI.api_current}) | |
async def set_stability_api_keys(request): | |
post = await request.post() | |
api_keys = post.get("api_keys") | |
current = post.get('current') | |
if api_keys is not None: | |
api_keys = json.loads(api_keys) | |
stableAPI.setAPIKeys(api_keys) | |
if current is not None: | |
print(current) | |
stableAPI.setAPIDefault(int(current)) | |
account = await stableAPI.getUserAccount() | |
balance = await stableAPI.getUserBalance() | |
return web.json_response({'account': account, 'balance': balance}) | |
else: | |
return web.json_response({'status': 'ok'}) | |
else: | |
return web.Response(status=400) | |
async def set_stability_api_default(request): | |
post = await request.post() | |
current = post.get("current") | |
if current is not None and current < len(stableAPI.api_keys): | |
stableAPI.api_current = current | |
return web.json_response({'status': 'ok'}) | |
else: | |
return web.Response(status=400) | |
async def get_account_info(request): | |
account = await stableAPI.getUserAccount() | |
balance = await stableAPI.getUserBalance() | |
return web.json_response({'account': account, 'balance': balance}) | |
async def get_balance_info(request): | |
balance = await stableAPI.getUserBalance() | |
return web.json_response({'balance': balance}) | |