mrdilaw commited on
Commit
bf1c46c
·
verified ·
1 Parent(s): a04a2fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -52
app.py CHANGED
@@ -1,64 +1,44 @@
1
- import gradio as gr
2
- from loadimg import load_img
3
- import spaces
4
- from transformers import AutoModelForImageSegmentation
5
  import torch
6
  from torchvision import transforms
 
 
 
 
7
 
8
- torch.set_float32_matmul_precision("high")
 
 
9
 
10
- # تحميل النموذج
11
- birefnet = AutoModelForImageSegmentation.from_pretrained(
12
- "ZhengPeng7/BiRefNet", trust_remote_code=True
13
- )
14
- birefnet.to("cpu")
15
 
16
- # تجهيز الصورة قبل الإدخال
17
- transform_image = transforms.Compose([
18
- transforms.Resize((1024, 1024)),
19
- transforms.ToTensor(),
20
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
21
- ])
22
 
23
- # المعالجة الأساسية
24
- @spaces.GPU
25
  def process(image):
26
- image_size = image.size
27
- input_images = transform_image(image).unsqueeze(0).to("cpu")
28
  with torch.no_grad():
29
- preds = birefnet(input_images)[-1].sigmoid().cpu()
30
- pred = preds[0].squeeze()
31
- pred_pil = transforms.ToPILImage()(pred)
32
- mask = pred_pil.resize(image_size)
 
 
 
 
33
  image.putalpha(mask)
34
  return image
35
 
36
- # واجهة المستخدم
37
- def from_upload(image):
38
- im = load_img(image, output_type="pil").convert("RGB")
39
- origin = im.copy()
40
- processed = process(im)
41
- return (processed, origin)
42
-
43
- def from_url(url):
44
- im = load_img(url, output_type="pil").convert("RGB")
45
- origin = im.copy()
46
- processed = process(im)
47
- return (processed, origin)
48
-
49
- def process_file(f):
50
- name_path = f.rsplit(".", 1)[0] + ".png"
51
- im = load_img(f, output_type="pil").convert("RGB")
52
- transparent = process(im)
53
- transparent.save(name_path)
54
- return name_path
55
-
56
- # واجهات التبويبات
57
- tab1 = gr.Interface(from_upload, inputs=gr.Image(), outputs=[gr.Image(label="Processed"), gr.Image(label="Original")], title="Upload Image")
58
- tab2 = gr.Interface(from_url, inputs=gr.Textbox(label="Paste Image URL"), outputs=[gr.Image(label="Processed"), gr.Image(label="Original")], title="From URL")
59
- tab3 = gr.Interface(process_file, inputs=gr.Image(type="filepath"), outputs=gr.File(), title="Save Transparent PNG")
60
-
61
- demo = gr.TabbedInterface([tab1, tab2, tab3], ["Upload", "URL", "Save PNG"], title="Background Removal with BiRefNet")
62
 
63
- if __name__ == "__main__":
64
- demo.launch(show_error=True)
 
 
 
 
 
1
  import torch
2
  from torchvision import transforms
3
+ from PIL import Image
4
+ import gradio as gr
5
+ from transformers import AutoModelForImageClassification, AutoFeatureExtractor
6
+ import numpy as np
7
 
8
+ # تحميل نموذج BiRefNet
9
+ birefnet = AutoModelForImageClassification.from_pretrained("briaai/RMBG-1.4")
10
+ birefnet.to("cpu") # ✅ تشغيل على CPU
11
 
12
+ # تحميل المحول (feature extractor)
13
+ extractor = AutoFeatureExtractor.from_pretrained("briaai/RMBG-1.4")
 
 
 
14
 
15
+ # دالة تحويل الصورة لتنسيق النموذج
16
+ def transform_image(image):
17
+ inputs = extractor(images=image, return_tensors="pt")
18
+ return inputs["pixel_values"][0]
 
 
19
 
20
+ # دالة معالجة الصورة
 
21
  def process(image):
22
+ input_images = transform_image(image).unsqueeze(0).to("cpu") # ✅ تشغيل على CPU
 
23
  with torch.no_grad():
24
+ output = birefnet(input_images).logits.squeeze(0)[0]
25
+ mask = torch.sigmoid(output).cpu().numpy()
26
+ mask = (mask * 255).astype(np.uint8)
27
+ mask = Image.fromarray(mask).resize(image.size)
28
+
29
+ # إزالة الخلفية
30
+ image = image.convert("RGBA")
31
+ mask = mask.convert("L")
32
  image.putalpha(mask)
33
  return image
34
 
35
+ # واجهة Gradio
36
+ demo = gr.Interface(
37
+ fn=process,
38
+ inputs=gr.Image(type="pil"),
39
+ outputs=gr.Image(type="pil"),
40
+ title="إزالة خلفية الصور باستخدام BiRefNet (CPU)",
41
+ description="ارفع صورة وسيتم إزالة الخلفية تلقائيًا باستخدام نموذج BiRefNet على وحدة المعالجة المركزية فقط."
42
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ demo.launch()