Skip to content

Commit 76bdde7

Browse files
authored
[feat] Add flexible step-wise LR scheduler with minimum changes (#256)
* [doc] Add the bug fix information for ipython user * [fix] Delete the gcc and gnn install because I could check we can build without them * [feat] Add flexible step-wise LR scheduler with minimum changes Since we would like to merge this feature promptly, I cut this new branch from the branch hot-fix-adapt... and narrowed down the scope of this PR. The subsequent PR addresses the issues, epspecially the format and mypy typing. * [fix] Fix flake8 issue and remove unneeded changes because of mis-branching * [test] Add the tests for the new features * [fix] Fix flake8 issues and torch.tensor to torch.Tensor in typing check The intention behind the change from torch.tensor to torch.Tensor is that since I got an error NoModuleFound `import torch.tensor`. Plus, the torch.tensor is not a TensorType class, but torch.Tensor is. Therefore, I changed the torch.tensor to torch.Tensor. * [feat] Be able to add the step_unit to ConfigSpace * [fix] Fix pytest issues by adding batch-wise update to incumbent Since the previous version always use batch-wise update, I added the step_unit = batch and then avoid the errors I got from pytest. * [fix] Add step_unit info in the greedy portfolio json file Since the latest version only supports the batch-wise update, I just inserted step_unit == "batch" to be able to run greedy portfolio selection. * [refactor] Rebase to the latest development and add overridden functions in base_scheduler * [fix] Fix flake8 and mypy issues * [fix] Fix flake8 issues * [test] Add the test for the train step and the lr scheduler check * [refactor] Change the name to * [fix] Fix flake8 issues * [fix] Disable the step_interval option from the hyperparameter settings * [fix] Change the default step_interval to Epoch-wise * [fix] Fix the step timing for ReduceLROnPlateau and fix flake8 issues * [fix] Add the after-validation option to StepIntervalUnit for ReduceLROnPlateau * [fix] Fix flake8 issues * [fix] Fix loss value for epoch-wise scheduler update * [fix] Delete type check of step_interval and make it property Since the step_interval should not be modified from outside, I made it a property of the base_scheduler class. Furthermore, since we do not have to check the type of step_interval except the initialization, I deleted the type check from prepare method. * [fix] Fix a mypy issue * [fix] Fix a mypy issue * [fix] Fix mypy issues * [fix] Fix mypy issues * [feedback] Address the Ravin's suggestions
1 parent 999f3c3 commit 76bdde7

File tree

17 files changed

+287
-115
lines changed

17 files changed

+287
-115
lines changed

autoPyTorch/pipeline/components/setup/lr_scheduler/CosineAnnealingLR.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
import numpy as np
99

1010
import torch.optim.lr_scheduler
11-
from torch.optim.lr_scheduler import _LRScheduler
1211

1312
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
1413
from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler import BaseLRComponent
14+
from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import StepIntervalUnit
1515
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter
1616

1717

@@ -26,13 +26,13 @@ class CosineAnnealingLR(BaseLRComponent):
2626
def __init__(
2727
self,
2828
T_max: int,
29+
step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.epoch,
2930
random_state: Optional[np.random.RandomState] = None
3031
):
3132

32-
super().__init__()
33+
super().__init__(step_interval)
3334
self.T_max = T_max
3435
self.random_state = random_state
35-
self.scheduler = None # type: Optional[_LRScheduler]
3636

3737
def fit(self, X: Dict[str, Any], y: Any = None) -> BaseLRComponent:
3838
"""
@@ -71,6 +71,8 @@ def get_hyperparameter_search_space(
7171
default_value=200,
7272
)
7373
) -> ConfigurationSpace:
74+
7475
cs = ConfigurationSpace()
7576
add_hyperparameter(cs, T_max, UniformIntegerHyperparameter)
77+
7678
return cs

autoPyTorch/pipeline/components/setup/lr_scheduler/CosineAnnealingWarmRestarts.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
from typing import Any, Dict, Optional, Union
22

33
from ConfigSpace.configuration_space import ConfigurationSpace
4-
from ConfigSpace.hyperparameters import UniformFloatHyperparameter, UniformIntegerHyperparameter
4+
from ConfigSpace.hyperparameters import (
5+
UniformFloatHyperparameter,
6+
UniformIntegerHyperparameter
7+
)
58

69
import numpy as np
710

811
import torch.optim.lr_scheduler
9-
from torch.optim.lr_scheduler import _LRScheduler
1012

1113
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
1214
from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler import BaseLRComponent
15+
from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import StepIntervalUnit
1316
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter
1417

1518

@@ -30,13 +33,13 @@ def __init__(
3033
self,
3134
T_0: int,
3235
T_mult: int,
33-
random_state: Optional[np.random.RandomState] = None
36+
step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.epoch,
37+
random_state: Optional[np.random.RandomState] = None,
3438
):
35-
super().__init__()
39+
super().__init__(step_interval)
3640
self.T_0 = T_0
3741
self.T_mult = T_mult
3842
self.random_state = random_state
39-
self.scheduler = None # type: Optional[_LRScheduler]
4043

4144
def fit(self, X: Dict[str, Any], y: Any = None) -> BaseLRComponent:
4245
"""
@@ -78,8 +81,9 @@ def get_hyperparameter_search_space(
7881
T_mult: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter='T_mult',
7982
value_range=(1.0, 2.0),
8083
default_value=1.0,
81-
),
84+
)
8285
) -> ConfigurationSpace:
86+
8387
cs = ConfigurationSpace()
8488
add_hyperparameter(cs, T_0, UniformIntegerHyperparameter)
8589
add_hyperparameter(cs, T_mult, UniformFloatHyperparameter)

autoPyTorch/pipeline/components/setup/lr_scheduler/CyclicLR.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
import numpy as np
1111

1212
import torch.optim.lr_scheduler
13-
from torch.optim.lr_scheduler import _LRScheduler
1413

1514
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
1615
from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler import BaseLRComponent
16+
from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import StepIntervalUnit
1717
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter
1818

1919

@@ -39,16 +39,16 @@ def __init__(
3939
base_lr: float,
4040
mode: str,
4141
step_size_up: int,
42+
step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.epoch,
4243
max_lr: float = 0.1,
4344
random_state: Optional[np.random.RandomState] = None
4445
):
45-
super().__init__()
46+
super().__init__(step_interval)
4647
self.base_lr = base_lr
4748
self.mode = mode
4849
self.max_lr = max_lr
4950
self.step_size_up = step_size_up
5051
self.random_state = random_state
51-
self.scheduler = None # type: Optional[_LRScheduler]
5252

5353
def fit(self, X: Dict[str, Any], y: Any = None) -> BaseLRComponent:
5454
"""

autoPyTorch/pipeline/components/setup/lr_scheduler/ExponentialLR.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
import numpy as np
99

1010
import torch.optim.lr_scheduler
11-
from torch.optim.lr_scheduler import _LRScheduler
1211

1312
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
1413
from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler import BaseLRComponent
14+
from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import StepIntervalUnit
1515
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter
1616

1717

@@ -27,13 +27,13 @@ class ExponentialLR(BaseLRComponent):
2727
def __init__(
2828
self,
2929
gamma: float,
30+
step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.epoch,
3031
random_state: Optional[np.random.RandomState] = None
3132
):
3233

33-
super().__init__()
34+
super().__init__(step_interval)
3435
self.gamma = gamma
3536
self.random_state = random_state
36-
self.scheduler = None # type: Optional[_LRScheduler]
3737

3838
def fit(self, X: Dict[str, Any], y: Any = None) -> BaseLRComponent:
3939
"""
@@ -72,6 +72,8 @@ def get_hyperparameter_search_space(
7272
default_value=0.9,
7373
)
7474
) -> ConfigurationSpace:
75+
7576
cs = ConfigurationSpace()
7677
add_hyperparameter(cs, gamma, UniformFloatHyperparameter)
78+
7779
return cs

autoPyTorch/pipeline/components/setup/lr_scheduler/NoScheduler.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44

55
import numpy as np
66

7-
from torch.optim.lr_scheduler import _LRScheduler
8-
97
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
108
from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler import BaseLRComponent
9+
from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import StepIntervalUnit
1110

1211

1312
class NoScheduler(BaseLRComponent):
@@ -17,12 +16,12 @@ class NoScheduler(BaseLRComponent):
1716
"""
1817
def __init__(
1918
self,
19+
step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.epoch,
2020
random_state: Optional[np.random.RandomState] = None
2121
):
2222

23-
super().__init__()
23+
super().__init__(step_interval)
2424
self.random_state = random_state
25-
self.scheduler = None # type: Optional[_LRScheduler]
2625

2726
def fit(self, X: Dict[str, Any], y: Any = None) -> BaseLRComponent:
2827
"""

autoPyTorch/pipeline/components/setup/lr_scheduler/ReduceLROnPlateau.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
import numpy as np
1111

1212
import torch.optim.lr_scheduler
13-
from torch.optim.lr_scheduler import _LRScheduler
1413

1514
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
1615
from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler import BaseLRComponent
16+
from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import StepIntervalUnit
1717
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter
1818

1919

@@ -31,22 +31,26 @@ class ReduceLROnPlateau(BaseLRComponent):
3131
factor (float): Factor by which the learning rate will be reduced. new_lr = lr * factor.
3232
patience (int): Number of epochs with no improvement after which learning
3333
rate will be reduced.
34+
step_interval (str): step should be called after validation in the case of ReduceLROnPlateau
3435
random_state (Optional[np.random.RandomState]): random state
36+
37+
Reference:
38+
https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ReduceLROnPlateau.html#torch.optim.lr_scheduler.ReduceLROnPlateau
3539
"""
3640

3741
def __init__(
3842
self,
3943
mode: str,
4044
factor: float,
4145
patience: int,
42-
random_state: Optional[np.random.RandomState] = None
46+
step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.valid,
47+
random_state: Optional[np.random.RandomState] = None,
4348
):
44-
super().__init__()
49+
super().__init__(step_interval)
4550
self.mode = mode
4651
self.factor = factor
4752
self.patience = patience
4853
self.random_state = random_state
49-
self.scheduler = None # type: Optional[_LRScheduler]
5054

5155
def fit(self, X: Dict[str, Any], y: Any = None) -> BaseLRComponent:
5256
"""
@@ -93,7 +97,7 @@ def get_hyperparameter_search_space(
9397
factor: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter='factor',
9498
value_range=(0.01, 0.9),
9599
default_value=0.1,
96-
),
100+
)
97101
) -> ConfigurationSpace:
98102

99103
cs = ConfigurationSpace()

autoPyTorch/pipeline/components/setup/lr_scheduler/StepLR.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
import numpy as np
1010

1111
import torch.optim.lr_scheduler
12-
from torch.optim.lr_scheduler import _LRScheduler
1312

1413
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
1514
from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler import BaseLRComponent
15+
from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import StepIntervalUnit
1616
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter
1717

1818

@@ -32,13 +32,13 @@ def __init__(
3232
self,
3333
step_size: int,
3434
gamma: float,
35+
step_interval: Union[str, StepIntervalUnit] = StepIntervalUnit.epoch,
3536
random_state: Optional[np.random.RandomState] = None
3637
):
37-
super().__init__()
38+
super().__init__(step_interval)
3839
self.gamma = gamma
3940
self.step_size = step_size
4041
self.random_state = random_state
41-
self.scheduler = None # type: Optional[_LRScheduler]
4242

4343
def fit(self, X: Dict[str, Any], y: Any = None) -> BaseLRComponent:
4444
"""
@@ -80,7 +80,7 @@ def get_hyperparameter_search_space(
8080
step_size: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter='step_size',
8181
value_range=(1, 10),
8282
default_value=5,
83-
),
83+
)
8484
) -> ConfigurationSpace:
8585
cs = ConfigurationSpace()
8686

autoPyTorch/pipeline/components/setup/lr_scheduler/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717
from autoPyTorch.pipeline.components.setup.lr_scheduler.base_scheduler import BaseLRComponent
1818

19+
1920
directory = os.path.split(__file__)[0]
2021
_schedulers = find_components(__package__,
2122
directory,

autoPyTorch/pipeline/components/setup/lr_scheduler/base_scheduler.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,39 @@
1-
from typing import Any, Dict, Optional
1+
from typing import Any, Dict, Optional, Union
22

33
from torch.optim import Optimizer
44
from torch.optim.lr_scheduler import _LRScheduler
55

66
from autoPyTorch.pipeline.components.setup.base_setup import autoPyTorchSetupComponent
7+
from autoPyTorch.pipeline.components.setup.lr_scheduler.constants import StepIntervalUnit, StepIntervalUnitChoices
78
from autoPyTorch.utils.common import FitRequirement
89

910

1011
class BaseLRComponent(autoPyTorchSetupComponent):
1112
"""Provide an abstract interface for schedulers
1213
in Auto-Pytorch"""
1314

14-
def __init__(self) -> None:
15+
def __init__(self, step_interval: Union[str, StepIntervalUnit]):
1516
super().__init__()
1617
self.scheduler = None # type: Optional[_LRScheduler]
18+
self._step_interval: StepIntervalUnit
19+
20+
if isinstance(step_interval, str):
21+
if step_interval not in StepIntervalUnitChoices:
22+
raise ValueError('step_interval must be either {}, but got {}.'.format(
23+
StepIntervalUnitChoices,
24+
step_interval
25+
))
26+
self._step_interval = getattr(StepIntervalUnit, step_interval)
27+
else:
28+
self._step_interval = step_interval
1729

1830
self.add_fit_requirements([
1931
FitRequirement('optimizer', (Optimizer,), user_defined=False, dataset_property=False)])
2032

33+
@property
34+
def step_interval(self) -> StepIntervalUnit:
35+
return self._step_interval
36+
2137
def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
2238
"""
2339
Adds the scheduler into the fit dictionary 'X' and returns it.
@@ -26,7 +42,11 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
2642
Returns:
2743
(Dict[str, Any]): the updated 'X' dictionary
2844
"""
29-
X.update({'lr_scheduler': self.scheduler})
45+
46+
X.update(
47+
lr_scheduler=self.scheduler,
48+
step_interval=self.step_interval
49+
)
3050
return X
3151

3252
def get_scheduler(self) -> _LRScheduler:
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from enum import Enum
2+
3+
4+
class StepIntervalUnit(Enum):
5+
"""
6+
By which interval we perform the step for learning rate schedulers.
7+
Attributes:
8+
batch (str): We update every batch evaluation
9+
epoch (str): We update every epoch
10+
valid (str): We update every validation
11+
"""
12+
batch = 'batch'
13+
epoch = 'epoch'
14+
valid = 'valid'
15+
16+
17+
StepIntervalUnitChoices = [step_interval.name for step_interval in StepIntervalUnit]

0 commit comments

Comments
 (0)