• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

meta-pytorch / opacus / 24173143467

26 Mar 2026 09:13PM UTC coverage: 79.477% (+0.3%) from 79.181%
24173143467

push

github

meta-codesync[bot]
Replace shape-based empty batch handling inside `DPDataLoader` with structure-aware approach (#806)

Summary:
## Types of changes

- [x] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [x] Breaking change (fix or feature that would cause existing functionality to change)
- [ ] Docs change / refactoring / dependency upgrade

## Motivation and Context / Related issue

Replaces unstable shape-based empty batch handling with a stateful approach that learns and replicates the actual output structure from `collate_fn`. This fixes a critical bug where custom collate functions returning non-list structures (dicts, custom classes) were incompatible with Poisson sampling.

The old implementation inspected `dataset[0]` to pre-compute shapes, then hardcoded empty batches as lists:
```python
def collate(batch, collate_fn, sample_empty_shapes, dtypes):
    if len(batch) > 0:
        return collate_fn(batch)  # Could return dict, custom class, etc.
    else:
        return [torch.zeros(shape, dtype=dtype) for ...]  # Always list!
```
Bug -> if `collate_fn` returns a dict, non-empty batches are dicts but empty batches are lists -> type mismatch crash

Existing, related issue: https://github.com/meta-pytorch/opacus/issues/534

### Solution:
New `CollateFnWithEmpty` learns the structure from the first non-empty batch:
```python
class CollateFnWithEmpty:
    def __call__(self, batch):
        if len(batch) > 0:
            output = self.wrapped_collator_fn(batch)
            if self.first_batch is None:
                self.first_batch = copy.deepcopy(output)  # Learn structure
        else:
            output = self._make_empty_batch(self.first_batch)  # Replicate structure
        return output

```
Now empty batches match the structure of non-empty batches, regardless of what `collate_fn` returns.

If the first non-empty batch is actually the first batch, then it returns an error:
```python
if self... (continued)

258 of 259 new or added lines in 4 files covered. (99.61%)

27 existing lines in 4 files now uncovered.

6111 of 7689 relevant lines covered (79.48%)

1.76 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.33
/opacus/utils/batch_memory_manager.py


Source Not Available

STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc