|
Ran
|
Jobs
0
|
Files
0
|
Run time
–
|
Badge
README BADGES
|
push
github
Fix AdaClipOptimizer (#779) Summary: ## Types of changes - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Docs change / refactoring / dependency upgrade ## Motivation and Context / Related issue AdaClipOptimizer fails on the attempt to [generate noise for `self.unclipped_num`](https://github.com/pytorch/opacus/blob/<a class=hub.com/pytorch/opacus/commit/<a class="double-link" href="https://git"><a class=hub.com/pytorch/opacus/commit/e1a695c1b82e6749e2ab80b491da9e3d2cfe823b">e1a695c1b/opacus/optimizers/adaclipoptimizer.py#L127C1-L131C10). PyTorch fails after the first step complaining that torch.normal is not defined for LongTensors. Converting it to `.float` just before the noise addition seems the shortest change possible to fix the issue. Otherwise, AdaClip doesn't work with the current version of PyTorch: ```python unclipped_num_noise = _generate_noise( std=self.unclipped_num_std, reference=self.unclipped_num.float(), <-- generator=self.generator, ) ``` ___ Btw, there is a general issue with how `unclipped_num_std` is handled. Initially it [starts as a ~int~float](https://github.com/pytorch/opacus/blob/e1a695c1b82e6749e2ab80b491da9e3d2cfe823b/opacus/optimizers/adaclipoptimizer.py#L85C9-L85C31): ```python self.unclipped_num = 0 ``` then gets [converted to a tensor almost unintentionally](https://github.com/pytorch/opacus/blob/e1a695c1b82e6749e2ab80b491da9e3d2cfe823b/opacus/optimizers/adaclipoptimizer.py#L108-L110): ```python self.unclipped_num += ( len(per_sample_clip_factor) - (per_sample_clip_factor < 1).sum() ) ``` then the place with the fix where it can [only be a tensor to work as a reference for the `_generate_noise`](https://github.com/pytorch/opacus/blob/42b0e7275/opacus/optimizers/optimizer.py#L106) function: ```python unclipped_num_noise = _generate_noise( std=self.unclipped_num_std, reference=self.unclipped_num, generator=self.gener... (continued)