FrancoMartino's picture
Update app.py
2cbffb5 verified
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM
import torch
import gradio as gr
import requests
import bs4
from bs4 import BeautifulSoup
def get_text_from_url(url):
headers = {
'Accept-Language': 'en-US,en;q=0.9',
}
response = requests.get(url, headers=headers)
if response.status_code == 200:
soup = BeautifulSoup(response.content, 'html.parser')
texto = soup.get_text()
return texto
else:
print("Error al obtener la página:", response.status_code)
return 'error'
classification_model_checkpoint = 'FrancoMartino/privacyPolicies_classification'
classification_tokenizer = AutoTokenizer.from_pretrained("FrancoMartino/privacyPolicies_classification")
classification_model = AutoModelForSequenceClassification.from_pretrained("FrancoMartino/privacyPolicies_classification")
summarization_model_checkpoint = "facebook/bart-large-cnn"
summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model_checkpoint)
summarization_model = AutoModelForSeq2SeqLM.from_pretrained(summarization_model_checkpoint)
def predict(url):
text = get_text_from_url(url)
if text == 'error':
return {'ERROR': 'Error with the url'}
if len(classification_tokenizer.tokenize(text)) > 4096:
print('long')
inputs = summarization_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
summary_ids = summarization_model.generate(inputs['input_ids'], max_length=4096)
text = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
inputs = classification_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=4096)
with torch.no_grad():
logits = classification_model(**inputs).logits
probabilities = torch.softmax(logits, dim=1)
prediction = probabilities[:,1].item()
return {'Risk Indicator': prediction}
examples_urls = [
["https://help.instagram.com/155833707900388"],
["https://www.apple.com/legal/privacy/en-ww/"],
]
interface = gr.Interface(fn=predict, inputs="text",examples=examples_urls, outputs="label", title="Privacy Policy Risk Indicator", description="Enter a privacy policy URL to calculate risk.")
interface.launch()