Skip to content

Commit

Permalink
Fix for data import (#9060) (#9065)
Browse files Browse the repository at this point in the history
- Prevent shadow overwrite of default_values dict
- Remove dead code

(cherry picked from commit 7049e84)

Co-authored-by: Oliver <[email protected]>
  • Loading branch information
github-actions[bot] and SchrodingersGat authored Feb 11, 2025
1 parent 407ccb7 commit 3b6b419
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 34 deletions.
32 changes: 15 additions & 17 deletions src/backend/InvenTree/importer/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,13 @@ def save(self, *args, **kwargs):
)

@property
def field_mapping(self):
def field_mapping(self) -> dict:
"""Construct a dict of field mappings for this import session.
Returns: A dict of field: column mappings
Returns:
A dict of field -> column mappings
"""
mapping = {}

for i in self.column_mappings.all():
mapping[i.field] = i.column

return mapping
return {mapping.field: mapping.column for mapping in self.column_mappings.all()}

@property
def model_class(self):
Expand All @@ -138,7 +134,7 @@ def serializer_class(self):

return supported_models().get(self.model_type, None)

def extract_columns(self):
def extract_columns(self) -> None:
"""Run initial column extraction and mapping.
This method is called when the import session is first created.
Expand Down Expand Up @@ -204,7 +200,7 @@ def extract_columns(self):
self.status = DataImportStatusCode.MAPPING.value
self.save()

def accept_mapping(self):
def accept_mapping(self) -> None:
"""Accept current mapping configuration.
- Validate that the current column mapping is correct
Expand Down Expand Up @@ -243,7 +239,7 @@ def accept_mapping(self):
# No errors, so trigger the data import process
self.trigger_data_import()

def trigger_data_import(self):
def trigger_data_import(self) -> None:
"""Trigger the data import process for this session.
Offloads the task to the background worker process.
Expand All @@ -256,7 +252,7 @@ def trigger_data_import(self):

offload_task(importer.tasks.import_data, self.pk)

def import_data(self):
def import_data(self) -> None:
"""Perform the data import process for this session."""
# Clear any existing data rows
self.rows.all().delete()
Expand Down Expand Up @@ -316,12 +312,12 @@ def check_complete(self) -> bool:
return True

@property
def row_count(self):
def row_count(self) -> int:
"""Return the number of rows in the import session."""
return self.rows.count()

@property
def completed_row_count(self):
def completed_row_count(self) -> int:
"""Return the number of completed rows for this session."""
return self.rows.filter(complete=True).count()

Expand Down Expand Up @@ -349,7 +345,7 @@ def available_fields(self):
self._available_fields = fields
return fields

def required_fields(self):
def required_fields(self) -> dict:
"""Returns information on which fields are *required* for import."""
fields = self.available_fields()

Expand Down Expand Up @@ -591,7 +587,7 @@ def extract_data(
value = value or None

# Use the default value, if provided
if value in [None, ''] and field in default_values:
if value is None and field in default_values:
value = default_values[field]

data[field] = value
Expand All @@ -607,7 +603,9 @@ def serializer_data(self):
- If available, we use the "default" values provided by the import session
- If available, we use the "override" values provided by the import session
"""
data = self.default_values
data = {}

data.update(self.default_values)

if self.data:
data.update(self.data)
Expand Down
17 changes: 0 additions & 17 deletions src/backend/InvenTree/importer/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,23 +81,6 @@ def extract_column_names(data_file) -> list:
return headers


def extract_rows(data_file) -> list:
"""Extract rows from the data file.
Each returned row is a dictionary of column_name: value pairs.
"""
data = load_data_file(data_file)

headers = data.headers

rows = []

for row in data:
rows.append(dict(zip(headers, row)))

return rows


def get_field_label(field) -> str:
"""Return the label for a field in a serializer class.
Expand Down

0 comments on commit 3b6b419

Please sign in to comment.