chrisvoncsefalvay commited on
Commit
26dc4f5
·
0 Parent(s):

Initial Gradio app for Dental VQA Model Comparison

Browse files
.gitignore ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual Environment
24
+ venv/
25
+ ENV/
26
+ env/
27
+ .venv/
28
+
29
+ # IDE
30
+ .vscode/
31
+ *.swp
32
+ *.swo
33
+ *~
34
+
35
+ # OS
36
+ .DS_Store
37
+ Thumbs.db
38
+
39
+ # Hugging Face
40
+ .gradio/
41
+ gradio_cached_examples/
42
+
43
+ # Scratch files
44
+ .scratches/
45
+
46
+ # PyCharm
47
+ .idea/
48
+
49
+ # Local environment files
50
+ .env
51
+ .env.local
52
+
53
+ # Model cache
54
+ models/
55
+ .cache/
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/dental-vqa-comparison.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/material_theme_project_new.xml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="MaterialThemeProjectNewConfig">
4
+ <option name="metadata">
5
+ <MTProjectMetadataState>
6
+ <option name="migrated" value="true" />
7
+ <option name="pristineConfig" value="false" />
8
+ <option name="userId" value="-31f020fc:19626a96fac:-7ffa" />
9
+ </MTProjectMetadataState>
10
+ </option>
11
+ </component>
12
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/dental-vqa-comparison.iml" filepath="$PROJECT_DIR$/.idea/dental-vqa-comparison.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="$PROJECT_DIR$" vcs="Git" />
5
+ </component>
6
+ </project>
README.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Dental VQA Model Comparison
3
+ emoji: 🦷
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ models:
12
+ - yasserrmd/DentaInstruct-1.2B
13
+ ---
14
+
15
+ # Dental VQA Model Comparison
16
+
17
+ An interactive Gradio interface for comparing dental visual question answering models, currently featuring the DentaInstruct-1.2B model for educational information about dental health and oral care.
18
+
19
+ ## Features
20
+
21
+ - Interactive chat interface for dental health questions
22
+ - Adjustable generation parameters (temperature, max tokens, etc.)
23
+ - Example questions to get started
24
+ - Mobile-responsive design
25
+ - Clear disclaimers about educational use only
26
+
27
+ ## Important Disclaimer
28
+
29
+ ⚠️ **This model is for educational purposes only.** It is NOT a substitute for professional dental care. Do not use this model for clinical diagnosis or treatment advice. Always consult a qualified dental professional.
30
+
31
+ ## Model Information
32
+
33
+ - **Base Model**: LFM2-1.2B
34
+ - **Parameters**: 1.17B
35
+ - **Training Data**: Dental subset of MIRIAD dataset
36
+ - **Purpose**: Educational dental information
37
+
38
+ ## Usage
39
+
40
+ Ask questions about:
41
+ - Dental procedures and treatments
42
+ - Oral health and hygiene
43
+ - Common dental conditions
44
+ - Preventive dental care
45
+ - Dental anatomy and terminology
46
+
47
+ ## Credits
48
+
49
+ - **Model**: Created by @yasserrmd
50
+ - **Interface**: Space by @chrisvoncsefalvay
51
+
52
+ ## License
53
+
54
+ Apache-2.0
app.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+
5
+ # Model configuration
6
+ MODEL_ID = "yasserrmd/DentaInstruct-1.2B"
7
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
+ # Initialize model and tokenizer
10
+ print(f"Loading model {MODEL_ID}...")
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ MODEL_ID,
14
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
15
+ device_map="auto" if torch.cuda.is_available() else None
16
+ )
17
+
18
+ if not torch.cuda.is_available():
19
+ model = model.to(DEVICE)
20
+
21
+ # Set padding token if not set
22
+ if tokenizer.pad_token is None:
23
+ tokenizer.pad_token = tokenizer.eos_token
24
+
25
+ def format_prompt(message, history):
26
+ """Format the prompt for the model"""
27
+ messages = []
28
+
29
+ # Add conversation history
30
+ for user_msg, assistant_msg in history:
31
+ messages.append({"role": "user", "content": user_msg})
32
+ if assistant_msg:
33
+ messages.append({"role": "assistant", "content": assistant_msg})
34
+
35
+ # Add current message
36
+ messages.append({"role": "user", "content": message})
37
+
38
+ # Apply chat template
39
+ if hasattr(tokenizer, 'apply_chat_template'):
40
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
41
+ else:
42
+ # Fallback formatting
43
+ prompt = ""
44
+ for msg in messages:
45
+ if msg["role"] == "user":
46
+ prompt += f"User: {msg['content']}\n"
47
+ else:
48
+ prompt += f"Assistant: {msg['content']}\n"
49
+ prompt += "Assistant: "
50
+
51
+ return prompt
52
+
53
+ def generate_response(
54
+ message,
55
+ history,
56
+ temperature=0.7,
57
+ max_new_tokens=512,
58
+ top_p=0.95,
59
+ repetition_penalty=1.1,
60
+ ):
61
+ """Generate response from the model"""
62
+
63
+ # Format the prompt
64
+ prompt = format_prompt(message, history)
65
+
66
+ # Tokenize input
67
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
68
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
69
+
70
+ # Generate response
71
+ with torch.no_grad():
72
+ outputs = model.generate(
73
+ **inputs,
74
+ max_new_tokens=max_new_tokens,
75
+ temperature=temperature,
76
+ top_p=top_p,
77
+ repetition_penalty=repetition_penalty,
78
+ do_sample=True,
79
+ pad_token_id=tokenizer.pad_token_id,
80
+ eos_token_id=tokenizer.eos_token_id,
81
+ )
82
+
83
+ # Decode response
84
+ response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
85
+
86
+ return response
87
+
88
+ # Example questions
89
+ EXAMPLES = [
90
+ ["What are the main types of dental cavities?"],
91
+ ["Explain the process of root canal treatment"],
92
+ ["What is the difference between gingivitis and periodontitis?"],
93
+ ["How should I care for my teeth after a dental extraction?"],
94
+ ["What are the benefits of fluoride in dental care?"],
95
+ ["Explain the stages of tooth development in children"],
96
+ ["What causes tooth sensitivity and how can it be treated?"],
97
+ ["Describe the different types of dental fillings available"],
98
+ ]
99
+
100
+ # Custom CSS for styling
101
+ custom_css = """
102
+ .disclaimer {
103
+ background-color: #fff3cd;
104
+ border: 1px solid #ffc107;
105
+ border-radius: 5px;
106
+ padding: 10px;
107
+ margin-bottom: 15px;
108
+ }
109
+ """
110
+
111
+ # Create Gradio interface
112
+ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
113
+ gr.Markdown(
114
+ """
115
+ # Dental VQA Model Comparison
116
+
117
+ Interactive comparison of dental visual question answering models. Currently featuring DentaInstruct-1.2B for dental education and oral health information.
118
+ """
119
+ )
120
+
121
+ gr.HTML(
122
+ """
123
+ <div class="disclaimer">
124
+ <strong>⚠️ Important Disclaimer:</strong><br>
125
+ This model is for educational purposes only. It is NOT a substitute for professional dental care.
126
+ Do not use this model for clinical diagnosis or treatment advice. Always consult a qualified dental professional.
127
+ </div>
128
+ """
129
+ )
130
+
131
+ chatbot = gr.Chatbot(
132
+ height=400,
133
+ label="Conversation"
134
+ )
135
+
136
+ msg = gr.Textbox(
137
+ label="Your dental question",
138
+ placeholder="Ask a question about dental health, procedures, or oral care...",
139
+ lines=2
140
+ )
141
+
142
+ with gr.Row():
143
+ submit = gr.Button("Send", variant="primary")
144
+ clear = gr.Button("Clear")
145
+
146
+ with gr.Accordion("Advanced Settings", open=False):
147
+ temperature = gr.Slider(
148
+ minimum=0.1,
149
+ maximum=1.0,
150
+ value=0.7,
151
+ step=0.1,
152
+ label="Temperature",
153
+ info="Controls randomness in responses"
154
+ )
155
+
156
+ max_new_tokens = gr.Slider(
157
+ minimum=64,
158
+ maximum=1024,
159
+ value=512,
160
+ step=64,
161
+ label="Max New Tokens",
162
+ info="Maximum length of the response"
163
+ )
164
+
165
+ top_p = gr.Slider(
166
+ minimum=0.1,
167
+ maximum=1.0,
168
+ value=0.95,
169
+ step=0.05,
170
+ label="Top-p",
171
+ info="Nucleus sampling parameter"
172
+ )
173
+
174
+ repetition_penalty = gr.Slider(
175
+ minimum=1.0,
176
+ maximum=1.5,
177
+ value=1.1,
178
+ step=0.05,
179
+ label="Repetition Penalty",
180
+ info="Reduces repetition in responses"
181
+ )
182
+
183
+ gr.Examples(
184
+ examples=EXAMPLES,
185
+ inputs=msg,
186
+ label="Example Questions"
187
+ )
188
+
189
+ gr.Markdown(
190
+ """
191
+ ## About This Model
192
+
193
+ DentaInstruct-1.2B is a specialised language model fine-tuned on dental educational content.
194
+ It's designed to provide educational information about dental health, procedures, and oral care.
195
+
196
+ **Model Details:**
197
+ - Base Model: LFM2-1.2B
198
+ - Parameters: 1.17B
199
+ - Training Data: Dental subset of MIRIAD dataset
200
+ - Purpose: Educational dental information
201
+
202
+ **Created by:** @yasserrmd | **Space by:** @chrisvoncsefalvay
203
+ """
204
+ )
205
+
206
+ # Event handlers
207
+ def respond(message, chat_history, temperature, max_new_tokens, top_p, repetition_penalty):
208
+ response = generate_response(
209
+ message,
210
+ chat_history,
211
+ temperature,
212
+ max_new_tokens,
213
+ top_p,
214
+ repetition_penalty
215
+ )
216
+ chat_history.append((message, response))
217
+ return "", chat_history
218
+
219
+ msg.submit(
220
+ respond,
221
+ [msg, chatbot, temperature, max_new_tokens, top_p, repetition_penalty],
222
+ [msg, chatbot]
223
+ )
224
+
225
+ submit.click(
226
+ respond,
227
+ [msg, chatbot, temperature, max_new_tokens, top_p, repetition_penalty],
228
+ [msg, chatbot]
229
+ )
230
+
231
+ clear.click(lambda: None, None, chatbot, queue=False)
232
+
233
+ if __name__ == "__main__":
234
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==4.44.0
2
+ transformers==4.44.2
3
+ torch>=2.0.0
4
+ accelerate==0.33.0
5
+ sentencepiece==0.2.0
6
+ protobuf==5.27.3