diff --git a/loss.py b/loss.py index c6b46d5..569f03c 100644 --- a/loss.py +++ b/loss.py @@ -52,7 +52,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: ) if ( - flattened_images and not self.num_bins + flattened_images and not self.num_modalities ): # then output loss should be reshaped loss_tensor = loss_tensor.reshape(B, H_W * self.num_modalities, C)