Upload 26 files
Browse files- eval_suite/__init__.py +0 -0
- eval_suite/image_utils.py +104 -0
- eval_suite/parse_prompt.py +54 -0
- eval_suite/prompts_raw/__init__.py +145 -0
- eval_suite/prompts_raw/fix_transcript.txt +8 -0
- eval_suite/prompts_raw/image_eval.txt +45 -0
- eval_suite/prompts_raw/text_eval_new.txt +47 -0
- eval_suite/prompts_raw/video_eval_new.txt +37 -0
- eval_suite/text_utils.py +80 -0
- eval_suite/utils.py +81 -0
- eval_suite/video_utils.py +167 -0
- mllm_tools/__init__.py +1 -0
- mllm_tools/__pycache__/__init__.cpython-312.pyc +0 -0
- mllm_tools/__pycache__/gemini.cpython-312.pyc +0 -0
- mllm_tools/__pycache__/litellm.cpython-312.pyc +0 -0
- mllm_tools/__pycache__/openai.cpython-312.pyc +0 -0
- mllm_tools/__pycache__/openrouter.cpython-312.pyc +0 -0
- mllm_tools/__pycache__/utils.cpython-312.pyc +0 -0
- mllm_tools/__pycache__/vertex_ai.cpython-312.pyc +0 -0
- mllm_tools/gemini.py +176 -0
- mllm_tools/github.py +305 -0
- mllm_tools/litellm.py +193 -0
- mllm_tools/openai.py +594 -0
- mllm_tools/openrouter.py +266 -0
- mllm_tools/utils.py +177 -0
- mllm_tools/vertex_ai.py +86 -0
eval_suite/__init__.py
ADDED
File without changes
|
eval_suite/image_utils.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image, ImageOps
|
6 |
+
from moviepy import VideoFileClip
|
7 |
+
|
8 |
+
from eval_suite.prompts_raw import _image_eval
|
9 |
+
from eval_suite.utils import extract_json, convert_score_fields, calculate_geometric_mean
|
10 |
+
from mllm_tools.utils import _prepare_text_image_inputs
|
11 |
+
from src.core.parse_video import image_with_most_non_black_space
|
12 |
+
|
13 |
+
def extract_key_frames(video_path, output_dir, num_chunks):
|
14 |
+
"""Extract key frames from a video by dividing it into chunks and selecting representative frames.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
video_path (str): Path to the input video file
|
18 |
+
output_dir (str): Directory where extracted frames will be saved
|
19 |
+
num_chunks (int): Number of chunks to divide the video into
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
list: List of paths to the extracted key frames
|
23 |
+
"""
|
24 |
+
# Create output directory if it doesn't exist
|
25 |
+
os.makedirs(output_dir, exist_ok=True)
|
26 |
+
|
27 |
+
# Extract all frames from the video
|
28 |
+
clip = VideoFileClip(video_path)
|
29 |
+
frames = list(clip.iter_frames(fps=1)) # one frame every second
|
30 |
+
|
31 |
+
total_frames = len(frames)
|
32 |
+
if total_frames == 0:
|
33 |
+
print("No frames extracted from the video.")
|
34 |
+
return []
|
35 |
+
|
36 |
+
# Determine the number of frames per chunk
|
37 |
+
frames_per_chunk = total_frames // num_chunks
|
38 |
+
num_chunks = min(num_chunks, (total_frames + frames_per_chunk - 1) // frames_per_chunk)
|
39 |
+
|
40 |
+
key_frames = []
|
41 |
+
|
42 |
+
# Process each chunk of frames
|
43 |
+
for i in range(num_chunks):
|
44 |
+
start_idx = i * frames_per_chunk
|
45 |
+
end_idx = min((i + 1) * frames_per_chunk, total_frames)
|
46 |
+
chunk_frames = frames[start_idx:end_idx]
|
47 |
+
|
48 |
+
if chunk_frames:
|
49 |
+
# Save the frame with most non-black space
|
50 |
+
output_path = os.path.join(output_dir, f"key_frame_{i+1}.jpg")
|
51 |
+
result = image_with_most_non_black_space(chunk_frames, output_path)
|
52 |
+
else:
|
53 |
+
print(f"No frames in chunk {i+1}. Skipping.")
|
54 |
+
result = None
|
55 |
+
|
56 |
+
if result is not None:
|
57 |
+
key_frames.append(output_path)
|
58 |
+
clip.close()
|
59 |
+
|
60 |
+
return key_frames
|
61 |
+
|
62 |
+
|
63 |
+
def evaluate_sampled_images(model, video_path, description="No description provided", num_chunks=10, output_folder=None):
|
64 |
+
"""Evaluate sampled frames from a video using an image evaluation model.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
model: The image evaluation model to use
|
68 |
+
video_path (str): Path to the input video file
|
69 |
+
description (str, optional): Description of the video content. Defaults to "No description provided"
|
70 |
+
num_chunks (int, optional): Number of chunks to divide the video into. Defaults to 10
|
71 |
+
output_folder (str, optional): Directory for temporary files. Defaults to None
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
dict: Dictionary containing evaluation scores and individual frame assessments with keys:
|
75 |
+
- evaluation: Dictionary of averaged scores for each criterion
|
76 |
+
- image_chunks: List of individual frame evaluation results
|
77 |
+
"""
|
78 |
+
with tempfile.TemporaryDirectory(dir=output_folder) as temp_dir:
|
79 |
+
key_frames = extract_key_frames(video_path, temp_dir, num_chunks)
|
80 |
+
|
81 |
+
prompt = _image_eval.format(description=description)
|
82 |
+
|
83 |
+
responses = []
|
84 |
+
for key_frame in key_frames:
|
85 |
+
inputs = _prepare_text_image_inputs(prompt, key_frame)
|
86 |
+
response = model(inputs)
|
87 |
+
response_json = extract_json(response)
|
88 |
+
response_json = convert_score_fields(response_json)
|
89 |
+
responses.append(response_json)
|
90 |
+
|
91 |
+
criteria = list(responses[0]["evaluation"].keys())
|
92 |
+
scores_dict = {c: [] for c in criteria}
|
93 |
+
for response in responses:
|
94 |
+
for key, val in response["evaluation"].items():
|
95 |
+
scores_dict[key].append(val["score"])
|
96 |
+
|
97 |
+
res_score = {}
|
98 |
+
for key, scores in scores_dict.items():
|
99 |
+
res_score[key] = {"score": calculate_geometric_mean(scores)}
|
100 |
+
|
101 |
+
return {
|
102 |
+
"evaluation": res_score,
|
103 |
+
"image_chunks": responses
|
104 |
+
}
|
eval_suite/parse_prompt.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from tqdm import tqdm
|
3 |
+
|
4 |
+
|
5 |
+
def call_parse_prompt():
|
6 |
+
"""
|
7 |
+
Locates the prompts_raw directory and generates an __init__.py file containing prompt texts.
|
8 |
+
|
9 |
+
Searches for prompts_raw directory in current and parent directories. Once found, calls
|
10 |
+
create_python_file_with_texts() to generate the __init__.py file.
|
11 |
+
"""
|
12 |
+
current_file_path = os.path.abspath(__file__)
|
13 |
+
current_folder_path = os.path.dirname(current_file_path)
|
14 |
+
folder_path = os.path.join(current_folder_path, "prompts_raw")
|
15 |
+
|
16 |
+
# If prompts_raw not found in current directory, search parent directories
|
17 |
+
if not os.path.exists(folder_path):
|
18 |
+
parent_dir = current_folder_path
|
19 |
+
while parent_dir != os.path.dirname(parent_dir): # Stop at root directory
|
20 |
+
parent_dir = os.path.dirname(parent_dir)
|
21 |
+
test_path = os.path.join(parent_dir, "prompts_raw")
|
22 |
+
if os.path.exists(test_path):
|
23 |
+
folder_path = test_path
|
24 |
+
break
|
25 |
+
|
26 |
+
output_file = os.path.join(folder_path, "__init__.py")
|
27 |
+
create_python_file_with_texts(folder_path, output_file)
|
28 |
+
|
29 |
+
|
30 |
+
def create_python_file_with_texts(folder_path, output_file):
|
31 |
+
"""
|
32 |
+
Creates a Python file containing prompt texts from .txt files.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
folder_path (str): Path to directory containing prompt .txt files
|
36 |
+
output_file (str): Path where the output __init__.py file will be created
|
37 |
+
|
38 |
+
The function reads all .txt files in the given folder, converts their contents into
|
39 |
+
Python variables, and writes them to the output file. Variable names are derived from
|
40 |
+
file paths with special characters replaced.
|
41 |
+
"""
|
42 |
+
with open(output_file, 'w', encoding='utf-8') as out_file:
|
43 |
+
out_file.write("# This file is generated automatically through parse_prompt.py\n\n")
|
44 |
+
txt_files = [file for root, dirs, files in os.walk(folder_path) for file in files if file.endswith(".txt")]
|
45 |
+
for file in tqdm(txt_files, desc="Processing files"):
|
46 |
+
file_path = os.path.join(folder_path, file)
|
47 |
+
var_name = "_" + file_path.replace(folder_path, "").replace(os.sep, "_").replace(".txt", "").strip("_")
|
48 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
49 |
+
content = f.read().replace('"""', '\"\"\"')
|
50 |
+
out_file.write(f'{var_name} = """{content}"""\n\n')
|
51 |
+
|
52 |
+
|
53 |
+
if __name__ == "__main__":
|
54 |
+
call_parse_prompt()
|
eval_suite/prompts_raw/__init__.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file is generated automatically through parse_prompt.py
|
2 |
+
|
3 |
+
_video_eval_new = """# Task: Video Frame Quality Evaluation
|
4 |
+
|
5 |
+
You are tasked with analyzing and scoring a chunk of a theorem explanation video. Note that you may not have the full context of the video. Your job is to assign a score from 1 to 5 for each criterion. Please provide a brief justification for your scores.
|
6 |
+
|
7 |
+
## Evaluation Criteria
|
8 |
+
|
9 |
+
1. **Visual Consistency**
|
10 |
+
- Style Consistency: Does the visual style remain consistent across frames?
|
11 |
+
- Smoothness: Are the motions and transitions smooth?
|
12 |
+
|
13 |
+
## Scoring Instructions
|
14 |
+
1. Assign a score from **1 to 5** for each dimension:
|
15 |
+
- **1**: Very poor quality, completely fails to meet the criteria.
|
16 |
+
- **2**: Below average, significant issues present.
|
17 |
+
- **3**: Acceptable, meets the basic criteria with minor issues.
|
18 |
+
- **4**: Good, performs well with no major issues.
|
19 |
+
- **5**: Excellent, fully meets or exceeds expectations.
|
20 |
+
2. Provide a comprehensive evaluation for each dimension.
|
21 |
+
3. Format your output in **JSON**
|
22 |
+
|
23 |
+
### JSON Output Format
|
24 |
+
```json
|
25 |
+
{{
|
26 |
+
"overall_analysis": "[Provide a general assessment of the video's quality]",
|
27 |
+
"evaluation": {{
|
28 |
+
"visual_consistency": {{
|
29 |
+
"comprehensive_evaluation": "[Analysis of visual consistency]",
|
30 |
+
"score": [1-5]
|
31 |
+
}}
|
32 |
+
}}
|
33 |
+
}}
|
34 |
+
```
|
35 |
+
|
36 |
+
Description of the theorem:
|
37 |
+
{description}
|
38 |
+
|
39 |
+
Video chunk:"""
|
40 |
+
|
41 |
+
_text_eval_new = """You are a specialist in evaluating theorem explanation videos, known for giving clear and objective feedback. You will be given the transcript of a video. Your task is to evaluate and score the content of the video in several dimensions.
|
42 |
+
|
43 |
+
### Task Objective
|
44 |
+
1. Perform an overall analysis of the video.
|
45 |
+
* Identify the topic of the video.
|
46 |
+
* Note your general thoughts and impression of the video, and any findings and observations.
|
47 |
+
2. Conduct a comprehensive evaluation and score each criterion in the given dimensions.
|
48 |
+
* Analyze how well or poorly the video meets each criterion.
|
49 |
+
* Assign a score from **1 to 5** for each dimension:
|
50 |
+
- **1**: Very poor quality, completely fails to meet the criteria.
|
51 |
+
- **2**: Below average, significant issues present.
|
52 |
+
- **3**: Acceptable, meets the basic criteria with minor issues.
|
53 |
+
- **4**: Good, performs well with no major issues.
|
54 |
+
- **5**: Excellent, fully meets or exceeds expectations.
|
55 |
+
3. Output the results in the specified JSON format.
|
56 |
+
|
57 |
+
### Evaluation Criteria
|
58 |
+
1. **Accuracy and Depth**
|
59 |
+
- Does the narration explain the theorem accurately?
|
60 |
+
- Does the video provide intuitive and/or rigorous explanations for why the theorem holds?
|
61 |
+
2. **Logical Flow**
|
62 |
+
- Does the video follow a clear and logical structure?
|
63 |
+
- Does the video present a coherent buildup of ideas?
|
64 |
+
|
65 |
+
### Notes
|
66 |
+
* You do not have access to the visual portion of the video as you are given only the textual portion. Do not reference or commentate on the visuals as they will be evaluated separately - just assume that there are reasonable visuals (e.g., geometric objects, graphs of functions, and calculations) to accompany the narration.
|
67 |
+
* The evaluation criteria are intended to be independent of each other. Do not restate the same violation in multiple criteria; only consider it in the most relevant criterion.
|
68 |
+
|
69 |
+
### Output Format
|
70 |
+
```json
|
71 |
+
{{
|
72 |
+
"overall_analysis": "[Overall analysis]",
|
73 |
+
"evaluation": {{
|
74 |
+
"accuracy_and_depth": {{
|
75 |
+
"comprehensive_evaluation": "[Analysis of accuracy and depth]",
|
76 |
+
"score": [1-5]
|
77 |
+
}},
|
78 |
+
"logical_flow": {{
|
79 |
+
"comprehensive_evaluation": "[Analysis of logical flow]",
|
80 |
+
"score": [1-5]
|
81 |
+
}}
|
82 |
+
}}
|
83 |
+
}}
|
84 |
+
```
|
85 |
+
|
86 |
+
The transcript of the video is as follows:
|
87 |
+
{transcript}
|
88 |
+
"""
|
89 |
+
|
90 |
+
_fix_transcript = """You are an expert in YouTube video transcripts. There is a transcript that was automatically generated through YouTube, so it lacks proper capitalization and punctuation. Your task is to fix the transcript so that there is proper punctuation, capitalization, and spacing. Do not make other modifications (e.g., keep the original word choice).
|
91 |
+
|
92 |
+
You should enclose the fixed transcript with a <SCRIPT></SCRIPT> block, i.e.:
|
93 |
+
<SCRIPT>
|
94 |
+
(Fixed transcript here)
|
95 |
+
</SCRIPT>
|
96 |
+
|
97 |
+
Original transcript: {transcript}
|
98 |
+
"""
|
99 |
+
|
100 |
+
_image_eval = """# Task: Video Frame Quality Evaluation
|
101 |
+
|
102 |
+
You are tasked with analyzing and scoring a frame taken from a theorem explanation video. Note that you may not have the context of the video, so the captured frame may be a frame where some motion of visual elements is taking place. Your job is to assign a score from 1 to 5 for each criterion. Please provide a brief justification for your scores.
|
103 |
+
|
104 |
+
## Evaluation Criteria
|
105 |
+
|
106 |
+
1. **Visual Relevance**
|
107 |
+
- Does the video frame align with the theorem's concepts and derivations?
|
108 |
+
|
109 |
+
2. **Element Layout**
|
110 |
+
- Placemend and Size: Are the visual elements well-placed and appropriately sized within the frame?
|
111 |
+
- Overlap: Are the visual elements free of unintentional overlap?
|
112 |
+
- Clarity: Is the visual information conveyed in the frame clear and easy to understand?
|
113 |
+
|
114 |
+
## Scoring Instructions
|
115 |
+
1. Assign a score from **1 to 5** for each dimension:
|
116 |
+
- **1**: Very poor quality, completely fails to meet the criteria.
|
117 |
+
- **2**: Below average, significant issues present.
|
118 |
+
- **3**: Acceptable, meets the basic criteria with minor issues.
|
119 |
+
- **4**: Good, performs well with no major issues.
|
120 |
+
- **5**: Excellent, fully meets or exceeds expectations.
|
121 |
+
2. Provide a comprehensive evaluation for each dimension.
|
122 |
+
3. Format your output in **JSON**
|
123 |
+
|
124 |
+
### JSON Output Format
|
125 |
+
```json
|
126 |
+
{{
|
127 |
+
"overall_analysis": "[Provide a general assessment of the image's quality]",
|
128 |
+
"evaluation": {{
|
129 |
+
"visual_relevance": {{
|
130 |
+
"comprehensive_evaluation": "[Analysis of visual relevance]",
|
131 |
+
"score": [1-5]
|
132 |
+
}},
|
133 |
+
"element_layout": {{
|
134 |
+
"comprehensive_evaluation": "[Analysis of element layout]",
|
135 |
+
"score": [1-5]
|
136 |
+
}}
|
137 |
+
}}
|
138 |
+
}}
|
139 |
+
```
|
140 |
+
|
141 |
+
Description of the theorem:
|
142 |
+
{description}
|
143 |
+
|
144 |
+
Image:"""
|
145 |
+
|
eval_suite/prompts_raw/fix_transcript.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
You are an expert in YouTube video transcripts. There is a transcript that was automatically generated through YouTube, so it lacks proper capitalization and punctuation. Your task is to fix the transcript so that there is proper punctuation, capitalization, and spacing. Do not make other modifications (e.g., keep the original word choice).
|
2 |
+
|
3 |
+
You should enclose the fixed transcript with a <SCRIPT></SCRIPT> block, i.e.:
|
4 |
+
<SCRIPT>
|
5 |
+
(Fixed transcript here)
|
6 |
+
</SCRIPT>
|
7 |
+
|
8 |
+
Original transcript: {transcript}
|
eval_suite/prompts_raw/image_eval.txt
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Task: Video Frame Quality Evaluation
|
2 |
+
|
3 |
+
You are tasked with analyzing and scoring a frame taken from a theorem explanation video. Note that you may not have the context of the video, so the captured frame may be a frame where some motion of visual elements is taking place. Your job is to assign a score from 1 to 5 for each criterion. Please provide a brief justification for your scores.
|
4 |
+
|
5 |
+
## Evaluation Criteria
|
6 |
+
|
7 |
+
1. **Visual Relevance**
|
8 |
+
- Does the video frame align with the theorem's concepts and derivations?
|
9 |
+
|
10 |
+
2. **Element Layout**
|
11 |
+
- Placemend and Size: Are the visual elements well-placed and appropriately sized within the frame?
|
12 |
+
- Overlap: Are the visual elements free of unintentional overlap?
|
13 |
+
- Clarity: Is the visual information conveyed in the frame clear and easy to understand?
|
14 |
+
|
15 |
+
## Scoring Instructions
|
16 |
+
1. Assign a score from **1 to 5** for each dimension:
|
17 |
+
- **1**: Very poor quality, completely fails to meet the criteria.
|
18 |
+
- **2**: Below average, significant issues present.
|
19 |
+
- **3**: Acceptable, meets the basic criteria with minor issues.
|
20 |
+
- **4**: Good, performs well with no major issues.
|
21 |
+
- **5**: Excellent, fully meets or exceeds expectations.
|
22 |
+
2. Provide a comprehensive evaluation for each dimension.
|
23 |
+
3. Format your output in **JSON**
|
24 |
+
|
25 |
+
### JSON Output Format
|
26 |
+
```json
|
27 |
+
{{
|
28 |
+
"overall_analysis": "[Provide a general assessment of the image's quality]",
|
29 |
+
"evaluation": {{
|
30 |
+
"visual_relevance": {{
|
31 |
+
"comprehensive_evaluation": "[Analysis of visual relevance]",
|
32 |
+
"score": [1-5]
|
33 |
+
}},
|
34 |
+
"element_layout": {{
|
35 |
+
"comprehensive_evaluation": "[Analysis of element layout]",
|
36 |
+
"score": [1-5]
|
37 |
+
}}
|
38 |
+
}}
|
39 |
+
}}
|
40 |
+
```
|
41 |
+
|
42 |
+
Description of the theorem:
|
43 |
+
{description}
|
44 |
+
|
45 |
+
Image:
|
eval_suite/prompts_raw/text_eval_new.txt
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
You are a specialist in evaluating theorem explanation videos, known for giving clear and objective feedback. You will be given the transcript of a video. Your task is to evaluate and score the content of the video in several dimensions.
|
2 |
+
|
3 |
+
### Task Objective
|
4 |
+
1. Perform an overall analysis of the video.
|
5 |
+
* Identify the topic of the video.
|
6 |
+
* Note your general thoughts and impression of the video, and any findings and observations.
|
7 |
+
2. Conduct a comprehensive evaluation and score each criterion in the given dimensions.
|
8 |
+
* Analyze how well or poorly the video meets each criterion.
|
9 |
+
* Assign a score from **1 to 5** for each dimension:
|
10 |
+
- **1**: Very poor quality, completely fails to meet the criteria.
|
11 |
+
- **2**: Below average, significant issues present.
|
12 |
+
- **3**: Acceptable, meets the basic criteria with minor issues.
|
13 |
+
- **4**: Good, performs well with no major issues.
|
14 |
+
- **5**: Excellent, fully meets or exceeds expectations.
|
15 |
+
3. Output the results in the specified JSON format.
|
16 |
+
|
17 |
+
### Evaluation Criteria
|
18 |
+
1. **Accuracy and Depth**
|
19 |
+
- Does the narration explain the theorem accurately?
|
20 |
+
- Does the video provide intuitive and/or rigorous explanations for why the theorem holds?
|
21 |
+
2. **Logical Flow**
|
22 |
+
- Does the video follow a clear and logical structure?
|
23 |
+
- Does the video present a coherent buildup of ideas?
|
24 |
+
|
25 |
+
### Notes
|
26 |
+
* You do not have access to the visual portion of the video as you are given only the textual portion. Do not reference or commentate on the visuals as they will be evaluated separately - just assume that there are reasonable visuals (e.g., geometric objects, graphs of functions, and calculations) to accompany the narration.
|
27 |
+
* The evaluation criteria are intended to be independent of each other. Do not restate the same violation in multiple criteria; only consider it in the most relevant criterion.
|
28 |
+
|
29 |
+
### Output Format
|
30 |
+
```json
|
31 |
+
{{
|
32 |
+
"overall_analysis": "[Overall analysis]",
|
33 |
+
"evaluation": {{
|
34 |
+
"accuracy_and_depth": {{
|
35 |
+
"comprehensive_evaluation": "[Analysis of accuracy and depth]",
|
36 |
+
"score": [1-5]
|
37 |
+
}},
|
38 |
+
"logical_flow": {{
|
39 |
+
"comprehensive_evaluation": "[Analysis of logical flow]",
|
40 |
+
"score": [1-5]
|
41 |
+
}}
|
42 |
+
}}
|
43 |
+
}}
|
44 |
+
```
|
45 |
+
|
46 |
+
The transcript of the video is as follows:
|
47 |
+
{transcript}
|
eval_suite/prompts_raw/video_eval_new.txt
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Task: Video Frame Quality Evaluation
|
2 |
+
|
3 |
+
You are tasked with analyzing and scoring a chunk of a theorem explanation video. Note that you may not have the full context of the video. Your job is to assign a score from 1 to 5 for each criterion. Please provide a brief justification for your scores.
|
4 |
+
|
5 |
+
## Evaluation Criteria
|
6 |
+
|
7 |
+
1. **Visual Consistency**
|
8 |
+
- Style Consistency: Does the visual style remain consistent across frames?
|
9 |
+
- Smoothness: Are the motions and transitions smooth?
|
10 |
+
|
11 |
+
## Scoring Instructions
|
12 |
+
1. Assign a score from **1 to 5** for each dimension:
|
13 |
+
- **1**: Very poor quality, completely fails to meet the criteria.
|
14 |
+
- **2**: Below average, significant issues present.
|
15 |
+
- **3**: Acceptable, meets the basic criteria with minor issues.
|
16 |
+
- **4**: Good, performs well with no major issues.
|
17 |
+
- **5**: Excellent, fully meets or exceeds expectations.
|
18 |
+
2. Provide a comprehensive evaluation for each dimension.
|
19 |
+
3. Format your output in **JSON**
|
20 |
+
|
21 |
+
### JSON Output Format
|
22 |
+
```json
|
23 |
+
{{
|
24 |
+
"overall_analysis": "[Provide a general assessment of the video's quality]",
|
25 |
+
"evaluation": {{
|
26 |
+
"visual_consistency": {{
|
27 |
+
"comprehensive_evaluation": "[Analysis of visual consistency]",
|
28 |
+
"score": [1-5]
|
29 |
+
}}
|
30 |
+
}}
|
31 |
+
}}
|
32 |
+
```
|
33 |
+
|
34 |
+
Description of the theorem:
|
35 |
+
{description}
|
36 |
+
|
37 |
+
Video chunk:
|
eval_suite/text_utils.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
import pysrt
|
4 |
+
|
5 |
+
from mllm_tools.litellm import LiteLLMWrapper
|
6 |
+
from mllm_tools.gemini import GeminiWrapper
|
7 |
+
from mllm_tools.utils import _prepare_text_inputs
|
8 |
+
from eval_suite.prompts_raw import _fix_transcript, _text_eval_new
|
9 |
+
from eval_suite.utils import extract_json, convert_score_fields
|
10 |
+
|
11 |
+
|
12 |
+
def parse_srt_to_text(srt_path) -> str:
|
13 |
+
"""
|
14 |
+
Parse an SRT subtitle file into plain text.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
srt_path: Path to the SRT subtitle file.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
str: The subtitle text with duplicates removed and ellipses replaced.
|
21 |
+
"""
|
22 |
+
subs = pysrt.open(srt_path)
|
23 |
+
full_text = []
|
24 |
+
for sub in subs:
|
25 |
+
sub.text = sub.text.replace("...", ".")
|
26 |
+
for line in sub.text.splitlines():
|
27 |
+
# .srt can contain repeated lines
|
28 |
+
if full_text and full_text[-1] == line:
|
29 |
+
continue
|
30 |
+
full_text.append(line)
|
31 |
+
return "\n".join(full_text)
|
32 |
+
|
33 |
+
|
34 |
+
def fix_transcript(text_eval_model: Union[LiteLLMWrapper, GeminiWrapper], transcript: str) -> str:
|
35 |
+
"""
|
36 |
+
Fix and clean up a transcript using an LLM model.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
text_eval_model: The LLM model wrapper to use for fixing the transcript.
|
40 |
+
transcript: The input transcript text to fix.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
str: The fixed and cleaned transcript text.
|
44 |
+
"""
|
45 |
+
print("Fixing transcript...")
|
46 |
+
|
47 |
+
prompt = _fix_transcript.format(transcript=transcript)
|
48 |
+
response = text_eval_model(_prepare_text_inputs(prompt))
|
49 |
+
fixed_script = response.split("<SCRIPT>", maxsplit=1)[1].split("</SCRIPT>")[0]
|
50 |
+
|
51 |
+
return fixed_script
|
52 |
+
|
53 |
+
|
54 |
+
def evaluate_text(text_eval_model: LiteLLMWrapper, transcript: str, retry_limit: int) -> dict:
|
55 |
+
"""
|
56 |
+
Evaluate transcript text using an LLM model with retry logic.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
text_eval_model: The LLM model wrapper to use for evaluation.
|
60 |
+
transcript: The transcript text to evaluate.
|
61 |
+
retry_limit: Maximum number of retry attempts on failure.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
dict: The evaluation results as a JSON object.
|
65 |
+
|
66 |
+
Raises:
|
67 |
+
ValueError: If all retry attempts fail.
|
68 |
+
"""
|
69 |
+
# prompt = _text_eval.format(transcript=transcript)
|
70 |
+
prompt = _text_eval_new.format(transcript=transcript)
|
71 |
+
for attempt in range(retry_limit):
|
72 |
+
try:
|
73 |
+
evaluation = text_eval_model(_prepare_text_inputs(prompt))
|
74 |
+
evaluation_json = extract_json(evaluation)
|
75 |
+
evaluation_json = convert_score_fields(evaluation_json)
|
76 |
+
return evaluation_json
|
77 |
+
except Exception as e:
|
78 |
+
print(f"Attempt {attempt + 1} failed: {e.__class__.__name__}: {e}")
|
79 |
+
if attempt + 1 == retry_limit:
|
80 |
+
raise ValueError("Reached maximum retry limit. Evaluation failed.") from None
|
eval_suite/utils.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
from math import prod
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
def extract_json(response: str) -> dict:
|
7 |
+
"""
|
8 |
+
Extract JSON content from a string response.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
response (str): String containing JSON content, possibly within code blocks.
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
dict: Extracted and parsed JSON content.
|
15 |
+
|
16 |
+
Raises:
|
17 |
+
ValueError: If no valid JSON content could be extracted.
|
18 |
+
"""
|
19 |
+
try:
|
20 |
+
evaluation_json = json.loads(response)
|
21 |
+
except json.JSONDecodeError:
|
22 |
+
# If JSON parsing fails, try to extract the content between ```json and ```
|
23 |
+
match = re.search(r'```json\n(.*?)\n```', response, re.DOTALL)
|
24 |
+
if not match:
|
25 |
+
# If no match for ```json, try to extract content between ``` and ```
|
26 |
+
match = re.search(r'```\n(.*?)\n```', response, re.DOTALL)
|
27 |
+
|
28 |
+
if match:
|
29 |
+
evaluation_content = match.group(1)
|
30 |
+
evaluation_json = json.loads(evaluation_content)
|
31 |
+
else:
|
32 |
+
raise ValueError("Failed to extract valid JSON content")
|
33 |
+
return evaluation_json
|
34 |
+
|
35 |
+
|
36 |
+
def convert_score_fields(data: dict) -> dict:
|
37 |
+
"""
|
38 |
+
Convert score fields in a dictionary to integers recursively.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
data (dict): Dictionary containing score fields to convert.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
dict: Dictionary with score fields converted to integers.
|
45 |
+
|
46 |
+
Raises:
|
47 |
+
ValueError: If a score value cannot be converted to integer.
|
48 |
+
"""
|
49 |
+
# Create a new dictionary with the converted values
|
50 |
+
converted_data = {}
|
51 |
+
for key, value in data.items():
|
52 |
+
if key == "score":
|
53 |
+
if isinstance(value, int):
|
54 |
+
converted_data[key] = value
|
55 |
+
elif isinstance(value, str) and value.isdigit():
|
56 |
+
converted_data[key] = int(value)
|
57 |
+
else:
|
58 |
+
raise ValueError(f"Invalid score value: {value!r}")
|
59 |
+
elif isinstance(value, dict):
|
60 |
+
converted_data[key] = convert_score_fields(value)
|
61 |
+
else:
|
62 |
+
converted_data[key] = value
|
63 |
+
return converted_data
|
64 |
+
|
65 |
+
|
66 |
+
def calculate_geometric_mean(scores: List[int]) -> float:
|
67 |
+
"""
|
68 |
+
Calculate the geometric mean of a list of scores.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
scores (List[int]): List of integer scores, may contain None values.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
float: Geometric mean of non-None scores. Returns 0.0 if list is empty
|
75 |
+
or contains only None values.
|
76 |
+
"""
|
77 |
+
scores = [s for s in scores if s is not None]
|
78 |
+
if not scores:
|
79 |
+
return 0.0
|
80 |
+
product = prod(scores)
|
81 |
+
return product ** (1 / len(scores))
|
eval_suite/video_utils.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import tempfile
|
4 |
+
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
|
7 |
+
from mllm_tools.utils import _prepare_text_video_inputs
|
8 |
+
from eval_suite.prompts_raw import _video_eval_new
|
9 |
+
from eval_suite.utils import extract_json, convert_score_fields
|
10 |
+
|
11 |
+
load_dotenv()
|
12 |
+
|
13 |
+
|
14 |
+
def reduce_video_framerate(input_path, target_fps=1, output_path=None):
|
15 |
+
"""
|
16 |
+
Reduces the frame rate of a video by only keeping frames at the target interval.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
input_path (str): Path to the input video
|
20 |
+
target_fps (int): Target frames per second (default: 1)
|
21 |
+
output_path (str, optional): Path to save the processed video. If None, uses a temporary file.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
str: Path to the processed video
|
25 |
+
|
26 |
+
Raises:
|
27 |
+
ValueError: If input video cannot be opened or has invalid FPS
|
28 |
+
RuntimeError: If video writer initialization fails or output video creation fails
|
29 |
+
"""
|
30 |
+
cap = cv2.VideoCapture(input_path)
|
31 |
+
if not cap.isOpened():
|
32 |
+
raise ValueError(f"Could not open input video: {input_path}")
|
33 |
+
|
34 |
+
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
35 |
+
if original_fps <= 0:
|
36 |
+
raise ValueError(f"Invalid FPS ({original_fps}) detected in input video")
|
37 |
+
|
38 |
+
frame_interval = int(original_fps / target_fps)
|
39 |
+
|
40 |
+
# Get video properties
|
41 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
42 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
43 |
+
|
44 |
+
# Use provided output path or create temporary file
|
45 |
+
if output_path is None:
|
46 |
+
temp_output = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
|
47 |
+
output_path = temp_output.name
|
48 |
+
|
49 |
+
# Ensure output directory exists
|
50 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
51 |
+
|
52 |
+
# Try different codecs in order of preference
|
53 |
+
codecs = [
|
54 |
+
('avc1', '.mp4'), # H.264 codec
|
55 |
+
('mp4v', '.mp4'), # MP4V codec
|
56 |
+
('XVID', '.avi'), # XVID codec
|
57 |
+
('MJPG', '.avi'), # Motion JPEG codec
|
58 |
+
]
|
59 |
+
|
60 |
+
success = False
|
61 |
+
for codec, ext in codecs:
|
62 |
+
if output_path.endswith('.mp4') and not ext.endswith('.mp4'):
|
63 |
+
# If we're switching to AVI format, change the extension
|
64 |
+
output_path = output_path[:-4] + ext
|
65 |
+
|
66 |
+
fourcc = cv2.VideoWriter_fourcc(*codec)
|
67 |
+
out = cv2.VideoWriter(output_path, fourcc, target_fps, (width, height))
|
68 |
+
|
69 |
+
if out.isOpened():
|
70 |
+
success = True
|
71 |
+
print(f"Successfully initialized video writer with codec: {codec}")
|
72 |
+
break
|
73 |
+
else:
|
74 |
+
out.release()
|
75 |
+
if os.path.exists(output_path):
|
76 |
+
os.remove(output_path)
|
77 |
+
|
78 |
+
if not success:
|
79 |
+
raise RuntimeError("Could not initialize video writer with any available codec")
|
80 |
+
|
81 |
+
frame_count = 0
|
82 |
+
frames_written = 0
|
83 |
+
while cap.isOpened():
|
84 |
+
ret, frame = cap.read()
|
85 |
+
if not ret:
|
86 |
+
break
|
87 |
+
|
88 |
+
# Only write frames at the specified interval
|
89 |
+
if frame_count % frame_interval == 0:
|
90 |
+
out.write(frame)
|
91 |
+
frames_written += 1
|
92 |
+
frame_count += 1
|
93 |
+
|
94 |
+
cap.release()
|
95 |
+
out.release()
|
96 |
+
|
97 |
+
# Verify the output
|
98 |
+
verify_cap = cv2.VideoCapture(output_path)
|
99 |
+
if not verify_cap.isOpened():
|
100 |
+
raise RuntimeError(f"Failed to create output video at {output_path}")
|
101 |
+
|
102 |
+
actual_fps = verify_cap.get(cv2.CAP_PROP_FPS)
|
103 |
+
total_frames = verify_cap.get(cv2.CAP_PROP_FRAME_COUNT)
|
104 |
+
verify_cap.release()
|
105 |
+
|
106 |
+
if actual_fps <= 0:
|
107 |
+
print("Warning: Output video reports invalid FPS. This might be a codec issue.")
|
108 |
+
actual_fps = target_fps # Use target FPS for duration calculation
|
109 |
+
|
110 |
+
print(f"Created video with {frames_written} frames at {actual_fps} FPS")
|
111 |
+
print(f"Total duration: {total_frames/actual_fps:.2f} seconds")
|
112 |
+
print(f"Video saved to: {output_path}")
|
113 |
+
|
114 |
+
return output_path
|
115 |
+
|
116 |
+
|
117 |
+
def evaluate_video_chunk_new(model, video_path, transcript="No transcript provided", description="No description provided",
|
118 |
+
save_processed_video=None, target_fps=None, retry_limit=5):
|
119 |
+
"""
|
120 |
+
Evaluate a single video chunk using a multimodal model.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
model: The multimodal model to use for evaluation
|
124 |
+
video_path (str): Path to the video file to evaluate
|
125 |
+
transcript (str, optional): Video transcript text. Defaults to "No transcript provided"
|
126 |
+
description (str, optional): Video description text. Defaults to "No description provided"
|
127 |
+
save_processed_video (str, optional): Path to save processed video. If None, uses temporary file
|
128 |
+
target_fps (int, optional): Target frames per second for video processing. If None, no processing
|
129 |
+
retry_limit (int, optional): Maximum number of retry attempts. Defaults to 5
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
dict: Evaluation results as a JSON object with scores converted to integers
|
133 |
+
|
134 |
+
Raises:
|
135 |
+
FileNotFoundError: If video file does not exist
|
136 |
+
Exception: If evaluation fails after all retry attempts
|
137 |
+
"""
|
138 |
+
if not os.path.exists(video_path):
|
139 |
+
raise FileNotFoundError(f"Video file not found: {video_path}")
|
140 |
+
|
141 |
+
# Only process video if target_fps is specified
|
142 |
+
if target_fps is not None:
|
143 |
+
processed_video_path = reduce_video_framerate(video_path, target_fps=target_fps, output_path=save_processed_video)
|
144 |
+
video_to_use = processed_video_path
|
145 |
+
else:
|
146 |
+
video_to_use = video_path
|
147 |
+
|
148 |
+
prompt = _video_eval_new.format(description=description)
|
149 |
+
inputs = _prepare_text_video_inputs(prompt, video_to_use)
|
150 |
+
|
151 |
+
try:
|
152 |
+
for attempt in range(retry_limit):
|
153 |
+
try:
|
154 |
+
response = model(inputs)
|
155 |
+
response_json = extract_json(response)
|
156 |
+
response_json = convert_score_fields(response_json)
|
157 |
+
|
158 |
+
return response_json
|
159 |
+
except Exception as e:
|
160 |
+
print(f"Attempt {attempt + 1} failed: {e}")
|
161 |
+
if attempt + 1 == retry_limit:
|
162 |
+
print("Reached maximum retry limit. Evaluation failed.")
|
163 |
+
raise
|
164 |
+
finally:
|
165 |
+
# Clean up the temporary processed video if we created one
|
166 |
+
if target_fps is not None and save_processed_video is None and os.path.exists(processed_video_path):
|
167 |
+
os.unlink(processed_video_path)
|
mllm_tools/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Empty file to make this directory a Python package
|
mllm_tools/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (154 Bytes). View file
|
|
mllm_tools/__pycache__/gemini.cpython-312.pyc
ADDED
Binary file (8.04 kB). View file
|
|
mllm_tools/__pycache__/litellm.cpython-312.pyc
ADDED
Binary file (7.58 kB). View file
|
|
mllm_tools/__pycache__/openai.cpython-312.pyc
ADDED
Binary file (21.9 kB). View file
|
|
mllm_tools/__pycache__/openrouter.cpython-312.pyc
ADDED
Binary file (11.1 kB). View file
|
|
mllm_tools/__pycache__/utils.cpython-312.pyc
ADDED
Binary file (7.32 kB). View file
|
|
mllm_tools/__pycache__/vertex_ai.cpython-312.pyc
ADDED
Binary file (3.64 kB). View file
|
|
mllm_tools/gemini.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Any, Union, Optional
|
2 |
+
import io
|
3 |
+
import os
|
4 |
+
import base64
|
5 |
+
from PIL import Image
|
6 |
+
import mimetypes
|
7 |
+
import google.generativeai as genai
|
8 |
+
import tempfile
|
9 |
+
import time
|
10 |
+
from urllib.parse import urlparse
|
11 |
+
import requests
|
12 |
+
from io import BytesIO
|
13 |
+
|
14 |
+
class GeminiWrapper:
|
15 |
+
"""Wrapper for Gemini to support multiple models and logging"""
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
model_name: str = "gemini-1.5-pro-002",
|
20 |
+
temperature: float = 0.7,
|
21 |
+
print_cost: bool = False,
|
22 |
+
verbose: bool = False,
|
23 |
+
use_langfuse: bool = False
|
24 |
+
):
|
25 |
+
"""
|
26 |
+
Initialize the Gemini wrapper
|
27 |
+
|
28 |
+
Args:
|
29 |
+
model_name: Name of the model to use
|
30 |
+
temperature: Temperature for completion
|
31 |
+
print_cost: Whether to print the cost of the completion
|
32 |
+
verbose: Whether to print verbose output
|
33 |
+
use_langfuse: Whether to enable Langfuse logging
|
34 |
+
"""
|
35 |
+
self.model_name = model_name.split('/')[-1] if '/' in model_name else model_name
|
36 |
+
self.temperature = temperature
|
37 |
+
self.print_cost = print_cost
|
38 |
+
self.verbose = verbose
|
39 |
+
self.accumulated_cost = 0
|
40 |
+
|
41 |
+
api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
|
42 |
+
if not api_key:
|
43 |
+
raise ValueError("No API_KEY found. Please set the `GEMINI_API_KEY` or `GOOGLE_API_KEY` environment variable.")
|
44 |
+
genai.configure(api_key=api_key)
|
45 |
+
|
46 |
+
generation_config = {
|
47 |
+
"temperature": self.temperature,
|
48 |
+
"top_p": 0.95,
|
49 |
+
"response_mime_type": "text/plain",
|
50 |
+
}
|
51 |
+
safety_settings = [
|
52 |
+
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
53 |
+
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
54 |
+
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
55 |
+
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
|
56 |
+
]
|
57 |
+
self.model = genai.GenerativeModel(
|
58 |
+
model_name=self.model_name,
|
59 |
+
safety_settings=safety_settings,
|
60 |
+
generation_config=generation_config,
|
61 |
+
)
|
62 |
+
|
63 |
+
def _get_mime_type(self, file_path: str) -> str:
|
64 |
+
"""
|
65 |
+
Get the MIME type of a file based on its extension
|
66 |
+
|
67 |
+
Args:
|
68 |
+
file_path: Path to the file
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
MIME type as a string (e.g., "image/jpeg", "audio/mp3")
|
72 |
+
"""
|
73 |
+
mime_type, _ = mimetypes.guess_type(file_path)
|
74 |
+
if mime_type is None:
|
75 |
+
raise ValueError(f"Unsupported file type: {file_path}")
|
76 |
+
return mime_type
|
77 |
+
|
78 |
+
def _download_file(self, url: str) -> str:
|
79 |
+
"""
|
80 |
+
Download a file from a URL and save it as a temporary file
|
81 |
+
|
82 |
+
Args:
|
83 |
+
url: URL of the file to download
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
Path to the temporary file
|
87 |
+
"""
|
88 |
+
response = requests.get(url)
|
89 |
+
if response.status_code == 200:
|
90 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False)
|
91 |
+
temp_file.write(response.content)
|
92 |
+
temp_file.close()
|
93 |
+
return temp_file.name
|
94 |
+
else:
|
95 |
+
raise ValueError(f"Failed to download file from URL: {url}")
|
96 |
+
|
97 |
+
def _save_image_to_temp(self, image: Image.Image) -> str:
|
98 |
+
"""
|
99 |
+
Save a PIL Image to a temporary file
|
100 |
+
|
101 |
+
Args:
|
102 |
+
image: PIL Image object
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
Path to the temporary file
|
106 |
+
"""
|
107 |
+
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
108 |
+
image.save(temp_file, format="PNG")
|
109 |
+
temp_file.close()
|
110 |
+
return temp_file.name
|
111 |
+
|
112 |
+
def _upload_to_gemini(self, file_path: str, mime_type: Optional[str] = None):
|
113 |
+
"""
|
114 |
+
Uploads the given file to Gemini.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
file_path: Path to the file
|
118 |
+
mime_type: MIME type of the file
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
Uploaded file object
|
122 |
+
"""
|
123 |
+
return genai.upload_file(file_path, mime_type=mime_type)
|
124 |
+
|
125 |
+
def __call__(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None) -> str:
|
126 |
+
"""
|
127 |
+
Process messages and return completion
|
128 |
+
|
129 |
+
Args:
|
130 |
+
messages: List of message dictionaries with 'type' and 'content' keys
|
131 |
+
metadata: Optional metadata to pass to Gemini completion
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
Generated text response
|
135 |
+
"""
|
136 |
+
contents = []
|
137 |
+
for msg in messages:
|
138 |
+
if msg["type"] == "text":
|
139 |
+
contents.append(msg["content"])
|
140 |
+
elif msg["type"] in ["image", "audio", "video"]:
|
141 |
+
if isinstance(msg["content"], Image.Image):
|
142 |
+
file_path = self._save_image_to_temp(msg["content"])
|
143 |
+
mime_type = "image/png"
|
144 |
+
elif isinstance(msg["content"], str):
|
145 |
+
if msg["content"].startswith("http"):
|
146 |
+
file_path = self._download_file(msg["content"])
|
147 |
+
mime_type = self._get_mime_type(msg["content"])
|
148 |
+
else:
|
149 |
+
file_path = msg["content"]
|
150 |
+
mime_type = self._get_mime_type(file_path)
|
151 |
+
else:
|
152 |
+
raise ValueError("Unsupported content type")
|
153 |
+
|
154 |
+
uploaded_file = self._upload_to_gemini(file_path, mime_type)
|
155 |
+
|
156 |
+
while uploaded_file.state.name == "PROCESSING":
|
157 |
+
print('.', end='')
|
158 |
+
time.sleep(3)
|
159 |
+
uploaded_file = genai.get_file(uploaded_file.name)
|
160 |
+
if uploaded_file.state.name == "FAILED":
|
161 |
+
raise ValueError(uploaded_file.state.name)
|
162 |
+
print("Upload successfully")
|
163 |
+
contents.append(uploaded_file)
|
164 |
+
else:
|
165 |
+
raise ValueError("Unsupported message type")
|
166 |
+
|
167 |
+
response = self.model.generate_content(contents, request_options={"timeout": 600})
|
168 |
+
try:
|
169 |
+
return response.text
|
170 |
+
except Exception as e:
|
171 |
+
print(e)
|
172 |
+
print(response.prompt_feedback)
|
173 |
+
return str(response.prompt_feedback)
|
174 |
+
|
175 |
+
if __name__ == "__main__":
|
176 |
+
pass
|
mllm_tools/github.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# filepath: d:\Theory2Manim-2\Theory2Manim\mllm_tools\github.py
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
from typing import List, Dict, Any, Union, Optional
|
5 |
+
import io
|
6 |
+
import os
|
7 |
+
import base64
|
8 |
+
from PIL import Image
|
9 |
+
import mimetypes
|
10 |
+
import litellm
|
11 |
+
from litellm import completion, completion_cost
|
12 |
+
from dotenv import load_dotenv
|
13 |
+
|
14 |
+
load_dotenv()
|
15 |
+
|
16 |
+
class GitHubModelsWrapper:
|
17 |
+
"""Wrapper for GitHub Models using LiteLLM to support multiple GitHub hosted models"""
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
model_name: str = "github/gpt-4o",
|
22 |
+
temperature: float = 0.7,
|
23 |
+
print_cost: bool = False,
|
24 |
+
verbose: bool = False,
|
25 |
+
use_langfuse: bool = True,
|
26 |
+
github_token: Optional[str] = None
|
27 |
+
):
|
28 |
+
"""
|
29 |
+
Initialize the GitHub Models wrapper
|
30 |
+
|
31 |
+
Args:
|
32 |
+
model_name: Name of the GitHub model to use (e.g. "github/gpt-4o", "github/gpt-4o-mini",
|
33 |
+
"github/o1-preview", "github/claude-3-5-sonnet", "github/phi-3.5-mini-instruct")
|
34 |
+
temperature: Temperature for completion
|
35 |
+
print_cost: Whether to print the cost of the completion
|
36 |
+
verbose: Whether to print verbose output
|
37 |
+
use_langfuse: Whether to enable Langfuse logging
|
38 |
+
github_token: GitHub token for authentication (if not provided, will use GITHUB_TOKEN env var)
|
39 |
+
"""
|
40 |
+
self.model_name = model_name
|
41 |
+
self.temperature = temperature
|
42 |
+
self.print_cost = print_cost
|
43 |
+
self.verbose = verbose
|
44 |
+
self.accumulated_cost = 0
|
45 |
+
|
46 |
+
# Set up GitHub token
|
47 |
+
self.github_token = github_token or os.getenv('GITHUB_TOKEN')
|
48 |
+
if not self.github_token:
|
49 |
+
raise ValueError("GitHub token is required. Please set GITHUB_TOKEN environment variable or pass github_token parameter.")
|
50 |
+
|
51 |
+
# Set environment variable for LiteLLM
|
52 |
+
os.environ['GITHUB_TOKEN'] = self.github_token
|
53 |
+
|
54 |
+
if self.verbose:
|
55 |
+
os.environ['LITELLM_LOG'] = 'DEBUG'
|
56 |
+
|
57 |
+
# Set langfuse callback only if enabled
|
58 |
+
if use_langfuse:
|
59 |
+
litellm.success_callback = ["langfuse"]
|
60 |
+
litellm.failure_callback = ["langfuse"]
|
61 |
+
|
62 |
+
def _encode_file(self, file_path: Union[str, Image.Image]) -> str:
|
63 |
+
"""
|
64 |
+
Encode local file or PIL Image to base64 string
|
65 |
+
|
66 |
+
Args:
|
67 |
+
file_path: Path to local file or PIL Image object
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
Base64 encoded file string
|
71 |
+
"""
|
72 |
+
if isinstance(file_path, Image.Image):
|
73 |
+
buffered = io.BytesIO()
|
74 |
+
file_path.save(buffered, format="PNG")
|
75 |
+
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
76 |
+
else:
|
77 |
+
with open(file_path, "rb") as file:
|
78 |
+
return base64.b64encode(file.read()).decode("utf-8")
|
79 |
+
|
80 |
+
def _get_mime_type(self, file_path: str) -> str:
|
81 |
+
"""
|
82 |
+
Get the MIME type of a file based on its extension
|
83 |
+
|
84 |
+
Args:
|
85 |
+
file_path: Path to the file
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
MIME type as a string (e.g., "image/jpeg", "audio/mp3")
|
89 |
+
"""
|
90 |
+
mime_type, _ = mimetypes.guess_type(file_path)
|
91 |
+
if mime_type is None:
|
92 |
+
raise ValueError(f"Unsupported file type: {file_path}")
|
93 |
+
return mime_type
|
94 |
+
|
95 |
+
def _supports_vision(self, model_name: str) -> bool:
|
96 |
+
"""
|
97 |
+
Check if the model supports vision/image processing
|
98 |
+
|
99 |
+
Args:
|
100 |
+
model_name: Name of the model
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
True if model supports vision, False otherwise
|
104 |
+
"""
|
105 |
+
vision_models = [
|
106 |
+
"gpt-4o",
|
107 |
+
"gpt-4o-mini",
|
108 |
+
"claude-3-5-sonnet",
|
109 |
+
"claude-3-haiku"
|
110 |
+
]
|
111 |
+
|
112 |
+
# Extract model name without the github/ prefix
|
113 |
+
clean_model_name = model_name.replace("github/", "")
|
114 |
+
return any(vision_model in clean_model_name for vision_model in vision_models)
|
115 |
+
|
116 |
+
def __call__(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None) -> str:
|
117 |
+
"""
|
118 |
+
Process messages and return completion
|
119 |
+
|
120 |
+
Args:
|
121 |
+
messages: List of message dictionaries with 'type' and 'content' keys
|
122 |
+
metadata: Optional metadata to pass to litellm completion, e.g. for Langfuse tracking
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
Generated text response
|
126 |
+
"""
|
127 |
+
if metadata is None:
|
128 |
+
metadata = {}
|
129 |
+
metadata["trace_name"] = f"github-models-completion-{self.model_name}"
|
130 |
+
|
131 |
+
# Convert messages to LiteLLM format
|
132 |
+
formatted_messages = []
|
133 |
+
|
134 |
+
for msg in messages:
|
135 |
+
if msg["type"] == "text":
|
136 |
+
formatted_messages.append({
|
137 |
+
"role": "user",
|
138 |
+
"content": [{"type": "text", "text": msg["content"]}]
|
139 |
+
})
|
140 |
+
elif msg["type"] == "image":
|
141 |
+
# Check if model supports vision
|
142 |
+
if not self._supports_vision(self.model_name):
|
143 |
+
raise ValueError(f"Model {self.model_name} does not support image processing")
|
144 |
+
|
145 |
+
# Check if content is a local file path or PIL Image
|
146 |
+
if isinstance(msg["content"], Image.Image) or os.path.isfile(msg["content"]):
|
147 |
+
try:
|
148 |
+
if isinstance(msg["content"], Image.Image):
|
149 |
+
mime_type = "image/png"
|
150 |
+
else:
|
151 |
+
mime_type = self._get_mime_type(msg["content"])
|
152 |
+
base64_data = self._encode_file(msg["content"])
|
153 |
+
data_url = f"data:{mime_type};base64,{base64_data}"
|
154 |
+
except ValueError as e:
|
155 |
+
print(f"Error processing file {msg['content']}: {e}")
|
156 |
+
continue
|
157 |
+
else:
|
158 |
+
data_url = msg["content"]
|
159 |
+
|
160 |
+
# Format for vision-capable models
|
161 |
+
formatted_messages.append({
|
162 |
+
"role": "user",
|
163 |
+
"content": [
|
164 |
+
{
|
165 |
+
"type": "image_url",
|
166 |
+
"image_url": {
|
167 |
+
"url": data_url,
|
168 |
+
"detail": "high"
|
169 |
+
}
|
170 |
+
}
|
171 |
+
]
|
172 |
+
})
|
173 |
+
else:
|
174 |
+
raise ValueError(f"Unsupported message type: {msg['type']}. GitHub models currently support 'text' and 'image' types.")
|
175 |
+
|
176 |
+
try:
|
177 |
+
# Check if it's an o-series model (like o1-preview, o1-mini)
|
178 |
+
if (re.match(r".*o1.*", self.model_name)):
|
179 |
+
# O-series models don't support temperature and have reasoning_effort
|
180 |
+
response = completion(
|
181 |
+
model=self.model_name,
|
182 |
+
messages=formatted_messages,
|
183 |
+
reasoning_effort="medium", # Options: "low", "medium", "high"
|
184 |
+
metadata=metadata,
|
185 |
+
max_retries=3
|
186 |
+
)
|
187 |
+
else:
|
188 |
+
response = completion(
|
189 |
+
model=self.model_name,
|
190 |
+
messages=formatted_messages,
|
191 |
+
temperature=self.temperature,
|
192 |
+
metadata=metadata,
|
193 |
+
max_retries=3
|
194 |
+
)
|
195 |
+
|
196 |
+
if self.print_cost:
|
197 |
+
try:
|
198 |
+
# Note: GitHub Models may not provide cost information
|
199 |
+
cost = completion_cost(completion_response=response)
|
200 |
+
if cost is not None:
|
201 |
+
self.accumulated_cost += cost
|
202 |
+
print(f"Cost: ${float(cost):.10f}")
|
203 |
+
print(f"Accumulated Cost: ${self.accumulated_cost:.10f}")
|
204 |
+
else:
|
205 |
+
print("Cost information not available for GitHub Models")
|
206 |
+
except Exception as e:
|
207 |
+
print(f"Could not calculate cost: {e}")
|
208 |
+
|
209 |
+
content = response.choices[0].message.content
|
210 |
+
if content is None:
|
211 |
+
print(f"Got null response from GitHub model. Full response: {response}")
|
212 |
+
return ""
|
213 |
+
return content
|
214 |
+
|
215 |
+
except Exception as e:
|
216 |
+
print(f"Error in GitHub model completion: {e}")
|
217 |
+
return str(e)
|
218 |
+
|
219 |
+
def create_github_model_wrapper(model_name: str = "github/gpt-4o", **kwargs) -> GitHubModelsWrapper:
|
220 |
+
"""
|
221 |
+
Convenience function to create a GitHub Models wrapper
|
222 |
+
|
223 |
+
Args:
|
224 |
+
model_name: GitHub model name (e.g., "github/gpt-4o", "github/claude-3-5-sonnet")
|
225 |
+
**kwargs: Additional arguments passed to GitHubModelsWrapper
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
Configured GitHubModelsWrapper instance
|
229 |
+
|
230 |
+
Example:
|
231 |
+
>>> # Create a wrapper for GPT-4o
|
232 |
+
>>> wrapper = create_github_model_wrapper("github/gpt-4o", temperature=0.3)
|
233 |
+
>>>
|
234 |
+
>>> # Use it for text generation
|
235 |
+
>>> response = wrapper([{"type": "text", "content": "Explain quantum computing"}])
|
236 |
+
>>>
|
237 |
+
>>> # Use it for vision (if model supports it)
|
238 |
+
>>> response = wrapper([
|
239 |
+
... {"type": "text", "content": "What's in this image?"},
|
240 |
+
... {"type": "image", "content": "path/to/image.jpg"}
|
241 |
+
... ])
|
242 |
+
"""
|
243 |
+
return GitHubModelsWrapper(model_name=model_name, **kwargs)
|
244 |
+
|
245 |
+
# Available GitHub Models (as of the documentation)
|
246 |
+
AVAILABLE_MODELS = {
|
247 |
+
# GPT Models
|
248 |
+
"gpt-4o": "github/gpt-4o",
|
249 |
+
"gpt-4o-mini": "github/gpt-4o-mini",
|
250 |
+
"o1-preview": "github/o1-preview",
|
251 |
+
"o1-mini": "github/o1-mini",
|
252 |
+
"gpt-4.1": "github/gpt-4.1",
|
253 |
+
|
254 |
+
|
255 |
+
# Phi Models
|
256 |
+
"phi-3-5-mini-instruct": "github/phi-3.5-mini-instruct",
|
257 |
+
"phi-3-5-moe-instruct": "github/phi-3.5-moe-instruct",
|
258 |
+
|
259 |
+
# Llama Models
|
260 |
+
"llama-3.1-405b-instruct": "github/llama-3.1-405b-instruct",
|
261 |
+
"llama-3.1-70b-instruct": "github/llama-3.1-70b-instruct",
|
262 |
+
"llama-3.1-8b-instruct": "github/llama-3.1-8b-instruct",
|
263 |
+
|
264 |
+
# Mistral Models
|
265 |
+
"mistral-large": "github/mistral-large",
|
266 |
+
"mistral-large-2407": "github/mistral-large-2407",
|
267 |
+
"mistral-nemo": "github/mistral-nemo",
|
268 |
+
"mistral-small": "github/mistral-small",
|
269 |
+
|
270 |
+
# Cohere Models
|
271 |
+
"cohere-command-r": "github/cohere-command-r",
|
272 |
+
"cohere-command-r-plus": "github/cohere-command-r-plus",
|
273 |
+
|
274 |
+
# AI21 Models
|
275 |
+
"ai21-jamba-1.5-large": "github/ai21-jamba-1.5-large",
|
276 |
+
"ai21-jamba-1.5-mini": "github/ai21-jamba-1.5-mini"
|
277 |
+
}
|
278 |
+
|
279 |
+
def list_available_models() -> Dict[str, str]:
|
280 |
+
"""
|
281 |
+
Get a dictionary of available GitHub models
|
282 |
+
|
283 |
+
Returns:
|
284 |
+
Dictionary mapping friendly names to full model names
|
285 |
+
"""
|
286 |
+
return AVAILABLE_MODELS.copy()
|
287 |
+
|
288 |
+
if __name__ == "__main__":
|
289 |
+
# Example usage
|
290 |
+
print("Available GitHub Models:")
|
291 |
+
for friendly_name, full_name in AVAILABLE_MODELS.items():
|
292 |
+
print(f" {friendly_name}: {full_name}")
|
293 |
+
|
294 |
+
# Example of creating a wrapper (requires GITHUB_TOKEN environment variable)
|
295 |
+
try:
|
296 |
+
wrapper = create_github_model_wrapper("github/gpt-4o-mini", temperature=0.3)
|
297 |
+
print("\nGitHub Models wrapper created successfully!")
|
298 |
+
|
299 |
+
# Test with a simple text prompt
|
300 |
+
response = wrapper([{"type": "text", "content": "Hello! Can you confirm you're working?"}])
|
301 |
+
print(f"Response: {response}")
|
302 |
+
|
303 |
+
except Exception as e:
|
304 |
+
print(f"Error creating wrapper: {e}")
|
305 |
+
print("Make sure to set GITHUB_TOKEN environment variable")
|
mllm_tools/litellm.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
from typing import List, Dict, Any, Union, Optional
|
4 |
+
import io
|
5 |
+
import os
|
6 |
+
import base64
|
7 |
+
from PIL import Image
|
8 |
+
import mimetypes
|
9 |
+
import litellm
|
10 |
+
from litellm import completion, completion_cost
|
11 |
+
from dotenv import load_dotenv
|
12 |
+
|
13 |
+
load_dotenv()
|
14 |
+
|
15 |
+
class LiteLLMWrapper:
|
16 |
+
"""Wrapper for LiteLLM to support multiple models and logging"""
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
model_name: str = "gpt-4-vision-preview",
|
21 |
+
temperature: float = 0.7,
|
22 |
+
print_cost: bool = False,
|
23 |
+
verbose: bool = False,
|
24 |
+
use_langfuse: bool = True,
|
25 |
+
):
|
26 |
+
"""
|
27 |
+
Initialize the LiteLLM wrapper
|
28 |
+
|
29 |
+
Args:
|
30 |
+
model_name: Name of the model to use (e.g. "azure/gpt-4", "vertex_ai/gemini-pro")
|
31 |
+
temperature: Temperature for completion
|
32 |
+
print_cost: Whether to print the cost of the completion
|
33 |
+
verbose: Whether to print verbose output
|
34 |
+
use_langfuse: Whether to enable Langfuse logging
|
35 |
+
"""
|
36 |
+
self.model_name = model_name
|
37 |
+
self.temperature = temperature
|
38 |
+
self.print_cost = print_cost
|
39 |
+
self.verbose = verbose
|
40 |
+
self.accumulated_cost = 0
|
41 |
+
|
42 |
+
if self.verbose:
|
43 |
+
os.environ['LITELLM_LOG'] = 'DEBUG'
|
44 |
+
|
45 |
+
# Set langfuse callback only if enabled
|
46 |
+
if use_langfuse:
|
47 |
+
litellm.success_callback = ["langfuse"]
|
48 |
+
litellm.failure_callback = ["langfuse"]
|
49 |
+
|
50 |
+
def _encode_file(self, file_path: Union[str, Image.Image]) -> str:
|
51 |
+
"""
|
52 |
+
Encode local file or PIL Image to base64 string
|
53 |
+
|
54 |
+
Args:
|
55 |
+
file_path: Path to local file or PIL Image object
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
Base64 encoded file string
|
59 |
+
"""
|
60 |
+
if isinstance(file_path, Image.Image):
|
61 |
+
buffered = io.BytesIO()
|
62 |
+
file_path.save(buffered, format="PNG")
|
63 |
+
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
64 |
+
else:
|
65 |
+
with open(file_path, "rb") as file:
|
66 |
+
return base64.b64encode(file.read()).decode("utf-8")
|
67 |
+
|
68 |
+
def _get_mime_type(self, file_path: str) -> str:
|
69 |
+
"""
|
70 |
+
Get the MIME type of a file based on its extension
|
71 |
+
|
72 |
+
Args:
|
73 |
+
file_path: Path to the file
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
MIME type as a string (e.g., "image/jpeg", "audio/mp3")
|
77 |
+
"""
|
78 |
+
mime_type, _ = mimetypes.guess_type(file_path)
|
79 |
+
if mime_type is None:
|
80 |
+
raise ValueError(f"Unsupported file type: {file_path}")
|
81 |
+
return mime_type
|
82 |
+
|
83 |
+
def __call__(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None) -> str:
|
84 |
+
"""
|
85 |
+
Process messages and return completion
|
86 |
+
|
87 |
+
Args:
|
88 |
+
messages: List of message dictionaries with 'type' and 'content' keys
|
89 |
+
metadata: Optional metadata to pass to litellm completion, e.g. for Langfuse tracking
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
Generated text response
|
93 |
+
"""
|
94 |
+
if metadata is None:
|
95 |
+
print("No metadata provided, using empty metadata")
|
96 |
+
metadata = {}
|
97 |
+
metadata["trace_name"] = f"litellm-completion-{self.model_name}"
|
98 |
+
# Convert messages to LiteLLM format
|
99 |
+
formatted_messages = []
|
100 |
+
for msg in messages:
|
101 |
+
if msg["type"] == "text":
|
102 |
+
formatted_messages.append({
|
103 |
+
"role": "user",
|
104 |
+
"content": [{"type": "text", "text": msg["content"]}]
|
105 |
+
})
|
106 |
+
elif msg["type"] in ["image", "audio", "video"]:
|
107 |
+
# Check if content is a local file path or PIL Image
|
108 |
+
if isinstance(msg["content"], Image.Image) or os.path.isfile(msg["content"]):
|
109 |
+
try:
|
110 |
+
if isinstance(msg["content"], Image.Image):
|
111 |
+
mime_type = "image/png"
|
112 |
+
else:
|
113 |
+
mime_type = self._get_mime_type(msg["content"])
|
114 |
+
base64_data = self._encode_file(msg["content"])
|
115 |
+
data_url = f"data:{mime_type};base64,{base64_data}"
|
116 |
+
except ValueError as e:
|
117 |
+
print(f"Error processing file {msg['content']}: {e}")
|
118 |
+
continue
|
119 |
+
else:
|
120 |
+
data_url = msg["content"]
|
121 |
+
|
122 |
+
# Append the formatted message based on the model
|
123 |
+
if "gemini" in self.model_name:
|
124 |
+
formatted_messages.append({
|
125 |
+
"role": "user",
|
126 |
+
"content": [
|
127 |
+
{
|
128 |
+
"type": "image_url",
|
129 |
+
"image_url": data_url
|
130 |
+
}
|
131 |
+
]
|
132 |
+
})
|
133 |
+
elif "gpt" in self.model_name:
|
134 |
+
# GPT and other models expect a different format
|
135 |
+
if msg["type"] == "image":
|
136 |
+
# Default format for images and videos in GPT
|
137 |
+
formatted_messages.append({
|
138 |
+
"role": "user",
|
139 |
+
"content": [
|
140 |
+
{
|
141 |
+
"type": f"image_url",
|
142 |
+
f"{msg['type']}_url": {
|
143 |
+
"url": data_url,
|
144 |
+
"detail": "high"
|
145 |
+
}
|
146 |
+
}
|
147 |
+
]
|
148 |
+
})
|
149 |
+
else:
|
150 |
+
raise ValueError("For GPT, only text and image inferencing are supported")
|
151 |
+
else:
|
152 |
+
raise ValueError("Only support Gemini and Gpt for Multimodal capability now")
|
153 |
+
|
154 |
+
try:
|
155 |
+
# if it's openai o series model, set temperature to None and reasoning_effort to "medium"
|
156 |
+
if (re.match(r"^o\d+.*$", self.model_name) or re.match(r"^openai/o.*$", self.model_name)):
|
157 |
+
self.temperature = None
|
158 |
+
self.reasoning_effort = "medium"
|
159 |
+
response = completion(
|
160 |
+
model=self.model_name,
|
161 |
+
messages=formatted_messages,
|
162 |
+
temperature=self.temperature,
|
163 |
+
reasoning_effort=self.reasoning_effort,
|
164 |
+
metadata=metadata,
|
165 |
+
max_retries=99
|
166 |
+
)
|
167 |
+
else:
|
168 |
+
response = completion(
|
169 |
+
model=self.model_name,
|
170 |
+
messages=formatted_messages,
|
171 |
+
temperature=self.temperature,
|
172 |
+
metadata=metadata,
|
173 |
+
max_retries=99
|
174 |
+
)
|
175 |
+
if self.print_cost:
|
176 |
+
# pass your response from completion to completion_cost
|
177 |
+
cost = completion_cost(completion_response=response)
|
178 |
+
formatted_string = f"Cost: ${float(cost):.10f}"
|
179 |
+
# print(formatted_string)
|
180 |
+
self.accumulated_cost += cost
|
181 |
+
print(f"Accumulated Cost: ${self.accumulated_cost:.10f}")
|
182 |
+
|
183 |
+
content = response.choices[0].message.content
|
184 |
+
if content is None:
|
185 |
+
print(f"Got null response from model. Full response: {response}")
|
186 |
+
return content
|
187 |
+
|
188 |
+
except Exception as e:
|
189 |
+
print(f"Error in model completion: {e}")
|
190 |
+
return str(e)
|
191 |
+
|
192 |
+
if __name__ == "__main__":
|
193 |
+
pass
|
mllm_tools/openai.py
ADDED
@@ -0,0 +1,594 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# filepath: d:\Theory2Manim-2\Theory2Manim\mllm_tools\openai.py
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
from typing import List, Dict, Any, Union, Optional
|
5 |
+
import io
|
6 |
+
import os
|
7 |
+
import base64
|
8 |
+
from PIL import Image
|
9 |
+
import mimetypes
|
10 |
+
import litellm
|
11 |
+
from litellm import completion, completion_cost
|
12 |
+
from dotenv import load_dotenv
|
13 |
+
|
14 |
+
# Load environment variables from .env file
|
15 |
+
load_dotenv()
|
16 |
+
|
17 |
+
# Note: Environment variables should be loaded from .env file or set manually using os.environ
|
18 |
+
|
19 |
+
class OpenAIWrapper:
|
20 |
+
"""Wrapper for OpenAI using LiteLLM to support all OpenAI models with unified interface"""
|
21 |
+
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
model_name: str = "gpt-4o",
|
25 |
+
temperature: float = 0.7,
|
26 |
+
print_cost: bool = False,
|
27 |
+
verbose: bool = False,
|
28 |
+
use_langfuse: bool = True,
|
29 |
+
api_key: Optional[str] = None,
|
30 |
+
organization: Optional[str] = None,
|
31 |
+
base_url: Optional[str] = None,
|
32 |
+
use_github_token: bool = True,
|
33 |
+
github_token: Optional[str] = os.getenv('GITHUB_TOKEN')
|
34 |
+
):
|
35 |
+
"""
|
36 |
+
Initialize the OpenAI wrapper
|
37 |
+
|
38 |
+
Args:
|
39 |
+
model_name: Name of the OpenAI model to use (e.g. "gpt-4o", "gpt-4o-mini",
|
40 |
+
"gpt-3.5-turbo", "o1-preview", "o1-mini", "dall-e-3")
|
41 |
+
temperature: Temperature for completion (ignored for o1 models)
|
42 |
+
print_cost: Whether to print the cost of the completion
|
43 |
+
verbose: Whether to print verbose output
|
44 |
+
use_langfuse: Whether to enable Langfuse logging
|
45 |
+
api_key: OpenAI API key (if not provided, will use OPENAI_API_KEY env var)
|
46 |
+
organization: OpenAI organization ID (optional)
|
47 |
+
base_url: Custom base URL for OpenAI API (optional, for proxies)
|
48 |
+
use_github_token: Whether to use GitHub AI model inference endpoint
|
49 |
+
github_token: GitHub token (if not provided, will use GITHUB_TOKEN env var)
|
50 |
+
"""
|
51 |
+
self.model_name = model_name
|
52 |
+
self.temperature = temperature
|
53 |
+
self.print_cost = print_cost
|
54 |
+
self.verbose = verbose
|
55 |
+
self.accumulated_cost = 0
|
56 |
+
self.use_github_token = use_github_token
|
57 |
+
|
58 |
+
# Configure API based on whether using GitHub token or OpenAI API
|
59 |
+
if use_github_token:
|
60 |
+
# Set up GitHub token and endpoint
|
61 |
+
self.github_token = github_token or os.getenv('GITHUB_TOKEN')
|
62 |
+
if not self.github_token:
|
63 |
+
raise ValueError("GitHub token is required when use_github_token=True. Please set GITHUB_TOKEN environment variable or pass github_token parameter.")
|
64 |
+
|
65 |
+
# Set GitHub AI inference endpoint
|
66 |
+
self.base_url = "https://models.github.ai/inference"
|
67 |
+
self.api_key = self.github_token
|
68 |
+
|
69 |
+
# Set environment variables for LiteLLM to use GitHub endpoint
|
70 |
+
os.environ['OPENAI_API_KEY'] = self.github_token
|
71 |
+
os.environ['OPENAI_BASE_URL'] = self.base_url
|
72 |
+
|
73 |
+
# Adjust model name for GitHub endpoint (add openai/ prefix if not present)
|
74 |
+
if not self.model_name.startswith("openai/"):
|
75 |
+
self.model_name = f"openai/{self.model_name}"
|
76 |
+
|
77 |
+
else:
|
78 |
+
# Original OpenAI API setup
|
79 |
+
self.api_key = api_key or os.getenv('OPENAI_API_KEY')
|
80 |
+
if not self.api_key:
|
81 |
+
raise ValueError("OpenAI API key is required. Please set OPENAI_API_KEY environment variable or pass api_key parameter.")
|
82 |
+
|
83 |
+
# Set environment variables for LiteLLM
|
84 |
+
os.environ['OPENAI_API_KEY'] = self.api_key
|
85 |
+
|
86 |
+
# Set optional custom base URL
|
87 |
+
if base_url:
|
88 |
+
os.environ['OPENAI_BASE_URL'] = base_url
|
89 |
+
self.base_url = base_url
|
90 |
+
else:
|
91 |
+
self.base_url = os.getenv('OPENAI_BASE_URL')
|
92 |
+
|
93 |
+
# Set optional organization (only for OpenAI, not GitHub)
|
94 |
+
if not use_github_token:
|
95 |
+
if organization:
|
96 |
+
os.environ['OPENAI_ORGANIZATION'] = organization
|
97 |
+
self.organization = organization
|
98 |
+
else:
|
99 |
+
self.organization = os.getenv('OPENAI_ORGANIZATION')
|
100 |
+
else:
|
101 |
+
self.organization = None
|
102 |
+
|
103 |
+
if self.verbose:
|
104 |
+
os.environ['LITELLM_LOG'] = 'DEBUG'
|
105 |
+
|
106 |
+
# Set langfuse callback only if enabled
|
107 |
+
if use_langfuse:
|
108 |
+
litellm.success_callback = ["langfuse"]
|
109 |
+
litellm.failure_callback = ["langfuse"]
|
110 |
+
|
111 |
+
def _encode_file(self, file_path: Union[str, Image.Image]) -> str:
|
112 |
+
"""
|
113 |
+
Encode local file or PIL Image to base64 string
|
114 |
+
|
115 |
+
Args:
|
116 |
+
file_path: Path to local file or PIL Image object
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
Base64 encoded file string
|
120 |
+
"""
|
121 |
+
if isinstance(file_path, Image.Image):
|
122 |
+
buffered = io.BytesIO()
|
123 |
+
file_path.save(buffered, format="PNG")
|
124 |
+
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
125 |
+
else:
|
126 |
+
with open(file_path, "rb") as file:
|
127 |
+
return base64.b64encode(file.read()).decode("utf-8")
|
128 |
+
|
129 |
+
def _get_mime_type(self, file_path: str) -> str:
|
130 |
+
"""
|
131 |
+
Get the MIME type of a file based on its extension
|
132 |
+
|
133 |
+
Args:
|
134 |
+
file_path: Path to the file
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
MIME type as a string (e.g., "image/jpeg", "application/pdf")
|
138 |
+
"""
|
139 |
+
mime_type, _ = mimetypes.guess_type(file_path)
|
140 |
+
if mime_type is None:
|
141 |
+
raise ValueError(f"Unsupported file type: {file_path}")
|
142 |
+
return mime_type
|
143 |
+
|
144 |
+
def _supports_vision(self, model_name: str) -> bool:
|
145 |
+
"""
|
146 |
+
Check if the model supports vision/image processing
|
147 |
+
|
148 |
+
Args:
|
149 |
+
model_name: Name of the model
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
True if model supports vision, False otherwise
|
153 |
+
"""
|
154 |
+
vision_models = [
|
155 |
+
"gpt-4o",
|
156 |
+
"gpt-4o-mini",
|
157 |
+
"gpt-4-vision-preview",
|
158 |
+
"gpt-4-turbo",
|
159 |
+
"gpt-4-turbo-vision"
|
160 |
+
]
|
161 |
+
|
162 |
+
return any(vision_model in model_name for vision_model in vision_models)
|
163 |
+
|
164 |
+
def _supports_files(self, model_name: str) -> bool:
|
165 |
+
"""
|
166 |
+
Check if the model supports file processing (PDFs, documents)
|
167 |
+
|
168 |
+
Args:
|
169 |
+
model_name: Name of the model
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
True if model supports file processing, False otherwise
|
173 |
+
"""
|
174 |
+
file_models = [
|
175 |
+
"gpt-4o",
|
176 |
+
"gpt-4o-mini",
|
177 |
+
"gpt-4-turbo"
|
178 |
+
]
|
179 |
+
|
180 |
+
return any(file_model in model_name for file_model in file_models)
|
181 |
+
|
182 |
+
def _is_o1_model(self, model_name: str) -> bool:
|
183 |
+
"""
|
184 |
+
Check if the model is an o1 series model (reasoning models)
|
185 |
+
|
186 |
+
Args:
|
187 |
+
model_name: Name of the model
|
188 |
+
|
189 |
+
Returns:
|
190 |
+
True if it's an o1 model, False otherwise
|
191 |
+
"""
|
192 |
+
return "o1" in model_name
|
193 |
+
|
194 |
+
def __call__(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None, **kwargs) -> str:
|
195 |
+
"""
|
196 |
+
Process messages and return completion
|
197 |
+
|
198 |
+
Args:
|
199 |
+
messages: List of message dictionaries with 'type' and 'content' keys
|
200 |
+
metadata: Optional metadata to pass to litellm completion, e.g. for Langfuse tracking
|
201 |
+
**kwargs: Additional parameters for completion (max_tokens, stream, etc.)
|
202 |
+
|
203 |
+
Returns:
|
204 |
+
Generated text response
|
205 |
+
"""
|
206 |
+
if metadata is None:
|
207 |
+
metadata = {}
|
208 |
+
metadata["trace_name"] = f"openai-completion-{self.model_name}"
|
209 |
+
|
210 |
+
# Convert messages to LiteLLM format
|
211 |
+
formatted_messages = []
|
212 |
+
|
213 |
+
for msg in messages:
|
214 |
+
if msg["type"] == "text":
|
215 |
+
formatted_messages.append({
|
216 |
+
"role": "user",
|
217 |
+
"content": [{"type": "text", "text": msg["content"]}]
|
218 |
+
})
|
219 |
+
elif msg["type"] == "image":
|
220 |
+
# Check if model supports vision
|
221 |
+
if not self._supports_vision(self.model_name):
|
222 |
+
raise ValueError(f"Model {self.model_name} does not support image processing")
|
223 |
+
|
224 |
+
# Check if content is a local file path or PIL Image
|
225 |
+
if isinstance(msg["content"], Image.Image) or (isinstance(msg["content"], str) and os.path.isfile(msg["content"])):
|
226 |
+
try:
|
227 |
+
if isinstance(msg["content"], Image.Image):
|
228 |
+
mime_type = "image/png"
|
229 |
+
else:
|
230 |
+
mime_type = self._get_mime_type(msg["content"])
|
231 |
+
base64_data = self._encode_file(msg["content"])
|
232 |
+
data_url = f"data:{mime_type};base64,{base64_data}"
|
233 |
+
except ValueError as e:
|
234 |
+
print(f"Error processing file {msg['content']}: {e}")
|
235 |
+
continue
|
236 |
+
else:
|
237 |
+
# Assume it's already a URL or base64 string
|
238 |
+
data_url = msg["content"]
|
239 |
+
|
240 |
+
# Format for vision-capable models
|
241 |
+
formatted_messages.append({
|
242 |
+
"role": "user",
|
243 |
+
"content": [
|
244 |
+
{
|
245 |
+
"type": "image_url",
|
246 |
+
"image_url": {
|
247 |
+
"url": data_url,
|
248 |
+
"detail": "high"
|
249 |
+
}
|
250 |
+
}
|
251 |
+
]
|
252 |
+
})
|
253 |
+
elif msg["type"] == "file":
|
254 |
+
# Check if model supports file processing
|
255 |
+
if not self._supports_files(self.model_name):
|
256 |
+
raise ValueError(f"Model {self.model_name} does not support file processing")
|
257 |
+
|
258 |
+
# Handle file content (PDF, documents, etc.)
|
259 |
+
if os.path.isfile(msg["content"]):
|
260 |
+
try:
|
261 |
+
mime_type = self._get_mime_type(msg["content"])
|
262 |
+
base64_data = self._encode_file(msg["content"])
|
263 |
+
|
264 |
+
# Use the file format for document processing
|
265 |
+
formatted_messages.append({
|
266 |
+
"role": "user",
|
267 |
+
"content": [
|
268 |
+
{
|
269 |
+
"type": "file",
|
270 |
+
"file": {
|
271 |
+
"filename": os.path.basename(msg["content"]),
|
272 |
+
"file_data": f"data:{mime_type};base64,{base64_data}",
|
273 |
+
}
|
274 |
+
}
|
275 |
+
]
|
276 |
+
})
|
277 |
+
except ValueError as e:
|
278 |
+
print(f"Error processing file {msg['content']}: {e}")
|
279 |
+
continue
|
280 |
+
else:
|
281 |
+
raise ValueError(f"File not found: {msg['content']}")
|
282 |
+
else:
|
283 |
+
raise ValueError(f"Unsupported message type: {msg['type']}. OpenAI models support 'text', 'image', and 'file' types.")
|
284 |
+
|
285 |
+
try:
|
286 |
+
# Prepare completion parameters
|
287 |
+
completion_params = {
|
288 |
+
"model": self.model_name,
|
289 |
+
"messages": formatted_messages,
|
290 |
+
"metadata": metadata,
|
291 |
+
"max_retries": 3
|
292 |
+
}
|
293 |
+
|
294 |
+
# Add additional kwargs
|
295 |
+
completion_params.update(kwargs)
|
296 |
+
|
297 |
+
# Check if it's an o1 series model (reasoning models)
|
298 |
+
if self._is_o1_model(self.model_name):
|
299 |
+
# O1 models don't support temperature and have reasoning_effort
|
300 |
+
if "reasoning_effort" not in completion_params:
|
301 |
+
completion_params["reasoning_effort"] = "medium" # Options: "low", "medium", "high"
|
302 |
+
# Remove temperature if it was added via kwargs
|
303 |
+
completion_params.pop("temperature", None)
|
304 |
+
else:
|
305 |
+
# Regular models support temperature
|
306 |
+
if "temperature" not in completion_params:
|
307 |
+
completion_params["temperature"] = self.temperature
|
308 |
+
|
309 |
+
response = completion(**completion_params)
|
310 |
+
|
311 |
+
if self.print_cost:
|
312 |
+
try:
|
313 |
+
cost = completion_cost(completion_response=response)
|
314 |
+
if cost is not None:
|
315 |
+
self.accumulated_cost += cost
|
316 |
+
print(f"Cost: ${float(cost):.10f}")
|
317 |
+
print(f"Accumulated Cost: ${self.accumulated_cost:.10f}")
|
318 |
+
else:
|
319 |
+
print("Cost information not available")
|
320 |
+
except Exception as e:
|
321 |
+
print(f"Could not calculate cost: {e}")
|
322 |
+
|
323 |
+
content = response.choices[0].message.content
|
324 |
+
if content is None:
|
325 |
+
print(f"Got null response from OpenAI model. Full response: {response}")
|
326 |
+
return ""
|
327 |
+
return content
|
328 |
+
|
329 |
+
except Exception as e:
|
330 |
+
print(f"Error in OpenAI model completion: {e}")
|
331 |
+
return str(e)
|
332 |
+
|
333 |
+
def stream_completion(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None, **kwargs):
|
334 |
+
"""
|
335 |
+
Process messages and return streaming completion
|
336 |
+
|
337 |
+
Args:
|
338 |
+
messages: List of message dictionaries with 'type' and 'content' keys
|
339 |
+
metadata: Optional metadata to pass to litellm completion
|
340 |
+
**kwargs: Additional parameters for completion
|
341 |
+
|
342 |
+
Yields:
|
343 |
+
Streaming response chunks
|
344 |
+
"""
|
345 |
+
kwargs["stream"] = True
|
346 |
+
|
347 |
+
# Use the same message formatting as regular completion
|
348 |
+
if metadata is None:
|
349 |
+
metadata = {}
|
350 |
+
metadata["trace_name"] = f"openai-streaming-{self.model_name}"
|
351 |
+
|
352 |
+
try:
|
353 |
+
# Convert messages to the same format as __call__
|
354 |
+
formatted_messages = []
|
355 |
+
|
356 |
+
for msg in messages:
|
357 |
+
if msg["type"] == "text":
|
358 |
+
formatted_messages.append({
|
359 |
+
"role": "user",
|
360 |
+
"content": [{"type": "text", "text": msg["content"]}]
|
361 |
+
})
|
362 |
+
elif msg["type"] == "image":
|
363 |
+
if not self._supports_vision(self.model_name):
|
364 |
+
raise ValueError(f"Model {self.model_name} does not support image processing")
|
365 |
+
|
366 |
+
if isinstance(msg["content"], Image.Image) or (isinstance(msg["content"], str) and os.path.isfile(msg["content"])):
|
367 |
+
try:
|
368 |
+
if isinstance(msg["content"], Image.Image):
|
369 |
+
mime_type = "image/png"
|
370 |
+
else:
|
371 |
+
mime_type = self._get_mime_type(msg["content"])
|
372 |
+
base64_data = self._encode_file(msg["content"])
|
373 |
+
data_url = f"data:{mime_type};base64,{base64_data}"
|
374 |
+
except ValueError as e:
|
375 |
+
print(f"Error processing file {msg['content']}: {e}")
|
376 |
+
continue
|
377 |
+
else:
|
378 |
+
data_url = msg["content"]
|
379 |
+
|
380 |
+
formatted_messages.append({
|
381 |
+
"role": "user",
|
382 |
+
"content": [
|
383 |
+
{
|
384 |
+
"type": "image_url",
|
385 |
+
"image_url": {
|
386 |
+
"url": data_url,
|
387 |
+
"detail": "high"
|
388 |
+
}
|
389 |
+
}
|
390 |
+
]
|
391 |
+
})
|
392 |
+
|
393 |
+
# Prepare completion parameters
|
394 |
+
completion_params = {
|
395 |
+
"model": self.model_name,
|
396 |
+
"messages": formatted_messages,
|
397 |
+
"metadata": metadata,
|
398 |
+
"max_retries": 3,
|
399 |
+
"stream": True
|
400 |
+
}
|
401 |
+
|
402 |
+
# Add additional kwargs
|
403 |
+
completion_params.update(kwargs)
|
404 |
+
|
405 |
+
# Handle o1 models
|
406 |
+
if self._is_o1_model(self.model_name):
|
407 |
+
if "reasoning_effort" not in completion_params:
|
408 |
+
completion_params["reasoning_effort"] = "medium"
|
409 |
+
completion_params.pop("temperature", None)
|
410 |
+
else:
|
411 |
+
if "temperature" not in completion_params:
|
412 |
+
completion_params["temperature"] = self.temperature
|
413 |
+
|
414 |
+
response = completion(**completion_params)
|
415 |
+
|
416 |
+
# Yield streaming chunks
|
417 |
+
for chunk in response:
|
418 |
+
yield chunk
|
419 |
+
|
420 |
+
except Exception as e:
|
421 |
+
print(f"Error in OpenAI streaming completion: {e}")
|
422 |
+
yield {"error": str(e)}
|
423 |
+
|
424 |
+
def create_openai_wrapper(model_name: str = "gpt-4o", use_github: bool = False, **kwargs) -> OpenAIWrapper:
|
425 |
+
"""
|
426 |
+
Convenience function to create an OpenAI wrapper
|
427 |
+
|
428 |
+
Args:
|
429 |
+
model_name: OpenAI model name (e.g., "gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo")
|
430 |
+
use_github: Whether to use GitHub's AI model inference endpoint
|
431 |
+
**kwargs: Additional arguments passed to OpenAIWrapper
|
432 |
+
|
433 |
+
Returns:
|
434 |
+
Configured OpenAIWrapper instance
|
435 |
+
|
436 |
+
Example:
|
437 |
+
>>> # Create a wrapper for GPT-4o using regular OpenAI
|
438 |
+
>>> wrapper = create_openai_wrapper("gpt-4o", temperature=0.3)
|
439 |
+
>>>
|
440 |
+
>>> # Create a wrapper for GPT-4o using GitHub AI models
|
441 |
+
>>> wrapper = create_openai_wrapper("gpt-4o", use_github=True, temperature=0.3)
|
442 |
+
>>>
|
443 |
+
>>> # Use it for text generation
|
444 |
+
>>> response = wrapper([{"type": "text", "content": "Explain quantum computing"}])
|
445 |
+
>>>
|
446 |
+
>>> # Use it for vision (if model supports it)
|
447 |
+
>>> response = wrapper([
|
448 |
+
... {"type": "text", "content": "What's in this image?"},
|
449 |
+
... {"type": "image", "content": "path/to/image.jpg"}
|
450 |
+
... ])
|
451 |
+
>>>
|
452 |
+
>>> # Use it for file processing (PDFs, etc.)
|
453 |
+
>>> response = wrapper([
|
454 |
+
... {"type": "text", "content": "Summarize this document"},
|
455 |
+
... {"type": "file", "content": "path/to/document.pdf"}
|
456 |
+
... ])
|
457 |
+
"""
|
458 |
+
return OpenAIWrapper(model_name=model_name, use_github_token=use_github, **kwargs)
|
459 |
+
|
460 |
+
# Available OpenAI Models
|
461 |
+
AVAILABLE_MODELS = {
|
462 |
+
# GPT-4 Models
|
463 |
+
"gpt-4o": "gpt-4o",
|
464 |
+
"gpt-4o-mini": "gpt-4o-mini",
|
465 |
+
"gpt-4-turbo": "gpt-4-turbo",
|
466 |
+
"gpt-4": "gpt-4",
|
467 |
+
"gpt-4-vision-preview": "gpt-4-vision-preview",
|
468 |
+
|
469 |
+
# O1 Reasoning Models
|
470 |
+
"o1-preview": "o1-preview",
|
471 |
+
"o1-mini": "o1-mini",
|
472 |
+
|
473 |
+
# GPT-3.5 Models
|
474 |
+
"gpt-3.5-turbo": "gpt-3.5-turbo",
|
475 |
+
"gpt-3.5-turbo-instruct": "gpt-3.5-turbo-instruct",
|
476 |
+
|
477 |
+
# Image Generation Models
|
478 |
+
"dall-e-3": "dall-e-3",
|
479 |
+
"dall-e-2": "dall-e-2",
|
480 |
+
|
481 |
+
# Embedding Models
|
482 |
+
"text-embedding-3-large": "text-embedding-3-large",
|
483 |
+
"text-embedding-3-small": "text-embedding-3-small",
|
484 |
+
"text-embedding-ada-002": "text-embedding-ada-002",
|
485 |
+
|
486 |
+
# Audio Models
|
487 |
+
"whisper-1": "whisper-1",
|
488 |
+
"tts-1": "tts-1",
|
489 |
+
"tts-1-hd": "tts-1-hd",
|
490 |
+
}
|
491 |
+
|
492 |
+
def create_github_openai_wrapper(model_name: str = "gpt-4o", **kwargs) -> OpenAIWrapper:
|
493 |
+
"""
|
494 |
+
Convenience function to create an OpenAI wrapper using GitHub's AI model inference
|
495 |
+
|
496 |
+
Args:
|
497 |
+
model_name: OpenAI model name (e.g., "gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo")
|
498 |
+
**kwargs: Additional arguments passed to OpenAIWrapper
|
499 |
+
|
500 |
+
Returns:
|
501 |
+
Configured OpenAIWrapper instance using GitHub endpoint
|
502 |
+
|
503 |
+
Example:
|
504 |
+
>>> # Create a wrapper for GPT-4o using GitHub AI models
|
505 |
+
>>> wrapper = create_github_openai_wrapper("gpt-4o", temperature=0.3)
|
506 |
+
>>>
|
507 |
+
>>> # Use it for text generation
|
508 |
+
>>> response = wrapper([{"type": "text", "content": "What is the capital of France?"}])
|
509 |
+
"""
|
510 |
+
return OpenAIWrapper(model_name=model_name, use_github_token=True, **kwargs)
|
511 |
+
|
512 |
+
def list_available_models() -> Dict[str, str]:
|
513 |
+
"""
|
514 |
+
Get a dictionary of available OpenAI models
|
515 |
+
|
516 |
+
Returns:
|
517 |
+
Dictionary mapping model names to their identifiers
|
518 |
+
"""
|
519 |
+
return AVAILABLE_MODELS.copy()
|
520 |
+
|
521 |
+
def get_model_capabilities(model_name: str) -> Dict[str, bool]:
|
522 |
+
"""
|
523 |
+
Get the capabilities of a specific model
|
524 |
+
|
525 |
+
Args:
|
526 |
+
model_name: Name of the model
|
527 |
+
|
528 |
+
Returns:
|
529 |
+
Dictionary of capabilities (vision, files, reasoning, etc.)
|
530 |
+
"""
|
531 |
+
wrapper = OpenAIWrapper(model_name=model_name)
|
532 |
+
|
533 |
+
return {
|
534 |
+
"vision": wrapper._supports_vision(model_name),
|
535 |
+
"files": wrapper._supports_files(model_name),
|
536 |
+
"reasoning": wrapper._is_o1_model(model_name),
|
537 |
+
"streaming": not wrapper._is_o1_model(model_name), # O1 models don't support streaming
|
538 |
+
"temperature": not wrapper._is_o1_model(model_name), # O1 models don't support temperature
|
539 |
+
}
|
540 |
+
|
541 |
+
if __name__ == "__main__":
|
542 |
+
# Example usage
|
543 |
+
print("Available OpenAI Models:")
|
544 |
+
for model_name, model_id in AVAILABLE_MODELS.items():
|
545 |
+
capabilities = get_model_capabilities(model_name)
|
546 |
+
print(f" {model_name} ({model_id}): {capabilities}")
|
547 |
+
|
548 |
+
print("\n" + "="*50)
|
549 |
+
print("Testing OpenAI wrapper...")
|
550 |
+
|
551 |
+
# Example 1: Regular OpenAI (requires OPENAI_API_KEY environment variable)
|
552 |
+
try:
|
553 |
+
print("\n1. Testing regular OpenAI wrapper:")
|
554 |
+
wrapper = create_openai_wrapper("gpt-4o-mini", temperature=0.3)
|
555 |
+
print("Regular OpenAI wrapper created successfully!")
|
556 |
+
|
557 |
+
# Test with a simple text prompt
|
558 |
+
response = wrapper([{"type": "text", "content": "Hello! Can you confirm you're working?"}])
|
559 |
+
print(f"Response: {response}")
|
560 |
+
|
561 |
+
except Exception as e:
|
562 |
+
print(f"Error creating regular OpenAI wrapper: {e}")
|
563 |
+
print("Make sure to set OPENAI_API_KEY environment variable")
|
564 |
+
|
565 |
+
# Example 2: GitHub AI models (requires GITHUB_TOKEN environment variable)
|
566 |
+
try:
|
567 |
+
print("\n2. Testing GitHub AI models wrapper:")
|
568 |
+
github_wrapper = create_github_openai_wrapper("gpt-4o", temperature=1.0)
|
569 |
+
print("GitHub OpenAI wrapper created successfully!")
|
570 |
+
|
571 |
+
# Test with a simple text prompt
|
572 |
+
response = github_wrapper([{
|
573 |
+
"type": "text",
|
574 |
+
"content": "What is the capital of France?"
|
575 |
+
}])
|
576 |
+
print(f"GitHub Response: {response}")
|
577 |
+
|
578 |
+
except Exception as e:
|
579 |
+
print(f"Error creating GitHub wrapper: {e}")
|
580 |
+
print("Make sure to set GITHUB_TOKEN environment variable")
|
581 |
+
# Example 3: Manual GitHub configuration
|
582 |
+
try:
|
583 |
+
print("\n3. Testing manual GitHub configuration:")
|
584 |
+
manual_wrapper = OpenAIWrapper(
|
585 |
+
model_name="openai/gpt-4o",
|
586 |
+
use_github_token=True,
|
587 |
+
temperature=1.0,
|
588 |
+
verbose=False
|
589 |
+
)
|
590 |
+
print("Manual GitHub wrapper created successfully!")
|
591 |
+
|
592 |
+
except Exception as e:
|
593 |
+
print(f"Error creating manual GitHub wrapper: {e}")
|
594 |
+
print("Make sure to set GITHUB_TOKEN environment variable")
|
mllm_tools/openrouter.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
from typing import List, Dict, Any, Optional, Union
|
4 |
+
import io
|
5 |
+
import base64
|
6 |
+
from PIL import Image
|
7 |
+
import mimetypes
|
8 |
+
from litellm import completion, completion_cost
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
|
11 |
+
load_dotenv()
|
12 |
+
|
13 |
+
class OpenRouterWrapper:
|
14 |
+
"""
|
15 |
+
OpenRouter wrapper using LiteLLM for various language models.
|
16 |
+
Compatible with the existing wrapper interface.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
model_name: str = "openrouter/deepseek/deepseek-chat-v3-0324:free",
|
22 |
+
temperature: float = 0.7,
|
23 |
+
print_cost: bool = False,
|
24 |
+
verbose: bool = False,
|
25 |
+
use_langfuse: bool = True,
|
26 |
+
site_url: str = "",
|
27 |
+
app_name: str = "Theory2Manim"
|
28 |
+
):
|
29 |
+
"""
|
30 |
+
Initialize OpenRouter wrapper.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
model_name: OpenRouter model name (with openrouter/ prefix)
|
34 |
+
temperature: Temperature for completion
|
35 |
+
print_cost: Whether to print the cost of the completion
|
36 |
+
verbose: Whether to print verbose output
|
37 |
+
use_langfuse: Whether to enable Langfuse logging
|
38 |
+
site_url: Optional site URL for tracking
|
39 |
+
app_name: Optional app name for tracking
|
40 |
+
"""
|
41 |
+
self.model_name = model_name
|
42 |
+
self.temperature = temperature
|
43 |
+
self.print_cost = print_cost
|
44 |
+
self.verbose = verbose
|
45 |
+
self.accumulated_cost = 0
|
46 |
+
|
47 |
+
# Setup OpenRouter environment variables
|
48 |
+
api_key = os.getenv("OPENROUTER_API_KEY")
|
49 |
+
if not api_key:
|
50 |
+
raise ValueError("No OPENROUTER_API_KEY found. Please set the environment variable.")
|
51 |
+
|
52 |
+
os.environ["OPENROUTER_API_KEY"] = api_key
|
53 |
+
os.environ["OPENROUTER_API_BASE"] = "https://openrouter.ai/api/v1"
|
54 |
+
|
55 |
+
if site_url or os.getenv("OR_SITE_URL"):
|
56 |
+
os.environ["OR_SITE_URL"] = site_url or os.getenv("OR_SITE_URL", "")
|
57 |
+
if app_name:
|
58 |
+
os.environ["OR_APP_NAME"] = app_name
|
59 |
+
|
60 |
+
if self.verbose:
|
61 |
+
os.environ['LITELLM_LOG'] = 'DEBUG'
|
62 |
+
|
63 |
+
# Set langfuse callback only if enabled
|
64 |
+
if use_langfuse:
|
65 |
+
import litellm
|
66 |
+
litellm.success_callback = ["langfuse"]
|
67 |
+
litellm.failure_callback = ["langfuse"]
|
68 |
+
|
69 |
+
def _encode_file(self, file_path: Union[str, Image.Image]) -> str:
|
70 |
+
"""
|
71 |
+
Encode local file or PIL Image to base64 string
|
72 |
+
|
73 |
+
Args:
|
74 |
+
file_path: Path to local file or PIL Image object
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
Base64 encoded file string
|
78 |
+
"""
|
79 |
+
if isinstance(file_path, Image.Image):
|
80 |
+
buffered = io.BytesIO()
|
81 |
+
file_path.save(buffered, format="PNG")
|
82 |
+
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
83 |
+
else:
|
84 |
+
with open(file_path, "rb") as file:
|
85 |
+
return base64.b64encode(file.read()).decode("utf-8")
|
86 |
+
|
87 |
+
def _get_mime_type(self, file_path: str) -> str:
|
88 |
+
"""
|
89 |
+
Get the MIME type of a file based on its extension
|
90 |
+
|
91 |
+
Args:
|
92 |
+
file_path: Path to the file
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
MIME type as a string (e.g., "image/jpeg", "audio/mp3")
|
96 |
+
"""
|
97 |
+
mime_type, _ = mimetypes.guess_type(file_path)
|
98 |
+
if mime_type is None:
|
99 |
+
raise ValueError(f"Unsupported file type: {file_path}")
|
100 |
+
return mime_type
|
101 |
+
|
102 |
+
def __call__(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None) -> str:
|
103 |
+
"""
|
104 |
+
Process messages and return completion
|
105 |
+
|
106 |
+
Args:
|
107 |
+
messages: List of message dictionaries with 'type' and 'content' keys
|
108 |
+
metadata: Optional metadata to pass to completion
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
Generated text response
|
112 |
+
"""
|
113 |
+
if metadata is None:
|
114 |
+
metadata = {}
|
115 |
+
metadata["trace_name"] = f"openrouter-completion-{self.model_name}"
|
116 |
+
|
117 |
+
# Convert messages to LiteLLM format
|
118 |
+
formatted_messages = []
|
119 |
+
for msg in messages:
|
120 |
+
if msg["type"] == "text":
|
121 |
+
formatted_messages.append({
|
122 |
+
"role": "user",
|
123 |
+
"content": [{"type": "text", "text": msg["content"]}]
|
124 |
+
})
|
125 |
+
elif msg["type"] in ["image", "audio", "video"]:
|
126 |
+
# Check if content is a local file path or PIL Image
|
127 |
+
if isinstance(msg["content"], Image.Image) or os.path.isfile(msg["content"]):
|
128 |
+
try:
|
129 |
+
if isinstance(msg["content"], Image.Image):
|
130 |
+
mime_type = "image/png"
|
131 |
+
else:
|
132 |
+
mime_type = self._get_mime_type(msg["content"])
|
133 |
+
base64_data = self._encode_file(msg["content"])
|
134 |
+
data_url = f"data:{mime_type};base64,{base64_data}"
|
135 |
+
except ValueError as e:
|
136 |
+
print(f"Error processing file {msg['content']}: {e}")
|
137 |
+
continue
|
138 |
+
else:
|
139 |
+
data_url = msg["content"]
|
140 |
+
|
141 |
+
# Format for vision models
|
142 |
+
if msg["type"] == "image":
|
143 |
+
formatted_messages.append({
|
144 |
+
"role": "user",
|
145 |
+
"content": [
|
146 |
+
{
|
147 |
+
"type": "image_url",
|
148 |
+
"image_url": {
|
149 |
+
"url": data_url,
|
150 |
+
"detail": "high"
|
151 |
+
}
|
152 |
+
}
|
153 |
+
]
|
154 |
+
})
|
155 |
+
else:
|
156 |
+
# For audio/video, treat as text for now
|
157 |
+
formatted_messages.append({
|
158 |
+
"role": "user",
|
159 |
+
"content": [{"type": "text", "text": f"[{msg['type'].upper()}]: {msg['content']}"}]
|
160 |
+
})
|
161 |
+
|
162 |
+
try:
|
163 |
+
response = completion(
|
164 |
+
model=self.model_name,
|
165 |
+
messages=formatted_messages,
|
166 |
+
temperature=self.temperature,
|
167 |
+
metadata=metadata,
|
168 |
+
max_retries=99
|
169 |
+
)
|
170 |
+
if self.print_cost:
|
171 |
+
# Calculate and print cost
|
172 |
+
cost = completion_cost(completion_response=response)
|
173 |
+
self.accumulated_cost += cost
|
174 |
+
print(f"Accumulated Cost: ${self.accumulated_cost:.10f}")
|
175 |
+
|
176 |
+
content = response.choices[0].message.content
|
177 |
+
if content is None:
|
178 |
+
print(f"Got null response from model. Full response: {response}")
|
179 |
+
return "Error: Received null response from model"
|
180 |
+
|
181 |
+
# Check if the response contains error messages about unmapped models
|
182 |
+
if "This model isn't mapped yet" in content or "model isn't mapped" in content.lower():
|
183 |
+
error_msg = f"Error: Model {self.model_name} is not supported by LiteLLM. Please use a supported model."
|
184 |
+
print(error_msg)
|
185 |
+
return error_msg
|
186 |
+
|
187 |
+
return content
|
188 |
+
|
189 |
+
except Exception as e:
|
190 |
+
print(f"Error in OpenRouter completion: {e}")
|
191 |
+
return f"Error: {str(e)}"
|
192 |
+
|
193 |
+
|
194 |
+
class OpenRouterClient:
|
195 |
+
"""
|
196 |
+
Legacy OpenRouter client for backward compatibility.
|
197 |
+
"""
|
198 |
+
|
199 |
+
def __init__(self, api_key: str, site_url: str = "", app_name: str = "Theory2Manim"):
|
200 |
+
"""
|
201 |
+
Initialize OpenRouter client.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
api_key: OpenRouter API key
|
205 |
+
site_url: Optional site URL for tracking
|
206 |
+
app_name: Optional app name for tracking
|
207 |
+
"""
|
208 |
+
os.environ["OPENROUTER_API_KEY"] = api_key
|
209 |
+
os.environ["OPENROUTER_API_BASE"] = "https://openrouter.ai/api/v1"
|
210 |
+
|
211 |
+
if site_url:
|
212 |
+
os.environ["OR_SITE_URL"] = site_url
|
213 |
+
if app_name:
|
214 |
+
os.environ["OR_APP_NAME"] = app_name
|
215 |
+
|
216 |
+
def complete(
|
217 |
+
self,
|
218 |
+
messages: List[Dict[str, str]],
|
219 |
+
model: str = "openrouter/openai/gpt-3.5-turbo",
|
220 |
+
transforms: Optional[List[str]] = None,
|
221 |
+
route: Optional[str] = None,
|
222 |
+
**kwargs
|
223 |
+
) -> Any:
|
224 |
+
"""
|
225 |
+
Generate completion using OpenRouter model.
|
226 |
+
|
227 |
+
Args:
|
228 |
+
messages: List of message dictionaries with 'role' and 'content'
|
229 |
+
model: Model name (with openrouter/ prefix)
|
230 |
+
transforms: Optional transforms to apply
|
231 |
+
route: Optional route specification
|
232 |
+
**kwargs: Additional parameters for completion
|
233 |
+
|
234 |
+
Returns:
|
235 |
+
Completion response
|
236 |
+
"""
|
237 |
+
params = {
|
238 |
+
"model": model,
|
239 |
+
"messages": messages,
|
240 |
+
**kwargs
|
241 |
+
}
|
242 |
+
|
243 |
+
if transforms:
|
244 |
+
params["transforms"] = transforms
|
245 |
+
if route:
|
246 |
+
params["route"] = route
|
247 |
+
|
248 |
+
return completion(**params)
|
249 |
+
|
250 |
+
# Convenience functions for common models
|
251 |
+
def ds_r1(messages: List[Dict[str, str]], **kwargs) -> Any:
|
252 |
+
"""Use GPT-3.5 Turbo via OpenRouter"""
|
253 |
+
client = OpenRouterClient(os.environ.get("OPENROUTER_API_KEY", ""))
|
254 |
+
return client.complete(messages, "deepseek/deepseek-r1:free", **kwargs)
|
255 |
+
|
256 |
+
def ds_v3(messages: List[Dict[str, str]], **kwargs) -> Any:
|
257 |
+
"""Use GPT-4 via OpenRouter"""
|
258 |
+
client = OpenRouterClient(os.environ.get("OPENROUTER_API_KEY", ""))
|
259 |
+
return client.complete(messages, "deepseek/deepseek-chat-v3-0324:free", **kwargs)
|
260 |
+
|
261 |
+
def qwen3(messages: List[Dict[str, str]], **kwargs) -> Any:
|
262 |
+
"""Use Claude-2 via OpenRouter"""
|
263 |
+
client = OpenRouterClient(os.environ.get("OPENROUTER_API_KEY", ""))
|
264 |
+
return client.complete(messages, "qwen/qwen3-235b-a22b:free", **kwargs)
|
265 |
+
|
266 |
+
|
mllm_tools/utils.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union, List, Dict, Any, Optional
|
2 |
+
from PIL import Image
|
3 |
+
import google.generativeai as genai
|
4 |
+
import tempfile
|
5 |
+
import os
|
6 |
+
from .gemini import GeminiWrapper
|
7 |
+
from .vertex_ai import VertexAIWrapper
|
8 |
+
from .openrouter import OpenRouterWrapper
|
9 |
+
|
10 |
+
|
11 |
+
def _prepare_text_inputs(texts: List[str]) -> List[Dict[str, str]]:
|
12 |
+
"""
|
13 |
+
Converts a list of text strings into the input format for the Agent model.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
texts (List[str]): The list of text strings to be processed.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
List[Dict[str, str]]: A list of dictionaries formatted for the Agent model.
|
20 |
+
"""
|
21 |
+
inputs = []
|
22 |
+
# Add each text string to the inputs
|
23 |
+
if isinstance(texts, str):
|
24 |
+
texts = [texts]
|
25 |
+
for text in texts:
|
26 |
+
inputs.append({
|
27 |
+
"type": "text",
|
28 |
+
"content": text
|
29 |
+
})
|
30 |
+
return inputs
|
31 |
+
|
32 |
+
def _prepare_text_image_inputs(texts: Union[str, List[str]], images: Union[str, Image.Image, List[Union[str, Image.Image]]]) -> List[Dict[str, str]]:
|
33 |
+
"""
|
34 |
+
Converts text strings and images into the input format for the Agent model.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
texts (Union[str, List[str]]): Text string(s) to be processed.
|
38 |
+
images (Union[str, Image.Image, List[Union[str, Image.Image]]]): Image file path(s) or PIL Image object(s).
|
39 |
+
Returns:
|
40 |
+
List[Dict[str, str]]: A list of dictionaries formatted for the Agent model.
|
41 |
+
"""
|
42 |
+
inputs = []
|
43 |
+
# Add each text string to the inputs
|
44 |
+
if isinstance(texts, str):
|
45 |
+
texts = [texts]
|
46 |
+
for text in texts:
|
47 |
+
inputs.append({
|
48 |
+
"type": "text",
|
49 |
+
"content": text
|
50 |
+
})
|
51 |
+
if isinstance(images, (str, Image.Image)):
|
52 |
+
images = [images]
|
53 |
+
for image in images:
|
54 |
+
inputs.append({
|
55 |
+
"type": "image",
|
56 |
+
"content": image
|
57 |
+
})
|
58 |
+
return inputs
|
59 |
+
|
60 |
+
def _prepare_text_video_inputs(texts: Union[str, List[str]], videos: Union[str, List[str]]) -> List[Dict[str, str]]:
|
61 |
+
"""
|
62 |
+
Converts text strings and video file paths into the input format for the Agent model.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
texts (Union[str, List[str]]): Text string(s) to be processed.
|
66 |
+
videos (Union[str, List[str]]): Video file path(s).
|
67 |
+
Returns:
|
68 |
+
List[Dict[str, str]]: A list of dictionaries formatted for the Agent model.
|
69 |
+
"""
|
70 |
+
inputs = []
|
71 |
+
# Add each text string to the inputs
|
72 |
+
if isinstance(texts, str):
|
73 |
+
texts = [texts]
|
74 |
+
for text in texts:
|
75 |
+
inputs.append({
|
76 |
+
"type": "text",
|
77 |
+
"content": text
|
78 |
+
})
|
79 |
+
# Add each video file path to the inputs
|
80 |
+
if isinstance(videos, str):
|
81 |
+
videos = [videos]
|
82 |
+
for video in videos:
|
83 |
+
inputs.append({
|
84 |
+
"type": "video",
|
85 |
+
"content": video
|
86 |
+
})
|
87 |
+
return inputs
|
88 |
+
|
89 |
+
def _prepare_text_audio_inputs(texts: Union[str, List[str]], audios: Union[str, List[str]]) -> List[Dict[str, str]]:
|
90 |
+
"""
|
91 |
+
Converts text strings and audio file paths into the input format for the Agent model.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
texts (Union[str, List[str]]): Text string(s) to be processed.
|
95 |
+
audios (Union[str, List[str]]): Audio file path(s).
|
96 |
+
Returns:
|
97 |
+
List[Dict[str, str]]: A list of dictionaries formatted for the Agent model.
|
98 |
+
"""
|
99 |
+
inputs = []
|
100 |
+
# Add each text string to the inputs
|
101 |
+
if isinstance(texts, str):
|
102 |
+
texts = [texts]
|
103 |
+
for text in texts:
|
104 |
+
inputs.append({
|
105 |
+
"type": "text",
|
106 |
+
"content": text
|
107 |
+
})
|
108 |
+
# Add each audio file path to the inputs
|
109 |
+
if isinstance(audios, str):
|
110 |
+
audios = [audios]
|
111 |
+
for audio in audios:
|
112 |
+
inputs.append({
|
113 |
+
"type": "audio",
|
114 |
+
"content": audio
|
115 |
+
})
|
116 |
+
return inputs
|
117 |
+
|
118 |
+
def _extract_code(text: str) -> str:
|
119 |
+
"""Helper to extract code block from model response, support Gemini style and OpenAI style"""
|
120 |
+
try:
|
121 |
+
# Find code between ```python and ``` tags
|
122 |
+
start = text.split("```python\n")[-1]
|
123 |
+
end = start.split("```")[0]
|
124 |
+
return end.strip()
|
125 |
+
except IndexError:
|
126 |
+
return text
|
127 |
+
|
128 |
+
def _upload_to_gemini(input, mime_type=None):
|
129 |
+
"""Uploads the given file or PIL image to Gemini.
|
130 |
+
|
131 |
+
See https://ai.google.dev/gemini-api/docs/prompting_with_media
|
132 |
+
"""
|
133 |
+
if isinstance(input, str):
|
134 |
+
# Input is a file path
|
135 |
+
file = genai.upload_file(input, mime_type=mime_type)
|
136 |
+
elif isinstance(input, Image.Image):
|
137 |
+
# Input is a PIL image
|
138 |
+
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
|
139 |
+
input.save(tmp_file, format="JPEG")
|
140 |
+
tmp_file_path = tmp_file.name
|
141 |
+
file = genai.upload_file(tmp_file_path, mime_type=mime_type or "image/jpeg")
|
142 |
+
os.remove(tmp_file_path)
|
143 |
+
else:
|
144 |
+
raise ValueError("Unsupported input type. Must be a file path or PIL Image.")
|
145 |
+
|
146 |
+
#print(f"Uploaded file '{file.display_name}' as: {file.uri}")
|
147 |
+
return file
|
148 |
+
|
149 |
+
def get_media_wrapper(model_name: str) -> Optional[Union[GeminiWrapper, VertexAIWrapper, OpenRouterWrapper]]:
|
150 |
+
"""Get appropriate wrapper for media handling based on model name"""
|
151 |
+
if model_name.startswith('gemini/'):
|
152 |
+
return GeminiWrapper(model_name=model_name.split('/')[-1])
|
153 |
+
elif model_name.startswith('vertex_ai/'):
|
154 |
+
return VertexAIWrapper(model_name=model_name.split('/')[-1])
|
155 |
+
elif model_name.startswith('openrouter/'):
|
156 |
+
return OpenRouterWrapper(model_name=model_name)
|
157 |
+
return None
|
158 |
+
|
159 |
+
def prepare_media_messages(prompt: str, media_path: Union[str, Image.Image], model_name: str) -> List[Dict[str, Any]]:
|
160 |
+
"""Prepare messages for media input based on model type"""
|
161 |
+
is_video = isinstance(media_path, str) and media_path.endswith('.mp4')
|
162 |
+
|
163 |
+
if is_video and (model_name.startswith('gemini/') or model_name.startswith('vertex_ai/') or model_name.startswith('openrouter/')):
|
164 |
+
return [
|
165 |
+
{"type": "text", "content": prompt},
|
166 |
+
{"type": "video", "content": media_path}
|
167 |
+
]
|
168 |
+
else:
|
169 |
+
# For images or non-video content
|
170 |
+
if isinstance(media_path, str):
|
171 |
+
media = Image.open(media_path)
|
172 |
+
else:
|
173 |
+
media = media_path
|
174 |
+
return [
|
175 |
+
{"type": "text", "content": prompt},
|
176 |
+
{"type": "image", "content": media}
|
177 |
+
]
|
mllm_tools/vertex_ai.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Dict, Any, Optional
|
3 |
+
import vertexai
|
4 |
+
from vertexai.generative_models import GenerativeModel, Part
|
5 |
+
from google.auth import default
|
6 |
+
from google.auth.transport import requests
|
7 |
+
|
8 |
+
|
9 |
+
# TODO: check if this is the correct way to use Vertex AI
|
10 |
+
# TODO: add langfuse support
|
11 |
+
class VertexAIWrapper:
|
12 |
+
"""Wrapper for Vertex AI to support Gemini models."""
|
13 |
+
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
model_name: str = "gemini-1.5-pro",
|
17 |
+
temperature: float = 0.7,
|
18 |
+
print_cost: bool = False,
|
19 |
+
verbose: bool = False,
|
20 |
+
use_langfuse: bool = False
|
21 |
+
):
|
22 |
+
"""Initialize the Vertex AI wrapper.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
model_name: Name of the model to use (e.g. "gemini-1.5-pro")
|
26 |
+
temperature: Temperature for generation between 0 and 1
|
27 |
+
print_cost: Whether to print the cost of the completion
|
28 |
+
verbose: Whether to print verbose output
|
29 |
+
use_langfuse: Whether to enable Langfuse logging
|
30 |
+
"""
|
31 |
+
self.model_name = model_name
|
32 |
+
self.temperature = temperature
|
33 |
+
self.print_cost = print_cost
|
34 |
+
self.verbose = verbose
|
35 |
+
|
36 |
+
# Initialize Vertex AI
|
37 |
+
project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
|
38 |
+
location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1")
|
39 |
+
if not project_id:
|
40 |
+
raise ValueError("No GOOGLE_CLOUD_PROJECT found in environment variables")
|
41 |
+
|
42 |
+
vertexai.init(project=project_id, location=location)
|
43 |
+
self.model = GenerativeModel(model_name)
|
44 |
+
|
45 |
+
def __call__(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None) -> str:
|
46 |
+
"""Process messages and return completion.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
messages: List of message dictionaries containing type and content
|
50 |
+
metadata: Optional metadata dictionary to pass to the model
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
Generated text response from the model
|
54 |
+
|
55 |
+
Raises:
|
56 |
+
ValueError: If message type is not supported
|
57 |
+
"""
|
58 |
+
parts = []
|
59 |
+
|
60 |
+
for msg in messages:
|
61 |
+
if msg["type"] == "text":
|
62 |
+
parts.append(Part.from_text(msg["content"]))
|
63 |
+
elif msg["type"] in ["image", "video"]:
|
64 |
+
mime_type = "video/mp4" if msg["type"] == "video" else "image/jpeg"
|
65 |
+
if isinstance(msg["content"], str):
|
66 |
+
# Handle GCS URI
|
67 |
+
parts.append(Part.from_uri(
|
68 |
+
msg["content"],
|
69 |
+
mime_type=mime_type
|
70 |
+
))
|
71 |
+
else:
|
72 |
+
# Handle file path or bytes
|
73 |
+
parts.append(Part.from_data(
|
74 |
+
msg["content"],
|
75 |
+
mime_type=mime_type
|
76 |
+
))
|
77 |
+
|
78 |
+
response = self.model.generate_content(
|
79 |
+
parts,
|
80 |
+
generation_config={
|
81 |
+
"temperature": self.temperature,
|
82 |
+
"top_p": 0.95,
|
83 |
+
}
|
84 |
+
)
|
85 |
+
|
86 |
+
return response.text
|