TextCNN.py示例
Published on Aug. 22, 2023, 12:09 p.m.
import torch
import torch.nn as nn
class TextCNN(nn.Module):
def __init__(self, num_classes, num_embeddings, embedding_dim, hidden_size, dropout_prob=0.5):
super(TextCNN, self).__init__()
self.num_classes = num_classes
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.hidden_size = hidden_size
self.dropout_prob = dropout_prob
o1 = 3 * hidden_size
self.embedding = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
self.conv1 = nn.Conv1d(in_channels=embedding_dim, out_channels=hidden_size, kernel_size=2)
self.conv2 = nn.Conv1d(in_channels=embedding_dim, out_channels=hidden_size, kernel_size=3)
self.conv3 = nn.Conv1d(in_channels=embedding_dim, out_channels=hidden_size, kernel_size=5)
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout(p=dropout_prob)
self.fc = nn.Linear(in_features=o1, out_features=num_classes)
def forward(self, x):
"""
Args:
x: [N,L]
"""
o2 = self.embedding(x)
o3 = o2.permute(dims=(0, 2, 1))
o4 = self.conv1(o3)
o6 = self.conv2(o3)
o8 = self.conv3(o3)
o5 = torch.max(o4, dim=-1)[0]
o7 = torch.max(o6, dim=-1)[0]
o9 = torch.max(o8, dim=-1)[0]
o10 = torch.cat((o5, o7, o9), dim=-1)
o11 = self.relu(o10)
o12 = self.dropout(o11)
o13 = self.fc(o12)
o14 = torch.log_softmax(o13, dim=-1)
return o14
虽然textcnn很很经典,不过在bert面前还是不够打的,即便是对比2层128这种迷你版本的,依然没啥优势。