• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

kermitt2 / grobid / 385

pending completion
385

push

circleci

review incremental training

3 of 3 new or added lines in 2 files covered. (100.0%)

14846 of 37503 relevant lines covered (39.59%)

0.4 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

4.94
/grobid-core/src/main/java/org/grobid/core/jni/DeLFTModel.java
1
package org.grobid.core.jni;
2

3
import org.grobid.core.GrobidModel;
4
import org.grobid.core.engines.label.TaggingLabels;
5
import org.grobid.core.exceptions.GrobidException;
6
import org.grobid.core.utilities.GrobidProperties;
7
import org.grobid.core.utilities.IOUtilities;
8
import org.slf4j.Logger;
9
import org.slf4j.LoggerFactory;
10

11
import java.util.concurrent.*;  
12
import java.io.*;
13
import java.lang.StringBuilder;
14
import java.util.*;
15
import java.util.regex.*;
16

17
import jep.Jep;
18
import jep.JepException;
19

20
import java.util.function.Consumer;
21

22
public class DeLFTModel {
23
    public static final Logger LOGGER = LoggerFactory.getLogger(DeLFTModel.class);
1✔
24

25
    // Exploit JNI CPython interpreter to execute load and execute a DeLFT deep learning model 
26
    private String modelName;
27
    private String architecture;
28

29
    public DeLFTModel(GrobidModel model, String architecture) {
1✔
30
        this.modelName = model.getModelName().replace("-", "_");
1✔
31
        this.architecture = architecture;
1✔
32
        try {
33
            LOGGER.info("Loading DeLFT model for " + model.getModelName() + " with architecture " + architecture + "...");            
1✔
34
            JEPThreadPool.getInstance().run(new InitModel(this.modelName, GrobidProperties.getInstance().getModelPath(), architecture));
1✔
35
        } catch(InterruptedException | RuntimeException e) {
×
36
            LOGGER.error("DeLFT model " + this.modelName + " initialization failed", e);
×
37
        }
1✔
38
    }
1✔
39

40
    class InitModel implements Runnable { 
41
        private String modelName;
42
        private File modelPath;
43
        private String architecture;
44
          
45
        public InitModel(String modelName, File modelPath, String architecture) { 
1✔
46
            this.modelName = modelName;
1✔
47
            this.modelPath = modelPath;
1✔
48
            this.architecture = architecture;
1✔
49
        } 
1✔
50
          
51
        @Override
52
        public void run() { 
53
            Jep jep = JEPThreadPool.getInstance().getJEPInstance(); 
×
54
            try { 
55
                String fullModelName = this.modelName.replace("_", "-");
×
56

57
                //if (architecture != null && !architecture.equals("BidLSTM_CRF"))
58
                if (architecture != null)
×
59
                    fullModelName += "-" + this.architecture;
×
60

61
                if (GrobidProperties.getInstance().useELMo(this.modelName) && modelName.toLowerCase().indexOf("bert") == -1)
×
62
                    fullModelName += "-with_ELMo";
×
63

64
                jep.eval(this.modelName+" = Sequence('" + fullModelName + "')");
×
65
                jep.eval(this.modelName+".load(dir_path='"+modelPath.getAbsolutePath()+"')");
×
66

67
                if (GrobidProperties.getInstance().getDelftRuntimeMaxSequenceLength(this.modelName) != -1) {
×
68
                    jep.eval(this.modelName+".model_config.max_sequence_length="+
×
69
                        GrobidProperties.getInstance().getDelftRuntimeMaxSequenceLength(this.modelName));
×
70
                }
71

72
                if (GrobidProperties.getInstance().getDelftRuntimeBatchSize(this.modelName) != -1) {
×
73
                    jep.eval(this.modelName+".model_config.batch_size="+
×
74
                        GrobidProperties.getInstance().getDelftRuntimeBatchSize(this.modelName));
×
75
                }
76

77
            } catch(JepException e) {
×
78
                LOGGER.error("DeLFT model initialization failed. ", e);
×
79
                throw new GrobidException("DeLFT model initialization failed. ", e);
×
80
            }
×
81
        } 
×
82
    } 
83

84
    private class LabelTask implements Callable<String> { 
85
        private String data;
86
        private String modelName;
87
        private String architecture;
88

89
        public LabelTask(String modelName, String data, String architecture) { 
×
90
            //System.out.println("label thread: " + Thread.currentThread().getId());
91
            this.modelName = modelName;
×
92
            this.data = data;
×
93
            this.architecture = architecture;
×
94
        }
×
95

96
        private void setJepStringValueWithFileFallback(
97
            Jep jep, String name, String value
98
        ) throws JepException, IOException {
99
            try {
100
                jep.set(name, value);
×
101
            } catch(JepException e) {
×
102
                File tempFile = IOUtilities.newTempFile(name, ".data");
×
103
                LOGGER.debug(
×
104
                    "Falling back to file {} due to exception: {}",
105
                    tempFile, e.toString()
×
106
                );
107
                IOUtilities.writeInFile(tempFile.getAbsolutePath(), value);
×
108
                jep.eval("from pathlib import Path");
×
109
                jep.eval(
×
110
                    name + " = Path('" + tempFile.getAbsolutePath() +
×
111
                    "').read_text(encoding='utf-8')"
112
                );
113
                tempFile.delete();
×
114
            }
×
115
        }
×
116

117
        @Override
118
        public String call() { 
119
            Jep jep = JEPThreadPool.getInstance().getJEPInstance(); 
×
120
            StringBuilder labelledData = new StringBuilder();
×
121
            try {
122
                //System.out.println(this.data);
123

124
                // load and tag
125
                this.setJepStringValueWithFileFallback(jep, "input", this.data);
×
126
                jep.eval("x_all, f_all = load_data_crf_string(input)");
×
127
                Object objectResults = null;
×
128
                if (architecture.indexOf("FEATURE") != -1) {
×
129
                    // model is expecting features
130
                    objectResults = jep.getValue(this.modelName+".tag(x_all, None, features=f_all)");
×
131
                } else {
132
                    // no features used by the model
133
                    objectResults = jep.getValue(this.modelName+".tag(x_all, None)");
×
134
                }
135

136
                // inject back the labels
137
                List<List<List<String>>> results = (List<List<List<String>>>) objectResults;
×
138
                BufferedReader bufReader = new BufferedReader(new StringReader(data));
×
139
                String inputLine;
140
                int i = 0; // sentence index
×
141
                int j = 0; // word index in the sentence
×
142
                if (results.size() > 0) {
×
143
                    List<List<String>> result = results.get(0);
×
144
                    while ((inputLine = bufReader.readLine()) != null) {
×
145
                        inputLine = inputLine.trim();
×
146
                        if ((inputLine.length() == 0) && (j != 0)) {
×
147
                            j = 0;
×
148
                            i++;
×
149
                            if (i == results.size())
×
150
                                break;
×
151
                            result = results.get(i);
×
152
                            continue;
×
153
                        }
154

155
                        if (inputLine.length() == 0) {
×
156
                            labelledData.append("\n");
×
157
                            continue;
×
158
                        }
159
                        labelledData.append(inputLine);
×
160
                        labelledData.append(" ");
×
161

162
                        if (j >= result.size()) {
×
163
                            labelledData.append(TaggingLabels.OTHER_LABEL);
×
164
                        } else {
165
                            List<String> pair = result.get(j);
×
166
                            // first is the token, second is the label (DeLFT format)
167
                            String token = pair.get(0);
×
168
                            String label = pair.get(1);
×
169
                            labelledData.append(DeLFTModel.delft2grobidLabel(label));
×
170
                        }
171
                        labelledData.append("\n");
×
172
                        j++;
×
173
                    }
174
                }
175
                
176
                // cleaning
177
                jep.eval("del input");
×
178
                jep.eval("del x_all");
×
179
                jep.eval("del f_all");
×
180
                //jep.eval("K.clear_session()");
181
            } catch(JepException e) {
×
182
                LOGGER.error("DeLFT model labelling via JEP failed", e);
×
183
            } catch(IOException e) {
×
184
                LOGGER.error("DeLFT model labelling failed", e);
×
185
            }
×
186
            //System.out.println(labelledData.toString());
187
            return labelledData.toString();
×
188
        } 
189
    } 
190

191
    public String label(String data) {
192
        String result = null;
×
193
        try {
194
            result = JEPThreadPool.getInstance().call(new LabelTask(this.modelName, data, this.architecture));
×
195
        } catch(InterruptedException e) {
×
196
            LOGGER.error("DeLFT model " + this.modelName + " labelling interrupted", e);
×
197
        } catch(ExecutionException e) {
×
198
            LOGGER.error("DeLFT model " + this.modelName + " labelling failed", e);
×
199
        }
×
200
        // In some areas, GROBID currently expects tabs as feature separators.
201
        // (Same as in WapitiModel.label)
202
        if (result != null)
×
203
            result = result.replaceAll(" ", "\t");
×
204
        return result;
×
205
    }
206

207
    /**
208
     * Training via JNI CPython interpreter (JEP). It appears that after some epochs, the JEP thread
209
     * usually hangs... Possibly issues with IO threads at the level of JEP (output not consumed because
210
     * of \r and no end of line?). 
211
     */
212
    public static void trainJNI(String modelName, File trainingData, File outputModel, String architecture, boolean incremental) {
213
        try {
214
            LOGGER.info("Train DeLFT model " + modelName + "...");
×
215
            JEPThreadPool.getInstance().run(
×
216
                new TrainTask(modelName, trainingData, GrobidProperties.getInstance().getModelPath(), architecture, incremental));
×
217
        } catch(InterruptedException e) {
×
218
            LOGGER.error("Train DeLFT model " + modelName + " task failed", e);
×
219
        }
×
220
    }
×
221

222
    private static class TrainTask implements Runnable { 
223
        private String modelName;
224
        private File trainPath;
225
        private File modelPath;
226
        private String architecture;
227
        private boolean incremental;
228

229
        public TrainTask(String modelName, File trainPath, File modelPath, String architecture, boolean incremental) { 
×
230
            //System.out.println("train thread: " + Thread.currentThread().getId());
231
            this.modelName = modelName;
×
232
            this.trainPath = trainPath;
×
233
            this.modelPath = modelPath;
×
234
            this.architecture = architecture;
×
235
            this.incremental = incremental;
×
236
        } 
×
237
          
238
        @Override
239
        public void run() { 
240
            Jep jep = JEPThreadPool.getInstance().getJEPInstance(); 
×
241
            try {
242
                // load data
243
                jep.eval("x_all, y_all, f_all = load_data_and_labels_crf_file('" + this.trainPath.getAbsolutePath() + "')");
×
244
                jep.eval("x_train, x_valid, y_train, y_valid = train_test_split(x_all, y_all, test_size=0.1)");
×
245
                jep.eval("print(len(x_train), 'train sequences')");
×
246
                jep.eval("print(len(x_valid), 'validation sequences')");
×
247

248
                String useELMo = "False";
×
249
                if (GrobidProperties.getInstance().useELMo(this.modelName) && modelName.toLowerCase().indexOf("bert") == -1) {
×
250
                    useELMo = "True";
×
251
                }
252

253
                String localArgs = "";
×
254
                if (GrobidProperties.getInstance().getDelftTrainingMaxSequenceLength(this.modelName) != -1)
×
255
                    localArgs += ", max_sequence_length="+
×
256
                        GrobidProperties.getInstance().getDelftTrainingMaxSequenceLength(this.modelName);
×
257

258
                if (GrobidProperties.getInstance().getDelftTrainingBatchSize(this.modelName) != -1)
×
259
                    localArgs += ", batch_size="+
×
260
                        GrobidProperties.getInstance().getDelftTrainingBatchSize(this.modelName);
×
261

262
                if (GrobidProperties.getInstance().getDelftTranformer(modelName) != null) {
×
263
                    localArgs += ", transformer="+
×
264
                        GrobidProperties.getInstance().getDelftTranformer(modelName);
×
265
                }
266

267
                // init model to be trained
268
                if (architecture == null)
×
269
                    jep.eval("model = Sequence('"+this.modelName+
×
270
                        "', max_epoch=100, recurrent_dropout=0.50, embeddings_name='glove-840B', use_ELMo="+useELMo+localArgs+")");
271
                else
272
                    jep.eval("model = Sequence('"+this.modelName+
×
273
                        "', max_epoch=100, recurrent_dropout=0.50, embeddings_name='glove-840B', use_ELMo="+useELMo+localArgs+ 
274
                        ", architecture='"+architecture+"')");
275

276
                // actual training
277
                //start_time = time.time()
278
                if (incremental) {
×
279
                    // if incremental training, we need to load the existing model
280
                    if (this.modelPath != null && 
×
281
                        this.modelPath.exists() &&
×
282
                        this.modelPath.isDirectory()) {
×
283
                        jep.eval("model.load('" + this.modelPath.getAbsolutePath() + "')");
×
284
                        jep.eval("model.train(x_train, y_train, x_valid, y_valid, incremental=True)");
×
285
                    } else {
286
                        throw new GrobidException("the path to the model to be used for starting incremental training is invalid: " +
×
287
                            this.modelPath.getAbsolutePath());
×
288
                    }
289
                } else
290
                    jep.eval("model.train(x_train, y_train, x_valid, y_valid)");
×
291
                //runtime = round(time.time() - start_time, 3)
292
                //print("training runtime: %s seconds " % (runtime))
293

294
                // saving the model
295
                System.out.println(this.modelPath.getAbsolutePath());
×
296
                jep.eval("model.save('"+this.modelPath.getAbsolutePath()+"')");
×
297
                
298
                // cleaning
299
                jep.eval("del x_all");
×
300
                jep.eval("del y_all");
×
301
                jep.eval("del f_all");
×
302
                jep.eval("del x_train");
×
303
                jep.eval("del x_valid");
×
304
                jep.eval("del y_train");
×
305
                jep.eval("del y_valid");
×
306
                jep.eval("del model");
×
307
            } catch(JepException e) {
×
308
                LOGGER.error("DeLFT model training via JEP failed", e);
×
309
            } catch(GrobidException e) {
×
310
                LOGGER.error("GROBID call to DeLFT training via JEP failed", e);
×
311
            } 
×
312
        } 
×
313
    } 
314

315
    /**
316
     *  Train with an external process rather than with JNI, this approach appears to be more stable for the
317
     *  training process (JNI approach hangs after a while) and does not raise any runtime/integration issues. 
318
     */
319
    public static void train(String modelName, File trainingData, File outputModel, String architecture, boolean incremental) {
320
        try {
321
            LOGGER.info("Train DeLFT model " + modelName + "...");
×
322
            List<String> command = new ArrayList<>();
×
323
            List<String> subcommands = Arrays.asList("python3", 
×
324
                "delft/applications/grobidTagger.py", 
325
                modelName,
326
                "train",
327
                "--input", trainingData.getAbsolutePath(),
×
328
                "--output", GrobidProperties.getInstance().getModelPath().getAbsolutePath());
×
329
            command.addAll(subcommands);
×
330
            if (architecture != null) {
×
331
                command.add("--architecture");
×
332
                command.add(architecture);
×
333
            }
334
            if (GrobidProperties.getInstance().getDelftTranformer(modelName) != null) {
×
335
                command.add("--transformer");
×
336
                command.add(GrobidProperties.getInstance().getDelftTranformer(modelName));
×
337
            }
338
            if (GrobidProperties.getInstance().useELMo(modelName) && modelName.toLowerCase().indexOf("bert") == -1) {
×
339
                command.add("--use-ELMo");
×
340
            }
341
            if (GrobidProperties.getInstance().getDelftTrainingMaxSequenceLength(modelName) != -1) {
×
342
                command.add("--max-sequence-length");
×
343
                command.add(String.valueOf(GrobidProperties.getInstance().getDelftTrainingMaxSequenceLength(modelName)));
×
344
            }
345
            if (GrobidProperties.getInstance().getDelftTrainingBatchSize(modelName) != -1) {
×
346
                command.add("--batch-size");
×
347
                command.add(String.valueOf(GrobidProperties.getInstance().getDelftTrainingBatchSize(modelName)));
×
348
            }
349
            if (incremental) {
×
350
                command.add("--incremental");
×
351

352
                // if incremental training, we need to load the existing model
353
                File modelPath = GrobidProperties.getInstance().getModelPath();
×
354
                if (modelPath != null && 
×
355
                    modelPath.exists() &&
×
356
                    modelPath.isDirectory()) {
×
357
                    command.add("--input-model");
×
358
                    command.add(GrobidProperties.getInstance().getModelPath().getAbsolutePath());
×
359
                } else {
360
                    throw new GrobidException("the path to the model to be used for starting incremental training is invalid: " +
×
361
                        GrobidProperties.getInstance().getModelPath().getAbsolutePath());
×
362
                }
363
            }
364
            ProcessBuilder pb = new ProcessBuilder(command);
×
365
            File delftPath = new File(GrobidProperties.getInstance().getDeLFTFilePath());
×
366
            pb.directory(delftPath);
×
367
            Process process = pb.start(); 
×
368
            //pb.inheritIO();
369
            CustomStreamGobbler customStreamGobbler = 
×
370
                new CustomStreamGobbler(process.getInputStream(), System.out);
×
371
            Executors.newSingleThreadExecutor().submit(customStreamGobbler);
×
372
            SimpleStreamGobbler streamGobbler = new SimpleStreamGobbler(process.getErrorStream(), System.err::println);
×
373
            Executors.newSingleThreadExecutor().submit(streamGobbler);
×
374
            int exitCode = process.waitFor();
×
375
            //assert exitCode == 0;
376
        } catch(IOException e) {
×
377
            LOGGER.error("IO error when training DeLFT model " + modelName, e);
×
378
        } catch(InterruptedException e) {
×
379
            LOGGER.error("Train DeLFT model " + modelName + " task failed", e);
×
380
        } catch(GrobidException e) {
×
381
            LOGGER.error("GROBID call to DeLFT training via JEP failed", e);
×
382
        } 
×
383
    }
×
384

385
    public synchronized void close() {
386
        try {
387
            LOGGER.info("Close DeLFT model " + this.modelName + "...");
×
388
            JEPThreadPool.getInstance().run(new CloseModel(this.modelName));
×
389
        } catch(InterruptedException e) {
×
390
            LOGGER.error("Close DeLFT model " + this.modelName + " task failed", e);
×
391
        }
×
392
    }
×
393

394
    private class CloseModel implements Runnable { 
395
        private String modelName;
396
          
397
        public CloseModel(String modelName) { 
×
398
            this.modelName = modelName;
×
399
        } 
×
400
          
401
        @Override
402
        public void run() { 
403
            Jep jep = JEPThreadPool.getInstance().getJEPInstance(); 
×
404
            try { 
405
                jep.eval("del "+this.modelName);
×
406
            } catch(JepException e) {
×
407
                LOGGER.error("Closing DeLFT model failed", e);
×
408
            } 
×
409
        } 
×
410
    }
411

412
    private static String delft2grobidLabel(String label) {
413
        if (label.equals(TaggingLabels.IOB_OTHER_LABEL)) {
×
414
            label = TaggingLabels.OTHER_LABEL;
×
415
        } else if (label.startsWith(TaggingLabels.IOB_START_ENTITY_LABEL_PREFIX)) {
×
416
            label = label.replace(TaggingLabels.IOB_START_ENTITY_LABEL_PREFIX, TaggingLabels.GROBID_START_ENTITY_LABEL_PREFIX);
×
417
        } else if (label.startsWith(TaggingLabels.IOB_INSIDE_LABEL_PREFIX)) {
×
418
            label = label.replace(TaggingLabels.IOB_INSIDE_LABEL_PREFIX, TaggingLabels.GROBID_INSIDE_ENTITY_LABEL_PREFIX);
×
419
        } 
420
        return label;
×
421
    }
422

423
    private static class SimpleStreamGobbler implements Runnable {
424
        private InputStream inputStream;
425
        private Consumer<String> consumer;
426
     
427
        public SimpleStreamGobbler(InputStream inputStream, Consumer<String> consumer) {
×
428
            this.inputStream = inputStream;
×
429
            this.consumer = consumer;
×
430
        }
×
431
     
432
        @Override
433
        public void run() {
434
            new BufferedReader(new InputStreamReader(inputStream)).lines()
×
435
              .forEach(consumer);
×
436
        }
×
437
    }
438

439
    /**
440
     * This is a custom gobbler that reproduces correctly the Keras training progress bar
441
     * by injecting a \r for progress line updates. 
442
     */ 
443
    private static class CustomStreamGobbler implements Runnable {
444
        public static final Logger LOGGER = LoggerFactory.getLogger(CustomStreamGobbler.class);
×
445

446
        private final InputStream is;
447
        private final PrintStream os;
448
        private Pattern pattern = Pattern.compile("\\d/\\d+ \\[");
×
449

450
        public CustomStreamGobbler(InputStream is, PrintStream os) {
×
451
            this.is = is;
×
452
            this.os = os;
×
453
        }
×
454
     
455
        @Override
456
        public void run() {
457
            try {
458
                InputStreamReader isr = new InputStreamReader(this.is);
×
459
                BufferedReader br = new BufferedReader(isr);
×
460
                String line = null;
×
461
                while ((line = br.readLine()) != null) {
×
462
                    Matcher matcher = pattern.matcher(line);
×
463
                    if (matcher.find()) {
×
464
                        os.print("\r" + line);
×
465
                        os.flush();
×
466
                    } else {
467
                        os.println(line);
×
468
                    }
469
                }
×
470
            }
471
            catch (IOException e) {
×
472
                LOGGER.warn("IO error between embedded python and java process", e);
×
473
            }
×
474
        }
×
475
    }
476

477
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc