Source code for beta_rec.datasets.dataset_base

import gzip
import os
import shutil

import pandas as pd
from py7zr import unpack_7zarchive
from tabulate import tabulate

from ..datasets.data_split import (
    filter_user_item,
    filter_user_item_order,
    generate_parameterized_path,
    load_split_data,
    split_data,
)
from ..utils.common_util import (
    get_dataframe_from_npz,
    save_dataframe_as_npz,
    timeit,
    un_zip,
)
from ..utils.constants import DEFAULT_ORDER_COL, DEFAULT_TIMESTAMP_COL
from ..utils.download import download_file, get_format
from ..utils.onedrive import OneDrive

default_root_dir = os.path.abspath(
    os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)
)

# register 7z unpack
shutil.register_unpack_format("7zip", [".7z"], unpack_7zarchive)


[docs]class DatasetBase(object): """Base class for processing raw dataset into interactions, making and loading data splits. This is an beta dataset which can derive to other dataset. Several directory that store the dataset file would be created in the initial process. Attributes: dataset_name: the dataset name. min_u_c: filter the items that were purchased by less than min_u_c users. (default: :obj:`0`) min_i_c: filter the users that have purchased by less than min_i_c items. (default: :obj:`3`) min_o_c: filter the users that have purchased by less than min_o_c orders. (default: :obj:`0`) url: the url of raw files. manual_download_url: the url that users use to download raw files manually """ def __init__( self, dataset_name, min_u_c=0, min_i_c=3, min_o_c=0, url=None, root_dir=None, manual_download_url=None, processed_leave_one_out_url="", processed_leave_one_basket_url="", processed_random_split_url="", processed_random_basket_split_url="", processed_temporal_split_url="", processed_temporal_basket_split_url="", tips=None, ): """Init DatasetBase Class.""" self.url = url self.manual_download_url = manual_download_url if manual_download_url else url self.processed_leave_one_out_url = processed_leave_one_out_url self.processed_leave_one_basket_url = processed_leave_one_basket_url self.processed_random_split_url = processed_random_split_url self.processed_random_basket_split_url = processed_random_basket_split_url self.processed_temporal_split_url = processed_temporal_split_url self.processed_temporal_basket_split_url = processed_temporal_basket_split_url self.min_u_c = min_u_c self.min_i_c = min_i_c self.min_o_c = min_o_c self.dataset_name = dataset_name # compatible method for the previous version self.save_dataframe_as_npz = save_dataframe_as_npz # create the root datasets directory if not root_dir: root_dir = default_root_dir self.dataset_dir = os.path.join(root_dir, "datasets") if not os.path.exists(self.dataset_dir): os.mkdir(self.dataset_dir) # create the dataset directory self.dataset_dir = os.path.join(self.dataset_dir, dataset_name) if not os.path.exists(self.dataset_dir): os.mkdir(self.dataset_dir) # create the raw directory self.raw_path = os.path.join(self.dataset_dir, "raw") if not os.path.exists(self.raw_path): os.mkdir(self.raw_path) self.processed_path = os.path.join(self.dataset_dir, "processed") if not os.path.exists(self.processed_path): os.mkdir(self.processed_path) if tips is None: tips = ( f"please download the dataset by your self via {self.manual_download_url}, rename to " + f"{self.dataset_name} and put it into {self.raw_path} after decompression " ) self.tips = tips
[docs] @timeit def download(self): """Download the raw dataset. Download the dataset with the given url and unpack the file. """ if not self.url: raise RuntimeError(self.tips) download_file_name = os.path.join( self.raw_path, os.path.splitext(os.path.basename(self.url))[0] ) file_format = self.url.split(".")[-1] if "amazon" in self.url: raw_file_path = os.path.join( self.raw_path, f"{self.dataset_name}.json.{file_format}" ) else: raw_file_path = os.path.join( self.raw_path, f"{self.dataset_name}.{file_format}" ) if "1drv.ms" in self.url: file_format = "zip" raw_file_path = os.path.join( self.raw_path, f"{self.dataset_name}.{file_format}" ) if not os.path.exists(raw_file_path): print(f"download_file: url: {self.url}, raw_file_path: {raw_file_path}") download_file(self.url, raw_file_path) if "amazon" in raw_file_path: # amazon dataset do not unzip print("amazon dataset do not decompress") return elif file_format == "gz": file_name = raw_file_path.replace(".gz", "") with gzip.open(raw_file_path, "rb") as fin: with open(file_name, "wb") as fout: shutil.copyfileobj(fin, fout) else: shutil.unpack_archive( raw_file_path, self.raw_path, format=get_format(file_format) ) if not os.path.exists(download_file_name): return elif os.path.isdir(download_file_name): os.rename( download_file_name, os.path.join(self.raw_path, self.dataset_name) ) else: os.rename( download_file_name, os.path.join( self.raw_path, f'{self.dataset_name}.{download_file_name.split(".")[-1]}', ), )
[docs] @timeit def preprocess(self): """Preprocess the raw file. A virtual function that needs to be implement in the derived class. Preprocess the file downloaded via the url, convert it to a dataframe consist of the user-item interaction and save in the processed directory. """ raise RuntimeError("please implement this function!")
[docs] def load_interaction(self): """Load the user-item interaction. And filter users, items or orders. Returns: DataFrame: Loaded interactions after filtering Load the interaction from the processed file(Need to preprocess the raw file before loading) """ processed_file_path = os.path.join( self.processed_path, f"{self.dataset_name}_interaction.npz" ) if not os.path.exists(os.path.join(processed_file_path)): try: self.preprocess() except FileNotFoundError: print("origin file is broken, re-download it") raw_file_path = os.path.join(self.raw_path, f"{self.dataset_name}.zip") os.remove(raw_file_path) self.download() finally: self.preprocess() data = get_dataframe_from_npz(processed_file_path) print("-" * 80) print("Raw interaction statistics") print( tabulate( data.agg(["count", "nunique"]), headers=data.columns, tablefmt="psql", disable_numparse=True, ) ) print("-" * 80) if self.min_o_c > 0: data = filter_user_item_order( data, min_u_c=self.min_u_c, min_i_c=self.min_i_c, min_o_c=self.min_o_c ) elif self.min_u_c > 0 or self.min_i_c > 0: data = filter_user_item(data, min_u_c=self.min_u_c, min_i_c=self.min_i_c) print("-" * 80) print( "Interaction statistics after filtering " + f"-- min_u_c:{self.min_u_c}, min_i_c:{self.min_i_c}, min_o_c:{self.min_o_c}." ) print( tabulate( data.agg(["count", "nunique"]), headers=data.columns, tablefmt="psql", disable_numparse=True, ) ) print("-" * 80) return data
[docs] @timeit def make_leave_one_out(self, data=None, random=False, n_negative=100, n_test=10): """Generate split data with leave_one_out. Generate split data with leave_one_out method. Args: data (DataFrame): DataFrame to be split. - Default is None. It will load the raw interaction, with a default filter ''' filter_user_item(data, min_u_c=0, min_i_c=3) ''' - Users can specify their filtered data by using filter methods in data_split.py random: bool. Whether randomly leave one item as testing. n_negative: Number of negative samples for testing and validation data. n_test: int. Default 10. The number of testing and validation copies. Returns: train_data (DataFrame): Interaction for training. valid_data list(DataFrame): List of interactions for validation test_data list(DataFrame): List of interactions for testing """ if data is None: data = self.load_interaction() if not isinstance(data, pd.DataFrame): raise RuntimeError("data is not a type of DataFrame") if DEFAULT_TIMESTAMP_COL not in data.columns: random = True result = split_data( data, split_type="leave_one_out", test_rate=0, random=random, n_negative=n_negative, save_dir=self.processed_path, n_test=n_test, ) return result
[docs] @timeit def make_leave_one_basket(self, data=None, random=False, n_negative=100, n_test=10): """Generate split data with leave_one_basket. Generate split data with leave_one_basket method. Args: data (DataFrame): DataFrame to be split. random: bool. Whether randomly leave one basket as testing. n_negative: Number of negative samples for testing and validation data. n_test: int. Default 10. The number of testing and validation copies. Returns: train_data (DataFrame): Interaction for training. valid_data list(DataFrame): List of interactions for validation test_data list(DataFrame): List of interactions for testing """ if data is None: data = self.load_interaction() if not isinstance(data, pd.DataFrame): raise RuntimeError("data is not a type of DataFrame") if DEFAULT_TIMESTAMP_COL not in data.columns: raise RuntimeError("This dataset doesn't have an TIMESTAMP_COL") if DEFAULT_ORDER_COL not in data.columns: raise RuntimeError("This dataset doesn't have an ORDER_COL") result = split_data( data, split_type="leave_one_basket", test_rate=0, random=random, n_negative=n_negative, save_dir=self.processed_path, n_test=n_test, ) return result
[docs] @timeit def make_random_split( self, data=None, test_rate=0.1, n_negative=100, by_user=False, n_test=10 ): """Generate split data with random_split. Generate split data with random_split method Args: data (DataFrame): DataFrame to be split. - Default is None. It will load the raw interaction, with a default filter ``` data = filter_user_item(data, min_u_c=3, min_i_c=3) ``` - Users can specify their filtered data by using filter methods in data_split.py test_rate: percentage of the test data. Note that percentage of the validation data will be the same as testing. n_negative: Number of negative samples for testing and validation data. by_user: bool. Default False. - True: user-based split, - False: global split, n_test: int. Default 10. The number of testing and validation copies. Returns: train_data (DataFrame): Interaction for training. valid_data list(DataFrame): List of interactions for validation test_data list(DataFrame): List of interactions for testing """ if data is None: data = self.load_interaction() data = filter_user_item(data, min_u_c=3, min_i_c=3) if not isinstance(data, pd.DataFrame): raise RuntimeError("data is not a type of DataFrame") result = split_data( data, split_type="random", test_rate=test_rate, n_negative=n_negative, save_dir=self.processed_path, by_user=by_user, n_test=n_test, ) return result
[docs] @timeit def make_random_basket_split( self, data=None, test_rate=0.1, n_negative=100, by_user=False, n_test=10 ): """Generate split data with random_basket_split. Generate split data with random_basket_split method. Args: data (DataFrame): DataFrame to be split. test_rate: percentage of the test data. Note that percentage of the validation data will be the same as testing. n_negative: Number of negative samples for testing and validation data. by_user: bool. Default False. - True: user-based split, - False: global split, n_test: int. Default 10. The number of testing and validation copies. Returns: train_data (DataFrame): Interaction for training. valid_data list(DataFrame): List of interactions for validation test_data list(DataFrame): List of interactions for testing """ if data is None: data = self.load_interaction() if not isinstance(data, pd.DataFrame): raise RuntimeError("data is not a type of DataFrame") if DEFAULT_ORDER_COL not in data.columns: raise RuntimeError("This dataset doesn't have an ORDER_COL") result = split_data( data, split_type="random_basket", test_rate=test_rate, n_negative=n_negative, save_dir=self.processed_path, by_user=by_user, n_test=n_test, ) return result
[docs] @timeit def make_temporal_split( self, data=None, test_rate=0.1, n_negative=100, by_user=False, n_test=10 ): """Generate split data with temporal_split. Generate split data with temporal_split method. Args: data (DataFrame): DataFrame to be split. - Default is None. It will load the raw interaction, with a default filter ``` data = filter_user_item(data, min_u_c=3, min_i_c=3) ``` - Users can specify their filtered data by using filter methods in data_split.py test_rate: percentage of the test data. Note that percentage of the validation data will be the same as testing. n_negative: Number of negative samples for testing and validation data. by_user: bool. Default False. - True: user-based split, - False: global split, n_test: int. Default 10. The number of testing and validation copies. Returns: train_data (DataFrame): Interaction for training. valid_data list(DataFrame): List of interactions for validation test_data list(DataFrame): List of interactions for testing """ if data is None: data = self.load_interaction() data = filter_user_item(data, min_u_c=3, min_i_c=3) if not isinstance(data, pd.DataFrame): raise RuntimeError("data is not a type of DataFrame") if DEFAULT_TIMESTAMP_COL not in data.columns: raise RuntimeError("This dataset doesn't have an TIMESTAMP_COL") result = split_data( data, split_type="temporal", test_rate=test_rate, n_negative=n_negative, save_dir=self.processed_path, by_user=by_user, n_test=n_test, ) return result
[docs] @timeit def make_temporal_basket_split( self, data=None, test_rate=0.1, n_negative=100, by_user=False, n_test=10 ): """Generate split data with temporal_basket_split. Generate split data with temporal_basket_split method. Args: data (DataFrame): DataFrame to be split. - Default is None. It will load the raw interaction, with a default filter - Users can specify their filtered data by using filter methods in data_split.py test_rate: percentage of the test data. Note that percentage of the validation data will be the same as testing. n_negative: Number of negative samples for testing and validation data. by_user: bool. Default False. - True: user-based split, - False: global split, n_test: int. Default 10. The number of testing and validation copies. Returns: train_data (DataFrame): Interaction for training. valid_data list(DataFrame): List of interactions for validation test_data list(DataFrame): List of interactions for testing """ if data is None: data = self.load_interaction() if not isinstance(data, pd.DataFrame): raise RuntimeError("data is not a type of DataFrame") if DEFAULT_TIMESTAMP_COL not in data.columns: raise RuntimeError("This dataset doesn't have an TIMESTAMP_COL") if DEFAULT_ORDER_COL not in data.columns: raise RuntimeError("This dataset doesn't have an ORDER_COL") result = split_data( data, split_type="temporal_basket", test_rate=test_rate, n_negative=n_negative, save_dir=self.processed_path, by_user=by_user, n_test=n_test, ) return result
[docs] def load_leave_one_out( self, random=False, n_negative=100, n_test=10, download=False, force_redo=False ): """Load split data generated by leave_out_out without random select. Load split data generated by leave_out_out without random select from Onedrive. Args: random (bool): . Whether randomly leave one item as testing. n_negative (int): Number of negative samples for testing and validation data. download (bool): Whether download the split produced by the Beta-rec team (With random seed:2020). force_redo (bool): Whether force to re-split the dataset. n_test (int): Default 10. The number of testing and validation copies. If n_test==0, will load the original (no negative items) valid and test datasets. Returns: train_data (DataFrame): Interaction for training. valid_data list(DataFrame): List of interactions for validation test_data list(DataFrame): List of interactions for testing """ processed_leave_one_out_path = os.path.join( self.processed_path, "leave_one_out" ) if not os.path.exists(processed_leave_one_out_path): os.mkdir(processed_leave_one_out_path) parameterized_path = generate_parameterized_path( test_rate=0, random=random, n_negative=n_negative, by_user=False ) download_path = processed_leave_one_out_path processed_leave_one_out_path = os.path.join( processed_leave_one_out_path, parameterized_path ) if force_redo: self.make_leave_one_out(random=random, n_negative=n_negative, n_test=n_test) elif not os.path.exists(processed_leave_one_out_path): if download and random is False and n_negative == 100: # default parameters, can be downloaded from Onedrive folder = OneDrive( url=self.processed_leave_one_out_url, path=download_path ) folder.download() un_zip(processed_leave_one_out_path + ".zip", download_path) else: # make self.make_leave_one_out( random=random, n_negative=n_negative, n_test=n_test ) # load data from local storage return load_split_data(processed_leave_one_out_path, n_test=n_test)
[docs] def load_leave_one_basket( self, random=False, n_negative=100, n_test=10, download=False, force_redo=False ): """Load split date generated by leave_one_basket without random select. Load split data generated by leave_one_basket without random select from Onedrive. Args: random: bool. Whether randomly leave one basket as testing. n_negative: Number of negative samples for testing and validation data. download (bool): Whether download the split produced by the Beta-rec team (With random seed:2020). force_redo (bool): Whether force to re-split the dataset. n_test: int. Default 10. The number of testing and validation copies. If n_test==0, will load the original (no negative items) valid and test datasets. Returns: train_data (DataFrame): Interaction for training. valid_data list(DataFrame): List of interactions for validation test_data list(DataFrame): List of interactions for testing """ processed_leave_one_basket_path = os.path.join( self.processed_path, "leave_one_basket" ) if not os.path.exists(processed_leave_one_basket_path): os.mkdir(processed_leave_one_basket_path) parameterized_path = generate_parameterized_path( test_rate=0, random=random, n_negative=n_negative, by_user=False ) download_path = processed_leave_one_basket_path processed_leave_one_basket_path = os.path.join( processed_leave_one_basket_path, parameterized_path ) if force_redo: self.make_leave_one_basket( random=random, n_negative=n_negative, n_test=n_test ) elif not os.path.exists(processed_leave_one_basket_path): if download and random is False and n_negative == 100: # default parameters, can be downloaded from Onedrive folder = OneDrive( url=self.processed_leave_one_basket_url, path=download_path ) folder.download() un_zip(processed_leave_one_basket_path + ".zip", download_path) else: # make self.make_leave_one_basket( random=random, n_negative=n_negative, n_test=n_test ) # load data from local storage return load_split_data(processed_leave_one_basket_path, n_test=n_test)
[docs] def load_random_split( self, test_rate=0.1, random=False, n_negative=100, by_user=False, n_test=10, download=False, force_redo=False, ): """Load split date generated by random_split. Load split data generated by random_split from Onedrive, with test_rate = 0.1 and by_user = False. Args: test_rate: percentage of the test data. Note that percentage of the validation data will be the same as test data. random: bool. Whether randomly leave one basket as testing. download (bool): Whether download the split produced by the Beta-rec team (With random seed:2020). force_redo (bool): Whether force to re-split the dataset. n_negative: Number of negative samples for testing and validation data. by_user: bool. Default False. - Ture: user-based split, - False: global split, n_test: int. Default 10. The number of testing and validation copies. If n_test==0, will load the original (no negative items) valid and test datasets. Returns: train_data (DataFrame): Interaction for training. valid_data list(DataFrame): List of interactions for validation test_data list(DataFrame): List of interactions for testing """ processed_random_split_path = os.path.join(self.processed_path, "random") if not os.path.exists(processed_random_split_path): os.mkdir(processed_random_split_path) parameterized_path = generate_parameterized_path( test_rate=test_rate, random=random, n_negative=n_negative, by_user=by_user ) download_path = processed_random_split_path processed_random_split_path = os.path.join( processed_random_split_path, parameterized_path ) if force_redo: self.make_random_split( test_rate=test_rate, random=random, n_negative=n_negative, by_user=by_user, n_test=n_test, ) elif not os.path.exists(processed_random_split_path): if ( download and test_rate == 0.1 and random is False and n_negative == 100 and by_user is False ): # default parameters, can be downloaded from Onedrive folder = OneDrive( url=self.processed_random_split_url, path=download_path ) folder.download() un_zip(processed_random_split_path + ".zip", download_path) else: # make self.make_random_split( test_rate=test_rate, random=random, n_negative=n_negative, by_user=by_user, n_test=n_test, ) # load data from local storage return load_split_data(processed_random_split_path, n_test=n_test)
[docs] def load_random_basket_split( self, test_rate=0.1, random=False, n_negative=100, by_user=False, n_test=10, download=False, force_redo=False, ): """Load split date generated by random_basket_split. Load split data generated by random_basket_split from Onedrive, with test_rate = 0.1 and by_user = False. Args: test_rate: percentage of the test data. Note that percentage of the validation data will be the same as test data. random: bool. Whether randomly leave one basket as testing. download (bool): Whether download the split produced by the Beta-rec team (With random seed:2020). force_redo (bool): Whether force to re-split the dataset. n_negative: Number of negative samples for testing and validation data. by_user: bool. Default False. - True: user-based split, - False: global split, n_test: int. Default 10. The number of testing and validation copies. If n_test==0, will load the original (no negative items) valid and test datasets. Returns: train_data (DataFrame): Interaction for training. valid_data list(DataFrame): List of interactions for validation test_data list(DataFrame): List of interactions for testing """ processed_random_basket_split_path = os.path.join( self.processed_path, "random_basket" ) if not os.path.exists(processed_random_basket_split_path): os.mkdir(processed_random_basket_split_path) parameterized_path = generate_parameterized_path( test_rate=test_rate, random=random, n_negative=n_negative, by_user=by_user ) download_path = processed_random_basket_split_path processed_random_basket_split_path = os.path.join( processed_random_basket_split_path, parameterized_path ) if force_redo: self.make_random_basket_split( test_rate=test_rate, random=random, n_negative=n_negative, by_user=by_user, n_test=n_test, ) elif not os.path.exists(processed_random_basket_split_path): if ( download and test_rate == 0.1 and random is False and n_negative == 100 and by_user is False ): # default parameters, can be downloaded from Onedrive folder = OneDrive( url=self.processed_random_basket_split_url, path=download_path ) folder.download() un_zip(processed_random_basket_split_path + ".zip", download_path) else: # make self.make_random_basket_split( test_rate=test_rate, random=random, n_negative=n_negative, by_user=by_user, n_test=n_test, ) # load data from local storage return load_split_data(processed_random_basket_split_path, n_test=n_test)
[docs] def load_temporal_split( self, test_rate=0.1, n_negative=100, by_user=False, n_test=10, download=False, force_redo=False, ): """Load split date generated by temporal_split. Load split data generated by temporal_split from Onedrive, with test_rate = 0.1 and by_user = False. Args: test_rate: percentage of the test data. Note that percentage of the validation data will be the same as test data. n_negative: Number of negative samples for testing and validation data. by_user: bool. Default False. - True: user-based split, - False: global split, n_test: int. Default 10. The number of testing and validation copies. If n_test==0, will load the original (no negative items) valid and test datasets. download (bool): Whether download the split produced by the Beta-rec team (With random seed:2020). force_redo (bool): Whether force to re-split the dataset. Returns: train_data (DataFrame): Interaction for training. valid_data list(DataFrame): List of interactions for validation test_data list(DataFrame): List of interactions for testing """ processed_temporal_split_path = os.path.join(self.processed_path, "temporal") if not os.path.exists(processed_temporal_split_path): os.mkdir(processed_temporal_split_path) parameterized_path = generate_parameterized_path( test_rate=test_rate, random=False, n_negative=n_negative, by_user=by_user ) download_path = processed_temporal_split_path processed_temporal_split_path = os.path.join( processed_temporal_split_path, parameterized_path ) if force_redo: self.make_temporal_split( test_rate=test_rate, n_negative=n_negative, by_user=by_user, n_test=n_test, ) elif not os.path.exists(processed_temporal_split_path): if download and test_rate == 0.1 and n_negative == 100 and by_user is False: # default parameters, can be downloaded from Onedrive folder = OneDrive( url=self.processed_temporal_split_url, path=download_path ) folder.download() un_zip(processed_temporal_split_path + ".zip", download_path) else: # make self.make_temporal_split( test_rate=test_rate, n_negative=n_negative, by_user=by_user, n_test=n_test, ) # load data from local storage return load_split_data(processed_temporal_split_path, n_test=n_test)
[docs] def load_temporal_basket_split( self, test_rate=0.1, n_negative=100, by_user=False, n_test=10, download=False, force_redo=False, ): """Load split date generated by temporal_basket_split. Load split data generated by temporal_basket_split from Onedrive, with test_rate = 0.1 and by_user = False. Args: test_rate: percentage of the test data. Note that percentage of the validation data will be the same as test data. n_negative: Number of negative samples for testing and validation data. by_user: bool. Default False. - True: user-based split, - False: global split, n_test: int. Default 10. The number of testing and validation copies. If n_test==0, will load the original (no negative items) valid and test datasets. download (bool): Whether download the split produced by the Beta-rec team (With random seed:2020). force_redo (bool): Whether force to re-split the dataset. Returns: train_data (DataFrame): Interaction for training. valid_data list(DataFrame): List of interactions for validation test_data list(DataFrame): List of interactions for testing """ processed_temporal_basket_split_path = os.path.join( self.processed_path, "temporal_basket" ) if not os.path.exists(processed_temporal_basket_split_path): os.mkdir(processed_temporal_basket_split_path) parameterized_path = generate_parameterized_path( test_rate=test_rate, random=False, n_negative=n_negative, by_user=by_user ) download_path = processed_temporal_basket_split_path processed_temporal_basket_split_path = os.path.join( processed_temporal_basket_split_path, parameterized_path ) if force_redo: self.make_temporal_basket_split( test_rate=test_rate, n_negative=n_negative, by_user=by_user, n_test=n_test, ) elif not os.path.exists(processed_temporal_basket_split_path): if download and test_rate == 0.1 and n_negative == 100 and by_user is False: # default parameters, can be downloaded from Onedrive folder = OneDrive( url=self.processed_temporal_basket_split_url, path=download_path ) folder.download() un_zip(processed_temporal_basket_split_path + ".zip", download_path) else: # make self.make_temporal_basket_split( test_rate=test_rate, n_negative=n_negative, by_user=by_user, n_test=n_test, ) # load data from local storage return load_split_data(processed_temporal_basket_split_path, n_test=n_test)
[docs] def load_split(self, config): """Load split data by config dict. Args: config (dict): config (dict): Dictionary of configuration Returns: train_data (DataFrame): Interaction for training. valid_data list(DataFrame): List of interactions for validation test_data list(DataFrame): List of interactions for testing """ data_split_str = config["data_split"] split_paras = {} split_paras["test_rate"] = config["test_rate"] if "test_rate" in config else 0.1 split_paras["random"] = config["random"] if "random" in config else False split_paras["download"] = config["download"] if "download" in config else False split_paras["n_negative"] = ( config["n_negative"] if "n_negative" in config else 100 ) split_paras["by_user"] = config["by_user"] if "by_user" in config else False split_paras["n_test"] = config["n_test"] if "n_test" in config else 10 if split_paras["n_negative"] < 0 and split_paras["n_test"] > 1: # n_negative < 0, validate and testing sets of splits will contain all the negative items. # There will be only one validata and one testing sets. split_paras["n_test"] = 1 data_split_mapping = { "leave_one_out": self.load_leave_one_out, "leave_one_basket": self.load_leave_one_basket, "random_split": self.load_random_split, "random_basket_split": self.load_random_basket_split, "temporal": self.load_temporal_split, "temporal_basket": self.load_temporal_basket_split, } split_para_mapping = { "leave_one_out": ["random", "download", "n_negative", "n_test"], "leave_one_basket": ["random", "download", "n_negative", "n_test"], "random_split": [ "test_rate", "download", "by_user", "n_negative", "n_test", ], "random_basket_split": [ "test_rate", "download", "by_user", "n_negative", "n_test", ], "temporal": ["test_rate", "by_user", "download", "n_negative", "n_test"], "temporal_basket": [ "test_rate", "download", "by_user", "n_negative", "n_test", ], } para_dic = { split_para_key: split_paras[split_para_key] if split_para_key in split_paras else None for split_para_key in split_para_mapping[data_split_str] } train_data, valid_data, test_data = data_split_mapping[data_split_str]( **para_dic ) return train_data, valid_data, test_data