aboutsummaryrefslogtreecommitdiff
path: root/modules/xlmr_m18.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/xlmr_m18.py')
-rw-r--r--modules/xlmr_m18.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py
index 18785692..a727e865 100644
--- a/modules/xlmr_m18.py
+++ b/modules/xlmr_m18.py
@@ -1,4 +1,4 @@
-from transformers import BertPreTrainedModel,BertModel,BertConfig
+from transformers import BertPreTrainedModel,BertConfig
import torch.nn as nn
import torch
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
@@ -28,7 +28,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
config_class = BertSeriesConfig
def __init__(self, config=None, **kargs):
- # modify initialization for autoloading
+ # modify initialization for autoloading
if config is None:
config = XLMRobertaConfig()
config.attention_probs_dropout_prob= 0.1
@@ -80,7 +80,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
text["attention_mask"] = torch.tensor(
text['attention_mask']).to(device)
features = self(**text)
- return features['projection_state']
+ return features['projection_state']
def forward(
self,
@@ -147,8 +147,8 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
"hidden_states": outputs.hidden_states,
"attentions": outputs.attentions,
}
-
-
+
+
# return {
# 'pooler_output':pooler_output,
# 'last_hidden_state':outputs.last_hidden_state,
@@ -161,4 +161,4 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
base_model_prefix = 'roberta'
- config_class= RobertaSeriesConfig \ No newline at end of file
+ config_class= RobertaSeriesConfig