• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohnSnowLabs / spark-nlp / 11429325160

20 Oct 2024 08:18PM UTC coverage: 60.052% (-0.2%) from 60.216%
11429325160

Pull #14439

github

web-flow
Merge 1c191569d into 9db33328b
Pull Request #14439: [SPARKNLP-1067] PromptAssembler

0 of 50 new or added lines in 2 files covered. (0.0%)

48 existing lines in 26 files now uncovered.

8985 of 14962 relevant lines covered (60.05%)

0.6 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/main/scala/com/johnsnowlabs/nlp/PromptAssembler.scala
1
package com.johnsnowlabs.nlp
2

3
import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT
4
import com.johnsnowlabs.nlp.llama.LlamaModel
5
import org.apache.spark.ml.Transformer
6
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap}
7
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
8
import org.apache.spark.sql.expressions.UserDefinedFunction
9
import org.apache.spark.sql.functions.udf
10
import org.apache.spark.sql.types._
11
import org.apache.spark.sql.{Column, DataFrame, Dataset}
12
import org.apache.spark.sql.types.StructType
13

14
/** Assembles a sequence of messages into a single string using a template. These strings can then
15
  * be used as prompts for large language models.
16
  *
17
  * This annotator expects an array of two-tuples as the type of the input column (one array of
18
  * tuples per row). The first element of the tuples should be the role and the second element is
19
  * the text of the message. Possible roles are "system", "user" and "assistant".
20
  *
21
  * An assistant header can be added to the end of the generated string by using
22
  * `setAddAssistant(true)`.
23
  *
24
  * At the moment, this annotator uses llama.cpp as a backend to parse and apply the templates.
25
  * llama.cpp uses basic pattern matching to determine the type of the template, then applies a
26
  * basic version of the template to the messages. This means that more advanced templates are not
27
  * supported.
28
  *
29
  * For an extended example see the
30
  * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/llama.cpp/PromptAssember_with_AutoGGUFModel.ipynb example notebook]].
31
  *
32
  * ==Example==
33
  * {{{
34
  * // Batches (whole conversations) of arrays of messages
35
  * val data: Seq[Seq[(String, String)]] = Seq(
36
  *   Seq(
37
  *     ("system", "You are a helpful assistant."),
38
  *     ("assistant", "Hello there, how can I help you?"),
39
  *     ("user", "I need help with organizing my room.")))
40
  *
41
  * val dataDF = data.toDF("messages")
42
  *
43
  * // llama3.1
44
  * val template =
45
  *   "{{- bos_token }} {%- if custom_tools is defined %} {%- set tools = custom_tools %} {%- " +
46
  *     "endif %} {%- if not tools_in_user_message is defined %} {%- set tools_in_user_message = true %} {%- " +
47
  *     "endif %} {%- if not date_string is defined %} {%- set date_string = \"26 Jul 2024\" %} {%- endif %} " +
48
  *     "{%- if not tools is defined %} {%- set tools = none %} {%- endif %} {#- This block extracts the " +
49
  *     "system message, so we can slot it into the right place. #} {%- if messages[0]['role'] == 'system' %}" +
50
  *     " {%- set system_message = messages[0]['content']|trim %} {%- set messages = messages[1:] %} {%- else" +
51
  *     " %} {%- set system_message = \"\" %} {%- endif %} {#- System message + builtin tools #} {{- " +
52
  *     "\"<|start_header_id|>system<|end_header_id|>\\n\\n\" }} {%- if builtin_tools is defined or tools is " +
53
  *     "not none %} {{- \"Environment: ipython\\n\" }} {%- endif %} {%- if builtin_tools is defined %} {{- " +
54
  *     "\"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}} " +
55
  *     "{%- endif %} {{- \"Cutting Knowledge Date: December 2023\\n\" }} {{- \"Today Date: \" + date_string " +
56
  *     "+ \"\\n\\n\" }} {%- if tools is not none and not tools_in_user_message %} {{- \"You have access to " +
57
  *     "the following functions. To call a function, please respond with JSON for a function call.\" }} {{- " +
58
  *     "'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its" +
59
  *     " value}.' }} {{- \"Do not use variables.\\n\\n\" }} {%- for t in tools %} {{- t | tojson(indent=4) " +
60
  *     "}} {{- \"\\n\\n\" }} {%- endfor %} {%- endif %} {{- system_message }} {{- \"<|eot_id|>\" }} {#- " +
61
  *     "Custom tools are passed in a user message with some extra guidance #} {%- if tools_in_user_message " +
62
  *     "and not tools is none %} {#- Extract the first user message so we can plug it in here #} {%- if " +
63
  *     "messages | length != 0 %} {%- set first_user_message = messages[0]['content']|trim %} {%- set " +
64
  *     "messages = messages[1:] %} {%- else %} {{- raise_exception(\"Cannot put tools in the first user " +
65
  *     "message when there's no first user message!\") }} {%- endif %} {{- " +
66
  *     "'<|start_header_id|>user<|end_header_id|>\\n\\n' -}} {{- \"Given the following functions, please " +
67
  *     "respond with a JSON for a function call \" }} {{- \"with its proper arguments that best answers the " +
68
  *     "given prompt.\\n\\n\" }} {{- 'Respond in the format {\"name\": function name, \"parameters\": " +
69
  *     "dictionary of argument name and its value}.' }} {{- \"Do not use variables.\\n\\n\" }} {%- for t in " +
70
  *     "tools %} {{- t | tojson(indent=4) }} {{- \"\\n\\n\" }} {%- endfor %} {{- first_user_message + " +
71
  *     "\"<|eot_id|>\"}} {%- endif %} {%- for message in messages %} {%- if not (message.role == 'ipython' " +
72
  *     "or message.role == 'tool' or 'tool_calls' in message) %} {{- '<|start_header_id|>' + message['role']" +
73
  *     " + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }} {%- elif 'tool_calls' in " +
74
  *     "message %} {%- if not message.tool_calls|length == 1 %} {{- raise_exception(\"This model only " +
75
  *     "supports single tool-calls at once!\") }} {%- endif %} {%- set tool_call = message.tool_calls[0]" +
76
  *     ".function %} {%- if builtin_tools is defined and tool_call.name in builtin_tools %} {{- " +
77
  *     "'<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}} {{- \"<|python_tag|>\" + tool_call.name + " +
78
  *     "\".call(\" }} {%- for arg_name, arg_val in tool_call.arguments | items %} {{- arg_name + '=\"' + " +
79
  *     "arg_val + '\"' }} {%- if not loop.last %} {{- \", \" }} {%- endif %} {%- endfor %} {{- \")\" }} {%- " +
80
  *     "else %} {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}} {{- '{\"name\": \"' + " +
81
  *     "tool_call.name + '\", ' }} {{- '\"parameters\": ' }} {{- tool_call.arguments | tojson }} {{- \"}\" " +
82
  *     "}} {%- endif %} {%- if builtin_tools is defined %} {#- This means we're in ipython mode #} {{- " +
83
  *     "\"<|eom_id|>\" }} {%- else %} {{- \"<|eot_id|>\" }} {%- endif %} {%- elif message.role == \"tool\" " +
84
  *     "or message.role == \"ipython\" %} {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }} {%- " +
85
  *     "if message.content is mapping or message.content is iterable %} {{- message.content | tojson }} {%- " +
86
  *     "else %} {{- message.content }} {%- endif %} {{- \"<|eot_id|>\" }} {%- endif %} {%- endfor %} {%- if " +
87
  *     "add_generation_prompt %} {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }} {%- endif %} "
88
  *
89
  * val promptAssembler = new PromptAssembler()
90
  *   .setInputCol("messages")
91
  *   .setOutputCol("prompt")
92
  *   .setChatTemplate(template)
93
  *
94
  * promptAssembler.transform(dataDF).select("prompt.result").show(truncate = false)
95
  * +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
96
  * |result                                                                                                                                                                                                                                                                                                                      |
97
  * +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
98
  * |[<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHello there, how can I help you?<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nI need help with organizing my room.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n]|
99
  * +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
100
  *
101
  * }}}
102
  *
103
  * @param uid
104
  *   required uid for storing annotator to disk
105
  * @groupname anno Annotator types
106
  * @groupdesc anno
107
  *   Required input and expected output annotator types
108
  * @groupname Ungrouped Members
109
  * @groupname param Parameters
110
  * @groupname setParam Parameter setters
111
  * @groupname getParam Parameter getters
112
  * @groupname Ungrouped Members
113
  * @groupprio param  1
114
  * @groupprio anno  2
115
  * @groupprio Ungrouped 3
116
  * @groupprio setParam  4
117
  * @groupprio getParam  5
118
  * @groupdesc param
119
  *   A list of (hyper-)parameter keys this annotator can take. Users can set and get the
120
  *   parameter values through setters and getters, respectively.
121
  */
122
class PromptAssembler(override val uid: String)
123
    extends Transformer
124
    with DefaultParamsWritable
125
    with HasOutputAnnotatorType
126
    with HasOutputAnnotationCol {
NEW
127
  override val outputAnnotatorType: AnnotatorType = DOCUMENT
×
128

NEW
129
  def this() = this(Identifiable.randomUID("PROMPT_ASSEMBLER"))
×
130

131
  val chatTemplate: Param[String] =
NEW
132
    new Param[String](this, "chatTemplate", "Template used for the chat")
×
133

134
  val inputCol: Param[String] =
NEW
135
    new Param[String](this, "inputCol", "Input column containing a sequence of messages")
×
136

137
  val addAssistant: BooleanParam =
NEW
138
    new BooleanParam(
×
139
      this,
NEW
140
      "addAssistant",
×
NEW
141
      "Whether to add an assistant header to the end of the generated string")
×
142

NEW
143
  setDefault(addAssistant -> true)
×
144

145
  /** Sets the input text column for processing
146
    *
147
    * @group setParam
148
    */
NEW
149
  def setInputCol(value: String): this.type = set(inputCol, value)
×
NEW
150
  def getInputCol: String = $(inputCol)
×
151

152
  /** Sets the chat template to be used for the chat. Should be something that llama.cpp can
153
    * parse.
154
    *
155
    * @param value
156
    *   The template to use
157
    */
NEW
158
  def setChatTemplate(value: String): this.type = set(chatTemplate, value)
×
159

160
  /** Gets the chat template to be used for the chat.
161
    *
162
    * @return
163
    *   The template to use
164
    */
NEW
165
  def getChatTemplate: String = $(chatTemplate)
×
166

167
  /** Whether to add an assistant header to the end of the generated string.
168
    *
169
    * @param value
170
    *   Whether to add the assistant header
171
    */
NEW
172
  def setAddAssistant(value: Boolean): this.type = set(addAssistant, value)
×
173

174
  /** Whether to add an assistant header to the end of the generated string.
175
    *
176
    * @return
177
    *   Whether to add the assistant header
178
    */
NEW
179
  def getAddAssistant: Boolean = $(addAssistant)
×
180

181
  // Expected Input type of the input column
NEW
182
  private val expectedInputType = ArrayType(
×
NEW
183
    StructType(
×
NEW
184
      Seq(
×
NEW
185
        StructField("_1", StringType, nullable = true),
×
NEW
186
        StructField("_2", StringType, nullable = true))),
×
NEW
187
    containsNull = true)
×
188

189
  /** Adds the result Annotation type to the schema.
190
    *
191
    * Requirement for pipeline transformation validation. It is called on fit()
192
    */
193
  override final def transformSchema(schema: StructType): StructType = {
NEW
194
    val metadataBuilder: MetadataBuilder = new MetadataBuilder()
×
NEW
195
    metadataBuilder.putString("annotatorType", outputAnnotatorType)
×
NEW
196
    val outputFields = schema.fields :+
×
NEW
197
      StructField(
×
NEW
198
        getOutputCol,
×
NEW
199
        ArrayType(Annotation.dataType),
×
NEW
200
        nullable = false,
×
NEW
201
        metadataBuilder.build)
×
NEW
202
    StructType(outputFields)
×
203
  }
204

205
  override def transform(dataset: Dataset[_]): DataFrame = {
NEW
206
    val metadataBuilder: MetadataBuilder = new MetadataBuilder()
×
NEW
207
    metadataBuilder.putString("annotatorType", outputAnnotatorType)
×
NEW
208
    val columnDataType = dataset.schema.fields
×
NEW
209
      .find(_.name == getInputCol)
×
210
      .getOrElse(
NEW
211
        throw new IllegalArgumentException(s"Dataset does not have any '$getInputCol' column"))
×
NEW
212
      .dataType
×
213

214
    val documentAnnotations: Column =
NEW
215
      if (columnDataType == expectedInputType) applyTemplate(dataset.col(getInputCol))
×
216
      else
NEW
217
        throw new IllegalArgumentException(
×
218
          s"Column '$getInputCol' must be of type Array[(String, String)] " +
219
            s"(exactly '$expectedInputType'), but got $columnDataType")
220

NEW
221
    dataset.withColumn(getOutputCol, documentAnnotations.as(getOutputCol, metadataBuilder.build))
×
222
  }
223

NEW
224
  private def applyTemplate: UserDefinedFunction = udf { chat: Seq[(String, String)] =>
×
NEW
225
    try {
×
NEW
226
      val template = $(chatTemplate)
×
227

NEW
228
      val chatArray = chat.map { case (role, text) =>
×
NEW
229
        Array(role, text)
×
NEW
230
      }.toArray
×
231

NEW
232
      val chatString = LlamaModel.applyChatTemplate(template, chatArray, $(addAssistant))
×
NEW
233
      Seq(Annotation(chatString))
×
234
    } catch {
235
      case _: Exception =>
236
        /*
237
         * when there is a null in the row
238
         * it outputs an empty Annotation
239
         * */
NEW
240
        Seq.empty
×
241
    }
242
  }
243

NEW
244
  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)
×
245
}
246

247
/** This is the companion object of [[PromptAssembler]]. Please refer to that class for the
248
  * documentation.
249
  */
250
object PromptAssembler extends DefaultParamsReadable[PromptAssembler]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc