400% higher error with PyTorch compared with identical Keras model (with Adam optimizer)

The problem here is unintentional broadcasting in the PyTorch training loop.

The result of a nn.Linear operation always has shape [B,D], where B is the batch size and D is the output dimension. Therefore, in your mean_squared_error function ypred has shape [32,1] and ytrue has shape [32]. By the broadcasting rules used by NumPy and PyTorch this means that ytrue - ypred has shape [32,32]. What you almost certainly meant is for ypred to have shape [32]. This can be accomplished in many ways; probably the most readable is to use Tensor.flatten

class TorchLinearModel(nn.Module):
    ...
    def forward(self, x):
        x = self.hidden_layer(x)
        x = self.output_layer(x)
        return x.flatten()

which produces the following train/val curves

enter image description here

Leave a Comment