Curative commited on
Commit
08f21ce
·
verified ·
1 Parent(s): 950842b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -14,7 +14,8 @@ def get_sentiment():
14
  def get_classifier():
15
  global classifier
16
  if not classifier:
17
- classifier = pipeline("text-classification",
 
18
  model="textattack/distilbert-base-uncased-ag-news")
19
  return classifier
20
 
@@ -43,8 +44,13 @@ def process(text, features):
43
  sent = get_sentiment()(text)[0]
44
  result["sentiment"] = {"label": sent["label"], "score": sent["score"]}
45
  if "Classification" in features:
46
- cls = get_classifier()(text)[0]
47
- result["classification"] = {"label": cls["label"], "score": cls["score"]}
 
 
 
 
 
48
  if "Entities" in features:
49
  ents = get_ner()(text)
50
  result["entities"] = [
 
14
  def get_classifier():
15
  global classifier
16
  if not classifier:
17
+ classifier = pipeline(
18
+ "zero-shot-classification",
19
  model="textattack/distilbert-base-uncased-ag-news")
20
  return classifier
21
 
 
44
  sent = get_sentiment()(text)[0]
45
  result["sentiment"] = {"label": sent["label"], "score": sent["score"]}
46
  if "Classification" in features:
47
+ candidate_labels = [
48
+ "technology", "sports", "business", "politics",
49
+ "health", "science", "travel", "entertainment"
50
+ ]
51
+ cls = get_classifier()(text, candidate_labels=candidate_labels)
52
+ # Map labels → scores
53
+ result["classification"] = dict(zip(cls["labels"], cls["scores"]))
54
  if "Entities" in features:
55
  ents = get_ner()(text)
56
  result["entities"] = [