naman1102 commited on
Commit
85d83ea
·
1 Parent(s): 9d332ff

Update hf_utils.py

Browse files
Files changed (1) hide show
  1. hf_utils.py +15 -16
hf_utils.py CHANGED
@@ -17,29 +17,28 @@ def download_filtered_space_files(space_id: str, local_dir: str = "repo_files",
17
 
18
  print(f"Downloading Space '{space_id}' and filtering for: {', '.join(file_extensions)}")
19
 
20
- # Download the full snapshot to a temp directory
21
- repo_path = snapshot_download(repo_id=space_id, repo_type="space")
22
-
23
  # Clear out local_dir if it already exists
24
  if os.path.exists(local_dir):
25
  shutil.rmtree(local_dir)
26
 
27
- os.makedirs(local_dir, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
28
  copied_files = 0
29
-
30
- # Walk through the snapshot and copy only files with desired extensions
31
- for root, _, files in os.walk(repo_path):
32
  for file in files:
33
  if any(file.endswith(ext) for ext in file_extensions):
34
- src_file = os.path.join(root, file)
35
- rel_path = os.path.relpath(src_file, repo_path)
36
- dest_file = os.path.join(local_dir, rel_path)
37
- os.makedirs(os.path.dirname(dest_file), exist_ok=True)
38
-
39
- # Debug: Show exactly which file is being downloaded
40
- print(f"DEBUG: Downloading file: {rel_path}")
41
-
42
- shutil.copy2(src_file, dest_file)
43
  copied_files += 1
44
 
45
  print(f"Downloaded {copied_files} filtered file(s) to: {local_dir}")
 
17
 
18
  print(f"Downloading Space '{space_id}' and filtering for: {', '.join(file_extensions)}")
19
 
 
 
 
20
  # Clear out local_dir if it already exists
21
  if os.path.exists(local_dir):
22
  shutil.rmtree(local_dir)
23
 
24
+ # Convert file extensions to allow_patterns format (e.g., ['.py', '.md'] -> ['*.py', '*.md'])
25
+ allow_patterns = [f"*{ext}" for ext in file_extensions]
26
+
27
+ # Download directly to local_dir with filtering during download
28
+ repo_path = snapshot_download(
29
+ repo_id=space_id,
30
+ repo_type="space",
31
+ local_dir=local_dir,
32
+ allow_patterns=allow_patterns
33
+ )
34
+
35
+ # Count downloaded files for feedback
36
  copied_files = 0
37
+ for root, _, files in os.walk(local_dir):
 
 
38
  for file in files:
39
  if any(file.endswith(ext) for ext in file_extensions):
40
+ rel_path = os.path.relpath(os.path.join(root, file), local_dir)
41
+ print(f"DEBUG: Downloaded file: {rel_path}")
 
 
 
 
 
 
 
42
  copied_files += 1
43
 
44
  print(f"Downloaded {copied_files} filtered file(s) to: {local_dir}")