|
|
|
import torch |
|
import gradio as gr |
|
import os |
|
import requests |
|
import base64 |
|
|
|
from libra.eval import libra_eval |
|
|
|
def generate_radiology_description( |
|
prompt: str, |
|
uploaded_current: str, |
|
uploaded_prior: str, |
|
temperature: float, |
|
top_p: float, |
|
num_beams: int, |
|
max_new_tokens: int |
|
) -> str: |
|
|
|
|
|
if not uploaded_current or not uploaded_prior: |
|
return "Please upload both current and prior images." |
|
|
|
|
|
model_path = "X-iZhang/libra-v1.0-7b" |
|
conv_mode = "libra_v1" |
|
|
|
try: |
|
|
|
print("Before calling libra_eval") |
|
output = libra_eval( |
|
model_path=model_path, |
|
model_base=None, |
|
image_file=[uploaded_current, uploaded_prior], |
|
query=prompt, |
|
temperature=temperature, |
|
top_p=top_p, |
|
num_beams=num_beams, |
|
length_penalty=1.0, |
|
num_return_sequences=1, |
|
conv_mode=conv_mode, |
|
max_new_tokens=max_new_tokens |
|
) |
|
print("After calling libra_eval, result:", output) |
|
return output |
|
except Exception as e: |
|
return f"An error occurred: {str(e)}" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
gr.Markdown("# Libra Radiology Report Generator (Local Upload Only)") |
|
gr.Markdown("Upload **Current** and **Prior** images below to generate a radiology description using the Libra model.") |
|
|
|
|
|
prompt_input = gr.Textbox( |
|
label="Prompt", |
|
value="Describe the key findings in these two images." |
|
) |
|
|
|
|
|
with gr.Row(): |
|
uploaded_current = gr.Image( |
|
label="Upload Current Image", |
|
type="filepath" |
|
) |
|
uploaded_prior = gr.Image( |
|
label="Upload Prior Image", |
|
type="filepath" |
|
) |
|
|
|
|
|
with gr.Row(): |
|
temperature_slider = gr.Slider( |
|
label="Temperature", |
|
minimum=0.1, |
|
maximum=1.0, |
|
step=0.1, |
|
value=0.7 |
|
) |
|
top_p_slider = gr.Slider( |
|
label="Top P", |
|
minimum=0.1, |
|
maximum=1.0, |
|
step=0.1, |
|
value=0.8 |
|
) |
|
num_beams_slider = gr.Slider( |
|
label="Number of Beams", |
|
minimum=1, |
|
maximum=20, |
|
step=1, |
|
value=2 |
|
) |
|
max_tokens_slider = gr.Slider( |
|
label="Max New Tokens", |
|
minimum=10, |
|
maximum=4096, |
|
step=10, |
|
value=128 |
|
) |
|
|
|
|
|
output_text = gr.Textbox( |
|
label="Generated Description", |
|
lines=10 |
|
) |
|
|
|
|
|
generate_button = gr.Button("Generate Description") |
|
generate_button.click( |
|
fn=generate_radiology_description, |
|
inputs=[ |
|
prompt_input, |
|
uploaded_current, |
|
uploaded_prior, |
|
temperature_slider, |
|
top_p_slider, |
|
num_beams_slider, |
|
max_tokens_slider |
|
], |
|
outputs=output_text |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |