• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

rendezqueue / rendezllama / 21178347239

20 Jan 2026 04:02PM UTC coverage: 86.957% (-3.8%) from 90.785%
21178347239

push

github

grencez
qual: Do not reinitialize sampling for each token

1 of 3 new or added lines in 1 file covered. (33.33%)

85 existing lines in 1 file now uncovered.

1980 of 2277 relevant lines covered (86.96%)

118.41 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

38.84
/src/language/inference.cc
1
#include "src/language/inference.hh"
2

3
#include <algorithm>
4
#include <cassert>
5
#include <cstring>
6
#include <thread>
7

8
#include <fildesh/fildesh.h>
9
#include <fildesh/ostream.hh>
10

11
#include "src/chat/display.hh"
12
#include "src/chat/guide.hh"
13
#include "src/chat/opt.hh"
14
#include "src/chat/trajectory.hh"
15
#include "src/language/vocabulary.hh"
16

17
using rendezllama::ChatDisplay;
18
using rendezllama::ChatGuide;
19
using rendezllama::ChatOptions;
20
using rendezllama::ChatTrajectory;
21
using rendezllama::Inference;
22
using rendezllama::Vocabulary;
23
using rendezllama::inference::AdjustViaKind;
24

25
Inference::Inference(const Vocabulary& vocabulary)
1✔
26
  : vocabulary_(vocabulary)
1✔
27
{}
1✔
28
Inference::~Inference() {
1✔
29
  if (smpl_) {llama_sampler_free(smpl_);}
1✔
30
}
1✔
31

32
  const std::string&
33
rendezllama::antiprompt_suffix(
5✔
34
    std::string_view text,
35
    const std::set<std::string>& antiprompts)
36
{
37
  static const std::string empty_string;
5✔
38
  for (const std::string& s : antiprompts) {
11✔
39
    if (text.size() >= s.size()) {
9✔
40
      const size_t offset = text.size() - s.size();
6✔
41
      if (0 == memcmp(&text[offset], &s[0], s.size())) {
6✔
42
        return s;
3✔
43
      }
44
    }
45
  }
46
  return empty_string;
2✔
47
}
48

49
static bool maybe_trim_endspace(std::string& s)
×
50
{
51
  bool result = false;
×
52
  while (!s.empty() && s.back() == ' ') {
×
53
    s.pop_back();
×
54
    result = true;
×
55
  }
56
  return result;
×
57
}
58

59
  void
60
rendezllama::augment_tokenize_chat_input(
×
61
    ChatGuide& chat_guide,
62
    ChatTrajectory& chat_traj,
63
    bool& prevent_subsequent_newline,
64
    std::string s,
65
    const Vocabulary& vocabulary,
66
    const ChatOptions& opt)
67
{
68
  prevent_subsequent_newline = false;
×
69
  if (s.size() >= 2 && s[0] == '\\' && s[1] == 'n') {
×
70
    chat_guide.end_turn();
×
71
    chat_guide.begin_turn(opt.message_opts.size()-1);
×
72
    s.erase(0, 2);
×
73
    prevent_subsequent_newline = maybe_trim_endspace(s);
×
74
    if (opt.message_opts.back().prefix.back() == '\n' && opt.linespace_on) {
×
75
      if (!s.empty() && s.front() != ' ') {
×
76
        s.insert(0, " ");
×
77
      }
78
    }
79
    chat_traj.tokenize_append(s, vocabulary);
×
80
  }
81
  else if (s.front() == '\n') {
×
82
    // This is from /yield.
83
    chat_guide.yield_turn(s.substr(1));
×
84
  }
85
  else if (s.front() == ' ') {
×
86
    prevent_subsequent_newline = maybe_trim_endspace(s);
×
87
    chat_traj.tokenize_append(s, vocabulary);
×
88
  }
89
  else {
90
    chat_guide.yield_turn(0);
×
91
    if (opt.message_opts[0].prefix.back() == '\n' && opt.linespace_on) {
×
92
      if (!s.empty() && s.front() != ' ') {
×
93
        s.insert(0, " ");
×
94
      }
95
    }
96
    chat_traj.tokenize_append(s, vocabulary);
×
97
    chat_guide.yield_turn();
×
98
    chat_traj.display_token_count_ = chat_traj.rfind_message_prefix_begin_at(
×
99
        chat_traj.token_count()-1);
×
100
    prevent_subsequent_newline = true;
×
101
  }
102
}
×
103

104
  std::tuple<struct llama_model*, struct llama_context*>
105
rendezllama::make_llama_context(rendezllama::ChatOptions& opt)
1✔
106
{
107
  llama_model_params model_params = llama_model_default_params();
1✔
108
  model_params.use_mlock = opt.mlock_on;
1✔
109
  model_params.use_mmap = opt.mmap_on;
1✔
110

111
  struct llama_model* model = llama_model_load_from_file(
1✔
112
      opt.model_filename.c_str(), model_params);
1✔
113
  if (!model) {
1✔
114
    fildesh_log_error("Failed to open model.");
×
115
    return std::make_tuple(nullptr, nullptr);
×
116
  }
117

118
  if (opt.model_token_limit == 0) {
1✔
119
    opt.model_token_limit = llama_model_n_ctx_train(model);
1✔
120
  }
121
  if (opt.context_token_limit == 0) {
1✔
122
    opt.context_token_limit = opt.model_token_limit;
1✔
123
  }
124

125
  model_params = llama_model_default_params();
1✔
126
  model_params.use_mlock = opt.mlock_on;
1✔
127
  model_params.use_mmap = opt.mmap_on;
1✔
128

129
  llama_context_params ctx_params = llama_context_default_params();
1✔
130
  ctx_params.n_ctx = opt.context_token_limit;
1✔
131
  ctx_params.n_threads = opt.thread_count;
1✔
132
  ctx_params.n_batch = opt.batch_count;
1✔
133
  ctx_params.rope_freq_scale = llama_model_rope_freq_scale_train(model);
1✔
134
  assert(ctx_params.rope_freq_scale > 0.0);
1✔
135
  while (
136
      (unsigned)(opt.model_token_limit / ctx_params.rope_freq_scale)
1✔
137
      <
1✔
138
      opt.context_token_limit)
1✔
139
  {
140
    ctx_params.rope_freq_scale /= 2;
×
141
  }
142

143
  struct llama_context* ctx = llama_init_from_model(model, ctx_params);
1✔
144
  if (!ctx) {
1✔
145
    llama_model_free(model);
×
146
    fildesh_log_error("Failed to create context.");
×
147
    return std::make_tuple(nullptr, nullptr);
×
148
  }
149
  return std::make_tuple(model, ctx);
1✔
150
}
151

152
static
153
  int
UNCOV
154
new_sampling_seed()
×
155
{
UNCOV
156
  return static_cast<int>(INT_MAX & time(NULL));
×
157
}
158

159
static
160
  void
UNCOV
161
apply_sampler_chain(
×
162
    struct llama_sampler* smpl,
163
    const rendezllama::inference::AdjustVia& adjust_via,
164
    const struct llama_model* model,
165
    unsigned seed,
166
    std::ostream& eout)
167
{
UNCOV
168
  const unsigned keep_one = 1;
×
169

UNCOV
170
  if (const auto* dry = std::get_if<AdjustViaKind::dry>(&adjust_via)) {
×
UNCOV
171
    static const char* seq_breakers[] = {
×
172
      "\n", ":",
173
    };
UNCOV
174
    llama_sampler_init_dry(
×
175
        llama_model_get_vocab(model),
176
        llama_model_n_ctx_train(model),
UNCOV
177
        dry->multiplier,
×
UNCOV
178
        dry->base,
×
UNCOV
179
        dry->allowed_length,
×
UNCOV
180
        dry->window_length,
×
181
        seq_breakers,
182
        sizeof(seq_breakers)/sizeof(*seq_breakers));
UNCOV
183
    eout << "dry:"
×
UNCOV
184
      << "\n  multiplier: " << dry->multiplier
×
UNCOV
185
      << "\n  base: " << dry->base
×
UNCOV
186
      << "\n  allowed_length: " << dry->allowed_length
×
UNCOV
187
      << "\n  window_length: " << dry->window_length
×
UNCOV
188
      << "\n";
×
189
  }
UNCOV
190
  if (const auto* min_p = std::get_if<AdjustViaKind::min_p>(&adjust_via)) {
×
UNCOV
191
    llama_sampler_chain_add(smpl, llama_sampler_init_min_p(*min_p, keep_one));
×
UNCOV
192
    eout << "min_p: " << *min_p << "\n";
×
193
  }
UNCOV
194
  if (const auto* penalize_with = std::get_if<AdjustViaKind::penalize_with>(&adjust_via)) {
×
UNCOV
195
    llama_sampler_init_penalties(
×
UNCOV
196
        penalize_with->window_length,
×
UNCOV
197
        penalize_with->repetition,
×
UNCOV
198
        penalize_with->frequency,
×
UNCOV
199
        penalize_with->presence);
×
UNCOV
200
    eout << "penalties:"
×
UNCOV
201
      << "\n  window_length: " << penalize_with->window_length
×
UNCOV
202
      << "\n  repetition: " << penalize_with->repetition
×
UNCOV
203
      << "\n  frequency: " << penalize_with->frequency
×
UNCOV
204
      << "\n  presence: " << penalize_with->presence
×
UNCOV
205
      << "\n";
×
206
  }
UNCOV
207
  if (const auto* temperature = std::get_if<AdjustViaKind::temperature>(&adjust_via)) {
×
UNCOV
208
    llama_sampler_chain_add(smpl, llama_sampler_init_temp(*temperature));
×
UNCOV
209
    eout << "temperature: " << *temperature << "\n";
×
210
  }
UNCOV
211
  if (const auto* top_k = std::get_if<AdjustViaKind::top_k>(&adjust_via)) {
×
UNCOV
212
    llama_sampler_chain_add(smpl, llama_sampler_init_top_k(*top_k));
×
UNCOV
213
    eout << "top_k: " << *top_k << "\n";
×
214
  }
UNCOV
215
  if (const auto* top_p = std::get_if<AdjustViaKind::top_p>(&adjust_via)) {
×
UNCOV
216
    llama_sampler_chain_add(smpl, llama_sampler_init_top_p(*top_p, keep_one));
×
UNCOV
217
    eout << "top_p: " << *top_p << "\n";
×
218
  }
UNCOV
219
  if (const auto* typical_p = std::get_if<AdjustViaKind::typical_p>(&adjust_via)) {
×
UNCOV
220
    llama_sampler_chain_add(smpl, llama_sampler_init_typical(*typical_p, keep_one));
×
UNCOV
221
    eout << "typical_p: " << *typical_p << "\n";
×
222
  }
UNCOV
223
  if (const auto* xtc = std::get_if<AdjustViaKind::xtc>(&adjust_via)) {
×
UNCOV
224
    llama_sampler_chain_add(smpl, llama_sampler_init_xtc(xtc->probability, xtc->threshold, keep_one, seed));
×
UNCOV
225
    eout << "xtc: "
×
UNCOV
226
      << "\n  probability: " << xtc->probability
×
UNCOV
227
      << "\n  threshold: " << xtc->threshold
×
UNCOV
228
      << "\n";
×
229
  }
UNCOV
230
}
×
231

232
static
233
  void
UNCOV
234
adaptive_p_sample(
×
235
    struct llama_sampler* smpl,
236
    const rendezllama::inference::AdaptiveP& adaptive_p,
237
    unsigned seed)
238
{
UNCOV
239
  llama_sampler_chain_add(
×
240
      smpl,
241
      llama_sampler_init_adaptive_p(
UNCOV
242
          adaptive_p.target,
×
UNCOV
243
          adaptive_p.decay,
×
244
          seed));
UNCOV
245
}
×
246

247
static
248
  void
UNCOV
249
mirostat_sample(
×
250
    struct llama_sampler* smpl,
251
    const rendezllama::inference::Mirostat& mirostat,
252
    unsigned seed,
253
    const rendezllama::Vocabulary& vocabulary)
254
{
UNCOV
255
  if (mirostat.version == 1) {
×
256
    const int mirostat_m = 100;
×
257
    llama_sampler_chain_add(
×
258
        smpl,
259
        llama_sampler_init_mirostat(
260
            vocabulary.cardinality(), seed,
×
261
            mirostat.tau, mirostat.eta, mirostat_m));
×
262
  }
UNCOV
263
  else if (mirostat.version == 2) {
×
UNCOV
264
    llama_sampler_chain_add(
×
265
        smpl,
266
        llama_sampler_init_mirostat_v2(
UNCOV
267
            seed, mirostat.tau, mirostat.eta));
×
268
  }
UNCOV
269
}
×
270

271
  void
UNCOV
272
Inference::reinitialize(const ChatOptions& opt, const struct llama_model* model)
×
273
{
UNCOV
274
  fildesh::ofstream eout("/dev/stderr");
×
275

UNCOV
276
  const auto* sampling = std::get_if<rendezllama::inference::Sampling>(&opt.infer_via);
×
277
  assert(sampling);
×
UNCOV
278
  auto seed = sampling->seed;
×
UNCOV
279
  if (smpl_ || seed < 0) {
×
280
    // We're retrying or just don't have a fixed seed, so we should reseed.
UNCOV
281
    seed = new_sampling_seed();
×
282
  }
UNCOV
283
  if (smpl_) {
×
UNCOV
284
    llama_sampler_free(smpl_);
×
UNCOV
285
    eout.open("/dev/null");
×
286
  }
UNCOV
287
  token_count_ = 0;
×
UNCOV
288
  auto smpl_param = llama_sampler_chain_default_params();
×
UNCOV
289
  smpl_ = llama_sampler_chain_init(smpl_param);
×
290

UNCOV
291
  for (const auto& adjust_via : sampling->adjust_thru) {
×
UNCOV
292
    apply_sampler_chain(smpl_, adjust_via, model, seed, eout);
×
293
  }
294

UNCOV
295
  if (std::get_if<rendezllama::inference::Probability>(&sampling->pick_via)) {
×
296
    llama_sampler_chain_add(smpl_, llama_sampler_init_dist(seed));
×
297
  }
UNCOV
298
  else if (std::get_if<rendezllama::inference::Determinism>(&sampling->pick_via)) {
×
UNCOV
299
    llama_sampler_chain_add(smpl_, llama_sampler_init_greedy());
×
300
  }
UNCOV
301
  else if (const auto* adaptive_p = std::get_if<rendezllama::inference::AdaptiveP>(&sampling->pick_via)) {
×
UNCOV
302
    adaptive_p_sample(smpl_, *adaptive_p, seed);
×
303
  }
UNCOV
304
  else if (const auto* mirostat = std::get_if<rendezllama::inference::Mirostat>(&sampling->pick_via)) {
×
UNCOV
305
    mirostat_sample(smpl_, *mirostat, seed, vocabulary_);
×
306
  }
307
  else {
NEW
308
    fildesh_log_error("Missing pick method? Using greedy.");
×
NEW
309
    llama_sampler_chain_add(smpl_, llama_sampler_init_greedy());
×
310
  }
UNCOV
311
}
×
312

313
  bool
314
Inference::commit_to_context(
11✔
315
    struct llama_context* ctx,
316
    ChatDisplay& chat_disp,
317
    ChatTrajectory& chat_traj,
318
    const ChatOptions& opt,
319
    const llama_model* model)
320
{
321
  assert(!chat_traj.erased_since_eval_ ||
11✔
322
         chat_traj.context_token_count_ < chat_traj.token_count());
323
  if (chat_traj.erased_since_eval_) {
11✔
UNCOV
324
    this->reinitialize(opt, model);
×
325
  }
326
  if (chat_traj.context_token_count_ == chat_traj.token_count()) {
11✔
327
    return true;
328
  }
329

330
  chat_traj.maybe_rollforget_within_limit(opt.context_token_limit, vocabulary_);
11✔
331

332
  // Reset thread count just in case the user reconfigured it.
333
  const unsigned thread_count = opt.thread_count;
11✔
334
  unsigned batch_thread_count = opt.batch_thread_count;
11✔
335
  if (batch_thread_count == 0) {
11✔
336
    batch_thread_count = std::thread::hardware_concurrency();
11✔
337
  }
338
  if (batch_thread_count == 0) {
11✔
339
    batch_thread_count = thread_count;
×
340
  }
341
  llama_set_n_threads(ctx, thread_count, batch_thread_count);
11✔
342

343
  // Clear KV cache past current position just in case the user deleted tokens.
344
  llama_memory_seq_rm(
22✔
345
      llama_get_memory(ctx),
346
      0, chat_traj.context_token_count_, -1);
11✔
347

348
  while (chat_traj.context_token_count_ < chat_traj.token_count()) {
33✔
349
    const unsigned n = std::min(
11✔
350
        opt.batch_count,
11✔
351
        chat_traj.token_count() - chat_traj.context_token_count_);
11✔
352

353
#if LLAMA_OPENBLAS_ON
354
    if (n < 32) {
355
      llama_set_n_threads(ctx, thread_count, batch_thread_count);
356
    }
357
    else {
358
      llama_set_n_threads(ctx, thread_count, 1);
359
    }
360
#endif
361
    chat_disp.show_new(chat_traj.context_token_count_ + n, chat_traj, vocabulary_);
11✔
362

363
    llama_batch batch = llama_batch_get_one(
11✔
364
        const_cast<int*>(&chat_traj.tokens()[chat_traj.context_token_count_]),
11✔
365
        n);
366
    const int istat = llama_decode(ctx, batch);
11✔
367
    if (istat != 0) {
11✔
368
      fildesh_log_error("Failed to eval.");
×
369
      chat_traj.context_token_count_ = 0;
×
370
      return false;
×
371
    }
372
    else {
373
      chat_traj.context_token_count_ += n;
11✔
374
    }
375
  }
376
  assert(chat_traj.context_token_count_ == chat_traj.token_count());
11✔
377
  chat_traj.erased_since_eval_ = false;
11✔
378
  while (token_count_ < chat_traj.token_count()) {
16✔
379
    Vocabulary::Token_id token_id = chat_traj.token_at(token_count_);
5✔
380
    llama_sampler_accept(smpl_, token_id);
5✔
381
    token_count_ += 1;
5✔
382
  }
383
  return true;
384
}
385

386
  void
387
Inference::sample_to_trajectory(
11✔
388
    ChatTrajectory& chat_traj,
389
    struct llama_context* ctx,
390
    bool preventing_newline)
391
{
392
  float* logits = llama_get_logits(ctx);
11✔
393
  if (preventing_newline) {
11✔
394
    // Zero probability for message-ending tokens when requested.
395
    logits[vocabulary_.eos_token_id()] = 0;
×
396
    logits[vocabulary_.newline_token_id()] = 0;
×
397
  }
398

399
  std::vector<llama_token_data> candidates;
11✔
400
  candidates.resize(vocabulary_.cardinality());
11✔
401
  for (llama_token i = 0; i < (llama_token)candidates.size(); ++i) {
22,539✔
402
    candidates[i] = llama_token_data{
22,528✔
403
      i, logits[i], 0.0f,
22,528✔
404
    };
405
  }
406
  logits = NULL;
11✔
407
  llama_token_data_array candidates_data[1] = {{
11✔
408
    candidates.data(),
11✔
409
    candidates.size(),
11✔
410
    /*selected=*/0,
411
    /*sorted=*/false,
412
  }};
11✔
413
  llama_sampler_apply(smpl_, candidates_data);
11✔
414
  chat_traj.push_back(candidates[candidates_data->selected].id);
11✔
415
  llama_sampler_accept(smpl_, chat_traj.token());
11✔
416
  token_count_ += 1;
11✔
417
}
11✔
418

STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc