davidberenstein1957 commited on
Commit
2cee398
·
verified ·
1 Parent(s): 7813401

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -30
app.py CHANGED
@@ -1,25 +1,9 @@
1
- # Copyright 2024-present, David Berenstein, Inc.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import io
16
  import os
 
17
  import random
18
- import time
19
 
20
  import requests
21
  from PIL import Image
22
-
23
  from dataset_viber import AnnotatorInterFace
24
 
25
  HF_TOKEN = os.environ["HF_TOKEN"]
@@ -29,6 +13,10 @@ DATASET_NAME = "poloclub%2Fdiffusiondb&config=2m_random_1k&split=train"
29
  MODEL_URL = (
30
  "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
31
  )
 
 
 
 
32
 
33
 
34
  def retrieve_sample(idx):
@@ -48,17 +36,10 @@ def get_rows():
48
 
49
 
50
  def generate_response(prompt):
51
- def _get_response(prompt):
52
- payload = {
53
- "inputs": prompt,
54
- }
55
- response = requests.post(MODEL_URL, headers=HEADERS, json=payload)
56
- if response.status_code != 200:
57
- time.sleep(10)
58
- return _get_response(prompt)
59
- return response
60
-
61
- response = _get_response(prompt)
62
  image = Image.open(io.BytesIO(response.content))
63
  return image
64
 
@@ -66,8 +47,7 @@ def generate_response(prompt):
66
  def next_input(_prompt, _completion_a, _completion_b):
67
  random_idx = random.randint(0, get_rows()) - 1
68
  img_url, prompt = retrieve_sample(random_idx)
69
- generated_image = generate_response(prompt)
70
- return (prompt, img_url, generated_image)
71
 
72
 
73
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import io
3
  import random
 
4
 
5
  import requests
6
  from PIL import Image
 
7
  from dataset_viber import AnnotatorInterFace
8
 
9
  HF_TOKEN = os.environ["HF_TOKEN"]
 
13
  MODEL_URL = (
14
  "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
15
  )
16
+ MODEL_URLS = [
17
+ MODEL_URL,
18
+ "https://api-inference.huggingface.co/models/runwayml/stable-diffusion-v1-5"
19
+ ]
20
 
21
 
22
  def retrieve_sample(idx):
 
36
 
37
 
38
  def generate_response(prompt):
39
+ payload = {
40
+ "inputs": prompt,
41
+ }
42
+ response = requests.post(random.choice(MODEL_URLS), headers=HEADERS, json=payload)
 
 
 
 
 
 
 
43
  image = Image.open(io.BytesIO(response.content))
44
  return image
45
 
 
47
  def next_input(_prompt, _completion_a, _completion_b):
48
  random_idx = random.randint(0, get_rows()) - 1
49
  img_url, prompt = retrieve_sample(random_idx)
50
+ return (prompt, generate_response(prompt), generate_response(prompt+" "))
 
51
 
52
 
53
  if __name__ == "__main__":