TextCNN.py示例

Published on Aug. 22, 2023, 12:09 p.m.


import torch
import torch.nn as nn

class TextCNN(nn.Module):
    def __init__(self, num_classes, num_embeddings, embedding_dim, hidden_size, dropout_prob=0.5):
        super(TextCNN, self).__init__()

        self.num_classes = num_classes
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.dropout_prob = dropout_prob

        o1 = 3 * hidden_size
        self.embedding = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
        self.conv1 = nn.Conv1d(in_channels=embedding_dim, out_channels=hidden_size, kernel_size=2)
        self.conv2 = nn.Conv1d(in_channels=embedding_dim, out_channels=hidden_size, kernel_size=3)
        self.conv3 = nn.Conv1d(in_channels=embedding_dim, out_channels=hidden_size, kernel_size=5)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(p=dropout_prob)
        self.fc = nn.Linear(in_features=o1, out_features=num_classes)

    def forward(self, x):
        """
        Args:
            x: [N,L]
        """
        o2 = self.embedding(x)
        o3 = o2.permute(dims=(0, 2, 1))
        o4 = self.conv1(o3)
        o6 = self.conv2(o3)
        o8 = self.conv3(o3)
        o5 = torch.max(o4, dim=-1)[0]
        o7 = torch.max(o6, dim=-1)[0]
        o9 = torch.max(o8, dim=-1)[0]
        o10 = torch.cat((o5, o7, o9), dim=-1)
        o11 = self.relu(o10)
        o12 = self.dropout(o11)
        o13 = self.fc(o12)
        o14 = torch.log_softmax(o13, dim=-1)
        return o14

虽然textcnn很很经典,不过在bert面前还是不够打的,即便是对比2层128这种迷你版本的,依然没啥优势。