Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workaround for nn.DataParallel bug #1

Open
gogolgrind opened this issue Apr 14, 2020 · 1 comment
Open

Workaround for nn.DataParallel bug #1

gogolgrind opened this issue Apr 14, 2020 · 1 comment
Assignees
Labels
bug Something isn't working question Further information is requested wontfix This will not be worked on

Comments

@gogolgrind
Copy link

gogolgrind commented Apr 14, 2020

It seems cplxmoduke don't work with nn.DataParallel.
Attached minimal example gives the following error

RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 0 does not equal 1 (while checking arguments for cudnn_convolution)

cplxmodule-bug.py.txt

@ivannz
Copy link
Owner

ivannz commented Apr 14, 2020

Thank you for the issue!

Analysis

Here is the minimally reproducing example:

import torch
from torch import nn

from cplxmodule import cplx
import cplxmodule.nn as cplxnn

from torch.nn.parallel.data_parallel import data_parallel


net, x = nn.Conv2d(3, 3, 3), torch.randn(2, 3, 6, 6)
data_parallel(net.cuda(0), x.cuda(0), [0, 1])

net, x = cplxnn.CplxConv2d(3, 3, 3), cplx.Cplx(torch.randn(1, 3, 6, 6))
data_parallel(net.cuda(0), x.cuda(0), [0, 1])

data_parallel uses three key functions:

from torch.nn.parallel.scatter_gather import scatter_kwargs
from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.parallel_apply import parallel_apply

scatter_kwargs is responsible for splitting the input along the batch dimension and moving the shards to appropriate devices. replicate is responsible for performing in-vivo surgery on the model: taking it apart, and rebuilding it with its parameters wrapped in scatter-gather operations. Finally parallel_apply takes properly placed inputs and model replicas and run them in parallel threads, collecting the output upon termination.

replication and moving inputs manually seems to work ok:

from torch.nn.parallel.replicate import replicate

net = cplxnn.CplxConv2d(3, 3, 3)
x = cplx.Cplx(torch.randn(1, 3, 6, 6), torch.randn(1, 3, 6, 6))

replicas = replicate(net.cuda(0), [0, 1])
print([m.weight.device for m in replicas])
print([m(x.to(m.weight.device)).device for m in replicas])

Since the error message from parallel_apply coincides with

import torch.nn.functional as F

F.conv2d(torch.randn(1, 3, 6, 6).cuda(0), torch.randn(3, 3, 3, 3).cuda(1))

I think the issue here is that the input is on the wrong device, whilst the model might be on the right one. Thus i investigated scatter_kwargs.

scatter_kwargs calls scatter on the input tensor. scatter uses some internal low-level functionality to scatter the tensor along the zero-th dimension. Unfortunately Cplx is a high level Python object and is duck-typed to behave like a Tensor only on high-level. It is not binary compatible with torch.Tensor on C++ level.

Solution

Unfortunately I can only suggest a workaround wrapper.

def r2r_wrap(model, dim_in=1, dim_out=1):
    return torch.nn.Sequential(
        # convert Tensor `B x F*2 x ...` (dim_in=1) to Cplx as part of the computations
        cplxnn.RealToCplx(dim=dim_in),
        model,
        # convert Cplx back to B x O*2 x ... (dim_out=1) Tensor as part of the pipeline
        cplxnn.CplxToReal(dim=dim_out) 
    )

model_complex = r2r_wrap(cplxtest_net()).to(device)

This is a thin Real-to-Real wrapper around the whole model, which makes conversion from torch Tensors to Cplx and back a part of the model structure, and thus bypasses the scatter issue.

The key nuance is that your input to and output of the model is just Tensor, not cplx.Cplx

# put real-imag pairs along the 1st dim
complex_data = torch.randn(1, 3*2, 224, 224).to(device)

# real-imag pairs are assumed to alternate the 1st dim
cplx_data = cplxnn.RealToCplx(dim=1)(complex_data)
assert torch.allclose(cplx_data.real, complex_data[:, 0::2])
assert torch.allclose(cplx_data.imag, complex_data[:, 1::2])

The output:

tensor_output = model_complex(complex_data)
cplx_data = cplxnn.RealToCplx(dim=1)(tensor_output)
assert torch.allclose(cplx_data.real, tensor_output[:, 0::2])
assert torch.allclose(cplx_data.imag, tensor_output[:, 1::2])

These Tensors are just a way to store the complex numbers and in no way affect their arithmetic or operations inside the Cplx network or cplxmodule itself.

@ivannz ivannz added bug Something isn't working wontfix This will not be worked on labels Apr 14, 2020
@ivannz ivannz self-assigned this Apr 14, 2020
@ivannz ivannz closed this as completed May 25, 2020
@ivannz ivannz pinned this issue Jun 8, 2020
@ivannz ivannz reopened this Jun 8, 2020
@ivannz ivannz changed the title nn.DataParallel bug Workaround for nn.DataParallel bug Jun 8, 2020
Repository owner locked as resolved and limited conversation to collaborators Jun 8, 2020
@ivannz ivannz added the question Further information is requested label Aug 16, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug Something isn't working question Further information is requested wontfix This will not be worked on
Projects
None yet
Development

No branches or pull requests

2 participants