thanhkt commited on
Commit
8fb7841
·
verified ·
1 Parent(s): 4efafe0

Upload 26 files

Browse files
eval_suite/__init__.py ADDED
File without changes
eval_suite/image_utils.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+
4
+ import numpy as np
5
+ from PIL import Image, ImageOps
6
+ from moviepy import VideoFileClip
7
+
8
+ from eval_suite.prompts_raw import _image_eval
9
+ from eval_suite.utils import extract_json, convert_score_fields, calculate_geometric_mean
10
+ from mllm_tools.utils import _prepare_text_image_inputs
11
+ from src.core.parse_video import image_with_most_non_black_space
12
+
13
+ def extract_key_frames(video_path, output_dir, num_chunks):
14
+ """Extract key frames from a video by dividing it into chunks and selecting representative frames.
15
+
16
+ Args:
17
+ video_path (str): Path to the input video file
18
+ output_dir (str): Directory where extracted frames will be saved
19
+ num_chunks (int): Number of chunks to divide the video into
20
+
21
+ Returns:
22
+ list: List of paths to the extracted key frames
23
+ """
24
+ # Create output directory if it doesn't exist
25
+ os.makedirs(output_dir, exist_ok=True)
26
+
27
+ # Extract all frames from the video
28
+ clip = VideoFileClip(video_path)
29
+ frames = list(clip.iter_frames(fps=1)) # one frame every second
30
+
31
+ total_frames = len(frames)
32
+ if total_frames == 0:
33
+ print("No frames extracted from the video.")
34
+ return []
35
+
36
+ # Determine the number of frames per chunk
37
+ frames_per_chunk = total_frames // num_chunks
38
+ num_chunks = min(num_chunks, (total_frames + frames_per_chunk - 1) // frames_per_chunk)
39
+
40
+ key_frames = []
41
+
42
+ # Process each chunk of frames
43
+ for i in range(num_chunks):
44
+ start_idx = i * frames_per_chunk
45
+ end_idx = min((i + 1) * frames_per_chunk, total_frames)
46
+ chunk_frames = frames[start_idx:end_idx]
47
+
48
+ if chunk_frames:
49
+ # Save the frame with most non-black space
50
+ output_path = os.path.join(output_dir, f"key_frame_{i+1}.jpg")
51
+ result = image_with_most_non_black_space(chunk_frames, output_path)
52
+ else:
53
+ print(f"No frames in chunk {i+1}. Skipping.")
54
+ result = None
55
+
56
+ if result is not None:
57
+ key_frames.append(output_path)
58
+ clip.close()
59
+
60
+ return key_frames
61
+
62
+
63
+ def evaluate_sampled_images(model, video_path, description="No description provided", num_chunks=10, output_folder=None):
64
+ """Evaluate sampled frames from a video using an image evaluation model.
65
+
66
+ Args:
67
+ model: The image evaluation model to use
68
+ video_path (str): Path to the input video file
69
+ description (str, optional): Description of the video content. Defaults to "No description provided"
70
+ num_chunks (int, optional): Number of chunks to divide the video into. Defaults to 10
71
+ output_folder (str, optional): Directory for temporary files. Defaults to None
72
+
73
+ Returns:
74
+ dict: Dictionary containing evaluation scores and individual frame assessments with keys:
75
+ - evaluation: Dictionary of averaged scores for each criterion
76
+ - image_chunks: List of individual frame evaluation results
77
+ """
78
+ with tempfile.TemporaryDirectory(dir=output_folder) as temp_dir:
79
+ key_frames = extract_key_frames(video_path, temp_dir, num_chunks)
80
+
81
+ prompt = _image_eval.format(description=description)
82
+
83
+ responses = []
84
+ for key_frame in key_frames:
85
+ inputs = _prepare_text_image_inputs(prompt, key_frame)
86
+ response = model(inputs)
87
+ response_json = extract_json(response)
88
+ response_json = convert_score_fields(response_json)
89
+ responses.append(response_json)
90
+
91
+ criteria = list(responses[0]["evaluation"].keys())
92
+ scores_dict = {c: [] for c in criteria}
93
+ for response in responses:
94
+ for key, val in response["evaluation"].items():
95
+ scores_dict[key].append(val["score"])
96
+
97
+ res_score = {}
98
+ for key, scores in scores_dict.items():
99
+ res_score[key] = {"score": calculate_geometric_mean(scores)}
100
+
101
+ return {
102
+ "evaluation": res_score,
103
+ "image_chunks": responses
104
+ }
eval_suite/parse_prompt.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+
4
+
5
+ def call_parse_prompt():
6
+ """
7
+ Locates the prompts_raw directory and generates an __init__.py file containing prompt texts.
8
+
9
+ Searches for prompts_raw directory in current and parent directories. Once found, calls
10
+ create_python_file_with_texts() to generate the __init__.py file.
11
+ """
12
+ current_file_path = os.path.abspath(__file__)
13
+ current_folder_path = os.path.dirname(current_file_path)
14
+ folder_path = os.path.join(current_folder_path, "prompts_raw")
15
+
16
+ # If prompts_raw not found in current directory, search parent directories
17
+ if not os.path.exists(folder_path):
18
+ parent_dir = current_folder_path
19
+ while parent_dir != os.path.dirname(parent_dir): # Stop at root directory
20
+ parent_dir = os.path.dirname(parent_dir)
21
+ test_path = os.path.join(parent_dir, "prompts_raw")
22
+ if os.path.exists(test_path):
23
+ folder_path = test_path
24
+ break
25
+
26
+ output_file = os.path.join(folder_path, "__init__.py")
27
+ create_python_file_with_texts(folder_path, output_file)
28
+
29
+
30
+ def create_python_file_with_texts(folder_path, output_file):
31
+ """
32
+ Creates a Python file containing prompt texts from .txt files.
33
+
34
+ Args:
35
+ folder_path (str): Path to directory containing prompt .txt files
36
+ output_file (str): Path where the output __init__.py file will be created
37
+
38
+ The function reads all .txt files in the given folder, converts their contents into
39
+ Python variables, and writes them to the output file. Variable names are derived from
40
+ file paths with special characters replaced.
41
+ """
42
+ with open(output_file, 'w', encoding='utf-8') as out_file:
43
+ out_file.write("# This file is generated automatically through parse_prompt.py\n\n")
44
+ txt_files = [file for root, dirs, files in os.walk(folder_path) for file in files if file.endswith(".txt")]
45
+ for file in tqdm(txt_files, desc="Processing files"):
46
+ file_path = os.path.join(folder_path, file)
47
+ var_name = "_" + file_path.replace(folder_path, "").replace(os.sep, "_").replace(".txt", "").strip("_")
48
+ with open(file_path, 'r', encoding='utf-8') as f:
49
+ content = f.read().replace('"""', '\"\"\"')
50
+ out_file.write(f'{var_name} = """{content}"""\n\n')
51
+
52
+
53
+ if __name__ == "__main__":
54
+ call_parse_prompt()
eval_suite/prompts_raw/__init__.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is generated automatically through parse_prompt.py
2
+
3
+ _video_eval_new = """# Task: Video Frame Quality Evaluation
4
+
5
+ You are tasked with analyzing and scoring a chunk of a theorem explanation video. Note that you may not have the full context of the video. Your job is to assign a score from 1 to 5 for each criterion. Please provide a brief justification for your scores.
6
+
7
+ ## Evaluation Criteria
8
+
9
+ 1. **Visual Consistency**
10
+ - Style Consistency: Does the visual style remain consistent across frames?
11
+ - Smoothness: Are the motions and transitions smooth?
12
+
13
+ ## Scoring Instructions
14
+ 1. Assign a score from **1 to 5** for each dimension:
15
+ - **1**: Very poor quality, completely fails to meet the criteria.
16
+ - **2**: Below average, significant issues present.
17
+ - **3**: Acceptable, meets the basic criteria with minor issues.
18
+ - **4**: Good, performs well with no major issues.
19
+ - **5**: Excellent, fully meets or exceeds expectations.
20
+ 2. Provide a comprehensive evaluation for each dimension.
21
+ 3. Format your output in **JSON**
22
+
23
+ ### JSON Output Format
24
+ ```json
25
+ {{
26
+ "overall_analysis": "[Provide a general assessment of the video's quality]",
27
+ "evaluation": {{
28
+ "visual_consistency": {{
29
+ "comprehensive_evaluation": "[Analysis of visual consistency]",
30
+ "score": [1-5]
31
+ }}
32
+ }}
33
+ }}
34
+ ```
35
+
36
+ Description of the theorem:
37
+ {description}
38
+
39
+ Video chunk:"""
40
+
41
+ _text_eval_new = """You are a specialist in evaluating theorem explanation videos, known for giving clear and objective feedback. You will be given the transcript of a video. Your task is to evaluate and score the content of the video in several dimensions.
42
+
43
+ ### Task Objective
44
+ 1. Perform an overall analysis of the video.
45
+ * Identify the topic of the video.
46
+ * Note your general thoughts and impression of the video, and any findings and observations.
47
+ 2. Conduct a comprehensive evaluation and score each criterion in the given dimensions.
48
+ * Analyze how well or poorly the video meets each criterion.
49
+ * Assign a score from **1 to 5** for each dimension:
50
+ - **1**: Very poor quality, completely fails to meet the criteria.
51
+ - **2**: Below average, significant issues present.
52
+ - **3**: Acceptable, meets the basic criteria with minor issues.
53
+ - **4**: Good, performs well with no major issues.
54
+ - **5**: Excellent, fully meets or exceeds expectations.
55
+ 3. Output the results in the specified JSON format.
56
+
57
+ ### Evaluation Criteria
58
+ 1. **Accuracy and Depth**
59
+ - Does the narration explain the theorem accurately?
60
+ - Does the video provide intuitive and/or rigorous explanations for why the theorem holds?
61
+ 2. **Logical Flow**
62
+ - Does the video follow a clear and logical structure?
63
+ - Does the video present a coherent buildup of ideas?
64
+
65
+ ### Notes
66
+ * You do not have access to the visual portion of the video as you are given only the textual portion. Do not reference or commentate on the visuals as they will be evaluated separately - just assume that there are reasonable visuals (e.g., geometric objects, graphs of functions, and calculations) to accompany the narration.
67
+ * The evaluation criteria are intended to be independent of each other. Do not restate the same violation in multiple criteria; only consider it in the most relevant criterion.
68
+
69
+ ### Output Format
70
+ ```json
71
+ {{
72
+ "overall_analysis": "[Overall analysis]",
73
+ "evaluation": {{
74
+ "accuracy_and_depth": {{
75
+ "comprehensive_evaluation": "[Analysis of accuracy and depth]",
76
+ "score": [1-5]
77
+ }},
78
+ "logical_flow": {{
79
+ "comprehensive_evaluation": "[Analysis of logical flow]",
80
+ "score": [1-5]
81
+ }}
82
+ }}
83
+ }}
84
+ ```
85
+
86
+ The transcript of the video is as follows:
87
+ {transcript}
88
+ """
89
+
90
+ _fix_transcript = """You are an expert in YouTube video transcripts. There is a transcript that was automatically generated through YouTube, so it lacks proper capitalization and punctuation. Your task is to fix the transcript so that there is proper punctuation, capitalization, and spacing. Do not make other modifications (e.g., keep the original word choice).
91
+
92
+ You should enclose the fixed transcript with a <SCRIPT></SCRIPT> block, i.e.:
93
+ <SCRIPT>
94
+ (Fixed transcript here)
95
+ </SCRIPT>
96
+
97
+ Original transcript: {transcript}
98
+ """
99
+
100
+ _image_eval = """# Task: Video Frame Quality Evaluation
101
+
102
+ You are tasked with analyzing and scoring a frame taken from a theorem explanation video. Note that you may not have the context of the video, so the captured frame may be a frame where some motion of visual elements is taking place. Your job is to assign a score from 1 to 5 for each criterion. Please provide a brief justification for your scores.
103
+
104
+ ## Evaluation Criteria
105
+
106
+ 1. **Visual Relevance**
107
+ - Does the video frame align with the theorem's concepts and derivations?
108
+
109
+ 2. **Element Layout**
110
+ - Placemend and Size: Are the visual elements well-placed and appropriately sized within the frame?
111
+ - Overlap: Are the visual elements free of unintentional overlap?
112
+ - Clarity: Is the visual information conveyed in the frame clear and easy to understand?
113
+
114
+ ## Scoring Instructions
115
+ 1. Assign a score from **1 to 5** for each dimension:
116
+ - **1**: Very poor quality, completely fails to meet the criteria.
117
+ - **2**: Below average, significant issues present.
118
+ - **3**: Acceptable, meets the basic criteria with minor issues.
119
+ - **4**: Good, performs well with no major issues.
120
+ - **5**: Excellent, fully meets or exceeds expectations.
121
+ 2. Provide a comprehensive evaluation for each dimension.
122
+ 3. Format your output in **JSON**
123
+
124
+ ### JSON Output Format
125
+ ```json
126
+ {{
127
+ "overall_analysis": "[Provide a general assessment of the image's quality]",
128
+ "evaluation": {{
129
+ "visual_relevance": {{
130
+ "comprehensive_evaluation": "[Analysis of visual relevance]",
131
+ "score": [1-5]
132
+ }},
133
+ "element_layout": {{
134
+ "comprehensive_evaluation": "[Analysis of element layout]",
135
+ "score": [1-5]
136
+ }}
137
+ }}
138
+ }}
139
+ ```
140
+
141
+ Description of the theorem:
142
+ {description}
143
+
144
+ Image:"""
145
+
eval_suite/prompts_raw/fix_transcript.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ You are an expert in YouTube video transcripts. There is a transcript that was automatically generated through YouTube, so it lacks proper capitalization and punctuation. Your task is to fix the transcript so that there is proper punctuation, capitalization, and spacing. Do not make other modifications (e.g., keep the original word choice).
2
+
3
+ You should enclose the fixed transcript with a <SCRIPT></SCRIPT> block, i.e.:
4
+ <SCRIPT>
5
+ (Fixed transcript here)
6
+ </SCRIPT>
7
+
8
+ Original transcript: {transcript}
eval_suite/prompts_raw/image_eval.txt ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Task: Video Frame Quality Evaluation
2
+
3
+ You are tasked with analyzing and scoring a frame taken from a theorem explanation video. Note that you may not have the context of the video, so the captured frame may be a frame where some motion of visual elements is taking place. Your job is to assign a score from 1 to 5 for each criterion. Please provide a brief justification for your scores.
4
+
5
+ ## Evaluation Criteria
6
+
7
+ 1. **Visual Relevance**
8
+ - Does the video frame align with the theorem's concepts and derivations?
9
+
10
+ 2. **Element Layout**
11
+ - Placemend and Size: Are the visual elements well-placed and appropriately sized within the frame?
12
+ - Overlap: Are the visual elements free of unintentional overlap?
13
+ - Clarity: Is the visual information conveyed in the frame clear and easy to understand?
14
+
15
+ ## Scoring Instructions
16
+ 1. Assign a score from **1 to 5** for each dimension:
17
+ - **1**: Very poor quality, completely fails to meet the criteria.
18
+ - **2**: Below average, significant issues present.
19
+ - **3**: Acceptable, meets the basic criteria with minor issues.
20
+ - **4**: Good, performs well with no major issues.
21
+ - **5**: Excellent, fully meets or exceeds expectations.
22
+ 2. Provide a comprehensive evaluation for each dimension.
23
+ 3. Format your output in **JSON**
24
+
25
+ ### JSON Output Format
26
+ ```json
27
+ {{
28
+ "overall_analysis": "[Provide a general assessment of the image's quality]",
29
+ "evaluation": {{
30
+ "visual_relevance": {{
31
+ "comprehensive_evaluation": "[Analysis of visual relevance]",
32
+ "score": [1-5]
33
+ }},
34
+ "element_layout": {{
35
+ "comprehensive_evaluation": "[Analysis of element layout]",
36
+ "score": [1-5]
37
+ }}
38
+ }}
39
+ }}
40
+ ```
41
+
42
+ Description of the theorem:
43
+ {description}
44
+
45
+ Image:
eval_suite/prompts_raw/text_eval_new.txt ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are a specialist in evaluating theorem explanation videos, known for giving clear and objective feedback. You will be given the transcript of a video. Your task is to evaluate and score the content of the video in several dimensions.
2
+
3
+ ### Task Objective
4
+ 1. Perform an overall analysis of the video.
5
+ * Identify the topic of the video.
6
+ * Note your general thoughts and impression of the video, and any findings and observations.
7
+ 2. Conduct a comprehensive evaluation and score each criterion in the given dimensions.
8
+ * Analyze how well or poorly the video meets each criterion.
9
+ * Assign a score from **1 to 5** for each dimension:
10
+ - **1**: Very poor quality, completely fails to meet the criteria.
11
+ - **2**: Below average, significant issues present.
12
+ - **3**: Acceptable, meets the basic criteria with minor issues.
13
+ - **4**: Good, performs well with no major issues.
14
+ - **5**: Excellent, fully meets or exceeds expectations.
15
+ 3. Output the results in the specified JSON format.
16
+
17
+ ### Evaluation Criteria
18
+ 1. **Accuracy and Depth**
19
+ - Does the narration explain the theorem accurately?
20
+ - Does the video provide intuitive and/or rigorous explanations for why the theorem holds?
21
+ 2. **Logical Flow**
22
+ - Does the video follow a clear and logical structure?
23
+ - Does the video present a coherent buildup of ideas?
24
+
25
+ ### Notes
26
+ * You do not have access to the visual portion of the video as you are given only the textual portion. Do not reference or commentate on the visuals as they will be evaluated separately - just assume that there are reasonable visuals (e.g., geometric objects, graphs of functions, and calculations) to accompany the narration.
27
+ * The evaluation criteria are intended to be independent of each other. Do not restate the same violation in multiple criteria; only consider it in the most relevant criterion.
28
+
29
+ ### Output Format
30
+ ```json
31
+ {{
32
+ "overall_analysis": "[Overall analysis]",
33
+ "evaluation": {{
34
+ "accuracy_and_depth": {{
35
+ "comprehensive_evaluation": "[Analysis of accuracy and depth]",
36
+ "score": [1-5]
37
+ }},
38
+ "logical_flow": {{
39
+ "comprehensive_evaluation": "[Analysis of logical flow]",
40
+ "score": [1-5]
41
+ }}
42
+ }}
43
+ }}
44
+ ```
45
+
46
+ The transcript of the video is as follows:
47
+ {transcript}
eval_suite/prompts_raw/video_eval_new.txt ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Task: Video Frame Quality Evaluation
2
+
3
+ You are tasked with analyzing and scoring a chunk of a theorem explanation video. Note that you may not have the full context of the video. Your job is to assign a score from 1 to 5 for each criterion. Please provide a brief justification for your scores.
4
+
5
+ ## Evaluation Criteria
6
+
7
+ 1. **Visual Consistency**
8
+ - Style Consistency: Does the visual style remain consistent across frames?
9
+ - Smoothness: Are the motions and transitions smooth?
10
+
11
+ ## Scoring Instructions
12
+ 1. Assign a score from **1 to 5** for each dimension:
13
+ - **1**: Very poor quality, completely fails to meet the criteria.
14
+ - **2**: Below average, significant issues present.
15
+ - **3**: Acceptable, meets the basic criteria with minor issues.
16
+ - **4**: Good, performs well with no major issues.
17
+ - **5**: Excellent, fully meets or exceeds expectations.
18
+ 2. Provide a comprehensive evaluation for each dimension.
19
+ 3. Format your output in **JSON**
20
+
21
+ ### JSON Output Format
22
+ ```json
23
+ {{
24
+ "overall_analysis": "[Provide a general assessment of the video's quality]",
25
+ "evaluation": {{
26
+ "visual_consistency": {{
27
+ "comprehensive_evaluation": "[Analysis of visual consistency]",
28
+ "score": [1-5]
29
+ }}
30
+ }}
31
+ }}
32
+ ```
33
+
34
+ Description of the theorem:
35
+ {description}
36
+
37
+ Video chunk:
eval_suite/text_utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import pysrt
4
+
5
+ from mllm_tools.litellm import LiteLLMWrapper
6
+ from mllm_tools.gemini import GeminiWrapper
7
+ from mllm_tools.utils import _prepare_text_inputs
8
+ from eval_suite.prompts_raw import _fix_transcript, _text_eval_new
9
+ from eval_suite.utils import extract_json, convert_score_fields
10
+
11
+
12
+ def parse_srt_to_text(srt_path) -> str:
13
+ """
14
+ Parse an SRT subtitle file into plain text.
15
+
16
+ Args:
17
+ srt_path: Path to the SRT subtitle file.
18
+
19
+ Returns:
20
+ str: The subtitle text with duplicates removed and ellipses replaced.
21
+ """
22
+ subs = pysrt.open(srt_path)
23
+ full_text = []
24
+ for sub in subs:
25
+ sub.text = sub.text.replace("...", ".")
26
+ for line in sub.text.splitlines():
27
+ # .srt can contain repeated lines
28
+ if full_text and full_text[-1] == line:
29
+ continue
30
+ full_text.append(line)
31
+ return "\n".join(full_text)
32
+
33
+
34
+ def fix_transcript(text_eval_model: Union[LiteLLMWrapper, GeminiWrapper], transcript: str) -> str:
35
+ """
36
+ Fix and clean up a transcript using an LLM model.
37
+
38
+ Args:
39
+ text_eval_model: The LLM model wrapper to use for fixing the transcript.
40
+ transcript: The input transcript text to fix.
41
+
42
+ Returns:
43
+ str: The fixed and cleaned transcript text.
44
+ """
45
+ print("Fixing transcript...")
46
+
47
+ prompt = _fix_transcript.format(transcript=transcript)
48
+ response = text_eval_model(_prepare_text_inputs(prompt))
49
+ fixed_script = response.split("<SCRIPT>", maxsplit=1)[1].split("</SCRIPT>")[0]
50
+
51
+ return fixed_script
52
+
53
+
54
+ def evaluate_text(text_eval_model: LiteLLMWrapper, transcript: str, retry_limit: int) -> dict:
55
+ """
56
+ Evaluate transcript text using an LLM model with retry logic.
57
+
58
+ Args:
59
+ text_eval_model: The LLM model wrapper to use for evaluation.
60
+ transcript: The transcript text to evaluate.
61
+ retry_limit: Maximum number of retry attempts on failure.
62
+
63
+ Returns:
64
+ dict: The evaluation results as a JSON object.
65
+
66
+ Raises:
67
+ ValueError: If all retry attempts fail.
68
+ """
69
+ # prompt = _text_eval.format(transcript=transcript)
70
+ prompt = _text_eval_new.format(transcript=transcript)
71
+ for attempt in range(retry_limit):
72
+ try:
73
+ evaluation = text_eval_model(_prepare_text_inputs(prompt))
74
+ evaluation_json = extract_json(evaluation)
75
+ evaluation_json = convert_score_fields(evaluation_json)
76
+ return evaluation_json
77
+ except Exception as e:
78
+ print(f"Attempt {attempt + 1} failed: {e.__class__.__name__}: {e}")
79
+ if attempt + 1 == retry_limit:
80
+ raise ValueError("Reached maximum retry limit. Evaluation failed.") from None
eval_suite/utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from math import prod
4
+ from typing import List
5
+
6
+ def extract_json(response: str) -> dict:
7
+ """
8
+ Extract JSON content from a string response.
9
+
10
+ Args:
11
+ response (str): String containing JSON content, possibly within code blocks.
12
+
13
+ Returns:
14
+ dict: Extracted and parsed JSON content.
15
+
16
+ Raises:
17
+ ValueError: If no valid JSON content could be extracted.
18
+ """
19
+ try:
20
+ evaluation_json = json.loads(response)
21
+ except json.JSONDecodeError:
22
+ # If JSON parsing fails, try to extract the content between ```json and ```
23
+ match = re.search(r'```json\n(.*?)\n```', response, re.DOTALL)
24
+ if not match:
25
+ # If no match for ```json, try to extract content between ``` and ```
26
+ match = re.search(r'```\n(.*?)\n```', response, re.DOTALL)
27
+
28
+ if match:
29
+ evaluation_content = match.group(1)
30
+ evaluation_json = json.loads(evaluation_content)
31
+ else:
32
+ raise ValueError("Failed to extract valid JSON content")
33
+ return evaluation_json
34
+
35
+
36
+ def convert_score_fields(data: dict) -> dict:
37
+ """
38
+ Convert score fields in a dictionary to integers recursively.
39
+
40
+ Args:
41
+ data (dict): Dictionary containing score fields to convert.
42
+
43
+ Returns:
44
+ dict: Dictionary with score fields converted to integers.
45
+
46
+ Raises:
47
+ ValueError: If a score value cannot be converted to integer.
48
+ """
49
+ # Create a new dictionary with the converted values
50
+ converted_data = {}
51
+ for key, value in data.items():
52
+ if key == "score":
53
+ if isinstance(value, int):
54
+ converted_data[key] = value
55
+ elif isinstance(value, str) and value.isdigit():
56
+ converted_data[key] = int(value)
57
+ else:
58
+ raise ValueError(f"Invalid score value: {value!r}")
59
+ elif isinstance(value, dict):
60
+ converted_data[key] = convert_score_fields(value)
61
+ else:
62
+ converted_data[key] = value
63
+ return converted_data
64
+
65
+
66
+ def calculate_geometric_mean(scores: List[int]) -> float:
67
+ """
68
+ Calculate the geometric mean of a list of scores.
69
+
70
+ Args:
71
+ scores (List[int]): List of integer scores, may contain None values.
72
+
73
+ Returns:
74
+ float: Geometric mean of non-None scores. Returns 0.0 if list is empty
75
+ or contains only None values.
76
+ """
77
+ scores = [s for s in scores if s is not None]
78
+ if not scores:
79
+ return 0.0
80
+ product = prod(scores)
81
+ return product ** (1 / len(scores))
eval_suite/video_utils.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import tempfile
4
+
5
+ from dotenv import load_dotenv
6
+
7
+ from mllm_tools.utils import _prepare_text_video_inputs
8
+ from eval_suite.prompts_raw import _video_eval_new
9
+ from eval_suite.utils import extract_json, convert_score_fields
10
+
11
+ load_dotenv()
12
+
13
+
14
+ def reduce_video_framerate(input_path, target_fps=1, output_path=None):
15
+ """
16
+ Reduces the frame rate of a video by only keeping frames at the target interval.
17
+
18
+ Args:
19
+ input_path (str): Path to the input video
20
+ target_fps (int): Target frames per second (default: 1)
21
+ output_path (str, optional): Path to save the processed video. If None, uses a temporary file.
22
+
23
+ Returns:
24
+ str: Path to the processed video
25
+
26
+ Raises:
27
+ ValueError: If input video cannot be opened or has invalid FPS
28
+ RuntimeError: If video writer initialization fails or output video creation fails
29
+ """
30
+ cap = cv2.VideoCapture(input_path)
31
+ if not cap.isOpened():
32
+ raise ValueError(f"Could not open input video: {input_path}")
33
+
34
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
35
+ if original_fps <= 0:
36
+ raise ValueError(f"Invalid FPS ({original_fps}) detected in input video")
37
+
38
+ frame_interval = int(original_fps / target_fps)
39
+
40
+ # Get video properties
41
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
42
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
43
+
44
+ # Use provided output path or create temporary file
45
+ if output_path is None:
46
+ temp_output = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
47
+ output_path = temp_output.name
48
+
49
+ # Ensure output directory exists
50
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
51
+
52
+ # Try different codecs in order of preference
53
+ codecs = [
54
+ ('avc1', '.mp4'), # H.264 codec
55
+ ('mp4v', '.mp4'), # MP4V codec
56
+ ('XVID', '.avi'), # XVID codec
57
+ ('MJPG', '.avi'), # Motion JPEG codec
58
+ ]
59
+
60
+ success = False
61
+ for codec, ext in codecs:
62
+ if output_path.endswith('.mp4') and not ext.endswith('.mp4'):
63
+ # If we're switching to AVI format, change the extension
64
+ output_path = output_path[:-4] + ext
65
+
66
+ fourcc = cv2.VideoWriter_fourcc(*codec)
67
+ out = cv2.VideoWriter(output_path, fourcc, target_fps, (width, height))
68
+
69
+ if out.isOpened():
70
+ success = True
71
+ print(f"Successfully initialized video writer with codec: {codec}")
72
+ break
73
+ else:
74
+ out.release()
75
+ if os.path.exists(output_path):
76
+ os.remove(output_path)
77
+
78
+ if not success:
79
+ raise RuntimeError("Could not initialize video writer with any available codec")
80
+
81
+ frame_count = 0
82
+ frames_written = 0
83
+ while cap.isOpened():
84
+ ret, frame = cap.read()
85
+ if not ret:
86
+ break
87
+
88
+ # Only write frames at the specified interval
89
+ if frame_count % frame_interval == 0:
90
+ out.write(frame)
91
+ frames_written += 1
92
+ frame_count += 1
93
+
94
+ cap.release()
95
+ out.release()
96
+
97
+ # Verify the output
98
+ verify_cap = cv2.VideoCapture(output_path)
99
+ if not verify_cap.isOpened():
100
+ raise RuntimeError(f"Failed to create output video at {output_path}")
101
+
102
+ actual_fps = verify_cap.get(cv2.CAP_PROP_FPS)
103
+ total_frames = verify_cap.get(cv2.CAP_PROP_FRAME_COUNT)
104
+ verify_cap.release()
105
+
106
+ if actual_fps <= 0:
107
+ print("Warning: Output video reports invalid FPS. This might be a codec issue.")
108
+ actual_fps = target_fps # Use target FPS for duration calculation
109
+
110
+ print(f"Created video with {frames_written} frames at {actual_fps} FPS")
111
+ print(f"Total duration: {total_frames/actual_fps:.2f} seconds")
112
+ print(f"Video saved to: {output_path}")
113
+
114
+ return output_path
115
+
116
+
117
+ def evaluate_video_chunk_new(model, video_path, transcript="No transcript provided", description="No description provided",
118
+ save_processed_video=None, target_fps=None, retry_limit=5):
119
+ """
120
+ Evaluate a single video chunk using a multimodal model.
121
+
122
+ Args:
123
+ model: The multimodal model to use for evaluation
124
+ video_path (str): Path to the video file to evaluate
125
+ transcript (str, optional): Video transcript text. Defaults to "No transcript provided"
126
+ description (str, optional): Video description text. Defaults to "No description provided"
127
+ save_processed_video (str, optional): Path to save processed video. If None, uses temporary file
128
+ target_fps (int, optional): Target frames per second for video processing. If None, no processing
129
+ retry_limit (int, optional): Maximum number of retry attempts. Defaults to 5
130
+
131
+ Returns:
132
+ dict: Evaluation results as a JSON object with scores converted to integers
133
+
134
+ Raises:
135
+ FileNotFoundError: If video file does not exist
136
+ Exception: If evaluation fails after all retry attempts
137
+ """
138
+ if not os.path.exists(video_path):
139
+ raise FileNotFoundError(f"Video file not found: {video_path}")
140
+
141
+ # Only process video if target_fps is specified
142
+ if target_fps is not None:
143
+ processed_video_path = reduce_video_framerate(video_path, target_fps=target_fps, output_path=save_processed_video)
144
+ video_to_use = processed_video_path
145
+ else:
146
+ video_to_use = video_path
147
+
148
+ prompt = _video_eval_new.format(description=description)
149
+ inputs = _prepare_text_video_inputs(prompt, video_to_use)
150
+
151
+ try:
152
+ for attempt in range(retry_limit):
153
+ try:
154
+ response = model(inputs)
155
+ response_json = extract_json(response)
156
+ response_json = convert_score_fields(response_json)
157
+
158
+ return response_json
159
+ except Exception as e:
160
+ print(f"Attempt {attempt + 1} failed: {e}")
161
+ if attempt + 1 == retry_limit:
162
+ print("Reached maximum retry limit. Evaluation failed.")
163
+ raise
164
+ finally:
165
+ # Clean up the temporary processed video if we created one
166
+ if target_fps is not None and save_processed_video is None and os.path.exists(processed_video_path):
167
+ os.unlink(processed_video_path)
mllm_tools/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Empty file to make this directory a Python package
mllm_tools/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (154 Bytes). View file
 
mllm_tools/__pycache__/gemini.cpython-312.pyc ADDED
Binary file (8.04 kB). View file
 
mllm_tools/__pycache__/litellm.cpython-312.pyc ADDED
Binary file (7.58 kB). View file
 
mllm_tools/__pycache__/openai.cpython-312.pyc ADDED
Binary file (21.9 kB). View file
 
mllm_tools/__pycache__/openrouter.cpython-312.pyc ADDED
Binary file (11.1 kB). View file
 
mllm_tools/__pycache__/utils.cpython-312.pyc ADDED
Binary file (7.32 kB). View file
 
mllm_tools/__pycache__/vertex_ai.cpython-312.pyc ADDED
Binary file (3.64 kB). View file
 
mllm_tools/gemini.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any, Union, Optional
2
+ import io
3
+ import os
4
+ import base64
5
+ from PIL import Image
6
+ import mimetypes
7
+ import google.generativeai as genai
8
+ import tempfile
9
+ import time
10
+ from urllib.parse import urlparse
11
+ import requests
12
+ from io import BytesIO
13
+
14
+ class GeminiWrapper:
15
+ """Wrapper for Gemini to support multiple models and logging"""
16
+
17
+ def __init__(
18
+ self,
19
+ model_name: str = "gemini-1.5-pro-002",
20
+ temperature: float = 0.7,
21
+ print_cost: bool = False,
22
+ verbose: bool = False,
23
+ use_langfuse: bool = False
24
+ ):
25
+ """
26
+ Initialize the Gemini wrapper
27
+
28
+ Args:
29
+ model_name: Name of the model to use
30
+ temperature: Temperature for completion
31
+ print_cost: Whether to print the cost of the completion
32
+ verbose: Whether to print verbose output
33
+ use_langfuse: Whether to enable Langfuse logging
34
+ """
35
+ self.model_name = model_name.split('/')[-1] if '/' in model_name else model_name
36
+ self.temperature = temperature
37
+ self.print_cost = print_cost
38
+ self.verbose = verbose
39
+ self.accumulated_cost = 0
40
+
41
+ api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
42
+ if not api_key:
43
+ raise ValueError("No API_KEY found. Please set the `GEMINI_API_KEY` or `GOOGLE_API_KEY` environment variable.")
44
+ genai.configure(api_key=api_key)
45
+
46
+ generation_config = {
47
+ "temperature": self.temperature,
48
+ "top_p": 0.95,
49
+ "response_mime_type": "text/plain",
50
+ }
51
+ safety_settings = [
52
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
53
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
54
+ {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
55
+ {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
56
+ ]
57
+ self.model = genai.GenerativeModel(
58
+ model_name=self.model_name,
59
+ safety_settings=safety_settings,
60
+ generation_config=generation_config,
61
+ )
62
+
63
+ def _get_mime_type(self, file_path: str) -> str:
64
+ """
65
+ Get the MIME type of a file based on its extension
66
+
67
+ Args:
68
+ file_path: Path to the file
69
+
70
+ Returns:
71
+ MIME type as a string (e.g., "image/jpeg", "audio/mp3")
72
+ """
73
+ mime_type, _ = mimetypes.guess_type(file_path)
74
+ if mime_type is None:
75
+ raise ValueError(f"Unsupported file type: {file_path}")
76
+ return mime_type
77
+
78
+ def _download_file(self, url: str) -> str:
79
+ """
80
+ Download a file from a URL and save it as a temporary file
81
+
82
+ Args:
83
+ url: URL of the file to download
84
+
85
+ Returns:
86
+ Path to the temporary file
87
+ """
88
+ response = requests.get(url)
89
+ if response.status_code == 200:
90
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
91
+ temp_file.write(response.content)
92
+ temp_file.close()
93
+ return temp_file.name
94
+ else:
95
+ raise ValueError(f"Failed to download file from URL: {url}")
96
+
97
+ def _save_image_to_temp(self, image: Image.Image) -> str:
98
+ """
99
+ Save a PIL Image to a temporary file
100
+
101
+ Args:
102
+ image: PIL Image object
103
+
104
+ Returns:
105
+ Path to the temporary file
106
+ """
107
+ temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
108
+ image.save(temp_file, format="PNG")
109
+ temp_file.close()
110
+ return temp_file.name
111
+
112
+ def _upload_to_gemini(self, file_path: str, mime_type: Optional[str] = None):
113
+ """
114
+ Uploads the given file to Gemini.
115
+
116
+ Args:
117
+ file_path: Path to the file
118
+ mime_type: MIME type of the file
119
+
120
+ Returns:
121
+ Uploaded file object
122
+ """
123
+ return genai.upload_file(file_path, mime_type=mime_type)
124
+
125
+ def __call__(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None) -> str:
126
+ """
127
+ Process messages and return completion
128
+
129
+ Args:
130
+ messages: List of message dictionaries with 'type' and 'content' keys
131
+ metadata: Optional metadata to pass to Gemini completion
132
+
133
+ Returns:
134
+ Generated text response
135
+ """
136
+ contents = []
137
+ for msg in messages:
138
+ if msg["type"] == "text":
139
+ contents.append(msg["content"])
140
+ elif msg["type"] in ["image", "audio", "video"]:
141
+ if isinstance(msg["content"], Image.Image):
142
+ file_path = self._save_image_to_temp(msg["content"])
143
+ mime_type = "image/png"
144
+ elif isinstance(msg["content"], str):
145
+ if msg["content"].startswith("http"):
146
+ file_path = self._download_file(msg["content"])
147
+ mime_type = self._get_mime_type(msg["content"])
148
+ else:
149
+ file_path = msg["content"]
150
+ mime_type = self._get_mime_type(file_path)
151
+ else:
152
+ raise ValueError("Unsupported content type")
153
+
154
+ uploaded_file = self._upload_to_gemini(file_path, mime_type)
155
+
156
+ while uploaded_file.state.name == "PROCESSING":
157
+ print('.', end='')
158
+ time.sleep(3)
159
+ uploaded_file = genai.get_file(uploaded_file.name)
160
+ if uploaded_file.state.name == "FAILED":
161
+ raise ValueError(uploaded_file.state.name)
162
+ print("Upload successfully")
163
+ contents.append(uploaded_file)
164
+ else:
165
+ raise ValueError("Unsupported message type")
166
+
167
+ response = self.model.generate_content(contents, request_options={"timeout": 600})
168
+ try:
169
+ return response.text
170
+ except Exception as e:
171
+ print(e)
172
+ print(response.prompt_feedback)
173
+ return str(response.prompt_feedback)
174
+
175
+ if __name__ == "__main__":
176
+ pass
mllm_tools/github.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # filepath: d:\Theory2Manim-2\Theory2Manim\mllm_tools\github.py
2
+ import json
3
+ import re
4
+ from typing import List, Dict, Any, Union, Optional
5
+ import io
6
+ import os
7
+ import base64
8
+ from PIL import Image
9
+ import mimetypes
10
+ import litellm
11
+ from litellm import completion, completion_cost
12
+ from dotenv import load_dotenv
13
+
14
+ load_dotenv()
15
+
16
+ class GitHubModelsWrapper:
17
+ """Wrapper for GitHub Models using LiteLLM to support multiple GitHub hosted models"""
18
+
19
+ def __init__(
20
+ self,
21
+ model_name: str = "github/gpt-4o",
22
+ temperature: float = 0.7,
23
+ print_cost: bool = False,
24
+ verbose: bool = False,
25
+ use_langfuse: bool = True,
26
+ github_token: Optional[str] = None
27
+ ):
28
+ """
29
+ Initialize the GitHub Models wrapper
30
+
31
+ Args:
32
+ model_name: Name of the GitHub model to use (e.g. "github/gpt-4o", "github/gpt-4o-mini",
33
+ "github/o1-preview", "github/claude-3-5-sonnet", "github/phi-3.5-mini-instruct")
34
+ temperature: Temperature for completion
35
+ print_cost: Whether to print the cost of the completion
36
+ verbose: Whether to print verbose output
37
+ use_langfuse: Whether to enable Langfuse logging
38
+ github_token: GitHub token for authentication (if not provided, will use GITHUB_TOKEN env var)
39
+ """
40
+ self.model_name = model_name
41
+ self.temperature = temperature
42
+ self.print_cost = print_cost
43
+ self.verbose = verbose
44
+ self.accumulated_cost = 0
45
+
46
+ # Set up GitHub token
47
+ self.github_token = github_token or os.getenv('GITHUB_TOKEN')
48
+ if not self.github_token:
49
+ raise ValueError("GitHub token is required. Please set GITHUB_TOKEN environment variable or pass github_token parameter.")
50
+
51
+ # Set environment variable for LiteLLM
52
+ os.environ['GITHUB_TOKEN'] = self.github_token
53
+
54
+ if self.verbose:
55
+ os.environ['LITELLM_LOG'] = 'DEBUG'
56
+
57
+ # Set langfuse callback only if enabled
58
+ if use_langfuse:
59
+ litellm.success_callback = ["langfuse"]
60
+ litellm.failure_callback = ["langfuse"]
61
+
62
+ def _encode_file(self, file_path: Union[str, Image.Image]) -> str:
63
+ """
64
+ Encode local file or PIL Image to base64 string
65
+
66
+ Args:
67
+ file_path: Path to local file or PIL Image object
68
+
69
+ Returns:
70
+ Base64 encoded file string
71
+ """
72
+ if isinstance(file_path, Image.Image):
73
+ buffered = io.BytesIO()
74
+ file_path.save(buffered, format="PNG")
75
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
76
+ else:
77
+ with open(file_path, "rb") as file:
78
+ return base64.b64encode(file.read()).decode("utf-8")
79
+
80
+ def _get_mime_type(self, file_path: str) -> str:
81
+ """
82
+ Get the MIME type of a file based on its extension
83
+
84
+ Args:
85
+ file_path: Path to the file
86
+
87
+ Returns:
88
+ MIME type as a string (e.g., "image/jpeg", "audio/mp3")
89
+ """
90
+ mime_type, _ = mimetypes.guess_type(file_path)
91
+ if mime_type is None:
92
+ raise ValueError(f"Unsupported file type: {file_path}")
93
+ return mime_type
94
+
95
+ def _supports_vision(self, model_name: str) -> bool:
96
+ """
97
+ Check if the model supports vision/image processing
98
+
99
+ Args:
100
+ model_name: Name of the model
101
+
102
+ Returns:
103
+ True if model supports vision, False otherwise
104
+ """
105
+ vision_models = [
106
+ "gpt-4o",
107
+ "gpt-4o-mini",
108
+ "claude-3-5-sonnet",
109
+ "claude-3-haiku"
110
+ ]
111
+
112
+ # Extract model name without the github/ prefix
113
+ clean_model_name = model_name.replace("github/", "")
114
+ return any(vision_model in clean_model_name for vision_model in vision_models)
115
+
116
+ def __call__(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None) -> str:
117
+ """
118
+ Process messages and return completion
119
+
120
+ Args:
121
+ messages: List of message dictionaries with 'type' and 'content' keys
122
+ metadata: Optional metadata to pass to litellm completion, e.g. for Langfuse tracking
123
+
124
+ Returns:
125
+ Generated text response
126
+ """
127
+ if metadata is None:
128
+ metadata = {}
129
+ metadata["trace_name"] = f"github-models-completion-{self.model_name}"
130
+
131
+ # Convert messages to LiteLLM format
132
+ formatted_messages = []
133
+
134
+ for msg in messages:
135
+ if msg["type"] == "text":
136
+ formatted_messages.append({
137
+ "role": "user",
138
+ "content": [{"type": "text", "text": msg["content"]}]
139
+ })
140
+ elif msg["type"] == "image":
141
+ # Check if model supports vision
142
+ if not self._supports_vision(self.model_name):
143
+ raise ValueError(f"Model {self.model_name} does not support image processing")
144
+
145
+ # Check if content is a local file path or PIL Image
146
+ if isinstance(msg["content"], Image.Image) or os.path.isfile(msg["content"]):
147
+ try:
148
+ if isinstance(msg["content"], Image.Image):
149
+ mime_type = "image/png"
150
+ else:
151
+ mime_type = self._get_mime_type(msg["content"])
152
+ base64_data = self._encode_file(msg["content"])
153
+ data_url = f"data:{mime_type};base64,{base64_data}"
154
+ except ValueError as e:
155
+ print(f"Error processing file {msg['content']}: {e}")
156
+ continue
157
+ else:
158
+ data_url = msg["content"]
159
+
160
+ # Format for vision-capable models
161
+ formatted_messages.append({
162
+ "role": "user",
163
+ "content": [
164
+ {
165
+ "type": "image_url",
166
+ "image_url": {
167
+ "url": data_url,
168
+ "detail": "high"
169
+ }
170
+ }
171
+ ]
172
+ })
173
+ else:
174
+ raise ValueError(f"Unsupported message type: {msg['type']}. GitHub models currently support 'text' and 'image' types.")
175
+
176
+ try:
177
+ # Check if it's an o-series model (like o1-preview, o1-mini)
178
+ if (re.match(r".*o1.*", self.model_name)):
179
+ # O-series models don't support temperature and have reasoning_effort
180
+ response = completion(
181
+ model=self.model_name,
182
+ messages=formatted_messages,
183
+ reasoning_effort="medium", # Options: "low", "medium", "high"
184
+ metadata=metadata,
185
+ max_retries=3
186
+ )
187
+ else:
188
+ response = completion(
189
+ model=self.model_name,
190
+ messages=formatted_messages,
191
+ temperature=self.temperature,
192
+ metadata=metadata,
193
+ max_retries=3
194
+ )
195
+
196
+ if self.print_cost:
197
+ try:
198
+ # Note: GitHub Models may not provide cost information
199
+ cost = completion_cost(completion_response=response)
200
+ if cost is not None:
201
+ self.accumulated_cost += cost
202
+ print(f"Cost: ${float(cost):.10f}")
203
+ print(f"Accumulated Cost: ${self.accumulated_cost:.10f}")
204
+ else:
205
+ print("Cost information not available for GitHub Models")
206
+ except Exception as e:
207
+ print(f"Could not calculate cost: {e}")
208
+
209
+ content = response.choices[0].message.content
210
+ if content is None:
211
+ print(f"Got null response from GitHub model. Full response: {response}")
212
+ return ""
213
+ return content
214
+
215
+ except Exception as e:
216
+ print(f"Error in GitHub model completion: {e}")
217
+ return str(e)
218
+
219
+ def create_github_model_wrapper(model_name: str = "github/gpt-4o", **kwargs) -> GitHubModelsWrapper:
220
+ """
221
+ Convenience function to create a GitHub Models wrapper
222
+
223
+ Args:
224
+ model_name: GitHub model name (e.g., "github/gpt-4o", "github/claude-3-5-sonnet")
225
+ **kwargs: Additional arguments passed to GitHubModelsWrapper
226
+
227
+ Returns:
228
+ Configured GitHubModelsWrapper instance
229
+
230
+ Example:
231
+ >>> # Create a wrapper for GPT-4o
232
+ >>> wrapper = create_github_model_wrapper("github/gpt-4o", temperature=0.3)
233
+ >>>
234
+ >>> # Use it for text generation
235
+ >>> response = wrapper([{"type": "text", "content": "Explain quantum computing"}])
236
+ >>>
237
+ >>> # Use it for vision (if model supports it)
238
+ >>> response = wrapper([
239
+ ... {"type": "text", "content": "What's in this image?"},
240
+ ... {"type": "image", "content": "path/to/image.jpg"}
241
+ ... ])
242
+ """
243
+ return GitHubModelsWrapper(model_name=model_name, **kwargs)
244
+
245
+ # Available GitHub Models (as of the documentation)
246
+ AVAILABLE_MODELS = {
247
+ # GPT Models
248
+ "gpt-4o": "github/gpt-4o",
249
+ "gpt-4o-mini": "github/gpt-4o-mini",
250
+ "o1-preview": "github/o1-preview",
251
+ "o1-mini": "github/o1-mini",
252
+ "gpt-4.1": "github/gpt-4.1",
253
+
254
+
255
+ # Phi Models
256
+ "phi-3-5-mini-instruct": "github/phi-3.5-mini-instruct",
257
+ "phi-3-5-moe-instruct": "github/phi-3.5-moe-instruct",
258
+
259
+ # Llama Models
260
+ "llama-3.1-405b-instruct": "github/llama-3.1-405b-instruct",
261
+ "llama-3.1-70b-instruct": "github/llama-3.1-70b-instruct",
262
+ "llama-3.1-8b-instruct": "github/llama-3.1-8b-instruct",
263
+
264
+ # Mistral Models
265
+ "mistral-large": "github/mistral-large",
266
+ "mistral-large-2407": "github/mistral-large-2407",
267
+ "mistral-nemo": "github/mistral-nemo",
268
+ "mistral-small": "github/mistral-small",
269
+
270
+ # Cohere Models
271
+ "cohere-command-r": "github/cohere-command-r",
272
+ "cohere-command-r-plus": "github/cohere-command-r-plus",
273
+
274
+ # AI21 Models
275
+ "ai21-jamba-1.5-large": "github/ai21-jamba-1.5-large",
276
+ "ai21-jamba-1.5-mini": "github/ai21-jamba-1.5-mini"
277
+ }
278
+
279
+ def list_available_models() -> Dict[str, str]:
280
+ """
281
+ Get a dictionary of available GitHub models
282
+
283
+ Returns:
284
+ Dictionary mapping friendly names to full model names
285
+ """
286
+ return AVAILABLE_MODELS.copy()
287
+
288
+ if __name__ == "__main__":
289
+ # Example usage
290
+ print("Available GitHub Models:")
291
+ for friendly_name, full_name in AVAILABLE_MODELS.items():
292
+ print(f" {friendly_name}: {full_name}")
293
+
294
+ # Example of creating a wrapper (requires GITHUB_TOKEN environment variable)
295
+ try:
296
+ wrapper = create_github_model_wrapper("github/gpt-4o-mini", temperature=0.3)
297
+ print("\nGitHub Models wrapper created successfully!")
298
+
299
+ # Test with a simple text prompt
300
+ response = wrapper([{"type": "text", "content": "Hello! Can you confirm you're working?"}])
301
+ print(f"Response: {response}")
302
+
303
+ except Exception as e:
304
+ print(f"Error creating wrapper: {e}")
305
+ print("Make sure to set GITHUB_TOKEN environment variable")
mllm_tools/litellm.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from typing import List, Dict, Any, Union, Optional
4
+ import io
5
+ import os
6
+ import base64
7
+ from PIL import Image
8
+ import mimetypes
9
+ import litellm
10
+ from litellm import completion, completion_cost
11
+ from dotenv import load_dotenv
12
+
13
+ load_dotenv()
14
+
15
+ class LiteLLMWrapper:
16
+ """Wrapper for LiteLLM to support multiple models and logging"""
17
+
18
+ def __init__(
19
+ self,
20
+ model_name: str = "gpt-4-vision-preview",
21
+ temperature: float = 0.7,
22
+ print_cost: bool = False,
23
+ verbose: bool = False,
24
+ use_langfuse: bool = True,
25
+ ):
26
+ """
27
+ Initialize the LiteLLM wrapper
28
+
29
+ Args:
30
+ model_name: Name of the model to use (e.g. "azure/gpt-4", "vertex_ai/gemini-pro")
31
+ temperature: Temperature for completion
32
+ print_cost: Whether to print the cost of the completion
33
+ verbose: Whether to print verbose output
34
+ use_langfuse: Whether to enable Langfuse logging
35
+ """
36
+ self.model_name = model_name
37
+ self.temperature = temperature
38
+ self.print_cost = print_cost
39
+ self.verbose = verbose
40
+ self.accumulated_cost = 0
41
+
42
+ if self.verbose:
43
+ os.environ['LITELLM_LOG'] = 'DEBUG'
44
+
45
+ # Set langfuse callback only if enabled
46
+ if use_langfuse:
47
+ litellm.success_callback = ["langfuse"]
48
+ litellm.failure_callback = ["langfuse"]
49
+
50
+ def _encode_file(self, file_path: Union[str, Image.Image]) -> str:
51
+ """
52
+ Encode local file or PIL Image to base64 string
53
+
54
+ Args:
55
+ file_path: Path to local file or PIL Image object
56
+
57
+ Returns:
58
+ Base64 encoded file string
59
+ """
60
+ if isinstance(file_path, Image.Image):
61
+ buffered = io.BytesIO()
62
+ file_path.save(buffered, format="PNG")
63
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
64
+ else:
65
+ with open(file_path, "rb") as file:
66
+ return base64.b64encode(file.read()).decode("utf-8")
67
+
68
+ def _get_mime_type(self, file_path: str) -> str:
69
+ """
70
+ Get the MIME type of a file based on its extension
71
+
72
+ Args:
73
+ file_path: Path to the file
74
+
75
+ Returns:
76
+ MIME type as a string (e.g., "image/jpeg", "audio/mp3")
77
+ """
78
+ mime_type, _ = mimetypes.guess_type(file_path)
79
+ if mime_type is None:
80
+ raise ValueError(f"Unsupported file type: {file_path}")
81
+ return mime_type
82
+
83
+ def __call__(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None) -> str:
84
+ """
85
+ Process messages and return completion
86
+
87
+ Args:
88
+ messages: List of message dictionaries with 'type' and 'content' keys
89
+ metadata: Optional metadata to pass to litellm completion, e.g. for Langfuse tracking
90
+
91
+ Returns:
92
+ Generated text response
93
+ """
94
+ if metadata is None:
95
+ print("No metadata provided, using empty metadata")
96
+ metadata = {}
97
+ metadata["trace_name"] = f"litellm-completion-{self.model_name}"
98
+ # Convert messages to LiteLLM format
99
+ formatted_messages = []
100
+ for msg in messages:
101
+ if msg["type"] == "text":
102
+ formatted_messages.append({
103
+ "role": "user",
104
+ "content": [{"type": "text", "text": msg["content"]}]
105
+ })
106
+ elif msg["type"] in ["image", "audio", "video"]:
107
+ # Check if content is a local file path or PIL Image
108
+ if isinstance(msg["content"], Image.Image) or os.path.isfile(msg["content"]):
109
+ try:
110
+ if isinstance(msg["content"], Image.Image):
111
+ mime_type = "image/png"
112
+ else:
113
+ mime_type = self._get_mime_type(msg["content"])
114
+ base64_data = self._encode_file(msg["content"])
115
+ data_url = f"data:{mime_type};base64,{base64_data}"
116
+ except ValueError as e:
117
+ print(f"Error processing file {msg['content']}: {e}")
118
+ continue
119
+ else:
120
+ data_url = msg["content"]
121
+
122
+ # Append the formatted message based on the model
123
+ if "gemini" in self.model_name:
124
+ formatted_messages.append({
125
+ "role": "user",
126
+ "content": [
127
+ {
128
+ "type": "image_url",
129
+ "image_url": data_url
130
+ }
131
+ ]
132
+ })
133
+ elif "gpt" in self.model_name:
134
+ # GPT and other models expect a different format
135
+ if msg["type"] == "image":
136
+ # Default format for images and videos in GPT
137
+ formatted_messages.append({
138
+ "role": "user",
139
+ "content": [
140
+ {
141
+ "type": f"image_url",
142
+ f"{msg['type']}_url": {
143
+ "url": data_url,
144
+ "detail": "high"
145
+ }
146
+ }
147
+ ]
148
+ })
149
+ else:
150
+ raise ValueError("For GPT, only text and image inferencing are supported")
151
+ else:
152
+ raise ValueError("Only support Gemini and Gpt for Multimodal capability now")
153
+
154
+ try:
155
+ # if it's openai o series model, set temperature to None and reasoning_effort to "medium"
156
+ if (re.match(r"^o\d+.*$", self.model_name) or re.match(r"^openai/o.*$", self.model_name)):
157
+ self.temperature = None
158
+ self.reasoning_effort = "medium"
159
+ response = completion(
160
+ model=self.model_name,
161
+ messages=formatted_messages,
162
+ temperature=self.temperature,
163
+ reasoning_effort=self.reasoning_effort,
164
+ metadata=metadata,
165
+ max_retries=99
166
+ )
167
+ else:
168
+ response = completion(
169
+ model=self.model_name,
170
+ messages=formatted_messages,
171
+ temperature=self.temperature,
172
+ metadata=metadata,
173
+ max_retries=99
174
+ )
175
+ if self.print_cost:
176
+ # pass your response from completion to completion_cost
177
+ cost = completion_cost(completion_response=response)
178
+ formatted_string = f"Cost: ${float(cost):.10f}"
179
+ # print(formatted_string)
180
+ self.accumulated_cost += cost
181
+ print(f"Accumulated Cost: ${self.accumulated_cost:.10f}")
182
+
183
+ content = response.choices[0].message.content
184
+ if content is None:
185
+ print(f"Got null response from model. Full response: {response}")
186
+ return content
187
+
188
+ except Exception as e:
189
+ print(f"Error in model completion: {e}")
190
+ return str(e)
191
+
192
+ if __name__ == "__main__":
193
+ pass
mllm_tools/openai.py ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # filepath: d:\Theory2Manim-2\Theory2Manim\mllm_tools\openai.py
2
+ import json
3
+ import re
4
+ from typing import List, Dict, Any, Union, Optional
5
+ import io
6
+ import os
7
+ import base64
8
+ from PIL import Image
9
+ import mimetypes
10
+ import litellm
11
+ from litellm import completion, completion_cost
12
+ from dotenv import load_dotenv
13
+
14
+ # Load environment variables from .env file
15
+ load_dotenv()
16
+
17
+ # Note: Environment variables should be loaded from .env file or set manually using os.environ
18
+
19
+ class OpenAIWrapper:
20
+ """Wrapper for OpenAI using LiteLLM to support all OpenAI models with unified interface"""
21
+
22
+ def __init__(
23
+ self,
24
+ model_name: str = "gpt-4o",
25
+ temperature: float = 0.7,
26
+ print_cost: bool = False,
27
+ verbose: bool = False,
28
+ use_langfuse: bool = True,
29
+ api_key: Optional[str] = None,
30
+ organization: Optional[str] = None,
31
+ base_url: Optional[str] = None,
32
+ use_github_token: bool = True,
33
+ github_token: Optional[str] = os.getenv('GITHUB_TOKEN')
34
+ ):
35
+ """
36
+ Initialize the OpenAI wrapper
37
+
38
+ Args:
39
+ model_name: Name of the OpenAI model to use (e.g. "gpt-4o", "gpt-4o-mini",
40
+ "gpt-3.5-turbo", "o1-preview", "o1-mini", "dall-e-3")
41
+ temperature: Temperature for completion (ignored for o1 models)
42
+ print_cost: Whether to print the cost of the completion
43
+ verbose: Whether to print verbose output
44
+ use_langfuse: Whether to enable Langfuse logging
45
+ api_key: OpenAI API key (if not provided, will use OPENAI_API_KEY env var)
46
+ organization: OpenAI organization ID (optional)
47
+ base_url: Custom base URL for OpenAI API (optional, for proxies)
48
+ use_github_token: Whether to use GitHub AI model inference endpoint
49
+ github_token: GitHub token (if not provided, will use GITHUB_TOKEN env var)
50
+ """
51
+ self.model_name = model_name
52
+ self.temperature = temperature
53
+ self.print_cost = print_cost
54
+ self.verbose = verbose
55
+ self.accumulated_cost = 0
56
+ self.use_github_token = use_github_token
57
+
58
+ # Configure API based on whether using GitHub token or OpenAI API
59
+ if use_github_token:
60
+ # Set up GitHub token and endpoint
61
+ self.github_token = github_token or os.getenv('GITHUB_TOKEN')
62
+ if not self.github_token:
63
+ raise ValueError("GitHub token is required when use_github_token=True. Please set GITHUB_TOKEN environment variable or pass github_token parameter.")
64
+
65
+ # Set GitHub AI inference endpoint
66
+ self.base_url = "https://models.github.ai/inference"
67
+ self.api_key = self.github_token
68
+
69
+ # Set environment variables for LiteLLM to use GitHub endpoint
70
+ os.environ['OPENAI_API_KEY'] = self.github_token
71
+ os.environ['OPENAI_BASE_URL'] = self.base_url
72
+
73
+ # Adjust model name for GitHub endpoint (add openai/ prefix if not present)
74
+ if not self.model_name.startswith("openai/"):
75
+ self.model_name = f"openai/{self.model_name}"
76
+
77
+ else:
78
+ # Original OpenAI API setup
79
+ self.api_key = api_key or os.getenv('OPENAI_API_KEY')
80
+ if not self.api_key:
81
+ raise ValueError("OpenAI API key is required. Please set OPENAI_API_KEY environment variable or pass api_key parameter.")
82
+
83
+ # Set environment variables for LiteLLM
84
+ os.environ['OPENAI_API_KEY'] = self.api_key
85
+
86
+ # Set optional custom base URL
87
+ if base_url:
88
+ os.environ['OPENAI_BASE_URL'] = base_url
89
+ self.base_url = base_url
90
+ else:
91
+ self.base_url = os.getenv('OPENAI_BASE_URL')
92
+
93
+ # Set optional organization (only for OpenAI, not GitHub)
94
+ if not use_github_token:
95
+ if organization:
96
+ os.environ['OPENAI_ORGANIZATION'] = organization
97
+ self.organization = organization
98
+ else:
99
+ self.organization = os.getenv('OPENAI_ORGANIZATION')
100
+ else:
101
+ self.organization = None
102
+
103
+ if self.verbose:
104
+ os.environ['LITELLM_LOG'] = 'DEBUG'
105
+
106
+ # Set langfuse callback only if enabled
107
+ if use_langfuse:
108
+ litellm.success_callback = ["langfuse"]
109
+ litellm.failure_callback = ["langfuse"]
110
+
111
+ def _encode_file(self, file_path: Union[str, Image.Image]) -> str:
112
+ """
113
+ Encode local file or PIL Image to base64 string
114
+
115
+ Args:
116
+ file_path: Path to local file or PIL Image object
117
+
118
+ Returns:
119
+ Base64 encoded file string
120
+ """
121
+ if isinstance(file_path, Image.Image):
122
+ buffered = io.BytesIO()
123
+ file_path.save(buffered, format="PNG")
124
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
125
+ else:
126
+ with open(file_path, "rb") as file:
127
+ return base64.b64encode(file.read()).decode("utf-8")
128
+
129
+ def _get_mime_type(self, file_path: str) -> str:
130
+ """
131
+ Get the MIME type of a file based on its extension
132
+
133
+ Args:
134
+ file_path: Path to the file
135
+
136
+ Returns:
137
+ MIME type as a string (e.g., "image/jpeg", "application/pdf")
138
+ """
139
+ mime_type, _ = mimetypes.guess_type(file_path)
140
+ if mime_type is None:
141
+ raise ValueError(f"Unsupported file type: {file_path}")
142
+ return mime_type
143
+
144
+ def _supports_vision(self, model_name: str) -> bool:
145
+ """
146
+ Check if the model supports vision/image processing
147
+
148
+ Args:
149
+ model_name: Name of the model
150
+
151
+ Returns:
152
+ True if model supports vision, False otherwise
153
+ """
154
+ vision_models = [
155
+ "gpt-4o",
156
+ "gpt-4o-mini",
157
+ "gpt-4-vision-preview",
158
+ "gpt-4-turbo",
159
+ "gpt-4-turbo-vision"
160
+ ]
161
+
162
+ return any(vision_model in model_name for vision_model in vision_models)
163
+
164
+ def _supports_files(self, model_name: str) -> bool:
165
+ """
166
+ Check if the model supports file processing (PDFs, documents)
167
+
168
+ Args:
169
+ model_name: Name of the model
170
+
171
+ Returns:
172
+ True if model supports file processing, False otherwise
173
+ """
174
+ file_models = [
175
+ "gpt-4o",
176
+ "gpt-4o-mini",
177
+ "gpt-4-turbo"
178
+ ]
179
+
180
+ return any(file_model in model_name for file_model in file_models)
181
+
182
+ def _is_o1_model(self, model_name: str) -> bool:
183
+ """
184
+ Check if the model is an o1 series model (reasoning models)
185
+
186
+ Args:
187
+ model_name: Name of the model
188
+
189
+ Returns:
190
+ True if it's an o1 model, False otherwise
191
+ """
192
+ return "o1" in model_name
193
+
194
+ def __call__(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None, **kwargs) -> str:
195
+ """
196
+ Process messages and return completion
197
+
198
+ Args:
199
+ messages: List of message dictionaries with 'type' and 'content' keys
200
+ metadata: Optional metadata to pass to litellm completion, e.g. for Langfuse tracking
201
+ **kwargs: Additional parameters for completion (max_tokens, stream, etc.)
202
+
203
+ Returns:
204
+ Generated text response
205
+ """
206
+ if metadata is None:
207
+ metadata = {}
208
+ metadata["trace_name"] = f"openai-completion-{self.model_name}"
209
+
210
+ # Convert messages to LiteLLM format
211
+ formatted_messages = []
212
+
213
+ for msg in messages:
214
+ if msg["type"] == "text":
215
+ formatted_messages.append({
216
+ "role": "user",
217
+ "content": [{"type": "text", "text": msg["content"]}]
218
+ })
219
+ elif msg["type"] == "image":
220
+ # Check if model supports vision
221
+ if not self._supports_vision(self.model_name):
222
+ raise ValueError(f"Model {self.model_name} does not support image processing")
223
+
224
+ # Check if content is a local file path or PIL Image
225
+ if isinstance(msg["content"], Image.Image) or (isinstance(msg["content"], str) and os.path.isfile(msg["content"])):
226
+ try:
227
+ if isinstance(msg["content"], Image.Image):
228
+ mime_type = "image/png"
229
+ else:
230
+ mime_type = self._get_mime_type(msg["content"])
231
+ base64_data = self._encode_file(msg["content"])
232
+ data_url = f"data:{mime_type};base64,{base64_data}"
233
+ except ValueError as e:
234
+ print(f"Error processing file {msg['content']}: {e}")
235
+ continue
236
+ else:
237
+ # Assume it's already a URL or base64 string
238
+ data_url = msg["content"]
239
+
240
+ # Format for vision-capable models
241
+ formatted_messages.append({
242
+ "role": "user",
243
+ "content": [
244
+ {
245
+ "type": "image_url",
246
+ "image_url": {
247
+ "url": data_url,
248
+ "detail": "high"
249
+ }
250
+ }
251
+ ]
252
+ })
253
+ elif msg["type"] == "file":
254
+ # Check if model supports file processing
255
+ if not self._supports_files(self.model_name):
256
+ raise ValueError(f"Model {self.model_name} does not support file processing")
257
+
258
+ # Handle file content (PDF, documents, etc.)
259
+ if os.path.isfile(msg["content"]):
260
+ try:
261
+ mime_type = self._get_mime_type(msg["content"])
262
+ base64_data = self._encode_file(msg["content"])
263
+
264
+ # Use the file format for document processing
265
+ formatted_messages.append({
266
+ "role": "user",
267
+ "content": [
268
+ {
269
+ "type": "file",
270
+ "file": {
271
+ "filename": os.path.basename(msg["content"]),
272
+ "file_data": f"data:{mime_type};base64,{base64_data}",
273
+ }
274
+ }
275
+ ]
276
+ })
277
+ except ValueError as e:
278
+ print(f"Error processing file {msg['content']}: {e}")
279
+ continue
280
+ else:
281
+ raise ValueError(f"File not found: {msg['content']}")
282
+ else:
283
+ raise ValueError(f"Unsupported message type: {msg['type']}. OpenAI models support 'text', 'image', and 'file' types.")
284
+
285
+ try:
286
+ # Prepare completion parameters
287
+ completion_params = {
288
+ "model": self.model_name,
289
+ "messages": formatted_messages,
290
+ "metadata": metadata,
291
+ "max_retries": 3
292
+ }
293
+
294
+ # Add additional kwargs
295
+ completion_params.update(kwargs)
296
+
297
+ # Check if it's an o1 series model (reasoning models)
298
+ if self._is_o1_model(self.model_name):
299
+ # O1 models don't support temperature and have reasoning_effort
300
+ if "reasoning_effort" not in completion_params:
301
+ completion_params["reasoning_effort"] = "medium" # Options: "low", "medium", "high"
302
+ # Remove temperature if it was added via kwargs
303
+ completion_params.pop("temperature", None)
304
+ else:
305
+ # Regular models support temperature
306
+ if "temperature" not in completion_params:
307
+ completion_params["temperature"] = self.temperature
308
+
309
+ response = completion(**completion_params)
310
+
311
+ if self.print_cost:
312
+ try:
313
+ cost = completion_cost(completion_response=response)
314
+ if cost is not None:
315
+ self.accumulated_cost += cost
316
+ print(f"Cost: ${float(cost):.10f}")
317
+ print(f"Accumulated Cost: ${self.accumulated_cost:.10f}")
318
+ else:
319
+ print("Cost information not available")
320
+ except Exception as e:
321
+ print(f"Could not calculate cost: {e}")
322
+
323
+ content = response.choices[0].message.content
324
+ if content is None:
325
+ print(f"Got null response from OpenAI model. Full response: {response}")
326
+ return ""
327
+ return content
328
+
329
+ except Exception as e:
330
+ print(f"Error in OpenAI model completion: {e}")
331
+ return str(e)
332
+
333
+ def stream_completion(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None, **kwargs):
334
+ """
335
+ Process messages and return streaming completion
336
+
337
+ Args:
338
+ messages: List of message dictionaries with 'type' and 'content' keys
339
+ metadata: Optional metadata to pass to litellm completion
340
+ **kwargs: Additional parameters for completion
341
+
342
+ Yields:
343
+ Streaming response chunks
344
+ """
345
+ kwargs["stream"] = True
346
+
347
+ # Use the same message formatting as regular completion
348
+ if metadata is None:
349
+ metadata = {}
350
+ metadata["trace_name"] = f"openai-streaming-{self.model_name}"
351
+
352
+ try:
353
+ # Convert messages to the same format as __call__
354
+ formatted_messages = []
355
+
356
+ for msg in messages:
357
+ if msg["type"] == "text":
358
+ formatted_messages.append({
359
+ "role": "user",
360
+ "content": [{"type": "text", "text": msg["content"]}]
361
+ })
362
+ elif msg["type"] == "image":
363
+ if not self._supports_vision(self.model_name):
364
+ raise ValueError(f"Model {self.model_name} does not support image processing")
365
+
366
+ if isinstance(msg["content"], Image.Image) or (isinstance(msg["content"], str) and os.path.isfile(msg["content"])):
367
+ try:
368
+ if isinstance(msg["content"], Image.Image):
369
+ mime_type = "image/png"
370
+ else:
371
+ mime_type = self._get_mime_type(msg["content"])
372
+ base64_data = self._encode_file(msg["content"])
373
+ data_url = f"data:{mime_type};base64,{base64_data}"
374
+ except ValueError as e:
375
+ print(f"Error processing file {msg['content']}: {e}")
376
+ continue
377
+ else:
378
+ data_url = msg["content"]
379
+
380
+ formatted_messages.append({
381
+ "role": "user",
382
+ "content": [
383
+ {
384
+ "type": "image_url",
385
+ "image_url": {
386
+ "url": data_url,
387
+ "detail": "high"
388
+ }
389
+ }
390
+ ]
391
+ })
392
+
393
+ # Prepare completion parameters
394
+ completion_params = {
395
+ "model": self.model_name,
396
+ "messages": formatted_messages,
397
+ "metadata": metadata,
398
+ "max_retries": 3,
399
+ "stream": True
400
+ }
401
+
402
+ # Add additional kwargs
403
+ completion_params.update(kwargs)
404
+
405
+ # Handle o1 models
406
+ if self._is_o1_model(self.model_name):
407
+ if "reasoning_effort" not in completion_params:
408
+ completion_params["reasoning_effort"] = "medium"
409
+ completion_params.pop("temperature", None)
410
+ else:
411
+ if "temperature" not in completion_params:
412
+ completion_params["temperature"] = self.temperature
413
+
414
+ response = completion(**completion_params)
415
+
416
+ # Yield streaming chunks
417
+ for chunk in response:
418
+ yield chunk
419
+
420
+ except Exception as e:
421
+ print(f"Error in OpenAI streaming completion: {e}")
422
+ yield {"error": str(e)}
423
+
424
+ def create_openai_wrapper(model_name: str = "gpt-4o", use_github: bool = False, **kwargs) -> OpenAIWrapper:
425
+ """
426
+ Convenience function to create an OpenAI wrapper
427
+
428
+ Args:
429
+ model_name: OpenAI model name (e.g., "gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo")
430
+ use_github: Whether to use GitHub's AI model inference endpoint
431
+ **kwargs: Additional arguments passed to OpenAIWrapper
432
+
433
+ Returns:
434
+ Configured OpenAIWrapper instance
435
+
436
+ Example:
437
+ >>> # Create a wrapper for GPT-4o using regular OpenAI
438
+ >>> wrapper = create_openai_wrapper("gpt-4o", temperature=0.3)
439
+ >>>
440
+ >>> # Create a wrapper for GPT-4o using GitHub AI models
441
+ >>> wrapper = create_openai_wrapper("gpt-4o", use_github=True, temperature=0.3)
442
+ >>>
443
+ >>> # Use it for text generation
444
+ >>> response = wrapper([{"type": "text", "content": "Explain quantum computing"}])
445
+ >>>
446
+ >>> # Use it for vision (if model supports it)
447
+ >>> response = wrapper([
448
+ ... {"type": "text", "content": "What's in this image?"},
449
+ ... {"type": "image", "content": "path/to/image.jpg"}
450
+ ... ])
451
+ >>>
452
+ >>> # Use it for file processing (PDFs, etc.)
453
+ >>> response = wrapper([
454
+ ... {"type": "text", "content": "Summarize this document"},
455
+ ... {"type": "file", "content": "path/to/document.pdf"}
456
+ ... ])
457
+ """
458
+ return OpenAIWrapper(model_name=model_name, use_github_token=use_github, **kwargs)
459
+
460
+ # Available OpenAI Models
461
+ AVAILABLE_MODELS = {
462
+ # GPT-4 Models
463
+ "gpt-4o": "gpt-4o",
464
+ "gpt-4o-mini": "gpt-4o-mini",
465
+ "gpt-4-turbo": "gpt-4-turbo",
466
+ "gpt-4": "gpt-4",
467
+ "gpt-4-vision-preview": "gpt-4-vision-preview",
468
+
469
+ # O1 Reasoning Models
470
+ "o1-preview": "o1-preview",
471
+ "o1-mini": "o1-mini",
472
+
473
+ # GPT-3.5 Models
474
+ "gpt-3.5-turbo": "gpt-3.5-turbo",
475
+ "gpt-3.5-turbo-instruct": "gpt-3.5-turbo-instruct",
476
+
477
+ # Image Generation Models
478
+ "dall-e-3": "dall-e-3",
479
+ "dall-e-2": "dall-e-2",
480
+
481
+ # Embedding Models
482
+ "text-embedding-3-large": "text-embedding-3-large",
483
+ "text-embedding-3-small": "text-embedding-3-small",
484
+ "text-embedding-ada-002": "text-embedding-ada-002",
485
+
486
+ # Audio Models
487
+ "whisper-1": "whisper-1",
488
+ "tts-1": "tts-1",
489
+ "tts-1-hd": "tts-1-hd",
490
+ }
491
+
492
+ def create_github_openai_wrapper(model_name: str = "gpt-4o", **kwargs) -> OpenAIWrapper:
493
+ """
494
+ Convenience function to create an OpenAI wrapper using GitHub's AI model inference
495
+
496
+ Args:
497
+ model_name: OpenAI model name (e.g., "gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo")
498
+ **kwargs: Additional arguments passed to OpenAIWrapper
499
+
500
+ Returns:
501
+ Configured OpenAIWrapper instance using GitHub endpoint
502
+
503
+ Example:
504
+ >>> # Create a wrapper for GPT-4o using GitHub AI models
505
+ >>> wrapper = create_github_openai_wrapper("gpt-4o", temperature=0.3)
506
+ >>>
507
+ >>> # Use it for text generation
508
+ >>> response = wrapper([{"type": "text", "content": "What is the capital of France?"}])
509
+ """
510
+ return OpenAIWrapper(model_name=model_name, use_github_token=True, **kwargs)
511
+
512
+ def list_available_models() -> Dict[str, str]:
513
+ """
514
+ Get a dictionary of available OpenAI models
515
+
516
+ Returns:
517
+ Dictionary mapping model names to their identifiers
518
+ """
519
+ return AVAILABLE_MODELS.copy()
520
+
521
+ def get_model_capabilities(model_name: str) -> Dict[str, bool]:
522
+ """
523
+ Get the capabilities of a specific model
524
+
525
+ Args:
526
+ model_name: Name of the model
527
+
528
+ Returns:
529
+ Dictionary of capabilities (vision, files, reasoning, etc.)
530
+ """
531
+ wrapper = OpenAIWrapper(model_name=model_name)
532
+
533
+ return {
534
+ "vision": wrapper._supports_vision(model_name),
535
+ "files": wrapper._supports_files(model_name),
536
+ "reasoning": wrapper._is_o1_model(model_name),
537
+ "streaming": not wrapper._is_o1_model(model_name), # O1 models don't support streaming
538
+ "temperature": not wrapper._is_o1_model(model_name), # O1 models don't support temperature
539
+ }
540
+
541
+ if __name__ == "__main__":
542
+ # Example usage
543
+ print("Available OpenAI Models:")
544
+ for model_name, model_id in AVAILABLE_MODELS.items():
545
+ capabilities = get_model_capabilities(model_name)
546
+ print(f" {model_name} ({model_id}): {capabilities}")
547
+
548
+ print("\n" + "="*50)
549
+ print("Testing OpenAI wrapper...")
550
+
551
+ # Example 1: Regular OpenAI (requires OPENAI_API_KEY environment variable)
552
+ try:
553
+ print("\n1. Testing regular OpenAI wrapper:")
554
+ wrapper = create_openai_wrapper("gpt-4o-mini", temperature=0.3)
555
+ print("Regular OpenAI wrapper created successfully!")
556
+
557
+ # Test with a simple text prompt
558
+ response = wrapper([{"type": "text", "content": "Hello! Can you confirm you're working?"}])
559
+ print(f"Response: {response}")
560
+
561
+ except Exception as e:
562
+ print(f"Error creating regular OpenAI wrapper: {e}")
563
+ print("Make sure to set OPENAI_API_KEY environment variable")
564
+
565
+ # Example 2: GitHub AI models (requires GITHUB_TOKEN environment variable)
566
+ try:
567
+ print("\n2. Testing GitHub AI models wrapper:")
568
+ github_wrapper = create_github_openai_wrapper("gpt-4o", temperature=1.0)
569
+ print("GitHub OpenAI wrapper created successfully!")
570
+
571
+ # Test with a simple text prompt
572
+ response = github_wrapper([{
573
+ "type": "text",
574
+ "content": "What is the capital of France?"
575
+ }])
576
+ print(f"GitHub Response: {response}")
577
+
578
+ except Exception as e:
579
+ print(f"Error creating GitHub wrapper: {e}")
580
+ print("Make sure to set GITHUB_TOKEN environment variable")
581
+ # Example 3: Manual GitHub configuration
582
+ try:
583
+ print("\n3. Testing manual GitHub configuration:")
584
+ manual_wrapper = OpenAIWrapper(
585
+ model_name="openai/gpt-4o",
586
+ use_github_token=True,
587
+ temperature=1.0,
588
+ verbose=False
589
+ )
590
+ print("Manual GitHub wrapper created successfully!")
591
+
592
+ except Exception as e:
593
+ print(f"Error creating manual GitHub wrapper: {e}")
594
+ print("Make sure to set GITHUB_TOKEN environment variable")
mllm_tools/openrouter.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from typing import List, Dict, Any, Optional, Union
4
+ import io
5
+ import base64
6
+ from PIL import Image
7
+ import mimetypes
8
+ from litellm import completion, completion_cost
9
+ from dotenv import load_dotenv
10
+
11
+ load_dotenv()
12
+
13
+ class OpenRouterWrapper:
14
+ """
15
+ OpenRouter wrapper using LiteLLM for various language models.
16
+ Compatible with the existing wrapper interface.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ model_name: str = "openrouter/deepseek/deepseek-chat-v3-0324:free",
22
+ temperature: float = 0.7,
23
+ print_cost: bool = False,
24
+ verbose: bool = False,
25
+ use_langfuse: bool = True,
26
+ site_url: str = "",
27
+ app_name: str = "Theory2Manim"
28
+ ):
29
+ """
30
+ Initialize OpenRouter wrapper.
31
+
32
+ Args:
33
+ model_name: OpenRouter model name (with openrouter/ prefix)
34
+ temperature: Temperature for completion
35
+ print_cost: Whether to print the cost of the completion
36
+ verbose: Whether to print verbose output
37
+ use_langfuse: Whether to enable Langfuse logging
38
+ site_url: Optional site URL for tracking
39
+ app_name: Optional app name for tracking
40
+ """
41
+ self.model_name = model_name
42
+ self.temperature = temperature
43
+ self.print_cost = print_cost
44
+ self.verbose = verbose
45
+ self.accumulated_cost = 0
46
+
47
+ # Setup OpenRouter environment variables
48
+ api_key = os.getenv("OPENROUTER_API_KEY")
49
+ if not api_key:
50
+ raise ValueError("No OPENROUTER_API_KEY found. Please set the environment variable.")
51
+
52
+ os.environ["OPENROUTER_API_KEY"] = api_key
53
+ os.environ["OPENROUTER_API_BASE"] = "https://openrouter.ai/api/v1"
54
+
55
+ if site_url or os.getenv("OR_SITE_URL"):
56
+ os.environ["OR_SITE_URL"] = site_url or os.getenv("OR_SITE_URL", "")
57
+ if app_name:
58
+ os.environ["OR_APP_NAME"] = app_name
59
+
60
+ if self.verbose:
61
+ os.environ['LITELLM_LOG'] = 'DEBUG'
62
+
63
+ # Set langfuse callback only if enabled
64
+ if use_langfuse:
65
+ import litellm
66
+ litellm.success_callback = ["langfuse"]
67
+ litellm.failure_callback = ["langfuse"]
68
+
69
+ def _encode_file(self, file_path: Union[str, Image.Image]) -> str:
70
+ """
71
+ Encode local file or PIL Image to base64 string
72
+
73
+ Args:
74
+ file_path: Path to local file or PIL Image object
75
+
76
+ Returns:
77
+ Base64 encoded file string
78
+ """
79
+ if isinstance(file_path, Image.Image):
80
+ buffered = io.BytesIO()
81
+ file_path.save(buffered, format="PNG")
82
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
83
+ else:
84
+ with open(file_path, "rb") as file:
85
+ return base64.b64encode(file.read()).decode("utf-8")
86
+
87
+ def _get_mime_type(self, file_path: str) -> str:
88
+ """
89
+ Get the MIME type of a file based on its extension
90
+
91
+ Args:
92
+ file_path: Path to the file
93
+
94
+ Returns:
95
+ MIME type as a string (e.g., "image/jpeg", "audio/mp3")
96
+ """
97
+ mime_type, _ = mimetypes.guess_type(file_path)
98
+ if mime_type is None:
99
+ raise ValueError(f"Unsupported file type: {file_path}")
100
+ return mime_type
101
+
102
+ def __call__(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None) -> str:
103
+ """
104
+ Process messages and return completion
105
+
106
+ Args:
107
+ messages: List of message dictionaries with 'type' and 'content' keys
108
+ metadata: Optional metadata to pass to completion
109
+
110
+ Returns:
111
+ Generated text response
112
+ """
113
+ if metadata is None:
114
+ metadata = {}
115
+ metadata["trace_name"] = f"openrouter-completion-{self.model_name}"
116
+
117
+ # Convert messages to LiteLLM format
118
+ formatted_messages = []
119
+ for msg in messages:
120
+ if msg["type"] == "text":
121
+ formatted_messages.append({
122
+ "role": "user",
123
+ "content": [{"type": "text", "text": msg["content"]}]
124
+ })
125
+ elif msg["type"] in ["image", "audio", "video"]:
126
+ # Check if content is a local file path or PIL Image
127
+ if isinstance(msg["content"], Image.Image) or os.path.isfile(msg["content"]):
128
+ try:
129
+ if isinstance(msg["content"], Image.Image):
130
+ mime_type = "image/png"
131
+ else:
132
+ mime_type = self._get_mime_type(msg["content"])
133
+ base64_data = self._encode_file(msg["content"])
134
+ data_url = f"data:{mime_type};base64,{base64_data}"
135
+ except ValueError as e:
136
+ print(f"Error processing file {msg['content']}: {e}")
137
+ continue
138
+ else:
139
+ data_url = msg["content"]
140
+
141
+ # Format for vision models
142
+ if msg["type"] == "image":
143
+ formatted_messages.append({
144
+ "role": "user",
145
+ "content": [
146
+ {
147
+ "type": "image_url",
148
+ "image_url": {
149
+ "url": data_url,
150
+ "detail": "high"
151
+ }
152
+ }
153
+ ]
154
+ })
155
+ else:
156
+ # For audio/video, treat as text for now
157
+ formatted_messages.append({
158
+ "role": "user",
159
+ "content": [{"type": "text", "text": f"[{msg['type'].upper()}]: {msg['content']}"}]
160
+ })
161
+
162
+ try:
163
+ response = completion(
164
+ model=self.model_name,
165
+ messages=formatted_messages,
166
+ temperature=self.temperature,
167
+ metadata=metadata,
168
+ max_retries=99
169
+ )
170
+ if self.print_cost:
171
+ # Calculate and print cost
172
+ cost = completion_cost(completion_response=response)
173
+ self.accumulated_cost += cost
174
+ print(f"Accumulated Cost: ${self.accumulated_cost:.10f}")
175
+
176
+ content = response.choices[0].message.content
177
+ if content is None:
178
+ print(f"Got null response from model. Full response: {response}")
179
+ return "Error: Received null response from model"
180
+
181
+ # Check if the response contains error messages about unmapped models
182
+ if "This model isn't mapped yet" in content or "model isn't mapped" in content.lower():
183
+ error_msg = f"Error: Model {self.model_name} is not supported by LiteLLM. Please use a supported model."
184
+ print(error_msg)
185
+ return error_msg
186
+
187
+ return content
188
+
189
+ except Exception as e:
190
+ print(f"Error in OpenRouter completion: {e}")
191
+ return f"Error: {str(e)}"
192
+
193
+
194
+ class OpenRouterClient:
195
+ """
196
+ Legacy OpenRouter client for backward compatibility.
197
+ """
198
+
199
+ def __init__(self, api_key: str, site_url: str = "", app_name: str = "Theory2Manim"):
200
+ """
201
+ Initialize OpenRouter client.
202
+
203
+ Args:
204
+ api_key: OpenRouter API key
205
+ site_url: Optional site URL for tracking
206
+ app_name: Optional app name for tracking
207
+ """
208
+ os.environ["OPENROUTER_API_KEY"] = api_key
209
+ os.environ["OPENROUTER_API_BASE"] = "https://openrouter.ai/api/v1"
210
+
211
+ if site_url:
212
+ os.environ["OR_SITE_URL"] = site_url
213
+ if app_name:
214
+ os.environ["OR_APP_NAME"] = app_name
215
+
216
+ def complete(
217
+ self,
218
+ messages: List[Dict[str, str]],
219
+ model: str = "openrouter/openai/gpt-3.5-turbo",
220
+ transforms: Optional[List[str]] = None,
221
+ route: Optional[str] = None,
222
+ **kwargs
223
+ ) -> Any:
224
+ """
225
+ Generate completion using OpenRouter model.
226
+
227
+ Args:
228
+ messages: List of message dictionaries with 'role' and 'content'
229
+ model: Model name (with openrouter/ prefix)
230
+ transforms: Optional transforms to apply
231
+ route: Optional route specification
232
+ **kwargs: Additional parameters for completion
233
+
234
+ Returns:
235
+ Completion response
236
+ """
237
+ params = {
238
+ "model": model,
239
+ "messages": messages,
240
+ **kwargs
241
+ }
242
+
243
+ if transforms:
244
+ params["transforms"] = transforms
245
+ if route:
246
+ params["route"] = route
247
+
248
+ return completion(**params)
249
+
250
+ # Convenience functions for common models
251
+ def ds_r1(messages: List[Dict[str, str]], **kwargs) -> Any:
252
+ """Use GPT-3.5 Turbo via OpenRouter"""
253
+ client = OpenRouterClient(os.environ.get("OPENROUTER_API_KEY", ""))
254
+ return client.complete(messages, "deepseek/deepseek-r1:free", **kwargs)
255
+
256
+ def ds_v3(messages: List[Dict[str, str]], **kwargs) -> Any:
257
+ """Use GPT-4 via OpenRouter"""
258
+ client = OpenRouterClient(os.environ.get("OPENROUTER_API_KEY", ""))
259
+ return client.complete(messages, "deepseek/deepseek-chat-v3-0324:free", **kwargs)
260
+
261
+ def qwen3(messages: List[Dict[str, str]], **kwargs) -> Any:
262
+ """Use Claude-2 via OpenRouter"""
263
+ client = OpenRouterClient(os.environ.get("OPENROUTER_API_KEY", ""))
264
+ return client.complete(messages, "qwen/qwen3-235b-a22b:free", **kwargs)
265
+
266
+
mllm_tools/utils.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, List, Dict, Any, Optional
2
+ from PIL import Image
3
+ import google.generativeai as genai
4
+ import tempfile
5
+ import os
6
+ from .gemini import GeminiWrapper
7
+ from .vertex_ai import VertexAIWrapper
8
+ from .openrouter import OpenRouterWrapper
9
+
10
+
11
+ def _prepare_text_inputs(texts: List[str]) -> List[Dict[str, str]]:
12
+ """
13
+ Converts a list of text strings into the input format for the Agent model.
14
+
15
+ Args:
16
+ texts (List[str]): The list of text strings to be processed.
17
+
18
+ Returns:
19
+ List[Dict[str, str]]: A list of dictionaries formatted for the Agent model.
20
+ """
21
+ inputs = []
22
+ # Add each text string to the inputs
23
+ if isinstance(texts, str):
24
+ texts = [texts]
25
+ for text in texts:
26
+ inputs.append({
27
+ "type": "text",
28
+ "content": text
29
+ })
30
+ return inputs
31
+
32
+ def _prepare_text_image_inputs(texts: Union[str, List[str]], images: Union[str, Image.Image, List[Union[str, Image.Image]]]) -> List[Dict[str, str]]:
33
+ """
34
+ Converts text strings and images into the input format for the Agent model.
35
+
36
+ Args:
37
+ texts (Union[str, List[str]]): Text string(s) to be processed.
38
+ images (Union[str, Image.Image, List[Union[str, Image.Image]]]): Image file path(s) or PIL Image object(s).
39
+ Returns:
40
+ List[Dict[str, str]]: A list of dictionaries formatted for the Agent model.
41
+ """
42
+ inputs = []
43
+ # Add each text string to the inputs
44
+ if isinstance(texts, str):
45
+ texts = [texts]
46
+ for text in texts:
47
+ inputs.append({
48
+ "type": "text",
49
+ "content": text
50
+ })
51
+ if isinstance(images, (str, Image.Image)):
52
+ images = [images]
53
+ for image in images:
54
+ inputs.append({
55
+ "type": "image",
56
+ "content": image
57
+ })
58
+ return inputs
59
+
60
+ def _prepare_text_video_inputs(texts: Union[str, List[str]], videos: Union[str, List[str]]) -> List[Dict[str, str]]:
61
+ """
62
+ Converts text strings and video file paths into the input format for the Agent model.
63
+
64
+ Args:
65
+ texts (Union[str, List[str]]): Text string(s) to be processed.
66
+ videos (Union[str, List[str]]): Video file path(s).
67
+ Returns:
68
+ List[Dict[str, str]]: A list of dictionaries formatted for the Agent model.
69
+ """
70
+ inputs = []
71
+ # Add each text string to the inputs
72
+ if isinstance(texts, str):
73
+ texts = [texts]
74
+ for text in texts:
75
+ inputs.append({
76
+ "type": "text",
77
+ "content": text
78
+ })
79
+ # Add each video file path to the inputs
80
+ if isinstance(videos, str):
81
+ videos = [videos]
82
+ for video in videos:
83
+ inputs.append({
84
+ "type": "video",
85
+ "content": video
86
+ })
87
+ return inputs
88
+
89
+ def _prepare_text_audio_inputs(texts: Union[str, List[str]], audios: Union[str, List[str]]) -> List[Dict[str, str]]:
90
+ """
91
+ Converts text strings and audio file paths into the input format for the Agent model.
92
+
93
+ Args:
94
+ texts (Union[str, List[str]]): Text string(s) to be processed.
95
+ audios (Union[str, List[str]]): Audio file path(s).
96
+ Returns:
97
+ List[Dict[str, str]]: A list of dictionaries formatted for the Agent model.
98
+ """
99
+ inputs = []
100
+ # Add each text string to the inputs
101
+ if isinstance(texts, str):
102
+ texts = [texts]
103
+ for text in texts:
104
+ inputs.append({
105
+ "type": "text",
106
+ "content": text
107
+ })
108
+ # Add each audio file path to the inputs
109
+ if isinstance(audios, str):
110
+ audios = [audios]
111
+ for audio in audios:
112
+ inputs.append({
113
+ "type": "audio",
114
+ "content": audio
115
+ })
116
+ return inputs
117
+
118
+ def _extract_code(text: str) -> str:
119
+ """Helper to extract code block from model response, support Gemini style and OpenAI style"""
120
+ try:
121
+ # Find code between ```python and ``` tags
122
+ start = text.split("```python\n")[-1]
123
+ end = start.split("```")[0]
124
+ return end.strip()
125
+ except IndexError:
126
+ return text
127
+
128
+ def _upload_to_gemini(input, mime_type=None):
129
+ """Uploads the given file or PIL image to Gemini.
130
+
131
+ See https://ai.google.dev/gemini-api/docs/prompting_with_media
132
+ """
133
+ if isinstance(input, str):
134
+ # Input is a file path
135
+ file = genai.upload_file(input, mime_type=mime_type)
136
+ elif isinstance(input, Image.Image):
137
+ # Input is a PIL image
138
+ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
139
+ input.save(tmp_file, format="JPEG")
140
+ tmp_file_path = tmp_file.name
141
+ file = genai.upload_file(tmp_file_path, mime_type=mime_type or "image/jpeg")
142
+ os.remove(tmp_file_path)
143
+ else:
144
+ raise ValueError("Unsupported input type. Must be a file path or PIL Image.")
145
+
146
+ #print(f"Uploaded file '{file.display_name}' as: {file.uri}")
147
+ return file
148
+
149
+ def get_media_wrapper(model_name: str) -> Optional[Union[GeminiWrapper, VertexAIWrapper, OpenRouterWrapper]]:
150
+ """Get appropriate wrapper for media handling based on model name"""
151
+ if model_name.startswith('gemini/'):
152
+ return GeminiWrapper(model_name=model_name.split('/')[-1])
153
+ elif model_name.startswith('vertex_ai/'):
154
+ return VertexAIWrapper(model_name=model_name.split('/')[-1])
155
+ elif model_name.startswith('openrouter/'):
156
+ return OpenRouterWrapper(model_name=model_name)
157
+ return None
158
+
159
+ def prepare_media_messages(prompt: str, media_path: Union[str, Image.Image], model_name: str) -> List[Dict[str, Any]]:
160
+ """Prepare messages for media input based on model type"""
161
+ is_video = isinstance(media_path, str) and media_path.endswith('.mp4')
162
+
163
+ if is_video and (model_name.startswith('gemini/') or model_name.startswith('vertex_ai/') or model_name.startswith('openrouter/')):
164
+ return [
165
+ {"type": "text", "content": prompt},
166
+ {"type": "video", "content": media_path}
167
+ ]
168
+ else:
169
+ # For images or non-video content
170
+ if isinstance(media_path, str):
171
+ media = Image.open(media_path)
172
+ else:
173
+ media = media_path
174
+ return [
175
+ {"type": "text", "content": prompt},
176
+ {"type": "image", "content": media}
177
+ ]
mllm_tools/vertex_ai.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Dict, Any, Optional
3
+ import vertexai
4
+ from vertexai.generative_models import GenerativeModel, Part
5
+ from google.auth import default
6
+ from google.auth.transport import requests
7
+
8
+
9
+ # TODO: check if this is the correct way to use Vertex AI
10
+ # TODO: add langfuse support
11
+ class VertexAIWrapper:
12
+ """Wrapper for Vertex AI to support Gemini models."""
13
+
14
+ def __init__(
15
+ self,
16
+ model_name: str = "gemini-1.5-pro",
17
+ temperature: float = 0.7,
18
+ print_cost: bool = False,
19
+ verbose: bool = False,
20
+ use_langfuse: bool = False
21
+ ):
22
+ """Initialize the Vertex AI wrapper.
23
+
24
+ Args:
25
+ model_name: Name of the model to use (e.g. "gemini-1.5-pro")
26
+ temperature: Temperature for generation between 0 and 1
27
+ print_cost: Whether to print the cost of the completion
28
+ verbose: Whether to print verbose output
29
+ use_langfuse: Whether to enable Langfuse logging
30
+ """
31
+ self.model_name = model_name
32
+ self.temperature = temperature
33
+ self.print_cost = print_cost
34
+ self.verbose = verbose
35
+
36
+ # Initialize Vertex AI
37
+ project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
38
+ location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1")
39
+ if not project_id:
40
+ raise ValueError("No GOOGLE_CLOUD_PROJECT found in environment variables")
41
+
42
+ vertexai.init(project=project_id, location=location)
43
+ self.model = GenerativeModel(model_name)
44
+
45
+ def __call__(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None) -> str:
46
+ """Process messages and return completion.
47
+
48
+ Args:
49
+ messages: List of message dictionaries containing type and content
50
+ metadata: Optional metadata dictionary to pass to the model
51
+
52
+ Returns:
53
+ Generated text response from the model
54
+
55
+ Raises:
56
+ ValueError: If message type is not supported
57
+ """
58
+ parts = []
59
+
60
+ for msg in messages:
61
+ if msg["type"] == "text":
62
+ parts.append(Part.from_text(msg["content"]))
63
+ elif msg["type"] in ["image", "video"]:
64
+ mime_type = "video/mp4" if msg["type"] == "video" else "image/jpeg"
65
+ if isinstance(msg["content"], str):
66
+ # Handle GCS URI
67
+ parts.append(Part.from_uri(
68
+ msg["content"],
69
+ mime_type=mime_type
70
+ ))
71
+ else:
72
+ # Handle file path or bytes
73
+ parts.append(Part.from_data(
74
+ msg["content"],
75
+ mime_type=mime_type
76
+ ))
77
+
78
+ response = self.model.generate_content(
79
+ parts,
80
+ generation_config={
81
+ "temperature": self.temperature,
82
+ "top_p": 0.95,
83
+ }
84
+ )
85
+
86
+ return response.text