"""运行顺序
for epoch_i in range(num_epoch):
on_train_epoch_start()
for batch in trainning_dataloader:
on_train_batch_start()
trainning_step()
on_train_batch_end()
on_train_epoch_end()
on_validation_start()
for batch in validation_dataloader:
on_validation_batch_start()
validation_step()
on_validation_batch_end
on_validation_end()
"""
class CustomNet(pl.LightningModule):
def __init__(self, options: dict):
super().__init__()
self.save_hyperparameters()
self._options = VqVaePartttenNetHyperParameters.parse_obj(options)
self._lr = self._options.learning_rate
...
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=self._lr)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[30, 60], gamma=0.1
),
"frequency": 1,
"interval": "epoch",
"strict": True,
"name": "learning_rate",
},
}
def forward(
self,
seq_daliy_industry_lv1_mean: torch.Tensor,
seq_intraday: Optional[torch.Tensor] = None,
) -> torch.Tensor:
...
return nearest_neighbor, preds
def training_step(self, batch, batch_idx):
...
return loss
def validation_step(self, batch, batch_idx):
...
def on_train_epoch_start(self) -> None:
...
def on_train_epoch_end(self) -> None:
...
def on_validation_epoch_start(self) -> None:
...
def on_validation_epoch_end(self) -> None:
...
def on_train_epoch_end(self) -> None:
result = self._signal_analysis_cache_train.compute()
if self.trainer.sanity_checking:
entropy = torch.tensor(0.0)
else:
counts = torch.bincount(self.zq_indices_train, minlength=32)
probs = counts.float() / self.zq_indices_train.shape[0]
entropy = -(probs * torch.log2(probs + 1e-9)).sum()
self.register_buffer(
"zq_indices_train",
torch.full((train_dataset_length,), -1),
persistent=False,
)