金价走势 PyTorch Lightning,训练时 LSTM, Timeseries

Published on Aug. 22, 2023, 12:10 p.m.

Import dependencies
https://www.kaggle.com/terrychanorg/pytorch-lightning-lstm-timeseries-cl001****

from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb")

!pip install pytorch_lightning==1.4.8
# !pip install torchmetrics
!pip install wandb -q

import os
os.environ["WANDB_SILENT"] = "true"
os.environ["WANDB_API_KEY"] = secret_value_0

Collecting pytorch_lightning==1.4.8
Downloading pytorch_lightning-1.4.8-py3-none-any.whl (924 kB)
|████████████████████████████████| 924 kB 1.4 MB/s
Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning==1.4.8) (3.7.4.1)
Requirement already satisfied: numpy>=1.17.2 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning==1.4.8) (1.18.5)
Collecting torchmetrics>=0.4.0
Downloading torchmetrics-0.6.0-py3-none-any.whl (329 kB)
|████████████████████████████████| 329 kB 47.2 MB/s
Collecting pyDeprecate==0.3.1
Downloading pyDeprecate-0.3.1-py3-none-any.whl (10 kB)
Collecting fsspec[http]!=2021.06.0,>=2021.05.0
Downloading fsspec-2021.11.0-py3-none-any.whl (132 kB)
|████████████████████████████████| 132 kB 55.4 MB/s
Requirement already satisfied: future>=0.17.1 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning==1.4.8) (0.18.2)
Requirement already satisfied: packaging>=17.0 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning==1.4.8) (20.1)
Requirement already satisfied: PyYAML>=5.1 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning==1.4.8) (5.3.1)
Requirement already satisfied: tqdm>=4.41.0 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning==1.4.8) (4.45.0)
Requirement already satisfied: torch>=1.6 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning==1.4.8) (1.6.0)
Requirement already satisfied: tensorboard>=2.2.0 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning==1.4.8) (2.3.0)
Requirement already satisfied: requests; extra == “http” in /opt/conda/lib/python3.7/site-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.4.8) (2.23.0)
Requirement already satisfied: aiohttp; extra == “http” in /opt/conda/lib/python3.7/site-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.4.8) (3.6.2)
Requirement already satisfied: six in /opt/conda/lib/python3.7/site-packages (from packaging>=17.0->pytorch_lightning==1.4.8) (1.14.0)
Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/lib/python3.7/site-packages (from packaging>=17.0->pytorch_lightning==1.4.8) (2.4.7)
Requirement already satisfied: protobuf>=3.6.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning==1.4.8) (3.13.0)
Requirement already satisfied: google-auth<2,>=1.6.3 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning==1.4.8) (1.14.0)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning==1.4.8) (1.7.0)
Requirement already satisfied: absl-py>=0.4 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning==1.4.8) (0.10.0)
Requirement already satisfied: wheel>=0.26; python_version >= “3” in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning==1.4.8) (0.34.2)
Requirement already satisfied: setuptools>=41.0.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning==1.4.8) (46.1.3.post20200325)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning==1.4.8) (0.4.1)
Requirement already satisfied: grpcio>=1.24.3 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning==1.4.8) (1.32.0)
Requirement already satisfied: markdown>=2.6.8 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning==1.4.8) (3.2.1)
Requirement already satisfied: werkzeug>=0.11.15 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning==1.4.8) (1.0.1)
Requirement already satisfied: idna<3,>=2.5 in /opt/conda/lib/python3.7/site-packages (from requests; extra == “http”->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.4.8) (2.9)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests; extra == “http”->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.4.8) (2020.6.20)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests; extra == “http”->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.4.8) (1.24.3)
Requirement already satisfied: chardet<4,>=3.0.2 in /opt/conda/lib/python3.7/site-packages (from requests; extra == “http”->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.4.8) (3.0.4)
Requirement already satisfied: async-timeout<4.0,>=3.0 in /opt/conda/lib/python3.7/site-packages (from aiohttp; extra == “http”->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.4.8) (3.0.1)
Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.7/site-packages (from aiohttp; extra == “http”->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.4.8) (19.3.0)
Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/lib/python3.7/site-packages (from aiohttp; extra == “http”->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.4.8) (1.5.1)
Requirement already satisfied: multidict<5.0,>=4.5 in /opt/conda/lib/python3.7/site-packages (from aiohttp; extra == “http”->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning==1.4.8) (4.7.6)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /opt/conda/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard>=2.2.0->pytorch_lightning==1.4.8) (0.2.7)
Requirement already satisfied: rsa<4.1,>=3.1.4 in /opt/conda/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard>=2.2.0->pytorch_lightning==1.4.8) (4.0)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /opt/conda/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard>=2.2.0->pytorch_lightning==1.4.8) (3.1.1)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /opt/conda/lib/python3.7/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch_lightning==1.4.8) (1.2.0)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /opt/conda/lib/python3.7/site-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard>=2.2.0->pytorch_lightning==1.4.8) (0.4.8)
Requirement already satisfied: oauthlib>=3.0.0 in /opt/conda/lib/python3.7/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch_lightning==1.4.8) (3.0.1)
Installing collected packages: torchmetrics, pyDeprecate, fsspec, pytorch-lightning
Attempting uninstall: fsspec
Found existing installation: fsspec 0.8.2
Uninstalling fsspec-0.8.2:
Successfully uninstalled fsspec-0.8.2
Attempting uninstall: pytorch-lightning
Found existing installation: pytorch-lightning 0.9.0
Uninstalling pytorch-lightning-0.9.0:
Successfully uninstalled pytorch-lightning-0.9.0
Successfully installed fsspec-2021.11.0 pyDeprecate-0.3.1 pytorch-lightning-1.4.8 torchmetrics-0.6.0
WARNING: You are using pip version 20.2.3; however, version 21.3.1 is available.
You should consider upgrading via the ‘/opt/conda/bin/python3.7 -m pip install –upgrade pip’ command.
WARNING: You are using pip version 20.2.3; however, version 21.3.1 is available.
You should consider upgrading via the ‘/opt/conda/bin/python3.7 -m pip install –upgrade pip’ command.

# Re-loads all imports every time the cell is ran. 
%load_ext autoreload
%autoreload 2

from time import time

import numpy as np
import pandas as pd
pd.options.display.float_format = '{:,.5f}'.format

from IPython.display import display

# Sklearn tools
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# Neural Networks
import torch
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader

import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers.csv_logs import CSVLogger
from pytorch_lightning.loggers import WandbLogger
# Plotting
%matplotlib inline
import matplotlib.pyplot as plt

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

        ```
/kaggle/input/electric-power-consumption-data-set/household_power_consumption.txt
/kaggle/input/historic-gold-prices/goldx.csv
Prediction task
We are going to predict hourly levels of global active power one step ahead.

TimeseriesDataset

class TimeseriesDataset(Dataset):

’‘’
Custom Dataset subclass.
Serves as input to DataLoader to transform X
into sequence data using rolling window.
DataLoader using this dataset will output batches
of (batch_size, seq_len, n_features) shape.
Suitable as an input to RNNs.
根据序列长度截取,作为一个输出
‘’‘
def init(self, X: np.ndarray, y: np.ndarray, seq_len: int = 1):
self.X = torch.tensor(X).float()
self.y = torch.tensor(y).float()
self.seq_len = seq_len

def __len__(self):
    return self.X.__len__() - (self.seq_len-1)

def __getitem__(self, index):
    # 一个序列产生
    return (self.X[index:index+self.seq_len], self.y[index+self.seq_len-1])
    ```

DataModule

path = '../input/historic-gold-prices/goldx.csv'


df1= pd.read\_csv(
==================


path
====


)
=


print(df1.head(10))
===================


df = pd.read\_csv(
path, 


sep=';',
========


parse\_dates={'dt' : ['Date']},
infer\_datetime\_format=True,
low\_memory=False,
na\_values=['nan','?'],
index\_col='dt'
)


print(df.head(10))
==================


df\_resample = df.resample('h').mean()


X = df\_resample.dropna().copy()


偏移一个,用来预测下一个的结果
===============


y = X['Price'].shift(-1).ffill()


self.columns = X.columns
========================


df.head(3),X.head(3),y.head(3)


( Price Open High Low
dt

2018-08-01 1,216.60000 1,223.40000 1,223.40000 1,216.20000
2018-07-31 1,223.70000 1,220.40000 1,228.10000 1,213.00000
2018-07-30 1,221.30000 1,222.50000 1,223.90000 1,218.10000,
Price Open High Low
dt

1979-12-27 515.50000 517.00000 517.00000 513.00000
1979-12-28 517.80000 516.00000 517.80000 510.40000
1979-12-31 533.60000 527.88000 534.50000 527.88000,
dt
1979-12-27 517.80000
1979-12-28 533.60000
1979-12-31 575.50000
Name: Price, dtype: float64)

class PowerConsumptionDataModule(pl.LightningDataModule):
'''
PyTorch Lighting DataModule subclass:
<https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html>



Serves the purpose of aggregating all data loading
and processing work in one place.
‘’‘

def init(self, seq_len = 1, batch_size = 128, num_workers=0):
super().init()
self.seq_len = seq_len
self.batch_size = batch_size
self.num_workers = num_workers
self.X_train = None
self.y_train = None
self.X_val = None
self.y_val = None
self.X_test = None
self.X_test = None
self.columns = None
self.preprocessing = None

def prepare_data(self):
pass

def setup(self, stage=None):
‘’‘
Data is resampled to hourly intervals.
Both ‘np.nan’ and ‘?’ are converted to ‘np.nan’
‘Date’ and ‘Time’ columns are merged into ‘dt’ index
‘’‘

if stage == 'fit' and self.X_train is not None:
    return 
if stage == 'test' and self.X_test is not None:
    return
if stage is None and self.X_train is not None and self.X_test is not None:  
    return
# 处理数据
path = '../input/historic-gold-prices/goldx.csv'

df1= pd.read\_csv(
==================


path
====


)
=


print(df1.head(10))
===================



df = pd.read_csv(
    path,

sep=';',
========



    parse_dates={'dt' : ['Date']}, 
    infer_datetime_format=True, 
    low_memory=False, 
    na_values=['nan','?'], 
    index_col='dt'
)

print(df.head(10))
==================



df_resample = df.resample('h').mean()

X = df_resample.dropna().copy()

# 偏移一个,用来预测下一个的结果
y = X['Price'].shift(-1).ffill()
self.columns = X.columns

X_cv, X_test, y_cv, y_test = train_test_split(
    X, y, test_size=0.2, shuffle=False
)

X_train, X_val, y_train, y_val = train_test_split(
    X_cv, y_cv, test_size=0.25, shuffle=False
)

preprocessing = StandardScaler()
preprocessing.fit(X_train)

if stage == 'fit' or stage is None:
    self.X_train = preprocessing.transform(X_train)
    self.y_train = y_train.values.reshape((-1, 1))
    self.X_val = preprocessing.transform(X_val)
    self.y_val = y_val.values.reshape((-1, 1))

if stage == 'test' or stage is None:
    self.X_test = preprocessing.transform(X_test)
    self.y_test = y_test.values.reshape((-1, 1))

def train_dataloader(self):


seq\_len 作为步长
=============



train_dataset = TimeseriesDataset(self.X_train, 
                                  self.y_train, 
                                  seq_len=self.seq_len)

train_loader = DataLoader(train_dataset, 
                          batch_size = self.batch_size, 
                          shuffle = True, 
                          num_workers = self.num_workers)

return train_loader

def val_dataloader(self):
val_dataset = TimeseriesDataset(self.X_val,
self.y_val,
seq_len=self.seq_len)
val_loader = DataLoader(val_dataset,
batch_size = self.batch_size,
shuffle = False,
num_workers = self.num_workers)

return val_loader

def test_dataloader(self):
test_dataset = TimeseriesDataset(self.X_test,
self.y_test,
seq_len=self.seq_len)
test_loader = DataLoader(test_dataset,
batch_size = self.batch_size,
shuffle = False,
num_workers = self.num_workers)

return test_loader

```

处理数据示例



dd=PowerConsumptionDataModule(seq_len = 5, batch_size = 16)
dd.setup(“test”)
for x,y in dd.test_dataloader():
print(x.size(),y.size())
print(x,y)
break

```

torch.Size([16, 5, 4]) torch.Size([16, 1])
tensor([[[12.5412, 12.4623, 12.2739, 12.6585],
[12.5503, 12.5846, 12.4046, 12.8737],
[12.3597, 12.6132, 12.4084, 12.5894],
[12.6509, 12.2491, 12.4674, 12.4473],
[12.5608, 12.5742, 12.4584, 12.6545]],

    [[12.5503, 12.5846, 12.4046, 12.8737],
     [12.3597, 12.6132, 12.4084, 12.5894],
     [12.6509, 12.2491, 12.4674, 12.4473],
     [12.5608, 12.5742, 12.4584, 12.6545],
     [12.6365, 12.4974, 12.4200, 12.6811]],

    [[12.3597, 12.6132, 12.4084, 12.5894],
     [12.6509, 12.2491, 12.4674, 12.4473],
     [12.5608, 12.5742, 12.4584, 12.6545],
     [12.6365, 12.4974, 12.4200, 12.6811],
     [12.5307, 12.6392, 12.4341, 12.7037]],

    [[12.6509, 12.2491, 12.4674, 12.4473],
     [12.5608, 12.5742, 12.4584, 12.6545],
     [12.6365, 12.4974, 12.4200, 12.6811],
     [12.5307, 12.6392, 12.4341, 12.7037],
     [12.8024, 12.5300, 12.5852, 12.6758]],

    [[12.5608, 12.5742, 12.4584, 12.6545],
     [12.6365, 12.4974, 12.4200, 12.6811],
     [12.5307, 12.6392, 12.4341, 12.7037],
     [12.8024, 12.5300, 12.5852, 12.6758],
     [12.7501, 12.7796, 12.6390, 12.9428]],

    [[12.6365, 12.4974, 12.4200, 12.6811],
     [12.5307, 12.6392, 12.4341, 12.7037],
     [12.8024, 12.5300, 12.5852, 12.6758],
     [12.7501, 12.7796, 12.6390, 12.9428],
     [12.7410, 12.6626, 12.5596, 12.9268]],

    [[12.5307, 12.6392, 12.4341, 12.7037],
     [12.8024, 12.5300, 12.5852, 12.6758],
     [12.7501, 12.7796, 12.6390, 12.9428],
     [12.7410, 12.6626, 12.5596, 12.9268],
     [12.9473, 12.7328, 12.7287, 13.0092]],

    [[12.8024, 12.5300, 12.5852, 12.6758],
     [12.7501, 12.7796, 12.6390, 12.9428],
     [12.7410, 12.6626, 12.5596, 12.9268],
     [12.9473, 12.7328, 12.7287, 13.0092],
     [12.9656, 12.8837, 12.7223, 13.1367]],

    [[12.7501, 12.7796, 12.6390, 12.9428],
     [12.7410, 12.6626, 12.5596, 12.9268],
     [12.9473, 12.7328, 12.7287, 13.0092],
     [12.9656, 12.8837, 12.7223, 13.1367],
     [12.9277, 12.9032, 12.6967, 13.0317]],

    [[12.7410, 12.6626, 12.5596, 12.9268],
     [12.9473, 12.7328, 12.7287, 13.0092],
     [12.9656, 12.8837, 12.7223, 13.1367],
     [12.9277, 12.9032, 12.6967, 13.0317],
     [12.9016, 12.8993, 12.7517, 13.0782]],

    [[12.9473, 12.7328, 12.7287, 13.0092],
     [12.9656, 12.8837, 12.7223, 13.1367],
     [12.9277, 12.9032, 12.6967, 13.0317],
     [12.9016, 12.8993, 12.7517, 13.0782],
     [12.9630, 12.8187, 12.7107, 13.0968]],

    [[12.9656, 12.8837, 12.7223, 13.1367],
     [12.9277, 12.9032, 12.6967, 13.0317],
     [12.9016, 12.8993, 12.7517, 13.0782],
     [12.9630, 12.8187, 12.7107, 13.0968],
     [13.0805, 12.8681, 12.8440, 13.1632]],

    [[12.9277, 12.9032, 12.6967, 13.0317],
     [12.9016, 12.8993, 12.7517, 13.0782],
     [12.9630, 12.8187, 12.7107, 13.0968],
     [13.0805, 12.8681, 12.8440, 13.1632],
     [13.0949, 13.0241, 12.9208, 13.2815]],

    [[12.9016, 12.8993, 12.7517, 13.0782],
     [12.9630, 12.8187, 12.7107, 13.0968],
     [13.0805, 12.8681, 12.8440, 13.1632],
     [13.0949, 13.0241, 12.9208, 13.2815],
     [13.2255, 13.0774, 12.9554, 13.3678]],

    [[12.9630, 12.8187, 12.7107, 13.0968],
     [13.0805, 12.8681, 12.8440, 13.1632],
     [13.0949, 13.0241, 12.9208, 13.2815],
     [13.2255, 13.0774, 12.9554, 13.3678],
     [13.2712, 13.1607, 13.0464, 13.4316]],

    [[13.0805, 12.8681, 12.8440, 13.1632],
     [13.0949, 13.0241, 12.9208, 13.2815],
     [13.2255, 13.0774, 12.9554, 13.3678],
     [13.2712, 13.1607, 13.0464, 13.4316],
     [13.4318, 13.3037, 13.2628, 13.6082]]]) tensor([[1339.6000],
    [1331.5000],
    [1352.3000],
    [1348.3000],
    [1347.6000],
    [1363.4000],
    [1364.8000],
    [1361.9000],
    [1359.9000],
    [1364.6000],
    [1373.6000],
    [1374.7000],
    [1384.7000],
    [1388.2000],
    [1400.5000],
    [1413.4000]])

Model
Implement LSTM regressor using pytorch-lighting module


class LSTMRegressor(pl.LightningModule):
'''
Standard PyTorch Lightning module:
<https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html>
'''
def **init**(self,
n\_features,
hidden\_size,
seq\_len,
batch\_size,
num\_layers,
dropout,
learning\_rate,
criterion,
optimizer\_name="AdamW"
):
super(LSTMRegressor, self).**init**()
self.save\_hyperparameters()



self.n_features = n_features
self.hidden_size = hidden_size
self.seq_len = seq_len
self.batch_size = batch_size
self.num_layers = num_layers
self.dropout = dropout
self.criterion = criterion
self.learning_rate = learning_rate

self.norm=nn.LayerNorm(n\_features)
===================================



self.lstm = nn.LSTM(input_size=n_features, 
                    hidden_size=hidden_size,
                    num_layers=num_layers, 
                    dropout=dropout, 
                    batch_first=True)
self.linear = nn.Linear(hidden_size, 1)

def forward(self, x):
# lstm_out = (batch_size, seq_len, hidden_size)


x=self.norm(x)
==============



lstm_out, _ = self.lstm(x)
y_pred = self.linear(lstm_out[:,-1])
return y_pred

def configure\_optimizers(self):
================================


return torch.optim.Adam(self.parameters(), lr=self.learning\_rate)
==================================================================



def configure_optimizers(self):
“”“优化器 # 类似于余弦,但其周期是变化的,初始周期为T_0,而后周期会✖️T_mult。每个周期学习率由大变小; https://www.notion.so/62e72678923f4e8aa04b73dc3eefaf71"""

optimizer = torch.optim.AdamW(self.parameters(), lr=(self.learning_rate))

    #只优化部分

optimizer = torch.optim.AdamW(self.parameters(), lr=(self.hparams.learning\_rate))
==================================================================================



    optimizer = getattr(torch.optim, self.hparams.optimizer_name)(self.parameters(), lr=self.hparams.learning_rate)
    #         使用自适应调整模型
    T_mult=2
    scheduler =torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=100,T_mult=T_mult,eta_min=0 )

https://github.com/PyTorchLightning/pytorch-lightning/blob/6dc1078822c33fa4710618dc2f03945123edecec/pytorch_lightning/core/lightning.py#L1119

    lr_scheduler={

‘optimizer’: optimizer,

       'scheduler': scheduler,

'reduce\_on\_plateau': True, # For ReduceLROnPlateau scheduler
==============================================================



        'interval': 'step', #epoch/step
        'frequency': 10,
        'name':"lr_scheduler",
        'monitor': 'train_loss', #监听数据变化
        'strict': True,
    }

return [optimizer], [lr_scheduler]

    return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

def training_step(self, batch, batch_idx):
x, y = batch

y_hat = self(x)

print(x, y ,y\_hat)
===================



loss = self.criterion(y_hat, y)

result = pl.TrainResult(loss)
=============================



self.log('train_loss', loss)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.criterion(y_hat, y)


result = pl.EvalResult(checkpoint\_on=loss)
===========================================



self.log('val_loss', loss)
return loss

def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.criterion(y_hat, y)


result = pl.EvalResult()
========================



self.log('test_loss', loss)
return loss


Parameters
‘’‘
All parameters are aggregated in one place.
This is useful for reporting experiment params to experiment tracking software
‘’‘


p = dict(
seq\_len = 30, # 基于多少天预测
batch\_size = 16,
criterion = nn.MSELoss(),
max\_epochs = 500,
n\_features = 4,
hidden\_size = 256,
num\_layers = 2,
dropout = 0.2,
learning\_rate = 0.001,
)



Train loop


from pytorch\_lightning.callbacks import ModelCheckpoint


自动停止
====


<https://pytorch-lightning.readthedocs.io/en/1.2.1/common/early_stopping.html>
==============================================================================


from pytorch\_lightning.callbacks.early\_stopping import EarlyStopping
from pytorch\_lightning.callbacks import LearningRateMonitor


seed\_everything(288)


early\_stop\_callback = EarlyStopping(
monitor='val\_loss',
min\_delta=0.0000,
patience=50,
verbose=True,
mode='min',


check\_finite=True,
===================


check\_on\_train\_epoch\_end=True,
==================================


divergence\_threshold=0.1
=========================


)
checkpoint\_callback = ModelCheckpoint(


filename='/kaggle/working/{epoch}-{val\_loss:.2f}',
===================================================


dirpath="/kaggle/working/checkpoints",
======================================


filename='checkpoint',
======================


filename='/kaggle/working/bart-out',
====================================


save\_last=True,
================



verbose=True,
monitor=’val_loss’,


every\_n\_train\_steps=2,
=========================



mode=’min’,


best\_model\_path='best'
========================



save_top_k=1


)
lr\_monitor = LearningRateMonitor(logging\_interval='step')


csv\_logger = CSVLogger('./', name='lstm', version='0'),
========================================================


wandb\_logger = WandbLogger(project='金价走势 PyTorch Lightning,训练时 LSTM, Timeseries, Cl001')
trainer = Trainer(
max\_epochs=p['max\_epochs'],
precision=16,amp\_level='O2',
checkpoint\_callback=True,


callbacks=[PyTorchLightningPruningCallback(trial, monitor="train\_loss")],
==========================================================================


resume\_from\_checkpoint="../input/seq2seq-bart-model/新分词方案seq2seq—预训练BartForConditionalGeneration\_uer\_bart-chinese-4-768-cluecorpussmall/2oij20bh/checkpoints/checkpoint.ckpt",
==================================================================================================================================================================================



auto_select_gpus=True,
callbacks=[
lr_monitor,
early_stop_callback,checkpoint_callback
],
logger=wandb_logger,
gpus=1,
weights_summary=”top”, #开始打印参数


row\_log\_interval=1,
=====================



progress_bar_refresh_rate=2,
```

)

model = LSTMRegressor(
n_features = p[‘n_features’],
hidden_size = p[‘hidden_size’],
seq_len = p[‘seq_len’],
batch_size = p[‘batch_size’],
criterion = p[‘criterion’],
num_layers = p[‘num_layers’],
dropout = p[‘dropout’],
learning_rate = p[‘learning_rate’]
)

dm = PowerConsumptionDataModule(
seq_len = p[‘seq_len’],
batch_size = p[‘batch_size’]
)

trainer.fit(model, dm)
trainer.test(model, datamodule=dm)

Tags:

related content