Antonio Toro Jaén commited on
Commit
1a1813d
·
1 Parent(s): 5be1c50

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -24
app.py CHANGED
@@ -1,30 +1,20 @@
1
  import gradio as gr
2
  import torch
 
3
  import re
4
  import os
5
  import csv
6
- from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from huggingface_hub import login
8
- import spaces
9
-
10
- @spaces.GPU
11
- def confirm_gpu():
12
- import torch
13
- return torch.cuda.is_available()
14
 
15
  login(os.environ["HF_TOKEN"])
16
 
17
- model_name = "atorojaen/DeepSeekMisogyny"
18
-
19
- tokenizer = AutoTokenizer.from_pretrained(model_name)
20
- model = AutoModelForCausalLM.from_pretrained(
21
- model_name,
22
- torch_dtype=torch.float16,
23
- device_map="auto"
24
  )
25
-
26
- assert torch.cuda.is_available(), "CUDA no está disponible"
27
-
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
  model.eval()
30
 
@@ -32,13 +22,18 @@ FLAG_FILE = "flags_data/flags.csv"
32
  os.makedirs(os.path.dirname(FLAG_FILE), exist_ok=True)
33
 
34
  def clean_lyrics(text):
 
35
  text = re.sub(r"[^a-zA-ZáéíóúñüÁÉÍÓÚÑÜ ]+", " ", text)
 
36
  text = text.lower()
 
37
  text = re.sub(r"\s+", " ", text).strip()
38
  return text
39
 
 
40
  def detect_misogyny(text):
41
  cleaned_text = clean_lyrics(text)
 
42
  prompt = """
43
  ### Instruccion
44
  Analiza la siguiente letra de canción y determina si contiene contenido misógino. Evalúa si incluye lenguaje, actitudes o mensajes que:
@@ -54,6 +49,7 @@ def detect_misogyny(text):
54
 
55
  ### Respuesta:
56
  <think>"""
 
57
  prompt = prompt.format(lyrics=text)
58
 
59
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
@@ -66,30 +62,34 @@ def detect_misogyny(text):
66
  )
67
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
68
 
 
69
  explanation_match = re.search(r"<think>(.*?)</think>", response, re.DOTALL)
70
  explanation = explanation_match.group(1).strip() if explanation_match else ""
71
 
 
72
  label_match = re.search(r"</think>\s*(\d)", response)
73
  label = label_match.group(1) if label_match else ""
74
 
 
75
  return f"{explanation}\n\nRespuesta final: {label}" if explanation and label else response.strip()
76
 
 
77
  def save_flag(user_text, response, flag_type):
 
78
  with open(FLAG_FILE, mode="a", newline="", encoding="utf-8") as f:
79
  writer = csv.writer(f)
80
  writer.writerow([user_text, response, flag_type])
81
  return f"Guardado flag: {flag_type}"
82
 
83
  with gr.Blocks() as demo:
 
 
84
  user_input = gr.Textbox(label="Letra de canción", lines=10)
85
  result = gr.Textbox(label="Respuesta del modelo", lines=10)
86
-
87
  btn_analizar = gr.Button("Analizar")
88
- btn_correcto = gr.Button("Respuesta correcta")
89
- btn_incorrecto = gr.Button("Respuesta incorrecta")
90
-
91
  btn_analizar.click(fn=detect_misogyny, inputs=user_input, outputs=result)
92
- btn_correcto.click(fn=save_flag, inputs=[user_input, result, gr.State("correcto")], outputs=result)
93
- btn_incorrecto.click(fn=save_flag, inputs=[user_input, result, gr.State("incorrecto")], outputs=result)
94
 
95
- demo.launch(share=True, ssr_mode=False)
 
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import re
5
  import os
6
  import csv
 
7
  from huggingface_hub import login
8
+ from unsloth import FastLanguageModel
 
 
 
 
 
9
 
10
  login(os.environ["HF_TOKEN"])
11
 
12
+ model, tokenizer = FastLanguageModel.from_pretrained(
13
+ model_name = "atorojaen/DeepSeek-R1-MiSonGyny", # Modelo base
14
+ max_seq_length = 2048,
15
+ dtype = torch.float16,
16
+ load_in_4bit = True,
 
 
17
  )
 
 
 
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  model.eval()
20
 
 
22
  os.makedirs(os.path.dirname(FLAG_FILE), exist_ok=True)
23
 
24
  def clean_lyrics(text):
25
+ # Elimina caracteres no alfabéticos (excepto espacios y letras acentuadas comunes en español)
26
  text = re.sub(r"[^a-zA-ZáéíóúñüÁÉÍÓÚÑÜ ]+", " ", text)
27
+ # Convierte a minúsculas
28
  text = text.lower()
29
+ # Reduce espacios múltiples
30
  text = re.sub(r"\s+", " ", text).strip()
31
  return text
32
 
33
+ # Función de predicción
34
  def detect_misogyny(text):
35
  cleaned_text = clean_lyrics(text)
36
+ # Construir el prompt de entrada
37
  prompt = """
38
  ### Instruccion
39
  Analiza la siguiente letra de canción y determina si contiene contenido misógino. Evalúa si incluye lenguaje, actitudes o mensajes que:
 
49
 
50
  ### Respuesta:
51
  <think>"""
52
+
53
  prompt = prompt.format(lyrics=text)
54
 
55
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
62
  )
63
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
64
 
65
+ # Extraer explicación entre <think>...</think>
66
  explanation_match = re.search(r"<think>(.*?)</think>", response, re.DOTALL)
67
  explanation = explanation_match.group(1).strip() if explanation_match else ""
68
 
69
+ # Extraer "1" o "0" después de </think>
70
  label_match = re.search(r"</think>\s*(\d)", response)
71
  label = label_match.group(1) if label_match else ""
72
 
73
+ # Combinar resultado final
74
  return f"{explanation}\n\nRespuesta final: {label}" if explanation and label else response.strip()
75
 
76
+
77
  def save_flag(user_text, response, flag_type):
78
+ # Guarda la entrada, salida y si fue correcta o incorrecta en CSV
79
  with open(FLAG_FILE, mode="a", newline="", encoding="utf-8") as f:
80
  writer = csv.writer(f)
81
  writer.writerow([user_text, response, flag_type])
82
  return f"Guardado flag: {flag_type}"
83
 
84
  with gr.Blocks() as demo:
85
+ gr.Markdown("# Detector de misoginia en letras de canciones") # Título principal
86
+ gr.Markdown("Este sistema analiza letras de canciones en español y detecta contenido misógino utilizando el modelo DeepSeek R1 entrenado.")
87
  user_input = gr.Textbox(label="Letra de canción", lines=10)
88
  result = gr.Textbox(label="Respuesta del modelo", lines=10)
89
+
90
  btn_analizar = gr.Button("Analizar")
91
+
 
 
92
  btn_analizar.click(fn=detect_misogyny, inputs=user_input, outputs=result)
 
 
93
 
94
+
95
+ demo.launch()