Ashraf commited on
Commit
c9c6b30
·
verified ·
1 Parent(s): 0d8156c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +203 -0
app.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ from flask import Flask, request, jsonify,Response
4
+ from flask_cors import CORS
5
+ import tempfile
6
+ import time
7
+ from flask import Flask, request, jsonify
8
+ from transformers import AutoProcessor, AutoModelForVision2Seq , AutoModelForImageTextToText
9
+ from PIL import Image
10
+ import torch
11
+ import tempfile
12
+ import whisper
13
+ import json
14
+ app = Flask(__name__)
15
+ from deep_translator import GoogleTranslator
16
+
17
+ CORS(app)
18
+ # Load MedGemma model (4B) on startup
19
+
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM
21
+ #from huggingface_hub import login
22
+ import os
23
+ #from llama_cpp import Llama
24
+ #from huggingface_hub import hf_hub_download
25
+ """hugging_face_token = os.getenv("kk")
26
+ if not hugging_face_token:
27
+ raise EnvironmentError("HUGGINGFACE_TOKEN environment variable not set.")
28
+
29
+ #login(hugging_face_token)
30
+ """
31
+ """model_name = "unsloth/medgemma-4b-it-GGUF"
32
+ model_file = "medgemma-4b-it-Q8_0.gguf" # this is the specific model file we'll use in this example. It's a 4-bit quant, but other levels of quantization are available in the model repo if preferred
33
+ model_path = hf_hub_download(model_name, filename=model_file)
34
+ llm = Llama(
35
+ model_path=model_path, # Update this to your local model path
36
+ n_ctx=8192,
37
+ n_threads=12,
38
+ temperature=0.7,
39
+ )
40
+ """
41
+ model_id = "google/medgemma-4b-pt"
42
+
43
+ model_medg = AutoModelForImageTextToText.from_pretrained(
44
+ model_id,
45
+ torch_dtype=torch.bfloat16,
46
+ device_map="auto",
47
+ )
48
+ processor = AutoProcessor.from_pretrained(model_id)
49
+ @app.route('/analyze-image', methods=['POST'])
50
+ def analyze_image():
51
+ image = None
52
+ image_path = None
53
+
54
+ # Get optional image
55
+ file = request.files.get('file')
56
+ if file:
57
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp:
58
+ file.save(tmp.name)
59
+ image_path = tmp.name
60
+ try:
61
+ image = Image.open(tmp.name).convert("RGB")
62
+ except Exception as e:
63
+ return jsonify({'error': f'Invalid image: {str(e)}'}), 400
64
+
65
+ # Get optional chat history
66
+ try:
67
+ chat_history = request.form.get("chat_history")
68
+ if not chat_history:
69
+ chat_history = request.json.get("chat_history", "[]")
70
+ else:
71
+ chat_history = json.loads(chat_history)
72
+ except Exception as e:
73
+ return jsonify({'error': f'Invalid or missing chat_history: {str(e)}'}), 400
74
+
75
+ # Build text prompt (from chat history)
76
+ prompt_parts = []
77
+ for msg in chat_history:
78
+ role = msg.get("role", "").strip().lower()
79
+ content = msg.get("content", "").strip()
80
+ if role == "system":
81
+ prompt_parts.append(content)
82
+ elif role == "user":
83
+ prompt_parts.append(f"User: {content}")
84
+ elif role == "assistant":
85
+ prompt_parts.append(f"Assistant: {content}")
86
+
87
+ combined_prompt = "\n".join(prompt_parts).strip()
88
+
89
+ if not image and not combined_prompt:
90
+ return jsonify({'error': 'You must provide either an image or a prompt.'}), 400
91
+
92
+ # Final model prompt
93
+ model_prompt = f" {combined_prompt or 'Response:'}"
94
+
95
+ # Prepare input to model
96
+ inputs = processor(
97
+ text=model_prompt,
98
+ images=image if image else None,
99
+ return_tensors="pt"
100
+ ).to(model_medg.device, dtype=torch.bfloat16)
101
+
102
+ input_len = inputs["input_ids"].shape[-1]
103
+ print(model_prompt)
104
+
105
+ with torch.inference_mode():
106
+ generation = model_medg.generate(
107
+ **inputs,
108
+ max_new_tokens=1000,
109
+ do_sample=False
110
+ )
111
+ generation = generation[0][input_len:]
112
+
113
+ decoded = processor.decode(generation, skip_special_tokens=True)
114
+
115
+ if image_path and os.path.exists(image_path):
116
+ os.remove(image_path)
117
+
118
+ return jsonify({"result": decoded.strip()})
119
+
120
+ @app.route('/med-llm', methods=['POST'])
121
+ def med_llm():
122
+ uploaded_file = request.files.get('file')
123
+ if not uploaded_file:
124
+ return jsonify({'error': 'No file uploaded'}), 400
125
+
126
+ mime_type = uploaded_file.mimetype
127
+ print(f"Received file type: {mime_type}")
128
+
129
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.bin') as tmp:
130
+ uploaded_file.save(tmp.name)
131
+
132
+ if mime_type.startswith('image/'):
133
+ mock_response = "📷 MedGemma Image Analysis: Detected mild cardiomegaly."
134
+ elif mime_type.startswith('audio/'):
135
+ mock_response = "🎧 MedGemma Audio Analysis: Suggests potential wheezing."
136
+ else:
137
+ mock_response = "Unsupported file type."
138
+
139
+ return jsonify({'result': mock_response})
140
+ model = whisper.load_model("base") # You can change to 'small', 'medium', etc.
141
+ @app.route('/transcribe-stream', methods=['POST'])
142
+ def transcribe_stream():
143
+ # Save the audio file from request
144
+ audio_file = request.files.get('audio')
145
+ if not audio_file:
146
+ return "Missing audio file", 400
147
+
148
+ # Save to temp file
149
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
150
+ audio_path = tmp.name
151
+ audio_file.save(audio_path)
152
+
153
+ def generate():
154
+ # Transcribe using Whisper (non-streaming)
155
+ result = model.transcribe(audio_path)
156
+ for word in result['text'].split():
157
+ yield f"data: {word}\n\n"
158
+ time.sleep(0.3) # Simulate streaming
159
+
160
+ os.remove(audio_path) # Clean up
161
+
162
+ return Response(generate(), mimetype='text/event-stream')
163
+ def translate_to_arabic(text):
164
+ try:
165
+ translated = GoogleTranslator(source='auto', target='ar').translate(text)
166
+ return translated
167
+ except Exception as e:
168
+ print(f"Translation failed: {e}")
169
+ return text
170
+ @app.route('/chat-translate', methods=['POST'])
171
+ def chat_translate():
172
+ try:
173
+ data = request.get_json()
174
+ chat_history = data.get('chat_history', [])
175
+ translate = request.args.get("translate") == "true"
176
+
177
+ # Join messages into prompt
178
+ prompt_parts = []
179
+ for msg in chat_history:
180
+ role = msg.get("role", "").strip().lower()
181
+ content = msg.get("content", "").strip()
182
+ if role == "system":
183
+ prompt_parts.append(f"System:\n{content}")
184
+ elif role == "user":
185
+ prompt_parts.append(f"User:\n{content}")
186
+ elif role == "model":
187
+ prompt_parts.append(f"Assistant:\n{content}")
188
+ result = content # Use the last model response as the result
189
+
190
+ full_prompt = "\n\n".join(prompt_parts)
191
+
192
+ # Simulated LLM response
193
+ # result = "Simulated model answer. Possible findings: * Infection * Fluid accumulation. Next steps: * Follow-up test."
194
+
195
+ if translate:
196
+ result = translate_to_arabic(result)
197
+
198
+ return jsonify({"result": result})
199
+
200
+ except Exception as e:
201
+ return jsonify({"error": str(e)}), 500
202
+ if __name__ == '__main__':
203
+ app.run('0.0.0.0',port=5002)