• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

yeliudev / nncore / 6442728583

07 Oct 2023 06:10PM UTC coverage: 16.251% (-0.2%) from 16.456%
6442728583

push

github

yeliudev
Add support for droppath in transformers

148 of 148 new or added lines in 11 files covered. (100.0%)

678 of 4172 relevant lines covered (16.25%)

3.14 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/nncore/parallel/parallel.py
1
# Copyright (c) Ye Liu. Licensed under the MIT License.
2

3
import torch
×
4
from torch.nn.parallel import DataParallel, DistributedDataParallel
×
5
from torch.nn.parallel._functions import Function, Scatter, _get_stream
×
6
from torch.nn.parallel.scatter_gather import _is_namedtuple
×
7

8
from .container import DataContainer
×
9

10

11
class _Scatter(Function):
×
12

13
    @staticmethod
×
14
    def forward(target_gpus, input):
×
15
        input_device = _get_input_device(input)
×
16
        streams = None
×
17
        if input_device == -1 and target_gpus != [-1]:
×
18
            streams = [_get_stream(device) for device in target_gpus]
×
19

20
        outputs = _scatter_stream(input, target_gpus, streams)
×
21
        if streams is not None:
×
22
            _sync_stream(outputs, target_gpus, streams)
×
23

24
        return tuple(outputs)
×
25

26

27
def _get_input_device(input):
×
28
    if isinstance(input, list):
×
29
        for item in input:
×
30
            input_device = _get_input_device(item)
×
31
            if input_device != -1:
×
32
                return input_device
×
33
        return -1
×
34
    elif torch.is_tensor(input):
×
35
        return input.get_device() if input.is_cuda else -1
×
36
    else:
37
        raise TypeError('unknown type {}'.format(type(input)))
×
38

39

40
def _scatter_stream(input, devices, streams=None):
×
41
    if streams is None:
×
42
        streams = [None] * len(devices)
×
43

44
    if isinstance(input, list):
×
45
        chunk_size = (len(input) - 1) // len(devices) + 1
×
46
        outputs = [
×
47
            _scatter_stream(input[i], [devices[i // chunk_size]],
48
                            [streams[i // chunk_size]])
49
            for i in range(len(input))
50
        ]
51
        return outputs
×
52
    elif torch.is_tensor(input):
×
53
        output = input.contiguous()
×
54
        stream = streams[0] if output.numel() > 0 else None
×
55
        if devices != [-1]:
×
56
            with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
×
57
                output = output.cuda(devices[0], non_blocking=True)
×
58
        return output
×
59
    else:
60
        raise TypeError('unknown type {}'.format(type(input)))
×
61

62

63
def _sync_stream(output, devices, streams):
×
64
    if isinstance(output, list):
×
65
        chunk_size = len(output) // len(devices)
×
66
        for i in range(len(devices)):
×
67
            for j in range(chunk_size):
×
68
                _sync_stream(output[i * chunk_size + j], [devices[i]],
×
69
                             [streams[i]])
70
    elif torch.is_tensor(output):
×
71
        if output.numel() != 0:
×
72
            with torch.cuda.device(devices[0]):
×
73
                main_stream = torch.cuda.current_stream()
×
74
                main_stream.wait_stream(streams[0])
×
75
                output.record_stream(main_stream)
×
76
    else:
77
        raise TypeError('unknown type {}'.format(type(output)))
×
78

79

80
def _scatter(inputs, target_gpus, dim=0):
×
81

82
    def _scatter_map(obj):
×
83
        if torch.is_tensor(obj):
×
84
            if target_gpus != [-1]:
×
85
                return Scatter.apply(target_gpus, None, dim, obj)
×
86
            else:
87
                return _Scatter.forward(target_gpus, obj)
×
88
        if isinstance(obj, DataContainer):
×
89
            if obj.cpu_only:
×
90
                return obj.data
×
91
            else:
92
                return _Scatter.forward(target_gpus, obj.data)
×
93
        if _is_namedtuple(obj):
×
94
            return [type(obj)(*args) for args in zip(*map(_scatter_map, obj))]
×
95
        if isinstance(obj, tuple) and len(obj) > 0:
×
96
            return list(zip(*map(_scatter_map, obj)))
×
97
        if isinstance(obj, list) and len(obj) > 0:
×
98
            return [list(i) for i in zip(*map(_scatter_map, obj))]
×
99
        if isinstance(obj, dict) and len(obj) > 0:
×
100
            return [type(obj)(i) for i in zip(*map(_scatter_map, obj.items()))]
×
101
        return [obj for _ in target_gpus]
×
102

103
    try:
×
104
        res = _scatter_map(inputs)
×
105
    finally:
106
        _scatter_map = None
×
107

108
    return res
×
109

110

111
def _scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
×
112
    inputs = _scatter(inputs, target_gpus, dim) if inputs else []
×
113
    kwargs = _scatter(kwargs, target_gpus, dim) if kwargs else []
×
114
    if len(inputs) < len(kwargs):
×
115
        inputs.extend(() for _ in range(len(kwargs) - len(inputs)))
×
116
    elif len(kwargs) < len(inputs):
×
117
        kwargs.extend({} for _ in range(len(inputs) - len(kwargs)))
×
118
    return tuple(inputs), tuple(kwargs)
×
119

120

121
class NNDataParallel(DataParallel):
×
122
    """
123
    A :obj:`nn.DataParallel` class with :obj:`DataContainer` support. This
124
    class only bundles single-device modules.
125
    """
126

127
    def __init__(self, module, device_ids=None, dim=0, **kwargs):
×
128
        assert device_ids is None or len(device_ids) <= 1
×
129
        super(NNDataParallel, self).__init__(
×
130
            module,
131
            device_ids=[0] if device_ids is None else device_ids,
132
            **kwargs)
133
        self.dim = dim
×
134

135
    def scatter(self, inputs, kwargs, device_ids):
×
136
        return _scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
×
137

138
    def forward(self, *inputs, **kwargs):
×
139
        if self.device_ids:
×
140
            return super(NNDataParallel, self).forward(*inputs, **kwargs)
×
141
        else:
142
            inputs, kwargs = self.scatter(inputs, kwargs, [-1])
×
143
            return self.module(*inputs[0], **kwargs[0])
×
144

145

146
class NNDistributedDataParallel(DistributedDataParallel):
×
147
    """
148
    A :obj:`nn.DistributedDataParallel` class with :obj:`DataContainer`
149
    support. This class only bundles single-device modules.
150
    """
151

152
    def __init__(self,
×
153
                 module,
154
                 device_ids=None,
155
                 broadcast_buffers=False,
156
                 **kwargs):
157
        assert device_ids is None or len(device_ids) <= 1
×
158

159
        if device_ids is None:
×
160
            if torch.cuda.is_available():
×
161
                device_ids = [torch.cuda.current_device()]
×
162
                module = module.cuda()
×
163
        elif len(device_ids) == 1:
×
164
            module = module.to('cuda:{}'.format(device_ids[0]))
×
165

166
        super(NNDistributedDataParallel, self).__init__(
×
167
            module,
168
            device_ids=device_ids,
169
            broadcast_buffers=broadcast_buffers,
170
            **kwargs)
171

172
    def _run_ddp_forward(self, *inputs, **kwargs):
×
173
        if self._use_replicated_tensor_module:
×
174
            module = self._replicated_tensor_module
×
175
        else:
176
            module = self.module
×
177

178
        if self.device_ids:
×
179
            inputs, kwargs = _scatter_kwargs(
×
180
                inputs, kwargs, self.device_ids, dim=self.dim)
181
            with self._inside_ddp_forward():
×
182
                return module(*inputs[0], **kwargs[0])
×
183
        else:
184
            with self._inside_ddp_forward():
×
185
                return module(*inputs, **kwargs)
×
186

187
    def scatter(self, inputs, kwargs, device_ids):
×
188
        return _scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
×
189

190
    def forward(self, *inputs, **kwargs):
×
191
        if self.device_ids:
×
192
            return super(NNDistributedDataParallel,
×
193
                         self).forward(*inputs, **kwargs)
194
        else:
195
            inputs, kwargs = self.scatter(inputs, kwargs, [-1])
×
196
            return self.module(*inputs[0], **kwargs[0])
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc