Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
a4b32da
1
Parent(s):
780769f
release code
Browse files- app.py +5 -8
- model/matchmaker.py +11 -5
- model/matchmaker_video.py +12 -5
- model/model_manager.py +0 -8
- model/model_registry.py +0 -185
- model/models/__init__.py +2 -18
- model/models/openai_api_models.py +0 -2
- model/models/other_api_models.py +3 -27
- model/models/replicate_api_models.py +7 -148
- serve/Ksort.py +29 -0
- serve/leaderboard.py +4 -0
- serve/update_skill.py +11 -3
- serve/update_skill_video.py +11 -3
- serve/upload.py +13 -5
app.py
CHANGED
|
@@ -7,12 +7,8 @@ from model.model_manager import ModelManager
|
|
| 7 |
from pathlib import Path
|
| 8 |
from serve.constants import SERVER_PORT, ROOT_PATH, ELO_RESULTS_DIR
|
| 9 |
|
|
|
|
| 10 |
def build_combine_demo(models, elo_results_file, leaderboard_table_file):
|
| 11 |
-
# gr.themes.Default(),
|
| 12 |
-
# gr.themes.Soft(),
|
| 13 |
-
# gr.Theme.from_hub('gary109/HaleyCH_Theme'),
|
| 14 |
-
# gr.Theme.from_hub('EveryPizza/Cartoony-Gradio-Theme')
|
| 15 |
-
# gr.themes.Default(primary_hue="red", secondary_hue="pink")
|
| 16 |
with gr.Blocks(
|
| 17 |
title="Play with Open Vision Models",
|
| 18 |
theme=gr.themes.Default(),
|
|
@@ -22,21 +18,21 @@ def build_combine_demo(models, elo_results_file, leaderboard_table_file):
|
|
| 22 |
with gr.Tab("Image Generation", id=0):
|
| 23 |
with gr.Tabs() as tabs_ig:
|
| 24 |
with gr.Tab("Generation Leaderboard", id=0):
|
| 25 |
-
# build_leaderboard_tab(elo_results_file['t2i_generation'], leaderboard_table_file['t2i_generation'])
|
| 26 |
build_leaderboard_tab()
|
| 27 |
-
|
| 28 |
with gr.Tab("Generation Arena (battle)", id=1):
|
| 29 |
build_side_by_side_ui_anony(models)
|
|
|
|
| 30 |
with gr.Tab("Video Generation", id=1):
|
| 31 |
with gr.Tabs() as tabs_ig:
|
| 32 |
with gr.Tab("Generation Leaderboard", id=0):
|
| 33 |
-
# build_leaderboard_tab(elo_results_file['t2i_generation'], leaderboard_table_file['t2i_generation'])
|
| 34 |
build_leaderboard_video_tab()
|
| 35 |
|
| 36 |
with gr.Tab("Generation Arena (battle)", id=1):
|
| 37 |
build_side_by_side_video_ui_anony(models)
|
|
|
|
| 38 |
with gr.Tab("Contributor", id=2):
|
| 39 |
build_leaderboard_contributor()
|
|
|
|
| 40 |
return demo
|
| 41 |
|
| 42 |
|
|
@@ -44,6 +40,7 @@ def load_elo_results(elo_results_dir):
|
|
| 44 |
from collections import defaultdict
|
| 45 |
elo_results_file = defaultdict(lambda: None)
|
| 46 |
leaderboard_table_file = defaultdict(lambda: None)
|
|
|
|
| 47 |
if elo_results_dir is not None:
|
| 48 |
elo_results_dir = Path(elo_results_dir)
|
| 49 |
elo_results_file = {}
|
|
|
|
| 7 |
from pathlib import Path
|
| 8 |
from serve.constants import SERVER_PORT, ROOT_PATH, ELO_RESULTS_DIR
|
| 9 |
|
| 10 |
+
|
| 11 |
def build_combine_demo(models, elo_results_file, leaderboard_table_file):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
with gr.Blocks(
|
| 13 |
title="Play with Open Vision Models",
|
| 14 |
theme=gr.themes.Default(),
|
|
|
|
| 18 |
with gr.Tab("Image Generation", id=0):
|
| 19 |
with gr.Tabs() as tabs_ig:
|
| 20 |
with gr.Tab("Generation Leaderboard", id=0):
|
|
|
|
| 21 |
build_leaderboard_tab()
|
|
|
|
| 22 |
with gr.Tab("Generation Arena (battle)", id=1):
|
| 23 |
build_side_by_side_ui_anony(models)
|
| 24 |
+
|
| 25 |
with gr.Tab("Video Generation", id=1):
|
| 26 |
with gr.Tabs() as tabs_ig:
|
| 27 |
with gr.Tab("Generation Leaderboard", id=0):
|
|
|
|
| 28 |
build_leaderboard_video_tab()
|
| 29 |
|
| 30 |
with gr.Tab("Generation Arena (battle)", id=1):
|
| 31 |
build_side_by_side_video_ui_anony(models)
|
| 32 |
+
|
| 33 |
with gr.Tab("Contributor", id=2):
|
| 34 |
build_leaderboard_contributor()
|
| 35 |
+
|
| 36 |
return demo
|
| 37 |
|
| 38 |
|
|
|
|
| 40 |
from collections import defaultdict
|
| 41 |
elo_results_file = defaultdict(lambda: None)
|
| 42 |
leaderboard_table_file = defaultdict(lambda: None)
|
| 43 |
+
|
| 44 |
if elo_results_dir is not None:
|
| 45 |
elo_results_dir = Path(elo_results_dir)
|
| 46 |
elo_results_file = {}
|
model/matchmaker.py
CHANGED
|
@@ -24,35 +24,41 @@ def create_ssh_matchmaker_client(server, port, user, password):
|
|
| 24 |
transport.set_keepalive(60)
|
| 25 |
|
| 26 |
sftp_matchmaker_client = ssh_matchmaker_client.open_sftp()
|
|
|
|
|
|
|
| 27 |
def is_connected():
|
| 28 |
global ssh_matchmaker_client, sftp_matchmaker_client
|
| 29 |
if ssh_matchmaker_client is None or sftp_matchmaker_client is None:
|
| 30 |
return False
|
| 31 |
-
# 检查SSH连接是否正常
|
| 32 |
if not ssh_matchmaker_client.get_transport().is_active():
|
| 33 |
return False
|
| 34 |
-
# 检查SFTP连接是否正常
|
| 35 |
try:
|
| 36 |
-
sftp_matchmaker_client.listdir('.')
|
| 37 |
except Exception as e:
|
| 38 |
print(f"Error checking SFTP connection: {e}")
|
| 39 |
return False
|
| 40 |
return True
|
|
|
|
|
|
|
| 41 |
def ucb_score(trueskill_diff, t, n):
|
| 42 |
exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
|
| 43 |
ucb = -trueskill_diff + 1.0 * exploration_term
|
| 44 |
return ucb
|
| 45 |
|
|
|
|
| 46 |
def update_trueskill(ratings, ranks):
|
| 47 |
new_ratings = trueskill_env.rate(ratings, ranks)
|
| 48 |
return new_ratings
|
| 49 |
|
|
|
|
| 50 |
def serialize_rating(rating):
|
| 51 |
return {'mu': rating.mu, 'sigma': rating.sigma}
|
| 52 |
|
|
|
|
| 53 |
def deserialize_rating(rating_dict):
|
| 54 |
return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
|
| 55 |
|
|
|
|
| 56 |
def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
|
| 57 |
global sftp_matchmaker_client
|
| 58 |
if not is_connected():
|
|
@@ -66,6 +72,7 @@ def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
|
|
| 66 |
with sftp_matchmaker_client.open(SSH_SKILL, 'w') as f:
|
| 67 |
f.write(json_data)
|
| 68 |
|
|
|
|
| 69 |
def load_json_via_sftp():
|
| 70 |
global sftp_matchmaker_client
|
| 71 |
if not is_connected():
|
|
@@ -107,7 +114,7 @@ def matchmaker(num_players, k_group=4, not_run=[]):
|
|
| 107 |
ucb_scores = ucb_score(trueskill_diff, total_comparisons, n)
|
| 108 |
|
| 109 |
# Exclude self, select opponent with highest UCB score
|
| 110 |
-
ucb_scores[selected_player] = -float('inf')
|
| 111 |
ucb_scores[not_run] = -float('inf')
|
| 112 |
opponents = np.argsort(ucb_scores)[-k_group + 1:].tolist()
|
| 113 |
|
|
@@ -117,4 +124,3 @@ def matchmaker(num_players, k_group=4, not_run=[]):
|
|
| 117 |
random.shuffle(model_ids)
|
| 118 |
|
| 119 |
return model_ids
|
| 120 |
-
|
|
|
|
| 24 |
transport.set_keepalive(60)
|
| 25 |
|
| 26 |
sftp_matchmaker_client = ssh_matchmaker_client.open_sftp()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
def is_connected():
|
| 30 |
global ssh_matchmaker_client, sftp_matchmaker_client
|
| 31 |
if ssh_matchmaker_client is None or sftp_matchmaker_client is None:
|
| 32 |
return False
|
|
|
|
| 33 |
if not ssh_matchmaker_client.get_transport().is_active():
|
| 34 |
return False
|
|
|
|
| 35 |
try:
|
| 36 |
+
sftp_matchmaker_client.listdir('.')
|
| 37 |
except Exception as e:
|
| 38 |
print(f"Error checking SFTP connection: {e}")
|
| 39 |
return False
|
| 40 |
return True
|
| 41 |
+
|
| 42 |
+
|
| 43 |
def ucb_score(trueskill_diff, t, n):
|
| 44 |
exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
|
| 45 |
ucb = -trueskill_diff + 1.0 * exploration_term
|
| 46 |
return ucb
|
| 47 |
|
| 48 |
+
|
| 49 |
def update_trueskill(ratings, ranks):
|
| 50 |
new_ratings = trueskill_env.rate(ratings, ranks)
|
| 51 |
return new_ratings
|
| 52 |
|
| 53 |
+
|
| 54 |
def serialize_rating(rating):
|
| 55 |
return {'mu': rating.mu, 'sigma': rating.sigma}
|
| 56 |
|
| 57 |
+
|
| 58 |
def deserialize_rating(rating_dict):
|
| 59 |
return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
|
| 60 |
|
| 61 |
+
|
| 62 |
def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
|
| 63 |
global sftp_matchmaker_client
|
| 64 |
if not is_connected():
|
|
|
|
| 72 |
with sftp_matchmaker_client.open(SSH_SKILL, 'w') as f:
|
| 73 |
f.write(json_data)
|
| 74 |
|
| 75 |
+
|
| 76 |
def load_json_via_sftp():
|
| 77 |
global sftp_matchmaker_client
|
| 78 |
if not is_connected():
|
|
|
|
| 114 |
ucb_scores = ucb_score(trueskill_diff, total_comparisons, n)
|
| 115 |
|
| 116 |
# Exclude self, select opponent with highest UCB score
|
| 117 |
+
ucb_scores[selected_player] = -float('inf')
|
| 118 |
ucb_scores[not_run] = -float('inf')
|
| 119 |
opponents = np.argsort(ucb_scores)[-k_group + 1:].tolist()
|
| 120 |
|
|
|
|
| 124 |
random.shuffle(model_ids)
|
| 125 |
|
| 126 |
return model_ids
|
|
|
model/matchmaker_video.py
CHANGED
|
@@ -13,6 +13,7 @@ trueskill_env = TrueSkill()
|
|
| 13 |
ssh_matchmaker_client = None
|
| 14 |
sftp_matchmaker_client = None
|
| 15 |
|
|
|
|
| 16 |
def create_ssh_matchmaker_client(server, port, user, password):
|
| 17 |
global ssh_matchmaker_client, sftp_matchmaker_client
|
| 18 |
ssh_matchmaker_client = paramiko.SSHClient()
|
|
@@ -24,35 +25,41 @@ def create_ssh_matchmaker_client(server, port, user, password):
|
|
| 24 |
transport.set_keepalive(60)
|
| 25 |
|
| 26 |
sftp_matchmaker_client = ssh_matchmaker_client.open_sftp()
|
|
|
|
|
|
|
| 27 |
def is_connected():
|
| 28 |
global ssh_matchmaker_client, sftp_matchmaker_client
|
| 29 |
if ssh_matchmaker_client is None or sftp_matchmaker_client is None:
|
| 30 |
return False
|
| 31 |
-
# 检查SSH连接是否正常
|
| 32 |
if not ssh_matchmaker_client.get_transport().is_active():
|
| 33 |
return False
|
| 34 |
-
# 检查SFTP连接是否正常
|
| 35 |
try:
|
| 36 |
-
sftp_matchmaker_client.listdir('.')
|
| 37 |
except Exception as e:
|
| 38 |
print(f"Error checking SFTP connection: {e}")
|
| 39 |
return False
|
| 40 |
return True
|
|
|
|
|
|
|
| 41 |
def ucb_score(trueskill_diff, t, n):
|
| 42 |
exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
|
| 43 |
ucb = -trueskill_diff + 1.0 * exploration_term
|
| 44 |
return ucb
|
| 45 |
|
|
|
|
| 46 |
def update_trueskill(ratings, ranks):
|
| 47 |
new_ratings = trueskill_env.rate(ratings, ranks)
|
| 48 |
return new_ratings
|
| 49 |
|
|
|
|
| 50 |
def serialize_rating(rating):
|
| 51 |
return {'mu': rating.mu, 'sigma': rating.sigma}
|
| 52 |
|
|
|
|
| 53 |
def deserialize_rating(rating_dict):
|
| 54 |
return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
|
| 55 |
|
|
|
|
| 56 |
def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
|
| 57 |
global sftp_matchmaker_client
|
| 58 |
if not is_connected():
|
|
@@ -66,6 +73,7 @@ def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
|
|
| 66 |
with sftp_matchmaker_client.open(SSH_VIDEO_SKILL, 'w') as f:
|
| 67 |
f.write(json_data)
|
| 68 |
|
|
|
|
| 69 |
def load_json_via_sftp():
|
| 70 |
global sftp_matchmaker_client
|
| 71 |
if not is_connected():
|
|
@@ -95,7 +103,7 @@ def matchmaker_video(num_players, k_group=4):
|
|
| 95 |
ucb_scores = ucb_score(trueskill_diff, total_comparisons, n)
|
| 96 |
|
| 97 |
# Exclude self, select opponent with highest UCB score
|
| 98 |
-
ucb_scores[selected_player] = -float('inf')
|
| 99 |
|
| 100 |
excluded_players_1 = [num_players-1, num_players-4]
|
| 101 |
excluded_players_2 = [num_players-2, num_players-3, num_players-5]
|
|
@@ -126,4 +134,3 @@ def matchmaker_video(num_players, k_group=4):
|
|
| 126 |
random.shuffle(model_ids)
|
| 127 |
|
| 128 |
return model_ids
|
| 129 |
-
|
|
|
|
| 13 |
ssh_matchmaker_client = None
|
| 14 |
sftp_matchmaker_client = None
|
| 15 |
|
| 16 |
+
|
| 17 |
def create_ssh_matchmaker_client(server, port, user, password):
|
| 18 |
global ssh_matchmaker_client, sftp_matchmaker_client
|
| 19 |
ssh_matchmaker_client = paramiko.SSHClient()
|
|
|
|
| 25 |
transport.set_keepalive(60)
|
| 26 |
|
| 27 |
sftp_matchmaker_client = ssh_matchmaker_client.open_sftp()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
def is_connected():
|
| 31 |
global ssh_matchmaker_client, sftp_matchmaker_client
|
| 32 |
if ssh_matchmaker_client is None or sftp_matchmaker_client is None:
|
| 33 |
return False
|
|
|
|
| 34 |
if not ssh_matchmaker_client.get_transport().is_active():
|
| 35 |
return False
|
|
|
|
| 36 |
try:
|
| 37 |
+
sftp_matchmaker_client.listdir('.')
|
| 38 |
except Exception as e:
|
| 39 |
print(f"Error checking SFTP connection: {e}")
|
| 40 |
return False
|
| 41 |
return True
|
| 42 |
+
|
| 43 |
+
|
| 44 |
def ucb_score(trueskill_diff, t, n):
|
| 45 |
exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
|
| 46 |
ucb = -trueskill_diff + 1.0 * exploration_term
|
| 47 |
return ucb
|
| 48 |
|
| 49 |
+
|
| 50 |
def update_trueskill(ratings, ranks):
|
| 51 |
new_ratings = trueskill_env.rate(ratings, ranks)
|
| 52 |
return new_ratings
|
| 53 |
|
| 54 |
+
|
| 55 |
def serialize_rating(rating):
|
| 56 |
return {'mu': rating.mu, 'sigma': rating.sigma}
|
| 57 |
|
| 58 |
+
|
| 59 |
def deserialize_rating(rating_dict):
|
| 60 |
return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
|
| 61 |
|
| 62 |
+
|
| 63 |
def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
|
| 64 |
global sftp_matchmaker_client
|
| 65 |
if not is_connected():
|
|
|
|
| 73 |
with sftp_matchmaker_client.open(SSH_VIDEO_SKILL, 'w') as f:
|
| 74 |
f.write(json_data)
|
| 75 |
|
| 76 |
+
|
| 77 |
def load_json_via_sftp():
|
| 78 |
global sftp_matchmaker_client
|
| 79 |
if not is_connected():
|
|
|
|
| 103 |
ucb_scores = ucb_score(trueskill_diff, total_comparisons, n)
|
| 104 |
|
| 105 |
# Exclude self, select opponent with highest UCB score
|
| 106 |
+
ucb_scores[selected_player] = -float('inf')
|
| 107 |
|
| 108 |
excluded_players_1 = [num_players-1, num_players-4]
|
| 109 |
excluded_players_2 = [num_players-2, num_players-3, num_players-5]
|
|
|
|
| 134 |
random.shuffle(model_ids)
|
| 135 |
|
| 136 |
return model_ids
|
|
|
model/model_manager.py
CHANGED
|
@@ -58,10 +58,8 @@ class ModelManager:
|
|
| 58 |
def generate_image_ig_api(self, prompt, model_name):
|
| 59 |
pipe = self.load_model_pipe(model_name)
|
| 60 |
result = pipe(prompt=prompt)
|
| 61 |
-
|
| 62 |
return result
|
| 63 |
|
| 64 |
-
|
| 65 |
def generate_image_ig_parallel_anony(self, prompt, model_A, model_B, model_C, model_D):
|
| 66 |
if model_A == "" and model_B == "" and model_C == "" and model_D == "":
|
| 67 |
from .matchmaker import matchmaker
|
|
@@ -73,13 +71,11 @@ class ModelManager:
|
|
| 73 |
else:
|
| 74 |
model_names = [model_A, model_B, model_C, model_D]
|
| 75 |
|
| 76 |
-
|
| 77 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 78 |
futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("huggingface")
|
| 79 |
else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names]
|
| 80 |
results = [future.result() for future in futures]
|
| 81 |
|
| 82 |
-
|
| 83 |
return results[0], results[1], results[2], results[3], \
|
| 84 |
model_names[0], model_names[1], model_names[2], model_names[3]
|
| 85 |
|
|
@@ -156,7 +152,6 @@ class ModelManager:
|
|
| 156 |
return results[0], results[1], results[2], results[3], \
|
| 157 |
model_names[0], model_names[1], model_names[2], model_names[3], prompt
|
| 158 |
|
| 159 |
-
|
| 160 |
def generate_image_ig_parallel(self, prompt, model_A, model_B):
|
| 161 |
model_names = [model_A, model_B]
|
| 162 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
@@ -165,14 +160,12 @@ class ModelManager:
|
|
| 165 |
results = [future.result() for future in futures]
|
| 166 |
return results[0], results[1]
|
| 167 |
|
| 168 |
-
|
| 169 |
@spaces.GPU(duration=200)
|
| 170 |
def generate_image_ie(self, textbox_source, textbox_target, textbox_instruct, source_image, model_name):
|
| 171 |
pipe = self.load_model_pipe(model_name)
|
| 172 |
result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct)
|
| 173 |
return result
|
| 174 |
|
| 175 |
-
|
| 176 |
def generate_image_ie_parallel(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
|
| 177 |
model_names = [model_A, model_B]
|
| 178 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
@@ -182,7 +175,6 @@ class ModelManager:
|
|
| 182 |
results = [future.result() for future in futures]
|
| 183 |
return results[0], results[1]
|
| 184 |
|
| 185 |
-
|
| 186 |
def generate_image_ie_parallel_anony(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
|
| 187 |
if model_A == "" and model_B == "":
|
| 188 |
model_names = random.sample([model for model in self.model_ie_list], 2)
|
|
|
|
| 58 |
def generate_image_ig_api(self, prompt, model_name):
|
| 59 |
pipe = self.load_model_pipe(model_name)
|
| 60 |
result = pipe(prompt=prompt)
|
|
|
|
| 61 |
return result
|
| 62 |
|
|
|
|
| 63 |
def generate_image_ig_parallel_anony(self, prompt, model_A, model_B, model_C, model_D):
|
| 64 |
if model_A == "" and model_B == "" and model_C == "" and model_D == "":
|
| 65 |
from .matchmaker import matchmaker
|
|
|
|
| 71 |
else:
|
| 72 |
model_names = [model_A, model_B, model_C, model_D]
|
| 73 |
|
|
|
|
| 74 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 75 |
futures = [executor.submit(self.generate_image_ig, prompt, model) if model.startswith("huggingface")
|
| 76 |
else executor.submit(self.generate_image_ig_api, prompt, model) for model in model_names]
|
| 77 |
results = [future.result() for future in futures]
|
| 78 |
|
|
|
|
| 79 |
return results[0], results[1], results[2], results[3], \
|
| 80 |
model_names[0], model_names[1], model_names[2], model_names[3]
|
| 81 |
|
|
|
|
| 152 |
return results[0], results[1], results[2], results[3], \
|
| 153 |
model_names[0], model_names[1], model_names[2], model_names[3], prompt
|
| 154 |
|
|
|
|
| 155 |
def generate_image_ig_parallel(self, prompt, model_A, model_B):
|
| 156 |
model_names = [model_A, model_B]
|
| 157 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
|
|
| 160 |
results = [future.result() for future in futures]
|
| 161 |
return results[0], results[1]
|
| 162 |
|
|
|
|
| 163 |
@spaces.GPU(duration=200)
|
| 164 |
def generate_image_ie(self, textbox_source, textbox_target, textbox_instruct, source_image, model_name):
|
| 165 |
pipe = self.load_model_pipe(model_name)
|
| 166 |
result = pipe(src_image = source_image, src_prompt = textbox_source, target_prompt = textbox_target, instruct_prompt = textbox_instruct)
|
| 167 |
return result
|
| 168 |
|
|
|
|
| 169 |
def generate_image_ie_parallel(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
|
| 170 |
model_names = [model_A, model_B]
|
| 171 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
|
|
| 175 |
results = [future.result() for future in futures]
|
| 176 |
return results[0], results[1]
|
| 177 |
|
|
|
|
| 178 |
def generate_image_ie_parallel_anony(self, textbox_source, textbox_target, textbox_instruct, source_image, model_A, model_B):
|
| 179 |
if model_A == "" and model_B == "":
|
| 180 |
model_names = random.sample([model for model in self.model_ie_list], 2)
|
model/model_registry.py
CHANGED
|
@@ -68,188 +68,3 @@ def get_video_model_description_md(model_list):
|
|
| 68 |
model_description_md += "\n"
|
| 69 |
ct += 1
|
| 70 |
return model_description_md
|
| 71 |
-
|
| 72 |
-
register_model_info(
|
| 73 |
-
["imagenhub_LCM_generation", "fal_LCM_text2image"],
|
| 74 |
-
"LCM",
|
| 75 |
-
"https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7",
|
| 76 |
-
"Latent Consistency Models.",
|
| 77 |
-
)
|
| 78 |
-
|
| 79 |
-
register_model_info(
|
| 80 |
-
["fal_LCM(v1.5/XL)_text2image"],
|
| 81 |
-
"LCM(v1.5/XL)",
|
| 82 |
-
"https://fal.ai/models/fast-lcm-diffusion-turbo",
|
| 83 |
-
"Latent Consistency Models (v1.5/XL)",
|
| 84 |
-
)
|
| 85 |
-
|
| 86 |
-
register_model_info(
|
| 87 |
-
["imagenhub_PlayGroundV2_generation", 'playground_PlayGroundV2_generation'],
|
| 88 |
-
"Playground v2",
|
| 89 |
-
"https://huggingface.co/playgroundai/playground-v2-1024px-aesthetic",
|
| 90 |
-
"Playground v2 – 1024px Aesthetic Model",
|
| 91 |
-
)
|
| 92 |
-
|
| 93 |
-
register_model_info(
|
| 94 |
-
["imagenhub_PlayGroundV2.5_generation", 'playground_PlayGroundV2.5_generation'],
|
| 95 |
-
"Playground v2.5",
|
| 96 |
-
"https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic",
|
| 97 |
-
"Playground v2.5 is the state-of-the-art open-source model in aesthetic quality",
|
| 98 |
-
)
|
| 99 |
-
|
| 100 |
-
register_model_info(
|
| 101 |
-
["imagenhub_OpenJourney_generation"],
|
| 102 |
-
"Openjourney",
|
| 103 |
-
"https://huggingface.co/prompthero/openjourney",
|
| 104 |
-
"Openjourney is an open source Stable Diffusion fine tuned model on Midjourney images, by PromptHero.",
|
| 105 |
-
)
|
| 106 |
-
|
| 107 |
-
register_model_info(
|
| 108 |
-
["imagenhub_SDXLTurbo_generation", "fal_SDXLTurbo_text2image"],
|
| 109 |
-
"SDXLTurbo",
|
| 110 |
-
"https://huggingface.co/stabilityai/sdxl-turbo",
|
| 111 |
-
"SDXL-Turbo is a fast generative text-to-image model.",
|
| 112 |
-
)
|
| 113 |
-
|
| 114 |
-
register_model_info(
|
| 115 |
-
["imagenhub_SDXL_generation", "fal_SDXL_text2image"],
|
| 116 |
-
"SDXL",
|
| 117 |
-
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0",
|
| 118 |
-
"SDXL is a Latent Diffusion Model that uses two fixed, pretrained text encoders.",
|
| 119 |
-
)
|
| 120 |
-
|
| 121 |
-
register_model_info(
|
| 122 |
-
["imagenhub_PixArtAlpha_generation"],
|
| 123 |
-
"PixArtAlpha",
|
| 124 |
-
"https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS",
|
| 125 |
-
"Pixart-α consists of pure transformer blocks for latent diffusion.",
|
| 126 |
-
)
|
| 127 |
-
|
| 128 |
-
register_model_info(
|
| 129 |
-
["imagenhub_PixArtSigma_generation", "fal_PixArtSigma_text2image"],
|
| 130 |
-
"PixArtSigma",
|
| 131 |
-
"https://github.com/PixArt-alpha/PixArt-sigma",
|
| 132 |
-
"Improved version of Pixart-α.",
|
| 133 |
-
)
|
| 134 |
-
|
| 135 |
-
register_model_info(
|
| 136 |
-
["imagenhub_SDXLLightning_generation", "fal_SDXLLightning_text2image"],
|
| 137 |
-
"SDXL-Lightning",
|
| 138 |
-
"https://huggingface.co/ByteDance/SDXL-Lightning",
|
| 139 |
-
"SDXL-Lightning is a lightning-fast text-to-image generation model.",
|
| 140 |
-
)
|
| 141 |
-
|
| 142 |
-
register_model_info(
|
| 143 |
-
["imagenhub_StableCascade_generation", "fal_StableCascade_text2image"],
|
| 144 |
-
"StableCascade",
|
| 145 |
-
"https://huggingface.co/stabilityai/stable-cascade",
|
| 146 |
-
"StableCascade is built upon the Würstchen architecture and working at a much smaller latent space.",
|
| 147 |
-
)
|
| 148 |
-
|
| 149 |
-
# regist image edition models
|
| 150 |
-
register_model_info(
|
| 151 |
-
["imagenhub_CycleDiffusion_edition"],
|
| 152 |
-
"CycleDiffusion",
|
| 153 |
-
"https://github.com/ChenWu98/cycle-diffusion?tab=readme-ov-file",
|
| 154 |
-
"A latent space for stochastic diffusion models.",
|
| 155 |
-
)
|
| 156 |
-
|
| 157 |
-
register_model_info(
|
| 158 |
-
["imagenhub_Pix2PixZero_edition"],
|
| 159 |
-
"Pix2PixZero",
|
| 160 |
-
"https://pix2pixzero.github.io/",
|
| 161 |
-
"A zero-shot Image-to-Image translation model.",
|
| 162 |
-
)
|
| 163 |
-
|
| 164 |
-
register_model_info(
|
| 165 |
-
["imagenhub_Prompt2prompt_edition"],
|
| 166 |
-
"Prompt2prompt",
|
| 167 |
-
"https://prompt-to-prompt.github.io/",
|
| 168 |
-
"Image Editing with Cross-Attention Control.",
|
| 169 |
-
)
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
register_model_info(
|
| 173 |
-
["imagenhub_InstructPix2Pix_edition"],
|
| 174 |
-
"InstructPix2Pix",
|
| 175 |
-
"https://www.timothybrooks.com/instruct-pix2pix",
|
| 176 |
-
"An instruction-based image editing model.",
|
| 177 |
-
)
|
| 178 |
-
|
| 179 |
-
register_model_info(
|
| 180 |
-
["imagenhub_MagicBrush_edition"],
|
| 181 |
-
"MagicBrush",
|
| 182 |
-
"https://osu-nlp-group.github.io/MagicBrush/",
|
| 183 |
-
"Manually Annotated Dataset for Instruction-Guided Image Editing.",
|
| 184 |
-
)
|
| 185 |
-
|
| 186 |
-
register_model_info(
|
| 187 |
-
["imagenhub_PNP_edition"],
|
| 188 |
-
"PNP",
|
| 189 |
-
"https://github.com/MichalGeyer/plug-and-play",
|
| 190 |
-
"Plug-and-Play Diffusion Features for Text-Driven Image-to-Image Translation.",
|
| 191 |
-
)
|
| 192 |
-
|
| 193 |
-
register_model_info(
|
| 194 |
-
["imagenhub_InfEdit_edition"],
|
| 195 |
-
"InfEdit",
|
| 196 |
-
"https://sled-group.github.io/InfEdit/",
|
| 197 |
-
"Inversion-Free Image Editing with Natural Language.",
|
| 198 |
-
)
|
| 199 |
-
|
| 200 |
-
register_model_info(
|
| 201 |
-
["imagenhub_CosXLEdit_edition"],
|
| 202 |
-
"CosXLEdit",
|
| 203 |
-
"https://huggingface.co/stabilityai/cosxl",
|
| 204 |
-
"An instruction-based image editing model from SDXL.",
|
| 205 |
-
)
|
| 206 |
-
|
| 207 |
-
register_model_info(
|
| 208 |
-
["fal_stable-cascade_text2image"],
|
| 209 |
-
"StableCascade",
|
| 210 |
-
"https://fal.ai/models/stable-cascade/api",
|
| 211 |
-
"StableCascade is a generative model that can generate high-quality images from text prompts.",
|
| 212 |
-
)
|
| 213 |
-
|
| 214 |
-
register_model_info(
|
| 215 |
-
["fal_AnimateDiff_text2video"],
|
| 216 |
-
"AnimateDiff",
|
| 217 |
-
"https://fal.ai/models/fast-animatediff-t2v",
|
| 218 |
-
"AnimateDiff is a text-driven models that produce diverse and personalized animated images.",
|
| 219 |
-
)
|
| 220 |
-
|
| 221 |
-
register_model_info(
|
| 222 |
-
["fal_AnimateDiffTurbo_text2video"],
|
| 223 |
-
"AnimateDiff Turbo",
|
| 224 |
-
"https://fal.ai/models/fast-animatediff-t2v-turbo",
|
| 225 |
-
"AnimateDiff Turbo is a lightning version of AnimateDiff.",
|
| 226 |
-
)
|
| 227 |
-
|
| 228 |
-
register_model_info(
|
| 229 |
-
["videogenhub_LaVie_generation"],
|
| 230 |
-
"LaVie",
|
| 231 |
-
"https://github.com/Vchitect/LaVie",
|
| 232 |
-
"LaVie is a video generation model with cascaded latent diffusion models.",
|
| 233 |
-
)
|
| 234 |
-
|
| 235 |
-
register_model_info(
|
| 236 |
-
["videogenhub_VideoCrafter2_generation"],
|
| 237 |
-
"VideoCrafter2",
|
| 238 |
-
"https://ailab-cvc.github.io/videocrafter2/",
|
| 239 |
-
"VideoCrafter2 is a T2V model that disentangling motion from appearance.",
|
| 240 |
-
)
|
| 241 |
-
|
| 242 |
-
register_model_info(
|
| 243 |
-
["videogenhub_ModelScope_generation"],
|
| 244 |
-
"ModelScope",
|
| 245 |
-
"https://arxiv.org/abs/2308.06571",
|
| 246 |
-
"ModelScope is a a T2V synthesis model that evolves from a T2I synthesis model.",
|
| 247 |
-
)
|
| 248 |
-
|
| 249 |
-
register_model_info(
|
| 250 |
-
["videogenhub_OpenSora_generation"],
|
| 251 |
-
"OpenSora",
|
| 252 |
-
"https://github.com/hpcaitech/Open-Sora",
|
| 253 |
-
"A community-driven opensource implementation of Sora.",
|
| 254 |
-
)
|
| 255 |
-
|
|
|
|
| 68 |
model_description_md += "\n"
|
| 69 |
ct += 1
|
| 70 |
return model_description_md
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/models/__init__.py
CHANGED
|
@@ -37,15 +37,6 @@ IMAGE_GENERATION_MODELS = [
|
|
| 37 |
"replicate_FLUX.1-dev_text2image",
|
| 38 |
]
|
| 39 |
|
| 40 |
-
|
| 41 |
-
IMAGE_EDITION_MODELS = ['imagenhub_CycleDiffusion_edition', 'imagenhub_Pix2PixZero_edition', 'imagenhub_Prompt2prompt_edition',
|
| 42 |
-
'imagenhub_SDEdit_edition', 'imagenhub_InstructPix2Pix_edition',
|
| 43 |
-
'imagenhub_MagicBrush_edition', 'imagenhub_PNP_edition',
|
| 44 |
-
'imagenhub_InfEdit_edition', 'imagenhub_CosXLEdit_edition']
|
| 45 |
-
# VIDEO_GENERATION_MODELS = ['fal_AnimateDiff_text2video',
|
| 46 |
-
# 'fal_AnimateDiffTurbo_text2video',
|
| 47 |
-
# 'videogenhub_LaVie_generation', 'videogenhub_VideoCrafter2_generation',
|
| 48 |
-
# 'videogenhub_ModelScope_generation', 'videogenhub_OpenSora_generation']
|
| 49 |
VIDEO_GENERATION_MODELS = ['replicate_Zeroscope-v2-xl_text2video',
|
| 50 |
'replicate_Animate-Diff_text2video',
|
| 51 |
'replicate_OpenSora_text2video',
|
|
@@ -59,22 +50,15 @@ VIDEO_GENERATION_MODELS = ['replicate_Zeroscope-v2-xl_text2video',
|
|
| 59 |
'other_Sora_text2video',
|
| 60 |
]
|
| 61 |
|
|
|
|
| 62 |
def load_pipeline(model_name):
|
| 63 |
"""
|
| 64 |
Load a model pipeline based on the model name
|
| 65 |
Args:
|
| 66 |
model_name (str): The name of the model to load, should be of the form {source}_{name}_{type}
|
| 67 |
-
the source can be either imagenhub or playground
|
| 68 |
-
the name is the name of the model used to load the model
|
| 69 |
-
the type is the type of the model, either generation or edition
|
| 70 |
"""
|
| 71 |
model_source, model_name, model_type = model_name.split("_")
|
| 72 |
-
|
| 73 |
-
# pipe = load_imagenhub_model(model_name, model_type)
|
| 74 |
-
# elif model_source == "fal":
|
| 75 |
-
# pipe = load_fal_model(model_name, model_type)
|
| 76 |
-
# elif model_source == "videogenhub":
|
| 77 |
-
# pipe = load_videogenhub_model(model_name)
|
| 78 |
if model_source == "replicate":
|
| 79 |
pipe = load_replicate_model(model_name, model_type)
|
| 80 |
elif model_source == "huggingface":
|
|
|
|
| 37 |
"replicate_FLUX.1-dev_text2image",
|
| 38 |
]
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
VIDEO_GENERATION_MODELS = ['replicate_Zeroscope-v2-xl_text2video',
|
| 41 |
'replicate_Animate-Diff_text2video',
|
| 42 |
'replicate_OpenSora_text2video',
|
|
|
|
| 50 |
'other_Sora_text2video',
|
| 51 |
]
|
| 52 |
|
| 53 |
+
|
| 54 |
def load_pipeline(model_name):
|
| 55 |
"""
|
| 56 |
Load a model pipeline based on the model name
|
| 57 |
Args:
|
| 58 |
model_name (str): The name of the model to load, should be of the form {source}_{name}_{type}
|
|
|
|
|
|
|
|
|
|
| 59 |
"""
|
| 60 |
model_source, model_name, model_type = model_name.split("_")
|
| 61 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
if model_source == "replicate":
|
| 63 |
pipe = load_replicate_model(model_name, model_type)
|
| 64 |
elif model_source == "huggingface":
|
model/models/openai_api_models.py
CHANGED
|
@@ -12,7 +12,6 @@ class OpenaiModel():
|
|
| 12 |
self.model_type = model_type
|
| 13 |
|
| 14 |
def __call__(self, *args, **kwargs):
|
| 15 |
-
|
| 16 |
if self.model_type == "text2image":
|
| 17 |
assert "prompt" in kwargs, "prompt is required for text2image model"
|
| 18 |
|
|
@@ -47,7 +46,6 @@ class OpenaiModel():
|
|
| 47 |
raise ValueError("model_type must be text2image or image2image")
|
| 48 |
|
| 49 |
|
| 50 |
-
|
| 51 |
def load_openai_model(model_name, model_type):
|
| 52 |
return OpenaiModel(model_name, model_type)
|
| 53 |
|
|
|
|
| 12 |
self.model_type = model_type
|
| 13 |
|
| 14 |
def __call__(self, *args, **kwargs):
|
|
|
|
| 15 |
if self.model_type == "text2image":
|
| 16 |
assert "prompt" in kwargs, "prompt is required for text2image model"
|
| 17 |
|
|
|
|
| 46 |
raise ValueError("model_type must be text2image or image2image")
|
| 47 |
|
| 48 |
|
|
|
|
| 49 |
def load_openai_model(model_name, model_type):
|
| 50 |
return OpenaiModel(model_name, model_type)
|
| 51 |
|
model/models/other_api_models.py
CHANGED
|
@@ -4,6 +4,7 @@ import os
|
|
| 4 |
from PIL import Image
|
| 5 |
import io, time
|
| 6 |
|
|
|
|
| 7 |
class OtherModel():
|
| 8 |
def __init__(self, model_name, model_type):
|
| 9 |
self.model_name = model_name
|
|
@@ -75,6 +76,8 @@ class OtherModel():
|
|
| 75 |
|
| 76 |
else:
|
| 77 |
raise ValueError("model_type must be text2image")
|
|
|
|
|
|
|
| 78 |
def load_other_model(model_name, model_type):
|
| 79 |
return OtherModel(model_name, model_type)
|
| 80 |
|
|
@@ -86,30 +89,3 @@ if __name__ == "__main__":
|
|
| 86 |
result = pipe(prompt="An Impressionist illustration depicts a river winding through a meadow ")
|
| 87 |
print(result)
|
| 88 |
exit()
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
# key = os.environ.get('MIDJOURNEY_KEY')
|
| 92 |
-
# prompt = "a good girl"
|
| 93 |
-
|
| 94 |
-
# conn = http.client.HTTPSConnection("xdai.online")
|
| 95 |
-
# payload = json.dumps({
|
| 96 |
-
# "messages": [
|
| 97 |
-
# {
|
| 98 |
-
# "role": "user",
|
| 99 |
-
# "content": "{}".format(prompt)
|
| 100 |
-
# }
|
| 101 |
-
# ],
|
| 102 |
-
# "stream": True,
|
| 103 |
-
# "model": "luma-video",
|
| 104 |
-
# # "model": "pika-text-to-video",
|
| 105 |
-
# })
|
| 106 |
-
# headers = {
|
| 107 |
-
# 'Authorization': "Bearer {}".format(key),
|
| 108 |
-
# 'Content-Type': 'application/json'
|
| 109 |
-
# }
|
| 110 |
-
# conn.request("POST", "/v1/chat/completions", payload, headers)
|
| 111 |
-
# res = conn.getresponse()
|
| 112 |
-
# data = res.read()
|
| 113 |
-
# info = data.decode("utf-8")
|
| 114 |
-
# print(data.decode("utf-8"))
|
| 115 |
-
|
|
|
|
| 4 |
from PIL import Image
|
| 5 |
import io, time
|
| 6 |
|
| 7 |
+
|
| 8 |
class OtherModel():
|
| 9 |
def __init__(self, model_name, model_type):
|
| 10 |
self.model_name = model_name
|
|
|
|
| 76 |
|
| 77 |
else:
|
| 78 |
raise ValueError("model_type must be text2image")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
def load_other_model(model_name, model_type):
|
| 82 |
return OtherModel(model_name, model_type)
|
| 83 |
|
|
|
|
| 89 |
result = pipe(prompt="An Impressionist illustration depicts a river winding through a meadow ")
|
| 90 |
print(result)
|
| 91 |
exit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/models/replicate_api_models.py
CHANGED
|
@@ -40,11 +40,11 @@ Replicate_MODEl_NAME_MAP = {
|
|
| 40 |
"FLUX.1-dev": "black-forest-labs/flux-dev",
|
| 41 |
}
|
| 42 |
|
|
|
|
| 43 |
class ReplicateModel():
|
| 44 |
def __init__(self, model_name, model_type):
|
| 45 |
self.model_name = model_name
|
| 46 |
self.model_type = model_type
|
| 47 |
-
# os.environ['FAL_KEY'] = os.environ['FalAPI']
|
| 48 |
|
| 49 |
def __call__(self, *args, **kwargs):
|
| 50 |
if self.model_type == "text2image":
|
|
@@ -179,155 +179,14 @@ class ReplicateModel():
|
|
| 179 |
else:
|
| 180 |
raise ValueError("model_type must be text2image or image2image")
|
| 181 |
|
|
|
|
| 182 |
def load_replicate_model(model_name, model_type):
|
| 183 |
return ReplicateModel(model_name, model_type)
|
| 184 |
|
| 185 |
|
| 186 |
if __name__ == "__main__":
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
from moviepy.editor import VideoFileClip
|
| 193 |
-
|
| 194 |
-
# model_name = 'replicate_zeroscope-v2-xl_text2video'
|
| 195 |
-
# model_name = 'replicate_Damo-Text-to-Video_text2video'
|
| 196 |
-
# model_name = 'replicate_Animate-Diff_text2video'
|
| 197 |
-
# model_name = 'replicate_open-sora_text2video'
|
| 198 |
-
# model_name = 'replicate_lavie_text2video'
|
| 199 |
-
# model_name = 'replicate_video-crafter_text2video'
|
| 200 |
-
# model_name = 'replicate_stable-video-diffusion_text2video'
|
| 201 |
-
# model_source, model_name, model_type = model_name.split("_")
|
| 202 |
-
# pipe = load_replicate_model(model_name, model_type)
|
| 203 |
-
# prompt = "Clown fish swimming in a coral reef, beautiful, 8k, perfect, award winning, national geographic"
|
| 204 |
-
# result = pipe(prompt=prompt)
|
| 205 |
-
|
| 206 |
-
# # 文件复制
|
| 207 |
-
source_folder = '/mnt/data/lizhikai/ksort_video_cache/Pika-v1.0add/'
|
| 208 |
-
destination_folder = '/mnt/data/lizhikai/ksort_video_cache/Advance/'
|
| 209 |
-
|
| 210 |
-
special_char = 'output'
|
| 211 |
-
for dirpath, dirnames, filenames in os.walk(source_folder):
|
| 212 |
-
for dirname in dirnames:
|
| 213 |
-
des_dirname = "output-"+dirname[-3:]
|
| 214 |
-
print(des_dirname)
|
| 215 |
-
if special_char in dirname:
|
| 216 |
-
model_name = ["Pika-v1.0"]
|
| 217 |
-
for name in model_name:
|
| 218 |
-
source_file_path = os.path.join(source_folder, os.path.join(dirname, name+".mp4"))
|
| 219 |
-
print(source_file_path)
|
| 220 |
-
destination_file_path = os.path.join(destination_folder, os.path.join(des_dirname, name+".mp4"))
|
| 221 |
-
print(destination_file_path)
|
| 222 |
-
shutil.copy(source_file_path, destination_file_path)
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
# 视频裁剪
|
| 226 |
-
# root_dir = '/mnt/data/lizhikai/ksort_video_cache/Runway-Gen3/'
|
| 227 |
-
# root_dir = '/mnt/data/lizhikai/ksort_video_cache/Runway-Gen2/'
|
| 228 |
-
# root_dir = '/mnt/data/lizhikai/ksort_video_cache/Pika-Beta/'
|
| 229 |
-
# root_dir = '/mnt/data/lizhikai/ksort_video_cache/Pika-v1/'
|
| 230 |
-
# root_dir = '/mnt/data/lizhikai/ksort_video_cache/Sora/'
|
| 231 |
-
# root_dir = '/mnt/data/lizhikai/ksort_video_cache/Pika-v1.0add/'
|
| 232 |
-
# special_char = 'output'
|
| 233 |
-
# num = 0
|
| 234 |
-
# for dirpath, dirnames, filenames in os.walk(root_dir):
|
| 235 |
-
# for dirname in dirnames:
|
| 236 |
-
# # 如果文件夹名称中包含指定的特殊字符
|
| 237 |
-
# if special_char in dirname:
|
| 238 |
-
# num = num+1
|
| 239 |
-
# print(num)
|
| 240 |
-
# if num < 0:
|
| 241 |
-
# continue
|
| 242 |
-
# video_path = os.path.join(root_dir, (os.path.join(dirname, f"{dirname}.mp4")))
|
| 243 |
-
# out_video_path = os.path.join(root_dir, (os.path.join(dirname, f"Pika-v1.0.mp4")))
|
| 244 |
-
# print(video_path)
|
| 245 |
-
# print(out_video_path)
|
| 246 |
-
|
| 247 |
-
# video = VideoFileClip(video_path)
|
| 248 |
-
# width, height = video.size
|
| 249 |
-
# center_x, center_y = width // 2, height // 2
|
| 250 |
-
# new_width, new_height = 512, 512
|
| 251 |
-
# cropped_video = video.crop(x_center=center_x, y_center=center_y, width=min(width, height), height=min(width, height))
|
| 252 |
-
# resized_video = cropped_video.resize(newsize=(new_width, new_height))
|
| 253 |
-
# resized_video.write_videofile(out_video_path, codec='libx264', fps=video.fps)
|
| 254 |
-
# os.remove(video_path)
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
# file_path = '/home/lizhikai/webvid_prompt100.txt'
|
| 259 |
-
# str_list = []
|
| 260 |
-
# with open(file_path, 'r', encoding='utf-8') as file:
|
| 261 |
-
# for line in file:
|
| 262 |
-
# str_list.append(line.strip())
|
| 263 |
-
# if len(str_list) == 100:
|
| 264 |
-
# break
|
| 265 |
-
|
| 266 |
-
# 生成代码
|
| 267 |
-
# def generate_image_ig_api(prompt, model_name):
|
| 268 |
-
# model_source, model_name, model_type = model_name.split("_")
|
| 269 |
-
# pipe = load_replicate_model(model_name, model_type)
|
| 270 |
-
# result = pipe(prompt=prompt)
|
| 271 |
-
# return result
|
| 272 |
-
# model_names = ['replicate_Zeroscope-v2-xl_text2video',
|
| 273 |
-
# # 'replicate_Damo-Text-to-Video_text2video',
|
| 274 |
-
# 'replicate_Animate-Diff_text2video',
|
| 275 |
-
# 'replicate_OpenSora_text2video',
|
| 276 |
-
# 'replicate_LaVie_text2video',
|
| 277 |
-
# 'replicate_VideoCrafter2_text2video',
|
| 278 |
-
# 'replicate_Stable-Video-Diffusion_text2video',
|
| 279 |
-
# ]
|
| 280 |
-
# save_names = []
|
| 281 |
-
# for name in model_names:
|
| 282 |
-
# model_source, model_name, model_type = name.split("_")
|
| 283 |
-
# save_names.append(model_name)
|
| 284 |
-
|
| 285 |
-
# # 遍历根目录及其子目录
|
| 286 |
-
# # root_dir = '/mnt/data/lizhikai/ksort_video_cache/Runway-Gen3/'
|
| 287 |
-
# root_dir = '/mnt/data/lizhikai/ksort_video_cache/Runway-Gen2/'
|
| 288 |
-
# # root_dir = '/mnt/data/lizhikai/ksort_video_cache/Pika-Beta/'
|
| 289 |
-
# # root_dir = '/mnt/data/lizhikai/ksort_video_cache/Pika-v1/'
|
| 290 |
-
# # root_dir = '/mnt/data/lizhikai/ksort_video_cache/Sora/'
|
| 291 |
-
# special_char = 'output'
|
| 292 |
-
# num = 0
|
| 293 |
-
# for dirpath, dirnames, filenames in os.walk(root_dir):
|
| 294 |
-
# for dirname in dirnames:
|
| 295 |
-
# # 如果文件夹名称中包含指定的特殊字符
|
| 296 |
-
# if special_char in dirname:
|
| 297 |
-
# num = num+1
|
| 298 |
-
# print(num)
|
| 299 |
-
# if num < 0:
|
| 300 |
-
# continue
|
| 301 |
-
# str_list = []
|
| 302 |
-
# prompt_path = os.path.join(root_dir, (os.path.join(dirname, "prompt.txt")))
|
| 303 |
-
# print(prompt_path)
|
| 304 |
-
# with open(prompt_path, 'r', encoding='utf-8') as file:
|
| 305 |
-
# for line in file:
|
| 306 |
-
# str_list.append(line.strip())
|
| 307 |
-
# prompt = str_list[0]
|
| 308 |
-
# print(prompt)
|
| 309 |
-
|
| 310 |
-
# with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 311 |
-
# futures = [executor.submit(generate_image_ig_api, prompt, model) for model in model_names]
|
| 312 |
-
# results = [future.result() for future in futures]
|
| 313 |
-
|
| 314 |
-
# # 下载视频并保存
|
| 315 |
-
# repeat_num = 5
|
| 316 |
-
# for j, url in enumerate(results):
|
| 317 |
-
# while 1:
|
| 318 |
-
# time.sleep(1)
|
| 319 |
-
# response = requests.get(url, stream=True)
|
| 320 |
-
# if response.status_code == 200:
|
| 321 |
-
# file_path = os.path.join(os.path.join(root_dir, dirname), f'{save_names[j]}.mp4')
|
| 322 |
-
# with open(file_path, 'wb') as file:
|
| 323 |
-
# for chunk in response.iter_content(chunk_size=8192):
|
| 324 |
-
# file.write(chunk)
|
| 325 |
-
# print(f"视频 {j} 已保存到 {file_path}")
|
| 326 |
-
# break
|
| 327 |
-
# else:
|
| 328 |
-
# repeat_num = repeat_num - 1
|
| 329 |
-
# if repeat_num == 0:
|
| 330 |
-
# print(f"视频 {j} 保存失败")
|
| 331 |
-
# # raise ValueError("Video request failed.")
|
| 332 |
-
# continue
|
| 333 |
-
|
|
|
|
| 40 |
"FLUX.1-dev": "black-forest-labs/flux-dev",
|
| 41 |
}
|
| 42 |
|
| 43 |
+
|
| 44 |
class ReplicateModel():
|
| 45 |
def __init__(self, model_name, model_type):
|
| 46 |
self.model_name = model_name
|
| 47 |
self.model_type = model_type
|
|
|
|
| 48 |
|
| 49 |
def __call__(self, *args, **kwargs):
|
| 50 |
if self.model_type == "text2image":
|
|
|
|
| 179 |
else:
|
| 180 |
raise ValueError("model_type must be text2image or image2image")
|
| 181 |
|
| 182 |
+
|
| 183 |
def load_replicate_model(model_name, model_type):
|
| 184 |
return ReplicateModel(model_name, model_type)
|
| 185 |
|
| 186 |
|
| 187 |
if __name__ == "__main__":
|
| 188 |
+
model_name = 'replicate_zeroscope-v2-xl_text2video'
|
| 189 |
+
model_source, model_name, model_type = model_name.split("_")
|
| 190 |
+
pipe = load_replicate_model(model_name, model_type)
|
| 191 |
+
prompt = "Clown fish swimming in a coral reef, beautiful, 8k, perfect, award winning, national geographic"
|
| 192 |
+
result = pipe(prompt=prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
serve/Ksort.py
CHANGED
|
@@ -8,6 +8,7 @@ from .utils import disable_btn, enable_btn, invisible_btn
|
|
| 8 |
from .upload import create_remote_directory, upload_ssh_all, upload_ssh_data
|
| 9 |
import json
|
| 10 |
|
|
|
|
| 11 |
def reset_level(Top_btn):
|
| 12 |
if Top_btn == "Top 1":
|
| 13 |
level = 0
|
|
@@ -19,6 +20,7 @@ def reset_level(Top_btn):
|
|
| 19 |
level = 3
|
| 20 |
return level
|
| 21 |
|
|
|
|
| 22 |
def reset_rank(windows, rank, vote_level):
|
| 23 |
if windows == "Model A":
|
| 24 |
rank[0] = vote_level
|
|
@@ -30,6 +32,7 @@ def reset_rank(windows, rank, vote_level):
|
|
| 30 |
rank[3] = vote_level
|
| 31 |
return rank
|
| 32 |
|
|
|
|
| 33 |
def reset_btn_rank(windows, rank, btn, vote_level):
|
| 34 |
if windows == "Model A" and btn == "1":
|
| 35 |
rank[0] = 0
|
|
@@ -73,6 +76,7 @@ def reset_btn_rank(windows, rank, btn, vote_level):
|
|
| 73 |
vote_level = 3
|
| 74 |
return (rank, vote_level)
|
| 75 |
|
|
|
|
| 76 |
def reset_vote_text(rank):
|
| 77 |
rank_str = ""
|
| 78 |
for i in range(len(rank)):
|
|
@@ -83,24 +87,28 @@ def reset_vote_text(rank):
|
|
| 83 |
rank_str = rank_str + " "
|
| 84 |
return rank_str
|
| 85 |
|
|
|
|
| 86 |
def clear_rank(rank, vote_level):
|
| 87 |
for i in range(len(rank)):
|
| 88 |
rank[i] = None
|
| 89 |
vote_level = 0
|
| 90 |
return rank, vote_level
|
| 91 |
|
|
|
|
| 92 |
def revote_windows(generate_ig0, generate_ig1, generate_ig2, generate_ig3, rank, vote_level):
|
| 93 |
for i in range(len(rank)):
|
| 94 |
rank[i] = None
|
| 95 |
vote_level = 0
|
| 96 |
return generate_ig0, generate_ig1, generate_ig2, generate_ig3, rank, vote_level
|
| 97 |
|
|
|
|
| 98 |
def reset_submit(rank):
|
| 99 |
for i in range(len(rank)):
|
| 100 |
if rank[i] == None:
|
| 101 |
return disable_btn
|
| 102 |
return enable_btn
|
| 103 |
|
|
|
|
| 104 |
def reset_mode(mode):
|
| 105 |
|
| 106 |
if mode == "Best":
|
|
@@ -116,8 +124,12 @@ def reset_mode(mode):
|
|
| 116 |
(gr.Textbox(value="Best", visible=False, interactive=False),)
|
| 117 |
else:
|
| 118 |
raise ValueError("Undefined mode")
|
|
|
|
|
|
|
| 119 |
def reset_chatbot(mode, generate_ig0, generate_ig1, generate_ig2, generate_ig3):
|
| 120 |
return generate_ig0, generate_ig1, generate_ig2, generate_ig3
|
|
|
|
|
|
|
| 121 |
def get_json_filename(conv_id):
|
| 122 |
output_dir = f'{KSORT_IMAGE_DIR}/{conv_id}/json/'
|
| 123 |
if not os.path.exists(output_dir):
|
|
@@ -127,6 +139,7 @@ def get_json_filename(conv_id):
|
|
| 127 |
print(output_file)
|
| 128 |
return output_file
|
| 129 |
|
|
|
|
| 130 |
def get_img_filename(conv_id, i):
|
| 131 |
output_dir = f'{KSORT_IMAGE_DIR}/{conv_id}/image/'
|
| 132 |
if not os.path.exists(output_dir):
|
|
@@ -135,6 +148,7 @@ def get_img_filename(conv_id, i):
|
|
| 135 |
print(output_file)
|
| 136 |
return output_file
|
| 137 |
|
|
|
|
| 138 |
def vote_submit(states, textbox, rank, request: gr.Request):
|
| 139 |
conv_id = states[0].conv_id
|
| 140 |
|
|
@@ -149,6 +163,7 @@ def vote_submit(states, textbox, rank, request: gr.Request):
|
|
| 149 |
}
|
| 150 |
fout.write(json.dumps(data) + "\n")
|
| 151 |
|
|
|
|
| 152 |
def vote_ssh_submit(states, textbox, rank, user_name, user_institution):
|
| 153 |
conv_id = states[0].conv_id
|
| 154 |
output_dir = create_remote_directory(conv_id)
|
|
@@ -167,6 +182,7 @@ def vote_ssh_submit(states, textbox, rank, user_name, user_institution):
|
|
| 167 |
from .update_skill import update_skill
|
| 168 |
update_skill(rank, [x.model_name for x in states])
|
| 169 |
|
|
|
|
| 170 |
def vote_video_ssh_submit(states, textbox, prompt_path, rank, user_name, user_institution):
|
| 171 |
conv_id = states[0].conv_id
|
| 172 |
output_dir = create_remote_directory(conv_id, video=True)
|
|
@@ -186,6 +202,7 @@ def vote_video_ssh_submit(states, textbox, prompt_path, rank, user_name, user_in
|
|
| 186 |
from .update_skill_video import update_skill_video
|
| 187 |
update_skill_video(rank, [x.model_name for x in states])
|
| 188 |
|
|
|
|
| 189 |
def submit_response_igm(
|
| 190 |
state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, rank, user_name, user_institution, request: gr.Request
|
| 191 |
):
|
|
@@ -205,6 +222,8 @@ def submit_response_igm(
|
|
| 205 |
gr.Markdown(state2.model_name, visible=True),
|
| 206 |
gr.Markdown(state3.model_name, visible=True)
|
| 207 |
) + (disable_btn,)
|
|
|
|
|
|
|
| 208 |
def submit_response_vg(
|
| 209 |
state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, prompt_path, rank, user_name, user_institution, request: gr.Request
|
| 210 |
):
|
|
@@ -223,6 +242,8 @@ def submit_response_vg(
|
|
| 223 |
gr.Markdown(state2.model_name, visible=True),
|
| 224 |
gr.Markdown(state3.model_name, visible=True)
|
| 225 |
) + (disable_btn,)
|
|
|
|
|
|
|
| 226 |
def submit_response_rank_igm(
|
| 227 |
state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, rank, right_vote_text, user_name, user_institution, request: gr.Request
|
| 228 |
):
|
|
@@ -246,6 +267,8 @@ def submit_response_rank_igm(
|
|
| 246 |
)
|
| 247 |
else:
|
| 248 |
return (enable_btn,) * 16 + (enable_btn,) * 3 + ("wrong",) + (gr.Markdown("", visible=False),) * 4
|
|
|
|
|
|
|
| 249 |
def submit_response_rank_vg(
|
| 250 |
state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, prompt_path, rank, right_vote_text, user_name, user_institution, request: gr.Request
|
| 251 |
):
|
|
@@ -269,6 +292,7 @@ def submit_response_rank_vg(
|
|
| 269 |
else:
|
| 270 |
return (enable_btn,) * 16 + (enable_btn,) * 3 + ("wrong",) + (gr.Markdown("", visible=False),) * 4
|
| 271 |
|
|
|
|
| 272 |
def text_response_rank_igm(generate_ig0, generate_ig1, generate_ig2, generate_ig3, Top1_text, Top2_text, Top3_text, Top4_text, vote_textbox):
|
| 273 |
rank_list = [char for char in vote_textbox if char.isdigit()]
|
| 274 |
generate_ig = [generate_ig0, generate_ig1, generate_ig2, generate_ig3]
|
|
@@ -318,6 +342,7 @@ def text_response_rank_igm(generate_ig0, generate_ig1, generate_ig2, generate_ig
|
|
| 318 |
|
| 319 |
return chatbot + [rank_str] + ["right"] + [rank]
|
| 320 |
|
|
|
|
| 321 |
def text_response_rank_vg(vote_textbox):
|
| 322 |
rank_list = [char for char in vote_textbox if char.isdigit()]
|
| 323 |
rank = [None, None, None, None]
|
|
@@ -336,6 +361,7 @@ def text_response_rank_vg(vote_textbox):
|
|
| 336 |
|
| 337 |
return [rank_str] + ["right"] + [rank]
|
| 338 |
|
|
|
|
| 339 |
def add_foreground(image, vote_level, Top1_text, Top2_text, Top3_text, Top4_text):
|
| 340 |
base_image = Image.fromarray(image).convert("RGBA")
|
| 341 |
base_image = base_image.resize((512, 512), Image.ANTIALIAS)
|
|
@@ -369,12 +395,15 @@ def add_foreground(image, vote_level, Top1_text, Top2_text, Top3_text, Top4_text
|
|
| 369 |
|
| 370 |
base_image = base_image.convert("RGB")
|
| 371 |
return base_image
|
|
|
|
|
|
|
| 372 |
def add_green_border(image):
|
| 373 |
border_color = (0, 255, 0) # RGB for green
|
| 374 |
border_size = 10 # Size of the border
|
| 375 |
img_with_border = ImageOps.expand(image, border=border_size, fill=border_color)
|
| 376 |
return img_with_border
|
| 377 |
|
|
|
|
| 378 |
def check_textbox(textbox):
|
| 379 |
if textbox=="":
|
| 380 |
return False
|
|
|
|
| 8 |
from .upload import create_remote_directory, upload_ssh_all, upload_ssh_data
|
| 9 |
import json
|
| 10 |
|
| 11 |
+
|
| 12 |
def reset_level(Top_btn):
|
| 13 |
if Top_btn == "Top 1":
|
| 14 |
level = 0
|
|
|
|
| 20 |
level = 3
|
| 21 |
return level
|
| 22 |
|
| 23 |
+
|
| 24 |
def reset_rank(windows, rank, vote_level):
|
| 25 |
if windows == "Model A":
|
| 26 |
rank[0] = vote_level
|
|
|
|
| 32 |
rank[3] = vote_level
|
| 33 |
return rank
|
| 34 |
|
| 35 |
+
|
| 36 |
def reset_btn_rank(windows, rank, btn, vote_level):
|
| 37 |
if windows == "Model A" and btn == "1":
|
| 38 |
rank[0] = 0
|
|
|
|
| 76 |
vote_level = 3
|
| 77 |
return (rank, vote_level)
|
| 78 |
|
| 79 |
+
|
| 80 |
def reset_vote_text(rank):
|
| 81 |
rank_str = ""
|
| 82 |
for i in range(len(rank)):
|
|
|
|
| 87 |
rank_str = rank_str + " "
|
| 88 |
return rank_str
|
| 89 |
|
| 90 |
+
|
| 91 |
def clear_rank(rank, vote_level):
|
| 92 |
for i in range(len(rank)):
|
| 93 |
rank[i] = None
|
| 94 |
vote_level = 0
|
| 95 |
return rank, vote_level
|
| 96 |
|
| 97 |
+
|
| 98 |
def revote_windows(generate_ig0, generate_ig1, generate_ig2, generate_ig3, rank, vote_level):
|
| 99 |
for i in range(len(rank)):
|
| 100 |
rank[i] = None
|
| 101 |
vote_level = 0
|
| 102 |
return generate_ig0, generate_ig1, generate_ig2, generate_ig3, rank, vote_level
|
| 103 |
|
| 104 |
+
|
| 105 |
def reset_submit(rank):
|
| 106 |
for i in range(len(rank)):
|
| 107 |
if rank[i] == None:
|
| 108 |
return disable_btn
|
| 109 |
return enable_btn
|
| 110 |
|
| 111 |
+
|
| 112 |
def reset_mode(mode):
|
| 113 |
|
| 114 |
if mode == "Best":
|
|
|
|
| 124 |
(gr.Textbox(value="Best", visible=False, interactive=False),)
|
| 125 |
else:
|
| 126 |
raise ValueError("Undefined mode")
|
| 127 |
+
|
| 128 |
+
|
| 129 |
def reset_chatbot(mode, generate_ig0, generate_ig1, generate_ig2, generate_ig3):
|
| 130 |
return generate_ig0, generate_ig1, generate_ig2, generate_ig3
|
| 131 |
+
|
| 132 |
+
|
| 133 |
def get_json_filename(conv_id):
|
| 134 |
output_dir = f'{KSORT_IMAGE_DIR}/{conv_id}/json/'
|
| 135 |
if not os.path.exists(output_dir):
|
|
|
|
| 139 |
print(output_file)
|
| 140 |
return output_file
|
| 141 |
|
| 142 |
+
|
| 143 |
def get_img_filename(conv_id, i):
|
| 144 |
output_dir = f'{KSORT_IMAGE_DIR}/{conv_id}/image/'
|
| 145 |
if not os.path.exists(output_dir):
|
|
|
|
| 148 |
print(output_file)
|
| 149 |
return output_file
|
| 150 |
|
| 151 |
+
|
| 152 |
def vote_submit(states, textbox, rank, request: gr.Request):
|
| 153 |
conv_id = states[0].conv_id
|
| 154 |
|
|
|
|
| 163 |
}
|
| 164 |
fout.write(json.dumps(data) + "\n")
|
| 165 |
|
| 166 |
+
|
| 167 |
def vote_ssh_submit(states, textbox, rank, user_name, user_institution):
|
| 168 |
conv_id = states[0].conv_id
|
| 169 |
output_dir = create_remote_directory(conv_id)
|
|
|
|
| 182 |
from .update_skill import update_skill
|
| 183 |
update_skill(rank, [x.model_name for x in states])
|
| 184 |
|
| 185 |
+
|
| 186 |
def vote_video_ssh_submit(states, textbox, prompt_path, rank, user_name, user_institution):
|
| 187 |
conv_id = states[0].conv_id
|
| 188 |
output_dir = create_remote_directory(conv_id, video=True)
|
|
|
|
| 202 |
from .update_skill_video import update_skill_video
|
| 203 |
update_skill_video(rank, [x.model_name for x in states])
|
| 204 |
|
| 205 |
+
|
| 206 |
def submit_response_igm(
|
| 207 |
state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, rank, user_name, user_institution, request: gr.Request
|
| 208 |
):
|
|
|
|
| 222 |
gr.Markdown(state2.model_name, visible=True),
|
| 223 |
gr.Markdown(state3.model_name, visible=True)
|
| 224 |
) + (disable_btn,)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
def submit_response_vg(
|
| 228 |
state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, prompt_path, rank, user_name, user_institution, request: gr.Request
|
| 229 |
):
|
|
|
|
| 242 |
gr.Markdown(state2.model_name, visible=True),
|
| 243 |
gr.Markdown(state3.model_name, visible=True)
|
| 244 |
) + (disable_btn,)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
def submit_response_rank_igm(
|
| 248 |
state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, rank, right_vote_text, user_name, user_institution, request: gr.Request
|
| 249 |
):
|
|
|
|
| 267 |
)
|
| 268 |
else:
|
| 269 |
return (enable_btn,) * 16 + (enable_btn,) * 3 + ("wrong",) + (gr.Markdown("", visible=False),) * 4
|
| 270 |
+
|
| 271 |
+
|
| 272 |
def submit_response_rank_vg(
|
| 273 |
state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, prompt_path, rank, right_vote_text, user_name, user_institution, request: gr.Request
|
| 274 |
):
|
|
|
|
| 292 |
else:
|
| 293 |
return (enable_btn,) * 16 + (enable_btn,) * 3 + ("wrong",) + (gr.Markdown("", visible=False),) * 4
|
| 294 |
|
| 295 |
+
|
| 296 |
def text_response_rank_igm(generate_ig0, generate_ig1, generate_ig2, generate_ig3, Top1_text, Top2_text, Top3_text, Top4_text, vote_textbox):
|
| 297 |
rank_list = [char for char in vote_textbox if char.isdigit()]
|
| 298 |
generate_ig = [generate_ig0, generate_ig1, generate_ig2, generate_ig3]
|
|
|
|
| 342 |
|
| 343 |
return chatbot + [rank_str] + ["right"] + [rank]
|
| 344 |
|
| 345 |
+
|
| 346 |
def text_response_rank_vg(vote_textbox):
|
| 347 |
rank_list = [char for char in vote_textbox if char.isdigit()]
|
| 348 |
rank = [None, None, None, None]
|
|
|
|
| 361 |
|
| 362 |
return [rank_str] + ["right"] + [rank]
|
| 363 |
|
| 364 |
+
|
| 365 |
def add_foreground(image, vote_level, Top1_text, Top2_text, Top3_text, Top4_text):
|
| 366 |
base_image = Image.fromarray(image).convert("RGBA")
|
| 367 |
base_image = base_image.resize((512, 512), Image.ANTIALIAS)
|
|
|
|
| 395 |
|
| 396 |
base_image = base_image.convert("RGB")
|
| 397 |
return base_image
|
| 398 |
+
|
| 399 |
+
|
| 400 |
def add_green_border(image):
|
| 401 |
border_color = (0, 255, 0) # RGB for green
|
| 402 |
border_size = 10 # Size of the border
|
| 403 |
img_with_border = ImageOps.expand(image, border=border_size, fill=border_color)
|
| 404 |
return img_with_border
|
| 405 |
|
| 406 |
+
|
| 407 |
def check_textbox(textbox):
|
| 408 |
if textbox=="":
|
| 409 |
return False
|
serve/leaderboard.py
CHANGED
|
@@ -40,12 +40,14 @@ def make_leaderboard_md():
|
|
| 40 |
"""
|
| 41 |
return leaderboard_md
|
| 42 |
|
|
|
|
| 43 |
def make_leaderboard_video_md():
|
| 44 |
leaderboard_md = f"""
|
| 45 |
# 🏆 K-Sort Arena Leaderboard (Text-to-Video Generation)
|
| 46 |
"""
|
| 47 |
return leaderboard_md
|
| 48 |
|
|
|
|
| 49 |
def model_hyperlink(model_name, link):
|
| 50 |
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
|
| 51 |
|
|
@@ -89,11 +91,13 @@ def make_disclaimer_md():
|
|
| 89 |
'''
|
| 90 |
return disclaimer_md
|
| 91 |
|
|
|
|
| 92 |
def make_arena_leaderboard_data(results):
|
| 93 |
import pandas as pd
|
| 94 |
df = pd.DataFrame(results)
|
| 95 |
return df
|
| 96 |
|
|
|
|
| 97 |
def build_leaderboard_tab(score_result_file = 'sorted_score_list.json'):
|
| 98 |
with open(score_result_file, "r") as json_file:
|
| 99 |
data = json.load(json_file)
|
|
|
|
| 40 |
"""
|
| 41 |
return leaderboard_md
|
| 42 |
|
| 43 |
+
|
| 44 |
def make_leaderboard_video_md():
|
| 45 |
leaderboard_md = f"""
|
| 46 |
# 🏆 K-Sort Arena Leaderboard (Text-to-Video Generation)
|
| 47 |
"""
|
| 48 |
return leaderboard_md
|
| 49 |
|
| 50 |
+
|
| 51 |
def model_hyperlink(model_name, link):
|
| 52 |
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
|
| 53 |
|
|
|
|
| 91 |
'''
|
| 92 |
return disclaimer_md
|
| 93 |
|
| 94 |
+
|
| 95 |
def make_arena_leaderboard_data(results):
|
| 96 |
import pandas as pd
|
| 97 |
df = pd.DataFrame(results)
|
| 98 |
return df
|
| 99 |
|
| 100 |
+
|
| 101 |
def build_leaderboard_tab(score_result_file = 'sorted_score_list.json'):
|
| 102 |
with open(score_result_file, "r") as json_file:
|
| 103 |
data = json.load(json_file)
|
serve/update_skill.py
CHANGED
|
@@ -9,9 +9,11 @@ trueskill_env = TrueSkill()
|
|
| 9 |
sys.path.append('../')
|
| 10 |
from model.models import IMAGE_GENERATION_MODELS
|
| 11 |
|
|
|
|
| 12 |
ssh_skill_client = None
|
| 13 |
sftp_skill_client = None
|
| 14 |
|
|
|
|
| 15 |
def create_ssh_skill_client(server, port, user, password):
|
| 16 |
global ssh_skill_client, sftp_skill_client
|
| 17 |
ssh_skill_client = paramiko.SSHClient()
|
|
@@ -23,32 +25,37 @@ def create_ssh_skill_client(server, port, user, password):
|
|
| 23 |
transport.set_keepalive(60)
|
| 24 |
|
| 25 |
sftp_skill_client = ssh_skill_client.open_sftp()
|
|
|
|
|
|
|
| 26 |
def is_connected():
|
| 27 |
global ssh_skill_client, sftp_skill_client
|
| 28 |
if ssh_skill_client is None or sftp_skill_client is None:
|
| 29 |
return False
|
| 30 |
-
# 检查SSH连接是否正常
|
| 31 |
if not ssh_skill_client.get_transport().is_active():
|
| 32 |
return False
|
| 33 |
-
# 检查SFTP连接是否正常
|
| 34 |
try:
|
| 35 |
-
sftp_skill_client.listdir('.')
|
| 36 |
except Exception as e:
|
| 37 |
print(f"Error checking SFTP connection: {e}")
|
| 38 |
return False
|
| 39 |
return True
|
|
|
|
|
|
|
| 40 |
def ucb_score(trueskill_diff, t, n):
|
| 41 |
exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
|
| 42 |
ucb = -trueskill_diff + 1.0 * exploration_term
|
| 43 |
return ucb
|
| 44 |
|
|
|
|
| 45 |
def update_trueskill(ratings, ranks):
|
| 46 |
new_ratings = trueskill_env.rate(ratings, ranks)
|
| 47 |
return new_ratings
|
| 48 |
|
|
|
|
| 49 |
def serialize_rating(rating):
|
| 50 |
return {'mu': rating.mu, 'sigma': rating.sigma}
|
| 51 |
|
|
|
|
| 52 |
def deserialize_rating(rating_dict):
|
| 53 |
return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
|
| 54 |
|
|
@@ -66,6 +73,7 @@ def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
|
|
| 66 |
with sftp_skill_client.open(SSH_SKILL, 'w') as f:
|
| 67 |
f.write(json_data)
|
| 68 |
|
|
|
|
| 69 |
def load_json_via_sftp():
|
| 70 |
global sftp_skill_client
|
| 71 |
if not is_connected():
|
|
|
|
| 9 |
sys.path.append('../')
|
| 10 |
from model.models import IMAGE_GENERATION_MODELS
|
| 11 |
|
| 12 |
+
|
| 13 |
ssh_skill_client = None
|
| 14 |
sftp_skill_client = None
|
| 15 |
|
| 16 |
+
|
| 17 |
def create_ssh_skill_client(server, port, user, password):
|
| 18 |
global ssh_skill_client, sftp_skill_client
|
| 19 |
ssh_skill_client = paramiko.SSHClient()
|
|
|
|
| 25 |
transport.set_keepalive(60)
|
| 26 |
|
| 27 |
sftp_skill_client = ssh_skill_client.open_sftp()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
def is_connected():
|
| 31 |
global ssh_skill_client, sftp_skill_client
|
| 32 |
if ssh_skill_client is None or sftp_skill_client is None:
|
| 33 |
return False
|
|
|
|
| 34 |
if not ssh_skill_client.get_transport().is_active():
|
| 35 |
return False
|
|
|
|
| 36 |
try:
|
| 37 |
+
sftp_skill_client.listdir('.')
|
| 38 |
except Exception as e:
|
| 39 |
print(f"Error checking SFTP connection: {e}")
|
| 40 |
return False
|
| 41 |
return True
|
| 42 |
+
|
| 43 |
+
|
| 44 |
def ucb_score(trueskill_diff, t, n):
|
| 45 |
exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
|
| 46 |
ucb = -trueskill_diff + 1.0 * exploration_term
|
| 47 |
return ucb
|
| 48 |
|
| 49 |
+
|
| 50 |
def update_trueskill(ratings, ranks):
|
| 51 |
new_ratings = trueskill_env.rate(ratings, ranks)
|
| 52 |
return new_ratings
|
| 53 |
|
| 54 |
+
|
| 55 |
def serialize_rating(rating):
|
| 56 |
return {'mu': rating.mu, 'sigma': rating.sigma}
|
| 57 |
|
| 58 |
+
|
| 59 |
def deserialize_rating(rating_dict):
|
| 60 |
return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
|
| 61 |
|
|
|
|
| 73 |
with sftp_skill_client.open(SSH_SKILL, 'w') as f:
|
| 74 |
f.write(json_data)
|
| 75 |
|
| 76 |
+
|
| 77 |
def load_json_via_sftp():
|
| 78 |
global sftp_skill_client
|
| 79 |
if not is_connected():
|
serve/update_skill_video.py
CHANGED
|
@@ -9,9 +9,11 @@ trueskill_env = TrueSkill()
|
|
| 9 |
sys.path.append('../')
|
| 10 |
from model.models import VIDEO_GENERATION_MODELS
|
| 11 |
|
|
|
|
| 12 |
ssh_skill_client = None
|
| 13 |
sftp_skill_client = None
|
| 14 |
|
|
|
|
| 15 |
def create_ssh_skill_client(server, port, user, password):
|
| 16 |
global ssh_skill_client, sftp_skill_client
|
| 17 |
ssh_skill_client = paramiko.SSHClient()
|
|
@@ -23,32 +25,37 @@ def create_ssh_skill_client(server, port, user, password):
|
|
| 23 |
transport.set_keepalive(60)
|
| 24 |
|
| 25 |
sftp_skill_client = ssh_skill_client.open_sftp()
|
|
|
|
|
|
|
| 26 |
def is_connected():
|
| 27 |
global ssh_skill_client, sftp_skill_client
|
| 28 |
if ssh_skill_client is None or sftp_skill_client is None:
|
| 29 |
return False
|
| 30 |
-
# 检查SSH连接是否正常
|
| 31 |
if not ssh_skill_client.get_transport().is_active():
|
| 32 |
return False
|
| 33 |
-
# 检查SFTP连接是否正常
|
| 34 |
try:
|
| 35 |
-
sftp_skill_client.listdir('.')
|
| 36 |
except Exception as e:
|
| 37 |
print(f"Error checking SFTP connection: {e}")
|
| 38 |
return False
|
| 39 |
return True
|
|
|
|
|
|
|
| 40 |
def ucb_score(trueskill_diff, t, n):
|
| 41 |
exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
|
| 42 |
ucb = -trueskill_diff + 1.0 * exploration_term
|
| 43 |
return ucb
|
| 44 |
|
|
|
|
| 45 |
def update_trueskill(ratings, ranks):
|
| 46 |
new_ratings = trueskill_env.rate(ratings, ranks)
|
| 47 |
return new_ratings
|
| 48 |
|
|
|
|
| 49 |
def serialize_rating(rating):
|
| 50 |
return {'mu': rating.mu, 'sigma': rating.sigma}
|
| 51 |
|
|
|
|
| 52 |
def deserialize_rating(rating_dict):
|
| 53 |
return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
|
| 54 |
|
|
@@ -66,6 +73,7 @@ def save_json_via_sftp(ratings, comparison_counts, total_comparisons):
|
|
| 66 |
with sftp_skill_client.open(SSH_VIDEO_SKILL, 'w') as f:
|
| 67 |
f.write(json_data)
|
| 68 |
|
|
|
|
| 69 |
def load_json_via_sftp():
|
| 70 |
global sftp_skill_client
|
| 71 |
if not is_connected():
|
|
|
|
| 9 |
sys.path.append('../')
|
| 10 |
from model.models import VIDEO_GENERATION_MODELS
|
| 11 |
|
| 12 |
+
|
| 13 |
ssh_skill_client = None
|
| 14 |
sftp_skill_client = None
|
| 15 |
|
| 16 |
+
|
| 17 |
def create_ssh_skill_client(server, port, user, password):
|
| 18 |
global ssh_skill_client, sftp_skill_client
|
| 19 |
ssh_skill_client = paramiko.SSHClient()
|
|
|
|
| 25 |
transport.set_keepalive(60)
|
| 26 |
|
| 27 |
sftp_skill_client = ssh_skill_client.open_sftp()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
def is_connected():
|
| 31 |
global ssh_skill_client, sftp_skill_client
|
| 32 |
if ssh_skill_client is None or sftp_skill_client is None:
|
| 33 |
return False
|
|
|
|
| 34 |
if not ssh_skill_client.get_transport().is_active():
|
| 35 |
return False
|
|
|
|
| 36 |
try:
|
| 37 |
+
sftp_skill_client.listdir('.')
|
| 38 |
except Exception as e:
|
| 39 |
print(f"Error checking SFTP connection: {e}")
|
| 40 |
return False
|
| 41 |
return True
|
| 42 |
+
|
| 43 |
+
|
| 44 |
def ucb_score(trueskill_diff, t, n):
|
| 45 |
exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5))
|
| 46 |
ucb = -trueskill_diff + 1.0 * exploration_term
|
| 47 |
return ucb
|
| 48 |
|
| 49 |
+
|
| 50 |
def update_trueskill(ratings, ranks):
|
| 51 |
new_ratings = trueskill_env.rate(ratings, ranks)
|
| 52 |
return new_ratings
|
| 53 |
|
| 54 |
+
|
| 55 |
def serialize_rating(rating):
|
| 56 |
return {'mu': rating.mu, 'sigma': rating.sigma}
|
| 57 |
|
| 58 |
+
|
| 59 |
def deserialize_rating(rating_dict):
|
| 60 |
return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma'])
|
| 61 |
|
|
|
|
| 73 |
with sftp_skill_client.open(SSH_VIDEO_SKILL, 'w') as f:
|
| 74 |
f.write(json_data)
|
| 75 |
|
| 76 |
+
|
| 77 |
def load_json_via_sftp():
|
| 78 |
global sftp_skill_client
|
| 79 |
if not is_connected():
|
serve/upload.py
CHANGED
|
@@ -9,15 +9,18 @@ import random
|
|
| 9 |
import concurrent.futures
|
| 10 |
from .constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_LOG, SSH_VIDEO_LOG, SSH_MSCOCO
|
| 11 |
|
|
|
|
| 12 |
ssh_client = None
|
| 13 |
sftp_client = None
|
| 14 |
sftp_client_imgs = None
|
| 15 |
|
|
|
|
| 16 |
def open_sftp(i=0):
|
| 17 |
global ssh_client
|
| 18 |
sftp_client = ssh_client.open_sftp()
|
| 19 |
return sftp_client
|
| 20 |
|
|
|
|
| 21 |
def create_ssh_client(server, port, user, password):
|
| 22 |
global ssh_client, sftp_client, sftp_client_imgs
|
| 23 |
ssh_client = paramiko.SSHClient()
|
|
@@ -40,22 +43,22 @@ def is_connected():
|
|
| 40 |
global ssh_client, sftp_client
|
| 41 |
if ssh_client is None or sftp_client is None:
|
| 42 |
return False
|
| 43 |
-
# 检查SSH连接是否正常
|
| 44 |
if not ssh_client.get_transport().is_active():
|
| 45 |
return False
|
| 46 |
-
# 检查SFTP连接是否正常
|
| 47 |
try:
|
| 48 |
-
sftp_client.listdir('.')
|
| 49 |
except Exception as e:
|
| 50 |
print(f"Error checking SFTP connection: {e}")
|
| 51 |
return False
|
| 52 |
return True
|
| 53 |
|
|
|
|
| 54 |
def get_image_from_url(image_url):
|
| 55 |
response = requests.get(image_url)
|
| 56 |
response.raise_for_status() # success
|
| 57 |
return Image.open(io.BytesIO(response.content))
|
| 58 |
|
|
|
|
| 59 |
# def get_random_mscoco_prompt():
|
| 60 |
# global sftp_client
|
| 61 |
# if not is_connected():
|
|
@@ -70,6 +73,7 @@ def get_image_from_url(image_url):
|
|
| 70 |
# print("\n")
|
| 71 |
# return content
|
| 72 |
|
|
|
|
| 73 |
def get_random_mscoco_prompt():
|
| 74 |
|
| 75 |
file_path = './coco_prompt.txt'
|
|
@@ -79,6 +83,7 @@ def get_random_mscoco_prompt():
|
|
| 79 |
random_line = random.choice(lines).strip()
|
| 80 |
return random_line
|
| 81 |
|
|
|
|
| 82 |
def get_random_video_prompt(root_dir):
|
| 83 |
subdirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
|
| 84 |
if not subdirs:
|
|
@@ -96,6 +101,7 @@ def get_random_video_prompt(root_dir):
|
|
| 96 |
raise NotImplementedError
|
| 97 |
return selected_dir, prompt
|
| 98 |
|
|
|
|
| 99 |
def get_ssh_random_video_prompt(root_dir, local_dir, model_names):
|
| 100 |
def is_directory(sftp, path):
|
| 101 |
try:
|
|
@@ -150,6 +156,7 @@ def get_ssh_random_video_prompt(root_dir, local_dir, model_names):
|
|
| 150 |
ssh.close()
|
| 151 |
return prompt, local_path[1:]
|
| 152 |
|
|
|
|
| 153 |
def get_ssh_random_image_prompt(root_dir, local_dir, model_names):
|
| 154 |
def is_directory(sftp, path):
|
| 155 |
try:
|
|
@@ -204,6 +211,7 @@ def get_ssh_random_image_prompt(root_dir, local_dir, model_names):
|
|
| 204 |
ssh.close()
|
| 205 |
return prompt, [Image.open(path) for path in local_path[1:]]
|
| 206 |
|
|
|
|
| 207 |
def create_remote_directory(remote_directory, video=False):
|
| 208 |
global ssh_client
|
| 209 |
if not is_connected():
|
|
@@ -220,6 +228,7 @@ def create_remote_directory(remote_directory, video=False):
|
|
| 220 |
print(f"Directory {remote_directory} created successfully.")
|
| 221 |
return log_dir
|
| 222 |
|
|
|
|
| 223 |
def upload_images(i, image_list, output_file_list, sftp_client):
|
| 224 |
with sftp_client as sftp:
|
| 225 |
if isinstance(image_list[i], str):
|
|
@@ -233,7 +242,6 @@ def upload_images(i, image_list, output_file_list, sftp_client):
|
|
| 233 |
print(f"Successfully uploaded image to {output_file_list[i]}")
|
| 234 |
|
| 235 |
|
| 236 |
-
|
| 237 |
def upload_ssh_all(states, output_dir, data, data_path):
|
| 238 |
global sftp_client
|
| 239 |
global sftp_client_imgs
|
|
@@ -246,7 +254,6 @@ def upload_ssh_all(states, output_dir, data, data_path):
|
|
| 246 |
output_file_list.append(output_file)
|
| 247 |
image_list.append(states[i].output)
|
| 248 |
|
| 249 |
-
|
| 250 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 251 |
futures = [executor.submit(upload_images, i, image_list, output_file_list, sftp_client_imgs[i]) for i in range(len(output_file_list))]
|
| 252 |
|
|
@@ -257,6 +264,7 @@ def upload_ssh_all(states, output_dir, data, data_path):
|
|
| 257 |
print(f"Successfully uploaded JSON data to {data_path}")
|
| 258 |
# create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
|
| 259 |
|
|
|
|
| 260 |
def upload_ssh_data(data, data_path):
|
| 261 |
global sftp_client
|
| 262 |
global sftp_client_imgs
|
|
|
|
| 9 |
import concurrent.futures
|
| 10 |
from .constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_LOG, SSH_VIDEO_LOG, SSH_MSCOCO
|
| 11 |
|
| 12 |
+
|
| 13 |
ssh_client = None
|
| 14 |
sftp_client = None
|
| 15 |
sftp_client_imgs = None
|
| 16 |
|
| 17 |
+
|
| 18 |
def open_sftp(i=0):
|
| 19 |
global ssh_client
|
| 20 |
sftp_client = ssh_client.open_sftp()
|
| 21 |
return sftp_client
|
| 22 |
|
| 23 |
+
|
| 24 |
def create_ssh_client(server, port, user, password):
|
| 25 |
global ssh_client, sftp_client, sftp_client_imgs
|
| 26 |
ssh_client = paramiko.SSHClient()
|
|
|
|
| 43 |
global ssh_client, sftp_client
|
| 44 |
if ssh_client is None or sftp_client is None:
|
| 45 |
return False
|
|
|
|
| 46 |
if not ssh_client.get_transport().is_active():
|
| 47 |
return False
|
|
|
|
| 48 |
try:
|
| 49 |
+
sftp_client.listdir('.')
|
| 50 |
except Exception as e:
|
| 51 |
print(f"Error checking SFTP connection: {e}")
|
| 52 |
return False
|
| 53 |
return True
|
| 54 |
|
| 55 |
+
|
| 56 |
def get_image_from_url(image_url):
|
| 57 |
response = requests.get(image_url)
|
| 58 |
response.raise_for_status() # success
|
| 59 |
return Image.open(io.BytesIO(response.content))
|
| 60 |
|
| 61 |
+
|
| 62 |
# def get_random_mscoco_prompt():
|
| 63 |
# global sftp_client
|
| 64 |
# if not is_connected():
|
|
|
|
| 73 |
# print("\n")
|
| 74 |
# return content
|
| 75 |
|
| 76 |
+
|
| 77 |
def get_random_mscoco_prompt():
|
| 78 |
|
| 79 |
file_path = './coco_prompt.txt'
|
|
|
|
| 83 |
random_line = random.choice(lines).strip()
|
| 84 |
return random_line
|
| 85 |
|
| 86 |
+
|
| 87 |
def get_random_video_prompt(root_dir):
|
| 88 |
subdirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
|
| 89 |
if not subdirs:
|
|
|
|
| 101 |
raise NotImplementedError
|
| 102 |
return selected_dir, prompt
|
| 103 |
|
| 104 |
+
|
| 105 |
def get_ssh_random_video_prompt(root_dir, local_dir, model_names):
|
| 106 |
def is_directory(sftp, path):
|
| 107 |
try:
|
|
|
|
| 156 |
ssh.close()
|
| 157 |
return prompt, local_path[1:]
|
| 158 |
|
| 159 |
+
|
| 160 |
def get_ssh_random_image_prompt(root_dir, local_dir, model_names):
|
| 161 |
def is_directory(sftp, path):
|
| 162 |
try:
|
|
|
|
| 211 |
ssh.close()
|
| 212 |
return prompt, [Image.open(path) for path in local_path[1:]]
|
| 213 |
|
| 214 |
+
|
| 215 |
def create_remote_directory(remote_directory, video=False):
|
| 216 |
global ssh_client
|
| 217 |
if not is_connected():
|
|
|
|
| 228 |
print(f"Directory {remote_directory} created successfully.")
|
| 229 |
return log_dir
|
| 230 |
|
| 231 |
+
|
| 232 |
def upload_images(i, image_list, output_file_list, sftp_client):
|
| 233 |
with sftp_client as sftp:
|
| 234 |
if isinstance(image_list[i], str):
|
|
|
|
| 242 |
print(f"Successfully uploaded image to {output_file_list[i]}")
|
| 243 |
|
| 244 |
|
|
|
|
| 245 |
def upload_ssh_all(states, output_dir, data, data_path):
|
| 246 |
global sftp_client
|
| 247 |
global sftp_client_imgs
|
|
|
|
| 254 |
output_file_list.append(output_file)
|
| 255 |
image_list.append(states[i].output)
|
| 256 |
|
|
|
|
| 257 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 258 |
futures = [executor.submit(upload_images, i, image_list, output_file_list, sftp_client_imgs[i]) for i in range(len(output_file_list))]
|
| 259 |
|
|
|
|
| 264 |
print(f"Successfully uploaded JSON data to {data_path}")
|
| 265 |
# create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)
|
| 266 |
|
| 267 |
+
|
| 268 |
def upload_ssh_data(data, data_path):
|
| 269 |
global sftp_client
|
| 270 |
global sftp_client_imgs
|