diff --git a/setup.py b/setup.py index af5676ba..4e44950d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.43.4', + version = '1.43.5', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index 482a3800..c41fcaff 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -38,6 +38,7 @@ class LayerIntermediates: attn_z_loss: Tensor | None = None mems: Tensor | None = None memory_tokens: Tensor | None = None + logit_entropies: Tensor | None = None LinearNoBias = partial(nn.Linear, bias = False) @@ -136,6 +137,15 @@ def or_reduce(masks): head = head | rest return head +# entropy + +def calc_entropy( + t: Tensor, + is_prob = False +): + prob = t.softmax(dim = -1) if not is_prob else t + return -(prob * log(prob)).sum(dim = -1) + # auxiliary loss helpers def calc_z_loss( @@ -2592,6 +2602,7 @@ def forward( return_embeddings = False, return_logits_and_embeddings = False, return_intermediates = False, + return_logit_entropies = False, mask = None, return_mems = False, return_attn = False, @@ -2809,6 +2820,12 @@ def forward( else: out = logits + # logit entropies + + if return_logit_entropies: + intermediates.logit_entropies = calc_entropy(logits) + return_intermediates = True + # aux loss if return_attn_z_loss: