Spaces:
Runtime error
Runtime error
| from google import genai | |
| import json | |
| import os | |
| from PIL import Image | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| class GeminiInference: | |
| """ | |
| Gemini API 呼び出しを扱うクラス。 | |
| """ | |
| def __init__(self, api_key_source=os.getenv('GEMINI_API_KEY')): | |
| self.api_key_source = api_key_source | |
| def get_response(self, file_path, prompt): | |
| """ | |
| 画像ファイルに対して Geminin API 呼び出しを行い、レスポンステキストを返す。 | |
| """ | |
| client = genai.Client(api_key=self.api_key_source) | |
| my_file = client.files.upload(file=file_path) | |
| response = client.models.generate_content( | |
| model="gemini-2.5-pro", | |
| contents=[my_file, prompt], | |
| ) | |
| return response.text | |
| def get_response_text(self,prompt): | |
| client = genai.Client(api_key=self.api_key_source) | |
| response = client.models.generate_content( | |
| model="gemini-2.5-pro", | |
| contents=[prompt], | |
| ) | |
| text = response.text | |
| return text | |
| def parse(self, text): | |
| json_str = text | |
| if '```json' in text: | |
| json_str = text[text.find('```json') + len('```json'):] | |
| json_str = json_str.strip('` \n') | |
| return json_str | |
| def parse_response(self, text): | |
| """ | |
| レスポンス JSON をパース。'label' と 'box_2d'([0-1000]正規化) を取り出し、[0,1]正規化に変換して返すリスト。 | |
| """ | |
| print("GeminiInference.parse_response:", text) | |
| if not text: | |
| return {'state': 'empty'} | |
| json_str = text | |
| if '```json' in text: | |
| json_str = text[text.find('```json') + len('```json'):] | |
| json_str = json_str.strip('` \n') | |
| try: | |
| data = json.loads(json_str) | |
| except Exception as e: | |
| print("JSON パースエラー:", e) | |
| return [] | |
| if isinstance(data, dict): | |
| data = [data] | |
| parsed = [] | |
| for obj in data: | |
| if 'box_2d' in obj and 'label' in obj: | |
| coords = obj['box_2d'] | |
| norm = [c / 1000.0 for c in coords] | |
| parsed.append({'label': obj['label'], 'box_2d': norm}) | |
| return parsed | |
| class ObjectDetector: | |
| def __init__(self, API_KEY=None): | |
| self.model = GeminiInference(API_KEY) | |
| self.prompt_objects=None | |
| self.text=None | |
| def detect_objects(self, image_path): | |
| self.prompt= f""" | |
| Detect all {self.prompt_objects} in the image. The box_2d should be [ymin, xmin, ymax, xmax] normalized to 0-1000. | |
| Please provide the response as a JSON array of objects, where each object has a 'label' and 'box_2d' field. | |
| Example: | |
| [ | |
| {{"label": "face", "box_2d": [100, 200, 300, 400]}}, | |
| {{"label": "license_plate", "box_2d": [500, 600, 700, 800]}} | |
| ] | |
| """ | |
| print(self.prompt) | |
| detected_objects_norm_0_1= self.model.parse_response(self.model.get_response(image_path, self.prompt)) | |
| return detected_objects_norm_0_1 | |
| """ | |
| Detects the danger level of the image. | |
| """ | |
| def detect_auto(self, image_path): | |
| analysis_prompt = f""" | |
| 画像の個人情報漏洩リスクを分析し、厳密にJSON形式で返答してください。なおこの時、資料があれば、資料を参考にしてください: | |
| {{ | |
| "risk_level":0~100, | |
| "risk_reason": "リスクの具体的理由", | |
| "objects_to_remove": ["消去すべきオブジェクトリスト(英語で、例: 'face', 'license_plate')"] | |
| }} | |
| <資料> | |
| {self.text if self.text else "なし"} | |
| </資料> | |
| """ | |
| response = self.model.parse(self.model.get_response(image_path, analysis_prompt)) | |
| print(f"Response: {response}") | |
| return json.loads(response) | |
| def detect_danger_level(self, image_path): | |
| analysis_prompt = f""" | |
| 画像の個人情報漏洩リスクを分析し、厳密にJSON形式で返答してください。なおこの時、資料があれば、資料を参考にしてください: | |
| {{ | |
| "risk_level": 0~100, | |
| }} | |
| <資料> | |
| {self.text if self.text else "なし"} | |
| </資料> | |
| """ | |
| response = self.model.parse(self.model.get_response(image_path, analysis_prompt)) | |
| print(f"Response: {response}") | |
| try: | |
| risk_level = int(response.get('risk_level', 0)) | |
| except ValueError: | |
| risk_level = 0 | |
| return risk_level | |