• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

saritasa-nest / django-import-export-extensions / 6956628324

22 Nov 2023 11:25AM UTC coverage: 78.492% (+0.4%) from 78.121%
6956628324

Pull #25

github

web-flow
Merge bc329a400 into a3aa8ce29
Pull Request #25: Add force-import feature

58 of 71 new or added lines in 8 files covered. (81.69%)

14 existing lines in 1 file now uncovered.

1135 of 1446 relevant lines covered (78.49%)

9.41 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.89
/import_export_extensions/resources.py
1
import typing
12✔
2
from enum import Enum
12✔
3

4
from django.core.exceptions import ValidationError
12✔
5
from django.db.models import QuerySet
12✔
6
from django.utils import timezone
12✔
7
from django.utils.functional import classproperty
12✔
8
from django.utils.translation import gettext_lazy as _
12✔
9

10
import tablib
12✔
11
from celery import current_task, result
12✔
12
from django_filters import rest_framework as filters
12✔
13
from django_filters.utils import translate_validation
12✔
14
from import_export import resources, results
12✔
15
from import_export.formats import base_formats
12✔
16
from import_export.results import Error as BaseError
12✔
17

18

19
class Error(BaseError):
12✔
20
    """Customization of over base Error class from import export."""
21

22
    def __repr__(self) -> str:
12✔
23
        """Return object representation in string format."""
UNCOV
24
        return f"Error({self.error})"
×
25

26
    def __reduce__(self):
12✔
27
        """Simplify Exception object for pickling.
28

29
        `error` object may contain not pickable objects (for example, django's
30
        lazy text), so here it replaced with simple string.
31

32
        """
33
        self.error = str(self.error)
12✔
34
        return super().__reduce__()
12✔
35

36

37
class TaskState(Enum):
12✔
38
    """Class with possible task state values."""
39
    IMPORTING = _("Importing")
12✔
40
    EXPORTING = _("Exporting")
12✔
41
    PARSING = _("Parsing")
12✔
42

43

44
class SkippedErrorsRowResult(results.RowResult):
12✔
45
    """Custom row result class with ability to store skipped errors in row."""
46
    def __init__(self, *args, **kwargs):
12✔
47
        self.non_field_skipped_errors: list[Error] = []
12✔
48
        self.field_skipped_errors: dict[str, list[ValidationError]] = dict()
12✔
49
        super().__init__()
12✔
50

51
    @property
12✔
52
    def has_skipped_errors(self) -> bool:
12✔
53
        """Return True if row contain any skipped errors."""
NEW
54
        if len(self.non_field_skipped_errors) > 0 or len(self.field_skipped_errors) > 0:
×
NEW
UNCOV
55
            return True
×
NEW
UNCOV
56
        return False
×
57

58
    @property
12✔
59
    def skipped_errors_count(self) -> int:
12✔
60
        """Return count of skipped errors."""
NEW
UNCOV
61
        return (
×
62
            len(self.non_field_skipped_errors)
63
            + len(self.field_skipped_errors)
64
        )
65

66

67
class SkippedErrorsResult(results.Result):
12✔
68
    """Custom result class with ability to store info about skipped rows."""
69

70
    @property
12✔
71
    def has_skipped_rows(self) -> bool:
12✔
72
        """Return True if contain any skipped rows."""
NEW
UNCOV
73
        if any(row.has_skipped_errors for row in self.rows):
×
NEW
UNCOV
74
            return True
×
NEW
UNCOV
75
        return False
×
76

77
    @property
12✔
78
    def skipped_rows(self) -> list[SkippedErrorsRowResult]:
12✔
79
        """Return all rows with skipped errors."""
NEW
UNCOV
80
        return list(
×
81
            filter(lambda row: row.has_skipped_errors, self.rows),
82
        )
83

84

85
class CeleryResourceMixin:
12✔
86
    """Mixin for resources for background import/export using celery."""
87
    filterset_class: typing.Type[filters.FilterSet]
12✔
88
    SUPPORTED_FORMATS: list[
12✔
89
        typing.Type[base_formats.Format]
90
    ] = base_formats.DEFAULT_FORMATS
91

92
    def __init__(
12✔
93
        self,
94
        filter_kwargs: typing.Optional[dict[str, typing.Any]] = None,
95
        **kwargs,
96
    ):
97
        """Remember init kwargs."""
98
        self._filter_kwargs = filter_kwargs
