Shortcuts

Source code for quaterion.loss.softmax_loss

from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import LongTensor, Tensor

from quaterion.loss.group_loss import GroupLoss


[docs]class SoftmaxLoss(GroupLoss): """Regular cross-entropy loss. An implementation of softmax with dot product. It is designed to work with the base :class:`~quaterion.loss.group_loss.GroupLoss`. Args: embedding_size: Output dimension of the encoder. num_groups: Number of groups in the dataset. temperature: Temperature value to divide logits, defaults to 0.05 """ def __init__(self, embedding_size: int, num_groups: int, temperature: float = 0.05): super(GroupLoss, self).__init__() self.temperature = temperature self.kernel = nn.Parameter(torch.FloatTensor(embedding_size, num_groups)) nn.init.normal_(self.kernel, std=0.01)
[docs] def forward( self, embeddings: Tensor, groups: LongTensor, ) -> Tensor: """Compute loss value. Args: embeddings: shape: (batch_size, vector_length) - Output embeddings from the encoder groups: shape: (batch_size,) - Group ids, associated with embeddings Returns: Tensor: zero-size tensor, loss value """ # shape: (batch_size, num_groups) logits = torch.mm(embeddings, self.kernel) / self.temperature return F.cross_entropy(logits, groups)

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community