Tensors shapes do not correspond in PyTorch

I have designed a network for instance segmentation and every time I load the mask files, it gives me an error by saying that the shapes do not correspond. The loaded masks' tensor shape is (32, 224, 224) (with the following order NxHxW).

1 Answers

I assume your segmentation head tries to solve the binary segmentation problem and you load the masks in gray-scale mode (you do not have color channels). Usually, the output of a segmentation network is a 4 dimensionality tensor with the following order in PyTorch NxCxHxW. Just bring bach the channels dimensionality.

import torch
input_tensor = torch.unsqueeze(input_tensor, dim=1)
