Mikiko Bazeley commited on
Commit
e29cdd1
·
1 Parent(s): f4f8888

Created separate page for Flux models

Browse files
pages/1_Comparing_LLMs.py DELETED
@@ -1,185 +0,0 @@
1
- from dotenv import load_dotenv
2
- import os
3
- from PIL import Image
4
- import streamlit as st
5
- import fireworks.client
6
-
7
- st.set_page_config(page_title="LLM Comparison Tool", page_icon="🎇")
8
- st.title("LLM-as-a-judge: Comparing LLMs using Fireworks")
9
- st.write("A light introduction to how easy it is to swap LLMs and how to use the Fireworks Python client")
10
-
11
- # Clear the cache before starting
12
- st.cache_data.clear()
13
-
14
- # Specify the path to the .env file in the env/ directory
15
- dotenv_path = os.path.join(os.path.dirname(__file__), '..', 'env', '.env')
16
-
17
- # Load the .env file from the specified path
18
- load_dotenv(dotenv_path)
19
-
20
- # Get the Fireworks API key from the environment variable
21
- fireworks_api_key = os.getenv("FIREWORKS_API_KEY")
22
-
23
- if not fireworks_api_key:
24
- raise ValueError("No API key found in the .env file. Please add your FIREWORKS_API_KEY to the .env file.")
25
-
26
- # Load the image
27
- logo_image = Image.open("img/fireworksai_logo.png")
28
- ash_image = Image.open("img/ash.png")
29
- bulbasaur_image = Image.open("img/bulbasaur.png")
30
- squirtel_image = Image.open("img/squirtel.png")
31
- charmander_image = Image.open("img/charmander.png")
32
-
33
- st.divider()
34
- # Streamlit app
35
- st.subheader("Fireworks Playground")
36
-
37
- st.write("Fireworks AI is a platform that offers serverless and scalable AI models.")
38
- st.write("👉 Learn more here: [Fireworks Serverless Models](https://fireworks.ai/models?show=Serverless)")
39
- st.divider()
40
-
41
- # Sidebar for selecting models
42
- with st.sidebar:
43
- st.image(logo_image)
44
-
45
- st.write("Select three models to compare their outputs:")
46
-
47
- st.image(bulbasaur_image, width=80)
48
- option_1 = st.selectbox("Select Model 1", [
49
- "Text: Meta Llama 3.1 Instruct - 70B",
50
- "Text: Meta Llama 3.1 Instruct - 8B",
51
- "Text: Meta Llama 3.2 Instruct - 3B",
52
- "Text: Gemma 2 Instruct - 9B",
53
- "Text: Mixtral MoE Instruct - 8x22B",
54
- "Text: Mixtral MoE Instruct - 8x7B",
55
- "Text: MythoMax L2 - 13B"
56
- ], index=2) # Default to Meta Llama 3.2 Instruct - 3B
57
-
58
- st.image(charmander_image, width=80)
59
- option_2 = st.selectbox("Select Model 2", [
60
- "Text: Meta Llama 3.1 Instruct - 70B",
61
- "Text: Meta Llama 3.1 Instruct - 8B",
62
- "Text: Meta Llama 3.2 Instruct - 3B",
63
- "Text: Gemma 2 Instruct - 9B",
64
- "Text: Mixtral MoE Instruct - 8x22B",
65
- "Text: Mixtral MoE Instruct - 8x7B",
66
- "Text: MythoMax L2 - 13B"
67
- ], index=5) # Default to Mixtral MoE Instruct - 8x7B
68
-
69
- st.image(squirtel_image, width=80)
70
- option_3 = st.selectbox("Select Model 3", [
71
- "Text: Meta Llama 3.1 Instruct - 70B",
72
- "Text: Meta Llama 3.1 Instruct - 8B",
73
- "Text: Meta Llama 3.2 Instruct - 3B",
74
- "Text: Gemma 2 Instruct - 9B",
75
- "Text: Mixtral MoE Instruct - 8x22B",
76
- "Text: Mixtral MoE Instruct - 8x7B",
77
- "Text: MythoMax L2 - 13B"
78
- ], index=0) # Default to Gemma 2 Instruct - 9B
79
-
80
- # Dropdown to select the LLM that will perform the comparison
81
- st.image(ash_image, width=80)
82
- comparison_llm = st.selectbox("Select Comparison Model", [
83
- "Text: Meta Llama 3.1 Instruct - 70B",
84
- "Text: Meta Llama 3.1 Instruct - 8B",
85
- "Text: Meta Llama 3.2 Instruct - 3B",
86
- "Text: Gemma 2 Instruct - 9B",
87
- "Text: Mixtral MoE Instruct - 8x22B",
88
- "Text: Mixtral MoE Instruct - 8x7B",
89
- "Text: MythoMax L2 - 13B"
90
- ], index=5) # Default to MythoMax L2 - 13B
91
-
92
- os.environ["FIREWORKS_API_KEY"] = fireworks_api_key
93
-
94
- # Helper text for the prompt
95
- st.markdown("### Enter your prompt below to generate responses:")
96
-
97
- prompt = st.text_input("Prompt", label_visibility="collapsed")
98
- st.divider()
99
-
100
- # Function to generate a response from a text model
101
- def generate_text_response(model_name, prompt):
102
- return fireworks.client.ChatCompletion.create(
103
- model=model_name,
104
- messages=[{
105
- "role": "user",
106
- "content": prompt,
107
- }]
108
- )
109
-
110
- # Function to compare the three responses using the selected LLM
111
- def compare_responses(response_1, response_2, response_3, comparison_model):
112
- comparison_prompt = f"Compare the following three responses:\n\nResponse 1: {response_1}\n\nResponse 2: {response_2}\n\nResponse 3: {response_3}\n\nProvide a succinct comparison."
113
-
114
- comparison_response = fireworks.client.ChatCompletion.create(
115
- model=comparison_model, # Use the selected LLM for comparison
116
- messages=[{
117
- "role": "user",
118
- "content": comparison_prompt,
119
- }]
120
- )
121
-
122
- return comparison_response.choices[0].message.content
123
-
124
-
125
- # If Generate button is clicked
126
- if st.button("Generate"):
127
- if not fireworks_api_key.strip() or not prompt.strip():
128
- st.error("Please provide the missing fields.")
129
- else:
130
- try:
131
- with st.spinner("Please wait..."):
132
- fireworks.client.api_key = fireworks_api_key
133
-
134
- # Create three columns for side-by-side comparison
135
- col1, col2, col3 = st.columns(3)
136
-
137
- # Model 1
138
- with col1:
139
- st.subheader(f"Model 1: {option_1}")
140
- st.image(bulbasaur_image)
141
- if option_1.startswith("Text"):
142
- model_map = {
143
- "Text: Meta Llama 3.1 Instruct - 70B": "accounts/fireworks/models/llama-v3p1-70b-instruct",
144
- "Text: Meta Llama 3.1 Instruct - 8B": "accounts/fireworks/models/llama-v3p1-8b-instruct",
145
- "Text: Meta Llama 3.2 Instruct - 3B": "accounts/fireworks/models/llama-v3p2-3b-instruct",
146
- "Text: Gemma 2 Instruct - 9B": "accounts/fireworks/models/gemma2-9b-it",
147
- "Text: Mixtral MoE Instruct - 8x22B": "accounts/fireworks/models/mixtral-8x22b-instruct",
148
- "Text: Mixtral MoE Instruct - 8x7B": "accounts/fireworks/models/mixtral-8x7b-instruct",
149
- "Text: MythoMax L2 - 13B": "accounts/fireworks/models/mythomax-l2-13b"
150
- }
151
- response_1 = generate_text_response(model_map[option_1], prompt)
152
- st.success(response_1.choices[0].message.content)
153
-
154
- # Model 2
155
- with col2:
156
- st.subheader(f"Model 2: {option_2}")
157
- st.image(charmander_image)
158
- response_2 = generate_text_response(model_map[option_2], prompt)
159
- st.success(response_2.choices[0].message.content)
160
-
161
- # Model 3
162
- with col3:
163
- st.subheader(f"Model 3: {option_3}")
164
- st.image(squirtel_image)
165
- response_3 = generate_text_response(model_map[option_3], prompt)
166
- st.success(response_3.choices[0].message.content)
167
-
168
- # Visual divider between model responses and comparison
169
- st.divider()
170
-
171
- # Generate a comparison of the three responses using the selected LLM
172
- comparison = compare_responses(
173
- response_1.choices[0].message.content,
174
- response_2.choices[0].message.content,
175
- response_3.choices[0].message.content,
176
- model_map[comparison_llm]
177
- )
178
-
179
- # Display the comparison
180
- st.subheader("Comparison of the Three Responses:")
181
- st.image(ash_image)
182
- st.write(comparison)
183
-
184
- except Exception as e:
185
- st.exception(f"Exception: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/3_Image_Generation.py CHANGED
@@ -1,28 +1,36 @@
1
- from dotenv import load_dotenv
2
  import os
 
 
3
  from PIL import Image
 
4
  import streamlit as st
5
  import fireworks.client
6
  from fireworks.client.image import ImageInference, Answer
7
 
 
8
  st.set_page_config(page_title="Image Generation Tool", page_icon="🎇")
9
- st.title("Image Generation Comparison using Fireworks")
10
- st.write("An introduction to how easy it is to generate images using the Fireworks Python client.")
11
 
12
- # Clear the cache before starting
13
- st.cache_data.clear()
14
 
15
- # Specify the path to the .env file in the env/ directory
16
- dotenv_path = os.path.join(os.path.dirname(__file__), '..', 'env', '.env')
17
 
18
- # Load the .env file from the specified path
19
- load_dotenv(dotenv_path)
20
-
21
- # Get the Fireworks API key from the environment variable
22
  fireworks_api_key = os.getenv("FIREWORKS_API_KEY")
23
 
 
 
 
 
24
  if not fireworks_api_key:
25
- raise ValueError("No API key found in the .env file. Please add your FIREWORKS_API_KEY to the .env file.")
 
 
 
 
 
 
26
 
27
  # Load image for logo
28
  logo_image = Image.open("img/fireworksai_logo.png")
@@ -41,29 +49,25 @@ with st.sidebar:
41
 
42
  st.write("Select three image generation models to compare:")
43
 
44
- # Updated model options with the correct paths and additional FLUX models
45
  model_options = {
46
  "Stable Diffusion XL": "stable-diffusion-xl-1024-v1-0",
47
  "Playground v2 1024": "playground-v2-1024px-aesthetic",
48
  "Playground v2.5 1024": "playground-v2-5-1024px-aesthetic",
49
- "Segmind Stable Diffusion 1B (SSD-1B)": "SSD-1B",
50
- "FLUX.1 [dev]": "flux-1-dev",
51
- "FLUX.1 [schnell]": "flux-1-schnell"
52
  }
53
 
54
  option_1 = st.selectbox("Select Image Model 1", list(model_options.keys()), index=0)
55
  option_2 = st.selectbox("Select Image Model 2", list(model_options.keys()), index=1)
56
  option_3 = st.selectbox("Select Image Model 3", list(model_options.keys()), index=2)
57
 
58
- os.environ["FIREWORKS_API_KEY"] = fireworks_api_key
59
-
60
  # Helper text for the prompt
61
  st.markdown("### Enter your prompt below to generate images:")
62
 
63
  prompt = st.text_input("Prompt", label_visibility="collapsed")
64
  st.divider()
65
 
66
- # Function to generate a response from an image model
67
  def generate_image_response(model_path, prompt):
68
  # Initialize the ImageInference client
69
  inference_client = ImageInference(model=model_path)
@@ -117,4 +121,3 @@ if st.button("Generate"):
117
 
118
  except Exception as e:
119
  st.exception(f"Exception: {e}")
120
-
 
 
1
  import os
2
+ import requests
3
+ from dotenv import load_dotenv
4
  from PIL import Image
5
+ from io import BytesIO
6
  import streamlit as st
7
  import fireworks.client
8
  from fireworks.client.image import ImageInference, Answer
9
 
10
+ # Set page configuration - must be the first Streamlit command
11
  st.set_page_config(page_title="Image Generation Tool", page_icon="🎇")
 
 
12
 
13
+ # Set the full path to the .env file
14
+ dotenv_path = os.path.join(os.path.dirname(__file__), '.env')
15
 
16
+ # Load environment variables from the .env file, overriding existing environment variables if necessary
17
+ load_dotenv(dotenv_path, override=True)
18
 
19
+ # Get the Fireworks API key from the .env file
 
 
 
20
  fireworks_api_key = os.getenv("FIREWORKS_API_KEY")
21
 
22
+ # Debugging check: print the API key to ensure it's being loaded correctly
23
+ st.write(f"API Key loaded in Streamlit: {fireworks_api_key}")
24
+
25
+ # Ensure the API key is loaded
26
  if not fireworks_api_key:
27
+ raise ValueError("API key not found. Make sure FIREWORKS_API_KEY is set in the .env file.")
28
+
29
+ st.title("Image Generation Comparison using Fireworks")
30
+ st.write("An introduction to how easy it is to generate images using the Fireworks Python client.")
31
+
32
+ # Clear the cache before starting
33
+ st.cache_data.clear()
34
 
35
  # Load image for logo
36
  logo_image = Image.open("img/fireworksai_logo.png")
 
49
 
50
  st.write("Select three image generation models to compare:")
51
 
52
+ # Updated model options (FLUX models removed)
53
  model_options = {
54
  "Stable Diffusion XL": "stable-diffusion-xl-1024-v1-0",
55
  "Playground v2 1024": "playground-v2-1024px-aesthetic",
56
  "Playground v2.5 1024": "playground-v2-5-1024px-aesthetic",
57
+ "Segmind Stable Diffusion 1B (SSD-1B)": "SSD-1B"
 
 
58
  }
59
 
60
  option_1 = st.selectbox("Select Image Model 1", list(model_options.keys()), index=0)
61
  option_2 = st.selectbox("Select Image Model 2", list(model_options.keys()), index=1)
62
  option_3 = st.selectbox("Select Image Model 3", list(model_options.keys()), index=2)
63
 
 
 
64
  # Helper text for the prompt
65
  st.markdown("### Enter your prompt below to generate images:")
66
 
67
  prompt = st.text_input("Prompt", label_visibility="collapsed")
68
  st.divider()
69
 
70
+ # Function to generate an image using the Fireworks API models
71
  def generate_image_response(model_path, prompt):
72
  # Initialize the ImageInference client
73
  inference_client = ImageInference(model=model_path)
 
121
 
122
  except Exception as e:
123
  st.exception(f"Exception: {e}")
 
pages/5_FLUX_image_generation.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from dotenv import load_dotenv
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ import streamlit as st
7
+
8
+ # Set page configuration
9
+ st.set_page_config(page_title="FLUX Image Generation Tool", page_icon="🎇")
10
+
11
+ # Correct the path to the .env file to reflect its location
12
+ dotenv_path = os.path.join(os.path.dirname(__file__), '../env/.env')
13
+
14
+ # Load environment variables from the .env file
15
+ load_dotenv(dotenv_path, override=True)
16
+
17
+ # Get the Fireworks API key from the .env file
18
+ fireworks_api_key = os.getenv("FIREWORKS_API_KEY")
19
+
20
+ if not fireworks_api_key:
21
+ st.error("API key not found. Make sure FIREWORKS_API_KEY is set in the .env file.")
22
+
23
+ # Function to make requests to the FLUX models
24
+ def generate_flux_image(model_path, prompt, steps):
25
+ url = f"https://api.fireworks.ai/inference/v1/workflows/accounts/fireworks/models/{model_path}/text_to_image"
26
+ headers = {
27
+ "Authorization": f"Bearer {fireworks_api_key}",
28
+ "Content-Type": "application/json",
29
+ "Accept": "image/jpeg"
30
+ }
31
+ data = {
32
+ "prompt": prompt,
33
+ "aspect_ratio": "16:9",
34
+ "guidance_scale": 3.5,
35
+ "num_inference_steps": steps,
36
+ "seed": 0
37
+ }
38
+
39
+ # Send the request
40
+ response = requests.post(url, headers=headers, json=data)
41
+
42
+ if response.status_code == 200:
43
+ # Convert the response to an image
44
+ img_data = response.content
45
+ img = Image.open(BytesIO(img_data))
46
+ return img
47
+ else:
48
+ raise RuntimeError(f"Error with FLUX model {model_path}: {response.text}")
49
+
50
+ # Streamlit UI
51
+ st.title("FLUX Image Generation")
52
+ st.write("Generate images using the FLUX models.")
53
+
54
+ # User input for the prompt
55
+ prompt = st.text_input("Enter your prompt for image generation:")
56
+
57
+ # Dropdown to select the model
58
+ model_choice = st.selectbox("Select the model:", ["flux-1-dev", "flux-1-schnell"])
59
+
60
+ # Button to generate images
61
+ if st.button("Generate Image"):
62
+ if not prompt.strip():
63
+ st.error("Please provide a prompt.")
64
+ else:
65
+ try:
66
+ with st.spinner("Generating image..."):
67
+ # Determine steps based on model
68
+ steps = 30 if model_choice == "flux-1-dev" else 4
69
+
70
+ # Generate image
71
+ generated_image = generate_flux_image(model_choice, prompt, steps)
72
+
73
+ # Display the image
74
+ st.image(generated_image, caption=f"Generated using {model_choice}", use_column_width=True)
75
+
76
+ except Exception as e:
77
+ st.error(f"An error occurred: {e}")
pages/test_endpoint.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from dotenv import load_dotenv
4
+ from PIL import Image
5
+ from io import BytesIO
6
+
7
+ # Correct the path to the .env file to reflect its location
8
+ dotenv_path = os.path.join(os.path.dirname(__file__), '../env/.env')
9
+
10
+ # Load environment variables from the .env file
11
+ load_dotenv(dotenv_path, override=True)
12
+
13
+ # Get the API key from the .env file
14
+ api_key = os.getenv("FIREWORKS_API_KEY")
15
+
16
+ if not api_key:
17
+ raise ValueError("API key not found. Make sure FIREWORKS_API_KEY is set in the .env file.")
18
+
19
+ # User input for the prompt
20
+ prompt = input("Enter a prompt for image generation: ")
21
+
22
+ # Validate the prompt input
23
+ if not prompt.strip():
24
+ raise ValueError("Prompt cannot be empty!")
25
+
26
+ # Set the model endpoint for either flux-1-dev or flux-1-schnell
27
+ # For dev: "flux-1-dev" (30 steps)
28
+ # For schnell: "flux-1-schnell" (4 steps)
29
+ model_path = "flux-1-schnell"
30
+ # model_path = "flux-1-dev" # Uncomment if you want to switch to the dev model
31
+
32
+ # API URL for the model
33
+ url = f"https://api.fireworks.ai/inference/v1/workflows/accounts/fireworks/models/{model_path}/text_to_image"
34
+
35
+ # Headers for the API request
36
+ headers = {
37
+ "Authorization": f"Bearer {api_key}",
38
+ "Content-Type": "application/json",
39
+ "Accept": "image/jpeg"
40
+ }
41
+
42
+ # Data payload to send with the request
43
+ data = {
44
+ "prompt": prompt, # Use the user-provided prompt
45
+ "aspect_ratio": "16:9",
46
+ "guidance_scale": 3.5,
47
+ "num_inference_steps": 30 if model_path == "flux-1-dev" else 4, # 30 steps for dev, 4 for schnell
48
+ "seed": 0
49
+ }
50
+
51
+ # Make the POST request to the API
52
+ response = requests.post(url, headers=headers, json=data)
53
+
54
+ # Check the status of the response
55
+ if response.status_code == 200:
56
+ # If the request is successful, convert the response to an image
57
+ img_data = response.content
58
+ img = Image.open(BytesIO(img_data))
59
+ # Save the image
60
+ img.save("output_image.jpg")
61
+ print("Image saved successfully as output_image.jpg.")
62
+ else:
63
+ # If there's an error, print the status code and response text
64
+ print(f"Error: {response.status_code}, {response.text}")