DHEIVER commited on
Commit
89dcdea
·
verified ·
1 Parent(s): 85bcefb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -20
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoModelForImageClassification, AutoImageProcessor
3
  import torch
4
  from PIL import Image
5
  import numpy as np
@@ -17,20 +17,19 @@ def load_models():
17
  # Carregando modelos específicos para feridas
18
  wound_classifier = pipeline(
19
  "image-classification",
20
- model="stevhliu/wound-classification", # Modelo específico para classificação de feridas
21
  device=0 if torch.cuda.is_available() else -1
22
  )
23
 
24
  tissue_classifier = pipeline(
25
  "image-classification",
26
- model="viktorcikojevic/wound-tissue-type", # Modelo para classificação do tipo de tecido
27
  device=0 if torch.cuda.is_available() else -1
28
  )
29
 
30
  return wound_classifier, tissue_classifier
31
 
32
  def preprocess_image(image):
33
- # Normalização e pré-processamento da imagem
34
  if isinstance(image, np.ndarray):
35
  image = Image.fromarray(image)
36
  image = image.convert('RGB')
@@ -49,14 +48,13 @@ def classify_wound(image):
49
  # Classificação do tipo de tecido
50
  tissue_results = tissue_classifier(processed_image)
51
 
52
- # Formatando resultados da classificação de feridas
53
  wound_formatted = []
54
  for result in wound_results:
55
  label = WOUND_TYPES.get(result['label'], result['label'])
56
  score = result['score']
57
  wound_formatted.append((label, score))
58
 
59
- # Formatando resultados da classificação de tecidos
60
  tissue_formatted = []
61
  for result in tissue_results:
62
  label = result['label'].replace('_', ' ').title()
@@ -65,7 +63,7 @@ def classify_wound(image):
65
 
66
  return wound_formatted, tissue_formatted
67
 
68
- # Criando a interface do Gradio
69
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
70
  gr.Markdown("""
71
  # 🏥 Classificador Especializado de Feridas
@@ -77,8 +75,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
77
  with gr.Column():
78
  input_image = gr.Image(
79
  label="Upload da Imagem",
80
- type="pil",
81
- tool="select"
82
  )
83
  submit_btn = gr.Button("Analisar Ferida", variant="primary", size="lg")
84
 
@@ -121,16 +118,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
121
  inputs=input_image,
122
  outputs=[wound_output, tissue_output]
123
  )
124
-
125
- # Exemplos
126
- gr.Examples(
127
- examples=[
128
- ["image1.jpg"]
129
- ],
130
- inputs=input_image,
131
- outputs=[wound_output, tissue_output],
132
- cache_examples=True
133
- )
134
 
135
  if __name__ == "__main__":
136
- demo.launch(share=True)
 
1
  import gradio as gr
2
+ from transformers import pipeline
3
  import torch
4
  from PIL import Image
5
  import numpy as np
 
17
  # Carregando modelos específicos para feridas
18
  wound_classifier = pipeline(
19
  "image-classification",
20
+ model="stevhliu/wound-classification",
21
  device=0 if torch.cuda.is_available() else -1
22
  )
23
 
24
  tissue_classifier = pipeline(
25
  "image-classification",
26
+ model="viktorcikojevic/wound-tissue-type",
27
  device=0 if torch.cuda.is_available() else -1
28
  )
29
 
30
  return wound_classifier, tissue_classifier
31
 
32
  def preprocess_image(image):
 
33
  if isinstance(image, np.ndarray):
34
  image = Image.fromarray(image)
35
  image = image.convert('RGB')
 
48
  # Classificação do tipo de tecido
49
  tissue_results = tissue_classifier(processed_image)
50
 
51
+ # Formatando resultados
52
  wound_formatted = []
53
  for result in wound_results:
54
  label = WOUND_TYPES.get(result['label'], result['label'])
55
  score = result['score']
56
  wound_formatted.append((label, score))
57
 
 
58
  tissue_formatted = []
59
  for result in tissue_results:
60
  label = result['label'].replace('_', ' ').title()
 
63
 
64
  return wound_formatted, tissue_formatted
65
 
66
+ # Interface Gradio
67
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
68
  gr.Markdown("""
69
  # 🏥 Classificador Especializado de Feridas
 
75
  with gr.Column():
76
  input_image = gr.Image(
77
  label="Upload da Imagem",
78
+ type="pil"
 
79
  )
80
  submit_btn = gr.Button("Analisar Ferida", variant="primary", size="lg")
81
 
 
118
  inputs=input_image,
119
  outputs=[wound_output, tissue_output]
120
  )
 
 
 
 
 
 
 
 
 
 
121
 
122
  if __name__ == "__main__":
123
+ demo.launch()