Prathamesh1420's picture
Create models.py
863f205 verified
from langchain_huggingface import HuggingFaceEmbeddings
from transformers import AutoTokenizer
from langchain_groq import ChatGroq
import google.generativeai as genai
def load_models(embedding_model="ibm-granite/granite-embedding-30m-english",
llm_model="llama3-70b-8192",
google_api_key=None,
groq_api_key=None):
"""
Load all required models.
Args:
embedding_model: Name/path of the embedding model
llm_model: Name of the LLM model
google_api_key: API key for Google Gemini
groq_api_key: API key for Groq
Returns:
tuple: (embeddings_model, embeddings_tokenizer, vision_model, llm_model)
"""
# Load embedding model and tokenizer
embeddings_model = HuggingFaceEmbeddings(model_name=embedding_model)
embeddings_tokenizer = AutoTokenizer.from_pretrained(embedding_model)
# Initialize Gemini vision model
if google_api_key:
genai.configure(api_key=google_api_key)
vision_model = genai.GenerativeModel(model_name="gemini-1.5-flash")
else:
vision_model = None
# Initialize Groq LLM
if groq_api_key:
llm_model = ChatGroq(model_name=llm_model, api_key=groq_api_key)
else:
llm_model = None
return embeddings_model, embeddings_tokenizer, vision_model, llm_model