|
import gradio as gr |
|
import requests |
|
import base64 |
|
import os |
|
|
|
|
|
from argparse import Namespace |
|
|
|
|
|
from libra.eval import libra_eval |
|
|
|
|
|
DEFAULT_IMAGES = { |
|
"Image 1": "https://drive.google.com/uc?export=view&id=10bvR7a4WSyDAtWsNQUjPSs1GlcSxtP81", |
|
"Image 2": "https://drive.google.com/uc?export=view&id=1yzKM1eo8yBAGRcm7ayqUhxASXQHNANUa" |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def image_url_to_base64(image_url: str) -> str: |
|
""" |
|
将远程图片 URL 转换为 Base64 数据 URI。 |
|
如果请求失败,则返回提示文本。 |
|
""" |
|
try: |
|
response = requests.get(image_url) |
|
response.raise_for_status() |
|
base64_image = base64.b64encode(response.content).decode("utf-8") |
|
return f"data:image/jpeg;base64,{base64_image}" |
|
except Exception as e: |
|
return f"<p style='color: red;'>Failed to load image: {e}</p>" |
|
|
|
def generate_image_html(image_url: str) -> str: |
|
""" |
|
生成一个 <img> 标签的 HTML,用于在 Gradio 中以预览形式显示图片。 |
|
如果是 http(s) 链接,则尝试转换为 Base64;如果是本地路径,直接使用 file://。 |
|
""" |
|
|
|
if image_url.startswith("http"): |
|
base64_image = image_url_to_base64(image_url) |
|
return f'<img src="{base64_image}" style="width: 200px; height: auto; display: inline-block; margin: 10px; border-radius: 10px;" />' |
|
else: |
|
|
|
return f'<img src="file://{image_url}" style="width: 200px; height: auto; display: inline-block; margin: 10px; border-radius: 10px;" />' |
|
|
|
def generate_radiology_description( |
|
prompt: str, |
|
selected_current: str, |
|
uploaded_current: str, |
|
selected_prior: str, |
|
uploaded_prior: str, |
|
temperature: float, |
|
top_p: float, |
|
num_beams: int, |
|
max_new_tokens: int |
|
) -> str: |
|
""" |
|
核心推理函数: |
|
1. 获取用户输入或默认图片 |
|
2. 调用 libra_eval 来生成报告描述 |
|
3. 返回生成的结果或错误消息 |
|
""" |
|
|
|
current_image = uploaded_current if uploaded_current else DEFAULT_IMAGES.get(selected_current) |
|
prior_image = uploaded_prior if uploaded_prior else DEFAULT_IMAGES.get(selected_prior) |
|
|
|
|
|
if not current_image or not prior_image: |
|
return "Please select or upload both current and prior images." |
|
|
|
|
|
model_path = "/nfs/LLaVA-ai4bio/gla-biomed-playground/final_model/finetuned_model/llava-libra-test" |
|
conv_mode = "libra_v1" |
|
|
|
try: |
|
|
|
output = libra_eval( |
|
model_path=model_path, |
|
model_base=None, |
|
image_file=[current_image, prior_image], |
|
query=prompt, |
|
temperature=temperature, |
|
top_p=top_p, |
|
num_beams=num_beams, |
|
length_penalty=1.0, |
|
num_return_sequences=1, |
|
conv_mode=conv_mode, |
|
max_new_tokens=max_new_tokens |
|
) |
|
return output |
|
except Exception as e: |
|
return f"An error occurred: {str(e)}" |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
gr.Markdown("# Libra Radiology Report Generator") |
|
gr.Markdown("Use **Libra** to generate radiology image descriptions. Provide a **Current** and a **Prior** image below.") |
|
|
|
|
|
with gr.Row(): |
|
prompt_input = gr.Textbox( |
|
label="Prompt", |
|
value="Provide a detailed description of the findings in the radiology image." |
|
) |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("### Current Image") |
|
|
|
for img in DEFAULT_IMAGES.values(): |
|
gr.HTML(generate_image_html(img)) |
|
|
|
selected_current = gr.Radio( |
|
label="Select Current Image", |
|
choices=list(DEFAULT_IMAGES.keys()), |
|
value="Image 1" |
|
) |
|
|
|
uploaded_current = gr.Image( |
|
label="Or Upload Current Image", |
|
type="filepath", |
|
tool="editor" |
|
) |
|
|
|
with gr.Column(): |
|
gr.Markdown("### Prior Image") |
|
|
|
for img in DEFAULT_IMAGES.values(): |
|
gr.HTML(generate_image_html(img)) |
|
selected_prior = gr.Radio( |
|
label="Select Prior Image", |
|
choices=list(DEFAULT_IMAGES.keys()), |
|
value="Image 2" |
|
) |
|
uploaded_prior = gr.Image( |
|
label="Or Upload Prior Image", |
|
type="filepath", |
|
tool="editor" |
|
) |
|
|
|
|
|
with gr.Row(): |
|
temperature_slider = gr.Slider( |
|
label="Temperature", |
|
minimum=0.1, |
|
maximum=1.0, |
|
step=0.1, |
|
value=0.7 |
|
) |
|
top_p_slider = gr.Slider( |
|
label="Top P", |
|
minimum=0.1, |
|
maximum=1.0, |
|
step=0.1, |
|
value=0.8 |
|
) |
|
num_beams_slider = gr.Slider( |
|
label="Number of Beams", |
|
minimum=1, |
|
maximum=20, |
|
step=1, |
|
value=2 |
|
) |
|
max_tokens_slider = gr.Slider( |
|
label="Max New Tokens", |
|
minimum=10, |
|
maximum=4096, |
|
step=10, |
|
value=128 |
|
) |
|
|
|
|
|
output_text = gr.Textbox( |
|
label="Generated Description", |
|
lines=10 |
|
) |
|
|
|
|
|
generate_button = gr.Button("Generate Description") |
|
generate_button.click( |
|
fn=generate_radiology_description, |
|
inputs=[ |
|
prompt_input, |
|
selected_current, |
|
uploaded_current, |
|
selected_prior, |
|
uploaded_prior, |
|
temperature_slider, |
|
top_p_slider, |
|
num_beams_slider, |
|
max_tokens_slider |
|
], |
|
outputs=output_text |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |