File size: 1,370 Bytes
863f205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from langchain_huggingface import HuggingFaceEmbeddings
from transformers import AutoTokenizer
from langchain_groq import ChatGroq
import google.generativeai as genai

def load_models(embedding_model="ibm-granite/granite-embedding-30m-english", 
                llm_model="llama3-70b-8192",
                google_api_key=None,
                groq_api_key=None):
    """
    Load all required models.
    
    Args:
        embedding_model: Name/path of the embedding model
        llm_model: Name of the LLM model
        google_api_key: API key for Google Gemini
        groq_api_key: API key for Groq
        
    Returns:
        tuple: (embeddings_model, embeddings_tokenizer, vision_model, llm_model)
    """
    # Load embedding model and tokenizer
    embeddings_model = HuggingFaceEmbeddings(model_name=embedding_model)
    embeddings_tokenizer = AutoTokenizer.from_pretrained(embedding_model)
    
    # Initialize Gemini vision model
    if google_api_key:
        genai.configure(api_key=google_api_key)
        vision_model = genai.GenerativeModel(model_name="gemini-1.5-flash")
    else:
        vision_model = None
    
    # Initialize Groq LLM
    if groq_api_key:
        llm_model = ChatGroq(model_name=llm_model, api_key=groq_api_key)
    else:
        llm_model = None
    
    return embeddings_model, embeddings_tokenizer, vision_model, llm_model