jbilcke-hf HF Staff commited on
Commit
fc0385d
·
1 Parent(s): ece1c33

fix for the custom prompt prefix

Browse files
vms/ui/project/services/importing/file_upload.py CHANGED
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
22
  class FileUploadHandler:
23
  """Handles processing of uploaded files"""
24
 
25
- def process_uploaded_files(self, file_paths: List[str], enable_splitting: bool) -> str:
26
  """Process uploaded file (ZIP, TAR, MP4, or image)
27
 
28
  Args:
@@ -48,7 +48,7 @@ class FileUploadHandler:
48
  file_ext = file_path.suffix.lower()
49
 
50
  if file_ext == '.zip':
51
- return self.process_zip_file(file_path, enable_splitting)
52
  elif file_ext == '.tar':
53
  return self.process_tar_file(file_path, enable_splitting)
54
  elif file_ext == '.mp4' or file_ext == '.webm':
@@ -63,7 +63,7 @@ class FileUploadHandler:
63
  logger.error(f"Error processing file {file_path}: {str(e)}", exc_info=True)
64
  raise gr.Error(f"Error processing file: {str(e)}")
65
 
66
- def process_zip_file(self, file_path: Path, enable_splitting: bool) -> str:
67
  """Process uploaded ZIP file containing media files or WebDataset tar files
68
 
69
  Args:
@@ -138,7 +138,7 @@ class FileUploadHandler:
138
  logger.info(f"Copied caption file for {file}")
139
  elif is_image_file(file_path):
140
  caption = txt_path.read_text()
141
- caption = add_prefix_to_caption(caption, DEFAULT_PROMPT_PREFIX)
142
  target_path.with_suffix('.txt').write_text(caption)
143
  logger.info(f"Processed caption for {file}")
144
 
 
22
  class FileUploadHandler:
23
  """Handles processing of uploaded files"""
24
 
25
+ def process_uploaded_files(self, file_paths: List[str], enable_splitting: bool, custom_prompt_prefix: str = None) -> str:
26
  """Process uploaded file (ZIP, TAR, MP4, or image)
27
 
28
  Args:
 
48
  file_ext = file_path.suffix.lower()
49
 
50
  if file_ext == '.zip':
51
+ return self.process_zip_file(file_path, enable_splitting, custom_prompt_prefix)
52
  elif file_ext == '.tar':
53
  return self.process_tar_file(file_path, enable_splitting)
54
  elif file_ext == '.mp4' or file_ext == '.webm':
 
63
  logger.error(f"Error processing file {file_path}: {str(e)}", exc_info=True)
64
  raise gr.Error(f"Error processing file: {str(e)}")
65
 
66
+ def process_zip_file(self, file_path: Path, enable_splitting: bool, custom_prompt_prefix: str = None) -> str:
67
  """Process uploaded ZIP file containing media files or WebDataset tar files
68
 
69
  Args:
 
138
  logger.info(f"Copied caption file for {file}")
139
  elif is_image_file(file_path):
140
  caption = txt_path.read_text()
141
+ caption = add_prefix_to_caption(caption, custom_prompt_prefix or DEFAULT_PROMPT_PREFIX)
142
  target_path.with_suffix('.txt').write_text(caption)
143
  logger.info(f"Processed caption for {file}")
144
 
vms/ui/project/services/importing/hub_dataset.py CHANGED
@@ -169,7 +169,8 @@ class HubDatasetBrowser:
169
  dataset_id: str,
170
  file_type: str,
171
  enable_splitting: bool,
172
- progress_callback: Optional[Callable] = None
 
173
  ) -> str:
174
  """Download all files of a specific type from the dataset
175
 
@@ -329,7 +330,8 @@ class HubDatasetBrowser:
329
  self,
330
  dataset_id: str,
331
  enable_splitting: bool,
332
- progress_callback: Optional[Callable] = None
 
333
  ) -> Tuple[str, str]:
334
  """Download a dataset and process its video/image content
335
 
@@ -555,7 +557,7 @@ class HubDatasetBrowser:
555
  txt_path = file_path.with_suffix('.txt')
556
  if txt_path.exists():
557
  caption = txt_path.read_text()
558
- caption = add_prefix_to_caption(caption, DEFAULT_PROMPT_PREFIX)
559
  target_path.with_suffix('.txt').write_text(caption)
560
  logger.info(f"Processed caption for {file_path.name}")
561
 
 
169
  dataset_id: str,
170
  file_type: str,
171
  enable_splitting: bool,
172
+ progress_callback: Optional[Callable] = None,
173
+ custom_prompt_prefix: str = None
174
  ) -> str:
175
  """Download all files of a specific type from the dataset
176
 
 
330
  self,
331
  dataset_id: str,
332
  enable_splitting: bool,
333
+ progress_callback: Optional[Callable] = None,
334
+ custom_prompt_prefix: str = None
335
  ) -> Tuple[str, str]:
336
  """Download a dataset and process its video/image content
337
 
 
557
  txt_path = file_path.with_suffix('.txt')
558
  if txt_path.exists():
559
  caption = txt_path.read_text()
560
+ caption = add_prefix_to_caption(caption, custom_prompt_prefix or DEFAULT_PROMPT_PREFIX)
561
  target_path.with_suffix('.txt').write_text(caption)
562
  logger.info(f"Processed caption for {file_path.name}")
563
 
vms/ui/project/services/importing/import_service.py CHANGED
@@ -28,7 +28,7 @@ class ImportingService:
28
  self.youtube_handler = YouTubeDownloader()
29
  self.hub_browser = HubDatasetBrowser(self.hf_api)
30
 
31
- def process_uploaded_files(self, file_paths: List[str], enable_splitting: bool) -> str:
32
  """Process uploaded file (ZIP, TAR, MP4, or image)
33
 
34
  Args:
@@ -45,7 +45,7 @@ class ImportingService:
45
 
46
  print(f"process_uploaded_files(..., enable_splitting = {enable_splitting:})")
47
  print(f"process_uploaded_files: calling self.file_handler.process_uploaded_files")
48
- return self.file_handler.process_uploaded_files(file_paths, enable_splitting)
49
 
50
  def download_youtube_video(self, url: str, enable_splitting: bool, progress=None) -> str:
51
  """Download a video from YouTube
@@ -86,7 +86,8 @@ class ImportingService:
86
  self,
87
  dataset_id: str,
88
  enable_splitting: bool,
89
- progress_callback: Optional[Callable] = None
 
90
  ) -> Tuple[str, str]:
91
  """Download a dataset and process its video/image content
92
 
@@ -98,14 +99,15 @@ class ImportingService:
98
  Returns:
99
  Tuple of (loading_msg, status_msg)
100
  """
101
- return await self.hub_browser.download_dataset(dataset_id, enable_splitting, progress_callback)
102
 
103
  async def download_file_group(
104
  self,
105
  dataset_id: str,
106
  file_type: str,
107
  enable_splitting: bool,
108
- progress_callback: Optional[Callable] = None
 
109
  ) -> str:
110
  """Download a group of files (videos or WebDatasets)
111
 
@@ -118,4 +120,4 @@ class ImportingService:
118
  Returns:
119
  Status message
120
  """
121
- return await self.hub_browser.download_file_group(dataset_id, file_type, enable_splitting, progress_callback)
 
28
  self.youtube_handler = YouTubeDownloader()
29
  self.hub_browser = HubDatasetBrowser(self.hf_api)
30
 
31
+ def process_uploaded_files(self, file_paths: List[str], enable_splitting: bool, custom_prompt_prefix: str = None) -> str:
32
  """Process uploaded file (ZIP, TAR, MP4, or image)
33
 
34
  Args:
 
45
 
46
  print(f"process_uploaded_files(..., enable_splitting = {enable_splitting:})")
47
  print(f"process_uploaded_files: calling self.file_handler.process_uploaded_files")
48
+ return self.file_handler.process_uploaded_files(file_paths, enable_splitting, custom_prompt_prefix)
49
 
50
  def download_youtube_video(self, url: str, enable_splitting: bool, progress=None) -> str:
51
  """Download a video from YouTube
 
86
  self,
87
  dataset_id: str,
88
  enable_splitting: bool,
89
+ progress_callback: Optional[Callable] = None,
90
+ custom_prompt_prefix: str = None
91
  ) -> Tuple[str, str]:
