TransE模型 PyTorch版本实现
Published on Aug. 22, 2023, 12:11 p.m.
版本1
出自 https://github.com/mklimasz/TransE-PyTorch/blob/master/model.py
import numpy as np
import torch
import torch.nn as nn
class TransE(nn.Module):
def __init__(self, entity_count, relation_count, device, norm=1, dim=100, margin=1.0):
super(TransE, self).__init__()
self.entity_count = entity_count
self.relation_count = relation_count
self.device = device
self.norm = norm
self.dim = dim
self.entities_emb = self._init_enitity_emb()
self.relations_emb = self._init_relation_emb()
self.criterion = nn.MarginRankingLoss(margin=margin, reduction='none')
def _init_enitity_emb(self):
entities_emb = nn.Embedding(num_embeddings=self.entity_count + 1,
embedding_dim=self.dim,
padding_idx=self.entity_count)
uniform_range = 6 / np.sqrt(self.dim)
entities_emb.weight.data.uniform_(-uniform_range, uniform_range)
return entities_emb
def _init_relation_emb(self):
relations_emb = nn.Embedding(num_embeddings=self.relation_count + 1,
embedding_dim=self.dim,
padding_idx=self.relation_count)
uniform_range = 6 / np.sqrt(self.dim)
relations_emb.weight.data.uniform_(-uniform_range, uniform_range)
# -1 to avoid nan for OOV vector
relations_emb.weight.data[:-1, :].div_(relations_emb.weight.data[:-1, :].norm(p=1, dim=1, keepdim=True))
return relations_emb
def forward(self, positive_triplets: torch.LongTensor, negative_triplets: torch.LongTensor):
"""Return model losses based on the input.
:param positive_triplets: triplets of positives in Bx3 shape (B - batch, 3 - head, relation and tail)
:param negative_triplets: triplets of negatives in Bx3 shape (B - batch, 3 - head, relation and tail)
:return: tuple of the model loss, positive triplets loss component, negative triples loss component
"""
# -1 to avoid nan for OOV vector
self.entities_emb.weight.data[:-1, :].div_(self.entities_emb.weight.data[:-1, :].norm(p=2, dim=1, keepdim=True))
assert positive_triplets.size()[1] == 3
positive_distances = self._distance(positive_triplets)
assert negative_triplets.size()[1] == 3
negative_distances = self._distance(negative_triplets)
return self.loss(positive_distances, negative_distances), positive_distances, negative_distances
def predict(self, triplets: torch.LongTensor):
"""Calculated dissimilarity score for given triplets.
:param triplets: triplets in Bx3 shape (B - batch, 3 - head, relation and tail)
:return: dissimilarity score for given triplets
"""
return self._distance(triplets)
def loss(self, positive_distances, negative_distances):
target = torch.tensor([-1], dtype=torch.long, device=self.device)
return self.criterion(positive_distances, negative_distances, target)
def _distance(self, triplets):
"""Triplets should have shape Bx3 where dim 3 are head id, relation id, tail id."""
assert triplets.size()[1] == 3
heads = triplets[:, 0]
relations = triplets[:, 1]
tails = triplets[:, 2]
return (self.entities_emb(heads) + self.relations_emb(relations) - self.entities_emb(tails)).norm(p=self.norm,
dim=1)
版本2实现
https://github.com/toooooodo/pytorch-TransE/blob/master/model.py
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from prepare_data import TrainSet, TestSet
import math
class TranE(nn.Module):
def __init__(self, entity_num, relation_num, device, dim=50, d_norm=2, gamma=1):
"""
:param entity_num: number of entities
:param relation_num: number of relations
:param dim: embedding dim
:param device:
:param d_norm: measure d(h+l, t), either L1-norm or L2-norm
:param gamma: margin hyperparameter
"""
super(TranE, self).__init__()
self.dim = dim
self.d_norm = d_norm
self.device = device
self.gamma = torch.FloatTensor([gamma]).to(self.device)
self.entity_num = entity_num
self.relation_num = relation_num
self.entity_embedding = nn.Embedding.from_pretrained(
torch.empty(entity_num, self.dim).uniform_(-6 / math.sqrt(self.dim), 6 / math.sqrt(self.dim)), freeze=False)
self.relation_embedding = nn.Embedding.from_pretrained(
torch.empty(relation_num, self.dim).uniform_(-6 / math.sqrt(self.dim), 6 / math.sqrt(self.dim)),
freeze=False)
# l <= l / ||l||
relation_norm = torch.norm(self.relation_embedding.weight.data, dim=1, keepdim=True)
self.relation_embedding.weight.data = self.relation_embedding.weight.data / relation_norm
def forward(self, pos_head, pos_relation, pos_tail, neg_head, neg_relation, neg_tail):
"""
:param pos_head: [batch_size]
:param pos_relation: [batch_size]
:param pos_tail: [batch_size]
:param neg_head: [batch_size]
:param neg_relation: [batch_size]
:param neg_tail: [batch_size]
:return: triples loss
"""
pos_dis = self.entity_embedding(pos_head) + self.relation_embedding(pos_relation) - self.entity_embedding(
pos_tail)
neg_dis = self.entity_embedding(neg_head) + self.relation_embedding(neg_relation) - self.entity_embedding(
neg_tail)
# return pos_head_and_relation, pos_tail, neg_head_and_relation, neg_tail
return self.calculate_loss(pos_dis, neg_dis).requires_grad_()
def calculate_loss(self, pos_dis, neg_dis):
"""
:param pos_dis: [batch_size, embed_dim]
:param neg_dis: [batch_size, embed_dim]
:return: triples loss: [batch_size]
"""
distance_diff = self.gamma + torch.norm(pos_dis, p=self.d_norm, dim=1) - torch.norm(neg_dis, p=self.d_norm,
dim=1)
return torch.sum(F.relu(distance_diff))
def tail_predict(self, head, relation, tail, k=10):
"""
to do tail prediction hits@k
:param head: [batch_size]
:param relation: [batch_size]
:param tail: [batch_size]
:param k: hits@k
:return:
"""
# head: [batch_size]
# h_and_r: [batch_size, embed_size] => [batch_size, 1, embed_size] => [batch_size, N, embed_size]
h_and_r = self.entity_embedding(head) + self.relation_embedding(relation)
h_and_r = torch.unsqueeze(h_and_r, dim=1)
h_and_r = h_and_r.expand(h_and_r.shape[0], self.entity_num, self.dim)
# embed_tail: [batch_size, N, embed_size]
embed_tail = self.entity_embedding.weight.data.expand(h_and_r.shape[0], self.entity_num, self.dim)
# indices: [batch_size, k]
values, indices = torch.topk(torch.norm(h_and_r - embed_tail, dim=2), k, dim=1, largest=False)
# tail: [batch_size] => [batch_size, 1]
tail = tail.view(-1, 1)
return torch.sum(torch.eq(indices, tail)).item()
if __name__ == '__main__':
train_data_set = TrainSet()
test_data_set = TestSet()
test_data_set.convert_word_to_index(train_data_set.entity_to_index, train_data_set.relation_to_index,
test_data_set.raw_data)
train_loader = DataLoader(train_data_set, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data_set, batch_size=32, shuffle=True)
for batch_idx, data in enumerate(test_loader):
print(data.shape)
break