• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ishikota / kyoka / 20

24 Oct 2016 - 5:58 coverage: 92.529% (+2.5%) from 90.06%
20

Pull #16

travis-ci

web-flow
Update sample script to use new save logic
Pull Request #16: Change save logic

116 of 122 new or added lines in 8 files covered. (95.08%)

1 existing line in 1 file now uncovered.

644 of 696 relevant lines covered (92.53%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.0
/kyoka/algorithm/base_rl_algorithm.py
1
from kyoka.finish_rule.watch_iteration_count import WatchIterationCount
1×
2

3
class BaseRLAlgorithm(object):
1×
4

5
  def setUp(self, domain, policy, value_function):
1×
6
    self.domain = domain
1×
7
    self.value_function = value_function
1×
8
    self.value_function.setUp()
1×
9
    self.policy = policy
1×
10

11
  def save(self, save_dir_path):
1×
12
    self.value_function.save(save_dir_path)
1×
13
    self.save_algorithm_state(save_dir_path)
1×
14

15
  def load(self, load_dir_path):
1×
16
    self.value_function.load(load_dir_path)
1×
17
    self.load_algorithm_state(load_dir_path)
1×
18

19
  def save_algorithm_state(self, save_dir_path):
1×
NEW
20
    pass
!
21

22
  def load_algorithm_state(self, load_dir_path):
1×
NEW
23
    pass
!
24

25
  def update_value_function(self, domain, policy, value_function):
1×
26
    err_msg = self.__build_err_msg("update_value_function")
1×
27
    raise NotImplementedError(err_msg)
1×
28

29
  def run_gpi(self, nb_iteration, finish_rules=[], callbacks=[]):
1×
30
    callbacks = self.__wrap_item_if_single(callbacks)
1×
31
    finish_rules = self.__wrap_item_if_single(finish_rules)
1×
32
    finish_rules.append(WatchIterationCount(nb_iteration))
1×
33
    iteration_counter = 0
1×
34
    [callback.before_gpi_start(self.domain, self.value_function) for callback in callbacks]
1×
35
    while True:
1×
36
      [callback.before_update(iteration_counter, self.domain, self.value_function) for callback in callbacks]
1×
37
      self.update_value_function(self.domain, self.policy, self.value_function)
1×
38
      [callback.after_update(iteration_counter, self.domain, self.value_function) for callback in callbacks]
1×
39
      iteration_counter += 1
1×
40
      for finish_rule in finish_rules:
1×
41
        if finish_rule.satisfy_condition(iteration_counter):
1×
42
          finish_msg = finish_rule.generate_finish_message(iteration_counter)
1×
43
          [callback.after_gpi_finish(self.domain, self.value_function) for callback in callbacks]
1×
44
          return finish_msg
1×
45

46
  def generate_episode(self, domain, value_function, policy):
1×
47
    state = domain.generate_initial_state()
1×
48
    episode = []
1×
49
    while not domain.is_terminal_state(state):
1×
50
      action = policy.choose_action(domain, value_function, state)
1×
51
      next_state = domain.transit_state(state, action)
1×
52
      reward = domain.calculate_reward(next_state)
1×
53
      episode.append((state, action, next_state, reward))
1×
54
      state = next_state
1×
55
    return episode
1×
56

57

58
  def __wrap_item_if_single(self, item):
1×
59
    return [item] if not isinstance(item, list) else item
1×
60

61
  def __build_err_msg(self, msg):
1×
62
    return "Accessed [ {0} ] method of BaseRLAlgorithm which should be overridden".format(msg)
1×
63

Troubleshooting · Open an Issue · Sales · Support · ENTERPRISE · CAREERS · STATUS
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2023 Coveralls, Inc