92
  """Download a dataset and process its video/image content
93
 
 
99
  Returns:
100
  Tuple of (loading_msg, status_msg)
101
  """
102
+ return await self.hub_browser.download_dataset(dataset_id, enable_splitting, progress_callback, custom_prompt_prefix)
103
 
104
  async def download_file_group(
105
  self,
106
  dataset_id: str,
107
  file_type: str,
108
  enable_splitting: bool,
109
+ progress_callback: Optional[Callable] = None,
110
+ custom_prompt_prefix: str = None
111
  ) -> str:
112
  """Download a group of files (videos or WebDatasets)
113
 
 
120
  Returns:
121
  Status message
122
  """
123
+ return await self.hub_browser.download_file_group(dataset_id, file_type, enable_splitting, progress_callback, custom_prompt_prefix)
vms/ui/project/services/training.py CHANGED
@@ -579,6 +579,7 @@ class TrainingService:
579
  precomputation_items: int = DEFAULT_PRECOMPUTATION_ITEMS,
580
  lr_warmup_steps: int = DEFAULT_NB_LR_WARMUP_STEPS,
581
  progress: Optional[gr.Progress] = None,
 
582
  ) -> Tuple[str, str]:
583
  """Start training with finetrainers"""
584
 
@@ -669,16 +670,10 @@ class TrainingService:
669
  else:
670
  flow_weighting_scheme = "logit_normal"
671
 
672
- # Get the custom prompt prefix from the tabs
673
- custom_prompt_prefix = None
674
- if hasattr(self, 'app') and self.app is not None:
675
- if hasattr(self.app, 'tabs') and 'caption_tab' in self.app.tabs:
676
- if hasattr(self.app.tabs['caption_tab'], 'components') and 'custom_prompt_prefix' in self.app.tabs['caption_tab'].components:
677
- # Get the value and clean it
678
- prefix = self.app.tabs['caption_tab'].components['custom_prompt_prefix'].value
679
- if prefix:
680
- # Clean the prefix - remove trailing comma, space or comma+space
681
- custom_prompt_prefix = prefix.rstrip(', ')
682
 
683
  # Create a proper dataset configuration JSON file
684
  dataset_config_file = self.app.output_path / "dataset_config.json"
 
579
  precomputation_items: int = DEFAULT_PRECOMPUTATION_ITEMS,
580
  lr_warmup_steps: int = DEFAULT_NB_LR_WARMUP_STEPS,
581
  progress: Optional[gr.Progress] = None,
582
+ custom_prompt_prefix: Optional[str] = None,
583
  ) -> Tuple[str, str]:
584
  """Start training with finetrainers"""
585
 
 
670
  else:
671
  flow_weighting_scheme = "logit_normal"
672
 
673
+ # Use the custom prompt prefix passed as parameter
674
+ # Clean the prefix - remove trailing comma, space or comma+space
675
+ if custom_prompt_prefix:
676
+ custom_prompt_prefix = custom_prompt_prefix.rstrip(', ')
 
 
 
 
 
 
677
 
678
  # Create a proper dataset configuration JSON file
679
  dataset_config_file = self.app.output_path / "dataset_config.json"
vms/ui/project/tabs/import_tab/hub_tab.py CHANGED
@@ -267,7 +267,7 @@ class HubTab(BaseTab):
267
  "" # status_output
268
  )
269
 
270
- async def _download_with_progress(self, dataset_id, file_type, enable_splitting, progress_callback):
271
  """Wrapper for download_file_group that integrates with progress tracking"""
272
  try:
273
  # Set up the progress callback adapter
@@ -289,7 +289,8 @@ class HubTab(BaseTab):
289
  dataset_id,
290
  file_type,
291
  enable_splitting,
292
- progress_callback=progress_adapter
 
293
  )
294
 
295
  return result
@@ -298,7 +299,7 @@ class HubTab(BaseTab):
298
  logger.error(f"Error in download with progress: {str(e)}", exc_info=True)
299
  return f"Error: {str(e)}"
300
 
301
- def download_file_group(self, dataset_id: str, enable_splitting: bool, file_type: str, progress=gr.Progress()) -> Tuple:
302
  """Handle download of a group of files (videos or WebDatasets) with progress tracking"""
303
  try:
