GPT 各层详解

Mask attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
class GPT2Attention(nn.Module):
def __init__(self, config, is_cross_attention=False, layer_idx=None):
super().__init__()

max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
1, 1, max_positions, max_positions
),
)
self.register_buffer("masked_bias", torch.tensor(-1e4))

self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
)

self.scale_attn_weights = config.scale_attn_weights
self.is_cross_attention = is_cross_attention

# Layer-wise attention scaling, reordering, and upcasting
self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
self.layer_idx = layer_idx
self.reorder_and_upcast_attn = config.reorder_and_upcast_attn

if self.is_cross_attention:
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
else:
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)

self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)

self.pruned_heads = set()

def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])

# Prune conv1d layers
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)

# Update hyper params
self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
self.num_heads = self.num_heads - len(heads)
self.pruned_heads = self.pruned_heads.union(heads)

def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))

if self.scale_attn_weights:
attn_weights = attn_weights / (value.size(-1) ** 0.5)

# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
attn_weights = attn_weights / float(self.layer_idx + 1)

if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))

if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask

attn_weights = nn.functional.softmax(attn_weights, dim=-1)

# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)

# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask

attn_output = torch.matmul(attn_weights, value)

return attn_output, attn_weights

def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
# Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
bsz, num_heads, q_seq_len, dk = query.size()
_, _, k_seq_len, _ = key.size()

# Preallocate attn_weights for `baddbmm`
attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)

# Compute Scale Factor
scale_factor = 1.0
if self.scale_attn_weights:
scale_factor /= float(value.size(-1)) ** 0.5

if self.scale_attn_by_inverse_layer_idx:
scale_factor /= float(self.layer_idx + 1)

# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
if is_amp_available:
with autocast(enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
else:
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)

if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))

if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask

attn_weights = nn.functional.softmax(attn_weights, dim=-1)

# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
if attn_weights.dtype != torch.float32:
raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)

# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask

attn_output = torch.matmul(attn_weights, value)

return attn_output, attn_weights

def _split_heads(self, tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)

def _merge_heads(self, tensor, num_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
tensor = tensor.permute(0, 2, 1, 3).contiguous()
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)

def forward(
self,
hidden_states,
layer_past=None,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
output_attentions=False,
):
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
)

query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)

if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)

if use_cache is True:
present = (key, value)
else:
present = None

if self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
else:
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)

return outputs # a, present, (attentions)

Cross attention

Generate

每次选最高的概率词;贪婪搜索的主要缺点是它错过了隐藏在低概率单词后面的高概率单词,如我们上面的草图所示
greedy search

波束搜索将始终找到比贪婪搜索概率更高的输出序列,但不能保证找到最可能的输出。最常见的n-gram惩罚是通过手动将下一个单词的概率设置为0来确保没有n-gram出现两次
beam search
在开放式一代中,最近提出了波束搜索可能不是最佳选择的几个原因:

  • 波束搜索可以在期望生成的长度或多或少可预测的任务中很好地工作,如机器翻译或摘要-参见Murray等人(2018)和Yang等人(2018年)。但开放式生成的情况并非如此,其中所需的输出长度可能会有很大变化,例如对话和故事生成。
  • 我们已经看到,波束搜索严重受到重复生成的影响。在故事生成过程中,这尤其难以用n-gram或其他惩罚来控制,因为在强制“不重复”和重复相同n-gram的循环之间找到一个很好的平衡需要大量的微调。
  • 正如Ari Holtzman等人(2019)所指出的,高质量的人类语言并不遵循高概率下一个单词的分布。换句话说,作为人类,我们希望生成的文本给我们带来惊喜,而不是无聊/可预测的。作者通过绘制概率图很好地展示了这一点,模型将给出人类文本与波束搜索的区别。

Sampling

通过温度系数和softmax对概率分布进行sharpen
sampling

Top-K sampling

前k个采样意味着按概率排序,并将低于第k个令牌的任何事物的概率归零。它似乎通过去除尾部来提高质量,使其不太可能偏离主题。它不能动态调整从下一个单词概率分布中过滤出来的单词数量;这可能是有问题的,因为有些单词可能来自非常尖锐的分布(上图中右侧的分布),而其他单词则来自更平坦的分布(下图中左侧的分布)。
Top K

Top-p(nucleus) sampling

与仅从最可能的K个单词中采样不同,在Top-p中,采样从累积概率超过概率p的最小可能单词集合中选择。
Top p


GPT 各层详解
http://example.com/2023/03/10/GPT-各层详解/
作者
ZHUHAI
发布于
2023年3月10日
许可协议