• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pim-book / programmers-introduction-to-mathematics / #987

pending completion
#987

push

travis-ci

GitHub
build(deps): bump decode-uri-component in /waves/javascript_demo

910 of 910 relevant lines covered (100.0%)

1.0 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

100.0
/neural_network/mnist_network.py
1
from neural_network import InputNode
1✔
2
from neural_network import L2ErrorNode
1✔
3
from neural_network import LinearNode
1✔
4
from neural_network import NeuralNetwork
1✔
5
from neural_network import ReluNode
1✔
6
from neural_network import SigmoidNode
1✔
7
from random import shuffle
1✔
8
import os
1✔
9

10

11
def load_1s_and_7s(filename):
1✔
12
    print('Loading data {}...'.format(filename))
1✔
13
    examples = []
1✔
14
    with open(filename, 'r') as infile:
1✔
15
        for line in infile:
1✔
16
            if line[0] in ['1', '7']:
1✔
17
                tokens = [int(x) for x in line.split(',')]
1✔
18
                label = tokens[0]
1✔
19
                example = [x / 255 for x in tokens[1:]]  # scale to [0,1]
1✔
20
                if label == 1:
1✔
21
                    examples.append([example, 0])
1✔
22
                elif label == 7:
1✔
23
                    examples.append([example, 1])
1✔
24
    print('Data loaded.')
1✔
25
    return examples
1✔
26

27

28
def print_example(example):
1✔
29
    for i, pixel in enumerate(example):
1✔
30
        if i % 28 == 0:
1✔
31
            print()
1✔
32
        print('%4d' % int(pixel * 255), end='')
1✔
33

34

35
def show_random_examples(network, test, n=5):
1✔
36
    test = test[:]
1✔
37
    shuffle(test)
1✔
38
    for i in range(n):
1✔
39
        example, label = test[i]
1✔
40
        print_example(example)
1✔
41
        print("\nExample with label {} is predicted to have label {}".format(
1✔
42
            label, network.evaluate(example)))
43

44

45
def build_network():
1✔
46
    input_nodes = InputNode.make_input_nodes(28*28)
1✔
47

48
    first_layer = [LinearNode(input_nodes) for i in range(10)]
1✔
49
    first_layer_relu = [ReluNode(L) for L in first_layer]
1✔
50

51
    second_layer = [LinearNode(first_layer_relu) for i in range(10)]
1✔
52
    second_layer_relu = [ReluNode(L) for L in second_layer]
1✔
53

54
    linear_output = LinearNode(second_layer_relu)
1✔
55
    output = SigmoidNode(linear_output)
1✔
56
    error_node = L2ErrorNode(output)
1✔
57
    network = NeuralNetwork(
1✔
58
        output, input_nodes, error_node=error_node, step_size=0.05)
59

60
    return network
1✔
61

62

63
cant_find_files = '''
1✔
64
Was unable to find the files {}, {}.
65

66
You may have to extract them from the gzipped tarball mnist/mnist.tar.gz.
67
'''
68

69

70
def train_mnist(data_dirname, num_epochs=5):
1✔
71
    train_file = os.path.join(data_dirname, 'mnist_train.csv')
1✔
72
    test_file = os.path.join(data_dirname, 'mnist_test.csv')
1✔
73
    try:
1✔
74
        train = load_1s_and_7s(train_file)
1✔
75
        test = load_1s_and_7s(test_file)
1✔
76
    except Exception:  # pragma: no cover
77
        print(cant_find_files.format(train_file, test_file))
78
        raise
79

80
    network = build_network()
1✔
81
    n = len(train)
1✔
82
    epoch_size = int(n/10)
1✔
83

84
    for i in range(num_epochs):
1✔
85
        shuffle(train)
1✔
86
        validation = train[:epoch_size]
1✔
87
        real_train = train[epoch_size: 2*epoch_size]
1✔
88

89
        print("Starting epoch of {} examples with {} validation".format(
1✔
90
            len(real_train), len(validation)))
91

92
        network.train(real_train, max_steps=len(real_train))
1✔
93

94
        print("Finished epoch. Validation error={:.3f}".format(
1✔
95
            network.error_on_dataset(validation)))
96

97
    print("Test error={:.3f}".format(network.error_on_dataset(test)))
1✔
98
    show_random_examples(network, test)
1✔
99
    return network
1✔
100

101

102
if __name__ == "__main__":
103
    data_dirname = os.path.join(os.path.dirname(__file__), 'mnist')
104
    network = train_mnist(data_dirname)
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc