Spaces:
Running
Running
Commit
·
fc0385d
1
Parent(s):
ece1c33
fix for the custom prompt prefix
Browse files- vms/ui/project/services/importing/file_upload.py +4 -4
- vms/ui/project/services/importing/hub_dataset.py +5 -3
- vms/ui/project/services/importing/import_service.py +8 -6
- vms/ui/project/services/training.py +5 -10
- vms/ui/project/tabs/import_tab/hub_tab.py +6 -4
- vms/ui/project/tabs/import_tab/upload_tab.py +1 -1
- vms/ui/project/tabs/train_tab.py +9 -1
vms/ui/project/services/importing/file_upload.py
CHANGED
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
|
|
22 |
class FileUploadHandler:
|
23 |
"""Handles processing of uploaded files"""
|
24 |
|
25 |
-
def process_uploaded_files(self, file_paths: List[str], enable_splitting: bool) -> str:
|
26 |
"""Process uploaded file (ZIP, TAR, MP4, or image)
|
27 |
|
28 |
Args:
|
@@ -48,7 +48,7 @@ class FileUploadHandler:
|
|
48 |
file_ext = file_path.suffix.lower()
|
49 |
|
50 |
if file_ext == '.zip':
|
51 |
-
return self.process_zip_file(file_path, enable_splitting)
|
52 |
elif file_ext == '.tar':
|
53 |
return self.process_tar_file(file_path, enable_splitting)
|
54 |
elif file_ext == '.mp4' or file_ext == '.webm':
|
@@ -63,7 +63,7 @@ class FileUploadHandler:
|
|
63 |
logger.error(f"Error processing file {file_path}: {str(e)}", exc_info=True)
|
64 |
raise gr.Error(f"Error processing file: {str(e)}")
|
65 |
|
66 |
-
def process_zip_file(self, file_path: Path, enable_splitting: bool) -> str:
|
67 |
"""Process uploaded ZIP file containing media files or WebDataset tar files
|
68 |
|
69 |
Args:
|
@@ -138,7 +138,7 @@ class FileUploadHandler:
|
|
138 |
logger.info(f"Copied caption file for {file}")
|
139 |
elif is_image_file(file_path):
|
140 |
caption = txt_path.read_text()
|
141 |
-
caption = add_prefix_to_caption(caption, DEFAULT_PROMPT_PREFIX)
|
142 |
target_path.with_suffix('.txt').write_text(caption)
|
143 |
logger.info(f"Processed caption for {file}")
|
144 |
|
|
|
22 |
class FileUploadHandler:
|
23 |
"""Handles processing of uploaded files"""
|
24 |
|
25 |
+
def process_uploaded_files(self, file_paths: List[str], enable_splitting: bool, custom_prompt_prefix: str = None) -> str:
|
26 |
"""Process uploaded file (ZIP, TAR, MP4, or image)
|
27 |
|
28 |
Args:
|
|
|
48 |
file_ext = file_path.suffix.lower()
|
49 |
|
50 |
if file_ext == '.zip':
|
51 |
+
return self.process_zip_file(file_path, enable_splitting, custom_prompt_prefix)
|
52 |
elif file_ext == '.tar':
|
53 |
return self.process_tar_file(file_path, enable_splitting)
|
54 |
elif file_ext == '.mp4' or file_ext == '.webm':
|
|
|
63 |
logger.error(f"Error processing file {file_path}: {str(e)}", exc_info=True)
|
64 |
raise gr.Error(f"Error processing file: {str(e)}")
|
65 |
|
66 |
+
def process_zip_file(self, file_path: Path, enable_splitting: bool, custom_prompt_prefix: str = None) -> str:
|
67 |
"""Process uploaded ZIP file containing media files or WebDataset tar files
|
68 |
|
69 |
Args:
|
|
|
138 |
logger.info(f"Copied caption file for {file}")
|
139 |
elif is_image_file(file_path):
|
140 |
caption = txt_path.read_text()
|
141 |
+
caption = add_prefix_to_caption(caption, custom_prompt_prefix or DEFAULT_PROMPT_PREFIX)
|
142 |
target_path.with_suffix('.txt').write_text(caption)
|
143 |
logger.info(f"Processed caption for {file}")
|
144 |
|
vms/ui/project/services/importing/hub_dataset.py
CHANGED
@@ -169,7 +169,8 @@ class HubDatasetBrowser:
|
|
169 |
dataset_id: str,
|
170 |
file_type: str,
|
171 |
enable_splitting: bool,
|
172 |
-
progress_callback: Optional[Callable] = None
|
|
|
173 |
) -> str:
|
174 |
"""Download all files of a specific type from the dataset
|
175 |
|
@@ -329,7 +330,8 @@ class HubDatasetBrowser:
|
|
329 |
self,
|
330 |
dataset_id: str,
|
331 |
enable_splitting: bool,
|
332 |
-
progress_callback: Optional[Callable] = None
|
|
|
333 |
) -> Tuple[str, str]:
|
334 |
"""Download a dataset and process its video/image content
|
335 |
|
@@ -555,7 +557,7 @@ class HubDatasetBrowser:
|
|
555 |
txt_path = file_path.with_suffix('.txt')
|
556 |
if txt_path.exists():
|
557 |
caption = txt_path.read_text()
|
558 |
-
caption = add_prefix_to_caption(caption, DEFAULT_PROMPT_PREFIX)
|
559 |
target_path.with_suffix('.txt').write_text(caption)
|
560 |
logger.info(f"Processed caption for {file_path.name}")
|
561 |
|
|
|
169 |
dataset_id: str,
|
170 |
file_type: str,
|
171 |
enable_splitting: bool,
|
172 |
+
progress_callback: Optional[Callable] = None,
|
173 |
+
custom_prompt_prefix: str = None
|
174 |
) -> str:
|
175 |
"""Download all files of a specific type from the dataset
|
176 |
|
|
|
330 |
self,
|
331 |
dataset_id: str,
|
332 |
enable_splitting: bool,
|
333 |
+
progress_callback: Optional[Callable] = None,
|
334 |
+
custom_prompt_prefix: str = None
|
335 |
) -> Tuple[str, str]:
|
336 |
"""Download a dataset and process its video/image content
|
337 |
|
|
|
557 |
txt_path = file_path.with_suffix('.txt')
|
558 |
if txt_path.exists():
|
559 |
caption = txt_path.read_text()
|
560 |
+
caption = add_prefix_to_caption(caption, custom_prompt_prefix or DEFAULT_PROMPT_PREFIX)
|
561 |
target_path.with_suffix('.txt').write_text(caption)
|
562 |
logger.info(f"Processed caption for {file_path.name}")
|
563 |
|
vms/ui/project/services/importing/import_service.py
CHANGED
@@ -28,7 +28,7 @@ class ImportingService:
|
|
28 |
self.youtube_handler = YouTubeDownloader()
|
29 |
self.hub_browser = HubDatasetBrowser(self.hf_api)
|
30 |
|
31 |
-
def process_uploaded_files(self, file_paths: List[str], enable_splitting: bool) -> str:
|
32 |
"""Process uploaded file (ZIP, TAR, MP4, or image)
|
33 |
|
34 |
Args:
|
@@ -45,7 +45,7 @@ class ImportingService:
|
|
45 |
|
46 |
print(f"process_uploaded_files(..., enable_splitting = {enable_splitting:})")
|
47 |
print(f"process_uploaded_files: calling self.file_handler.process_uploaded_files")
|
48 |
-
return self.file_handler.process_uploaded_files(file_paths, enable_splitting)
|
49 |
|
50 |
def download_youtube_video(self, url: str, enable_splitting: bool, progress=None) -> str:
|
51 |
"""Download a video from YouTube
|
@@ -86,7 +86,8 @@ class ImportingService:
|
|
86 |
self,
|
87 |
dataset_id: str,
|
88 |
enable_splitting: bool,
|
89 |
-
progress_callback: Optional[Callable] = None
|
|
|
90 |
) -> Tuple[str, str]:
|
91 |
"""Download a dataset and process its video/image content
|
92 |
|
@@ -98,14 +99,15 @@ class ImportingService:
|
|
98 |
Returns:
|
99 |
Tuple of (loading_msg, status_msg)
|
100 |
"""
|
101 |
-
return await self.hub_browser.download_dataset(dataset_id, enable_splitting, progress_callback)
|
102 |
|
103 |
async def download_file_group(
|
104 |
self,
|
105 |
dataset_id: str,
|
106 |
file_type: str,
|
107 |
enable_splitting: bool,
|
108 |
-
progress_callback: Optional[Callable] = None
|
|
|
109 |
) -> str:
|
110 |
"""Download a group of files (videos or WebDatasets)
|
111 |
|
@@ -118,4 +120,4 @@ class ImportingService:
|
|
118 |
Returns:
|
119 |
Status message
|
120 |
"""
|
121 |
-
return await self.hub_browser.download_file_group(dataset_id, file_type, enable_splitting, progress_callback)
|
|
|
28 |
self.youtube_handler = YouTubeDownloader()
|
29 |
self.hub_browser = HubDatasetBrowser(self.hf_api)
|
30 |
|
31 |
+
def process_uploaded_files(self, file_paths: List[str], enable_splitting: bool, custom_prompt_prefix: str = None) -> str:
|
32 |
"""Process uploaded file (ZIP, TAR, MP4, or image)
|
33 |
|
34 |
Args:
|
|
|
45 |
|
46 |
print(f"process_uploaded_files(..., enable_splitting = {enable_splitting:})")
|
47 |
print(f"process_uploaded_files: calling self.file_handler.process_uploaded_files")
|
48 |
+
return self.file_handler.process_uploaded_files(file_paths, enable_splitting, custom_prompt_prefix)
|
49 |
|
50 |
def download_youtube_video(self, url: str, enable_splitting: bool, progress=None) -> str:
|
51 |
"""Download a video from YouTube
|
|
|
86 |
self,
|
87 |
dataset_id: str,
|
88 |
enable_splitting: bool,
|
89 |
+
progress_callback: Optional[Callable] = None,
|
90 |
+
custom_prompt_prefix: str = None
|
91 |
) -> Tuple[str, str]:
|
92 |
"""Download a dataset and process its video/image content
|
93 |
|
|
|
99 |
Returns:
|
100 |
Tuple of (loading_msg, status_msg)
|
101 |
"""
|
102 |
+
return await self.hub_browser.download_dataset(dataset_id, enable_splitting, progress_callback, custom_prompt_prefix)
|
103 |
|
104 |
async def download_file_group(
|
105 |
self,
|
106 |
dataset_id: str,
|
107 |
file_type: str,
|
108 |
enable_splitting: bool,
|
109 |
+
progress_callback: Optional[Callable] = None,
|
110 |
+
custom_prompt_prefix: str = None
|
111 |
) -> str:
|
112 |
"""Download a group of files (videos or WebDatasets)
|
113 |
|
|
|
120 |
Returns:
|
121 |
Status message
|
122 |
"""
|
123 |
+
return await self.hub_browser.download_file_group(dataset_id, file_type, enable_splitting, progress_callback, custom_prompt_prefix)
|
vms/ui/project/services/training.py
CHANGED
@@ -579,6 +579,7 @@ class TrainingService:
|
|
579 |
precomputation_items: int = DEFAULT_PRECOMPUTATION_ITEMS,
|
580 |
lr_warmup_steps: int = DEFAULT_NB_LR_WARMUP_STEPS,
|
581 |
progress: Optional[gr.Progress] = None,
|
|
|
582 |
) -> Tuple[str, str]:
|
583 |
"""Start training with finetrainers"""
|
584 |
|
@@ -669,16 +670,10 @@ class TrainingService:
|
|
669 |
else:
|
670 |
flow_weighting_scheme = "logit_normal"
|
671 |
|
672 |
-
#
|
673 |
-
|
674 |
-
if
|
675 |
-
|
676 |
-
if hasattr(self.app.tabs['caption_tab'], 'components') and 'custom_prompt_prefix' in self.app.tabs['caption_tab'].components:
|
677 |
-
# Get the value and clean it
|
678 |
-
prefix = self.app.tabs['caption_tab'].components['custom_prompt_prefix'].value
|
679 |
-
if prefix:
|
680 |
-
# Clean the prefix - remove trailing comma, space or comma+space
|
681 |
-
custom_prompt_prefix = prefix.rstrip(', ')
|
682 |
|
683 |
# Create a proper dataset configuration JSON file
|
684 |
dataset_config_file = self.app.output_path / "dataset_config.json"
|
|
|
579 |
precomputation_items: int = DEFAULT_PRECOMPUTATION_ITEMS,
|
580 |
lr_warmup_steps: int = DEFAULT_NB_LR_WARMUP_STEPS,
|
581 |
progress: Optional[gr.Progress] = None,
|
582 |
+
custom_prompt_prefix: Optional[str] = None,
|
583 |
) -> Tuple[str, str]:
|
584 |
"""Start training with finetrainers"""
|
585 |
|
|
|
670 |
else:
|
671 |
flow_weighting_scheme = "logit_normal"
|
672 |
|
673 |
+
# Use the custom prompt prefix passed as parameter
|
674 |
+
# Clean the prefix - remove trailing comma, space or comma+space
|
675 |
+
if custom_prompt_prefix:
|
676 |
+
custom_prompt_prefix = custom_prompt_prefix.rstrip(', ')
|
|
|
|
|
|
|
|
|
|
|
|
|
677 |
|
678 |
# Create a proper dataset configuration JSON file
|
679 |
dataset_config_file = self.app.output_path / "dataset_config.json"
|
vms/ui/project/tabs/import_tab/hub_tab.py
CHANGED
@@ -267,7 +267,7 @@ class HubTab(BaseTab):
|
|
267 |
"" # status_output
|
268 |
)
|
269 |
|
270 |
-
async def _download_with_progress(self, dataset_id, file_type, enable_splitting, progress_callback):
|
271 |
"""Wrapper for download_file_group that integrates with progress tracking"""
|
272 |
try:
|
273 |
# Set up the progress callback adapter
|
@@ -289,7 +289,8 @@ class HubTab(BaseTab):
|
|
289 |
dataset_id,
|
290 |
file_type,
|
291 |
enable_splitting,
|
292 |
-
progress_callback=progress_adapter
|
|
|
293 |
)
|
294 |
|
295 |
return result
|
@@ -298,7 +299,7 @@ class HubTab(BaseTab):
|
|
298 |
logger.error(f"Error in download with progress: {str(e)}", exc_info=True)
|
299 |
return f"Error: {str(e)}"
|
300 |
|
301 |
-
def download_file_group(self, dataset_id: str, enable_splitting: bool, file_type: str, progress=gr.Progress()) -> Tuple:
|
302 |
"""Handle download of a group of files (videos or WebDatasets) with progress tracking"""
|
303 |
try:
|
304 |
if not dataset_id:
|
@@ -323,7 +324,8 @@ class HubTab(BaseTab):
|
|
323 |
dataset_id,
|
324 |
file_type,
|
325 |
enable_splitting,
|
326 |
-
progress
|
|
|
327 |
))
|
328 |
|
329 |
# When download is complete, update the UI
|
|
|
267 |
"" # status_output
|
268 |
)
|
269 |
|
270 |
+
async def _download_with_progress(self, dataset_id, file_type, enable_splitting, progress_callback, custom_prompt_prefix=None):
|
271 |
"""Wrapper for download_file_group that integrates with progress tracking"""
|
272 |
try:
|
273 |
# Set up the progress callback adapter
|
|
|
289 |
dataset_id,
|
290 |
file_type,
|
291 |
enable_splitting,
|
292 |
+
progress_callback=progress_adapter,
|
293 |
+
custom_prompt_prefix=custom_prompt_prefix
|
294 |
)
|
295 |
|
296 |
return result
|
|
|
299 |
logger.error(f"Error in download with progress: {str(e)}", exc_info=True)
|
300 |
return f"Error: {str(e)}"
|
301 |
|
302 |
+
def download_file_group(self, dataset_id: str, enable_splitting: bool, file_type: str, custom_prompt_prefix: str = None, progress=gr.Progress()) -> Tuple:
|
303 |
"""Handle download of a group of files (videos or WebDatasets) with progress tracking"""
|
304 |
try:
|
305 |
if not dataset_id:
|
|
|
324 |
dataset_id,
|
325 |
file_type,
|
326 |
enable_splitting,
|
327 |
+
progress,
|
328 |
+
custom_prompt_prefix
|
329 |
))
|
330 |
|
331 |
# When download is complete, update the UI
|
vms/ui/project/tabs/import_tab/upload_tab.py
CHANGED
@@ -65,7 +65,7 @@ class UploadTab(BaseTab):
|
|
65 |
# File upload event with enable_splitting parameter
|
66 |
upload_event = self.components["files"].upload(
|
67 |
fn=self.app.importing.process_uploaded_files,
|
68 |
-
inputs=[self.components["files"], self.components["enable_automatic_video_split"]],
|
69 |
outputs=[self.components["import_status"]]
|
70 |
).success(
|
71 |
fn=self.app.tabs["import_tab"].on_import_success,
|
|
|
65 |
# File upload event with enable_splitting parameter
|
66 |
upload_event = self.components["files"].upload(
|
67 |
fn=self.app.importing.process_uploaded_files,
|
68 |
+
inputs=[self.components["files"], self.components["enable_automatic_video_split"], self.app.tabs["caption_tab"].components["custom_prompt_prefix"]],
|
69 |
outputs=[self.components["import_status"]]
|
70 |
).success(
|
71 |
fn=self.app.tabs["import_tab"].on_import_success,
|
vms/ui/project/tabs/train_tab.py
CHANGED
@@ -906,6 +906,13 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
906 |
precomputation_items = int(self.components["precomputation_items"].value)
|
907 |
lr_warmup_steps = int(self.components["lr_warmup_steps"].value)
|
908 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
909 |
# Start training (it will automatically use the checkpoint if provided)
|
910 |
try:
|
911 |
return self.app.training.start_training(
|
@@ -924,7 +931,8 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
924 |
precomputation_items=precomputation_items,
|
925 |
lr_warmup_steps=lr_warmup_steps,
|
926 |
progress=progress,
|
927 |
-
pretrained_lora_path=pretrained_lora_path
|
|
|
928 |
)
|
929 |
except Exception as e:
|
930 |
logger.exception("Error starting training")
|
|
|
906 |
precomputation_items = int(self.components["precomputation_items"].value)
|
907 |
lr_warmup_steps = int(self.components["lr_warmup_steps"].value)
|
908 |
|
909 |
+
# Get custom prompt prefix from caption tab
|
910 |
+
custom_prompt_prefix = None
|
911 |
+
if hasattr(self.app, 'tabs') and 'caption_tab' in self.app.tabs:
|
912 |
+
caption_tab = self.app.tabs['caption_tab']
|
913 |
+
if hasattr(caption_tab, 'components') and 'custom_prompt_prefix' in caption_tab.components:
|
914 |
+
custom_prompt_prefix = caption_tab.components['custom_prompt_prefix'].value
|
915 |
+
|
916 |
# Start training (it will automatically use the checkpoint if provided)
|
917 |
try:
|
918 |
return self.app.training.start_training(
|
|
|
931 |
precomputation_items=precomputation_items,
|
932 |
lr_warmup_steps=lr_warmup_steps,
|
933 |
progress=progress,
|
934 |
+
pretrained_lora_path=pretrained_lora_path,
|
935 |
+
custom_prompt_prefix=custom_prompt_prefix
|
936 |
)
|
937 |
except Exception as e:
|
938 |
logger.exception("Error starting training")
|