BERT 各层详解

BERT Embedding

  • segment ID:0/1表示,区分是否是同一句话;第一句话全是0,第二句话全是1;”type_vocab_size”: 2
  • token ID:每个单词的索引;”vocab_size”: 21128
  • position ID: 绝对embedding;”max_position_embeddings”: 512

bert embedding = segment ID + token ID + position ID。 注意:三个ID经过embedding相加

为什么可以相加?

embedding
在原始one-hot输入向量上concat,经过变换,最终效果等于对原始向量先token embedding 加position embedding

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class BertEmbeddings(nn.Module):

def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))

def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]

seq_length = input_shape[1]

if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] ##position embedding


if token_type_ids is None:
if hasattr(self, "token_type_ids"):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) ##token_type_ids embedding

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids) ##input id embedding
token_type_embeddings = self.token_type_embeddings(token_type_ids)

embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings) # LN+dropout
return embeddings

BERT Attention

$$
atten = softmax(\frac{QK^T}{\sqrt{(d_k)})})\cdot V
$$
其中$d_k = \frac{hidden_size}{num_heads}$;

为什么除以$d_k$:

如果向量的维度比较大,那么qk点积之后的结果也会比较大,这些数的数量积都会比较大,如果没有经过缩放的话,softmax很有可能就剩下[0, 0, 1]的结果了。为什么非要将均值和方差拉到0和1呢?这是ICS内部协变量偏移问题:机器学习都有一个前提假设那就是数据符合标准正态分布的,当到qk的时候,就发生了变化,因为点积的操作分布就不再是标准正态分布了,也会影响后续所有数据的分布。除以d_k,拉回标准正态分布。由于softmax的马太效应,在求偏导计算梯度的时候,梯度值为0,导致参数无法更新,即梯度消失。经过缩放之后,就不再是0或是1了,梯度值就能够正常的进行参数的更新。

为什么多头?

多头保证了transformer可以注意到不同子空间的信息,捕捉到更加丰富的特征信息。实验结果好。

Attention mask

$attention \quad mask = [1,1,1,\cdots,0,0]$

Padding mask

$padding \quad mask = [1,1,1,\cdots,0,0]$
将padding项变成1,其它项变成0,将qk的结果和padding mask相加;相加的时候,padding为1的位置会变成一个很小的数$-10^9$,这样softmax之后为0

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
class BertSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.num_attention_heads = config.num_attention_heads #12
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) #768/12 = 64
self.all_head_size = self.num_attention_heads * self.attention_head_size

self.query = nn.Linear(config.hidden_size, self.all_head_size) #Q 768x768
self.key = nn.Linear(config.hidden_size, self.all_head_size) # K 768x768
self.value = nn.Linear(config.hidden_size, self.all_head_size) #V 768x768

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

self.is_decoder = config.is_decoder

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
mixed_query_layer = self.query(hidden_states)
is_cross_attention = encoder_hidden_states is not None

if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states)) # b x s x 768 -> b x 12 x s x 64
value_layer = self.transpose_for_scores(self.value(hidden_states)) # b x s x 768 -> b x 12 x s x 64

query_layer = self.transpose_for_scores(mixed_query_layer) # b x s x 768 -> b x 12 x s x 64


# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) #QK^T b x 12 x s x s


attention_scores = attention_scores / math.sqrt(self.attention_head_size) # QK^T/sqrt(64)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask # attention mask

# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs) #dropout

# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask

context_layer = torch.matmul(attention_probs, value_layer) # * V

context_layer = context_layer.permute(0, 2, 1, 3).contiguous() #b x 12 x s x 64 -> b x s x 12 x 64
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)#b x s x 12 x 64 -> b x s x 768

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs

BERT Output

集成了linear+dropout+LN

1
2
3
4
5
6
7
8
9
10
11
12
class BertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) #残差链接
return hidden_states

为什么使用LN

所以BatchNorm会受到Batch size的影响; 当Batchsize小的时候效果往往不是非常稳定.
在nlp领域使用BN效果不好,BN的计算方式是以一个batch中的样本数据去计算它的均值和方差的。
这样计算它是有padding的影响的,并且代表不了整个的均值和方差。在刚刚小明和小红例子中,身高这个特征所在的列含义是相同的。但是在nlp里,第一行是”我“字的emb词向量,第二行是”宣“字的emb词向量,经过attention之后,形成的是包含语义信息的向量,每一列代表的含义并不相同了。不能说每个词向量的每个维度代表的含义是一样的。因此从这个⻆度来理解,这里采用BN是不合适的。所以LN用的是比较多的。以每个词向量去做标准化就可以了,这样就不会引入padding不相关的信息。batch_size太小时,一个batch的样本,其均值和方差,不足以代表总体样本的均值与方差。NLP领域不适合用BN

BERT Intermediate

集成了linear(768-3072)+GELU激活函数
$$
GELU(x) = 0.5x(1+tanh(\alpha(x+\beta x^3)))
$$
GELU函数可以能避免梯度消失问题。

1
2
3
4
5
6
7
8
9
10
11
12
13
class BertIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act

def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states

AdamW

具体见优化器blog

wordpiece

BPE(Byte-Pair Encoding)BPE的大概训练过程:按照从左到右的顺序,将一个词拆分成多个子词,每个子词尽可能长。 greedy longest-match-first algorithm,贪婪最长优先匹配算法。
BPE

SELU激活函数

内部归一化的速度比外部归一化快,这意味着网络能更快收敛;
不可能出现梯度消失或爆炸问题,见 SELU 论文附录的定理 2 和 3。
$$
SELU(x) = \lambda x \quad if \quad x>0
$$
$$
SELU(x) = \lambda (\alpha e^x-\alpha) \quad if \quad x<=0
$$

进阶

Warm up

在训练初期使用较小的学习率(从 0 开始),在一定步数(比如 1000 步)内逐渐提高到正常大小(比如上面的 2e-5),避免模型过早进入局部最优而过拟合;在训练后期再慢慢将学习率降低到 0,避免后期训练还出现较大的参数变化。

标签平滑

和L1,L2,dropout一样的正则化方法,对one-hot标签向量平滑
$$
\hat y = y_{hot}(1-\alpha)+\alpha/K
$$
K是分类的个数;避免模型对于正确标签过于自信,使得预测正负样本的输出值差别不那么大,从而避免过拟合,提高模型的泛化能力。

标签平滑可以让分类之间的cluster更加紧凑,增加类间距离,减少类内距离,提高泛化性,同时还能提高Model Calibration(模型对于预测值的confidences和accuracies之间aligned的程度)。但是在模型蒸馏中使用Label smoothing会导致性能下降。

gradient checkpointing

用 gradient checkpointing 技术以降低训练时的显存占用。gradient checkpointing 即梯度检查点,通过减少保存的计算图节点压缩模型占用空间,但是在计算梯度的时候需要重新计算没有存储的值,参考论文《Training Deep Nets with Sublinear Memory Cost》


BERT 各层详解
http://example.com/2023/03/09/BERT-embedding层详解/
作者
ZHUHAI
发布于
2023年3月9日
许可协议