Upload modeling_fastesm.py with huggingface_hub
Browse files- modeling_fastesm.py +6 -19
modeling_fastesm.py
CHANGED
|
@@ -749,35 +749,22 @@ class FastEsmModel(FastEsmPreTrainedModel):
|
|
| 749 |
else:
|
| 750 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 751 |
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
input_ids=input_ids,
|
| 755 |
-
position_ids=position_ids,
|
| 756 |
attention_mask=attention_mask,
|
|
|
|
| 757 |
inputs_embeds=inputs_embeds,
|
| 758 |
-
)
|
| 759 |
-
|
| 760 |
-
if attention_mask is not None:
|
| 761 |
-
extended_attention_mask = attention_mask[:, None, None, :].expand(
|
| 762 |
-
batch_size, 1, seq_length, seq_length
|
| 763 |
-
).bool()
|
| 764 |
-
else:
|
| 765 |
-
extended_attention_mask = None
|
| 766 |
-
|
| 767 |
-
encoder_outputs = self.encoder(
|
| 768 |
-
embedding_output,
|
| 769 |
-
attention_mask=extended_attention_mask,
|
| 770 |
output_hidden_states=output_hidden_states,
|
| 771 |
output_attentions=output_attentions,
|
| 772 |
)
|
| 773 |
-
sequence_output =
|
| 774 |
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 775 |
|
| 776 |
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 777 |
last_hidden_state=sequence_output,
|
| 778 |
pooler_output=pooled_output,
|
| 779 |
-
hidden_states=
|
| 780 |
-
attentions=
|
| 781 |
)
|
| 782 |
|
| 783 |
|
|
|
|
| 749 |
else:
|
| 750 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 751 |
|
| 752 |
+
outputs = self.esm(
|
| 753 |
+
input_ids,
|
|
|
|
|
|
|
| 754 |
attention_mask=attention_mask,
|
| 755 |
+
position_ids=position_ids,
|
| 756 |
inputs_embeds=inputs_embeds,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 757 |
output_hidden_states=output_hidden_states,
|
| 758 |
output_attentions=output_attentions,
|
| 759 |
)
|
| 760 |
+
sequence_output = outputs.last_hidden_state
|
| 761 |
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 762 |
|
| 763 |
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 764 |
last_hidden_state=sequence_output,
|
| 765 |
pooler_output=pooled_output,
|
| 766 |
+
hidden_states=outputs.hidden_states,
|
| 767 |
+
attentions=outputs.attentions,
|
| 768 |
)
|
| 769 |
|
| 770 |
|