• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

saritasa-nest / django-import-export-extensions / 6943835407

21 Nov 2023 12:32PM UTC coverage: 78.492% (+0.4%) from 78.121%
6943835407

Pull #25

github

web-flow
Merge 2c3246b3f into a3aa8ce29
Pull Request #25: Feature/force import

46 of 55 new or added lines in 7 files covered. (83.64%)

26 existing lines in 3 files now uncovered.

1135 of 1446 relevant lines covered (78.49%)

9.41 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.81
/import_export_extensions/resources.py
1
import typing
12✔
2
from enum import Enum
12✔
3

4
from django.db.models import QuerySet
12✔
5
from django.utils import timezone
12✔
6
from django.utils.functional import classproperty
12✔
7
from django.utils.translation import gettext_lazy as _
12✔
8

9
import tablib
12✔
10
from celery import current_task, result
12✔
11
from django_filters import rest_framework as filters
12✔
12
from django_filters.utils import translate_validation
12✔
13
from import_export import resources, results
12✔
14
from import_export.formats import base_formats
12✔
15
from import_export.results import Error as BaseError
12✔
16

17

18
class Error(BaseError):
12✔
19
    """Customization of over base Error class from import export."""
20

21
    def __repr__(self) -> str:
12✔
22
        """Return object representation in string format."""
UNCOV
23
        return f"Error({self.error})"
×
24

25
    def __reduce__(self):
12✔
26
        """Simplify Exception object for pickling.
27

28
        `error` object may contain not pickable objects (for example, django's
29
        lazy text), so here it replaced with simple string.
30

31
        """
32
        self.error = str(self.error)
12✔
33
        return super().__reduce__()
12✔
34

35

36
class TaskState(Enum):
12✔
37
    """Class with possible task state values."""
38
    IMPORTING = _("Importing")
12✔
39
    EXPORTING = _("Exporting")
12✔
40
    PARSING = _("Parsing")
12✔
41

42

43
class SkippedErrorsRowResult(results.RowResult):
12✔
44
    """Custom row result class with ability to store skipped errors in row."""
45
    def __init__(self, *args, **kwargs):
12✔
46
        self.non_field_skipped_errors: list[str] = []
12✔
47
        self.field_skipped_errors: dict[str, str] = dict()
12✔
48
        super().__init__()
12✔
49

50
    @property
12✔
51
    def has_skipped_errors(self):
12✔
52
        """Return True if row contain any skipped errors."""
NEW
53
        if len(self.non_field_skipped_errors) > 0 or len(self.field_skipped_errors) > 0:
×
NEW
54
            return True
×
NEW
UNCOV
55
        return False
×
56

57
    @property
12✔
58
    def skipped_errors_count(self):
12✔
59
        """Return count of skipped errors."""
NEW
UNCOV
60
        return (
×
61
            len(self.non_field_skipped_errors)
62
            + len(self.field_skipped_errors)
63
        )
64

65

66
class SkippedErrorsResult(results.Result):
12✔
67
    """Custom result class with ability to store info about skipped rows."""
68

69
    @property
12✔
70
    def has_skipped_rows(self):
12✔
71
        """Return True if contain any skipped rows."""
NEW
UNCOV
72
        if any(row.has_skipped_errors for row in self.rows):
×
NEW
UNCOV
73
            return True
×
NEW
UNCOV
74
        return False
×
75

76
    @property
12✔
77
    def skipped_rows(self):
12✔
78
        """Return all rows with skipped errors."""
NEW
UNCOV
79
        return list(
×
80
            filter(lambda row: row.has_skipped_errors, self.rows),
81
        )
82

83

84
class CeleryResourceMixin:
12✔
85
    """Mixin for resources for background import/export using celery."""
86
    filterset_class: typing.Type[filters.FilterSet]
12✔
87
    SUPPORTED_FORMATS: list[
12✔
88
        typing.Type[base_formats.Format]
89
    ] = base_formats.DEFAULT_FORMATS
90

91
    def __init__(
12✔
92
        self,
93
        filter_kwargs: typing.Optional[dict[str, typing.Any]] = None,
94
        **kwargs,
95
    ):
96
        """Remember init kwargs."""
97
        self._filter_kwargs = filter_kwargs
12✔
98
        self.resource_init_kwargs: dict[str, typing.Any] = kwargs
12✔
99
        super().__init__()
12✔
100

101
    def get_queryset(self):
12✔
102
        """Filter export queryset via filterset class."""
103
        queryset = super().get_queryset()
12✔
104
        if not self._filter_kwargs:
12✔
105
            return queryset
12✔
106
        filter_instance = self.filterset_class(
×
107
            data=self._filter_kwargs,
108
        )
UNCOV
109
        if not filter_instance.is_valid():
×
UNCOV
110
            raise translate_validation(filter_instance.errors)
×
UNCOV
111
        return filter_instance.filter_queryset(queryset=queryset)
×
112

113
    @classproperty
12✔
114
    def class_path(cls) -> str:
12✔
115
        """Get path of class to import it."""
116
        return ".".join([cls.__module__, cls.__name__])
12✔
117

118
    @classmethod
12✔
119
    def get_supported_formats(cls) -> list[typing.Type[base_formats.Format]]:
12✔
120
        """Get a list of supported formats."""
121
        return cls.SUPPORTED_FORMATS
12✔
122

123
    @classmethod
12✔
124
    def get_supported_extensions_map(cls) -> dict[
12✔
125
        str, typing.Type[base_formats.Format],
126
    ]:
127
        """Get a map of supported formats and their extensions."""
128
        return {
12✔
129
            supported_format().get_extension(): supported_format
130
            for supported_format in cls.SUPPORTED_FORMATS
131
        }
132

133
    def import_data(
12✔
134
        self,
135
        dataset: tablib.Dataset,
136
        dry_run: bool = False,
137
        raise_errors: bool = False,
138
        use_transactions: typing.Optional[bool] = None,
139
        collect_failed_rows: bool = False,
140
        rollback_on_validation_errors: bool = False,
141
        force_import: bool = False,
142
        **kwargs,
143
    ):
144
        """Init task state before importing.
145

146
        If `force_import=True`, then rows with errors will be skipped.
147

148
        """
149
        self.initialize_task_state(
12✔
150
            state=(
151
                TaskState.IMPORTING.name if not dry_run
152
                else TaskState.PARSING.name
153
            ),
154
            queryset=dataset,
155
        )
156
        return super().import_data(  # type: ignore
12✔
157
            dataset=dataset,
158
            dry_run=dry_run,
159
            raise_errors=raise_errors,
160
            use_transactions=use_transactions,
161
            collect_failed_rows=collect_failed_rows,
162
            rollback_on_validation_errors=rollback_on_validation_errors,
163
            force_import=force_import,
164
            **kwargs,
165
        )
166

167
    def import_row(
12✔
168
        self,
169
        row,
170
        instance_loader,
171
        using_transactions=True,
172
        dry_run=False,
173
        raise_errors=False,
174
        force_import=False,
175
        **kwargs,
176
    ):
177
        """Update task status as we import rows.
178

179
        If `force_import=True`, then row errors will be stored in
180
        `field_skipped_errors` or `non_field_skipped_errors`.
181

182
        """
183
        imported_row: SkippedErrorsRowResult = super().import_row(
12✔
184
            row=row,
185
            instance_loader=instance_loader,
186
            using_transactions=using_transactions,
187
            dry_run=dry_run,
188
            raise_errors=raise_errors,
189
            **kwargs,
190
        )
191
        self.update_task_state(
12✔
192
            state=(
193
                TaskState.IMPORTING.name if not dry_run
194
                else TaskState.PARSING.name
195
            ),
196
        )
197
        if not force_import:
12✔
198
            return imported_row
12✔
199
        if (
12✔
200
            imported_row.import_type == results.RowResult.IMPORT_TYPE_ERROR
201
            or imported_row.import_type == results.RowResult.IMPORT_TYPE_INVALID
202
        ):
203
            imported_row.diff = []
12✔
204
            for field in self.get_fields():
12✔
205
                imported_row.diff.append(row.get(field.column_name, ""))
12✔
206

207
            imported_row.non_field_skipped_errors.extend(
12✔
208
                str(error.error) for error in imported_row.errors
209
            )
210
            if imported_row.validation_error is not None:
12✔
NEW
UNCOV
211
                imported_row.field_skipped_errors.append(
×
212
                    imported_row.validation_error.message_dict,
213
                )
214
            imported_row.errors = []
12✔
215
            imported_row.validation_error = None
12✔
216

217
            imported_row.import_type = results.RowResult.IMPORT_TYPE_SKIP
12✔
218
        return imported_row
12✔
219

220
    @classmethod
12✔
221
    def get_row_result_class(self):
12✔
222
        """Return custom row result class."""
223
        return SkippedErrorsRowResult
12✔
224

225
    @classmethod
12✔
226
    def get_result_class(self):
12✔
227
        """Geti custom result class."""
228
        return SkippedErrorsResult
12✔
229

230
    def export(
12✔
231
        self,
232
        queryset: QuerySet = None,
233
        *args,
234
        **kwargs,
235
    ) -> tablib.Dataset:
236
        """Init task state before exporting."""
237
        if queryset is None:
12✔
238
            queryset = self.get_queryset()
12✔
239
        self.initialize_task_state(
12✔
240
            state=TaskState.EXPORTING.name,
241
            queryset=queryset,
242
        )
243
        return super().export(  # type: ignore
12✔
244
            queryset=queryset,
245
            *args,
246
            **kwargs,
247
        )
248

249
    def export_resource(self, obj):
12✔
250
        """Update task status as we export rows."""
251
        resource = [
12✔
252
            self.export_field(field, obj) for field in self.get_export_fields()
253
        ]
254
        self.update_task_state(state=TaskState.EXPORTING.name)
12✔
255
        return resource
12✔
256

257
    def initialize_task_state(
12✔
258
        self,
259
        state: str,
260
        queryset: typing.Union[QuerySet, tablib.Dataset],
261
    ):
262
        """Set initial state of the task to track progress.
263

264
        Counts total number of instances to import/export and
265
        generate state for the task.
266

267
        """
268
        if not current_task or current_task.request.called_directly:
12✔
269
            return
12✔
270

271
        if isinstance(queryset, QuerySet):
12✔
272
            total = queryset.count()
12✔
273
        else:
274
            total = len(queryset)
12✔
275

276
        self._update_current_task_state(
12✔
277
            state=state,
278
            meta=dict(
279
                current=0,
280
                total=total,
281
            ),
282
        )
283

284
    def update_task_state(
12✔
285
        self,
286
        state: str,
287
    ):
288
        """Update state of the current event.
289

290
        Receives meta of the current task and increase the `current`
291
        field by 1.
292

293
        """
294
        if not current_task or current_task.request.called_directly:
12✔
295
            return
12✔
296

297
        async_result = result.AsyncResult(current_task.request.get("id"))
12✔
298
        if not async_result.result:
12✔
UNCOV
299
            return
×
300

301
        self._update_current_task_state(
12✔
302
            state=state,
303
            meta=dict(
304
                current=async_result.result.get("current", 0) + 1,
305
                total=async_result.result.get("total", 0),
306
            ),
307
        )
308

309
    def _update_current_task_state(self, state: str, meta: dict[str, int]):
12✔
310
        """Update state of task where resource is executed."""
311
        current_task.update_state(
12✔
312
            state=state,
313
            meta=meta,
314
        )
315

316
    def generate_export_filename(self, file_format: base_formats.Format):
12✔
317
        """Generate export filename."""
318
        return self._generate_export_filename_from_model(file_format)
12✔
319

320
    def _generate_export_filename_from_model(
12✔
321
        self,
322
        file_format: base_formats.Format,
323
    ):
324
        """Generate export file name from model name."""
325
        model = self._meta.model._meta.verbose_name_plural
12✔
326
        date_str = timezone.now().strftime("%Y-%m-%d")
12✔
327
        extension = file_format.get_extension()
12✔
328
        return f"{model}-{date_str}.{extension}"
12✔
329

330
    @classmethod
12✔
331
    def get_error_result_class(cls):
12✔
332
        """Override default error class."""
333
        return Error
12✔
334

335

336
class CeleryResource(CeleryResourceMixin, resources.Resource):
12✔
337
    """Resource which supports importing via celery."""
338

339

340
class CeleryModelResource(CeleryResourceMixin, resources.ModelResource):
12✔
341
    """ModelResource which supports importing via celery."""
342

343
    @classmethod
12✔
344
    def get_model_queryset(cls):
12✔
345
        """Return a queryset of all objects for this model.
346

347
        Override this if you
348
        want to limit the returned queryset.
349

350
        Same as resources.ModelResource get_queryset.
351

352
        """
353
        return cls._meta.model.objects.all()
12✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc