• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

yeliudev / nncore / 8330591720

18 Mar 2024 04:59PM UTC coverage: 15.882% (-0.004%) from 15.886%
8330591720

push

github

yeliudev
Fix transformer block for ddp

0 of 24 new or added lines in 2 files covered. (0.0%)

3 existing lines in 1 file now uncovered.

678 of 4269 relevant lines covered (15.88%)

3.07 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/nncore/parallel/parallel.py
1
# Copyright (c) Ye Liu. Licensed under the MIT License.
2

3
import torch
×
4
from torch.nn.parallel import DataParallel, DistributedDataParallel
×
5
from torch.nn.parallel._functions import Scatter, _get_stream
×
6
from torch.nn.parallel.scatter_gather import _is_namedtuple
×
7

8
import nncore
×
9
from .container import DataContainer
×
10

11

12
class _Scatter(torch.autograd.Function):
×
13

14
    @staticmethod
×
15
    def forward(target_gpus, input):
×
16
        input_device = _get_input_device(input)
×
17
        streams = None
×
18
        if input_device == -1 and target_gpus != [-1]:
×
19
            streams = [
×
20
                _get_stream(torch.device('cuda', gpu_id))
21
                for gpu_id in target_gpus
22
            ]
23
        outputs = _scatter_stream(input, target_gpus, streams)
×
24
        if streams is not None:
×
25
            _sync_stream(outputs, target_gpus, streams)
×
26
        return tuple(outputs)
×
27

28

29
def _get_input_device(input):
×
30
    if isinstance(input, list):
×
31
        for item in input:
×
32
            input_device = _get_input_device(item)
×
33
            if input_device != -1:
×
34
                return input_device
×
35
        return -1
×
36
    elif torch.is_tensor(input):
×
37
        return input.get_device() if input.is_cuda else -1
×
38
    else:
39
        raise TypeError('unknown type {}'.format(type(input)))
×
40

41

42
def _scatter_stream(input, devices, streams=None):
×
43
    if streams is None:
×
44
        streams = [None] * len(devices)
×
45
    if isinstance(input, list):
×
46
        chunk_size = (len(input) - 1) // len(devices) + 1
×
47
        outputs = [
×
48
            _scatter_stream(input[i], [devices[i // chunk_size]],
49
                            [streams[i // chunk_size]])
50
            for i in range(len(input))
51
        ]
52
        return outputs
×
53
    elif torch.is_tensor(input):
×
54
        output = input.contiguous()
×
55
        stream = streams[0] if output.numel() > 0 else None
×
56
        if devices != [-1]:
×
57
            with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
×
58
                output = output.cuda(devices[0], non_blocking=True)
×
59
        return output
×
60
    else:
61
        raise TypeError('unknown type {}'.format(type(input)))
×
62

63

64
def _sync_stream(output, devices, streams):
×
65
    if isinstance(output, list):
×
66
        chunk_size = len(output) // len(devices)
×
67
        for i in range(len(devices)):
×
68
            for j in range(chunk_size):
×
69
                _sync_stream(output[i * chunk_size + j], [devices[i]],
×
70
                             [streams[i]])
71
    elif torch.is_tensor(output):
×
72
        if output.numel() != 0:
×
73
            with torch.cuda.device(devices[0]):
×
74
                main_stream = torch.cuda.current_stream()
×
75
                main_stream.wait_stream(streams[0])
×
76
                output.record_stream(main_stream)
×
77
    else:
78
        raise TypeError('unknown type {}'.format(type(output)))
×
79

80

81
def _scatter(inputs, target_gpus, dim=0):
×
82

83
    def _scatter_map(obj):
×
84
        if torch.is_tensor(obj) and target_gpus != [-1]:
×
85
            return Scatter.apply(target_gpus, None, dim, obj)
×
86
        if isinstance(obj, DataContainer):
×
87
            return obj.data if obj.cpu_only else _Scatter.forward(
×
88
                target_gpus, obj.data)
89
        if _is_namedtuple(obj):
×
90
            return [type(obj)(*args) for args in zip(*map(_scatter_map, obj))]
×
91
        if isinstance(obj, tuple) and len(obj) > 0:
×
92
            return list(zip(*map(_scatter_map, obj)))
×
93
        if isinstance(obj, list) and len(obj) > 0:
×
94
            return [list(i) for i in zip(*map(_scatter_map, obj))]
×
95
        if isinstance(obj, dict) and len(obj) > 0:
×
96
            return [type(obj)(i) for i in zip(*map(_scatter_map, obj.items()))]
×
97
        return [obj for _ in target_gpus]
×
98

99
    try:
×
100
        res = _scatter_map(inputs)
×
101
    finally:
102
        _scatter_map = None
×
103

104
    return res
×
105

106

107
def _scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
×
108
    inputs = _scatter(inputs, target_gpus, dim) if inputs else []
×
109
    kwargs = _scatter(kwargs, target_gpus, dim) if kwargs else []
×
110
    if len(inputs) < len(kwargs):
×
111
        inputs.extend(() for _ in range(len(kwargs) - len(inputs)))
×
112
    elif len(kwargs) < len(inputs):
×
113
        kwargs.extend({} for _ in range(len(inputs) - len(kwargs)))
×
114
    return tuple(inputs), tuple(kwargs)
×
115

116

NEW
117
def _get_device(device_id=None):
×
NEW
118
    if device_id is not None:
×
NEW
119
        return device_id
×
NEW
120
    if torch.cuda.is_available():
×
NEW
121
        return torch.cuda.current_device()
×
NEW
122
    return -1
×
123

124

UNCOV
125
class NNDataParallel(DataParallel):
×
126
    """
127
    A :obj:`nn.DataParallel` class with :obj:`DataContainer` support. This
128
    class only bundles single-device modules.
129

130
    Args:
131
        module (:obj:`nn.Module`): The module to be bundled.
132
        device_id (int | None, optional): The device id to be used. ``None``
133
            means using the default device, and ``-1`` means CPU. Default:
134
            ``None``.
135
    """
136

137
    def __init__(self, module, device_id=None, dim=0, **kwargs):
×
138
        assert isinstance(device_id, int) or device_id is None
×
139
        assert 'device_ids' not in kwargs and 'output_device' not in kwargs
×
140

NEW
141
        device_id = _get_device(device_id)
×
142

NEW
143
        if device_id >= 0:
×
NEW
144
            super(NNDataParallel, self).__init__(
×
145
                module,
146
                device_ids=[device_id],
147
                output_device=device_id,
148
                **kwargs)
149
        else:
150
            logger = nncore.get_logger()
×
151
            logger.warn('{} is running on CPU'.format(self.__class__.__name__))
×
152

153
            super(DataParallel, self).__init__()
×
154
            torch._C._log_api_usage_once('torch.nn.parallel.DataParallel')
×
155

156
            self.module = module
×
157
            self.device_ids = []
×
158
            self.dim = dim
×
159

UNCOV
160
    def scatter(self, inputs, kwargs, device_ids):
×
161
        return _scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
×
162

163
    def forward(self, *inputs, **kwargs):
×
164
        if self.device_ids:
×
165
            return super(NNDataParallel, self).forward(*inputs, **kwargs)
×
166
        else:
167
            inputs, kwargs = self.scatter(inputs, kwargs, [-1])
×
168
            return self.module(*inputs[0], **kwargs[0])
×
169

170

171
class NNDistributedDataParallel(DistributedDataParallel):
×
172
    """
173
    A :obj:`nn.DistributedDataParallel` class with :obj:`DataContainer`
174
    support. This class only bundles single-device modules.
175

176
    Args:
177
        module (:obj:`nn.Module`): The module to be bundled.
178
        device_id (int | None, optional): The device id to be used. ``None``
179
            means using the default device, and ``-1`` means CPU. Default:
180
            ``None``.
181
    """
182

183
    def __init__(self, module, device_id=None, **kwargs):
×
184
        assert isinstance(device_id, int) or device_id is None
×
185
        assert 'device_ids' not in kwargs and 'output_device' not in kwargs
×
186

NEW
187
        device_id = _get_device(device_id)
×
188

189
        if device_id >= 0:
×
190
            module = module.to('cuda:{}'.format(device_id))
×
191
            device_ids = [device_id]
×
192
        else:
193
            logger = nncore.get_logger()
×
194
            logger.warn('{} is running on CPU'.format(self.__class__.__name__))
×
195
            device_ids = None
×
196

197
        super(NNDistributedDataParallel, self).__init__(
×
198
            module, device_ids=device_ids, **kwargs)
199

NEW
200
    def _pre_forward(self, *inputs, **kwargs):
×
201
        if self.device_ids:
×
NEW
202
            inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
×
UNCOV
203
            inputs, kwargs = inputs[0], kwargs[0]
×
NEW
204
        return super(NNDistributedDataParallel,
×
205
                     self)._pre_forward(*inputs, **kwargs)
206

NEW
207
    def scatter(self, inputs, kwargs, device_ids):
×
NEW
208
        return _scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
×
209

210
    def forward(self, *inputs, **kwargs):
×
211
        if self.device_ids:
×
212
            return super(NNDistributedDataParallel,
×
213
                         self).forward(*inputs, **kwargs)
214
        else:
215
            inputs, kwargs = self.scatter(inputs, kwargs, [-1])
×
216
            return self.module(*inputs[0], **kwargs[0])
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc