quaterion.loss.group_loss module¶
- class GroupLoss(distance_metric_name: Distance = Distance.COSINE)[source]¶
Bases:
SimilarityLoss
Base class for group losses.
- Parameters:
distance_metric_name – Name of the distance function, e.g.,
Distance
.
- forward(embeddings: Tensor, groups: LongTensor) Tensor [source]¶
- Parameters:
embeddings – shape: (batch_size, vector_length)
groups – shape: (batch_size,) - Groups, associated with embeddings
- Returns:
Tensor – zero-size tensor, loss value
- xbm_loss(embeddings: Tensor, groups: LongTensor, memory_embeddings: Tensor, memory_groups: LongTensor) Tensor [source]¶
Implement XBM loss computation for this loss.
- Parameters:
embeddings – shape: (batch_size, vector_length) - Output embeddings from the encoder.
groups – shape: (batch_size,) - Group ids associated with embeddings.
memory_embeddings – shape: (memory_buffer_size, vector_length) - Embeddings stored in a ring buffer
memory_groups – shape: (memory_buffer_size,) - Groups ids associated with memory_embeddings
- Returns:
Tensor – zero-size tensor, XBM loss value.
- training: bool¶