amine_dubs
commited on
Commit
·
2a3ff67
1
Parent(s):
517b06c
erroe
Browse files- backend/main.py +26 -7
- backend/requirements.txt +1 -0
backend/main.py
CHANGED
@@ -70,22 +70,41 @@ def initialize_model():
|
|
70 |
cache_dir="/tmp/transformers_cache"
|
71 |
)
|
72 |
|
73 |
-
# Check if TensorFlow
|
74 |
tf_available = False
|
75 |
try:
|
76 |
import tensorflow
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
print("TensorFlow is available, will use from_tf=True")
|
79 |
except ImportError:
|
80 |
print("TensorFlow is not installed, will use default PyTorch loading")
|
81 |
|
82 |
# Load the model with appropriate settings based on TensorFlow availability
|
83 |
print(f"Loading model {'with from_tf=True' if tf_available else 'with default PyTorch settings'}...")
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
# Create a pipeline with the loaded model and tokenizer
|
91 |
print("Creating pipeline with pre-loaded model...")
|
|
|
70 |
cache_dir="/tmp/transformers_cache"
|
71 |
)
|
72 |
|
73 |
+
# Check if TensorFlow and tf-keras are available
|
74 |
tf_available = False
|
75 |
try:
|
76 |
import tensorflow
|
77 |
+
# Try to import tf_keras which is the compatibility package
|
78 |
+
try:
|
79 |
+
import tf_keras
|
80 |
+
print("tf-keras is installed, using TensorFlow with compatibility layer")
|
81 |
+
tf_available = True
|
82 |
+
except ImportError:
|
83 |
+
print("tf-keras not found, will try to use PyTorch backend")
|
84 |
print("TensorFlow is available, will use from_tf=True")
|
85 |
except ImportError:
|
86 |
print("TensorFlow is not installed, will use default PyTorch loading")
|
87 |
|
88 |
# Load the model with appropriate settings based on TensorFlow availability
|
89 |
print(f"Loading model {'with from_tf=True' if tf_available else 'with default PyTorch settings'}...")
|
90 |
+
try:
|
91 |
+
# First try with PyTorch approach which is more reliable
|
92 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
93 |
+
model_name,
|
94 |
+
from_tf=False, # Use PyTorch first
|
95 |
+
cache_dir="/tmp/transformers_cache"
|
96 |
+
)
|
97 |
+
except Exception as e:
|
98 |
+
print(f"PyTorch loading failed: {e}")
|
99 |
+
if tf_available:
|
100 |
+
print("Attempting to load with TensorFlow...")
|
101 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
102 |
+
model_name,
|
103 |
+
from_tf=True,
|
104 |
+
cache_dir="/tmp/transformers_cache"
|
105 |
+
)
|
106 |
+
else:
|
107 |
+
raise # Re-raise if we can't use TensorFlow either
|
108 |
|
109 |
# Create a pipeline with the loaded model and tokenizer
|
110 |
print("Creating pipeline with pre-loaded model...")
|
backend/requirements.txt
CHANGED
@@ -9,3 +9,4 @@ transformers
|
|
9 |
torch
|
10 |
sentencepiece
|
11 |
tensorflow
|
|
|
|
9 |
torch
|
10 |
sentencepiece
|
11 |
tensorflow
|
12 |
+
tf-keras
|