12✔
99
        self.resource_init_kwargs: dict[str, typing.Any] = kwargs
12✔
100
        super().__init__()
12✔
101

102
    def get_queryset(self):
12✔
103
        """Filter export queryset via filterset class."""
104
        queryset = super().get_queryset()
12✔
105
        if not self._filter_kwargs:
12✔
106
            return queryset
12✔
UNCOV
107
        filter_instance = self.filterset_class(
×
108
            data=self._filter_kwargs,
109
        )
UNCOV
110
        if not filter_instance.is_valid():
×
UNCOV
111
            raise translate_validation(filter_instance.errors)
×
UNCOV
112
        return filter_instance.filter_queryset(queryset=queryset)
×
113

114
    @classproperty
12✔
115
    def class_path(cls) -> str:
12✔
116
        """Get path of class to import it."""
117
        return ".".join([cls.__module__, cls.__name__])
12✔
118

119
    @classmethod
12✔
120
    def get_supported_formats(cls) -> list[typing.Type[base_formats.Format]]:
12✔
121
        """Get a list of supported formats."""
122
        return cls.SUPPORTED_FORMATS
12✔
123

124
    @classmethod
12✔
125
    def get_supported_extensions_map(cls) -> dict[
12✔
126
        str, typing.Type[base_formats.Format],
127
    ]:
128
        """Get a map of supported formats and their extensions."""
129
        return {
12✔
130
            supported_format().get_extension(): supported_format
131
            for supported_format in cls.SUPPORTED_FORMATS
132
        }
133

134
    def import_data(
12✔
135
        self,
136
        dataset: tablib.Dataset,
137
        dry_run: bool = False,
138
        raise_errors: bool = False,
139
        use_transactions: typing.Optional[bool] = None,
140
        collect_failed_rows: bool = False,
141
        rollback_on_validation_errors: bool = False,
142
        force_import: bool = False,
143
        **kwargs,
144
    ):
145
        """Init task state before importing.
146

147
        If `force_import=True`, then rows with errors will be skipped.
148

149
        """
150
        self.initialize_task_state(
12✔
151
            state=(
152
                TaskState.IMPORTING.name if not dry_run
153
                else TaskState.PARSING.name
154
            ),
155
            queryset=dataset,
156
        )
157
        return super().import_data(  # type: ignore
12✔
158
            dataset=dataset,
159
            dry_run=dry_run,
160
            raise_errors=raise_errors,
161
            use_transactions=use_transactions,
162
            collect_failed_rows=collect_failed_rows,
163
            rollback_on_validation_errors=rollback_on_validation_errors,
164
            force_import=force_import,
165
            **kwargs,
166
        )
167

168
    def import_row(
12✔
169
        self,
170
        row,
171
        instance_loader,
172
        using_transactions=True,
173
        dry_run=False,
174
        raise_errors=False,
175
        force_import=False,
176
        **kwargs,
177
    ):
178
        """Update task status as we import rows.
179

180
        If `force_import=True`, then row errors will be stored in
181
        `field_skipped_errors` or `non_field_skipped_errors`.
182

183
        """
184
        imported_row: SkippedErrorsRowResult = super().import_row(
12✔
185
            row=row,
186
            instance_loader=instance_loader,
187
            using_transactions=using_transactions,
188
            dry_run=dry_run,
189
            raise_errors=raise_errors,
190
            **kwargs,
191
        )
192
        self.update_task_state(
12✔
193
            state=(
194
                TaskState.IMPORTING.name if not dry_run
195
                else TaskState.PARSING.name
196
            ),
197
        )
198
        if not force_import:
12✔
199
            return imported_row
12✔
200
        if (
12✔
201
            imported_row.import_type == results.RowResult.IMPORT_TYPE_ERROR
202
            or imported_row.import_type == results.RowResult.IMPORT_TYPE_INVALID
203
        ):
204
            imported_row.diff = []
12✔
205
            for field in self.get_fields():
12✔
206
                imported_row.diff.append(row.get(field.column_name, ""))
12✔
207

208
            imported_row.non_field_skipped_errors.extend(
12✔
209
                imported_row.errors,
210
            )
211
            if imported_row.validation_error is not None:
12✔
NEW
UNCOV
212
                imported_row.field_skipped_errors.update(
×
213
                    **imported_row.validation_error.error_dict,
214
                )
215
            imported_row.errors = []
12✔
216
            imported_row.validation_error = None
12✔
217

218
            imported_row.import_type = results.RowResult.IMPORT_TYPE_SKIP
12✔
219
        return imported_row
12✔
220

221
    @classmethod
12✔
222
    def get_row_result_class(self):
12✔
223
        """Return custom row result class."""
224
        return SkippedErrorsRowResult
12✔
225

226
    @classmethod
12✔
227
    def get_result_class(self):
12✔
228
        """Geti custom result class."""
229
        return SkippedErrorsResult
12✔
230

231
    def export(
12✔
232
        self,
233
        queryset: QuerySet = None,
234
        *args,
235
        **kwargs,
236
    ) -> tablib.Dataset:
237
        """Init task state before exporting."""
238
        if queryset is None:
12✔
239
            queryset = self.get_queryset()
12✔
240
        self.initialize_task_state(
12✔
241
            state=TaskState.EXPORTING.name,
242
            queryset=queryset,
243
        )
244
        return super().export(  # type: ignore
12✔
245
            queryset=queryset,
246
            *args,
247
            **kwargs,
248
        )
249

250
    def export_resource(self, obj):
12✔
251
        """Update task status as we export rows."""
252
        resource = [
12✔
253
            self.export_field(field, obj) for field in self.get_export_fields()
254
        ]
255
        self.update_task_state(state=TaskState.EXPORTING.name)
12✔
256
        return resource
12✔
257

258
    def initialize_task_state(
12✔
259
        self,
260
        state: str,
261
        queryset: typing.Union[QuerySet, tablib.Dataset],
262
    ):
263
        """Set initial state of the task to track progress.
264

265
        Counts total number of instances to import/export and
266
        generate state for the task.
267

268
        """
269
        if not current_task or current_task.request.called_directly:
12✔
270
            return
12✔
271

272
        if isinstance(queryset, QuerySet):
12✔
273
            total = queryset.count()
12✔
274
        else:
275
            total = len(queryset)
12✔
276

277
        self._update_current_task_state(
12✔
278
            state=state,
279
            meta=dict(
280
                current=0,
281
                total=total,
282
            ),
283
        )
284

285
    def update_task_state(
12✔
286
        self,
287
        state: str,
288
    ):
289
        """Update state of the current event.
290

291
        Receives meta of the current task and increase the `current`
292
        field by 1.
293

294
        """
295
        if not current_task or current_task.request.called_directly:
12✔
296
            return
12✔
297

298
        async_result = result.AsyncResult(current_task.request.get("id"))
12✔
299
        if not async_result.result:
12✔
UNCOV
300
            return
×
301

302
        self._update_current_task_state(
12✔
303
            state=state,
304
            meta=dict(
305
                current=async_result.result.get("current", 0) + 1,
306
                total=async_result.result.get("total", 0),
307
            ),
308
        )
309

310
    def _update_current_task_state(self, state: str, meta: dict[str, int]):
12✔
311
        """Update state of task where resource is executed."""
312
        current_task.update_state(
12✔
313
            state=state,
314
            meta=meta,
315
        )
316

317
    def generate_export_filename(self, file_format: base_formats.Format):
12✔
318
        """Generate export filename."""
319
        return self._generate_export_filename_from_model(file_format)
12✔
320

321
    def _generate_export_filename_from_model(
12✔
322
        self,
323
        file_format: base_formats.Format,
324
    ):
325
        """Generate export file name from model name."""
326
        model = self._meta.model._meta.verbose_name_plural
12✔
327
        date_str = timezone.now().strftime("%Y-%m-%d")
12✔
328
        extension = file_format.get_extension()
12✔
329
        return f"{model}-{date_str}.{extension}"
12✔
330

331
    @classmethod
12✔
332
    def get_error_result_class(cls):
12✔
333
        """Override default error class."""
334
        return Error
12✔
335

336

337
class CeleryResource(CeleryResourceMixin, resources.Resource):
12✔
338
    """Resource which supports importing via celery."""
339

340

341
class CeleryModelResource(CeleryResourceMixin, resources.ModelResource):
12✔
342
    """ModelResource which supports importing via celery."""
343

344
    @classmethod
12✔
345
    def get_model_queryset(cls):
12✔
346
        """Return a queryset of all objects for this model.
347

348
        Override this if you
349
        want to limit the returned queryset.
350

351
        Same as resources.ModelResource get_queryset.
352

353
        """
354
        return cls._meta.model.objects.all()
12✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc