• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

yeliudev / nncore / 8318269918

17 Mar 2024 09:04PM UTC coverage: 15.886% (-0.01%) from 15.897%
8318269918

push

github

yeliudev
Fix NNDistributedDataParallel

0 of 5 new or added lines in 1 file covered. (0.0%)

678 of 4268 relevant lines covered (15.89%)

3.07 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/nncore/parallel/parallel.py
1
# Copyright (c) Ye Liu. Licensed under the MIT License.
2

3
import torch
×
4
from torch.nn.parallel import DataParallel, DistributedDataParallel
×
5
from torch.nn.parallel._functions import Scatter, _get_stream
×
6
from torch.nn.parallel.scatter_gather import _is_namedtuple
×
7

8
import nncore
×
9
from .container import DataContainer
×
10

11

12
class _Scatter(torch.autograd.Function):
×
13

14
    @staticmethod
×
15
    def forward(target_gpus, input):
×
16
        input_device = _get_input_device(input)
×
17
        streams = None
×
18
        if input_device == -1 and target_gpus != [-1]:
×
19
            streams = [
×
20
                _get_stream(torch.device('cuda', gpu_id))
21
                for gpu_id in target_gpus
22
            ]
23
        outputs = _scatter_stream(input, target_gpus, streams)
×
24
        if streams is not None:
×
25
            _sync_stream(outputs, target_gpus, streams)
×
26
        return tuple(outputs)
×
27

28

29
def _get_input_device(input):
×
30
    if isinstance(input, list):
×
31
        for item in input:
×
32
            input_device = _get_input_device(item)
×
33
            if input_device != -1:
×
34
                return input_device
×
35
        return -1
×
36
    elif torch.is_tensor(input):
×
37
        return input.get_device() if input.is_cuda else -1
×
38
    else:
39
        raise TypeError('unknown type {}'.format(type(input)))
×
40

41

42
def _scatter_stream(input, devices, streams=None):
×
43
    if streams is None:
×
44
        streams = [None] * len(devices)
×
45
    if isinstance(input, list):
×
46
        chunk_size = (len(input) - 1) // len(devices) + 1
