amine_dubs commited on
Commit
2a3ff67
·
1 Parent(s): 517b06c
Files changed (2) hide show
  1. backend/main.py +26 -7
  2. backend/requirements.txt +1 -0
backend/main.py CHANGED
@@ -70,22 +70,41 @@ def initialize_model():
70
  cache_dir="/tmp/transformers_cache"
71
  )
72
 
73
- # Check if TensorFlow is available
74
  tf_available = False
75
  try:
76
  import tensorflow
77
- tf_available = True
 
 
 
 
 
 
78
  print("TensorFlow is available, will use from_tf=True")
79
  except ImportError:
80
  print("TensorFlow is not installed, will use default PyTorch loading")
81
 
82
  # Load the model with appropriate settings based on TensorFlow availability
83
  print(f"Loading model {'with from_tf=True' if tf_available else 'with default PyTorch settings'}...")
84
- model = AutoModelForSeq2SeqLM.from_pretrained(
85
- model_name,
86
- from_tf=tf_available, # Only set True if TensorFlow is available
87
- cache_dir="/tmp/transformers_cache"
88
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  # Create a pipeline with the loaded model and tokenizer
91
  print("Creating pipeline with pre-loaded model...")
 
70
  cache_dir="/tmp/transformers_cache"
71
  )
72
 
73
+ # Check if TensorFlow and tf-keras are available
74
  tf_available = False
75
  try:
76
  import tensorflow
77
+ # Try to import tf_keras which is the compatibility package
78
+ try:
79
+ import tf_keras
80
+ print("tf-keras is installed, using TensorFlow with compatibility layer")
81
+ tf_available = True
82
+ except ImportError:
83
+ print("tf-keras not found, will try to use PyTorch backend")
84
  print("TensorFlow is available, will use from_tf=True")
85
  except ImportError:
86
  print("TensorFlow is not installed, will use default PyTorch loading")
87
 
88
  # Load the model with appropriate settings based on TensorFlow availability
89
  print(f"Loading model {'with from_tf=True' if tf_available else 'with default PyTorch settings'}...")
90
+ try:
91
+ # First try with PyTorch approach which is more reliable
92
+ model = AutoModelForSeq2SeqLM.from_pretrained(
93
+ model_name,
94
+ from_tf=False, # Use PyTorch first
95
+ cache_dir="/tmp/transformers_cache"
96
+ )
97
+ except Exception as e:
98
+ print(f"PyTorch loading failed: {e}")
99
+ if tf_available:
100
+ print("Attempting to load with TensorFlow...")
101
+ model = AutoModelForSeq2SeqLM.from_pretrained(
102
+ model_name,
103
+ from_tf=True,
104
+ cache_dir="/tmp/transformers_cache"
105
+ )
106
+ else:
107
+ raise # Re-raise if we can't use TensorFlow either
108
 
109
  # Create a pipeline with the loaded model and tokenizer
110
  print("Creating pipeline with pre-loaded model...")
backend/requirements.txt CHANGED
@@ -9,3 +9,4 @@ transformers
9
  torch
10
  sentencepiece
11
  tensorflow
 
 
9
  torch
10
  sentencepiece
11
  tensorflow
12
+ tf-keras