Spaces:
Runtime error
Runtime error
import transformers | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM | |
import torch | |
import gradio as gr | |
import requests | |
import bs4 | |
from bs4 import BeautifulSoup | |
def get_text_from_url(url): | |
headers = { | |
'Accept-Language': 'en-US,en;q=0.9', | |
} | |
response = requests.get(url, headers=headers) | |
if response.status_code == 200: | |
soup = BeautifulSoup(response.content, 'html.parser') | |
texto = soup.get_text() | |
return texto | |
else: | |
print("Error al obtener la página:", response.status_code) | |
return 'error' | |
classification_model_checkpoint = 'FrancoMartino/privacyPolicies_classification' | |
classification_tokenizer = AutoTokenizer.from_pretrained("FrancoMartino/privacyPolicies_classification") | |
classification_model = AutoModelForSequenceClassification.from_pretrained("FrancoMartino/privacyPolicies_classification") | |
summarization_model_checkpoint = "facebook/bart-large-cnn" | |
summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model_checkpoint) | |
summarization_model = AutoModelForSeq2SeqLM.from_pretrained(summarization_model_checkpoint) | |
def predict(url): | |
text = get_text_from_url(url) | |
if text == 'error': | |
return {'ERROR': 'Error with the url'} | |
if len(classification_tokenizer.tokenize(text)) > 4096: | |
print('long') | |
inputs = summarization_tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
with torch.no_grad(): | |
summary_ids = summarization_model.generate(inputs['input_ids'], max_length=4096) | |
text = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
inputs = classification_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=4096) | |
with torch.no_grad(): | |
logits = classification_model(**inputs).logits | |
probabilities = torch.softmax(logits, dim=1) | |
prediction = probabilities[:,1].item() | |
return {'Risk Indicator': prediction} | |
examples_urls = [ | |
["https://help.instagram.com/155833707900388"], | |
["https://www.apple.com/legal/privacy/en-ww/"], | |
] | |
interface = gr.Interface(fn=predict, inputs="text",examples=examples_urls, outputs="label", title="Privacy Policy Risk Indicator", description="Enter a privacy policy URL to calculate risk.") | |
interface.launch() |