Shortcuts

Source code for quaterion.eval.accumulators.group_accumulator

from typing import Dict

import torch

from quaterion.eval.accumulators import Accumulator


[docs]class GroupAccumulator(Accumulator): """Accumulate embeddings and groups for group-based tasks.""" def __init__(self): super().__init__() self._groups = [] @property def state(self) -> Dict[str, torch.Tensor]: """Accumulated state Returns: Dict[str, torch.Tensor] - dictionary with embeddings and groups. """ state = super().state state.update({"groups": self.groups}) return state
[docs] def update(self, embeddings: torch.Tensor, groups: torch.Tensor, device=None): """Update accumulator state. Move provided embeddings and groups to proper device and add to accumulated state. Args: embeddings: embeddings to accumulate groups: corresponding groups to accumulate device: device to store tensors on """ if device is None: device = embeddings.device embeddings = embeddings.detach().to(device) groups = groups.detach().to(device) self._embeddings.append(embeddings) self._groups.append(groups)
[docs] def reset(self): """Reset accumulator state Reset accumulator status, accumulated embeddings and groups """ super().reset() self._groups = []
@property def groups(self): """Concatenate list of groups to Tensor Help to avoid concatenating groups for each batch during accumulation. Instead, concatenate it only on call. Returns: torch.Tensor: batch of groups """ return torch.cat(self._groups) if len(self._groups) else torch.Tensor()

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community