juanelot commited on
Commit
609cccb
verified
1 Parent(s): 2ac0298

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -6
app.py CHANGED
@@ -2,25 +2,36 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
  # Cargar el modelo y el tokenizador
5
- model_name = "microsoft/DialoGPT-medium" # Puedes cambiar esto por otro modelo de chatbot
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = AutoModelForCausalLM.from_pretrained(model_name)
8
 
9
- def chatbot(input, history=[]):
 
 
 
10
  # Agregar el input del usuario al historial
11
- history.append(input)
 
 
 
 
 
 
 
12
 
13
  # Tokenizar la conversaci贸n
14
- input_ids = tokenizer.encode(" ".join(history), return_tensors="pt")
15
 
16
  # Generar una respuesta
17
- output = model.generate(input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
 
18
 
19
  # Decodificar la respuesta
20
  response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
21
 
22
  # Agregar la respuesta al historial
23
- history.append(response)
24
 
25
  return history, history
26
 
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
  # Cargar el modelo y el tokenizador
5
+ model_name = "microsoft/DialoGPT-medium"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = AutoModelForCausalLM.from_pretrained(model_name)
8
 
9
+ def chatbot(input, history):
10
+ # Aseg煤rate de que history sea una lista de listas
11
+ history = history or []
12
+
13
  # Agregar el input del usuario al historial
14
+ history.append([input, None])
15
+
16
+ # Preparar el contexto para el modelo
17
+ chat_history = []
18
+ for human, ai in history:
19
+ chat_history.append(human)
20
+ if ai:
21
+ chat_history.append(ai)
22
 
23
  # Tokenizar la conversaci贸n
24
+ input_ids = tokenizer.encode(" ".join(chat_history), return_tensors="pt")
25
 
26
  # Generar una respuesta
27
+ attention_mask = input_ids.new_ones(input_ids.shape)
28
+ output = model.generate(input_ids, attention_mask=attention_mask, max_length=1000, pad_token_id=tokenizer.eos_token_id)
29
 
30
  # Decodificar la respuesta
31
  response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
32
 
33
  # Agregar la respuesta al historial
34
+ history[-1][1] = response
35
 
36
  return history, history
37