import torch
import torch.nn as nn
from ..models.torch_engine import ModelEngine
from ..utils.common_util import timeit
[docs]class GMF(torch.nn.Module):
"""GMF Class."""
def __init__(self, config):
"""Initialize GMF Class."""
super(GMF, self).__init__()
self.config = config
self.num_users = config["n_users"]
self.num_items = config["n_items"]
self.emb_dim = config["emb_dim"]
self.embedding_user = torch.nn.Embedding(
num_embeddings=self.num_users, embedding_dim=self.emb_dim
)
self.embedding_item = torch.nn.Embedding(
num_embeddings=self.num_items, embedding_dim=self.emb_dim
)
self.init_weight()
self.affine_output = torch.nn.Linear(in_features=self.emb_dim, out_features=1)
self.logistic = torch.nn.Sigmoid()
[docs] def forward(self, user_indices, item_indices):
"""Train the model."""
user_embedding = self.embedding_user(user_indices)
item_embedding = self.embedding_item(item_indices)
element_product = torch.mul(user_embedding, item_embedding)
logits = self.affine_output(element_product)
rating = self.logistic(logits)
return rating
[docs] def predict(self, user_indices, item_indices):
"""Predict result with the model."""
user_indices = torch.LongTensor(user_indices).to(self.device)
item_indices = torch.LongTensor(item_indices).to(self.device)
with torch.no_grad():
return self.forward(user_indices, item_indices)
[docs] def init_weight(self):
"""Initialize weights."""
nn.init.normal_(self.embedding_user.weight, std=0.01)
nn.init.normal_(self.embedding_user.weight, std=0.01)
[docs]class GMFEngine(ModelEngine):
"""Engine for training & evaluating GMF model."""
def __init__(self, config):
"""Initialize GMFEngine Class."""
self.model = GMF(config["model"])
self.loss = torch.nn.BCELoss()
super(GMFEngine, self).__init__(config)
[docs] def train_single_batch(self, users, items, ratings):
"""Train the model in a single batch.
Args:
batch_data (list): batch users, positive items and negative items.
Return:
loss (float): batch loss.
"""
assert hasattr(self, "model"), "Please specify the exact model !"
users, items, ratings = (
users.to(self.device),
items.to(self.device),
ratings.to(self.device),
)
self.optimizer.zero_grad()
ratings_pred = self.model(users, items)
loss = self.loss(ratings_pred.view(-1), ratings)
loss.backward()
self.optimizer.step()
loss = loss.item()
return loss
[docs] @timeit
def train_an_epoch(self, train_loader, epoch_id):
"""Train the model in one epoch.
Args:
epoch_id (int): the number of epoch.
train_loader (function): user, pos_items and neg_items generator.
"""
assert hasattr(self, "model"), "Please specify the exact model !"
self.model.train()
total_loss = 0
for batch_id, batch in enumerate(train_loader):
assert isinstance(batch[0], torch.LongTensor)
user, item, rating = batch[0], batch[1], batch[2]
rating = rating.float()
loss = self.train_single_batch(user, item, rating)
total_loss += loss
print("[Training Epoch {}], Loss {}".format(epoch_id, total_loss))
self.writer.add_scalar("model/loss", total_loss, epoch_id)