File size: 4,371 Bytes
8ed8457
 
 
57a1258
56d8f41
8ed8457
56d8f41
fe01251
8ed8457
 
 
 
fe01251
8ed8457
56d8f41
fe01251
 
 
 
 
 
 
 
 
 
 
56d8f41
 
 
8ed8457
 
 
fe01251
56d8f41
fe01251
 
 
8ed8457
56d8f41
 
 
 
fe01251
56d8f41
 
 
8ed8457
 
56d8f41
 
 
8ed8457
56d8f41
57a1258
 
56d8f41
57a1258
 
 
8ed8457
 
57a1258
56d8f41
57a1258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ed8457
57a1258
56d8f41
57a1258
 
 
 
56d8f41
57a1258
 
56d8f41
fe01251
57a1258
 
 
 
 
56d8f41
57a1258
 
 
 
 
56d8f41
8ed8457
 
 
 
 
 
 
 
 
 
56d8f41
 
 
 
 
8ed8457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56d8f41
 
 
8ed8457
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, TextStreamer
import torch
import gc
import os

# Enable better CPU performance
torch.set_num_threads(4)
device = "cpu"

def load_model():
    model_name = "forestav/unsloth_vision_radiography_finetune"
    base_model_name = "unsloth/Llama-3.2-11B-Vision-Instruct"  # Correct base model
    
    print("Loading tokenizer and processor...")
    # Load tokenizer from base model
    tokenizer = AutoTokenizer.from_pretrained(
        base_model_name,
        trust_remote_code=True
    )
    
    # Load processor from base model
    processor = AutoProcessor.from_pretrained(
        base_model_name,
        trust_remote_code=True
    )
    
    print("Loading model...")
    # Load model with CPU optimizations
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="cpu",
        torch_dtype=torch.float32,
        low_cpu_mem_usage=True,
        offload_folder="offload",
        offload_state_dict=True,
        trust_remote_code=True
    )
    
    print("Quantizing model...")
    model = torch.quantization.quantize_dynamic(
        model,
        {torch.nn.Linear},
        dtype=torch.qint8
    )
    
    return model, tokenizer, processor

# Create offload directory if it doesn't exist
os.makedirs("offload", exist_ok=True)

# Initialize model and tokenizer globally
print("Starting model initialization...")
try:
    model, tokenizer, processor = load_model()
    print("Model loaded and quantized successfully!")
except Exception as e:
    print(f"Error loading model: {str(e)}")
    raise

def analyze_image(image, instruction):
    try:
        # Clear memory
        gc.collect()
        
        if instruction.strip() == "":
            instruction = "You are an expert radiographer. Describe accurately what you see in this image."
        
        # Prepare the messages
        messages = [
            {"role": "user", "content": [
                {"type": "image"},
                {"type": "text", "text": instruction}
            ]}
        ]
        
        # Process the image and text
        inputs = processor(
            images=image,
            text=tokenizer.apply_chat_template(messages, add_generation_prompt=True),
            return_tensors="pt"
        )
        
        # Generate with conservative settings for CPU
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=128,
                temperature=1.0,
                min_p=0.1,
                use_cache=True,
                pad_token_id=tokenizer.eos_token_id,
                num_beams=1
            )
        
        # Decode the response
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Clean up
        del outputs
        gc.collect()
            
        return response
    except Exception as e:
        return f"Error processing image: {str(e)}\nPlease try again with a smaller image or different settings."

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("""
    # Medical Image Analysis Assistant
    Upload a medical image and receive a professional description from an AI radiographer.
    """)
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(
                type="pil",
                label="Upload Medical Image",
                max_pixels=1500000  # Limit image size
            )
            instruction_input = gr.Textbox(
                label="Custom Instruction (optional)",
                placeholder="You are an expert radiographer. Describe accurately what you see in this image.",
                lines=2
            )
            submit_btn = gr.Button("Analyze Image")
        
        with gr.Column():
            output_text = gr.Textbox(label="Analysis Result", lines=10)
    
    # Handle the submission
    submit_btn.click(
        fn=analyze_image,
        inputs=[image_input, instruction_input],
        outputs=output_text
    )
    
    gr.Markdown("""
    ### Notes:
    - The model runs on CPU and may take several moments to process each image
    - For best results, upload images smaller than 1.5MP
    - Please be patient during processing
    """)

# Launch the app
if __name__ == "__main__":
    demo.launch()