torch.searchsorted¶
-
torch.
searchsorted
(sorted_sequence, values, *, out_int32=False, right=False, out=None) → Tensor¶ Find the indices from the innermost dimension of
sorted_sequence
such that, if the corresponding values invalues
were inserted before the indices, the order of the corresponding innermost dimension withinsorted_sequence
would be preserved. Return a new tensor with the same size asvalues
. Ifright
is False (default), then the left boundary ofsorted_sequence
is closed. More formally, the returned index satisfies the following rules:sorted_sequence
right
returned index satisfies
1-D
False
sorted_sequence[i-1] < values[m][n]...[l][x] <= sorted_sequence[i]
1-D
True
sorted_sequence[i-1] <= values[m][n]...[l][x] < sorted_sequence[i]
N-D
False
sorted_sequence[m][n]...[l][i-1] <= values[m][n]...[l][x] < sorted_sequence[m][n]...[l][i]
N-D
True
sorted_sequence[m][n]...[l][i-1] < values[m][n]...[l][x] <= sorted_sequence[m][n]...[l][i]
- Parameters
- Keyword Arguments
out_int32 (bool, optional) – indicate the output data type. torch.int32 if True, torch.int64 otherwise. Default value is False, i.e. default output data type is torch.int64.
right (bool, optional) – if False, return the first suitable location that is found. If True, return the last such index. If no suitable index found, return 0 for non-numerical value (eg. nan, inf) or the size of innermost dimension within
sorted_sequence
(one pass the last index of the innermost dimension). In other words, if False, gets the lower bound index for each value invalues
on the corresponding innermost dimension of thesorted_sequence
. If True, gets the upper bound index instead. Default value is False.out (Tensor, optional) – the output tensor, must be the same size as
values
if provided.
Note
If your use case is always 1-D sorted sequence,
torch.bucketize()
is preferred, because it has fewer dimension checks resulting in slightly better performance.Example:
>>> sorted_sequence = torch.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]) >>> sorted_sequence tensor([[ 1, 3, 5, 7, 9], [ 2, 4, 6, 8, 10]]) >>> values = torch.tensor([[3, 6, 9], [3, 6, 9]]) >>> values tensor([[3, 6, 9], [3, 6, 9]]) >>> torch.searchsorted(sorted_sequence, values) tensor([[1, 3, 4], [1, 2, 4]]) >>> torch.searchsorted(sorted_sequence, values, right=True) tensor([[2, 3, 5], [1, 3, 4]]) >>> sorted_sequence_1d = torch.tensor([1, 3, 5, 7, 9]) >>> sorted_sequence_1d tensor([1, 3, 5, 7, 9]) >>> torch.searchsorted(sorted_sequence_1d, values) tensor([[1, 3, 4], [1, 3, 4]])