asmashayea commited on
Commit
ca6eb6e
·
1 Parent(s): a2b7e8f
Files changed (1) hide show
  1. seq2seq_inference.py +34 -10
seq2seq_inference.py CHANGED
@@ -1,5 +1,5 @@
1
- import json
2
  import torch
 
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  from peft import PeftModel
5
 
@@ -18,10 +18,38 @@ SYSTEM_PROMPT = (
18
  )
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- def infer_t5_prompt(review_text, tokenizer, peft_model):
23
- prompt = SYSTEM_PROMPT + f"\n\nReview: {review_text}"
 
 
24
 
 
 
 
 
 
25
  inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(peft_model.device)
26
 
27
  with torch.no_grad():
@@ -36,15 +64,11 @@ def infer_t5_prompt(review_text, tokenizer, peft_model):
36
  eos_token_id=tokenizer.eos_token_id,
37
  )
38
 
39
- decoded = tokenizer.decode(
40
- outputs[0],
41
- skip_special_tokens=True,
42
- clean_up_tokenization_spaces=False
43
- ).strip()
44
-
45
  decoded = decoded.replace('<extra_id_0>', '').replace('</s>', '').strip()
46
 
47
  try:
48
  return json.loads(decoded)
49
  except json.JSONDecodeError:
50
- return decoded
 
 
 
1
  import torch
2
+ import json
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  from peft import PeftModel
5
 
 
18
  )
19
 
20
 
21
+ MODEL_OPTIONS = {
22
+ "Araberta": {
23
+ "base": "asmashayea/absa-araberta",
24
+ "adapter": "asmashayea/absa-araberta"
25
+ },
26
+ "mT5": {
27
+ "base": "google/mt5-base",
28
+ "adapter": "asmashayea/mt4-absa"
29
+ },
30
+ "mBART": {
31
+ "base": "facebook/mbart-large-50-many-to-many-mmt",
32
+ "adapter": "asmashayea/mbart-absa"
33
+ },
34
+ }
35
+
36
+ cached_models = {}
37
+
38
+ # ✅ Reusable for both mT5 + mBART
39
+ def load_mt5_bart(model_key):
40
+ base_id = MODEL_OPTIONS[model_key]["base"]
41
+ adapter_id = MODEL_OPTIONS[model_key]["adapter"]
42
 
43
+ tokenizer = AutoTokenizer.from_pretrained(adapter_id)
44
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(base_id)
45
+ peft_model = PeftModel.from_pretrained(base_model, adapter_id)
46
+ peft_model.eval()
47
 
48
+ cached_models[model_key] = (tokenizer, peft_model)
49
+ return tokenizer, peft_model
50
+
51
+ def infer_t5_bart(text, tokenizer, peft_model):
52
+ prompt = SYSTEM_PROMPT + f"\n\nReview: {text}"
53
  inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(peft_model.device)
54
 
55
  with torch.no_grad():
 
64
  eos_token_id=tokenizer.eos_token_id,
65
  )
66
 
67
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
 
 
 
 
 
68
  decoded = decoded.replace('<extra_id_0>', '').replace('</s>', '').strip()
69
 
70
  try:
71
  return json.loads(decoded)
72
  except json.JSONDecodeError:
73
+ return {"raw_output": decoded, "error": "Invalid JSON"}
74
+