• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

vanvalenlab / deepcell-label / 4578689396

pending completion
4578689396

Pull #436

github

GitHub
Merge ddb425c30 into 6a993cb7a
Pull Request #436: Model training overhaul: SNGP model, uncertainty visualization, and custom embedding support

462 of 1163 branches covered (39.72%)

Branch coverage included in aggregate %.

20 of 628 new or added lines in 27 files covered. (3.18%)

76 existing lines in 5 files now uncovered.

3248 of 5431 relevant lines covered (59.8%)

543.49 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

29.17
/frontend/src/Project/service/labels/trainingMachine.js
1
/** Perform tensorflow.js training using embeddings and labeled cell types as input data
2
 */
3

4
import { actions, assign, Machine, send } from 'xstate';
5
import { fromEventBus } from '../eventBus';
6
import { predict, train } from './trainingMachineUtils';
7

8
const { choose } = actions;
19✔
9

10
const createTrainingMachine = ({ eventBuses }) =>
19✔
11
  Machine(
866✔
12
    {
13
      id: 'training',
14
      invoke: [
15
        { id: 'eventBus', src: fromEventBus('training', () => eventBuses.training) },
18✔
16
        { id: 'load', src: fromEventBus('training', () => eventBuses.load, 'LOADED') },
18✔
17
        { id: 'cellTypes', src: fromEventBus('training', () => eventBuses.cellTypes) },
18✔
18
        { id: 'cells', src: fromEventBus('training', () => eventBuses.cells, 'CELLS') },
18✔
19
        {
20
          id: 'channelExpression',
21
          src: fromEventBus('training', () => eventBuses.channelExpression),
18✔
22
        },
23
        { src: fromEventBus('training', () => eventBuses.image, 'SET_T') },
18✔
24
        { src: fromEventBus('training', () => eventBuses.labeled, 'SET_FEATURE') },
18✔
25
      ],
26
      context: {
27
        // Hyperparameters
28
        batchSize: 1,
29
        numEpochs: 20,
30
        learningRate: 0.01,
31
        valSplit: 0.8,
32
        // "Input" context
33
        embedding: 'Mean',
34
        embeddings: null,
35
        t: 0,
36
        feature: 0,
37
        cells: null,
38
        epoch: 0,
39
        numChannels: null, // from raw
40
        cellTypes: null,
41
        calculations: null,
42
        whole: false,
43
        uncertaintyThreshold: 0.5,
44
        predictionMode: 'over',
45
        // "Output" context
46
        confusionMatrix: null,
47
        trainCounts: null,
48
        valCounts: null,
49
        predUncertainties: null,
50
        range: null,
51
        model: null,
52
        valLogs: [],
53
        trainLogs: [],
54
        parameterLog: null,
55
      },
56
      initial: 'loading',
57
      on: {
58
        CELLS: { actions: 'setCells' },
59
        CELLTYPES: { actions: 'setCellTypes' },
60
        SET_T: { actions: 'setT' },
61
        SET_FEATURE: { actions: 'setFeature' },
62
      },
63
      states: {
64
        loading: {
65
          on: {
66
            LOADED: {
67
              actions: ['setCellTypes', 'setCells', 'setNumChannels', 'setEmbeddings'],
68
              target: 'loaded',
69
            },
70
          },
71
        },
72
        loaded: {
73
          initial: 'idle',
74
          states: {
75
            idle: {
76
              on: {
77
                TRAIN: { target: 'training' },
78
                PREDICT: { target: 'predicting' },
79
                EMBEDDING: { actions: 'setEmbedding' },
80
                BATCH_SIZE: { actions: 'setBatchSize' },
81
                LEARNING_RATE: { actions: 'setLearningRate' },
82
                NUM_EPOCHS: { actions: 'setNumEpochs' },
83
                VAL_SPLIT: { actions: 'setValSplit' },
84
                TOGGLE_WHOLE: { actions: 'toggleWhole' },
85
                THRESHOLD: { actions: 'setThreshold' },
86
                PREDICTION_MODE: { actions: 'setPredictionMode' },
87
              },
88
            },
89
            training: {
90
              initial: 'calculating',
91
              states: {
92
                calculating: {
93
                  entry: choose([
94
                    {
95
                      cond: (ctx) => ctx.embedding === 'Mean',
×
96
                      actions: ['resetEpoch', 'resetLogs', 'getMean'],
97
                    },
98
                    {
99
                      cond: (ctx) => ctx.embedding === 'Total',
×
100
                      actions: ['resetEpoch', 'resetLogs', 'getTotal'],
101
                    },
102
                    {
NEW
103
                      cond: (ctx) => ctx.embedding === 'Imported',
×
104
                      actions: ['resetEpoch', 'resetLogs', send({ type: 'TRAIN', imported: true })],
105
                    },
106
                  ]),
107
                  on: {
108
                    CALCULATION: { actions: 'setCalculation', target: 'train' },
109
                    TRAIN: { target: 'train' },
110
                  },
111
                },
112
                train: {
113
                  invoke: {
114
                    id: 'training',
115
                    src: (ctx, evt) => (sendBack) => {
×
116
                      // TO-DO: handle errors in the training function
117
                      train(ctx, evt, sendBack);
×
118
                    },
119
                    // onError: { target: 'idle', actions: (c, e) => console.log(c, e) },
120
                  },
121
                  on: {
122
                    SET_EPOCH: { actions: ['setEpoch', 'setLogs'] },
123
                  },
124
                },
125
              },
126
              on: {
127
                CANCEL: { target: 'idle' },
128
                DONE: {
129
                  target: 'idle',
130
                  actions: [
131
                    'saveModel',
132
                    'setConfusionMatrix',
133
                    'setTrainHistogram',
134
                    'setValHistogram',
135
                    'setRange',
136
                  ],
137
                },
138
              },
139
            },
140
            predicting: {
141
              initial: 'calculating',
142
              states: {
143
                calculating: {
144
                  entry: choose([
145
                    {
146
                      cond: (ctx) => ctx.embedding === 'Mean',
×
147
                      actions: 'getMean',
148
                    },
149
                    {
150
                      cond: (ctx) => ctx.embedding === 'Total',
×
151
                      actions: 'getTotal',
152
                    },
153
                    {
NEW
154
                      cond: (ctx) => ctx.embedding === 'Imported',
×
155
                      actions: send({ type: 'PREDICT', imported: true }),
156
                    },
157
                  ]),
158
                  on: {
159
                    CALCULATION: { actions: 'setCalculation', target: 'predict' },
160
                    PREDICT: { target: 'predict' },
161
                  },
162
                },
163
                predict: {
164
                  invoke: {
165
                    id: 'predicting',
166
                    src: (ctx, evt) => (sendBack) => {
×
167
                      // TO-DO: handle errors in the predicting function
168
                      predict(ctx, evt, sendBack);
×
169
                    },
170
                    // onError: { target: 'idle', actions: (c, e) => console.log(c, e) },
171
                  },
172
                },
173
              },
174
              on: {
175
                DONE: {
176
                  target: 'idle',
177
                  actions: ['sendPredictions', 'setPredUncertainties'],
178
                },
179
              },
180
            },
181
          },
182
        },
183
      },
184
    },
185
    {
186
      actions: {
187
        setBatchSize: assign({ batchSize: (_, evt) => evt.batchSize }),
×
188
        setCells: assign({
189
          cells: (ctx, evt) => evt.cells,
39✔
190
        }),
191
        setEmbedding: assign({ embedding: (_, evt) => evt.embedding }),
×
192
        setEmbeddings: assign({ embeddings: (_, evt) => evt.embeddings }),
13✔
193
        setNumEpochs: assign({ numEpochs: (_, evt) => evt.numEpochs }),
×
194
        setLearningRate: assign({ learningRate: (_, evt) => evt.learningRate }),
×
195
        setValSplit: assign({ valSplit: (_, evt) => evt.valSplit }),
×
NEW
196
        setThreshold: assign({ uncertaintyThreshold: (_, evt) => evt.uncertaintyThreshold }),
×
NEW
197
        setPredictionMode: assign({ predictionMode: (_, evt) => evt.predictionMode }),
×
198
        setNumChannels: assign({ numChannels: (_, evt) => evt.raw.length }),
13✔
199
        setCellTypes: assign({ cellTypes: (_, evt) => evt.cellTypes }),
38✔
200
        setT: assign({ t: (_, evt) => evt.t }),
×
201
        setFeature: assign({ feature: (_, evt) => evt.feature }),
×
202
        setEpoch: assign({ epoch: (_, evt) => evt.epoch }),
×
203
        setLogs: assign({
NEW
204
          valLogs: (ctx, evt) => ctx.valLogs.concat(ctx.valSplit < 1 ? [evt.logs.val_loss] : []),
×
205
          trainLogs: (ctx, evt) => ctx.trainLogs.concat([evt.logs.loss]),
×
206
        }),
207
        setCalculation: assign({ calculations: (_, evt) => evt.calculations }),
×
208
        setConfusionMatrix: assign({ confusionMatrix: (_, evt) => evt.confusionMatrix }),
×
NEW
209
        setTrainHistogram: assign({ trainCounts: (_, evt) => evt.trainCounts }),
×
NEW
210
        setValHistogram: assign({ valCounts: (_, evt) => evt.valCounts }),
×
UNCOV
211
        toggleWhole: assign({ whole: (ctx) => !ctx.whole }),
×
212
        getMean: send({ type: 'CALCULATE', stat: 'Mean' }, { to: 'channelExpression' }),
213
        getTotal: send({ type: 'CALCULATE', stat: 'Total' }, { to: 'channelExpression' }),
214
        resetLogs: assign({
215
          valLogs: [],
216
          trainLogs: [],
217
        }),
218
        resetEpoch: assign({ epoch: () => 0 }),
×
219
        saveModel: assign({ model: (_, evt) => evt.model }),
×
220
        setRange: assign({ range: (_, evt) => [evt.inputMin, evt.inputMax] }),
×
221
        sendPredictions: send((_, evt) => ({ type: 'ADD_PREDICTIONS', predictions: evt.predMap }), {
×
222
          to: 'cellTypes',
223
        }),
NEW
224
        setPredUncertainties: assign({ predUncertainties: (_, evt) => evt.uncertainties }),
×
225
      },
226
    }
227
  );
228

229
export default createTrainingMachine;
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc