Spaces:
Running
Running
# LLM_package.py の修正 | |
from google import genai | |
import json | |
import os | |
from PIL import Image | |
from dotenv import load_dotenv | |
load_dotenv() | |
class GeminiInference: | |
""" | |
Gemini API 呼び出しを扱うクラス。 | |
""" | |
def __init__(self, api_key_source=os.getenv('GEMINI_API_KEY')): | |
self.api_key_source = api_key_source | |
def get_response(self, file_path, prompt): | |
""" | |
画像ファイルに対して Geminin API 呼び出しを行い、レスポンステキストを返す。 | |
""" | |
client = genai.Client(api_key=self.api_key_source) | |
my_file = client.files.upload(file=file_path) | |
response = client.models.generate_content( | |
model="gemini-2.5-flash", | |
contents=[my_file, prompt], | |
) | |
return response.text | |
def get_response_text(self,prompt): | |
client = genai.Client(api_key=self.api_key_source) | |
response = client.models.generate_content( | |
model="gemini-2.5-flash", | |
contents=[prompt], | |
) | |
text = response.text | |
return text | |
def parse(self, text): | |
json_str = text | |
if '```json' in text: | |
json_str = text[text.find('```json') + len('```json'):] | |
json_str = json_str.strip('` \n') | |
return json_str | |
def parse_response(self, text): | |
""" | |
レスポンス JSON をパース。'label' と 'box_2d'([0-1000]正規化) を取り出し、[0,1]正規化に変換して返すリスト。 | |
""" | |
print("GeminiInference.parse_response:", text) | |
if not text: | |
return {'state': 'empty'} | |
json_str = text | |
if '```json' in text: | |
json_str = text[text.find('```json') + len('```json'):] | |
json_str = json_str.strip('` \n') | |
try: | |
data = json.loads(json_str) | |
except Exception as e: | |
print("JSON パースエラー:", e) | |
return [] | |
if isinstance(data, dict): | |
data = [data] | |
parsed = [] | |
for obj in data: | |
if 'box_2d' in obj and 'label' in obj: | |
coords = obj['box_2d'] | |
# box_2dの各要素が0-1000の範囲であることを確認し、0-1の範囲に正規化 | |
# ただし、Geminiの出力がすでに0-1で返される可能性も考慮し、 | |
# 値が1を超える場合は1000で割り、1を超えない場合はそのまま使用 | |
norm = [] | |
for c in coords: | |
if c > 1.0001: # 1をわずかに超える値であれば1000スケールと判断 | |
norm.append(c / 1000.0) | |
else: # 1以下の値であればすでに0-1スケールと判断 | |
norm.append(c) | |
parsed.append({'label': obj['label'], 'box_2d': norm}) | |
return parsed | |
class ObjectDetector: | |
def __init__(self, API_KEY=os.getenv('GEMINI_API_KEY')): # API_KEYをapi_keyに変更し、デフォルト値を設定 | |
self.model = GeminiInference(API_KEY) # ここもapi_keyを使用 | |
self.prompt_objects=None | |
self.text=None | |
self.scene=None | |
def detect_objects(self, image_path): | |
self.prompt= f""" | |
Detect all {self.prompt_objects} in the image. The box_2d should be [ymin, xmin, ymax, xmax] normalized to 0-1000. | |
Please provide the response as a JSON array of objects, where each object has a 'label' and 'box_2d' field. | |
Example: | |
[ | |
{{"label": "face", "box_2d": [100, 200, 300, 400]}}, | |
{{"label": "license_plate", "box_2d": [500, 600, 700, 800]}} | |
] | |
""" | |
print(self.prompt) | |
detected_objects_norm_0_1= self.model.parse_response(self.model.get_response(image_path, self.prompt)) | |
return detected_objects_norm_0_1 | |
""" | |
Detects the danger level of the image. | |
""" | |
def detect_auto(self, image_path): | |
analysis_prompt = f""" | |
画像の個人情報漏洩リスクを分析し、厳密にJSON形式で返答してください。なおこの時、資料があれば、資料を参考にしてください。またこの際にその写真のシーンの情報もあれば参考にして動的に | |
消去するオブジェクトを変化させてください: | |
{{ | |
"risk_level":0~100, | |
"risk_reason": "リスクの具体的理由", | |
"objects_to_remove": ["消去すべきオブジェクトリスト(英語で、例: 'face', 'license_plate')"] | |
}} | |
<写真を撮ったシーンの情報>{self.scene if self.scene else "なし"} | |
<資料> | |
{self.text if self.text else "なし"} | |
</資料> | |
""" | |
response = self.model.parse(self.model.get_response(image_path, analysis_prompt)) | |
print(f"Response: {response}") | |
return json.loads(response) | |
def detect_danger_level(self, image_path): | |
analysis_prompt = f""" | |
画像の個人情報漏洩リスクを分析し、厳密にJSON形式で返答してください。なおこの時、資料があれば、資料を参考にしてください: | |
{{ | |
"risk_level": 0~100, | |
}} | |
<資料> | |
{self.text if self.text else "なし"} | |
</資料> | |
""" | |
response_text = self.model.get_response(image_path, analysis_prompt) | |
response_json_str = self.model.parse(response_text) | |
print(f"Response from API (raw): {response_json_str}") | |
try: | |
# JSON文字列をPythonの辞書にパースする | |
response_data = json.loads(response_json_str) | |
# 辞書から 'risk_level' を取得し、整数に変換する | |
risk_level = int(response_data.get('risk_level', 0)) | |
except (json.JSONDecodeError, ValueError, TypeError, AttributeError) as e: | |
print(f"Failed to parse risk_level from response. Error: {e}") | |
print(f"Response content: {response_json_str}") | |
risk_level = 0 # パース失敗時はデフォルト値0を返す | |
return risk_level |