• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

inventree / InvenTree / 4361124568

pending completion
4361124568

push

github

GitHub
Unit test speed improvements (#4463)

181 of 181 new or added lines in 20 files covered. (100.0%)

25546 of 29143 relevant lines covered (87.66%)

0.88 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.68
/InvenTree/InvenTree/api_tester.py
1
"""Helper functions for performing API unit tests."""
2

3
import csv
1✔
4
import io
1✔
5
import re
1✔
6

7
from django.contrib.auth import get_user_model
1✔
8
from django.contrib.auth.models import Group
1✔
9
from django.http.response import StreamingHttpResponse
1✔
10

11
from djmoney.contrib.exchange.models import ExchangeBackend, Rate
1✔
12
from rest_framework.test import APITestCase
1✔
13

14
from plugin import registry
1✔
15
from plugin.models import PluginConfig
1✔
16

17

18
class UserMixin:
1✔
19
    """Mixin to setup a user and login for tests.
20

21
    Use parameters to set username, password, email, roles and permissions.
22
    """
23

24
    # User information
25
    username = 'testuser'
1✔
26
    password = 'mypassword'
1✔
27
    email = 'test@testing.com'
1✔
28

29
    superuser = False
1✔
30
    is_staff = True
1✔
31
    auto_login = True
1✔
32

33
    # Set list of roles automatically associated with the user
34
    roles = []
1✔
35

36
    @classmethod
1✔
37
    def setUpTestData(cls):
1✔
38
        """Run setup for all tests in a given class"""
39
        super().setUpTestData()
1✔
40

41
        # Create a user to log in with
42
        cls.user = get_user_model().objects.create_user(
1✔
43
            username=cls.username,
44
            password=cls.password,
45
            email=cls.email
46
        )
47

48
        # Create a group for the user
49
        cls.group = Group.objects.create(name='my_test_group')
1✔
50
        cls.user.groups.add(cls.group)
1✔
51

52
        if cls.superuser:
1✔
53
            cls.user.is_superuser = True
1✔
54

55
        if cls.is_staff:
1✔
56
            cls.user.is_staff = True
1✔
57

58
        cls.user.save()
1✔
59

60
        # Assign all roles if set
61
        if cls.roles == 'all':
1✔
62
            cls.assignRole(group=cls.group, assign_all=True)
1✔
63

64
        # else filter the roles
65
        else:
66
            for role in cls.roles:
1✔
67
                cls.assignRole(role=role, group=cls.group)
1✔
68

69
    def setUp(self):
1✔
70
        """Run setup for individual test methods"""
71

72
        if self.auto_login:
1✔
73
            self.client.login(username=self.username, password=self.password)
1✔
74

75
    @classmethod
1✔
76
    def assignRole(cls, role=None, assign_all: bool = False, group=None):
1✔
77
        """Set the user roles for the registered user.
78

79
        Arguments:
80
            role: Role of the format 'rule.permission' e.g. 'part.add'
81
            assign_all: Set to True to assign *all* roles
82
            group: The group to assign roles to (or leave None to use the group assigned to this class)
83
        """
84

85
        if group is None:
1✔
86
            group = cls.group
1✔
87

88
        if type(assign_all) is not bool:
1✔
89
            # Raise exception if common mistake is made!
90
            raise TypeError('assignRole: assign_all must be a boolean value')
×
91

92
        if not role and not assign_all:
1✔
93
            raise ValueError('assignRole: either role must be provided, or assign_all must be set')
×
94

95
        if not assign_all and role:
1✔
96
            rule, perm = role.split('.')
1✔
97

98
        for ruleset in group.rule_sets.all():
1✔
99

100
            if assign_all or ruleset.name == rule:
1✔
101

102
                if assign_all or perm == 'view':
1✔
103
                    ruleset.can_view = True
1✔
104
                elif assign_all or perm == 'change':
1✔
105
                    ruleset.can_change = True
1✔
106
                elif assign_all or perm == 'delete':
1✔
107
                    ruleset.can_delete = True
1✔
108
                elif assign_all or perm == 'add':
1✔
109
                    ruleset.can_add = True
1✔
110

111
                ruleset.save()
1✔
112
                break
1✔
113

114

115
class PluginMixin:
1✔
116
    """Mixin to ensure that all plugins are loaded for tests."""
117

118
    def setUp(self):
1✔
119
        """Setup for plugin tests."""
120
        super().setUp()
1✔
121

122
        # Load plugin configs
123
        self.plugin_confs = PluginConfig.objects.all()
1✔
124
        # Reload if not present
125
        if not self.plugin_confs:
1✔
126
            registry.reload_plugins()
1✔
127
            self.plugin_confs = PluginConfig.objects.all()
1✔
128

129

130
class ExchangeRateMixin:
1✔
131
    """Mixin class for generating exchange rate data"""
132

133
    def generate_exchange_rates(self):
1✔
134
        """Helper function which generates some exchange rates to work with"""
135

136
        rates = {
1✔
137
            'AUD': 1.5,
138
            'CAD': 1.7,
139
            'GBP': 0.9,
140
            'USD': 1.0,
141
        }
142

143
        # Create a dummy backend
144
        ExchangeBackend.objects.create(
1✔
145
            name='InvenTreeExchange',
146
            base_currency='USD',
147
        )
148

149
        backend = ExchangeBackend.objects.get(name='InvenTreeExchange')
1✔
150

151
        items = []
1✔
152

153
        for currency, rate in rates.items():
1✔
154
            items.append(
1✔
155
                Rate(
156
                    currency=currency,
157
                    value=rate,
158
                    backend=backend,
159
                )
160
            )
161

162
        Rate.objects.bulk_create(items)
1✔
163

164

165
class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase):
1✔
166
    """Base class for running InvenTree API tests."""
167

168
    def getActions(self, url):
1✔
169
        """Return a dict of the 'actions' available at a given endpoint.
170

171
        Makes use of the HTTP 'OPTIONS' method to request this.
172
        """
173
        response = self.client.options(url)
1✔
174
        self.assertEqual(response.status_code, 200)
1✔
175

176
        actions = response.data.get('actions', None)
1✔
177

178
        if not actions:
1✔
179
            actions = {}
×
180

181
        return actions
1✔
182

183
    def get(self, url, data=None, expected_code=200, format='json'):
1✔
184
        """Issue a GET request."""
185
        # Set default - see B006
186
        if data is None:
1✔
187
            data = {}
1✔
188

189
        response = self.client.get(url, data, format=format)
1✔
190

191
        if expected_code is not None:
1✔
192

193
            if response.status_code != expected_code:
1✔
194
                print(f"Unexpected response at '{url}': status_code = {response.status_code}")
×
195
                print(response.data)
×
196

197
            self.assertEqual(response.status_code, expected_code)
1✔
198

199
        return response
1✔
200

201
    def post(self, url, data=None, expected_code=None, format='json'):
1✔
202
        """Issue a POST request."""
203

204
        # Set default value - see B006
205
        if data is None:
1✔
206
            data = {}
1✔
207

208
        response = self.client.post(url, data=data, format=format)
1✔
209

210
        if expected_code is not None:
1✔
211

212
            if response.status_code != expected_code:
1✔
213
                print(f"Unexpected response at '{url}': status code = {response.status_code}")
×
214

215
                if hasattr(response, 'data'):
×
216
                    print(response.data)
×
217
                else:
218
                    print(f"(response object {type(response)} has no 'data' attribute")
×
219

220
            self.assertEqual(response.status_code, expected_code)
1✔
221

222
        return response
1✔
223

224
    def delete(self, url, data=None, expected_code=None, format='json'):
1✔
225
        """Issue a DELETE request."""
226

227
        if data is None:
1✔
228
            data = {}
1✔
229

230
        response = self.client.delete(url, data=data, format=format)
1✔
231

232
        if expected_code is not None:
1✔
233
            self.assertEqual(response.status_code, expected_code)
1✔
234

235
        return response
1✔
236

237
    def patch(self, url, data, expected_code=None, format='json'):
1✔
238
        """Issue a PATCH request."""
239
        response = self.client.patch(url, data=data, format=format)
1✔
240

241
        if expected_code is not None:
1✔
242
            self.assertEqual(response.status_code, expected_code)
1✔
243

244
        return response
1✔
245

246
    def put(self, url, data, expected_code=None, format='json'):
1✔
247
        """Issue a PUT request."""
248
        response = self.client.put(url, data=data, format=format)
1✔
249

250
        if expected_code is not None:
1✔
251

252
            if response.status_code != expected_code:
1✔
253
                print(f"Unexpected response at '{url}':")
×
254
                print(response.data)
×
255

256
            self.assertEqual(response.status_code, expected_code)
1✔
257

258
        return response
1✔
259

260
    def options(self, url, expected_code=None):
1✔
261
        """Issue an OPTIONS request."""
262
        response = self.client.options(url, format='json')
1✔
263

264
        if expected_code is not None:
1✔
265
            self.assertEqual(response.status_code, expected_code)
1✔
266

267
        return response
1✔
268

269
    def download_file(self, url, data, expected_code=None, expected_fn=None, decode=True):
1✔
270
        """Download a file from the server, and return an in-memory file."""
271
        response = self.client.get(url, data=data, format='json')
1✔
272

273
        if expected_code is not None:
1✔
274
            self.assertEqual(response.status_code, expected_code)
1✔
275

276
        # Check that the response is of the correct type
277
        if not isinstance(response, StreamingHttpResponse):
1✔
278
            raise ValueError("Response is not a StreamingHttpResponse object as expected")
1✔
279

280
        # Extract filename
281
        disposition = response.headers['Content-Disposition']
1✔
282

283
        result = re.search(r'attachment; filename="([\w.]+)"', disposition)
1✔
284

285
        fn = result.groups()[0]
1✔
286

287
        if expected_fn is not None:
1✔
288
            self.assertEqual(expected_fn, fn)
1✔
289

290
        if decode:
1✔
291
            # Decode data and return as StringIO file object
292
            fo = io.StringIO()
1✔
293
            fo.name = fo
1✔
294
            fo.write(response.getvalue().decode('UTF-8'))
1✔
295
        else:
296
            # Return a a BytesIO file object
297
            fo = io.BytesIO()
1✔
298
            fo.name = fn
1✔
299
            fo.write(response.getvalue())
1✔
300

301
        fo.seek(0)
1✔
302

303
        return fo
1✔
304

305
    def process_csv(self, fo, delimiter=',', required_cols=None, excluded_cols=None, required_rows=None):
1✔
306
        """Helper function to process and validate a downloaded csv file."""
307
        # Check that the correct object type has been passed
308
        self.assertTrue(isinstance(fo, io.StringIO))
1✔
309

310
        fo.seek(0)
1✔
311

312
        reader = csv.reader(fo, delimiter=delimiter)
1✔
313

314
        headers = []
1✔
315
        rows = []
1✔
316

317
        for idx, row in enumerate(reader):
1✔
318
            if idx == 0:
1✔
319
                headers = row
1✔
320
            else:
321
                rows.append(row)
1✔
322

323
        if required_cols is not None:
1✔
324
            for col in required_cols:
1✔
325
                self.assertIn(col, headers)
1✔
326

327
        if excluded_cols is not None:
1✔
328
            for col in excluded_cols:
1✔
329
                self.assertNotIn(col, headers)
1✔
330

331
        if required_rows is not None:
1✔
332
            self.assertEqual(len(rows), required_rows)
1✔
333

334
        # Return the file data as a list of dict items, based on the headers
335
        data = []
1✔
336

337
        for row in rows:
1✔
338
            entry = {}
1✔
339

340
            for idx, col in enumerate(headers):
1✔
341
                entry[col] = row[idx]
1✔
342

343
            data.append(entry)
1✔
344

345
        return data
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc