Source code for quaterion.loss.softmax_loss
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import LongTensor, Tensor
from quaterion.loss.group_loss import GroupLoss
[docs]class SoftmaxLoss(GroupLoss):
"""Regular cross-entropy loss.
An implementation of softmax with dot product.
It is designed to work with the base :class:`~quaterion.loss.group_loss.GroupLoss`.
Args:
embedding_size: Output dimension of the encoder.
num_groups: Number of groups in the dataset.
temperature: Temperature value to divide logits, defaults to 0.05
"""
def __init__(self, embedding_size: int, num_groups: int, temperature: float = 0.05):
super(GroupLoss, self).__init__()
self.temperature = temperature
self.kernel = nn.Parameter(torch.FloatTensor(embedding_size, num_groups))
nn.init.normal_(self.kernel, std=0.01)
[docs] def forward(
self,
embeddings: Tensor,
groups: LongTensor,
) -> Tensor:
"""Compute loss value.
Args:
embeddings: shape: (batch_size, vector_length) - Output embeddings from the
encoder
groups: shape: (batch_size,) - Group ids, associated with embeddings
Returns:
Tensor: zero-size tensor, loss value
"""
# shape: (batch_size, num_groups)
logits = torch.mm(embeddings, self.kernel) / self.temperature
return F.cross_entropy(logits, groups)