Default / 默认 · 9月 1, 2021

pytorch_lightning 自动保存model_checkpoint

内容纲要

#pytorch_lightning #model_checkpoint

自动保存最佳的模型检出点
方便寻找出最佳的检出点,修改参数继续之前最佳。

示例


from pytorch_lightning.callbacks import ModelCheckpoint checkpoint_callback = ModelCheckpoint( <h1>设置个名字。方便知道检出点位置</h1> <pre><code>filename='/kaggle/working/{epoch}-{val_loss:.2f}-{other_metric:.2f}', </code></pre> # dirpath="/kaggle/working/checkpoints", 要求太多了 直接设置文件名为全路径来的简单 <h1>filename='bart-out-{epoch:02d}-{val_loss:.2f}',</h1> <h1>filename='bart-out',</h1> <pre><code># save_last=True, #保存最后一个,话说我们要最后一个干嘛,要最佳表现啊 verbose=True, monitor='val_loss', #检测的依据指标 mode='min', #增长行数据 </code></pre> <h1>best_model_path='best'</h1> <pre><code>save_top_k=2 #保存最佳数目 </code></pre> ) trainer = pl.Trainer( <pre><code> checkpoint_callback=checkpoint_callback, </code></pre> )

https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.model_checkpoint.html

%d 博主赞过: