THeaxxxxxxxx commited on
Commit
41370ea
·
verified ·
1 Parent(s): d35c94b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -9
app.py CHANGED
@@ -26,32 +26,28 @@ def load_pipelines():
26
  "branch service", "transaction delay", "account closure", "information error"
27
  ]
28
 
29
- device = 0 if torch.cuda.is_available() else -1
30
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
31
 
32
  topic_classifier = pipeline(
33
  "zero-shot-classification",
34
  model="MoritzLaurer/deberta-v3-base-zeroshot-v1",
35
- device=device,
36
- torch_dtype=dtype
37
  )
38
 
39
  # Sentiment Analysis Model
40
  sentiment_classifier = pipeline(
41
  "sentiment-analysis",
42
- model="cardiffnlp/twitter-roberta-base-sentiment-latest", # Best performing model from notebook
43
- device=device
44
  )
45
 
46
  # Reply Generation Model
47
- model_name = "Leo66277/finetuned-tinyllama-customer-replies" # Using the fine-tuned model in the notebook
48
  tokenizer = AutoTokenizer.from_pretrained(model_name)
49
  model = AutoModelForCausalLM.from_pretrained(model_name)
50
- model.to(device)
51
 
52
  def generate_reply(text):
53
  prompt_text = f"Please write a short, polite English customer service reply to the following customer comment:\n{text}"
54
- inputs = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=512).to(device)
55
 
56
  with torch.no_grad():
57
  gen_ids = model.generate(
 
26
  "branch service", "transaction delay", "account closure", "information error"
27
  ]
28
 
29
+ dtype = torch.float32
 
30
 
31
  topic_classifier = pipeline(
32
  "zero-shot-classification",
33
  model="MoritzLaurer/deberta-v3-base-zeroshot-v1",
 
 
34
  )
35
 
36
  # Sentiment Analysis Model
37
  sentiment_classifier = pipeline(
38
  "sentiment-analysis",
39
+ model="cardiffnlp/twitter-roberta-base-sentiment-latest",
 
40
  )
41
 
42
  # Reply Generation Model
43
+ model_name = "Leo66277/finetuned-tinyllama-customer-replies"
44
  tokenizer = AutoTokenizer.from_pretrained(model_name)
45
  model = AutoModelForCausalLM.from_pretrained(model_name)
46
+
47
 
48
  def generate_reply(text):
49
  prompt_text = f"Please write a short, polite English customer service reply to the following customer comment:\n{text}"
50
+ inputs = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=512)
51
 
52
  with torch.no_grad():
53
  gen_ids = model.generate(