shukdevdatta123 commited on
Commit
c7cc8ee
·
verified ·
1 Parent(s): c447daa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -82
app.py CHANGED
@@ -1,38 +1,15 @@
1
  import gradio as gr
2
- from transformers.image_utils import load_image
3
- from threading import Thread
4
- import time
5
  import torch
6
  from PIL import Image
 
 
7
  from transformers import (
8
  Qwen2VLForConditionalGeneration,
9
  AutoProcessor,
10
  TextIteratorStreamer,
11
  )
12
 
13
- # ---------------------------
14
- # Helper Functions
15
- # ---------------------------
16
- def progress_bar_html(label: str, primary_color: str = "#4B0082", secondary_color: str = "#9370DB") -> str:
17
- """
18
- Returns an HTML snippet for a thin animated progress bar with a label.
19
- """
20
- return f'''
21
- <div style="display: flex; align-items: center;">
22
- <span style="margin-right: 10px; font-size: 14px;">{label}</span>
23
- <div style="width: 110px; height: 5px; background-color: {secondary_color}; border-radius: 2px; overflow: hidden;">
24
- <div style="width: 100%; height: 100%; background-color: {primary_color}; animation: loading 1.5s linear infinite;"></div>
25
- </div>
26
- </div>
27
- <style>
28
- @keyframes loading {{
29
- 0% {{ transform: translateX(-100%); }}
30
- 100% {{ transform: translateX(100%); }}
31
- }}
32
- </style>
33
- '''
34
-
35
- # Model and Processor Setup - CPU version
36
  MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
37
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
38
  model = Qwen2VLForConditionalGeneration.from_pretrained(
@@ -41,23 +18,18 @@ model = Qwen2VLForConditionalGeneration.from_pretrained(
41
  torch_dtype=torch.float32 # Using float32 for CPU compatibility
42
  ).to("cpu").eval()
43
 
44
- # Main Inference Function
45
- def extract_medicines(image_files):
46
  """Extract medicine names from prescription images."""
47
- if not image_files:
48
  return "Please upload a prescription image."
49
 
50
- # Handle file inputs correctly
51
- image_paths = [file.name for file in image_files] if isinstance(image_files, list) else [image_files.name]
52
- images = [load_image(path) for path in image_paths]
53
-
54
- # Specific prompt to extract only medicine names
55
  text = "Extract ONLY the names of medications/medicines from this prescription image. Format the output as a numbered list of medicine names only, without dosages or instructions."
56
 
57
  messages = [{
58
  "role": "user",
59
  "content": [
60
- *[{"type": "image", "image": image} for image in images],
61
  {"type": "text", "text": text},
62
  ],
63
  }]
@@ -65,59 +37,34 @@ def extract_medicines(image_files):
65
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
66
  inputs = processor(
67
  text=[prompt_full],
68
- images=images,
69
  return_tensors="pt",
70
  padding=True,
71
  ).to("cpu")
72
 
73
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
74
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
 
75
 
76
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
77
- thread.start()
78
 
79
- buffer = ""
80
- yield progress_bar_html("Extracting Medicine Names")
 
81
 
82
- for new_text in streamer:
83
- buffer += new_text
84
- buffer = buffer.replace("<|im_end|>", "")
85
- time.sleep(0.01)
86
- yield buffer
87
 
88
- # Gradio Interface
89
- with gr.Blocks() as demo:
90
- gr.Markdown("# Medicine Name Extractor")
91
- gr.Markdown("Upload prescription images to extract medicine names")
92
-
93
- with gr.Row():
94
- with gr.Column():
95
- image_input = gr.File(
96
- label="Upload Prescription Image(s)",
97
- file_count="multiple",
98
- file_types=["image"]
99
- )
100
- extract_btn = gr.Button("Extract Medicine Names", variant="primary")
101
-
102
- with gr.Column():
103
- output = gr.Markdown(label="Extracted Medicine Names")
104
-
105
- extract_btn.click(
106
- fn=extract_medicines,
107
- inputs=image_input,
108
- outputs=output
109
- )
110
-
111
- # Note: For examples to work with current Gradio versions, you need a different approach
112
- # than what I previously provided. Remove examples for now to fix the immediate error.
113
-
114
- gr.Markdown("""
115
- ### Notes:
116
- - This app is optimized to run on CPU
117
- - Upload clear images of prescriptions for best results
118
- - Only medicine names will be extracted
119
- - Processing might take a minute or two on CPU
120
- """)
121
 
122
- demo.queue()
123
- demo.launch(debug=True)
 
1
  import gradio as gr
 
 
 
2
  import torch
3
  from PIL import Image
4
+ import time
5
+ from threading import Thread
6
  from transformers import (
7
  Qwen2VLForConditionalGeneration,
8
  AutoProcessor,
9
  TextIteratorStreamer,
10
  )
11
 
12
+ # Load model and processor - CPU version
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
14
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
15
  model = Qwen2VLForConditionalGeneration.from_pretrained(
 
18
  torch_dtype=torch.float32 # Using float32 for CPU compatibility
19
  ).to("cpu").eval()
20
 
21
+ def extract_medicines(image):
 
22
  """Extract medicine names from prescription images."""
23
+ if image is None:
24
  return "Please upload a prescription image."
25
 
26
+ # Process the image
 
 
 
 
27
  text = "Extract ONLY the names of medications/medicines from this prescription image. Format the output as a numbered list of medicine names only, without dosages or instructions."
28
 
29
  messages = [{
30
  "role": "user",
31
  "content": [
32
+ {"type": "image", "image": Image.open(image)},
33
  {"type": "text", "text": text},
34
  ],
35
  }]
 
37
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
38
  inputs = processor(
39
  text=[prompt_full],
40
+ images=[Image.open(image)],
41
  return_tensors="pt",
42
  padding=True,
43
  ).to("cpu")
44
 
45
+ # Generate response
46
+ with torch.no_grad():
47
+ output = model.generate(**inputs, max_new_tokens=512)
48
 
49
+ # Decode and return response
50
+ response = processor.decode(output[0], skip_special_tokens=True)
51
 
52
+ # Clean up the response to get just the model's answer
53
+ if "<|assistant|>" in response:
54
+ response = response.split("<|assistant|>")[1].strip()
55
 
56
+ return response
 
 
 
 
57
 
58
+ # Create a simple Gradio interface
59
+ demo = gr.Interface(
60
+ fn=extract_medicines,
61
+ inputs=gr.Image(type="filepath", label="Upload Prescription Image"),
62
+ outputs=gr.Textbox(label="Extracted Medicine Names"),
63
+ title="Medicine Name Extractor",
64
+ description="Upload prescription images to extract medicine names",
65
+ examples=[["examples/prescription1.jpg"]], # Update with your actual example paths or remove if not available
66
+ cache_examples=True,
67
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ if __name__ == "__main__":
70
+ demo.launch(debug=True)