• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 4992350528

pending completion
4992350528

Pull #13797

github

GitHub
Merge 424c7ff18 into ef7906c5e
Pull Request #13797: SPARKNLP-835: ProtectedParam and ProtectedFeature

24 of 24 new or added lines in 6 files covered. (100.0%)

8643 of 13129 relevant lines covered (65.83%)

0.66 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

51.15
/src/main/scala/com/johnsnowlabs/nlp/serialization/Feature.scala
1
/*
2
 * Copyright 2017-2022 John Snow Labs
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *    http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package com.johnsnowlabs.nlp.serialization
18

19
import com.github.liblevenshtein.serialization.PlainTextSerializer
20
import com.johnsnowlabs.nlp.HasFeatures
21
import com.johnsnowlabs.nlp.annotators.spell.context.parser.VocabParser
22
import com.johnsnowlabs.nlp.util.io.ResourceHelper
23
import com.johnsnowlabs.util.{ConfigHelper, ConfigLoader}
24
import org.apache.hadoop.fs.{FileSystem, Path}
25
import org.apache.spark.broadcast.Broadcast
26
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
27

28
import scala.reflect.ClassTag
29

30
abstract class Feature[Serializable1, Serializable2, TComplete: ClassTag](
31
    model: HasFeatures,
32
    val name: String)
33
    extends Serializable {
34
  model.features.append(this)
1✔
35

36
  private val spark: SparkSession = ResourceHelper.spark
37

38
  val serializationMode: String =
39
    ConfigLoader.getConfigStringValue(ConfigHelper.serializationMode)
1✔
40
  val useBroadcast: Boolean = ConfigLoader.getConfigBooleanValue(ConfigHelper.useBroadcast)
1✔
41
  final protected var broadcastValue: Option[Broadcast[TComplete]] = None
1✔
42

43
  final protected var rawValue: Option[TComplete] = None
1✔
44
  final protected var fallbackRawValue: Option[TComplete] = None
1✔
45

46
  final protected var fallbackLazyValue: Option[() => TComplete] = None
1✔
47
  final protected var isProtected: Boolean = false
1✔
48

49
  final def serialize(
50
      spark: SparkSession,
51
      path: String,
52
      field: String,
53
      value: TComplete): Unit = {
54
    serializationMode match {
1✔
55
      case "dataset" => serializeDataset(spark, path, field, value)
×
56
      case "object" => serializeObject(spark, path, field, value)
1✔
57
      case _ =>
58
        throw new IllegalArgumentException(
×
59
          "Illegal performance.serialization setting. Can be 'dataset' or 'object'")
60
    }
61
  }
62

63
  final def serializeInfer(spark: SparkSession, path: String, field: String, value: Any): Unit =
64
    serialize(spark, path, field, value.asInstanceOf[TComplete])
1✔
65

66
  final def deserialize(spark: SparkSession, path: String, field: String): Option[_] = {
67
    if (broadcastValue.isDefined || rawValue.isDefined)
1✔
68
      throw new Exception(
×
69
        s"Trying de deserialize an already set value for ${this.name}. This should not happen.")
70
    serializationMode match {
1✔
71
      case "dataset" => deserializeDataset(spark, path, field)
×
72
      case "object" => deserializeObject(spark, path, field)
1✔
73
      case _ =>
74
        throw new IllegalArgumentException(
×
75
          "Illegal performance.serialization setting. Can be 'dataset' or 'object'")
76
    }
77
  }
78

79
  protected def serializeDataset(
80
      spark: SparkSession,
81
      path: String,
82
      field: String,
83
      value: TComplete): Unit
84

85
  protected def deserializeDataset(spark: SparkSession, path: String, field: String): Option[_]
86

87
  protected def serializeObject(
88
      spark: SparkSession,
89
      path: String,
90
      field: String,
91
      value: TComplete): Unit
92

93
  protected def deserializeObject(spark: SparkSession, path: String, field: String): Option[_]
94

95
  final protected def getFieldPath(path: String, field: String): Path =
96
    Path.mergePaths(new Path(path), new Path("/fields/" + field))
1✔
97

98
  private def callAndSetFallback: Option[TComplete] = {
99
    fallbackRawValue = fallbackLazyValue.map(_())
1✔
100
    fallbackRawValue
1✔
101
  }
102

103
  final def get: Option[TComplete] = {
104
    broadcastValue.map(_.value).orElse(rawValue)
1✔
105
  }
106

107
  final def orDefault: Option[TComplete] = {
108
    broadcastValue
109
      .map(_.value)
1✔
110
      .orElse(rawValue)
1✔
111
      .orElse(fallbackRawValue)
1✔
112
      .orElse(callAndSetFallback)
1✔
113
  }
114

115
  final def getOrDefault: TComplete = {
116
    orDefault
117
      .getOrElse(throw new Exception(s"feature $name is not set"))
1✔
118
  }
119

120
  final def setValue(value: Option[Any]): HasFeatures = {
121
    if (isProtected && isSet) {
1✔
122
      val warnString =
123
        s"Warning: The parameter ${this.name} is protected and can only be set once." +
124
          " For a pretrained model, this was done during the initialization process." +
1✔
125
          " If you are trying to train your own model, please check the documentation."
126
      println(warnString)
1✔
127
    } else {
128
      if (useBroadcast) {
1✔
129
        if (isSet) broadcastValue.get.destroy()
×
130
        broadcastValue =
1✔
131
          value.map(v => spark.sparkContext.broadcast[TComplete](v.asInstanceOf[TComplete]))
1✔
132
      } else {
133
        rawValue = value.map(_.asInstanceOf[TComplete])
×
134
      }
135
    }
136
    model
1✔
137
  }
138

139
  def setFallback(v: Option[() => TComplete]): HasFeatures = {
140
    fallbackLazyValue = v
1✔
141
    model
1✔
142
  }
143

144
  final def isSet: Boolean = {
145
    broadcastValue.isDefined || rawValue.isDefined
1✔
146
  }
147

148
  /** Sets this feature to be protected and only settable once.
149
    *
150
    * @return
151
    *   This Feature
152
    */
153
  final def setProtected(): this.type = {
154
    isProtected = true
1✔
155
    this
156
  }
157

158
}
159

160
class StructFeature[TValue: ClassTag](model: HasFeatures, override val name: String)
161
    extends Feature[TValue, TValue, TValue](model, name) {
162

163
  implicit val encoder: Encoder[TValue] = Encoders.kryo[TValue]
1✔
164

165
  override def serializeObject(
166
      spark: SparkSession,
167
      path: String,
168
      field: String,
169
      value: TValue): Unit = {
170
    val dataPath = getFieldPath(path, field)
1✔
171
    spark.sparkContext.parallelize(Seq(value)).saveAsObjectFile(dataPath.toString)
1✔
172
  }
173

174
  override def deserializeObject(
175
      spark: SparkSession,
176
      path: String,
177
      field: String): Option[TValue] = {
178
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
1✔
179
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
1✔
180
    val dataPath = getFieldPath(path, field)
1✔
181
    if (fs.exists(dataPath)) {
1✔
182
      Some(spark.sparkContext.objectFile[TValue](dataPath.toString).first)
1✔
183
    } else {
184
      None
×
185
    }
186
  }
187

188
  override def serializeDataset(
189
      spark: SparkSession,
190
      path: String,
191
      field: String,
192
      value: TValue): Unit = {
193
    val dataPath = getFieldPath(path, field)
×
194
    spark.createDataset(Seq(value)).write.mode("overwrite").parquet(dataPath.toString)
×
195
  }
196

197
  override def deserializeDataset(
198
      spark: SparkSession,
199
      path: String,
200
      field: String): Option[TValue] = {
201
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
×
202
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
×
203
    val dataPath = getFieldPath(path, field)
×
204
    if (fs.exists(dataPath)) {
×
205
      Some(spark.read.parquet(dataPath.toString).as[TValue].first)
×
206
    } else {
207
      None
×
208
    }
209
  }
210

211
}
212

