-
Notifications
You must be signed in to change notification settings - Fork 516
/
positional_encoding.py
156 lines (128 loc) · 5.29 KB
/
positional_encoding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
import math
from typing import Optional
import torch
from torch import Tensor, nn
from corenet.modeling.layers.base_layer import BaseLayer
from corenet.modeling.layers.dropout import Dropout
class SinusoidalPositionalEncoding(BaseLayer):
"""
This layer adds sinusoidal positional embeddings to a 3D input tensor. The code has been adapted from
`Pytorch tutorial <https://pytorch.org/tutorials/beginner/transformer_tutorial.html>`_
Args:
d_model (int): dimension of the input tensor
dropout (Optional[float]): Dropout rate. Default: 0.0
max_len (Optional[int]): Max. number of patches (or seq. length). Default: 5000
channels_last (Optional[bool]): Channels dimension is the last in the input tensor
Shape:
- Input: :math:`(N, C, P)` or :math:`(N, P, C)` where :math:`N` is the batch size, :math:`C` is the embedding dimension,
:math:`P` is the number of patches
- Output: same shape as the input
"""
def __init__(
self,
d_model: int,
dropout: Optional[float] = 0.0,
max_len: Optional[int] = 5000,
channels_last: Optional[bool] = True,
*args,
**kwargs
) -> None:
position_last = not channels_last
pos_encoding = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pos_encoding[:, 0::2] = torch.sin(position * div_term)
pos_encoding[:, 1::2] = torch.cos(position * div_term)
# add dummy batch dimension
pos_encoding = pos_encoding.unsqueeze(0) # [1 x C x P_max)
patch_dim = -2 # patch dimension is second last (N, P, C)
if position_last:
pos_encoding = pos_encoding.transpose(
1, 2
) # patch dimension is last (N, C, P)
patch_dim = -1 # patch dimension is last (N, C, P)
super().__init__()
self.dropout = Dropout(p=dropout)
self.patch_dim = patch_dim
self.register_buffer("pe", pos_encoding)
def forward_patch_last(
self, x, indices: Optional[Tensor] = None, *args, **kwargs
) -> Tensor:
# seq_length should be the last dim
if indices is None:
x = x + self.pe[..., : x.shape[-1]]
else:
ndim = x.ndim
repeat_size = [x.shape[0]] + [-1] * (ndim - 1)
pe = self.pe.expand(repeat_size)
selected_pe = torch.gather(pe, index=indices, dim=-1)
x = x + selected_pe
return self.dropout(x)
def forward_others(
self, x, indices: Optional[Tensor] = None, *args, **kwargs
) -> Tensor:
# seq_length should be the second last dim
if indices is None:
x = x + self.pe[..., : x.shape[-2], :]
else:
ndim = x.ndim
repeat_size = [x.shape[0]] + [-1] * (ndim - 1)
pe = self.pe.expand(repeat_size)
selected_pe = torch.gather(pe, index=indices, dim=-2)
x = x + selected_pe
return self.dropout(x)
def forward(self, x, indices: Optional[Tensor] = None, *args, **kwargs) -> Tensor:
if self.patch_dim == -1:
return self.forward_patch_last(x, indices=indices)
else:
return self.forward_others(x, indices=indices)
def __repr__(self):
return "{}(dropout={})".format(self.__class__.__name__, self.dropout.p)
class LearnablePositionEncoding(BaseLayer):
"""
This layer adds learnable positional embeddings to a 3D input tensor.
Args:
embed_dim (int): dimension of the input tensor
num_embeddings (int): number of input embeddings. This is similar to vocab size in NLP.
dropout (Optional[float]): Dropout rate. Default: 0.0
channels_last (Optional[bool]): Channels dimension is the last in the input tensor
Shape:
- Input: :math:`(N, *, C, P)` or :math:`(N, *, P, C)` where :math:`N` is the batch size, :math:`C` is the embedding dimension,
:math:`P` is the number of patches
- Output: same shape as the input
"""
def __init__(
self,
embed_dim: int,
num_embeddings: int,
dropout: Optional[float] = 0.0,
channels_last: Optional[bool] = True,
*args,
**kwargs
) -> None:
super().__init__()
self.pos_emb = nn.Embedding(
num_embeddings=num_embeddings, embedding_dim=embed_dim
)
self.channel_last = channels_last
self.dropout = Dropout(p=dropout)
def forward(self, x, *args, **kwargs) -> Tensor:
num_embeddings = x.shape[-2] if self.channel_last else x.shape[-1]
posistions = torch.arange(num_embeddings, dtype=torch.int64, device=x.device)
position_emb = self.pos_emb(posistions)
position_emb = position_emb.expand_as(x)
x = x + position_emb
return self.dropout(x)
def __repr__(self):
return "{}(embed_dim={}, vocab_size={}, dropout={})".format(
self.__class__.__name__,
self.pos_emb.embedding_dim,
self.pos_emb.num_embeddings,
self.dropout.p,
)