pytorch_lightning 自动保存model_checkpoint
Published on Aug. 22, 2023, 12:09 p.m.
pytorch_lightning #model_checkpoint
自动保存最佳的模型检出点
方便寻找出最佳的检出点,修改参数继续之前最佳。
示例
from pytorch_lightning.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(
<h1>设置个名字。方便知道检出点位置</h1>
<pre><code>filename='/kaggle/working/{epoch}-{val_loss:.2f}-{other_metric:.2f}',
</code></pre>
# dirpath="/kaggle/working/checkpoints", 要求太多了 直接设置文件名为全路径来的简单
<h1>filename='bart-out-{epoch:02d}-{val_loss:.2f}',</h1>
<h1>filename='bart-out',</h1>
<pre><code># save_last=True, #保存最后一个,话说我们要最后一个干嘛,要最佳表现啊
verbose=True,
monitor='val_loss', #检测的依据指标
mode='min', #增长行数据
</code></pre>
<h1>best_model_path='best'</h1>
<pre><code>save_top_k=2 #保存最佳数目
</code></pre>
)
trainer = pl.Trainer(
<pre><code> checkpoint_callback=checkpoint_callback,
</code></pre>
)