-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest.py
executable file
·124 lines (96 loc) · 5.05 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Copyright 2018 Saarland University, Spoken Language
# Systems LSV (author: Youssef Oualil, during his work period at LSV)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS*, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
#
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
###############################################################################
# Parts of this code are based on the Tensorflow PTB-LM recipe licensed under
# the Apache License, Version 2.0 by the TensorFlow Authors.
# (Source: https://github.com/tensorflow/models/blob/master/tutorials/rnn/ptb/ptb_word_lm.py
# retrieved in January 2018)
###############################################################################
from __future__ import print_function
import time
import sys
import os
from six.moves import cPickle
import argparse
#from tensorflow.python import pywrap_tensorflow
import tensorflow as tf
import numpy as np
from data_processor import DataProcessor
from basic_rnn_models import LM as Basic_RNN
from srnn import LM as SRNN
from lsrc import LM as LSRC
def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--model_file', type=str, default='model.ckpt-15.data-00000-of-00001',
help='path to the file storing the model to be evaluated')
parser.add_argument('--test_file', type=str, default='data/ptb/ptb.test.txt',
help='path to the test file containing one sentence per line (</eos> is automatically added)')
parser.add_argument('--batch_size', type=int, default=20,
help='mini-batch size')
parser.add_argument('--seq_length', type=int, default=1,
help='Word sequence length processed at each forward pass. Increasing the value can increase '
+ 'processing speed, but might result in slightly different results due to some words '
+ 'at the end of the corpus not fitting into the batch-size x seq-length matrix.')
test_config = parser.parse_args()
test_config.save_dir = os.path.dirname(test_config.model_file)
try:
with open(os.path.join(test_config.save_dir, 'config.pkl'), 'rb') as f:
config = cPickle.load(f)
except IOError:
raise IOError("Could not open and/or read the config file {}.".format(
os.path.join(test_config.save_dir, 'config.pkl')))
# copy the parameters that are specific to the test data
config.save_dir = test_config.save_dir
config.test_file = test_config.test_file
config.batch_size = test_config.batch_size
config.seq_length = test_config.seq_length
config.model_path = test_config.model_file
if not hasattr(config, 'history_size'):
config.history_size = 1
calculate_perplexity(config)
def calculate_perplexity(config):
"""
Calculate perplexity for a given (processed) test data
"""
# load the config files and vocabulary files
if not os.path.exists(config.model_path + '.meta'):
raise Exception("Could not open and/or read model file {}".format(config.model_path + '.meta'))
# process the test corpus and load it into batches
# vocabulary is loaded from the model directory
test_data = DataProcessor(config.test_file, config.batch_size, config.seq_length,
is_training=False, unk='<unk>', history_size=config.history_size,
vocab_dir_path=config.save_dir)
# define/load the model
with tf.variable_scope("Model", reuse=False):
if config.model == 'lsrc':
model_test = LSRC(config, False)
elif config.model == 'wi-srnn' or config.model == 'wd-srnn' or config.model == 'ff-srnn' :
model_test = SRNN(config, False)
elif config.model == 'lstm' or config.model == 'lstmp' or config.model == 'rnn' or config.model == 'gru':
model_test = Basic_RNN(config, False)
else:
raise Exception("model type not supported: {}".format(config.model))
with tf.Session() as session:
# restore the model
saver = tf.train.Saver(tf.global_variables())
saver.restore(session, config.model_path)
test_perplexity = model_test.run_model(session, test_data, eval_op=None, verbosity=10000, verbose=True)
print("\n[SUMMARY] Perplexity: %.3f" % test_perplexity)
print('========================\n')
if __name__ == '__main__':
main()