Spaces:
Sleeping
Sleeping
File size: 5,109 Bytes
f787f97 7fb8860 f787f97 7fb8860 f787f97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import gradio as gr import torch from transformers import Qwen2VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info import re # Load the model on CPU def load_model(): model = Qwen2VLForConditionalGeneration.from_pretrained( "prithivMLmods/Qwen2-VL-OCR-2B-Instruct", torch_dtype=torch.float32, device_map="cpu" ) processor = AutoProcessor.from_pretrained("prithivMLmods/Qwen2-VL-OCR-2B-Instruct") return model, processor # Function to extract medicine names def extract_medicine_names(image): model, processor = load_model() # Prepare the message with the specific prompt for medicine extraction messages = [ { "role": "user", "content": [ { "type": "image", "image": image, }, {"type": "text", "text": "Extract and list ONLY the names of medicines/drugs from this prescription image. Output the medicine names as a numbered list without any additional information or descriptions."}, ], } ] # Prepare for inference text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) # Generate output generated_ids = model.generate(**inputs, max_new_tokens=256) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] # Remove <|im_end|> and any other special tokens that might appear in the output output_text = output_text.replace("<|im_end|>", "").strip() return output_text # Create a singleton model and processor to avoid reloading for each request model_instance = None processor_instance = None def get_model_and_processor(): global model_instance, processor_instance if model_instance is None or processor_instance is None: model_instance, processor_instance = load_model() return model_instance, processor_instance # Optimized extraction function that uses the singleton model def extract_medicine_names_optimized(image): if image is None: return "Please upload an image." model, processor = get_model_and_processor() # Prepare the message with the specific prompt for medicine extraction messages = [ { "role": "user", "content": [ { "type": "image", "image": image, }, {"type": "text", "text": "Extract and list ONLY the names of medicines/drugs from this prescription image. Output the medicine names as a numbered list without any additional information or descriptions."}, ], } ] # Prepare for inference text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) # Generate output generated_ids = model.generate(**inputs, max_new_tokens=256) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] # Remove <|im_end|> and any other special tokens that might appear in the output output_text = output_text.replace("<|im_end|>", "").strip() return output_text # Create Gradio interface with gr.Blocks(title="Medicine Name Extractor") as app: gr.Markdown("# Medicine Name Extractor") gr.Markdown("Upload a medical prescription image to extract the names of medicines.") with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Upload Prescription Image") extract_btn = gr.Button("Extract Medicine Names", variant="primary") with gr.Column(): output_text = gr.Textbox(label="Extracted Medicine Names", lines=10) extract_btn.click( fn=extract_medicine_names_optimized, inputs=input_image, outputs=output_text ) gr.Markdown("### Notes") gr.Markdown("- This tool uses the Qwen2-VL-OCR model to extract text from prescription images") gr.Markdown("- For best results, ensure the prescription image is clear and readable") gr.Markdown("- Processing may take some time as the model runs on CPU") # Launch the app if __name__ == "__main__": app.launch() |