Coveralls logob
Coveralls logo
  • Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ishikota / kyoka / 130

24 Dec 2016 - 0:25 coverage: 98.031% (+0.03%) from 98.0%
130

Pull #33

travis-ci

9181eb84f9c35729a3bad740fb7f9d93?size=18&default=identiconweb-flow
Bump up version to 0.2.1
Pull Request #33: release v0.2.1

77 of 80 new or added lines in 9 files covered. (96.25%)

1 existing line in 1 file now uncovered.

747 of 762 relevant lines covered (98.03%)

0.98 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.75
/kyoka/callback.py
1
import os
1×
2
import time
1×
3

4
from utils import build_not_implemented_msg
1×
5

6

7
class BaseCallback(object):
1×
8
    """Base class for creating new GPI callback as you like.
9

10
    The following callback methods are called from BaseRLAlgorithm.run_gpi_for_an_episode
11
    in proper timing.
12

13
    before_gpi_start : called before start training
14
    before_update : called before each call of RLalgorithm.run_gpi_for_an_episode
15
    after_update : called after each call of RLalgorithm.run_gpi_for_an_episode
16
    after_gpi_finish : called when training is finished
17
    interrupt_gpi : called after each update is finished to judge if finish the training
18
    """
19

20
    def before_gpi_start(self, task, value_function):
1×
NEW
21
        pass
!
22

23
    def before_update(self, iteration_count, task, value_function):
1×
24
        pass
1×
25

26
    def after_update(self, iteration_count, task, value_function):
1×
27
        pass
1×
28

29
    def after_gpi_finish(self, task, value_function):
1×
30
        pass
1×
31

32
    def interrupt_gpi(self, iteration_count, task, value_function):
1×
33
        """Return True if you want to stop the training."""
34
        return False
1×
35

36
    def define_log_tag(self):
1×
37
        """Define tag string which displayed with log message.
38

39
        Ex. if you return "MyTag" here then
40
            self.log("some message") would output
41
            "[MyTag] some message" on your console
42
        """
43
        return self.__class__.__name__
1×
44

45
    @property
1×
46
    def tag(self):
47
        return self.define_log_tag()
1×
48

49
    def log(self, message):
1×
50
        if message and len(message) != 0:
1×
51
            print "[%s] %s" % (self.tag, message)
1×
52

53
class BasePerformanceWatcher(BaseCallback):
1×
54
    """Utility class to execute some calculation with intermediate result of training.
55
    """
56

57
    def setUp(self, task, value_function):
1×
58
        pass
1×
59

60
    def tearDown(self, task, value_function):
1×
61
        pass
!
62

63
    def define_performance_test_interval(self):
1×
64
        """Define interval to execute "run_performance_test".
65

66
        For example, if you return 1 then "run_performance_test" is called after
67
        every update of training.
68
        """
69
        err_msg = build_not_implemented_msg(self, "define_performance_test_interval")
1×
70
        raise NotImplementedError(err_msg)
1×
71

72
    def run_performance_test(self, task, value_function):
1×
73
        """Define some calculation to see how good current value_function works.
74

75
        Args:
76
            task: Task instance which is used in training
77
            value_function: Value function which is in the middle of training.
78
        """
79
        err_msg = build_not_implemented_msg(self, "run_performance_test")
1×
80
        raise NotImplementedError(err_msg)
1×
81

82
    def define_log_message(self, iteration_count, task, value_function, test_result):
1×
83
        base_msg = "Performance test result : %s (nb_iteration=%d)"
1×
84
        return base_msg % (test_result, iteration_count)
1×
85

86

87
    def before_gpi_start(self, task, value_function):
1×
88
        self.performance_log = []
1×
89
        self.test_interval = self.define_performance_test_interval()
1×
90
        self.setUp(task, value_function)
1×
91

92
    def after_update(self, iteration_count, task, value_function):
1×
93
        if iteration_count % self.test_interval == 0:
1×
94
            result = self.run_performance_test(task, value_function)
1×
95
            self.performance_log.append(result)
1×
96
            message = self.define_log_message(iteration_count, task, value_function, result)
1×
97
            self.log(message)
1×
98

99
    def after_gpi_finish(self, task, value_function):
1×
100
        self.tearDown(task, value_function)
1×
101

102
class EpsilonAnnealer(BaseCallback):
1×
103
    """Callback to decay epsiron of EpsironGreedyPolicy during training."""
104

105
    def __init__(self, epsilon_greedy_policy):
1×
106
        """
107
        Args:
108
            epsilon_greedy_policy: target to execute epsilon annealing.
109
        """
110
        self.policy = epsilon_greedy_policy
1×
111
        self.anneal_finished = False
1×
112

113
    def define_log_tag(self):
1×
114
        return "EpsilonGreedyAnnealing"
1×
115

116
    def before_gpi_start(self, _task, _value_function):
1×
117
        start_msg = "Anneal epsilon from %s to %s." % (self.policy.eps, self.policy.min_eps)
1×
118
        self.log(start_msg)
1×
119

120
    def after_update(self, iteration_count, _task, _value_function):
1×
121
        self.policy.anneal_eps()
1×
122
        if not self.anneal_finished and self.policy.eps == self.policy.min_eps:
1×
123
            self.anneal_finished = True
1×
124
            finish_msg = "Annealing has finished at %d iteration." % iteration_count
1×
125
            self.log(finish_msg)
1×
126

127
class LearningRecorder(BaseCallback):
1×
128
    """Callback to save intermediate result of training.
129

130
    If you set
131
        root_save_dir_path=training_results
132
        save_interval=100
133
    After 250 iteration of training, "algorithm.save(training_results)" is called twice.
134
    So training_results directory would have two items like below
135

136
    training_results/after_100_iteration/...
137
                    /after_200_iteration/...
138
    """
139

140
    def __init__(self, algorithm, root_save_dir_path, save_interval):
1×
141
        """
142
        Args:
143
            algorithm: the RL algorithm which will be used in training.
144
            root_save_dir_path: save method is executed under this path of directory.
145
            save_interval: interval of training to execute save method.
146
        """
147
        self.algorithm = algorithm
1×
148
        self.root_save_dir_path = root_save_dir_path
1×
149
        self.save_interval = save_interval
1×
150

151
    def before_gpi_start(self, _task, _value_function):
1×
152
        if not os.path.exists(self.root_save_dir_path):
1×
153
            err_msg = "Directory [ %s ] not found which you passed to LearningRecorder."
1×
154
            raise Exception(err_msg  % self.root_save_dir_path)
1×
155
        base_msg = 'Your algorithm will be saved after each %d iteration on directory [ %s ].'
1×
156
        self.log(base_msg % (self.save_interval, self.root_save_dir_path))
1×
157

158
    def after_update(self, iteration_count, _task, _value_function):
1×
159
        if iteration_count % self.save_interval == 0:
1×
160
            dir_name = self.define_checkpoint_save_dir_name(iteration_count)
1×
161
            save_path = os.path.join(self.root_save_dir_path, dir_name)
1×
162
            os.mkdir(save_path)
1×
163
            self.algorithm.save(save_path)
1×
164
            base_msg = "Saved algorithm after %d iteration at [ %s ]."
1×
165
            self.log(base_msg % (iteration_count, save_path))
1×
166

167
    def after_gpi_finish(self, task, value_function):
1×
168
        dir_name = self.define_finish_save_dir_name()
1×
169
        save_path = os.path.join(self.root_save_dir_path, dir_name)
1×
170
        os.mkdir(save_path)
1×
171
        self.algorithm.save(save_path)
1×
172

173
    def define_checkpoint_save_dir_name(self, iteration_count):
1×
174
        return "after_%d_iteration" % iteration_count
1×
175

176
    def define_finish_save_dir_name(self):
1×
177
        return "gpi_finished"
1×
178

179
class BaseFinishRule(BaseCallback):
1×
180
    """Base class to define the rule to stop the training.
181

182
    Child class needs to implement following 3 methods.
183

184
    check_condition : return True if you want  to stop the training.
185
    generate_start_message : the string retured here is logged when start training.
186
    generate_finish_message : the string returned here is logged after training finished.
187
    """
188

189
    def check_condition(self, iteration_count, task, value_function):
1×
190
        """Return True if you want to stop the training"""
191
        err_msg = build_not_implemented_msg(self, "check_condition")
1×
192
        raise NotImplementedError(err_msg)
1×
193

194
    def generate_start_message(self):
1×
195
        err_msg = build_not_implemented_msg(self, "generate_start_message")
1×
196
        raise NotImplementedError(err_msg)
1×
197

198
    def generate_finish_message(self, iteration_count):
1×
199
        err_msg = build_not_implemented_msg(self, "generate_finish_message")
1×
200
        raise NotImplementedError(err_msg)
1×
201

202
    def before_gpi_start(self, task, value_function):
1×
203
        self.log(self.generate_start_message())
1×
204

205
    def interrupt_gpi(self, iteration_count, task, value_function):
1×
206
        finish_iteration = self.check_condition(iteration_count, task, value_function)
1×
207
        if finish_iteration: self.log(self.generate_finish_message(iteration_count))
1×
208
        return finish_iteration
1×
209

210
class ManualInterruption(BaseFinishRule):
1×
211
    """Callback to stop the training by file-base communication
212

213
    How to stop the training manually
214
    1. Share the path of a file with this callback
215
    2. Start the training with this callback.
216
    3. Write target word (default is "stop") on the shared file.
217
    4. Callback will find the target word from the shared file and stop the training.
218
    """
219

220
    TARGET_WARD = "stop"
1×
221

222
    def __init__(self, monitor_file_path, watch_interval=30):
1×
223
        """
224
        Args:
225
            monitor_file_path: path of file to send stop command.
226
            watch_interval: the interval to check the file in monitor_file_path
227
        """
228
        self.monitor_file_path = monitor_file_path
1×
229
        self.watch_interval = watch_interval
1×
230

231
    def check_condition(self, _iteration_count, _task, _value_function):
1×
232
        current_time = time.time()
1×
233
        if current_time - self.last_check_time >= self.watch_interval:
1×
234
            self.last_check_time = current_time
1×
235
            return self.__order_found_in_monitoring_file(self.monitor_file_path, self.TARGET_WARD)
1×
236
        else:
237
            return False
1×
238

239
    def generate_start_message(self):
1×
240
        self.last_check_time = time.time()
1×
241
        base_first_msg ='Write word "%s" on file "%s" will finish the GPI'
1×
242
        base_second_msg = "(Stopping GPI may take about %s seconds. Because we check target file every %s seconds.)"
1×
243
        first_msg = base_first_msg % (self.TARGET_WARD, self.monitor_file_path)
1×
244
        second_msg = base_second_msg % (self.watch_interval, self.watch_interval)
1×
245
        return "\n".join([first_msg, second_msg])
1×
246

247
    def generate_finish_message(self, iteration_count):
1×
248
        base_msg = "Interrupt GPI after %d iterations because interupption order found in [ %s ]."
1×
249
        return base_msg % (iteration_count, self.monitor_file_path)
1×
250

251
    def __order_found_in_monitoring_file(self, filepath, target_word):
1×
252
        return os.path.isfile(filepath) and self.__found_target_ward_in_file(filepath, target_word)
1×
253

254
    def __found_target_ward_in_file(self, filepath, target_word):
1×
255
        search_word = lambda src, target: target in src
1×
256
        src = self.__read_data(filepath)
1×
257
        return search_word(src, target_word) if src else False
1×
258

259
    def __read_data(self, filepath):
1×
260
        with open(filepath, 'rb') as f: return f.read()
1×
261

262
class WatchIterationCount(BaseFinishRule):
1×
263
    """Finish the training after specified number of training."""
264

265
    def __init__(self, target_count, verbose=1):
1×
266
        """
267
        Args
268
            target_count : stop the training after this number of iteration.
269
            verbose : if verbose > 0 then log is activated.
270
        """
271
        self.target_count = target_count
1×
272
        self.start_time = self.last_update_time = 0
1×
273
        self.verbose = verbose
1×
274

275
    def define_log_tag(self):
1×
276
        return "Progress"
1×
277

278
    def check_condition(self, iteration_count, task, value_function):
1×
279
        return iteration_count >= self.target_count
1×
280

281
    def generate_start_message(self):
1×
282
        self.start_time = self.last_update_time = time.time()
1×
283
        return "Start GPI iteration for %d times" % self.target_count
1×
284

285
    def generate_finish_message(self, iteration_count):
1×
286
        base_msg = "Completed GPI iteration for %d times. (total time: %ds)"
1×
287
        return base_msg % (iteration_count, time.time() - self.start_time)
1×
288

289
    def before_update(self, iteration_count, task, value_function):
1×
290
        super(WatchIterationCount, self).before_update(iteration_count, task, value_function)
1×
291
        self.last_update_time = time.time()
1×
292

293
    def after_update(self, iteration_count, task, value_function):
1×
294
        super(WatchIterationCount, self).after_update(iteration_count, task, value_function)
1×
295
        if self.verbose > 0:
1×
296
            current_time = time.time()
1×
297
            msg = "Finished %d / %d iterations (%.1fs)" % (
1×
298
                    iteration_count, self.target_count,
299
                    current_time - self.last_update_time)
300
            self.last_update_time = current_time
1×
301
            self.log(msg)
1×
302

Troubleshooting · Open an Issue · Sales · Support · ENTERPRISE · CAREERS · STATUS
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2023 Coveralls, Inc