Skip to content

Calling trainer.predict during custom Callback cause error #10365

@whaowhao

Description

@whaowhao

🐛 Bug

Calling trainer.predict during custom Callback cause error :

Exception has occurred: AttributeError       (note: full exception trace is shown but execution is paused at: <module>)
'NoneType' object has no attribute 'extract_batch_size'
  File "/home/whao/.virtualenv/mixedclassify/lib/python3.6/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py", line 136, in on_evaluation_batch_start
    self.trainer._results.extract_batch_size(batch)
  File "/home/whao/.virtualenv/mixedclassify/lib/python3.6/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 182, in on_evaluation_batch_start
    self.trainer.logger_connector.on_evaluation_batch_start(batch, batch_idx, dataloader_idx, self._num_dataloaders)
  File "/home/whao/.virtualenv/mixedclassify/lib/python3.6/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 105, in advance
    self.on_evaluation_batch_start(batch, batch_idx, dataloader_idx)
  File "/home/whao/.virtualenv/mixedclassify/lib/python3.6/site-packages/pytorch_lightning/loops/base.py", line 111, in run
    self.advance(*args, **kwargs)
  File "/home/whao/.virtualenv/mixedclassify/lib/python3.6/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 111, in advance
    dataloader_iter, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders
  File "/home/whao/.virtualenv/mixedclassify/lib/python3.6/site-packages/pytorch_lightning/loops/base.py", line 111, in run
    self.advance(*args, **kwargs)
  File "/home/whao/.virtualenv/mixedclassify/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 1122, in _run_sanity_check
    self._evaluation_loop.run()
  File "/home/whao/.virtualenv/mixedclassify/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 1035, in _run_train
    self._run_sanity_check(self.lightning_module)
  File "/home/whao/.virtualenv/mixedclassify/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 1000, in run_stage
    return self._run_train()
  File "/home/whao/.virtualenv/mixedclassify/lib/python3.6/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 161, in start_training
    self._results = trainer.run_stage()
  File "/home/whao/.virtualenv/mixedclassify/lib/python3.6/site-packages/pytorch_lightning/accelerators/accelerator.py", line 92, in start_training
    self.training_type_plugin.start_training(trainer)
  File "/home/whao/.virtualenv/mixedclassify/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 990, in _dispatch
    self.accelerator.start_training(self)
  File "/home/whao/.virtualenv/mixedclassify/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 922, in _run
    self._dispatch()
  File "/home/whao/.virtualenv/mixedclassify/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 552, in fit
    self._run(model)
  File "/home/whao/whao/MultiHead/multihead/cli.py", line 131, in run_active_learn
    trainer.fit(model)
  File "/home/whao/whao/MultiHead/multihead/cli.py", line 363, in main
    run_active_learn(args)
  File "/home/whao/whao/MultiHead/multihead/cli.py", line 399, in test_active_learn
    main(args)
  File "/home/whao/whao/MultiHead/multihead/cli.py", line 413, in <module> (Current frame)
    test_active_learn()

To Reproduce

class ActiveLearningInferencer(Callback):
    def __init__(self, cache_dir, batch_size, num_workers):
        super().__init__()
        self.cache_dir = cache_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
    
    def on_validation_start(self, trainer, pl_module):
        ds_active_learn = torch.load(os.path.join(self.cache_dir, "ds_active_learn.pt"))
        dl_active_learn = DataLoader(
            dataset=ds_active_learn, 
            batch_size=self.batch_size,
            shuffle=False, 
            pin_memory=True, 
            num_workers=self.num_workers
        )
        trainer.predict(pl_module, dl_active_learn)

Expected behavior

Callback prediction shouldn't affect other parts of the code.

Environment

pytorch-lightning==1.4.9
python 3.6.9

cc @tchaton

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingduplicateThis issue or pull request already existshelp wantedOpen to be worked onpriority: 2Low priority taskwon't fixThis will not be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions