Prathamesh1420 commited on
Commit
863f205
·
verified ·
1 Parent(s): d237c98

Create models.py

Browse files
Files changed (1) hide show
  1. utils/models.py +39 -0
utils/models.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_huggingface import HuggingFaceEmbeddings
2
+ from transformers import AutoTokenizer
3
+ from langchain_groq import ChatGroq
4
+ import google.generativeai as genai
5
+
6
+ def load_models(embedding_model="ibm-granite/granite-embedding-30m-english",
7
+ llm_model="llama3-70b-8192",
8
+ google_api_key=None,
9
+ groq_api_key=None):
10
+ """
11
+ Load all required models.
12
+
13
+ Args:
14
+ embedding_model: Name/path of the embedding model
15
+ llm_model: Name of the LLM model
16
+ google_api_key: API key for Google Gemini
17
+ groq_api_key: API key for Groq
18
+
19
+ Returns:
20
+ tuple: (embeddings_model, embeddings_tokenizer, vision_model, llm_model)
21
+ """
22
+ # Load embedding model and tokenizer
23
+ embeddings_model = HuggingFaceEmbeddings(model_name=embedding_model)
24
+ embeddings_tokenizer = AutoTokenizer.from_pretrained(embedding_model)
25
+
26
+ # Initialize Gemini vision model
27
+ if google_api_key:
28
+ genai.configure(api_key=google_api_key)
29
+ vision_model = genai.GenerativeModel(model_name="gemini-1.5-flash")
30
+ else:
31
+ vision_model = None
32
+
33
+ # Initialize Groq LLM
34
+ if groq_api_key:
35
+ llm_model = ChatGroq(model_name=llm_model, api_key=groq_api_key)
36
+ else:
37
+ llm_model = None
38
+
39
+ return embeddings_model, embeddings_tokenizer, vision_model, llm_model