import torch
import torch.nn as nn
from beta_rec.models.torch_engine import ModelEngine
[docs]class LightGCN(torch.nn.Module):
"""Model initialisation, embedding generation and prediction of NGCF."""
def __init__(self, config, norm_adj):
"""Initialize LightGCN Class."""
super(LightGCN, self).__init__()
self.config = config
self.n_users = config["n_users"]
self.n_items = config["n_items"]
self.emb_dim = config["emb_dim"]
self.layer_size = config["layer_size"]
self.n_layers = len(self.layer_size)
self.norm_adj = norm_adj
self.layer_size = [self.emb_dim] + self.layer_size
self.user_embedding = nn.Embedding(self.n_users, self.emb_dim)
self.item_embedding = nn.Embedding(self.n_items, self.emb_dim)
self.f = nn.Sigmoid()
self.init_emb()
[docs] def dropout(self, x, keep_prob):
"""Drop out some layers."""
size = x.size()
index = x.indices().t()
values = x.values()
random_index = torch.rand(len(values)) + keep_prob
random_index = random_index.int().bool()
index = index[random_index]
values = values[random_index] / keep_prob
g = torch.sparse.FloatTensor(index.t(), values, size)
return g
[docs] def init_emb(self):
"""Initialize users and items' embeddings."""
# Initialize users and items' embeddings
nn.init.xavier_uniform_(self.user_embedding.weight)
nn.init.xavier_uniform_(self.item_embedding.weight)
[docs] def forward(self, norm_adj):
"""Train GNN on users and item embeddings.
Args:
norm_adj (torch sparse tensor): the norm adjacent matrix of the user-item interaction matrix.
Returns:
u_g_embeddings (tensor): processed user embeddings.
i_g_embeddings (tensor): processed item embeddings.
"""
all_emb = torch.cat(
(self.user_embedding.weight, self.item_embedding.weight), dim=0
)
embs = [all_emb]
norm_adj = norm_adj.coalesce()
norm_adj = norm_adj.to(self.device)
# if self.config["dropout"]:
# print("droping")
# norm_adj = self.dropout(x=norm_adj, keep_prob=self.config["keep_pro"])
# else:
# norm_adj = norm_adj
if self.training:
norm_adj = self.dropout(x=norm_adj, keep_prob=self.config["keep_pro"])
for layer in range(self.n_layers):
all_emb = torch.sparse.mm(norm_adj, all_emb)
embs.append(all_emb)
embs = torch.stack(embs, dim=1)
embs = torch.mean(embs, dim=1)
u_g_embeddings, i_g_embeddings = torch.split(embs, [self.n_users, self.n_items])
return u_g_embeddings, i_g_embeddings
[docs] def predict(self, users, items):
"""Predict result with the model.
Args:
users (int, or list of int): user id.
items (int, or list of int): item id.
Return:
scores (int): dot product.
"""
self.eval()
users_t = torch.tensor(users, dtype=torch.int64, device=self.device)
items_t = torch.tensor(items, dtype=torch.int64, device=self.device)
with torch.no_grad():
# scores = torch.mul(
# self.user_embedding(users_t), self.item_embedding(items_t)
# ).sum(dim=1)
ua_embeddings, ia_embeddings = self.forward(self.norm_adj)
u_g_embeddings = ua_embeddings[users_t]
i_g_embeddings = ia_embeddings[items_t]
scores = self.f(torch.mul(u_g_embeddings, i_g_embeddings).sum(dim=1))
return scores
[docs]class LightGCNEngine(ModelEngine):
"""LightGCNEngine Class."""
# A class includes train an epoch and train a batch of NGCF
def __init__(self, config):
"""Initialize LightGCNEngine Class."""
self.config = config
self.regs = config["model"]["regs"] # reg is the regularisation
self.decay = self.regs[0]
self.norm_adj = config["model"]["norm_adj"]
self.model = LightGCN(config["model"], self.norm_adj)
super(LightGCNEngine, self).__init__(config)
self.model.to(self.device)
[docs] def train_single_batch(self, batch_data):
"""Train the model in a single batch.
Args:
batch_data (list): batch users, positive items and negative items.
Return:
loss (float): batch loss.
"""
assert hasattr(self, "model"), "Please specify the exact model !"
self.optimizer.zero_grad()
norm_adj = self.norm_adj
ua_embeddings, ia_embeddings = self.model.forward(norm_adj)
batch_users, pos_items, neg_items = batch_data
u_g_embeddings = ua_embeddings[batch_users]
pos_i_g_embeddings = ia_embeddings[pos_items]
neg_i_g_embeddings = ia_embeddings[neg_items]
batch_mf_loss, batch_reg_loss = self.loss_comput(
u_g_embeddings,
pos_i_g_embeddings,
neg_i_g_embeddings,
batch_users,
pos_items,
neg_items,
)
batch_loss = batch_mf_loss + batch_reg_loss
batch_loss.backward()
self.optimizer.step()
loss = batch_loss.item()
return loss
[docs] def train_an_epoch(self, train_loader, epoch_id):
"""Train the model in one epoch.
Args:
epoch_id (int): the number of epoch.
train_loader (function): user, pos_items and neg_items generator.
"""
assert hasattr(self, "model"), "Please specify the exact model !"
self.model.train()
total_loss = 0.0
for batch_data in train_loader:
loss = self.train_single_batch(batch_data)
total_loss += loss
print("[Training Epoch {}], Loss {}".format(epoch_id, loss))
self.writer.add_scalar("model/loss", total_loss, epoch_id)
[docs] def loss_comput(self, usersE, pos_itemsE, neg_itemsE, users, pos_item, neg_item):
"""Calculate BPR loss."""
pos_scores = torch.sum(torch.mul(usersE, pos_itemsE), dim=1)
neg_scores = torch.sum(torch.mul(usersE, neg_itemsE), dim=1)
userEmb0 = self.model.user_embedding(users.to(self.device))
posEmb0 = self.model.item_embedding(pos_item.to(self.device))
negEmb0 = self.model.item_embedding(neg_item.to(self.device))
reg_loss = (
(1 / 2)
* (
userEmb0.norm(2).pow(2)
+ posEmb0.norm(2).pow(2)
+ negEmb0.norm(2).pow(2)
)
/ float(len(users))
)
reg_loss = reg_loss * self.decay
loss = torch.mean(torch.nn.functional.softplus(neg_scores - pos_scores))
return loss, reg_loss