BiGRU孪生网络训练相关程度
Published on Aug. 22, 2023, 12:10 p.m.
本例使用BiGRU来实现sentence bert,尝试用简单的模型来实现。
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
for filename in filenames:
print(os.path.join(dirname, filename))
# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All"
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session
/kaggle/input/notebook134e5b9624/__results__.html
/kaggle/input/notebook134e5b9624/__notebook__.ipynb
/kaggle/input/notebook134e5b9624/__output__.json
/kaggle/input/notebook134e5b9624/custom.css
/kaggle/input/notebook134e5b9624/data/val.pkt
/kaggle/input/notebook134e5b9624/data/labels.json
/kaggle/input/notebook134e5b9624/data/train.pkt
/kaggle/input/notebook134e5b9624/data/test.pkt
使用孪生网络训练
!pip install pytorch_lightning
!pip install wandb -q
!pip install tkitAutoMask
!pip install transformers
Requirement already satisfied: pytorch_lightning in /opt/conda/lib/python3.7/site-packages (1.4.4)
Requirement already satisfied: torch>=1.6 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning) (1.9.1)
Requirement already satisfied: future>=0.17.1 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning) (0.18.2)
Requirement already satisfied: packaging>=17.0 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning) (21.0)
Requirement already satisfied: tqdm>=4.41.0 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning) (4.62.3)
Requirement already satisfied: fsspec[http]!=2021.06.0,>=2021.05.0 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning) (2021.10.1)
Requirement already satisfied: torchmetrics>=0.4.0 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning) (0.5.0)
Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning) (3.10.0.2)
Requirement already satisfied: pyDeprecate==0.3.1 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning) (0.3.1)
Requirement already satisfied: PyYAML>=5.1 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning) (5.4.1)
Requirement already satisfied: tensorboard>=2.2.0 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning) (2.6.0)
Requirement already satisfied: numpy>=1.17.2 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning) (1.19.5)
Requirement already satisfied: aiohttp in /opt/conda/lib/python3.7/site-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (3.7.4.post0)
Requirement already satisfied: requests in /opt/conda/lib/python3.7/site-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (2.25.1)
Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/lib/python3.7/site-packages (from packaging>=17.0->pytorch_lightning) (2.4.7)
Requirement already satisfied: setuptools>=41.0.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (58.0.4)
Requirement already satisfied: wheel>=0.26 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (0.37.0)
Requirement already satisfied: google-auth=1.6.3 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (1.35.0)
Requirement already satisfied: markdown>=2.6.8 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (3.3.4)
Requirement already satisfied: tensorboard-data-server=0.6.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (0.6.1)
Requirement already satisfied: grpcio>=1.24.3 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (1.38.1)
Requirement already satisfied: protobuf>=3.6.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (3.19.0)
Requirement already satisfied: absl-py>=0.4 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (0.14.0)
Requirement already satisfied: werkzeug>=0.11.15 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (2.0.1)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (1.8.0)
Requirement already satisfied: google-auth-oauthlib=0.4.1 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (0.4.6)
Requirement already satisfied: six in /opt/conda/lib/python3.7/site-packages (from absl-py>=0.4->tensorboard>=2.2.0->pytorch_lightning) (1.16.0)
Requirement already satisfied: rsa=3.1.4 in /opt/conda/lib/python3.7/site-packages (from google-auth=1.6.3->tensorboard>=2.2.0->pytorch_lightning) (4.7.2)
Requirement already satisfied: cachetools=2.0.0 in /opt/conda/lib/python3.7/site-packages (from google-auth=1.6.3->tensorboard>=2.2.0->pytorch_lightning) (4.2.2)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /opt/conda/lib/python3.7/site-packages (from google-auth=1.6.3->tensorboard>=2.2.0->pytorch_lightning) (0.2.7)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /opt/conda/lib/python3.7/site-packages (from google-auth-oauthlib=0.4.1->tensorboard>=2.2.0->pytorch_lightning) (1.3.0)
Requirement already satisfied: importlib-metadata in /opt/conda/lib/python3.7/site-packages (from markdown>=2.6.8->tensorboard>=2.2.0->pytorch_lightning) (4.8.1)
Requirement already satisfied: pyasn1=0.4.6 in /opt/conda/lib/python3.7/site-packages (from pyasn1-modules>=0.2.1->google-auth=1.6.3->tensorboard>=2.2.0->pytorch_lightning) (0.4.8)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (2021.10.8)
Requirement already satisfied: chardet=3.0.2 in /opt/conda/lib/python3.7/site-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (4.0.0)
Requirement already satisfied: urllib3=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (1.26.6)
Requirement already satisfied: idna=2.5 in /opt/conda/lib/python3.7/site-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (2.10)
Requirement already satisfied: oauthlib>=3.0.0 in /opt/conda/lib/python3.7/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib=0.4.1->tensorboard>=2.2.0->pytorch_lightning) (3.1.1)
Requirement already satisfied: yarl=1.0 in /opt/conda/lib/python3.7/site-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (1.6.3)
Requirement already satisfied: async-timeout=3.0 in /opt/conda/lib/python3.7/site-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (3.0.1)
Requirement already satisfied: multidict=4.5 in /opt/conda/lib/python3.7/site-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (5.1.0)
Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.7/site-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (21.2.0)
Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata->markdown>=2.6.8->tensorboard>=2.2.0->pytorch_lightning) (3.5.0)
[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv[0m
[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv[0m
Collecting tkitAutoMask
Downloading tkitAutoMask-0.0.0.316350799-py3-none-any.whl (9.2 kB)
Installing collected packages: tkitAutoMask
Successfully installed tkitAutoMask-0.0.0.316350799
[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv[0m
Requirement already satisfied: transformers in /opt/conda/lib/python3.7/site-packages (4.5.1)
Requirement already satisfied: sacremoses in /opt/conda/lib/python3.7/site-packages (from transformers) (0.0.46)
Requirement already satisfied: filelock in /opt/conda/lib/python3.7/site-packages (from transformers) (3.0.12)
Requirement already satisfied: importlib-metadata in /opt/conda/lib/python3.7/site-packages (from transformers) (4.8.1)
Requirement already satisfied: packaging in /opt/conda/lib/python3.7/site-packages (from transformers) (21.0)
Requirement already satisfied: requests in /opt/conda/lib/python3.7/site-packages (from transformers) (2.25.1)
Requirement already satisfied: tokenizers=0.10.1 in /opt/conda/lib/python3.7/site-packages (from transformers) (0.10.3)
Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.7/site-packages (from transformers) (2021.8.28)
Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.7/site-packages (from transformers) (1.19.5)
Requirement already satisfied: tqdm>=4.27 in /opt/conda/lib/python3.7/site-packages (from transformers) (4.62.3)
Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata->transformers) (3.5.0)
Requirement already satisfied: typing-extensions>=3.6.4 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata->transformers) (3.10.0.2)
Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/lib/python3.7/site-packages (from packaging->transformers) (2.4.7)
Requirement already satisfied: chardet=3.0.2 in /opt/conda/lib/python3.7/site-packages (from requests->transformers) (4.0.0)
Requirement already satisfied: idna=2.5 in /opt/conda/lib/python3.7/site-packages (from requests->transformers) (2.10)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests->transformers) (2021.10.8)
Requirement already satisfied: urllib3=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests->transformers) (1.26.6)
Requirement already satisfied: joblib in /opt/conda/lib/python3.7/site-packages (from sacremoses->transformers) (1.0.1)
Requirement already satisfied: click in /opt/conda/lib/python3.7/site-packages (from sacremoses->transformers) (8.0.1)
Requirement already satisfied: six in /opt/conda/lib/python3.7/site-packages (from sacremoses->transformers) (1.16.0)
[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv[0m
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb")
os.environ["WANDB_SILENT"] = "true"
os.environ["WANDB_API_KEY"] = secret_value_0
import torch
from torch import nn
import torch.nn.functional as F
# from torch.utils.data import DataLoader, random_split,TensorDataset
from transformers import BertTokenizer, AlbertModel,AutoModel,AutoTokenizer,AutoModelForTokenClassification,AutoConfig
from torch.utils.data import DataLoader, random_split,TensorDataset
from transformers import BertTokenizer,AutoModel,AutoTokenizer,AutoConfig
import torch.optim as optim
from tqdm.auto import tqdm
import torchmetrics
import pytorch_lightning as pl
import re
import random
import tqdm
import gzip
import numpy as np
import torch
import os
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint,LearningRateMonitor
# 自动停止
# https://pytorch-lightning.readthedocs.io/en/1.2.1/common/early_stopping.html
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import torchmetrics
from torchmetrics.functional import precision_recall,precision_recall_curve,f1
# from tkitAttNLocal import AttNLocal
# from .lr import CyclicCosineDecayLR
from tkitAutoMask import autoMask
import csv
预处理数据
class SimICD(pl.LightningModule):
"""
用于icd编码标准化的模型
参考sentent bert模式,加入双向gru来降低对参数的依赖。
"""
def __init__(
self,learning_rate=3e-4,
T_max=5,
ignore_index=0,max_len=256,
optimizer_name="AdamW",
dropout=0.2,
labels=2,
pretrained="uer/chinese_roberta_L-2_H-128",
batch_size=2,
trainfile="./data/train.pkt",
valfile="./data/val.pkt",
testfile="./data/test.pkt",
**kwargs):
super().__init__()
self.save_hyperparameters()
self.tokenizer = BertTokenizer.from_pretrained(pretrained)
config = AutoConfig.from_pretrained(pretrained)
self.model = AutoModel.from_pretrained(pretrained,config=config)
self.rnn = nn.GRU(config.hidden_size, config.hidden_size,dropout=dropout,num_layers=2,bidirectional=True)
# self.rnn = nn.LSTM(config.hidden_size, config.hidden_size,dropout=dropout,num_layers=2,bidirectional=True)
self.pre_classifier=nn.Linear(config.hidden_size*6,config.hidden_size)
self.dropout = torch.nn.Dropout(dropout)
self.classifier = torch.nn.Linear(config.hidden_size, 1)
# self.classifierSigmoid = torch.nn.Sigmoid()
self.tomask = autoMask(
# transformer,
mask_token_id = self.tokenizer.mask_token_id, # the token id reserved for masking
pad_token_id = self.tokenizer.pad_token_id, # the token id for padding
mask_prob = 0.05, # masking probability for masked language modeling
replace_prob = 0.90, # ~10% probability that token will not be masked, but included in loss, as detailed in the epaper
mask_ignore_token_ids = [self.tokenizer.cls_token_id,self.tokenizer.eos_token_id] # other tokens to exclude from masking, include the [cls] and [sep] here
)
def forward(self, input_ids_a,input_ids_b,attention_mask_a=None,attention_mask_b=None):
"""
分类解决方案
"""
B,L=input_ids_a.size()
# print(input_ids_a.size())
outputs_a=self.model(input_ids=input_ids_a,attention_mask=attention_mask_a)
# Perform pooling. In this case, max pooling.
# emb_a = self.mean_pooling(outputs_a, attention_mask_a)
emb_a,_=self.rnn(outputs_a[0])
# print(emb_a.size(),emb_a.sum(1).size())
# emb_a=emb_a.sum(1).view(B,-1)
emb_a = self.mean_pooling(emb_a, attention_mask_a)
# print(emb_a.size())
outputs_b=self.model(input_ids=input_ids_b,attention_mask=attention_mask_b)
emb_b,_=self.rnn(outputs_b[0])
# Perform pooling. In this case, max pooling.
emb_b = self.mean_pooling(emb_b, attention_mask_b)
# _,emb_b=self.rnn(outputs_b[0])
# emb_b=emb_b.sum(1).view(B,-1)
emb_diff=emb_a-emb_b
emb=torch.cat((emb_a,emb_b,emb_diff.abs()),-1)
pooler=self.pre_classifier(emb)
# cos = nn.CosineSimilarity(dim=1, eps=1e-8)
# sim=cos(emb_a, emb_b)
pooler = torch.nn.ReLU()(pooler)
pooler = self.dropout(pooler)
output = self.classifier(pooler)
# output = self.classifierSigmoid(output)
return output
def mean_pooling(self,model_output, attention_mask):
token_embeddings = model_output #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# def mean_pooling(self,model_output, attention_mask):
# token_embeddings = model_output[0] #First element of model_output contains all token embeddings
# input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
# return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# def loss_fc(self,out,label):
# loss_fc = nn.L1Loss()
# loss=loss_fc(out,label)
# return loss
def loss_fc(self,out,label):
"""
jisuan 离散的损失
"""
# loss_fc = nn.CrossEntropyLoss()
loss_fc = nn.BCEWithLogitsLoss()
# print(out.size(),label.size())
# .view(-1,self.hparams.labels)
loss=loss_fc(out.view(-1),label.view(-1))
return loss
def training_step(self, batch, batch_idx):
# training_step defined the train loop.
# It is independent of forward
input_ids_a,attention_mask_a,input_ids_b,attention_mask_b,labels = batch
# input_ids_a=self.tomask(input_ids_a)[0]
# input_ids_b=self.tomask(input_ids_b)[0]
out=self(input_ids_a,input_ids_b,attention_mask_a,attention_mask_b)
loss=self.loss_fc(out,labels)
self.log('train_loss',loss)
return loss
def validation_step(self, batch, batch_idx):
# training_step defined the train loop.
# It is independent of forward
input_ids_a,attention_mask_a,input_ids_b,attention_mask_b,labels = batch
out=self(input_ids_a,input_ids_b,attention_mask_a,attention_mask_b)
loss=self.loss_fc(out,labels)
self.log('val_loss',loss)
return loss
def test_step(self, batch, batch_idx):
# training_step defined the train loop.
# It is independent of forward
input_ids_a,attention_mask_a,input_ids_b,attention_mask_b,labels = batch
out=self(input_ids_a,input_ids_b,attention_mask_a,attention_mask_b)
loss=self.loss_fc(out,labels)
self.log('test_loss',loss)
return loss
def train_dataloader(self):
train=torch.load(self.hparams.trainfile)
return DataLoader(train, batch_size=int(self.hparams.batch_size),num_workers=2,pin_memory=True, shuffle=True)
def val_dataloader(self):
val=torch.load(self.hparams.valfile)
return DataLoader(val, batch_size=int(self.hparams.batch_size),num_workers=2,pin_memory=True)
def test_dataloader(self):
val=torch.load(self.hparams.testfile)
return DataLoader(val, batch_size=int(self.hparams.batch_size),num_workers=2,pin_memory=True)
def configure_optimizers(self):
"""优化器 自动优化器"""
optimizer = getattr(optim, self.hparams.optimizer_name)(self.parameters(), lr=self.hparams.learning_rate)
# 使用自适应调整模型
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',patience=5000,factor=0.8,verbose=True)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 2000, 500)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=500, eta_min=1e-8)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,max_lr=1e-3, T_max=1000, eta_min=1e-8)
# scheduler = CyclicCosineDecayLR(optimizer,
# init_decay_epochs=10,
# min_decay_lr=1e-8,
# restart_interval = 5,
# restart_lr=self.hparams.learning_rate/1,
# restart_interval_multiplier=1.5,
# warmup_epochs=10,
# warmup_start_lr=self.hparams.learning_rate/10)
#
lr_scheduler={
'scheduler': scheduler,
'interval': 'step',
'frequency': 1,
'name':"lr_scheduler",
'monitor': 'train_loss', #监听数据变化
'strict': True,
}
# return [optimizer], [lr_scheduler]
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
# /kaggle/input/reformerchinesemodel/epoch4step21209.ckpt
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
# 自动停止
# https://pytorch-lightning.readthedocs.io/en/1.2.1/common/early_stopping.html
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor
# 引入修剪技术 https://pytorch-lightning.readthedocs.io/en/stable/advanced/pruning_quantization.html
from pytorch_lightning.callbacks import ModelPruning
import torch.nn.utils.prune as prune
# https://pytorch-lightning.readthedocs.io/en/latest/common/early_stopping.html
# 量化 降低内存 低精度 https://pytorch-lightning.readthedocs.io/en/stable/advanced/pruning_quantization.html
from pytorch_lightning.callbacks import QuantizationAwareTraining
# 使用 DDP 时设置 find_unused_parameters=False
# 默认情况下,我们已启用查找未使用的参数为 True。这是针对过去出现的兼容性问题(有关更多信息,请参阅讨论)。默认情况下,这会影响性能,并且在大多数情况下可以禁用。
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning.loggers import WandbLogger
seed_everything(2021)
early_stop_callback = EarlyStopping(
monitor='val_loss',
min_delta=0.0000,
patience=5,
verbose=True,
mode='min',
# check_finite=True,
# check_on_train_epoch_end=True,
# divergence_threshold=0.1
)
checkpoint_callback = ModelCheckpoint(
# filename='/kaggle/working/{epoch}-{val_loss:.2f}',
# dirpath="/kaggle/working/checkpoints",
filename='checkpoint',
# filename='/kaggle/working/bart-out',
save_last=True,
verbose=True,
monitor='val_loss',
# every_n_train_steps=2,
mode='min',
# best_model_path='best'
save_top_k=1
)
lr_monitor = LearningRateMonitor(logging_interval='step')
wandb_logger = WandbLogger(project='GRU孪生网络训练相关程度notebook82fd4fd669训练')
# profilers=pl.profiler.profilers.SimpleProfiler()
# model=LitAutoEncDec(learning_rate=3e-4,T_max=5,optimizer_name="AdamW",batch_size=96)
model= SimICD(
learning_rate=7e-5,
T_max=5,
ignore_index=0,
optimizer_name="AdamW",
dropout=0.2,
labels=2,
pretrained="uer/chinese_roberta_L-2_H-512",
batch_size=1024,
trainfile="../input/notebook134e5b9624/data/train.pkt",
valfile="../input/notebook134e5b9624/data/val.pkt",
testfile="../input/notebook134e5b9624/data/test.pkt"
)
trainer = pl.Trainer(
gpus=1,
# min_epochs=1,
precision=16,amp_level='O2',
# val_check_interval=0.5, #这里增加检查验证的频率
# limit_val_batches=0.5, # 限制验证的数目 降低每次验证的批次大小
checkpoint_callback=True,
# callbacks=[PyTorchLightningPruningCallback(trial, monitor="train_loss")],
# resume_from_checkpoint="../input/sposeq2seqbartuerbartchinese4768model/spo新分词方案seq2seq—预训练BartForConditionalGeneration_uer_bart-chinese-4-768-cluecorpussmall/36ppt8oy/checkpoints/checkpoint.ckpt",
auto_select_gpus=True,
callbacks=[lr_monitor,early_stop_callback,
# pruning,
checkpoint_callback
# QuantizationAwareTraining()
],
gradient_clip_val=0.5,
stochastic_weight_avg=True,# 随机加权平均https://pytorch-lightning.readthedocs.io/en/stable/advanced/training_tricks.html#stochastic-weight-averaging
max_epochs=20,
auto_lr_find=True,
logger=wandb_logger, #日志
# logger=neptune_logger,
# plugins=[DDPPlugin(find_unused_parameters=False)],
accumulate_grad_batches=1,
# overfit_batches=20, #过拟合一小部分训练数据 (float) 或一组批次 (int)。 小数据测试时候用它
# terminate_on_nan=True, # 出现nan则停止
weights_summary="top", #开始打印参数
flush_logs_every_n_steps=5,
log_every_n_steps=5,
# profiler=profilers,
# accelerator=
)
# trainer.tune(model)
trainer.fit(model)
Downloading: 0%| | 0.00/110k [00:00, ?B/s]
Downloading: 0%| | 0.00/112 [00:00, ?B/s]
Downloading: 0%| | 0.00/216 [00:00, ?B/s]
Downloading: 0%| | 0.00/466 [00:00, ?B/s]
Downloading: 0%| | 0.00/70.7M [00:00, ?B/s]
Some weights of BertModel were not initialized from the model checkpoint at uer/chinese_roberta_L-2_H-512 and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
CondaEnvException: Unable to determine environment
Please re-run this command with one of the following options:
* Provide an environment name via --name or -n
* Re-run this command inside an activated conda environment.
Validation sanity check: 0it [00:00, ?it/s]
Training: -1it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/callback_hook.py:103: LightningDeprecationWarning: The signature of `Callback.on_train_epoch_end` has changed in v1.3. `outputs` parameter has been removed. Support for the old signature will be removed in v1.5
"The signature of `Callback.on_train_epoch_end` has changed in v1.3."
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/callbacks/stochastic_weight_avg.py:190: UserWarning: SWA is currently only supported every epoch. Found {'scheduler': , 'name': 'lr\_scheduler', 'interval': 'step', 'frequency': 1, 'reduce\_on\_plateau': False, 'monitor': 'train\_loss', 'strict': True, 'opt\_idx': None}
rank\_zero\_warn(f"SWA is currently only supported every epoch. Found {scheduler\_cfg}")
Validating: 0it [00:00, ?it/s]
/opt/conda/lib/python3.7/site-packages/pytorch\_lightning/plugins/precision/precision\_plugin.py:138: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip\_grad\_norm\_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error\_if\_nonfinite=false will be required to retain the old behavior.
torch.nn.utils.clip\_grad\_norm\_(parameters, clip\_val)
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
# model.hparams
# trainer.fit(model)
# # /kaggle/input/reformerchinesemodel/epoch4step21209.ckpt
# from pytorch_lightning import Trainer, seed_everything
# from pytorch_lightning.callbacks import ModelCheckpoint
# # 自动停止
# # https://pytorch-lightning.readthedocs.io/en/1.2.1/common/early_stopping.html
# from pytorch_lightning.callbacks.early_stopping import EarlyStopping
# from pytorch_lightning.callbacks import LearningRateMonitor
# # 引入修剪技术 https://pytorch-lightning.readthedocs.io/en/stable/advanced/pruning_quantization.html
# from pytorch_lightning.callbacks import ModelPruning
# import torch.nn.utils.prune as prune
# # https://pytorch-lightning.readthedocs.io/en/latest/common/early_stopping.html
# # 量化 降低内存 低精度 https://pytorch-lightning.readthedocs.io/en/stable/advanced/pruning_quantization.html
# from pytorch_lightning.callbacks import QuantizationAwareTraining
# # 使用 DDP 时设置 find_unused_parameters=False
# # 默认情况下,我们已启用查找未使用的参数为 True。这是针对过去出现的兼容性问题(有关更多信息,请参阅讨论)。默认情况下,这会影响性能,并且在大多数情况下可以禁用。
# from pytorch_lightning.plugins import DDPPlugin
# from pytorch_lightning.loggers import WandbLogger
# early_stop_callback = EarlyStopping(
# monitor='val_loss',
# min_delta=0.0000,
# patience=5,
# verbose=True,
# mode='min',
# # check_finite=True,
# # check_on_train_epoch_end=True,
# # divergence_threshold=0.1
# )
# checkpoint_callback = ModelCheckpoint(
# # filename='/kaggle/working/{epoch}-{val_loss:.2f}',
# # dirpath="/kaggle/working/checkpoints",
# filename='checkpoint',
# # filename='/kaggle/working/bart-out',
# # save_last=True,
# verbose=True,
# monitor='val_loss',
# # every_n_train_steps=2,
# mode='min',
# # best_model_path='best'
# save_top_k=1
# )
# lr_monitor = LearningRateMonitor(logging_interval='step')
# wandb_logger = WandbLogger(project='GRU孪生网络下一句notebook82fd4fd669训练')
# # profilers=pl.profiler.profilers.SimpleProfiler()
# # model=LitAutoEncDec(learning_rate=3e-4,T_max=5,optimizer_name="AdamW",batch_size=96)
# model= SNext(
# learning_rate=5e-5,
# T_max=5,
# ignore_index=0,
# optimizer_name="AdamW",
# dropout=0.2,
# labels=2,
# pretrained="uer/chinese_roberta_L-2_H-128",
# batch_size=12,
# trainfile="../input/nspchinesedataset/train.pkt",
# valfile="../input/nspchinesedataset/val.pkt",
# testfile="../input/nspchinesedataset/test.pkt"
# )
# trainer = pl.Trainer(
# gpus=1,
# # min_epochs=1,
# precision=16,amp_level='O2',
# # val_check_interval=0.5, #这里增加检查验证的频率
# # limit_val_batches=0.5, # 限制验证的数目 降低每次验证的批次大小
# checkpoint_callback=True,
# # callbacks=[PyTorchLightningPruningCallback(trial, monitor="train_loss")],
# # resume_from_checkpoint="../input/sposeq2seqbartuerbartchinese4768model/spo新分词方案seq2seq—预训练BartForConditionalGeneration_uer_bart-chinese-4-768-cluecorpussmall/36ppt8oy/checkpoints/checkpoint.ckpt",
# auto_select_gpus=True,
# callbacks=[lr_monitor,early_stop_callback,
# # pruning,
# checkpoint_callback
# # QuantizationAwareTraining()
# ],
# gradient_clip_val=0.5,
# stochastic_weight_avg=True,# 随机加权平均https://pytorch-lightning.readthedocs.io/en/stable/advanced/training_tricks.html#stochastic-weight-averaging
# max_epochs=5,
# logger=wandb_logger, #日志
# # logger=neptune_logger,
# # plugins=[DDPPlugin(find_unused_parameters=False)],
# accumulate_grad_batches=10,
# # overfit_batches=20, #过拟合一小部分训练数据 (float) 或一组批次 (int)。 小数据测试时候用它
# # terminate_on_nan=True, # 出现nan则停止
# weights_summary="top", #开始打印参数
# flush_logs_every_n_steps=5,
# log_every_n_steps=5,
# # profiler=profilers,
# # accelerator=
# )
# # trainer.tune(model)
# trainer.fit(model)