Unflatten¶
-
class
torch.nn.
Unflatten
(dim, unflattened_size)[source]¶ Unflattens a tensor dim expanding it to a desired shape. For use with
Sequential
.dim
specifies the dimension of the input tensor to be unflattened, and it can be either int or str when Tensor or NamedTensor is used, respectively.unflattened_size
is the new shape of the unflattened dimension of the tensor and it can be a tuple of ints or a list of ints or torch.Size for Tensor input; a NamedShape (tuple of (name, size) tuples) for NamedTensor input.
- Shape:
Input:
Output:
- Parameters
Examples
>>> input = torch.randn(2, 50) >>> # With tuple of ints >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, (2, 5, 5)) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With torch.Size >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, torch.Size([2, 5, 5])) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With namedshape (tuple of tuples) >>> input = torch.randn(2, 50, names=('N', 'features')) >>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5))) >>> output = unflatten(input) >>> output.size() torch.Size([2, 2, 5, 5])
-
NamedShape
¶ alias of
typing.Tuple