• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

yeliudev / nncore / 8314091714

17 Mar 2024 08:21AM UTC coverage: 15.964% (-0.2%) from 16.205%
8314091714

push

github

yeliudev
Fix scatter

0 of 70 new or added lines in 1 file covered. (0.0%)

2 existing lines in 1 file now uncovered.

678 of 4247 relevant lines covered (15.96%)

3.09 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/nncore/parallel/parallel.py
1
# Copyright (c) Ye Liu. Licensed under the MIT License.
2

3
import torch
×
4
from torch.nn.parallel import DataParallel, DistributedDataParallel
×
NEW
5
from torch.nn.parallel._functions import Scatter, _get_stream
×
6
from torch.nn.parallel.scatter_gather import _is_namedtuple
×
7

8
import nncore
×
9
from .container import DataContainer
×
10

11

NEW
12
class _Scatter(torch.autograd.Function):
×
13

NEW
14
    @staticmethod
×
NEW
15
    def forward(target_gpus, input):
×
NEW
16
        input_device = _get_input_device(input)
×
NEW
17
        streams = None
×
NEW
18
        if input_device == -1 and target_gpus != [-1]:
×
NEW
19
            streams = [
×
20
                _get_stream(torch.device('cuda', gpu_id))
21
                for gpu_id in target_gpus
22
            ]
NEW
23
        outputs = _scatter_stream(input, target_gpus, streams)
×
NEW
24
        if streams is not None:
×
NEW
25
            _sync_stream(outputs, target_gpus, streams)
×
NEW
26
        return tuple(outputs)
×
27

28

NEW
29
def _get_input_device(input):
×
NEW
30
    if isinstance(input, list):
×
NEW
31
        for item in input:
×
NEW
32
            input_device = _get_input_device(item)
×
NEW
33
            if input_device != -1:
×
NEW
34
                return input_device
×
NEW
35
        return -1
×
NEW
36
    elif torch.is_tensor(input):
×
NEW
37
        return input.get_device() if input.is_cuda else -1
×
38
    else:
NEW
39
        raise TypeError('unknown type {}'.format(type(input)))
×
40

41

NEW
42
def _scatter_stream(input, devices, streams=None):
×
NEW
43
    if streams is None:
×
NEW
44
        streams = [None] * len(devices)
×
NEW
45
    if isinstance(input, list):
×
NEW
46
        chunk_size = (len(input) - 1) // len(devices) + 1
×
NEW
47
        outputs = [
×
48
            _scatter_stream(input[i], [devices[i // chunk_size]],
49
                            [streams[i // chunk_size]])
50
            for i in range(len(input))
51
        ]
NEW
52
        return outputs
×
NEW
53
    elif torch.is_tensor(input):
×
NEW
54
        output = input.contiguous()
×
NEW
55
        stream = streams[0] if output.numel() > 0 else None
×
NEW
56
        if devices != [-1]:
×
NEW
57
            with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
×
NEW
58
                output = output.cuda(devices[0], non_blocking=True)
×
NEW
59
        return output
×
60
    else:
NEW
61
        raise TypeError('unknown type {}'.format(type(input)))
×
62

63

NEW
64
def _sync_stream(output, devices, streams):
×
NEW
65
    if isinstance(output, list):
×
NEW
66
        chunk_size = len(output) // len(devices)
×
NEW
67
        for i in range(len(devices)):
×
NEW
68
            for j in range(chunk_size):
×
NEW
69
                _sync_stream(output[i * chunk_size + j], [devices[i]],
×
70
                             [streams[i]])
NEW
71
    elif torch.is_tensor(output):
×
NEW
72
        if output.numel() != 0:
×
NEW
73
            with torch.cuda.device(devices[0]):
×
NEW
74
                main_stream = torch.cuda.current_stream()
×
NEW
75
                main_stream.wait_stream(streams[0])
×
NEW
76
                output.record_stream(main_stream)
×
77
    else:
NEW
78
        raise TypeError('unknown type {}'.format(type(output)))
×
79

80

UNCOV
81
def _scatter(inputs, target_gpus, dim=0):
×
82

83
    def _scatter_map(obj):
×
84
        if torch.is_tensor(obj) and target_gpus != [-1]:
×
85
            return Scatter.apply(target_gpus, None, dim, obj)
×
86
        if isinstance(obj, DataContainer):
×
NEW
87
            return obj.data if obj.cpu_only else _Scatter.forward(
×
88
                target_gpus, obj.data)
89
        if _is_namedtuple(obj):
×
90
            return [type(obj)(*args) for args in zip(*map(_scatter_map, obj))]
×
91
        if isinstance(obj, tuple) and len(obj) > 0:
×
92
            return list(zip(*map(_scatter_map, obj)))
×
93
        if isinstance(obj, list) and len(obj) > 0:
×
94
            return [list(i) for i in zip(*map(_scatter_map, obj))]
×
95
        if isinstance(obj, dict) and len(obj) > 0:
×
96
            return [type(obj)(i) for i in zip(*map(_scatter_map, obj.items()))]
×
97
        return [obj for _ in target_gpus]
×
98

99
    try:
×
100
        res = _scatter_map(inputs)
×
101
    finally:
102
        _scatter_map = None
×
103

104
    return res
×
105

106

107
def _scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
×
108
    inputs = _scatter(inputs, target_gpus, dim) if inputs else []
×
109
    kwargs = _scatter(kwargs, target_gpus, dim) if kwargs else []
×
110
    if len(inputs) < len(kwargs):
×
111
        inputs.extend(() for _ in range(len(kwargs) - len(inputs)))
×
112
    elif len(kwargs) < len(inputs):
×
113
        kwargs.extend({} for _ in range(len(inputs) - len(kwargs)))
×
114
    return tuple(inputs), tuple(kwargs)
×
115

116

117
class NNDataParallel(DataParallel):
×
118
    """
119
    A :obj:`nn.DataParallel` class with :obj:`DataContainer` support. This
120
    class only bundles single-device modules.
121

122
    Args:
123
        module (:obj:`nn.Module`): The module to be bundled.
124
        device_id (int | None, optional): The device id to be used. ``None``
125
            means using the default device, and ``-1`` means CPU. Default:
126
            ``None``.
127
    """
128

NEW
129
    def __init__(self, module, device_id=None, dim=0, **kwargs):
×
NEW
130
        assert isinstance(device_id, int) or device_id is None
×
UNCOV
131
        assert 'device_ids' not in kwargs and 'output_device' not in kwargs
×
132

NEW
133
        if device_id is None:
×
NEW
134
            if torch.cuda.is_available():
×
NEW
135
                device_id = torch.cuda.current_device()
×
136
            else:
NEW
137
                device_id = -1
×
138

139
        if device_id == -1:
×
140
            logger = nncore.get_logger()
×
141
            logger.warn('{} is running on CPU'.format(self.__class__.__name__))
×
142

NEW
143
            super(DataParallel, self).__init__()
×
NEW
144
            torch._C._log_api_usage_once('torch.nn.parallel.DataParallel')
×
145

NEW
146
            self.module = module
×
NEW
147
            self.device_ids = []
×
NEW
148
            self.dim = dim
×
NEW
149
            return
×
150

151
        super(NNDataParallel, self).__init__(
×
152
            module, device_ids=[device_id], output_device=device_id, **kwargs)
153

154
    def scatter(self, inputs, kwargs, device_ids):
×
155
        return _scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
×
156

157
    def forward(self, *inputs, **kwargs):
×
158
        if self.device_ids:
×
159
            return super(NNDataParallel, self).forward(*inputs, **kwargs)
×
160
        else:
161
            inputs, kwargs = self.scatter(inputs, kwargs, [-1])
×
162
            return self.module(*inputs[0], **kwargs[0])
×
163

164

165
class NNDistributedDataParallel(DistributedDataParallel):
×
166
    """
167
    A :obj:`nn.DistributedDataParallel` class with :obj:`DataContainer`
168
    support. This class only bundles single-device modules.
169

170
    Args:
171
        module (:obj:`nn.Module`): The module to be bundled.
172
        device_id (int | None, optional): The device id to be used. ``None``
173
            means using the default device, and ``-1`` means CPU. Default:
174
            ``None``.
175
    """
176

NEW
177
    def __init__(self, module, device_id=None, **kwargs):
×
NEW
178
        assert isinstance(device_id, int) or device_id is None
×
NEW
179
        assert 'device_ids' not in kwargs and 'output_device' not in kwargs
×
180

NEW
181
        if device_id is None:
×
NEW
182
            if torch.cuda.is_available():
×
NEW
183
                device_id = torch.cuda.current_device()
×
184
            else:
NEW
185
                device_id = -1
×
186

187
        if device_id >= 0:
×
188
            module = module.to('cuda:{}'.format(device_id))
×
189
            device_ids = [device_id]
×
190
        else:
191
            logger = nncore.get_logger()
×
192
            logger.warn('{} is running on CPU'.format(self.__class__.__name__))
×
193
            device_ids = None
×
194

195
        super(NNDistributedDataParallel, self).__init__(
×
196
            module, device_ids=device_ids, **kwargs)
197

198
    def to_kwargs(self, inputs, kwargs, device_id):
×
199
        return _scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)
×
200

201
    def forward(self, *inputs, **kwargs):
×
202
        if self.device_ids:
×
203
            return super(NNDistributedDataParallel,
×
204
                         self).forward(*inputs, **kwargs)
205
        else:
206
            inputs, kwargs = self.scatter(inputs, kwargs, [-1])
×
207
            return self.module(*inputs[0], **kwargs[0])
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc