• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

rendezqueue / rendezllama / 21364977367

26 Jan 2026 04:16PM UTC coverage: 87.834% (-0.8%) from 88.654%
21364977367

push

github

grencez
Implement batch token generation in Inference class

72 of 103 new or added lines in 3 files covered. (69.9%)

1 existing line in 1 file now uncovered.

2202 of 2507 relevant lines covered (87.83%)

515.98 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

62.69
/src/language/inference.cc
1
#include "src/language/inference.hh"
2

3
#include <algorithm>
4
#include <cassert>
5
#include <cstring>
6
#include <stdexcept>
7
#include <thread>
8
#include <vector>
9

10
#include <fildesh/fildesh.h>
11
#include <fildesh/ostream.hh>
12

13
#include "src/chat/display.hh"
14
#include "src/chat/guide.hh"
15
#include "src/chat/opt.hh"
16
#include "src/chat/trajectory.hh"
17
#include "src/language/vocabulary.hh"
18

19
using rendezllama::ChatDisplay;
20
using rendezllama::ChatGuide;
21
using rendezllama::ChatOptions;
22
using rendezllama::ChatTrajectory;
23
using rendezllama::Inference;
24
using rendezllama::Vocabulary;
25
using rendezllama::inference::AdjustViaKind;
26

27
Inference::Inference(const Vocabulary& vocabulary)
2✔
28
  : vocabulary_(vocabulary)
2✔
29
{}
2✔
30
Inference::~Inference() {
2✔
31
  if (smpl_) {llama_sampler_free(smpl_);}
2✔
32
  llama_batch_free(batch_);
2✔
33
}
2✔
34

35
  const std::string&
36
rendezllama::antiprompt_suffix(
5✔
37
    std::string_view text,
38
    const std::set<std::string>& antiprompts)
39
{
40
  static const std::string empty_string;
5✔
41
  for (const std::string& s : antiprompts) {
11✔
42
    if (text.size() >= s.size()) {
9✔
43
      const size_t offset = text.size() - s.size();
6✔
44
      if (0 == memcmp(&text[offset], &s[0], s.size())) {
6✔
45
        return s;
3✔
46
      }
47
    }
48
  }
49
  return empty_string;
2✔
50
}
51

52
static bool maybe_trim_endspace(std::string& s)
×
53
{
54
  bool result = false;
×
55
  while (!s.empty() && s.back() == ' ') {
×
56
    s.pop_back();
×
57
    result = true;
×
58
  }
59
  return result;
×
60
}
61

62
  void
63
rendezllama::augment_tokenize_chat_input(
×
64
    ChatGuide& chat_guide,
65
    ChatTrajectory& chat_traj,
66
    bool& prevent_subsequent_newline,
67
    std::string s,
68
    const Vocabulary& vocabulary,
69
    const ChatOptions& opt)
70
{
71
  prevent_subsequent_newline = false;
×
72
  if (s.size() >= 2 && s[0] == '\\' && s[1] == 'n') {
×
73
    chat_guide.end_turn();
×
74
    chat_guide.begin_turn(opt.message_opts.size()-1);
×
75
    s.erase(0, 2);
×
76
    prevent_subsequent_newline = maybe_trim_endspace(s);
×
77
    if (opt.message_opts.back().prefix.back() == '\n' && opt.linespace_on) {
×
78
      if (!s.empty() && s.front() != ' ') {
×
79
        s.insert(0, " ");
×
80
      }
81
    }
82
    chat_traj.tokenize_append(s, vocabulary);
×
83
  }
84
  else if (s.front() == '\n') {
×
85
    // This is from /yield.
86
    chat_guide.yield_turn(s.substr(1));
×
87
  }
88
  else if (s.front() == ' ') {
×
89
    prevent_subsequent_newline = maybe_trim_endspace(s);
×
90
    chat_traj.tokenize_append(s, vocabulary);
×
91
  }
92
  else {
93
    chat_guide.yield_turn(0);
×
94
    if (opt.message_opts[0].prefix.back() == '\n' && opt.linespace_on) {
×
95
      if (!s.empty() && s.front() != ' ') {
×
96
        s.insert(0, " ");
×
97
      }
98
    }
99
    chat_traj.tokenize_append(s, vocabulary);
×
100
    chat_guide.yield_turn();
×
101
    chat_traj.display_token_count_ = chat_traj.rfind_message_prefix_begin_at(
×
102
        chat_traj.token_count()-1);
×
103
    prevent_subsequent_newline = true;
×
104
  }
105
}
×
106

107
  std::tuple<struct llama_model*, struct llama_context*>
108
rendezllama::make_llama_context(rendezllama::ChatOptions& opt)
2✔
109
{
110
  llama_model_params model_params = llama_model_default_params();
2✔
111
  model_params.use_mlock = opt.mlock_on;
2✔
112
  model_params.use_mmap = opt.mmap_on;
2✔
113

114
  struct llama_model* model = llama_model_load_from_file(
2✔
115
      opt.model_filename.c_str(), model_params);
2✔
116
  if (!model) {
2✔
117
    fildesh_log_error("Failed to open model.");
×
118
    return std::make_tuple(nullptr, nullptr);
×
119
  }
120

121
  if (opt.model_token_limit == 0) {
2✔
122
    opt.model_token_limit = llama_model_n_ctx_train(model);
2✔
123
  }
124
  if (opt.context_token_limit == 0) {
2✔
125
    opt.context_token_limit = opt.model_token_limit;
1✔
126
  }
127
  float rope_freq_scale = llama_model_rope_freq_scale_train(model);
2✔
128
  if (rope_freq_scale <= 0.0) {
2✔
129
    rope_freq_scale = 1.0f;
×
130
  }
131
  while (
132
      (unsigned)(opt.model_token_limit / rope_freq_scale)
2✔
133
      <
2✔
134
      opt.context_token_limit)
2✔
135
  {
136
    rope_freq_scale /= 2;
×
137
  }
138
  llama_model_free(model);
2✔
139
  model = nullptr;
2✔
140

141

142
  model_params = llama_model_default_params();
2✔
143
  model_params.use_mlock = opt.mlock_on;
2✔
144
  model_params.use_mmap = opt.mmap_on;
2✔
145

146
  llama_context_params ctx_params = llama_context_default_params();
2✔
147
  ctx_params.n_ctx = opt.context_token_limit;
2✔
148
  ctx_params.n_batch = opt.batch_count;
2✔
149
  ctx_params.rope_freq_scale = rope_freq_scale;
2✔
150

151
  std::vector<float> tensor_split(llama_max_devices());
2✔
152
  std::vector<llama_model_tensor_buft_override> tensor_buft_overrides(llama_max_tensor_buft_overrides());
2✔
153
  std::vector<size_t> margins(llama_max_devices(), 0);
2✔
154

155
  // Auto-tune parameters if possible (and not manually overridden by user yet).
156
  // This helps avoid OOM crashes on Vulkan/GPU by fitting layers to available memory.
157
  auto status = llama_params_fit(
2✔
158
      opt.model_filename.c_str(),
159
      &model_params,
160
      &ctx_params,
161
      tensor_split.data(),
162
      tensor_buft_overrides.data(),
163
      margins.data(),
164
      /*n_ctx_min=*/0,
165
      GGML_LOG_LEVEL_ERROR);
166

167
  if (status != 0) {
2✔
168
    fildesh_log_warning("llama_params_fit failed");
×
169
  }
170

171
  model = llama_model_load_from_file(
2✔
172
      opt.model_filename.c_str(), model_params);
173
  if (!model) {
2✔
174
    fildesh_log_error("Failed to open model.");
×
175
    return std::make_tuple(nullptr, nullptr);
×
176
  }
177

178
  struct llama_context* ctx = llama_init_from_model(model, ctx_params);
2✔
179
  if (!ctx) {
2✔
180
    llama_model_free(model);
×
181
    fildesh_log_error("Failed to create context.");
×
182
    return std::make_tuple(nullptr, nullptr);
×
183
  }
184
  return std::make_tuple(model, ctx);
2✔
185
}
2✔
186

187
static
188
  int
189
new_sampling_seed()
2✔
190
{
191
  return static_cast<int>(INT_MAX & time(NULL));
2✔
192
}
193

194
static
195
  void
196
apply_sampler_chain(
3✔
197
    struct llama_sampler* smpl,
198
    const rendezllama::inference::AdjustVia& adjust_via,
199
    const struct llama_model* model,
200
    unsigned seed,
201
    std::ostream& eout)
202
{
203
  const unsigned keep_one = 1;
3✔
204

205
  if (const auto* dry = std::get_if<AdjustViaKind::dry>(&adjust_via)) {
3✔
206
    static const char* seq_breakers[] = {
×
207
      "\n", ":",
208
    };
209
    llama_sampler_init_dry(
×
210
        llama_model_get_vocab(model),
211
        llama_model_n_ctx_train(model),
212
        dry->multiplier,
×
213
        dry->base,
×
214
        dry->allowed_length,
×
215
        dry->window_length,
×
216
        seq_breakers,
217
        sizeof(seq_breakers)/sizeof(*seq_breakers));
218
    eout << "dry:"
×
219
      << "\n  multiplier: " << dry->multiplier
×
220
      << "\n  base: " << dry->base
×
221
      << "\n  allowed_length: " << dry->allowed_length
×
222
      << "\n  window_length: " << dry->window_length
×
223
      << "\n";
×
224
  }
225
  if (const auto* min_p = std::get_if<AdjustViaKind::min_p>(&adjust_via)) {
3✔
226
    llama_sampler_chain_add(smpl, llama_sampler_init_min_p(*min_p, keep_one));
1✔
227
    eout << "min_p: " << *min_p << "\n";
1✔
228
  }
229
  if (const auto* penalize_with = std::get_if<AdjustViaKind::penalize_with>(&adjust_via)) {
3✔
230
    llama_sampler_init_penalties(
×
231
        penalize_with->window_length,
×
232
        penalize_with->repetition,
×
233
        penalize_with->frequency,
×
234
        penalize_with->presence);
×
235
    eout << "penalties:"
×
236
      << "\n  window_length: " << penalize_with->window_length
×
237
      << "\n  repetition: " << penalize_with->repetition
×
238
      << "\n  frequency: " << penalize_with->frequency
×
239
      << "\n  presence: " << penalize_with->presence
×
240
      << "\n";
×
241
  }
242
  if (const auto* temperature = std::get_if<AdjustViaKind::temperature>(&adjust_via)) {
3✔
243
    llama_sampler_chain_add(smpl, llama_sampler_init_temp(*temperature));
2✔
244
    eout << "temperature: " << *temperature << "\n";
2✔
245
  }
246
  if (const auto* top_k = std::get_if<AdjustViaKind::top_k>(&adjust_via)) {
3✔
247
    llama_sampler_chain_add(smpl, llama_sampler_init_top_k(*top_k));
×
248
    eout << "top_k: " << *top_k << "\n";
×
249
  }
250
  if (const auto* top_p = std::get_if<AdjustViaKind::top_p>(&adjust_via)) {
3✔
251
    llama_sampler_chain_add(smpl, llama_sampler_init_top_p(*top_p, keep_one));
×
252
    eout << "top_p: " << *top_p << "\n";
×
253
  }
254
  if (const auto* typical_p = std::get_if<AdjustViaKind::typical_p>(&adjust_via)) {
3✔
255
    llama_sampler_chain_add(smpl, llama_sampler_init_typical(*typical_p, keep_one));
×
256
    eout << "typical_p: " << *typical_p << "\n";
×
257
  }
258
  if (const auto* xtc = std::get_if<AdjustViaKind::xtc>(&adjust_via)) {
3✔
259
    llama_sampler_chain_add(smpl, llama_sampler_init_xtc(xtc->probability, xtc->threshold, keep_one, seed));
×
260
    eout << "xtc: "
×
261
      << "\n  probability: " << xtc->probability
×
262
      << "\n  threshold: " << xtc->threshold
×
263
      << "\n";
×
264
  }
265
}
3✔
266

267
static
268
  void
269
adaptive_p_sample(
×
270
    struct llama_sampler* smpl,
271
    const rendezllama::inference::AdaptiveP& adaptive_p,
272
    unsigned seed)
273
{
274
  llama_sampler_chain_add(
×
275
      smpl,
276
      llama_sampler_init_adaptive_p(
277
          adaptive_p.target,
×
278
          adaptive_p.decay,
×
279
          seed));
280
}
×
281

282
static
283
  void
284
mirostat_sample(
×
285
    struct llama_sampler* smpl,
286
    const rendezllama::inference::Mirostat& mirostat,
287
    unsigned seed,
288
    const rendezllama::Vocabulary& vocabulary)
289
{
290
  if (mirostat.version == 1) {
×
291
    const int mirostat_m = 100;
×
292
    llama_sampler_chain_add(
×
293
        smpl,
294
        llama_sampler_init_mirostat(
295
            vocabulary.cardinality(), seed,
×
296
            mirostat.tau, mirostat.eta, mirostat_m));
×
297
  }
298
  else if (mirostat.version == 2) {
×
299
    llama_sampler_chain_add(
×
300
        smpl,
301
        llama_sampler_init_mirostat_v2(
302
            seed, mirostat.tau, mirostat.eta));
×
303
  }
304
}
×
305

306
static
307
  std::tuple<unsigned, unsigned>
308
infer_thread_counts(const rendezllama::ChatOptions& opt)
2✔
309
{
310
  unsigned thread_count = opt.thread_count;
2✔
311
  unsigned batch_thread_count = opt.batch_thread_count;
2✔
312
  const unsigned n = std::thread::hardware_concurrency();
2✔
313
  if (thread_count == 0) {
2✔
314
    thread_count = n / 2;
2✔
315
    if (thread_count == 0) {
2✔
316
      thread_count = 1;
×
317
    }
318
#if defined(__x86_64__) || defined(_M_X64) || defined(__i386__) || defined(_M_IX86)
319
    if (2 <= n && n <= 4) {
2✔
320
      thread_count = n;
2✔
321
    }
322
#endif
323
  }
324
  if (batch_thread_count == 0) {
2✔
325
    batch_thread_count = n;
2✔
326
  }
327
  return std::make_tuple(thread_count, batch_thread_count);
2✔
328
}
329

330
  void
331
Inference::reinitialize(const ChatOptions& opt, const struct llama_model* model)
2✔
332
{
333
  fildesh::ofstream eout("/dev/stderr");
2✔
334

335
  const auto* sampling = std::get_if<rendezllama::inference::Sampling>(&opt.infer_via);
2✔
336
  assert(sampling);
×
337
  auto seed = sampling->seed;
2✔
338
  if (smpl_ || seed < 0) {
2✔
339
    // We're retrying or just don't have a fixed seed, so we should reseed.
340
    seed = new_sampling_seed();
2✔
341
  }
342
  std::tie(thread_count_, batch_thread_count_) = infer_thread_counts(opt);
2✔
343
  if (smpl_) {
2✔
344
    llama_sampler_free(smpl_);
×
345
    eout.open("/dev/null");
×
346
  }
347
  token_count_ = 0;
2✔
348
  auto smpl_param = llama_sampler_chain_default_params();
2✔
349
  smpl_ = llama_sampler_chain_init(smpl_param);
2✔
350

351
  for (const auto& adjust_via : sampling->adjust_thru) {
5✔
352
    apply_sampler_chain(smpl_, adjust_via, model, seed, eout);
3✔
353
  }
354

355
  if (std::get_if<rendezllama::inference::Probability>(&sampling->pick_via)) {
2✔
356
    llama_sampler_chain_add(smpl_, llama_sampler_init_dist(seed));
×
357
  }
358
  else if (std::get_if<rendezllama::inference::Determinism>(&sampling->pick_via)) {
2✔
359
    llama_sampler_chain_add(smpl_, llama_sampler_init_greedy());
1✔
360
  }
361
  else if (const auto* adaptive_p = std::get_if<rendezllama::inference::AdaptiveP>(&sampling->pick_via)) {
1✔
362
    adaptive_p_sample(smpl_, *adaptive_p, seed);
×
363
  }
364
  else if (const auto* mirostat = std::get_if<rendezllama::inference::Mirostat>(&sampling->pick_via)) {
1✔
365
    mirostat_sample(smpl_, *mirostat, seed, vocabulary_);
×
366
  }
367
  else {
368
    fildesh_log_error("Missing pick method? Using greedy.");
1✔
369
    llama_sampler_chain_add(smpl_, llama_sampler_init_greedy());
1✔
370
  }
371
}
2✔
372

373
  bool
374
Inference::commit_to_context(
14✔
375
    struct llama_context* ctx,
376
    ChatDisplay& chat_disp,
377
    ChatTrajectory& chat_traj,
378
    const ChatOptions& opt,
379
    const llama_model* model)
380
{
381
  assert(!chat_traj.erased_since_eval_ ||
14✔
382
         chat_traj.context_token_count_ < chat_traj.token_count());
383
  if (chat_traj.erased_since_eval_ || !smpl_) {
14✔
384
    this->reinitialize(opt, model);
2✔
385
  }
386
  if (chat_traj.context_token_count_ == chat_traj.token_count()) {
14✔
387
    return true;
388
  }
389

390
  chat_traj.maybe_rollforget_within_limit(opt.context_token_limit, vocabulary_);
13✔
391

392
  // Reset thread count just in case the user reconfigured it.
393
  llama_set_n_threads(ctx, thread_count_, batch_thread_count_);
13✔
394

395
  // Clear KV cache past current position just in case the user deleted tokens.
396
  llama_memory_seq_rm(
26✔
397
      llama_get_memory(ctx),
398
      0, chat_traj.context_token_count_, -1);
13✔
399

400
  while (chat_traj.context_token_count_ < chat_traj.token_count()) {
39✔
401
    const unsigned n = std::min(
13✔
402
        opt.batch_count,
13✔
403
        chat_traj.token_count() - chat_traj.context_token_count_);
13✔
404

405
    chat_disp.show_new(chat_traj.context_token_count_ + n, chat_traj, vocabulary_);
13✔
406

407
    if (!batch_.token || (unsigned)batch_.n_tokens < n) {
13✔
408
      llama_batch_free(batch_);
2✔
409
      unsigned n_alloc = n;
2✔
410
      if (n_alloc < opt.batch_count) {n_alloc = opt.batch_count;}
2✔
411
      batch_ = llama_batch_init(n_alloc, /*embd=*/0, /*n_seq_max=*/1);
2✔
412
    }
413
    batch_.n_tokens = n;
13✔
414
    for (unsigned i = 0; i < n; ++i) {
43✔
415
      batch_.token[i] = chat_traj.tokens()[chat_traj.context_token_count_ + i];
30✔
416
      batch_.pos[i] = chat_traj.context_token_count_ + i;
30✔
417
      batch_.n_seq_id[i] = 1;
30✔
418
      batch_.seq_id[i][0] = 0;
30✔
419
      batch_.logits[i] = (i == n - 1);
30✔
420
    }
421

422
    const int istat = llama_decode(ctx, batch_);
13✔
423

424
    if (istat != 0) {
13✔
425
      fildesh_log_error("Failed to eval.");
×
426
      chat_traj.context_token_count_ = 0;
×
427
      return false;
×
428
    }
429
    else {
430
      chat_traj.context_token_count_ += n;
13✔
431
    }
432
  }
433
  assert(chat_traj.context_token_count_ == chat_traj.token_count());
13✔
434
  chat_traj.erased_since_eval_ = false;
13✔
435
  while (token_count_ < chat_traj.token_count()) {
32✔
436
    Vocabulary::Token_id token_id = chat_traj.token_at(token_count_);
19✔
437
    llama_sampler_accept(smpl_, token_id);
19✔
438
    token_count_ += 1;
19✔
439
  }
440
  return true;
441
}
442

443
  void
444
Inference::sample_to_trajectory(
11✔
445
    ChatTrajectory& chat_traj,
446
    struct llama_context* ctx,
447
    bool preventing_newline)
448
{
449
  float* logits = llama_get_logits(ctx);
11✔
450
  if (preventing_newline) {
11✔
451
    // Zero probability for message-ending tokens when requested.
452
    logits[vocabulary_.eos_token_id()] = 0;
×
453
    logits[vocabulary_.newline_token_id()] = 0;
×
454
  }
455

456
  std::vector<llama_token_data> candidates;
11✔
457
  candidates.resize(vocabulary_.cardinality());
11✔
458
  for (llama_token i = 0; i < (llama_token)candidates.size(); ++i) {
22,539✔
459
    candidates[i] = llama_token_data{
22,528✔
460
      i, logits[i], 0.0f,
22,528✔
461
    };
462
  }
463
  logits = NULL;
11✔
464
  llama_token_data_array candidates_data[1] = {{
11✔
465
    candidates.data(),
11✔
466
    candidates.size(),
11✔
467
    /*selected=*/0,
468
    /*sorted=*/false,
469
  }};
11✔
470
  llama_sampler_apply(smpl_, candidates_data);
11✔
471
  chat_traj.push_back(candidates[candidates_data->selected].id);
11✔
472
  llama_sampler_accept(smpl_, chat_traj.token());
11✔
473
  token_count_ += 1;
11✔
474
}
11✔
475

476
  void
477
Inference::sample_to_trajectory(
15✔
478
    ChatTrajectory& chat_traj,
479
    struct llama_context* ctx,
480
    int batch_idx)
481
{
482
  float* logits = llama_get_logits_ith(ctx, batch_idx);
15✔
483
  // Note: We don't support preventing_newline here yet, but usually not needed for speculation verification.
484

485
  std::vector<llama_token_data> candidates;
15✔
486
  candidates.resize(vocabulary_.cardinality());
15✔
487
  for (llama_token i = 0; i < (llama_token)candidates.size(); ++i) {
337,935✔
488
    candidates[i] = llama_token_data{
337,920✔
489
      i, logits[i], 0.0f,
337,920✔
490
    };
491
  }
492
  logits = NULL;
15✔
493
  llama_token_data_array candidates_data[1] = {{
15✔
494
    candidates.data(),
15✔
495
    candidates.size(),
15✔
496
    /*selected=*/0,
497
    /*sorted=*/false,
498
  }};
15✔
499
  llama_sampler_apply(smpl_, candidates_data);
15✔
500
  chat_traj.push_back(candidates[candidates_data->selected].id);
15✔
501
  llama_sampler_accept(smpl_, chat_traj.token());
15✔
502
  token_count_ += 1;
15✔
503
}
15✔
504

505
static
506
  void
507
find_ngram_candidates(
15✔
508
    std::vector<llama_token>& candidates,
509
    const std::vector<llama_token>& tokens,
510
    unsigned n_gram_len,
511
    unsigned candidate_limit)
512
{
513
  candidates.clear();
15✔
514
  if (tokens.size() < n_gram_len) { return; }
15✔
515

516
  // Simple backward search
517
  size_t n = tokens.size();
15✔
518
  for (size_t i = n - 1 - n_gram_len; i > 0; --i) { // i is the index of the last token of the match candidate
245✔
519
      // We want tokens[i - n_gram_len + 1 ... i] == tokens[n - n_gram_len ... n - 1]
520
      bool match = true;
232✔
521
      for (size_t j = 0; j < n_gram_len; ++j) {
232✔
522
          if (tokens[i - j] != tokens[n - 1 - j]) {
232✔
523
              match = false;
524
              break;
525
          }
526
      }
527
      if (match) {
230✔
528
          // Found match ending at i.
529
          // Candidate tokens start at i + 1.
NEW
530
          for (size_t k = 0; k < candidate_limit; ++k) {
×
NEW
531
              if (i + 1 + k < n) { // Ensure we don't go past current end (though finding self is weird if we search backwards enough)
×
532
                 // Actually we can grab from history.
NEW
533
                 if (i + 1 + k < tokens.size()) {
×
NEW
534
                     candidates.push_back(tokens[i + 1 + k]);
×
535
                 }
536
              }
537
          }
NEW
538
          if (!candidates.empty()) return;
×
539
      }
540
  }
541
}
542

543
  bool
544
Inference::generate_next_tokens(
2✔
545
    struct llama_context* ctx,
546
    ChatDisplay& chat_disp,
547
    ChatTrajectory& chat_traj,
548
    const ChatOptions& opt,
549
    const llama_model* model,
550
    unsigned n_tokens)
551
{
552
  if (!this->commit_to_context(ctx, chat_disp, chat_traj, opt, model)) {
2✔
553
    return false;
554
  }
555

556
  std::vector<llama_token> draft_candidates;
2✔
557

558
  for (unsigned i = 0; i < n_tokens; ++i) {
16✔
559
    // 1. We start with a token that has been sampled/accepted but not decoded.
560
    // The previous loop iteration or commit_to_context left us in this state.
561
    // chat_traj.token() is the last token (T).
562

563
    // Draft candidates
564
    draft_candidates.clear();
15✔
565
    const unsigned kDraftMax = 5;
15✔
566
    // Only draft if we have enough context and batch space
567
    if (chat_traj.context_token_count_ + kDraftMax + 1 < opt.context_token_limit &&
15✔
568
        opt.batch_count >= kDraftMax + 1)
15✔
569
    {
570
       find_ngram_candidates(draft_candidates, chat_traj.tokens(), 2, kDraftMax);
15✔
571
    }
572

573
    // Ensure batch size
574
    unsigned required_batch = 1 + draft_candidates.size();
15✔
575
    // We assume batch_ has capacity of at least opt.batch_count (set by commit_to_context or previous init).
576
    // Reallocate only if we exceed that or if batch_ is null.
577
    if (!batch_.token || opt.batch_count < required_batch) {
15✔
NEW
578
      if (batch_.token) llama_batch_free(batch_);
×
NEW
579
      unsigned n_alloc = std::max(opt.batch_count, required_batch);
×
NEW
580
      batch_ = llama_batch_init(n_alloc, /*embd=*/0, /*n_seq_max=*/1);
×
581
    }
582

583
    // Prepare batch: [T, C1, C2, ...]
584
    batch_.n_tokens = required_batch;
15✔
585
    batch_.token[0] = chat_traj.token();
15✔
586
    batch_.pos[0] = chat_traj.context_token_count_;
15✔
587
    batch_.n_seq_id[0] = 1;
15✔
588
    batch_.seq_id[0][0] = 0;
15✔
589
    batch_.logits[0] = true;
15✔
590

591
    for (size_t k = 0; k < draft_candidates.size(); ++k) {
15✔
NEW
592
        batch_.token[k+1] = draft_candidates[k];
×
NEW
593
        batch_.pos[k+1] = chat_traj.context_token_count_ + 1 + k;
×
NEW
594
        batch_.n_seq_id[k+1] = 1;
×
NEW
595
        batch_.seq_id[k+1][0] = 0;
×
NEW
596
        batch_.logits[k+1] = true;
×
597
    }
598

599
    if (llama_decode(ctx, batch_) != 0) {
15✔
NEW
600
      fildesh_log_error("Failed to eval.");
×
601
      return false;
1✔
602
    }
603

604
    // T is definitely decoded.
605
    chat_traj.context_token_count_ += 1;
15✔
606

607
    // Verify
608
    bool divergence_found = false;
15✔
609
    for (size_t k = 0; k < draft_candidates.size(); ++k) {
15✔
610
        // Sample R from logits of batch[k] (which was input T or C(k-1))
611
        // Expect R == C(k).
NEW
612
        this->sample_to_trajectory(chat_traj, ctx, (int)k);
×
613
        // sample_to_trajectory pushed R to chat_traj.
614

NEW
615
        if (chat_traj.token() == draft_candidates[k]) {
×
616
            // Match!
NEW
617
            chat_traj.context_token_count_ += 1; // C(k) is confirmed decoded
×
NEW
618
            i += 1; // We advanced one extra step in the main generation request
×
NEW
619
            if (chat_traj.token() == vocabulary_.eos_token_id()) {
×
NEW
620
                chat_disp.show_new(chat_traj, vocabulary_);
×
621
                return true;
622
            }
623
        } else {
624
            // Divergence!
625
            // R != C(k). We accepted R.
626
            // But we need to remove C(k)... from KV cache.
627
            // C(k) was at batch_.pos[k+1].
NEW
628
            llama_memory_seq_rm(llama_get_memory(ctx), 0, batch_.pos[k+1], -1);
×
629
            divergence_found = true;
630
            break;
631
        }
632
    }
633

634
    if (!divergence_found) {
15✔
635
        // All candidates matched.
636
        // We still have the output from the last candidate (batch index `draft_candidates.size()`).
637
        // Use it to generate the *next* token (start of next iteration).
638
        this->sample_to_trajectory(chat_traj, ctx, (int)draft_candidates.size());
15✔
639
        // This new token is NOT decoded yet.
640
    }
641
    else {
642
         // Divergence found. We accepted R (the divergence).
643
         // R is NOT decoded.
644
         // We removed invalid KV.
645
         // We are ready for next iter.
646
    }
647

648
    if (chat_traj.token() == vocabulary_.eos_token_id()) {
15✔
649
      chat_disp.show_new(chat_traj, vocabulary_);
1✔
650
      return true;
651
    }
652

653
    chat_disp.show_new(chat_traj, vocabulary_);
14✔
654

655
    if (chat_traj.context_token_count_ >= opt.context_token_limit) {
14✔
656
       // Sync if full
NEW
657
       if (!this->commit_to_context(ctx, chat_disp, chat_traj, opt, model)) {
×
658
         return false;
659
       }
660
    }
661
  }
662
  return true;
663
}
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc