Spaces:
Sleeping
Sleeping
import os | |
import urllib.request | |
model_urls = { | |
"sam2_hiera_tiny.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt", | |
"sam2_hiera_small.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt", | |
"sam2_hiera_base_plus.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt", | |
"sam2_hiera_large.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt", | |
} | |
def download_models(): | |
for filename, url in model_urls.items(): | |
if not os.path.exists(filename): | |
print(f"Downloading {filename}...") | |
urllib.request.urlretrieve(url, filename) | |
else: | |
print(f"{filename} already exists, skipping download.") | |
download_models() | |
import gradio as gr | |
from gradio.themes.base import Base | |
from gradio.themes.utils import colors | |
import numpy as np | |
import pandas as pd | |
import cv2 | |
import torch | |
import torch.nn as nn | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from fastai.vision import * | |
from fastai.vision.all import * | |
from sklearn.metrics import confusion_matrix | |
from sklearn.model_selection import train_test_split | |
import tensorflow as tf | |
import re | |
import json | |
import ast | |
import openai | |
import tiktoken | |
import shutil | |
import concurrent | |
import textwrap | |
from time import sleep | |
from csv import writer | |
from tqdm import tqdm | |
from scipy import spatial | |
from pptx import Presentation | |
from PyPDF2 import PdfReader | |
from openai import OpenAI | |
from IPython.display import display, Markdown, Latex, HTML | |
from transformers import GPT2Tokenizer | |
from termcolor import colored | |
from sam2.build_sam import build_sam2 | |
from sam2.sam2_image_predictor import SAM2ImagePredictor | |
sam2_checkpoint = "sam2_hiera_small.pt" | |
model_cfg = "sam2_hiera_s.yaml" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device) | |
predictor = SAM2ImagePredictor(sam2_model) | |
checkpoint_path = "sam2_lr0.0001_wd0.01_900.torch" | |
predictor.model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu'))) | |
def display_thread(thread_id): | |
for message in client.beta.threads.messages.list(thread_id=thread_id): | |
display(message.content[0].text.value) | |
def read_file(filepath, max_pages=None): | |
if filepath.endswith('.pdf'): | |
return read_pdf(filepath, max_pages) | |
elif filepath.endswith('.txt'): | |
return read_text_file(filepath) | |
elif filepath.endswith('.docx'): | |
return read_docx(filepath) | |
elif filepath.endswith('.xlsx'): | |
return read_xlsx(filepath) | |
elif filepath.endswith('.pptx'): | |
return read_pptx(filepath) | |
else: | |
raise ValueError("Unsupported file type") | |
def read_pdf(filepath, max_pages=None): | |
reader = PdfReader(filepath) | |
pdf_text = "" | |
page_number = 0 | |
for page in reader.pages: | |
page_number += 1 | |
if max_pages and (page_number > max_pages): | |
break | |
page_text = page.extract_text() | |
if page_text: | |
page_text = re.sub(r'\n+', ' ', page_text) | |
pdf_text += page_text + f"\nPage Number: {page_number}\n" | |
else: | |
pdf_text += f"\n[No extractable text on Page {page_number}]\n" | |
return pdf_text | |
calc_similarity = lambda x, y: 1 - spatial.distance.cosine(x.data[0].embedding, y.data[0].embedding) | |
def pretty_print(df): | |
return display(HTML(df.to_html().replace("\\n", "<br>"))) | |
def read_directory(directory): | |
assert os.path.exists(directory) | |
res_dict = {} | |
for filename in os.listdir(directory): | |
if filename.endswith(('pdf', 'txt', 'docx', 'pptx')): | |
filepath = os.path.join(directory, filename) | |
text = read_file(filepath, 2) | |
res_dict[filename] = (filepath, text) | |
df = pd.DataFrame(res_dict).T | |
df = df.reset_index() | |
df.columns = ["Filename", "Filepath", "Text"] | |
return df | |
# Initialize GPT tokenizer | |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
tokenizer.model_max_length = int(1e30) | |
def ask_chatbot(question, context, m): | |
max_context_tokens = 16385 | |
truncated_context = truncate_context(context, max_context_tokens) | |
response = client.chat.completions.create( | |
model=m, | |
messages=[ | |
{"role": "system", "content": """You are an expert doctor who treats chronic wounds, and you know every single thing about wounds and how to treat them as well as preventing them from getting worse. | |
The user will provide the following inputs: Name, Gender, Age, Pre-existing Medical Conditions, Wound Part of Body, Wound Classficiation, Colors of the Wounds (as percents out of 100). | |
Please provide the medical advice in 2 concise paragraphs that must incorporate the following key features everytime: | |
1. **Wound Risk Score (1-100):** You will be given a PDF and you shall review it and use it to aid in your risk score generation. The wound risk score should be between 1-100! Of course, any color percentages **less than 3** shouldn't be taken into consideration when making the score. | |
**Make sure to be specific!** | |
2. **Medical Advice:** Give the patient bulleted directions on how to monitor and care for their wound. **Make sure to include if the person needs to go see a doctor as soon as possible.**"""}, | |
{"role": "user", "content": truncated_context}, | |
{"role": "user", "content": question} | |
] | |
) | |
return response.choices[0].message.content | |
def truncate_context(context, max_tokens): | |
tokens = tokenizer.encode(context) | |
if len(tokens) > max_tokens: | |
truncated_tokens = tokens[:max_tokens] | |
return tokenizer.decode(truncated_tokens) | |
return context | |
file_content = read_file("Wound Healing Risk Assessment.pdf") | |
api_key = os.environ.get("OPENAI_API_KEY") | |
client = OpenAI(api_key=api_key) | |
model="gpt-4o-mini" | |
assistant = client.beta.assistants.create( | |
name="Wound Treater", | |
instructions="""You are an expert doctor who treats chronic wounds, and you know every single thing about wounds and how to treat them as well as preventing them from getting worse. | |
The user will provide the following inputs: Name, Gender, Age, Pre-existing Medical Conditions, Wound Part of Body, Wound Classficiation, Colors of the Wounds (as percents out of 100). | |
Please provide the medical advice in 2 concise paragraphs that must incorporate the following key features everytime: | |
1. **Wound Risk Score (1-100):** Generate a wound risk score from 1-100, 1 being no risk and 100 being going to see a medical professional immediately! Of course, any color percentages **less than 3** shouldn't be taken into consideration when making the score. | |
**Make sure to be specific and list the components of the wound risk score.** | |
2. **Medical Advice:** Give the patient directions on how to monitor and care for their wound. **Make sure to include if the person needs to go see a doctor as soon as possible.**""", | |
model=model) | |
def get_assistant_response(name="None", gender="None", age="None", conditions="None", bodyPart="None", typeWound="None", red="None", orange="None", yellow="None", magenta="None", white="None", gray="None", black="None"): | |
thread = client.beta.threads.create() | |
input_text = "Name: " + str(name) + ", Gender: " + str(gender) + ", Age: " + str(age) + ", Pre-Existing Medical Conditions: " + str(conditions) + ", Part of Body: " + str(bodyPart) + ", Type of Wound: " + str(typeWound) + ", Wound Colors (Red, Orange, Yellow, Magenta, White, Gray, Black): [" + str(red) + ", " + str(orange) + ", " + str(yellow) + ", " + str(magenta) + ", " + str(white) + ", " + str(gray) + ", " + str(black) + "]" | |
message = client.beta.threads.messages.create( | |
thread_id=thread.id, | |
role="user", | |
content=input_text) | |
run = client.beta.threads.runs.create( | |
thread_id=thread.id, | |
assistant_id=assistant.id, | |
) | |
sleep(15) | |
return input_text, client.beta.threads.messages.list(thread.id).data[0].content[0].text.value | |
def get_response_with_context(name="None", gender="None", age="None", conditions="None", bodyPart="None", typeWound="None", red="None", orange="None", yellow="None", magenta="None", white="None", gray="None", black="None"): | |
input_text = "Name: " + str(name) + ", Gender: " + str(gender) + ", Age: " + str(age) + ", Pre-Existing Medical Conditions: " + str(conditions) + ", Part of Body: " + str(bodyPart) + ", Type of Wound: " + str(typeWound) + ", Wound Colors (Red, Orange, Yellow, Magenta, White, Gray, Black): [" + str(red) + ", " + str(orange) + ", " + str(yellow) + ", " + str(magenta) + ", " + str(white) + ", " + str(gray) + ", " + str(black) + "]" | |
response = ask_chatbot(input_text, file_content, model) | |
return input_text, response | |
wounds = [] | |
learn = load_learner('model.pkl') | |
def one_step_inference(image_path, threshold=0.5): | |
image = cv2.imread(image_path) | |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
image_hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) | |
predictor.set_image(image) | |
high_res_features = [feat[-1].unsqueeze(0) for feat in predictor._features["high_res_feats"]] | |
with torch.no_grad(): | |
sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(points=None, boxes=None, masks=None) | |
low_res_masks, _, _, _ = predictor.model.sam_mask_decoder( | |
image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0), | |
image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(), | |
sparse_prompt_embeddings=sparse_embeddings, | |
dense_prompt_embeddings=dense_embeddings, | |
multimask_output=False, | |
repeat_image=False, | |
high_res_features=high_res_features,) | |
mask = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1]) | |
final_mask = (mask > threshold).cpu().detach().numpy()[0][0] | |
final_mask = final_mask.astype("uint8") | |
selected_pixels = cv2.bitwise_and(image_rgb, image_rgb, mask=final_mask) | |
selected_pixels = image_hsv[final_mask == 1] | |
colors = classify_colors(selected_pixels) | |
return colors["Red"], colors["Orange"], colors["Yellow"], colors["Magenta"], colors["White"], colors["Gray"], colors["Black"] | |
def classify_colors(hsv_pixels): | |
color_ranges = { | |
'Red': [(0, 50, 50), (10, 255, 255)], # Red wraps around | |
'Red2': [(170, 50, 50), (179, 255, 255)], | |
'Orange': [(11, 50, 50), (25, 255, 255)], | |
'Yellow': [(26, 50, 50), (35, 255, 255)], | |
'Green': [(36, 50, 50), (85, 255, 255)], | |
'Cyan': [(86, 50, 50), (95, 255, 255)], | |
'Blue': [(96, 50, 50), (130, 255, 255)], | |
'Purple': [(131, 50, 50), (160, 255, 255)], | |
'Magenta': [(161, 50, 50), (169, 255, 255)], | |
'White': [(0, 0, 200), (179, 55, 255)], # High brightness, low saturation | |
'Gray': [(0, 0, 50), (179, 50, 200)], # Low saturation, varying brightness | |
'Black': [(0, 0, 0), (179, 50, 50)] # Low brightness | |
} | |
hsv_pixels = hsv_pixels.reshape(-1, 3) | |
color_counts = {color: 0 for color in color_ranges} | |
total_pixels = hsv_pixels.shape[0] | |
for pixel in hsv_pixels: | |
h, s, v = pixel | |
for color, ranges in color_ranges.items(): | |
if isinstance(ranges[0], tuple): | |
lower = ranges[0] | |
upper = ranges[1] | |
if (lower[0] <= h <= upper[0] or lower[0] > upper[0] and (h >= lower[0] or h <= upper[0])) \ | |
and lower[1] <= s <= upper[1] and lower[2] <= v <= upper[2]: | |
color_counts[color] += 1 | |
break | |
else: | |
lower, upper = ranges | |
if lower[0] <= h <= upper[0] and lower[1] <= s <= upper[1] and lower[2] <= v <= upper[2]: | |
color_counts[color] += 1 | |
break | |
color_counts["Red"] += color_counts["Red2"] | |
del color_counts["Red2"] | |
if(total_pixels == 0): | |
total_pixels = 1 | |
color_percentages = {color: (count / total_pixels) * 100 for color, count in color_counts.items()} | |
return color_percentages | |
def predict_image(image_path): | |
img = PILImage.create(image_path) | |
pred, pred_idx, probs = learn.predict(img) | |
return pred | |
def reveal_group(): | |
return gr.update(visible=True) | |
def hide_group(): | |
return gr.update(visible=False) | |
def add_wound(image, partOfBody): | |
wounds.append({"image": image, "description": partOfBody}) | |
return image, partOfBody | |
def clear_inputs(image, partOfBody): | |
image=None | |
partOfBody="" | |
return image, partOfBody | |
with gr.Blocks(theme=gr.themes.Glass()) as demo: | |
gr.Markdown("<center><h1>Welcome to WoundView!</h1></center>") | |
# Sign-up Group | |
with gr.Group() as sign_up: | |
gr.Markdown("<center><h2>New User</h2></center>") | |
name = gr.Textbox(label="Full Name", placeholder="Enter your name here...") | |
gender = gr.Radio(["Male", "Female"], label="Gender") | |
age = gr.Number(label="Age") | |
conditions = gr.CheckboxGroup(["Diabetes", "Peripheral Arterial Disease", "Venous Insufficiency", "Obesity", "Smoking", ], label="Pre-Existing Medical Conditions") | |
gr.Markdown("<span style='color: red;'>Some fields were left empty. Please fill them out!</span>", visible=False) | |
sign_up_btn = gr.Button(value="Sign Up", variant="secondary") | |
# Home Group | |
with gr.Group(visible=False) as home: | |
gr.Markdown("<center><h2>Wounds</h2></center>") | |
with gr.Row(visible=False) as wound_display: | |
wound_image = gr.Image() | |
with gr.Column(): | |
wound_title = gr.Markdown("<center><h2>Wound Description</h2></center>") | |
with gr.Row(): | |
gr.Markdown("<center>Part of Body:</center>") | |
wound_desc = gr.Textbox(container=False) | |
with gr.Row(): | |
gr.Markdown("<center>Type of Wound:</center>") | |
wound_classification = gr.Textbox(container=False) | |
gr.Markdown("<center><h4>Colors:</h4></center>") | |
with gr.Row(): | |
gr.Markdown("<center>Red:</center>") | |
red_percent = gr.Textbox(container=False) | |
with gr.Row(): | |
gr.Markdown("<center>Orange:</center>") | |
orange_percent = gr.Textbox(container=False) | |
with gr.Row(): | |
gr.Markdown("<center>Yellow:</center>") | |
yellow_percent = gr.Textbox(container=False) | |
with gr.Row(): | |
gr.Markdown("<center>Magenta:</center>") | |
magenta_percent = gr.Textbox(container=False) | |
with gr.Row(): | |
gr.Markdown("<center>White:</center>") | |
white_percent = gr.Textbox(container=False) | |
with gr.Row(): | |
gr.Markdown("<center>Gray:</center>") | |
gray_percent = gr.Textbox(container=False) | |
with gr.Row(): | |
gr.Markdown("<center>Black:</center>") | |
black_percent = gr.Textbox(container=False) | |
ai_chat_btn = gr.Button(value="AI ChatBot") | |
add_new_btn = gr.Button(value="Add New") | |
# Add New Group | |
with gr.Group(visible=False) as add_new: | |
gr.Markdown("<center><h2>Add New Wound</h2></center>") | |
with gr.Row(): | |
with gr.Column(): | |
image = gr.Image(label="Picture of wound", type="filepath") | |
examples = gr.Examples(examples=["3_photo.jpg", "0_photo.jpg", "12_photo.jpg", "13_photo.jpg", "1_photo.jpg", "4_photo.jpg", "67_photo.jpg", "71_photo.jpg"], inputs=image) | |
partOfBody = gr.Radio(["Head", "Arm", "Hand", "Back", "Stomach", "Leg", "Foot"], label="What part of the body is the wound on?") | |
with gr.Row(): | |
confirm_add_new_btn = gr.Button(value="Confirm") | |
cancel_add_new_btn = gr.Button(value="Cancel") | |
with gr.Group(visible=False) as ai_chat: | |
gr.Markdown("<center><h2>AI Chat</h2></center>") | |
with gr.Column() as gpt: | |
gr.Markdown("<center><h3>Chat GPT</h3></center>") | |
chatGPTInput = gr.Textbox(container=False) | |
chatGPTOutput = gr.Textbox(container=False) | |
cancel_ai_chat_btn = gr.Button(value="Cancel") | |
# Button Click Events | |
sign_up_btn.click(hide_group, outputs=sign_up).then(reveal_group, outputs=home) | |
add_new_btn.click(hide_group, outputs=home).then(reveal_group, outputs=add_new | |
).then(clear_inputs, | |
inputs=[image, partOfBody], | |
outputs=[image, partOfBody] | |
) | |
confirm_add_new_btn.click(add_wound, | |
inputs=[image, partOfBody], | |
outputs=[wound_image, wound_desc] | |
).then(reveal_group, outputs=home | |
).then(hide_group, outputs=add_new | |
).then(reveal_group, outputs=wound_display | |
).then(predict_image, | |
inputs=image, | |
outputs=wound_classification | |
).then(one_step_inference, | |
inputs=image, | |
outputs=[red_percent, orange_percent, yellow_percent, magenta_percent, white_percent, gray_percent, black_percent] | |
) | |
cancel_add_new_btn.click(hide_group, outputs=add_new).then(reveal_group, outputs=home) | |
ai_chat_btn.click(hide_group, outputs=home).then(reveal_group, outputs=ai_chat | |
).then(get_response_with_context, | |
inputs=[name, gender, age, conditions, partOfBody, wound_classification, red_percent, orange_percent, yellow_percent, magenta_percent, white_percent, gray_percent, black_percent], | |
outputs=[chatGPTInput, chatGPTOutput] | |
) | |
cancel_ai_chat_btn.click(hide_group, outputs=ai_chat).then(reveal_group, outputs=home) | |
demo.launch(share=True) |