File size: 1,745 Bytes
89ded21 2cee398 ce11ffc f20ab91 ce11ffc 870c158 ce11ffc 2cee398 ce11ffc 2cee398 ce11ffc 89ded21 ce11ffc 40ef6c4 ce11ffc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import os
import io
import random
import requests
from PIL import Image
from dataset_viber import AnnotatorInterFace
HF_TOKEN = os.environ["HF_TOKEN"]
HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"}
DATASET_SERVER_URL = "https://datasets-server.huggingface.co"
DATASET_NAME = "poloclub%2Fdiffusiondb&config=2m_random_1k&split=train"
MODEL_URL = (
"https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
)
MODEL_URLS = [MODEL_URL]
def retrieve_sample(idx):
api_url = f"{DATASET_SERVER_URL}/rows?dataset={DATASET_NAME}&offset={idx}&length=1"
response = requests.get(api_url, headers=HEADERS)
data = response.json()
img_url = data["rows"][0]["row"]["image"]["src"]
prompt = data["rows"][0]["row"]["prompt"]
return img_url, prompt
def get_rows():
api_url = f"{DATASET_SERVER_URL}/size?dataset={DATASET_NAME}"
response = requests.get(api_url, headers=HEADERS)
num_rows = response.json()["size"]["config"]["num_rows"]
return num_rows
def generate_response(prompt):
payload = {
"inputs": prompt,
}
response = requests.post(random.choice(MODEL_URLS), headers=HEADERS, json=payload)
image = Image.open(io.BytesIO(response.content))
return image
def next_input(_prompt, _completion_a, _completion_b):
random_idx = random.randint(0, get_rows()) - 1
img_url, prompt = retrieve_sample(random_idx)
return (prompt, generate_response(prompt), generate_response(prompt+" "))
if __name__ == "__main__":
interface = AnnotatorInterFace.for_image_generation_preference(
interactive=False, fn_next_input=next_input, dataset_name="dataset-viber-image-generation-preference-inference-endpoints-battle-flux"
)
interface.launch()
|