davanstrien HF Staff commited on
Commit
ef5559f
·
1 Parent(s): d89e107
Files changed (1) hide show
  1. app.py +40 -10
app.py CHANGED
@@ -1,17 +1,47 @@
1
- import gradio as gr
2
- # from transformers import pipeline
3
  from pathlib import Path
 
 
 
 
4
 
5
- # pipe = pipeline('image-classification','models/davanstrien/autotrain-encyclopaedia-illustrations-blog-post-3327992159')
6
 
7
- def upload_file(files):
8
- return files
9
 
10
- with gr.Blocks() as demo:
11
- file_output = gr.File()
12
- upload_button = gr.UploadButton("Click to Upload a Directory", file_types=["image", "video"], file_count="directory")
13
- upload_button.upload(upload_file, upload_button, file_output)
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- demo.launch()
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from pathlib import Path
2
+ import gradio as gr
3
+ from transformers import pipeline
4
+ from transformers import AutoModelForImageClassification
5
+ import shutil
6
 
 
7
 
8
+ chk_point = "davanstrien/autotrain-encyclopaedia-illustrations-blog-post-3327992158"
 
9
 
10
+ model = AutoModelForImageClassification.from_pretrained(chk_point)
 
 
 
11
 
12
+ pipe = pipeline(
13
+ "image-classification",
14
+ "davanstrien/autotrain-encyclopaedia-illustrations-blog-post-3327992158",
15
+ )
16
+
17
+
18
+ def make_label_folders():
19
+ folders = model.config.label2id.keys()
20
+ for folder in folders:
21
+ folder = Path(folder)
22
+ if not folder.exists():
23
+ folder.mkdir()
24
+ return folders
25
 
 
26
 
27
+ def predictions_into_folders(files):
28
+ files = [file.name for file in files]
29
+ folders = make_label_folders()
30
+ predictions = pipe(files)
31
+ for file, prediction in zip(files, predictions):
32
+ label = prediction[0]["label"]
33
+ file_name = Path(file).name
34
+ shutil.copy(file, f"{label}/{file_name}")
35
+ for folder in folders:
36
+ shutil.make_archive(folder, "zip", ".", folder)
37
+ return [f"{folder}.zip" for folder in folders]
38
+
39
+
40
+ demo = gr.Interface(
41
+ predictions_into_folders,
42
+ gr.File(file_count="directory"),
43
+ gr.Files(),
44
+ cache_examples=True,
45
+ )
46
+
47
+ demo.launch()