import transformers from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM import torch import gradio as gr import requests import bs4 from bs4 import BeautifulSoup def get_text_from_url(url): headers = { 'Accept-Language': 'en-US,en;q=0.9', } response = requests.get(url, headers=headers) if response.status_code == 200: soup = BeautifulSoup(response.content, 'html.parser') texto = soup.get_text() return texto else: print("Error al obtener la página:", response.status_code) return 'error' classification_model_checkpoint = 'FrancoMartino/privacyPolicies_classification' classification_tokenizer = AutoTokenizer.from_pretrained("FrancoMartino/privacyPolicies_classification") classification_model = AutoModelForSequenceClassification.from_pretrained("FrancoMartino/privacyPolicies_classification") summarization_model_checkpoint = "facebook/bart-large-cnn" summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model_checkpoint) summarization_model = AutoModelForSeq2SeqLM.from_pretrained(summarization_model_checkpoint) def predict(url): text = get_text_from_url(url) if text == 'error': return {'ERROR': 'Error with the url'} if len(classification_tokenizer.tokenize(text)) > 4096: print('long') inputs = summarization_tokenizer(text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): summary_ids = summarization_model.generate(inputs['input_ids'], max_length=4096) text = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True) inputs = classification_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=4096) with torch.no_grad(): logits = classification_model(**inputs).logits probabilities = torch.softmax(logits, dim=1) prediction = probabilities[:,1].item() return {'Risk Indicator': prediction} examples_urls = [ ["https://help.instagram.com/155833707900388"], ["https://www.apple.com/legal/privacy/en-ww/"], ] interface = gr.Interface(fn=predict, inputs="text",examples=examples_urls, outputs="label", title="Privacy Policy Risk Indicator", description="Enter a privacy policy URL to calculate risk.") interface.launch()