quaterion.eval.samplers.pair_sampler module¶
- class PairSampler(sample_size: int = -1, distinguish: bool = False, encode_batch_size: int = 16, device: device | str | None = None, log_progress: bool = True)[source]¶
Bases:
BaseSampler
Perform selection of embeddings and targets for pairs based tasks.
Sampler allows reducing amount of time and resources to calculate a distance matrix. Instead of calculation of squared matrix with shape (num_embeddings, num_embeddings), it selects embeddings and computes matrix of a rectangle shape.
- Parameters:
sample_size – int - amount of objects to select
distinguish – bool - determines whether to compare all objects each-to-each, or to compare only obj_a to obj_b. If true - compare only obj_a to obj_b. Significantly reduces matrix size.
encode_batch_size – int - batch size to use during encoding
- accumulate(model: SimilarityModel, dataset: Sized)[source]¶
Encodes objects and accumulates embeddings with the corresponding raw labels
- Parameters:
model – model to encode objects
dataset – Sized object, like list, tuple, torch.utils.data.Dataset, etc. to accumulate
- sample(dataset: Sized, metric: PairMetric, model: SimilarityModel) Tuple[Tensor, Tensor] [source]¶
Sample embeddings and targets for pairs based tasks.
- Parameters:
dataset – Sized object, like list, tuple, torch.utils.data.Dataset, etc. to sample
metric – PairMetric instance to compute final labels representation
model – model to encode objects
- Returns:
torch.Tensor, torch.Tensor – metrics labels and computed distance matrix