mipatov commited on
Commit
a0b0d25
·
1 Parent(s): f72f3a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -29,6 +29,7 @@ def predict_gpt(text, model, tokenizer, temperature=1.0):
29
  temperature= temperature,
30
  top_p=0.75,
31
  max_length=512,
 
32
  eos_token_id = tokenizer.eos_token_id,
33
  pad_token_id = tokenizer.pad_token_id,
34
  num_return_sequences = 1,
@@ -36,8 +37,8 @@ def predict_gpt(text, model, tokenizer, temperature=1.0):
36
  return_dict_in_generate=True,
37
  )
38
  decode = lambda x : tokenizer.decode(x, skip_special_tokens=True)
39
- generated_text = list(map(decode, out['sequences']))[0].split('Описание :')[1]
40
- return 'Описание : '+ generated_text
41
 
42
  def predict_t5(text, model, tokenizer, temperature=1.2):
43
  input_ids = tokenizer.encode(text, return_tensors="pt")
@@ -48,7 +49,7 @@ def predict_t5(text, model, tokenizer, temperature=1.2):
48
  temperature=temperature,
49
  top_p=0.35,
50
  max_length=512,
51
- length_penalty = 5.5,
52
  output_attentions = True,
53
  return_dict_in_generate=True,
54
  repetition_penalty = 2.5,
 
29
  temperature= temperature,
30
  top_p=0.75,
31
  max_length=512,
32
+ length_penalty = 5.5,
33
  eos_token_id = tokenizer.eos_token_id,
34
  pad_token_id = tokenizer.pad_token_id,
35
  num_return_sequences = 1,
 
37
  return_dict_in_generate=True,
38
  )
39
  decode = lambda x : tokenizer.decode(x, skip_special_tokens=True)
40
+ generated_text = list(map(decode, out['sequences']))[0].replace(text,'')
41
+ return generated_text
42
 
43
  def predict_t5(text, model, tokenizer, temperature=1.2):
44
  input_ids = tokenizer.encode(text, return_tensors="pt")
 
49
  temperature=temperature,
50
  top_p=0.35,
51
  max_length=512,
52
+ length_penalty = -1.0,
53
  output_attentions = True,
54
  return_dict_in_generate=True,
55
  repetition_penalty = 2.5,