|
18 | 18 |
|
19 | 19 | import torch
|
20 | 20 | from lightning_utilities.core.apply_func import apply_to_collection, apply_to_collections
|
21 |
| -from torch import Tensor |
22 | 21 | from torch.utils.data import Dataset
|
23 | 22 | from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader
|
24 | 23 | from torch.utils.data.dataset import IterableDataset
|
|
33 | 32 | from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
34 | 33 |
|
35 | 34 |
|
36 |
| -class TensorRunningAccum: |
37 |
| - """Tracks a running accumulation values (min, max, mean) without graph references. |
38 |
| -
|
39 |
| - Examples: |
40 |
| - >>> accum = TensorRunningAccum(5) |
41 |
| - >>> accum.last(), accum.mean() |
42 |
| - (None, None) |
43 |
| - >>> accum.append(torch.tensor(1.5)) |
44 |
| - >>> accum.last(), accum.mean() |
45 |
| - (tensor(1.5000), tensor(1.5000)) |
46 |
| - >>> accum.append(torch.tensor(2.5)) |
47 |
| - >>> accum.last(), accum.mean() |
48 |
| - (tensor(2.5000), tensor(2.)) |
49 |
| - >>> accum.reset() |
50 |
| - >>> _= [accum.append(torch.tensor(i)) for i in range(13)] |
51 |
| - >>> accum.last(), accum.mean(), accum.min(), accum.max() |
52 |
| - (tensor(12.), tensor(10.), tensor(8.), tensor(12.)) |
53 |
| - """ |
54 |
| - |
55 |
| - def __init__(self, window_length: int): |
56 |
| - self.window_length = window_length |
57 |
| - self.reset(window_length) |
58 |
| - |
59 |
| - def reset(self, window_length: Optional[int] = None) -> None: |
60 |
| - """Empty the accumulator.""" |
61 |
| - if window_length is not None: |
62 |
| - self.window_length = window_length |
63 |
| - self.memory: Optional[Tensor] = None |
64 |
| - self.current_idx: int = 0 |
65 |
| - self.last_idx: Optional[int] = None |
66 |
| - self.rotated: bool = False |
67 |
| - |
68 |
| - def last(self) -> Optional[Tensor]: |
69 |
| - """Get the last added element.""" |
70 |
| - if self.last_idx is not None: |
71 |
| - assert isinstance(self.memory, Tensor) |
72 |
| - return self.memory[self.last_idx].float() |
73 |
| - |
74 |
| - def append(self, x: Tensor) -> None: |
75 |
| - """Add an element to the accumulator.""" |
76 |
| - if self.memory is None: |
77 |
| - # tradeoff memory for speed by keeping the memory on device |
78 |
| - self.memory = torch.zeros(self.window_length, *x.shape, device=x.device, dtype=x.dtype) |
79 |
| - |
80 |
| - # store without grads |
81 |
| - with torch.no_grad(): |
82 |
| - self.memory[self.current_idx] = x |
83 |
| - self.last_idx = self.current_idx |
84 |
| - |
85 |
| - # increase index |
86 |
| - self.current_idx += 1 |
87 |
| - |
88 |
| - # reset index when hit limit of tensor |
89 |
| - self.current_idx = self.current_idx % self.window_length |
90 |
| - if self.current_idx == 0: |
91 |
| - self.rotated = True |
92 |
| - |
93 |
| - def mean(self) -> Optional[Tensor]: |
94 |
| - """Get mean value from stored elements.""" |
95 |
| - return self._agg_memory("mean") |
96 |
| - |
97 |
| - def max(self) -> Optional[Tensor]: |
98 |
| - """Get maximal value from stored elements.""" |
99 |
| - return self._agg_memory("max") |
100 |
| - |
101 |
| - def min(self) -> Optional[Tensor]: |
102 |
| - """Get minimal value from stored elements.""" |
103 |
| - return self._agg_memory("min") |
104 |
| - |
105 |
| - def _agg_memory(self, how: str) -> Optional[Tensor]: |
106 |
| - if self.last_idx is not None: |
107 |
| - assert isinstance(self.memory, Tensor) |
108 |
| - if self.rotated: |
109 |
| - return getattr(self.memory.float(), how)() |
110 |
| - return getattr(self.memory[: self.current_idx].float(), how)() |
111 |
| - |
112 |
| - |
113 | 35 | @dataclass
|
114 | 36 | class SharedCycleIteratorState:
|
115 | 37 | """A state shared between all CycleIterators in a CombinedLoader.
|
|
0 commit comments