FrancoMartino's picture
Update app.py
2cbffb5 verified
raw
history blame
2.33 kB
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM
import torch
import gradio as gr
import requests
import bs4
from bs4 import BeautifulSoup
def get_text_from_url(url):
headers = {
'Accept-Language': 'en-US,en;q=0.9',
}
response = requests.get(url, headers=headers)
if response.status_code == 200:
soup = BeautifulSoup(response.content, 'html.parser')
texto = soup.get_text()
return texto
else:
print("Error al obtener la página:", response.status_code)
return 'error'
classification_model_checkpoint = 'FrancoMartino/privacyPolicies_classification'
classification_tokenizer = AutoTokenizer.from_pretrained("FrancoMartino/privacyPolicies_classification")
classification_model = AutoModelForSequenceClassification.from_pretrained("FrancoMartino/privacyPolicies_classification")
summarization_model_checkpoint = "facebook/bart-large-cnn"
summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model_checkpoint)
summarization_model = AutoModelForSeq2SeqLM.from_pretrained(summarization_model_checkpoint)
def predict(url):
text = get_text_from_url(url)
if text == 'error':
return {'ERROR': 'Error with the url'}
if len(classification_tokenizer.tokenize(text)) > 4096:
print('long')
inputs = summarization_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
summary_ids = summarization_model.generate(inputs['input_ids'], max_length=4096)
text = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
inputs = classification_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=4096)
with torch.no_grad():
logits = classification_model(**inputs).logits
probabilities = torch.softmax(logits, dim=1)
prediction = probabilities[:,1].item()
return {'Risk Indicator': prediction}
examples_urls = [
["https://help.instagram.com/155833707900388"],
["https://www.apple.com/legal/privacy/en-ww/"],
]
interface = gr.Interface(fn=predict, inputs="text",examples=examples_urls, outputs="label", title="Privacy Policy Risk Indicator", description="Enter a privacy policy URL to calculate risk.")
interface.launch()