FrancoMartino commited on
Commit
2cbffb5
·
verified ·
1 Parent(s): 0829d46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -5
app.py CHANGED
@@ -1,10 +1,11 @@
1
- from transformers import LongformerTokenizerFast, LongformerForSequenceClassification
 
2
  import torch
3
  import gradio as gr
4
  import requests
 
5
  from bs4 import BeautifulSoup
6
 
7
-
8
  def get_text_from_url(url):
9
 
10
  headers = {
@@ -20,14 +21,26 @@ def get_text_from_url(url):
20
  else:
21
 
22
  print("Error al obtener la página:", response.status_code)
23
- return None
24
 
25
  classification_model_checkpoint = 'FrancoMartino/privacyPolicies_classification'
26
- classification_tokenizer = LongformerTokenizerFast.from_pretrained(classification_model_checkpoint)
27
- classification_model = LongformerForSequenceClassification.from_pretrained(classification_model_checkpoint)
 
 
 
 
28
 
29
  def predict(url):
30
  text = get_text_from_url(url)
 
 
 
 
 
 
 
 
31
  inputs = classification_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=4096)
32
  with torch.no_grad():
33
  logits = classification_model(**inputs).logits
 
1
+ import transformers
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM
3
  import torch
4
  import gradio as gr
5
  import requests
6
+ import bs4
7
  from bs4 import BeautifulSoup
8
 
 
9
  def get_text_from_url(url):
10
 
11
  headers = {
 
21
  else:
22
 
23
  print("Error al obtener la página:", response.status_code)
24
+ return 'error'
25
 
26
  classification_model_checkpoint = 'FrancoMartino/privacyPolicies_classification'
27
+ classification_tokenizer = AutoTokenizer.from_pretrained("FrancoMartino/privacyPolicies_classification")
28
+ classification_model = AutoModelForSequenceClassification.from_pretrained("FrancoMartino/privacyPolicies_classification")
29
+
30
+ summarization_model_checkpoint = "facebook/bart-large-cnn"
31
+ summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model_checkpoint)
32
+ summarization_model = AutoModelForSeq2SeqLM.from_pretrained(summarization_model_checkpoint)
33
 
34
  def predict(url):
35
  text = get_text_from_url(url)
36
+ if text == 'error':
37
+ return {'ERROR': 'Error with the url'}
38
+ if len(classification_tokenizer.tokenize(text)) > 4096:
39
+ print('long')
40
+ inputs = summarization_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
41
+ with torch.no_grad():
42
+ summary_ids = summarization_model.generate(inputs['input_ids'], max_length=4096)
43
+ text = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
44
  inputs = classification_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=4096)
45
  with torch.no_grad():
46
  logits = classification_model(**inputs).logits