213
class MapFeature[TKey: ClassTag, TValue: ClassTag](model: HasFeatures, override val name: String)
214
    extends Feature[TKey, TValue, Map[TKey, TValue]](model, name) {
215

216
  implicit val encoder: Encoder[(TKey, TValue)] = Encoders.kryo[(TKey, TValue)]
1✔
217

218
  override def serializeObject(
219
      spark: SparkSession,
220
      path: String,
221
      field: String,
222
      value: Map[TKey, TValue]): Unit = {
223
    val dataPath = getFieldPath(path, field)
1✔
224
    spark.sparkContext.parallelize(value.toSeq).saveAsObjectFile(dataPath.toString)
1✔
225
  }
226

227
  override def deserializeObject(
228
      spark: SparkSession,
229
      path: String,
230
      field: String): Option[Map[TKey, TValue]] = {
231
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
1✔
232
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
1✔
233
    val dataPath = getFieldPath(path, field)
1✔
234
    if (fs.exists(dataPath)) {
1✔
235
      Some(spark.sparkContext.objectFile[(TKey, TValue)](dataPath.toString).collect.toMap)
1✔
236
    } else {
237
      None
1✔
238
    }
239
  }
240

241
  override def serializeDataset(
242
      spark: SparkSession,
243
      path: String,
244
      field: String,
245
      value: Map[TKey, TValue]): Unit = {
246
    import spark.implicits._
247
    val dataPath = getFieldPath(path, field)
×
248
    value.toSeq.toDS.write.mode("overwrite").parquet(dataPath.toString)
×
249
  }
250

251
  override def deserializeDataset(
252
      spark: SparkSession,
253
      path: String,
254
      field: String): Option[Map[TKey, TValue]] = {
255
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
×
256
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
×
257
    val dataPath = getFieldPath(path, field)
×
258
    if (fs.exists(dataPath)) {
×
259
      Some(spark.read.parquet(dataPath.toString).as[(TKey, TValue)].collect.toMap)
×
260
    } else {
261
      None
×
262
    }
263
  }
264

265
}
266

267
class ArrayFeature[TValue: ClassTag](model: HasFeatures, override val name: String)
268
    extends Feature[TValue, TValue, Array[TValue]](model, name) {
269

270
  implicit val encoder: Encoder[TValue] = Encoders.kryo[TValue]
1✔
271

272
  override def serializeObject(
273
      spark: SparkSession,
274
      path: String,
275
      field: String,
276
      value: Array[TValue]): Unit = {
277
    val dataPath = getFieldPath(path, field)
1✔
278
    spark.sparkContext.parallelize(value).saveAsObjectFile(dataPath.toString)
1✔
279
  }
280

281
  override def deserializeObject(
282
      spark: SparkSession,
283
      path: String,
284
      field: String): Option[Array[TValue]] = {
285
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
1✔
286
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
1✔
287
    val dataPath = getFieldPath(path, field)
1✔
288
    if (fs.exists(dataPath)) {
1✔
289
      Some(spark.sparkContext.objectFile[TValue](dataPath.toString).collect())
1✔
290
    } else {
291
      None
×
292
    }
293
  }
294

295
  override def serializeDataset(
296
      spark: SparkSession,
297
      path: String,
298
      field: String,
299
      value: Array[TValue]): Unit = {
300
    val dataPath = getFieldPath(path, field)
×
301
    spark.createDataset(value).write.mode("overwrite").parquet(dataPath.toString)
×
302
  }
303

304
  override def deserializeDataset(
305
      spark: SparkSession,
306
      path: String,
307
      field: String): Option[Array[TValue]] = {
308
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
×
309
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
×
310
    val dataPath = getFieldPath(path, field)
×
311
    if (fs.exists(dataPath)) {
×
312
      Some(spark.read.parquet(dataPath.toString).as[TValue].collect)
×
313
    } else {
314
      None
×
315
    }
316
  }
317

318
}
319

320
class SetFeature[TValue: ClassTag](model: HasFeatures, override val name: String)
321
    extends Feature[TValue, TValue, Set[TValue]](model, name) {
322

323
  implicit val encoder: Encoder[TValue] = Encoders.kryo[TValue]
1✔
324

325
  override def serializeObject(
326
      spark: SparkSession,
327
      path: String,
328
      field: String,
329
      value: Set[TValue]): Unit = {
330
    val dataPath = getFieldPath(path, field)
1✔
331
    spark.sparkContext.parallelize(value.toSeq).saveAsObjectFile(dataPath.toString)
1✔
332
  }
333

334
  override def deserializeObject(
335
      spark: SparkSession,
336
      path: String,
337
      field: String): Option[Set[TValue]] = {
338
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
1✔
339
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
1✔
340
    val dataPath = getFieldPath(path, field)
1✔
341
    if (fs.exists(dataPath)) {
1✔
342
      Some(spark.sparkContext.objectFile[TValue](dataPath.toString).collect().toSet)
1✔
343
    } else {
344
      None
×
345
    }
346
  }
347

348
  override def serializeDataset(
349
      spark: SparkSession,
350
      path: String,
351
      field: String,
352
      value: Set[TValue]): Unit = {
353
    val dataPath = getFieldPath(path, field)
×
354
    spark.createDataset(value.toSeq).write.mode("overwrite").parquet(dataPath.toString)
×
355
  }
356

357
  override def deserializeDataset(
358
      spark: SparkSession,
359
      path: String,
360
      field: String): Option[Set[TValue]] = {
361
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
×
362
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
×
363
    val dataPath = getFieldPath(path, field)
×
364
    if (fs.exists(dataPath)) {
×
365
      Some(spark.read.parquet(dataPath.toString).as[TValue].collect.toSet)
×
366
    } else {
367
      None
×
368
    }
369
  }
370

371
}
372

373
class TransducerFeature(model: HasFeatures, override val name: String)
374
    extends Feature[VocabParser, VocabParser, VocabParser](model, name) {
375

376
  override def serializeObject(
377
      spark: SparkSession,
378
      path: String,
379
      field: String,
380
      trans: VocabParser): Unit = {
381
    val dataPath = getFieldPath(path, field)
×
382
    spark.sparkContext.parallelize(Seq(trans), 1).saveAsObjectFile(dataPath.toString)
×
383

384
  }
385

386
  override def deserializeObject(
387
      spark: SparkSession,
388
      path: String,
389
      field: String): Option[VocabParser] = {
390
    val serializer = new PlainTextSerializer
×
391
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
×
392
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
×
393
    val dataPath = getFieldPath(path, field)
×
394
    if (fs.exists(dataPath)) {
×
395
      val sc = spark.sparkContext.objectFile[VocabParser](dataPath.toString).collect().head
×
396
      Some(sc)
×
397
    } else {
398
      None
×
399
    }
400
  }
401

402
  override def serializeDataset(
403
      spark: SparkSession,
404
      path: String,
405
      field: String,
406
      trans: VocabParser): Unit = {
407
    implicit val encoder: Encoder[VocabParser] = Encoders.kryo[VocabParser]
×
408
    val serializer = new PlainTextSerializer
×
409
    val dataPath = getFieldPath(path, field)
×
410
    spark.createDataset(Seq(trans)).write.mode("overwrite").parquet(dataPath.toString)
×
411
  }
412

413
  override def deserializeDataset(
414
      spark: SparkSession,
415
      path: String,
416
      field: String): Option[VocabParser] = {
417
    implicit val encoder: Encoder[VocabParser] = Encoders.kryo[VocabParser]
×
418
    val uri = new java.net.URI(path.replaceAllLiterally("\\", "/"))
×
419
    val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
×
420
    val dataPath = getFieldPath(path, field)
×
421
    if (fs.exists(dataPath)) {
×
422
      val sc = spark.read.parquet(dataPath.toString).as[VocabParser].collect.head
×
423
      Some(sc)
×
424
    } else {
425
      None
×
426
    }
427
  }
428

429
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc