ybhavsar2009 commited on
Commit
5672a87
·
verified ·
1 Parent(s): 31017bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -34
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- #os.chdir("/home/user/app/segment-anything-2")
3
  import urllib.request
4
 
5
  model_urls = {
@@ -17,17 +16,11 @@ def download_models():
17
  else:
18
  print(f"{filename} already exists, skipping download.")
19
 
20
- # Call it at startup
21
  download_models()
22
 
23
- """wget -O sam2_hiera_tiny.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt"
24
- wget -O sam2_hiera_small.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt"
25
- wget -O sam2_hiera_base_plus.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt"
26
- wget -O sam2_hiera_large.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"
27
- """
28
-
29
- #import os
30
  import gradio as gr
 
 
31
  import numpy as np
32
  import pandas as pd
33
  import cv2
@@ -58,24 +51,15 @@ from PyPDF2 import PdfReader
58
  from openai import OpenAI
59
  from IPython.display import display, Markdown, Latex, HTML
60
  from transformers import GPT2Tokenizer
61
- #from google.colab import files # for uploading files
62
- from termcolor import colored # for colored text output
63
-
64
- #%matplotlib inline
65
- #%config InlineBackend.figure_format='retina'
66
-
67
- #os.chdir("segment-anything-2")
68
 
69
  from sam2.build_sam import build_sam2
70
  from sam2.sam2_image_predictor import SAM2ImagePredictor
71
-
72
  sam2_checkpoint = "sam2_hiera_small.pt"
73
  model_cfg = "sam2_hiera_s.yaml"
74
-
75
  device = "cuda" if torch.cuda.is_available() else "cpu"
76
  sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
77
  predictor = SAM2ImagePredictor(sam2_model)
78
-
79
  checkpoint_path = "sam2_lr0.0001_wd0.01_900.torch"
80
  predictor.model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
81
 
@@ -107,9 +91,7 @@ def read_pdf(filepath, max_pages=None):
107
  break
108
  page_text = page.extract_text()
109
 
110
- # Check if page_text is None before proceeding
111
  if page_text:
112
- # Replace multiple newlines with a space to make it readable
113
  page_text = re.sub(r'\n+', ' ', page_text)
114
  pdf_text += page_text + f"\nPage Number: {page_number}\n"
115
  else:
@@ -144,7 +126,7 @@ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
144
  tokenizer.model_max_length = int(1e30)
145
 
146
  def ask_chatbot(question, context, m):
147
- max_context_tokens = 16385 # Adjust based on the maximum allowable context tokens
148
  truncated_context = truncate_context(context, max_context_tokens)
149
  response = client.chat.completions.create(
150
  model=m,
@@ -232,7 +214,7 @@ def one_step_inference(image_path, threshold=0.5):
232
  sparse_prompt_embeddings=sparse_embeddings,
233
  dense_prompt_embeddings=dense_embeddings,
234
  multimask_output=False,
235
- repeat_image=False, # Fixed argument
236
  high_res_features=high_res_features,)
237
 
238
  mask = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])
@@ -245,9 +227,7 @@ def one_step_inference(image_path, threshold=0.5):
245
 
246
  return colors["Red"], colors["Orange"], colors["Yellow"], colors["Magenta"], colors["White"], colors["Gray"], colors["Black"]
247
 
248
- # Replace this function in your original code
249
  def classify_colors(hsv_pixels):
250
- # Define color ranges in HSV
251
  color_ranges = {
252
  'Red': [(0, 50, 50), (10, 255, 255)], # Red wraps around
253
  'Red2': [(170, 50, 50), (179, 255, 255)],
@@ -262,17 +242,13 @@ def classify_colors(hsv_pixels):
262
  'Gray': [(0, 0, 50), (179, 50, 200)], # Low saturation, varying brightness
263
  'Black': [(0, 0, 0), (179, 50, 50)] # Low brightness
264
  }
265
- # Flatten the HSV pixels to process as a single list
266
  hsv_pixels = hsv_pixels.reshape(-1, 3)
267
- # Initialize counts for each color
268
  color_counts = {color: 0 for color in color_ranges}
269
- # Total number of pixels
270
  total_pixels = hsv_pixels.shape[0]
271
- # Classify each pixel
272
  for pixel in hsv_pixels:
273
  h, s, v = pixel
274
  for color, ranges in color_ranges.items():
275
- if isinstance(ranges[0], tuple): # Handles multiple ranges (e.g., red)
276
  lower = ranges[0]
277
  upper = ranges[1]
278
  if (lower[0] <= h <= upper[0] or lower[0] > upper[0] and (h >= lower[0] or h <= upper[0])) \
@@ -284,7 +260,7 @@ def classify_colors(hsv_pixels):
284
  if lower[0] <= h <= upper[0] and lower[1] <= s <= upper[1] and lower[2] <= v <= upper[2]:
285
  color_counts[color] += 1
286
  break
287
- # Calculate percentages
288
  color_counts["Red"] += color_counts["Red2"]
289
  del color_counts["Red2"]
290
  if(total_pixels == 0):
@@ -304,7 +280,6 @@ def reveal_group():
304
  def hide_group():
305
  return gr.update(visible=False)
306
 
307
- # Add new wound to the list
308
  def add_wound(image, partOfBody):
309
  wounds.append({"image": image, "description": partOfBody})
310
  return image, partOfBody
@@ -314,8 +289,47 @@ def clear_inputs(image, partOfBody):
314
  partOfBody=""
315
  return image, partOfBody
316
 
317
- # Initialize Gradio app
318
- with gr.Blocks(theme=gr.themes.Glass(dark_mode=True)) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  gr.Markdown("<center><h1>Welcome to WoundView!</h1></center>")
320
 
321
  # Sign-up Group
 
1
  import os
 
2
  import urllib.request
3
 
4
  model_urls = {
 
16
  else:
17
  print(f"{filename} already exists, skipping download.")
18
 
 
19
  download_models()
20
 
 
 
 
 
 
 
 
21
  import gradio as gr
22
+ from gradio.themes.base import Base
23
+ from gradio.themes.utils import colors
24
  import numpy as np
25
  import pandas as pd
26
  import cv2
 
51
  from openai import OpenAI
52
  from IPython.display import display, Markdown, Latex, HTML
53
  from transformers import GPT2Tokenizer
54
+ from termcolor import colored
 
 
 
 
 
 
55
 
56
  from sam2.build_sam import build_sam2
57
  from sam2.sam2_image_predictor import SAM2ImagePredictor
 
58
  sam2_checkpoint = "sam2_hiera_small.pt"
59
  model_cfg = "sam2_hiera_s.yaml"
 
60
  device = "cuda" if torch.cuda.is_available() else "cpu"
61
  sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
62
  predictor = SAM2ImagePredictor(sam2_model)
 
63
  checkpoint_path = "sam2_lr0.0001_wd0.01_900.torch"
64
  predictor.model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
65
 
 
91
  break
92
  page_text = page.extract_text()
93
 
 
94
  if page_text:
 
95
  page_text = re.sub(r'\n+', ' ', page_text)
96
  pdf_text += page_text + f"\nPage Number: {page_number}\n"
97
  else:
 
126
  tokenizer.model_max_length = int(1e30)
127
 
128
  def ask_chatbot(question, context, m):
129
+ max_context_tokens = 16385
130
  truncated_context = truncate_context(context, max_context_tokens)
131
  response = client.chat.completions.create(
132
  model=m,
 
214
  sparse_prompt_embeddings=sparse_embeddings,
215
  dense_prompt_embeddings=dense_embeddings,
216
  multimask_output=False,
217
+ repeat_image=False,
218
  high_res_features=high_res_features,)
219
 
220
  mask = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])
 
227
 
228
  return colors["Red"], colors["Orange"], colors["Yellow"], colors["Magenta"], colors["White"], colors["Gray"], colors["Black"]
229
 
 
230
  def classify_colors(hsv_pixels):
 
231
  color_ranges = {
232
  'Red': [(0, 50, 50), (10, 255, 255)], # Red wraps around
233
  'Red2': [(170, 50, 50), (179, 255, 255)],
 
242
  'Gray': [(0, 0, 50), (179, 50, 200)], # Low saturation, varying brightness
243
  'Black': [(0, 0, 0), (179, 50, 50)] # Low brightness
244
  }
 
245
  hsv_pixels = hsv_pixels.reshape(-1, 3)
 
246
  color_counts = {color: 0 for color in color_ranges}
 
247
  total_pixels = hsv_pixels.shape[0]
 
248
  for pixel in hsv_pixels:
249
  h, s, v = pixel
250
  for color, ranges in color_ranges.items():
251
+ if isinstance(ranges[0], tuple):
252
  lower = ranges[0]
253
  upper = ranges[1]
254
  if (lower[0] <= h <= upper[0] or lower[0] > upper[0] and (h >= lower[0] or h <= upper[0])) \
 
260
  if lower[0] <= h <= upper[0] and lower[1] <= s <= upper[1] and lower[2] <= v <= upper[2]:
261
  color_counts[color] += 1
262
  break
263
+
264
  color_counts["Red"] += color_counts["Red2"]
265
  del color_counts["Red2"]
266
  if(total_pixels == 0):
 
280
  def hide_group():
281
  return gr.update(visible=False)
282
 
 
283
  def add_wound(image, partOfBody):
284
  wounds.append({"image": image, "description": partOfBody})
285
  return image, partOfBody
 
289
  partOfBody=""
290
  return image, partOfBody
291
 
292
+ class CustomDarkGlass(Base):
293
+ def __init__(self):
294
+ super().__init__(
295
+ primary_hue=colors.blue,
296
+ secondary_hue=colors.slate,
297
+ neutral_hue=colors.gray,
298
+ font=[Base.fonts.SANS],
299
+ )
300
+ self.set(
301
+ # Backgrounds
302
+ body_background_fill="#1f2937", # dark slate / near-black
303
+ body_text_color="#ffffff", # white text
304
+ block_background_fill="#374151", # medium dark gray
305
+ block_border_color="#4b5563", # border gray
306
+ block_shadow=None,
307
+ block_title_text_color="#ffffff",
308
+
309
+ # Input components
310
+ input_background_fill="#4b5563", # slightly lighter
311
+ input_border_color="#6b7280", # input border gray
312
+ input_text_color="#ffffff", # white input text
313
+
314
+ # Buttons
315
+ button_primary_background_fill="#2563eb", # blue-600
316
+ button_primary_text_color="#ffffff",
317
+ button_secondary_background_fill="#4b5563",
318
+ button_secondary_text_color="#ffffff",
319
+
320
+ # Checkbox / Radio
321
+ checkbox_label_text_color="#ffffff",
322
+ checkbox_background_color_selected="#2563eb",
323
+ checkbox_border_color="#6b7280",
324
+
325
+ # Slider / Progress
326
+ slider_color="#2563eb",
327
+ progress_color="#2563eb",
328
+ )
329
+
330
+ theme = CustomDarkGlass()
331
+
332
+ with gr.Blocks(theme=theme) as demo:
333
  gr.Markdown("<center><h1>Welcome to WoundView!</h1></center>")
334
 
335
  # Sign-up Group