×
47
        outputs = [
×
48
            _scatter_stream(input[i], [devices[i // chunk_size]],
49
                            [streams[i // chunk_size]])
50
            for i in range(len(input))
51
        ]
52
        return outputs
×
53
    elif torch.is_tensor(input):
×
54
        output = input.contiguous()
×
55
        stream = streams[0] if output.numel() > 0 else None
×
56
        if devices != [-1]:
×
57
            with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
×
58
                output = output.cuda(devices[0], non_blocking=True)
×
59
        return output
×
60
    else:
61
        raise TypeError('unknown type {}'.format(type(input)))
×
62

63

64
def _sync_stream(output, devices, streams):
×
65
    if isinstance(output, list):
×
66
        chunk_size = len(output) // len(devices)
×
67
        for i in range(len(devices)):
×
68
            for j in range(chunk_size):
×
69
                _sync_stream(output[i * chunk_size + j], [devices[i]],
×
70
                             [streams[i]])
71
    elif torch.is_tensor(output):
×
72
        if output.numel() != 0:
×
73
            with torch.cuda.device(devices[0]):
×
74
                main_stream = torch.cuda.current_stream()
×
75
                main_stream.wait_stream(streams[0])
×
76
                output.record_stream(main_stream)
×
77
    else:
78
        raise TypeError('unknown type {}'.format(type(output)))
×
79

80

81
def _scatter(inputs, target_gpus, dim=0):
×
82

83
    def _scatter_map(obj):
×
84
        if torch.is_tensor(obj) and target_gpus != [-1]:
×
85
            return Scatter.apply(target_gpus, None, dim, obj)
×
86
        if isinstance(obj, DataContainer):
×
87
            return obj.data if obj.cpu_only else _Scatter.forward(
×
88
                target_gpus, obj.data)
89
        if _is_namedtuple(obj):
×
90
            return [type(obj)(*args) for args in zip(*map(_scatter_map, obj))]
×
91
        if isinstance(obj, tuple) and len(obj) > 0:
×
92
            return list(zip(*map(_scatter_map, obj)))
×
93
        if isinstance(obj, list) and len(obj) > 0:
×
94
            return [list(i) for i in zip(*map(_scatter_map, obj))]
×
95
        if isinstance(obj, dict) and len(obj) > 0:
×
96
            return [type(obj)(i) for i in zip(*map(_scatter_map, obj.items()))]
×
97
        return [obj for _ in target_gpus]
×
98

99
    try:
×
100
        res = _scatter_map(inputs)
×
101
    finally:
102
        _scatter_map = None
×
103

104
    return res
×
105

106

107
def _scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
×
108
    inputs = _scatter(inputs, target_gpus, dim) if inputs else []
×
109
    kwargs = _scatter(kwargs, target_gpus, dim) if kwargs else []
×
110
    if len(inputs) < len(kwargs):
×
111
        inputs.extend(() for _ in range(len(kwargs) - len(inputs)))
×
112
    elif len(kwargs) < len(inputs):
×
113
        kwargs.extend({} for _ in range(len(inputs) - len(kwargs)))
×
114
    return tuple(inputs), tuple(kwargs)
×
115

116

117
class NNDataParallel(DataParallel):
×
118
    """
119
    A :obj:`nn.DataParallel` class with :obj:`DataContainer` support. This
120
    class only bundles single-device modules.
121

122
    Args:
123
        module (:obj:`nn.Module`): The module to be bundled.
124
        device_id (int | None, optional): The device id to be used. ``None``
125
            means using the default device, and ``-1`` means CPU. Default:
126
            ``None``.
127
    """
128

129
    def __init__(self, module, device_id=None, dim=0, **kwargs):
×
130
        assert isinstance(device_id, int) or device_id is None
×
131
        assert 'device_ids' not in kwargs and 'output_device' not in kwargs
×
132

133
        if device_id is None:
×
134
            if torch.cuda.is_available():
×
135
                device_id = torch.cuda.current_device()
×
136
            else:
137
                device_id = -1
×
138

139
        if device_id == -1:
×
140
            logger = nncore.get_logger()
×
141
            logger.warn('{} is running on CPU'.format(self.__class__.__name__))
×
142

143
            super(DataParallel, self).__init__()
×
144
            torch._C._log_api_usage_once('torch.nn.parallel.DataParallel')
×
145

146
            self.module = module
×
147
            self.device_ids = []
×
148
            self.dim = dim
×
149
            return
×
150

151
        super(NNDataParallel, self).__init__(
×
152
            module, device_ids=[device_id], output_device=device_id, **kwargs)
153

154
    def scatter(self, inputs, kwargs, device_ids):
×
155
        return _scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
×
156

157
    def forward(self, *inputs, **kwargs):
×
158
        if self.device_ids:
×
159
            return super(NNDataParallel, self).forward(*inputs, **kwargs)
×
160
        else:
161
            inputs, kwargs = self.scatter(inputs, kwargs, [-1])
×
162
            return self.module(*inputs[0], **kwargs[0])
×
163

164

165
class NNDistributedDataParallel(DistributedDataParallel):
×
166
    """
167
    A :obj:`nn.DistributedDataParallel` class with :obj:`DataContainer`
168
    support. This class only bundles single-device modules.
169

170
    Args:
171
        module (:obj:`nn.Module`): The module to be bundled.
172
        device_id (int | None, optional): The device id to be used. ``None``
173
            means using the default device, and ``-1`` means CPU. Default:
174
            ``None``.
175
    """
176

177
    def __init__(self, module, device_id=None, **kwargs):
×
178
        assert isinstance(device_id, int) or device_id is None
×
179
        assert 'device_ids' not in kwargs and 'output_device' not in kwargs
×
180

181
        if device_id is None:
×
182
            if torch.cuda.is_available():
×
183
                device_id = torch.cuda.current_device()
×
184
            else:
185
                device_id = -1
×
186

187
        if device_id >= 0:
×
188
            module = module.to('cuda:{}'.format(device_id))
×
189
            device_ids = [device_id]
×
190
        else:
191
            logger = nncore.get_logger()
×
192
            logger.warn('{} is running on CPU'.format(self.__class__.__name__))
×
193
            device_ids = None
×
194

195
        super(NNDistributedDataParallel, self).__init__(
×
196
            module, device_ids=device_ids, **kwargs)
197

NEW
198
    def _run_ddp_forward(self, *inputs, **kwargs):
×
NEW
199
        if self.device_ids:
×
NEW
200
            inputs, kwargs = _scatter_kwargs(
×
201
                inputs, kwargs, self.device_ids, dim=self.dim)
NEW
202
            inputs, kwargs = inputs[0], kwargs[0]
×
NEW
203
        super(NNDistributedDataParallel,
×
204
              self)._run_ddp_forward(*inputs, **kwargs)
205

206
    def forward(self, *inputs, **kwargs):
×
207
        if self.device_ids:
×
208
            return super(NNDistributedDataParallel,
×
209
                         self).forward(*inputs, **kwargs)
210
        else:
211
            inputs, kwargs = self.scatter(inputs, kwargs, [-1])
×
212
            return self.module(*inputs[0], **kwargs[0])
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc