BiGRU孪生网络训练相关程度

Published on Aug. 22, 2023, 12:10 p.m.

本例使用BiGRU来实现sentence bert,尝试用简单的模型来实现。

# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session
/kaggle/input/notebook134e5b9624/__results__.html
/kaggle/input/notebook134e5b9624/__notebook__.ipynb
/kaggle/input/notebook134e5b9624/__output__.json
/kaggle/input/notebook134e5b9624/custom.css
/kaggle/input/notebook134e5b9624/data/val.pkt
/kaggle/input/notebook134e5b9624/data/labels.json
/kaggle/input/notebook134e5b9624/data/train.pkt
/kaggle/input/notebook134e5b9624/data/test.pkt

使用孪生网络训练

!pip install pytorch_lightning
!pip install wandb -q
!pip install tkitAutoMask
!pip install transformers
Requirement already satisfied: pytorch_lightning in /opt/conda/lib/python3.7/site-packages (1.4.4)
Requirement already satisfied: torch>=1.6 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning) (1.9.1)
Requirement already satisfied: future>=0.17.1 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning) (0.18.2)
Requirement already satisfied: packaging>=17.0 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning) (21.0)
Requirement already satisfied: tqdm>=4.41.0 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning) (4.62.3)
Requirement already satisfied: fsspec[http]!=2021.06.0,>=2021.05.0 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning) (2021.10.1)
Requirement already satisfied: torchmetrics>=0.4.0 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning) (0.5.0)
Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning) (3.10.0.2)
Requirement already satisfied: pyDeprecate==0.3.1 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning) (0.3.1)
Requirement already satisfied: PyYAML>=5.1 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning) (5.4.1)
Requirement already satisfied: tensorboard>=2.2.0 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning) (2.6.0)
Requirement already satisfied: numpy>=1.17.2 in /opt/conda/lib/python3.7/site-packages (from pytorch_lightning) (1.19.5)
Requirement already satisfied: aiohttp in /opt/conda/lib/python3.7/site-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (3.7.4.post0)
Requirement already satisfied: requests in /opt/conda/lib/python3.7/site-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (2.25.1)
Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/lib/python3.7/site-packages (from packaging>=17.0->pytorch_lightning) (2.4.7)
Requirement already satisfied: setuptools>=41.0.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (58.0.4)
Requirement already satisfied: wheel>=0.26 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (0.37.0)
Requirement already satisfied: google-auth=1.6.3 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (1.35.0)
Requirement already satisfied: markdown>=2.6.8 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (3.3.4)
Requirement already satisfied: tensorboard-data-server=0.6.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (0.6.1)
Requirement already satisfied: grpcio>=1.24.3 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (1.38.1)
Requirement already satisfied: protobuf>=3.6.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (3.19.0)
Requirement already satisfied: absl-py>=0.4 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (0.14.0)
Requirement already satisfied: werkzeug>=0.11.15 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (2.0.1)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (1.8.0)
Requirement already satisfied: google-auth-oauthlib=0.4.1 in /opt/conda/lib/python3.7/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (0.4.6)
Requirement already satisfied: six in /opt/conda/lib/python3.7/site-packages (from absl-py>=0.4->tensorboard>=2.2.0->pytorch_lightning) (1.16.0)
Requirement already satisfied: rsa=3.1.4 in /opt/conda/lib/python3.7/site-packages (from google-auth=1.6.3->tensorboard>=2.2.0->pytorch_lightning) (4.7.2)
Requirement already satisfied: cachetools=2.0.0 in /opt/conda/lib/python3.7/site-packages (from google-auth=1.6.3->tensorboard>=2.2.0->pytorch_lightning) (4.2.2)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /opt/conda/lib/python3.7/site-packages (from google-auth=1.6.3->tensorboard>=2.2.0->pytorch_lightning) (0.2.7)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /opt/conda/lib/python3.7/site-packages (from google-auth-oauthlib=0.4.1->tensorboard>=2.2.0->pytorch_lightning) (1.3.0)
Requirement already satisfied: importlib-metadata in /opt/conda/lib/python3.7/site-packages (from markdown>=2.6.8->tensorboard>=2.2.0->pytorch_lightning) (4.8.1)
Requirement already satisfied: pyasn1=0.4.6 in /opt/conda/lib/python3.7/site-packages (from pyasn1-modules>=0.2.1->google-auth=1.6.3->tensorboard>=2.2.0->pytorch_lightning) (0.4.8)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (2021.10.8)
Requirement already satisfied: chardet=3.0.2 in /opt/conda/lib/python3.7/site-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (4.0.0)
Requirement already satisfied: urllib3=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (1.26.6)
Requirement already satisfied: idna=2.5 in /opt/conda/lib/python3.7/site-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (2.10)
Requirement already satisfied: oauthlib>=3.0.0 in /opt/conda/lib/python3.7/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib=0.4.1->tensorboard>=2.2.0->pytorch_lightning) (3.1.1)
Requirement already satisfied: yarl=1.0 in /opt/conda/lib/python3.7/site-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (1.6.3)
Requirement already satisfied: async-timeout=3.0 in /opt/conda/lib/python3.7/site-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (3.0.1)
Requirement already satisfied: multidict=4.5 in /opt/conda/lib/python3.7/site-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (5.1.0)
Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.7/site-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (21.2.0)
Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata->markdown>=2.6.8->tensorboard>=2.2.0->pytorch_lightning) (3.5.0)
[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv[0m
[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv[0m
Collecting tkitAutoMask
  Downloading tkitAutoMask-0.0.0.316350799-py3-none-any.whl (9.2 kB)
Installing collected packages: tkitAutoMask
Successfully installed tkitAutoMask-0.0.0.316350799
[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv[0m
Requirement already satisfied: transformers in /opt/conda/lib/python3.7/site-packages (4.5.1)
Requirement already satisfied: sacremoses in /opt/conda/lib/python3.7/site-packages (from transformers) (0.0.46)
Requirement already satisfied: filelock in /opt/conda/lib/python3.7/site-packages (from transformers) (3.0.12)
Requirement already satisfied: importlib-metadata in /opt/conda/lib/python3.7/site-packages (from transformers) (4.8.1)
Requirement already satisfied: packaging in /opt/conda/lib/python3.7/site-packages (from transformers) (21.0)
Requirement already satisfied: requests in /opt/conda/lib/python3.7/site-packages (from transformers) (2.25.1)
Requirement already satisfied: tokenizers=0.10.1 in /opt/conda/lib/python3.7/site-packages (from transformers) (0.10.3)
Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.7/site-packages (from transformers) (2021.8.28)
Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.7/site-packages (from transformers) (1.19.5)
Requirement already satisfied: tqdm>=4.27 in /opt/conda/lib/python3.7/site-packages (from transformers) (4.62.3)
Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata->transformers) (3.5.0)
Requirement already satisfied: typing-extensions>=3.6.4 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata->transformers) (3.10.0.2)
Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/lib/python3.7/site-packages (from packaging->transformers) (2.4.7)
Requirement already satisfied: chardet=3.0.2 in /opt/conda/lib/python3.7/site-packages (from requests->transformers) (4.0.0)
Requirement already satisfied: idna=2.5 in /opt/conda/lib/python3.7/site-packages (from requests->transformers) (2.10)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests->transformers) (2021.10.8)
Requirement already satisfied: urllib3=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests->transformers) (1.26.6)
Requirement already satisfied: joblib in /opt/conda/lib/python3.7/site-packages (from sacremoses->transformers) (1.0.1)
Requirement already satisfied: click in /opt/conda/lib/python3.7/site-packages (from sacremoses->transformers) (8.0.1)
Requirement already satisfied: six in /opt/conda/lib/python3.7/site-packages (from sacremoses->transformers) (1.16.0)
[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv[0m
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb")
os.environ["WANDB_SILENT"] = "true"
os.environ["WANDB_API_KEY"] = secret_value_0
import torch 
from torch import nn
import torch.nn.functional as F
# from torch.utils.data import DataLoader, random_split,TensorDataset
from transformers import BertTokenizer, AlbertModel,AutoModel,AutoTokenizer,AutoModelForTokenClassification,AutoConfig
from torch.utils.data import DataLoader, random_split,TensorDataset
from transformers import BertTokenizer,AutoModel,AutoTokenizer,AutoConfig
import torch.optim as optim
from tqdm.auto import tqdm
import torchmetrics
import pytorch_lightning as pl
import re
import random
import tqdm
import gzip
import numpy as np
import torch
import  os

import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint,LearningRateMonitor
# 自动停止
# https://pytorch-lightning.readthedocs.io/en/1.2.1/common/early_stopping.html
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import torchmetrics
from torchmetrics.functional import precision_recall,precision_recall_curve,f1

# from tkitAttNLocal import AttNLocal

# from .lr import CyclicCosineDecayLR

from tkitAutoMask import autoMask
import csv

预处理数据

class SimICD(pl.LightningModule):
    """
    用于icd编码标准化的模型
    参考sentent bert模式,加入双向gru来降低对参数的依赖。

    """
    def __init__(
        self,learning_rate=3e-4,
        T_max=5,
        ignore_index=0,max_len=256,
        optimizer_name="AdamW",
        dropout=0.2,
        labels=2,
        pretrained="uer/chinese_roberta_L-2_H-128",
        batch_size=2,
        trainfile="./data/train.pkt",
        valfile="./data/val.pkt",
        testfile="./data/test.pkt",
        **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.tokenizer = BertTokenizer.from_pretrained(pretrained) 
        config = AutoConfig.from_pretrained(pretrained)
        self.model = AutoModel.from_pretrained(pretrained,config=config)
        self.rnn = nn.GRU(config.hidden_size, config.hidden_size,dropout=dropout,num_layers=2,bidirectional=True)
#         self.rnn = nn.LSTM(config.hidden_size, config.hidden_size,dropout=dropout,num_layers=2,bidirectional=True)
        self.pre_classifier=nn.Linear(config.hidden_size*6,config.hidden_size)
        self.dropout = torch.nn.Dropout(dropout)
        self.classifier = torch.nn.Linear(config.hidden_size, 1)
#         self.classifierSigmoid = torch.nn.Sigmoid()

        self.tomask = autoMask(
                # transformer,
                mask_token_id = self.tokenizer.mask_token_id,          # the token id reserved for masking
                pad_token_id = self.tokenizer.pad_token_id,           # the token id for padding
                mask_prob = 0.05,           # masking probability for masked language modeling
                replace_prob = 0.90,        # ~10% probability that token will not be masked, but included in loss, as detailed in the epaper
                mask_ignore_token_ids = [self.tokenizer.cls_token_id,self.tokenizer.eos_token_id]  # other tokens to exclude from masking, include the [cls] and [sep] here
                )
    def forward(self, input_ids_a,input_ids_b,attention_mask_a=None,attention_mask_b=None):
        """
        分类解决方案

        """
        B,L=input_ids_a.size()
#         print(input_ids_a.size())
        outputs_a=self.model(input_ids=input_ids_a,attention_mask=attention_mask_a)
        # Perform pooling. In this case, max pooling.
#         emb_a = self.mean_pooling(outputs_a, attention_mask_a)
        emb_a,_=self.rnn(outputs_a[0])
#         print(emb_a.size(),emb_a.sum(1).size())
#         emb_a=emb_a.sum(1).view(B,-1)
        emb_a = self.mean_pooling(emb_a, attention_mask_a)
#         print(emb_a.size())
        outputs_b=self.model(input_ids=input_ids_b,attention_mask=attention_mask_b)
        emb_b,_=self.rnn(outputs_b[0])
        # Perform pooling. In this case, max pooling.
        emb_b = self.mean_pooling(emb_b, attention_mask_b)
#         _,emb_b=self.rnn(outputs_b[0])
#         emb_b=emb_b.sum(1).view(B,-1)
        emb_diff=emb_a-emb_b

        emb=torch.cat((emb_a,emb_b,emb_diff.abs()),-1)
        pooler=self.pre_classifier(emb)
#         cos = nn.CosineSimilarity(dim=1, eps=1e-8)
#         sim=cos(emb_a, emb_b)

        pooler = torch.nn.ReLU()(pooler)
        pooler = self.dropout(pooler)
        output = self.classifier(pooler)
#         output = self.classifierSigmoid(output)
        return output     

    def mean_pooling(self,model_output, attention_mask):
        token_embeddings = model_output #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
#     def mean_pooling(self,model_output, attention_mask):
#         token_embeddings = model_output[0] #First element of model_output contains all token embeddings
#         input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
#         return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

#     def loss_fc(self,out,label):
#         loss_fc = nn.L1Loss()
#         loss=loss_fc(out,label)
#         return loss
    def loss_fc(self,out,label):
        """
        jisuan 离散的损失

        """
#         loss_fc = nn.CrossEntropyLoss()
        loss_fc = nn.BCEWithLogitsLoss()
#         print(out.size(),label.size())
# .view(-1,self.hparams.labels)
        loss=loss_fc(out.view(-1),label.view(-1))
        return loss

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        input_ids_a,attention_mask_a,input_ids_b,attention_mask_b,labels = batch

#         input_ids_a=self.tomask(input_ids_a)[0]
#         input_ids_b=self.tomask(input_ids_b)[0]
        out=self(input_ids_a,input_ids_b,attention_mask_a,attention_mask_b)
        loss=self.loss_fc(out,labels)

        self.log('train_loss',loss)
        return  loss
    def validation_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        input_ids_a,attention_mask_a,input_ids_b,attention_mask_b,labels = batch

        out=self(input_ids_a,input_ids_b,attention_mask_a,attention_mask_b)
        loss=self.loss_fc(out,labels)
        self.log('val_loss',loss)
        return  loss
    def test_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        input_ids_a,attention_mask_a,input_ids_b,attention_mask_b,labels = batch

        out=self(input_ids_a,input_ids_b,attention_mask_a,attention_mask_b)
        loss=self.loss_fc(out,labels)
        self.log('test_loss',loss)
        return  loss
    def train_dataloader(self):
        train=torch.load(self.hparams.trainfile)
        return DataLoader(train, batch_size=int(self.hparams.batch_size),num_workers=2,pin_memory=True, shuffle=True)
    def val_dataloader(self):
        val=torch.load(self.hparams.valfile)
        return DataLoader(val, batch_size=int(self.hparams.batch_size),num_workers=2,pin_memory=True)
    def test_dataloader(self):
        val=torch.load(self.hparams.testfile)
        return DataLoader(val, batch_size=int(self.hparams.batch_size),num_workers=2,pin_memory=True)

    def configure_optimizers(self):
        """优化器 自动优化器"""
        optimizer = getattr(optim, self.hparams.optimizer_name)(self.parameters(), lr=self.hparams.learning_rate)
        #         使用自适应调整模型
#         scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',patience=5000,factor=0.8,verbose=True)
#         scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 2000, 500)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=500, eta_min=1e-8)
#         scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,max_lr=1e-3, T_max=1000, eta_min=1e-8)

#         scheduler = CyclicCosineDecayLR(optimizer, 
#                                         init_decay_epochs=10,
#                                         min_decay_lr=1e-8,
#                                         restart_interval = 5,
#                                         restart_lr=self.hparams.learning_rate/1,
#                                         restart_interval_multiplier=1.5,
#                                         warmup_epochs=10,
#                                         warmup_start_lr=self.hparams.learning_rate/10)
#
        lr_scheduler={
            'scheduler': scheduler,
            'interval': 'step',
            'frequency': 1,
            'name':"lr_scheduler",
            'monitor': 'train_loss', #监听数据变化
            'strict': True,
        }
#         return [optimizer], [lr_scheduler]
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

# /kaggle/input/reformerchinesemodel/epoch4step21209.ckpt
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
# 自动停止
# https://pytorch-lightning.readthedocs.io/en/1.2.1/common/early_stopping.html
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor

# 引入修剪技术 https://pytorch-lightning.readthedocs.io/en/stable/advanced/pruning_quantization.html
from pytorch_lightning.callbacks import ModelPruning
import torch.nn.utils.prune as prune
# https://pytorch-lightning.readthedocs.io/en/latest/common/early_stopping.html
# 量化 降低内存 低精度  https://pytorch-lightning.readthedocs.io/en/stable/advanced/pruning_quantization.html
from pytorch_lightning.callbacks import QuantizationAwareTraining

# 使用 DDP 时设置 find_unused_pa​​rameters=False
# 默认情况下,我们已启用查找未使用的参数为 True。这是针对过去出现的兼容性问题(有关更多信息,请参阅讨论)。默认情况下,这会影响性能,并且在大多数情况下可以禁用。
from pytorch_lightning.plugins import DDPPlugin

from pytorch_lightning.loggers import WandbLogger

seed_everything(2021)

early_stop_callback = EarlyStopping(
   monitor='val_loss',
   min_delta=0.0000,
   patience=5,
   verbose=True,
   mode='min',
#   check_finite=True,
#  check_on_train_epoch_end=True,
#     divergence_threshold=0.1

)
checkpoint_callback = ModelCheckpoint(
#     filename='/kaggle/working/{epoch}-{val_loss:.2f}',
#     dirpath="/kaggle/working/checkpoints",
    filename='checkpoint',
#     filename='/kaggle/working/bart-out',
    save_last=True,
    verbose=True,
    monitor='val_loss',
#     every_n_train_steps=2,
    mode='min',
#     best_model_path='best'
    save_top_k=1
)
lr_monitor = LearningRateMonitor(logging_interval='step')

wandb_logger = WandbLogger(project='GRU孪生网络训练相关程度notebook82fd4fd669训练')

# profilers=pl.profiler.profilers.SimpleProfiler()

# model=LitAutoEncDec(learning_rate=3e-4,T_max=5,optimizer_name="AdamW",batch_size=96)
model= SimICD(
        learning_rate=7e-5,
        T_max=5,
        ignore_index=0,
        optimizer_name="AdamW",
        dropout=0.2,
        labels=2,
        pretrained="uer/chinese_roberta_L-2_H-512",
        batch_size=1024,
        trainfile="../input/notebook134e5b9624/data/train.pkt",
        valfile="../input/notebook134e5b9624/data/val.pkt",
        testfile="../input/notebook134e5b9624/data/test.pkt"
             )
trainer = pl.Trainer(
        gpus=1,
        #     min_epochs=1,
        precision=16,amp_level='O2',

#         val_check_interval=0.5, #这里增加检查验证的频率
#         limit_val_batches=0.5, # 限制验证的数目 降低每次验证的批次大小
        checkpoint_callback=True,
        #         callbacks=[PyTorchLightningPruningCallback(trial, monitor="train_loss")],
#         resume_from_checkpoint="../input/sposeq2seqbartuerbartchinese4768model/spo新分词方案seq2seq—预训练BartForConditionalGeneration_uer_bart-chinese-4-768-cluecorpussmall/36ppt8oy/checkpoints/checkpoint.ckpt",
        auto_select_gpus=True,
        callbacks=[lr_monitor,early_stop_callback,
#                pruning,
               checkpoint_callback
        #                        QuantizationAwareTraining()
              ],

        gradient_clip_val=0.5,
        stochastic_weight_avg=True,# 随机加权平均https://pytorch-lightning.readthedocs.io/en/stable/advanced/training_tricks.html#stochastic-weight-averaging
        max_epochs=20,
        auto_lr_find=True,
        logger=wandb_logger, #日志
#         logger=neptune_logger,
#         plugins=[DDPPlugin(find_unused_parameters=False)],
        accumulate_grad_batches=1,
        #     overfit_batches=20, #过拟合一小部分训练数据 (float) 或一组批次 (int)。 小数据测试时候用它 
#         terminate_on_nan=True, # 出现nan则停止
        weights_summary="top", #开始打印参数
        flush_logs_every_n_steps=5,
        log_every_n_steps=5,
#         profiler=profilers,
#     accelerator=
)

# trainer.tune(model)

trainer.fit(model)
Downloading:   0%|          | 0.00/110k [00:00, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00, ?B/s]

Downloading:   0%|          | 0.00/216 [00:00, ?B/s]

Downloading:   0%|          | 0.00/466 [00:00, ?B/s]

Downloading:   0%|          | 0.00/70.7M [00:00, ?B/s]

Some weights of BertModel were not initialized from the model checkpoint at uer/chinese_roberta_L-2_H-512 and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

CondaEnvException: Unable to determine environment

Please re-run this command with one of the following options:

* Provide an environment name via --name or -n
* Re-run this command inside an activated conda environment.

Validation sanity check: 0it [00:00, ?it/s]

Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/callback_hook.py:103: LightningDeprecationWarning: The signature of `Callback.on_train_epoch_end` has changed in v1.3. `outputs` parameter has been removed. Support for the old signature will be removed in v1.5
  "The signature of `Callback.on_train_epoch_end` has changed in v1.3."

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/callbacks/stochastic_weight_avg.py:190: UserWarning: SWA is currently only supported every epoch. Found {'scheduler': , 'name': 'lr\_scheduler', 'interval': 'step', 'frequency': 1, 'reduce\_on\_plateau': False, 'monitor': 'train\_loss', 'strict': True, 'opt\_idx': None}
 rank\_zero\_warn(f"SWA is currently only supported every epoch. Found {scheduler\_cfg}")

Validating: 0it [00:00, ?it/s]

/opt/conda/lib/python3.7/site-packages/pytorch\_lightning/plugins/precision/precision\_plugin.py:138: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip\_grad\_norm\_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error\_if\_nonfinite=false will be required to retain the old behavior.
 torch.nn.utils.clip\_grad\_norm\_(parameters, clip\_val)

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]
# model.hparams
# trainer.fit(model)
# # /kaggle/input/reformerchinesemodel/epoch4step21209.ckpt
# from pytorch_lightning import Trainer, seed_everything
# from pytorch_lightning.callbacks import ModelCheckpoint
# # 自动停止
# # https://pytorch-lightning.readthedocs.io/en/1.2.1/common/early_stopping.html
# from pytorch_lightning.callbacks.early_stopping import EarlyStopping
# from pytorch_lightning.callbacks import LearningRateMonitor

# # 引入修剪技术 https://pytorch-lightning.readthedocs.io/en/stable/advanced/pruning_quantization.html
# from pytorch_lightning.callbacks import ModelPruning
# import torch.nn.utils.prune as prune
# # https://pytorch-lightning.readthedocs.io/en/latest/common/early_stopping.html
# # 量化 降低内存 低精度  https://pytorch-lightning.readthedocs.io/en/stable/advanced/pruning_quantization.html
# from pytorch_lightning.callbacks import QuantizationAwareTraining

# # 使用 DDP 时设置 find_unused_pa​​rameters=False
# # 默认情况下,我们已启用查找未使用的参数为 True。这是针对过去出现的兼容性问题(有关更多信息,请参阅讨论)。默认情况下,这会影响性能,并且在大多数情况下可以禁用。
# from pytorch_lightning.plugins import DDPPlugin

# from pytorch_lightning.loggers import WandbLogger

# early_stop_callback = EarlyStopping(
#    monitor='val_loss',
#    min_delta=0.0000,
#    patience=5,
#    verbose=True,
#    mode='min',
# #   check_finite=True,
# #  check_on_train_epoch_end=True,
# #     divergence_threshold=0.1

# )
# checkpoint_callback = ModelCheckpoint(
# #     filename='/kaggle/working/{epoch}-{val_loss:.2f}',
# #     dirpath="/kaggle/working/checkpoints",
#     filename='checkpoint',
# #     filename='/kaggle/working/bart-out',
# #     save_last=True,
#     verbose=True,
#     monitor='val_loss',
# #     every_n_train_steps=2,
#     mode='min',
# #     best_model_path='best'
#     save_top_k=1
# )
# lr_monitor = LearningRateMonitor(logging_interval='step')

# wandb_logger = WandbLogger(project='GRU孪生网络下一句notebook82fd4fd669训练')

# # profilers=pl.profiler.profilers.SimpleProfiler()

# # model=LitAutoEncDec(learning_rate=3e-4,T_max=5,optimizer_name="AdamW",batch_size=96)
# model= SNext(
#         learning_rate=5e-5,
#         T_max=5,
#         ignore_index=0,
#         optimizer_name="AdamW",
#         dropout=0.2,
#         labels=2,
#         pretrained="uer/chinese_roberta_L-2_H-128",
#         batch_size=12,
#         trainfile="../input/nspchinesedataset/train.pkt",
#         valfile="../input/nspchinesedataset/val.pkt",
#         testfile="../input/nspchinesedataset/test.pkt"
#              )
# trainer = pl.Trainer(
#         gpus=1,
#         #     min_epochs=1,
#         precision=16,amp_level='O2',

# #         val_check_interval=0.5, #这里增加检查验证的频率
# #         limit_val_batches=0.5, # 限制验证的数目 降低每次验证的批次大小
#         checkpoint_callback=True,
#         #         callbacks=[PyTorchLightningPruningCallback(trial, monitor="train_loss")],
# #         resume_from_checkpoint="../input/sposeq2seqbartuerbartchinese4768model/spo新分词方案seq2seq—预训练BartForConditionalGeneration_uer_bart-chinese-4-768-cluecorpussmall/36ppt8oy/checkpoints/checkpoint.ckpt",
#         auto_select_gpus=True,
#         callbacks=[lr_monitor,early_stop_callback,
# #                pruning,
#                checkpoint_callback
#         #                        QuantizationAwareTraining()
#               ],

#         gradient_clip_val=0.5,
#         stochastic_weight_avg=True,# 随机加权平均https://pytorch-lightning.readthedocs.io/en/stable/advanced/training_tricks.html#stochastic-weight-averaging
#         max_epochs=5,
#         logger=wandb_logger, #日志
# #         logger=neptune_logger,
# #         plugins=[DDPPlugin(find_unused_parameters=False)],
#         accumulate_grad_batches=10,
#         #     overfit_batches=20, #过拟合一小部分训练数据 (float) 或一组批次 (int)。 小数据测试时候用它 
# #         terminate_on_nan=True, # 出现nan则停止
#         weights_summary="top", #开始打印参数
#         flush_logs_every_n_steps=5,
#         log_every_n_steps=5,
# #         profiler=profilers,
# #     accelerator=
# )
# # trainer.tune(model)

# trainer.fit(model)