-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
运行srgnn时报错 #81
Comments
你好!recbole_gnn重写了自己的主函数,包括数据集构造,模型选择等都与recbole有一定区别。 |
您好,感谢回复!因为之前一直是按照修改的主函数进行运行,所以我仍旧希望能在之前的主函数进行。经过和其他使用recbole的同学交流,他使用以下的代码是可以成功运行的,但是我会在训练的部分出错,修改后的代码与运行结果如下: === from model ===from model import myGCEGNNfrom test import GCEGNN if name == 'main':
运行结果: Training Hyper Parameters: Evaluation Hyper Parameters: Dataset Hyper Parameters: Other Hyper Parameters: 14 Mar 14:36 INFO diginetica |
你好!我也出现了同样的问题,请问你现在解决了吗 |
敬爱的工作者您好!我在运行srgnn时报错,猜测应该是main函数中trainer和interaction使用的是recbole而非recbole_gnn框架下的问题,但我不知道如何进行修改补充,辛苦您为我答疑解惑,期待您的回复,万分感谢!
main函数:
from recbole_gnn.config import Config
from recbole_gnn.utils import create_dataset, data_preparation
from recbole.utils import init_logger, init_seed
from recbole_gnn.utils import set_color, get_trainer
from logging import getLogger
from test import SRGNN
if name == 'main':
# configurations initialization
config = Config(
model=SRGNN,
dataset='diginetica',
config_file_list=['config.yaml', 'config_model.yaml'],
)
init_seed(config['seed'], config['reproducibility'])
config.yaml与config_model.yaml均使用框架中提供的参数。
运行结果:
General Hyper Parameters:
gpu_id = 0
use_gpu = True
seed = 2020
state = INFO
reproducibility = True
data_path = dataset/diginetica
checkpoint_dir = saved
show_progress = True
save_dataset = False
dataset_save_path = None
save_dataloaders = False
dataloaders_save_path = None
log_wandb = False
Training Hyper Parameters:
epochs = 500
train_batch_size = 4096
learner = adam
learning_rate = 0.001
neg_sampling = None
eval_step = 1
stopping_step = 10
clip_grad_norm = None
weight_decay = 0.0
loss_decimal_place = 4
Evaluation Hyper Parameters:
eval_args = {'split': {'LS': 'valid_and_test'}, 'mode': 'full', 'order': 'TO', 'group_by': 'user'}
repeatable = True
metrics = ['MRR', 'Precision']
topk = [10, 20]
valid_metric = MRR@10
valid_metric_bigger = True
eval_batch_size = 2000
metric_decimal_place = 5
Dataset Hyper Parameters:
field_separator =
seq_separator =
USER_ID_FIELD = session_id
ITEM_ID_FIELD = item_id
RATING_FIELD = rating
TIME_FIELD = timestamp
seq_len = None
LABEL_FIELD = label
threshold = None
NEG_PREFIX = neg_
load_col = {'inter': ['session_id', 'item_id', 'timestamp']}
unload_col = None
unused_col = None
additional_feat_suffix = None
rm_dup_inter = None
val_interval = None
filter_inter_by_user_or_item = True
user_inter_num_interval = [5,inf)
item_inter_num_interval = [5,inf)
alias_of_user_id = None
alias_of_item_id = None
alias_of_entity_id = None
alias_of_relation_id = None
preload_weight = None
normalize_field = None
normalize_all = None
ITEM_LIST_LENGTH_FIELD = item_length
LIST_SUFFIX = _list
MAX_ITEM_LIST_LENGTH = 20
POSITION_FIELD = position_id
HEAD_ENTITY_ID_FIELD = head_id
TAIL_ENTITY_ID_FIELD = tail_id
RELATION_ID_FIELD = relation_id
ENTITY_ID_FIELD = entity_id
benchmark_filename = None
Other Hyper Parameters:
wandb_project = recbole
require_pow = False
embedding_size = 64
step = 1
loss_type = CE
MODEL_TYPE = ModelType.SEQUENTIAL
gnn_transform = sess_graph
train_neg_sample_args = {'strategy': 'none'}
MODEL_INPUT_TYPE = InputType.POINTWISE
eval_type = EvaluatorType.RANKING
device = cpu
eval_neg_sample_args = {'strategy': 'full', 'distribution': 'uniform'}
06 Mar 13:17 INFO diginetica
The number of users: 72014
Average actions of users: 8.060905669809618
The number of items: 29454
Average actions of items: 19.70902794282416
The number of inters: 580490
The sparsity of the dataset: 99.97263260088765%
Remain Fields: ['session_id', 'item_id', 'timestamp']
06 Mar 13:17 INFO Constructing session graphs.
100%|██████████| 364451/364451 [00:33<00:00, 11034.37it/s]
06 Mar 13:18 INFO Constructing session graphs.
100%|██████████| 72013/72013 [00:07<00:00, 9464.61it/s]
06 Mar 13:18 INFO Constructing session graphs.
100%|██████████| 72013/72013 [00:07<00:00, 9047.17it/s]
06 Mar 13:18 INFO SessionGraph Transform in DataLoader.
06 Mar 13:18 INFO SessionGraph Transform in DataLoader.
06 Mar 13:18 INFO SessionGraph Transform in DataLoader.
06 Mar 13:18 INFO [Training]: train_batch_size = [4096] negative sampling: [{'strategy': 'none'}]
06 Mar 13:18 INFO [Evaluation]: eval_batch_size = [2000] eval_args: [{'split': {'LS': 'valid_and_test'}, 'mode': 'full', 'order': 'TO', 'group_by': 'user'}]
06 Mar 13:18 INFO SRGNN(
(item_embedding): Embedding(29454, 64, padding_idx=0)
(gnncell): SRGNNCell(
(incomming_conv): SRGNNConv()
(outcomming_conv): SRGNNConv()
(lin_ih): Linear(in_features=128, out_features=192, bias=True)
(lin_hh): Linear(in_features=64, out_features=192, bias=True)
)
(linear_one): Linear(in_features=64, out_features=64, bias=True)
(linear_two): Linear(in_features=64, out_features=64, bias=True)
(linear_three): Linear(in_features=64, out_features=1, bias=False)
(linear_transform): Linear(in_features=128, out_features=64, bias=True)
(loss_fct): CrossEntropyLoss()
)
Trainable parameters: 1947264
Train 0: 0%| | 0/89 [00:00<?, ?it/s]
Traceback (most recent call last):
File "E:/ADACONDA/envs/pytorch/pythonproject_test/Next Work/RecBole-GNN-main/main.py", line 41, in
best_valid_score, best_valid_result = trainer.fit(
File "E:\ADACONDA\envs\pytorch\lib\site-packages\recbole\trainer\trainer.py", line 335, in fit
train_loss = self._train_epoch(train_data, epoch_idx, show_progress=show_progress)
File "E:\ADACONDA\envs\pytorch\lib\site-packages\recbole\trainer\trainer.py", line 181, in _train_epoch
losses = loss_func(interaction)
File "E:\ADACONDA\envs\pytorch\pythonproject_test\Next Work\RecBole-GNN-main\test.py", line 105, in calculate_loss
x = interaction['x']
File "E:\ADACONDA\envs\pytorch\lib\site-packages\recbole\data\interaction.py", line 131, in getitem
return self.interaction[index]
KeyError: 'x'
The text was updated successfully, but these errors were encountered: