asmashayea commited on
Commit
5825fbe
·
1 Parent(s): e90dd4b
Files changed (1) hide show
  1. inference.py +25 -0
inference.py CHANGED
@@ -7,6 +7,31 @@ from seq2seq_inference import infer_t5_prompt, infer_mBart_prompt
7
  from peft import LoraConfig, get_peft_model, PeftModel
8
  from modeling_bilstm_crf import BERT_BiLSTM_CRF
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  cached_models = {}
11
 
12
  def load_araberta():
 
7
  from peft import LoraConfig, get_peft_model, PeftModel
8
  from modeling_bilstm_crf import BERT_BiLSTM_CRF
9
 
10
+ # Define supported models and their adapter IDs
11
+ MODEL_OPTIONS = {
12
+
13
+ "Araberta": {
14
+ "base": "asmashayea/absa-araberta",
15
+ "adapter": "asmashayea/absa-araberta"
16
+ },
17
+ "mT5": {
18
+ "base": "google/mt5-base",
19
+ "adapter": "asmashayea/mt4-absa"
20
+ },
21
+ # "mBART": {
22
+ # "base": "facebook/mbart-large-50-many-to-many-mmt",
23
+ # "adapter": "asmashayea/mbart-absa"
24
+ # },
25
+ "GPT3.5": {
26
+ "base": "bigscience/bloom-560m", # example, not ideal for ABSA
27
+ "adapter": "asmashayea/gpt-absa"
28
+ },
29
+ "GPT4o": {
30
+ "base": "bigscience/bloom-560m", # example, not ideal for ABSA
31
+ "adapter": "asmashayea/gpt-absa"
32
+ }
33
+ }
34
+
35
  cached_models = {}
36
 
37
  def load_araberta():