This is kind of unusual, as commonly “Batch” is exactly defined as all operations of the network being invariant along that dimension.
So you could, for example, just introduce another dimension. So you have the “former batch dimension” in which your operation is not invariant. For this keep your current implementation. Then, parallelize over the new dimension of multiple “actual batches” of data.
But, to stay closer to the question you asked, I see two options:
- As you said, inside your implementation figure out which original batch you are operating on (depending on total number of parallel splits, etc). This can become hairy.
- Consider your parameter as Part of Input! In your outside call, pass the parameter along your input data to the forward of your model.
So (Pythonlike-Pseudocode):
Network(nn.Module):
...
def forward(x, parameter):
x=self.pre_modules(x)
x=self.custom_module(x,parameter)
return x
parameter=torch.zeros(16,requires_grad=True)
net=nn.DataParallel(model)
net(input,parameter)
If your are willing to accept that this will be a leaky abstraction of the network and are mainly interested in getting things to work, I would try out the latter approach first.