File size: 6,270 Bytes
a345b5c
2443b6b
48a4145
2443b6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae047d1
2443b6b
 
 
 
 
 
ae047d1
2443b6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
709c305
 
 
2443b6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a345b5c
 
 
 
 
 
 
 
 
2443b6b
 
709c305
e31e87c
 
709c305
 
6c31327
709c305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45c8412
709c305
6c31327
 
709c305
45c8412
709c305
 
 
6c31327
709c305
 
 
 
 
 
 
941164e
45c8412
 
 
 
 
 
 
 
 
 
f8d2233
 
45c8412
f8d2233
45c8412
f8d2233
 
 
 
 
 
 
 
a345b5c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# LLM_package.py の修正
from google import genai

import json
import os
from PIL import Image
from dotenv import load_dotenv
load_dotenv()


class GeminiInference:
    """
    Gemini API 呼び出しを扱うクラス。
    """
    def __init__(self, api_key_source=os.getenv('GEMINI_API_KEY')):
        self.api_key_source = api_key_source

    def get_response(self, file_path, prompt):
        """
        画像ファイルに対して Geminin API 呼び出しを行い、レスポンステキストを返す。
        """
        client = genai.Client(api_key=self.api_key_source)
        my_file = client.files.upload(file=file_path)
        response = client.models.generate_content(
            model="gemini-2.5-flash",
            contents=[my_file, prompt],
        )
        return response.text
    def get_response_text(self,prompt):
        client = genai.Client(api_key=self.api_key_source)
        response = client.models.generate_content(
            model="gemini-2.5-flash",
            contents=[prompt],
        )
        text = response.text
        return text
    def parse(self, text):
        json_str = text
        if '```json' in text:
            json_str = text[text.find('```json') + len('```json'):]

        json_str = json_str.strip('` \n')
        return json_str
    def parse_response(self, text):
        """
        レスポンス JSON をパース。'label' と 'box_2d'([0-1000]正規化) を取り出し、[0,1]正規化に変換して返すリスト。
        """
        print("GeminiInference.parse_response:", text)
        if not text:
            return {'state': 'empty'}
        json_str = text
        if '```json' in text:
            json_str = text[text.find('```json') + len('```json'):]
        json_str = json_str.strip('` \n')
        try:
            data = json.loads(json_str)
        except Exception as e:
            print("JSON パースエラー:", e)
            return []
        if isinstance(data, dict):
            data = [data]
        parsed = []
        for obj in data:
            if 'box_2d' in obj and 'label' in obj:
                coords = obj['box_2d']
                # box_2dの各要素が0-1000の範囲であることを確認し、0-1の範囲に正規化
                # ただし、Geminiの出力がすでに0-1で返される可能性も考慮し、
                # 値が1を超える場合は1000で割り、1を超えない場合はそのまま使用
                norm = []
                for c in coords:
                    if c > 1.0001: # 1をわずかに超える値であれば1000スケールと判断
                        norm.append(c / 1000.0)
                    else: # 1以下の値であればすでに0-1スケールと判断
                        norm.append(c)
                parsed.append({'label': obj['label'], 'box_2d': norm})
        return parsed
class ObjectDetector:
    def __init__(self, API_KEY=os.getenv('GEMINI_API_KEY')): # API_KEYをapi_keyに変更し、デフォルト値を設定
        self.model  = GeminiInference(API_KEY) # ここもapi_keyを使用
        self.prompt_objects=None
        self.text=None
        self.scene=None
        
    def detect_objects(self, image_path):
        self.prompt= f"""
    Detect all {self.prompt_objects} in the image. The box_2d should be [ymin, xmin, ymax, xmax] normalized to 0-1000.
    Please provide the response as a JSON array of objects, where each object has a 'label' and 'box_2d' field.
    Example:
    [
        {{"label": "face", "box_2d": [100, 200, 300, 400]}},
        {{"label": "license_plate", "box_2d": [500, 600, 700, 800]}}
    ]
    """
        print(self.prompt)
        detected_objects_norm_0_1= self.model.parse_response(self.model.get_response(image_path, self.prompt))
        return detected_objects_norm_0_1
    """
    Detects the danger level of the image.
    """
    def detect_auto(self, image_path):
        analysis_prompt = f"""
            画像の個人情報漏洩リスクを分析し、厳密にJSON形式で返答してください。なおこの時、資料があれば、資料を参考にしてください。またこの際にその写真のシーンの情報もあれば参考にして動的に
            消去するオブジェクトを変化させてください:
                {{
                    "risk_level":0~100,
                    "risk_reason": "リスクの具体的理由",
                    "objects_to_remove": ["消去すべきオブジェクトリスト(英語で、例: 'face', 'license_plate')"]
                }}
                <写真を撮ったシーンの情報>{self.scene if self.scene else "なし"}
                <資料>
                {self.text if self.text else "なし"}
                </資料>
                """
        
        response = self.model.parse(self.model.get_response(image_path, analysis_prompt))
        print(f"Response: {response}")
        return json.loads(response)
    def detect_danger_level(self, image_path):
        analysis_prompt = f"""
            画像の個人情報漏洩リスクを分析し、厳密にJSON形式で返答してください。なおこの時、資料があれば、資料を参考にしてください:
                {{
                    "risk_level": 0~100,
                }}
                <資料>
                {self.text if self.text else "なし"}
                </資料>
                """
        response_text = self.model.get_response(image_path, analysis_prompt)
        response_json_str = self.model.parse(response_text)

        print(f"Response from API (raw): {response_json_str}")
        try:
            # JSON文字列をPythonの辞書にパースする
            response_data = json.loads(response_json_str)
            # 辞書から 'risk_level' を取得し、整数に変換する
            risk_level = int(response_data.get('risk_level', 0))
        except (json.JSONDecodeError, ValueError, TypeError, AttributeError) as e:
            print(f"Failed to parse risk_level from response. Error: {e}")
            print(f"Response content: {response_json_str}")
            risk_level = 0 # パース失敗時はデフォルト値0を返す
        return risk_level