304
  if not dataset_id:
@@ -323,7 +324,8 @@ class HubTab(BaseTab):
323
  dataset_id,
324
  file_type,
325
  enable_splitting,
326
- progress
 
327
  ))
328
 
329
  # When download is complete, update the UI
 
267
  "" # status_output
268
  )
269
 
270
+ async def _download_with_progress(self, dataset_id, file_type, enable_splitting, progress_callback, custom_prompt_prefix=None):
271
  """Wrapper for download_file_group that integrates with progress tracking"""
272
  try:
273
  # Set up the progress callback adapter
 
289
  dataset_id,
290
  file_type,
291
  enable_splitting,
292
+ progress_callback=progress_adapter,
293
+ custom_prompt_prefix=custom_prompt_prefix
294
  )
295
 
296
  return result
 
299
  logger.error(f"Error in download with progress: {str(e)}", exc_info=True)
300
  return f"Error: {str(e)}"
301
 
302
+ def download_file_group(self, dataset_id: str, enable_splitting: bool, file_type: str, custom_prompt_prefix: str = None, progress=gr.Progress()) -> Tuple:
303
  """Handle download of a group of files (videos or WebDatasets) with progress tracking"""
304
  try:
305
  if not dataset_id:
 
324
  dataset_id,
325
  file_type,
326
  enable_splitting,
327
+ progress,
328
+ custom_prompt_prefix
329
  ))
330
 
331
  # When download is complete, update the UI
vms/ui/project/tabs/import_tab/upload_tab.py CHANGED
@@ -65,7 +65,7 @@ class UploadTab(BaseTab):
65
  # File upload event with enable_splitting parameter
66
  upload_event = self.components["files"].upload(
67
  fn=self.app.importing.process_uploaded_files,
68
- inputs=[self.components["files"], self.components["enable_automatic_video_split"]],
69
  outputs=[self.components["import_status"]]
70
  ).success(
71
  fn=self.app.tabs["import_tab"].on_import_success,
 
65
  # File upload event with enable_splitting parameter
66
  upload_event = self.components["files"].upload(
67
  fn=self.app.importing.process_uploaded_files,
68
+ inputs=[self.components["files"], self.components["enable_automatic_video_split"], self.app.tabs["caption_tab"].components["custom_prompt_prefix"]],
69
  outputs=[self.components["import_status"]]
70
  ).success(
71
  fn=self.app.tabs["import_tab"].on_import_success,
vms/ui/project/tabs/train_tab.py CHANGED
@@ -906,6 +906,13 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
906
  precomputation_items = int(self.components["precomputation_items"].value)
907
  lr_warmup_steps = int(self.components["lr_warmup_steps"].value)
908
 
 
 
 
 
 
 
 
909
  # Start training (it will automatically use the checkpoint if provided)
910
  try:
911
  return self.app.training.start_training(
@@ -924,7 +931,8 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
924
  precomputation_items=precomputation_items,
925
  lr_warmup_steps=lr_warmup_steps,
926
  progress=progress,
927
- pretrained_lora_path=pretrained_lora_path
 
928
  )
929
  except Exception as e:
930
  logger.exception("Error starting training")
 
906
  precomputation_items = int(self.components["precomputation_items"].value)
907
  lr_warmup_steps = int(self.components["lr_warmup_steps"].value)
908
 
909
+ # Get custom prompt prefix from caption tab
910
+ custom_prompt_prefix = None
911
+ if hasattr(self.app, 'tabs') and 'caption_tab' in self.app.tabs:
912
+ caption_tab = self.app.tabs['caption_tab']
913
+ if hasattr(caption_tab, 'components') and 'custom_prompt_prefix' in caption_tab.components:
914
+ custom_prompt_prefix = caption_tab.components['custom_prompt_prefix'].value
915
+
916
  # Start training (it will automatically use the checkpoint if provided)
917
  try:
918
  return self.app.training.start_training(
 
931
  precomputation_items=precomputation_items,
932
  lr_warmup_steps=lr_warmup_steps,
933
  progress=progress,
934
+ pretrained_lora_path=pretrained_lora_path,
935
+ custom_prompt_prefix=custom_prompt_prefix
936
  )
937
  except Exception as e:
938
  logger.exception("Error starting training")