• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

grpc / grpc-java / #18706

pending completion
#18706

push

github-actions

web-flow
implemented and tested static stride scheduler for weighted round robin load balancing policy (#10272)

30562 of 34641 relevant lines covered (88.22%)

0.88 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.49
/../xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java
1
/*
2
 * Copyright 2023 The gRPC Authors
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *     http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package io.grpc.xds;
18

19
import static com.google.common.base.Preconditions.checkArgument;
20
import static com.google.common.base.Preconditions.checkNotNull;
21

22
import com.google.common.annotations.VisibleForTesting;
23
import com.google.common.base.MoreObjects;
24
import com.google.common.base.Preconditions;
25
import io.grpc.ConnectivityState;
26
import io.grpc.ConnectivityStateInfo;
27
import io.grpc.Deadline.Ticker;
28
import io.grpc.EquivalentAddressGroup;
29
import io.grpc.ExperimentalApi;
30
import io.grpc.LoadBalancer;
31
import io.grpc.NameResolver;
32
import io.grpc.Status;
33
import io.grpc.SynchronizationContext;
34
import io.grpc.SynchronizationContext.ScheduledHandle;
35
import io.grpc.services.MetricReport;
36
import io.grpc.util.ForwardingLoadBalancerHelper;
37
import io.grpc.util.ForwardingSubchannel;
38
import io.grpc.util.RoundRobinLoadBalancer;
39
import io.grpc.xds.orca.OrcaOobUtil;
40
import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener;
41
import io.grpc.xds.orca.OrcaPerRequestUtil;
42
import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener;
43
import java.util.HashMap;
44
import java.util.HashSet;
45
import java.util.List;
46
import java.util.Map;
47
import java.util.Random;
48
import java.util.concurrent.ScheduledExecutorService;
49
import java.util.concurrent.TimeUnit;
50
import java.util.concurrent.atomic.AtomicInteger;
51
import java.util.logging.Level;
52
import java.util.logging.Logger;
53

54
/**
55
 * A {@link LoadBalancer} that provides weighted-round-robin load-balancing over
56
 * the {@link EquivalentAddressGroup}s from the {@link NameResolver}. The subchannel weights are
57
 * determined by backend metrics using ORCA.
58
 */
59
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/9885")
60
final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
61
  private static final Logger log = Logger.getLogger(
1✔
62
      WeightedRoundRobinLoadBalancer.class.getName());
1✔
63
  private WeightedRoundRobinLoadBalancerConfig config;
64
  private final SynchronizationContext syncContext;
65
  private final ScheduledExecutorService timeService;
66
  private ScheduledHandle weightUpdateTimer;
67
  private final Runnable updateWeightTask;
68
  private final Random random;
69
  private final long infTime;
70
  private final Ticker ticker;
71

72
  public WeightedRoundRobinLoadBalancer(Helper helper, Ticker ticker) {
73
    this(new WrrHelper(OrcaOobUtil.newOrcaReportingHelper(helper)), ticker, new Random());
1✔
74
  }
1✔
75

76
  public WeightedRoundRobinLoadBalancer(WrrHelper helper, Ticker ticker, Random random) {
77
    super(helper);
1✔
78
    helper.setLoadBalancer(this);
1✔
79
    this.ticker = checkNotNull(ticker, "ticker");
1✔
80
    this.infTime = ticker.nanoTime() + Long.MAX_VALUE;
1✔
81
    this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext");
1✔
82
    this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService");
1✔
83
    this.updateWeightTask = new UpdateWeightTask();
1✔
84
    this.random = random;
1✔
85
    log.log(Level.FINE, "weighted_round_robin LB created");
1✔
86
  }
1✔
87

88
  @VisibleForTesting
89
  WeightedRoundRobinLoadBalancer(Helper helper, Ticker ticker, Random random) {
90
    this(new WrrHelper(OrcaOobUtil.newOrcaReportingHelper(helper)), ticker, random);
1✔
91
  }
1✔
92

93
  @Override
94
  public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
95
    if (resolvedAddresses.getLoadBalancingPolicyConfig() == null) {
1✔
96
      handleNameResolutionError(Status.UNAVAILABLE.withDescription(
1✔
97
              "NameResolver returned no WeightedRoundRobinLoadBalancerConfig. addrs="
98
                      + resolvedAddresses.getAddresses()
1✔
99
                      + ", attrs=" + resolvedAddresses.getAttributes()));
1✔
100
      return false;
1✔
101
    }
102
    config =
1✔
103
            (WeightedRoundRobinLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig();
1✔
104
    boolean accepted = super.acceptResolvedAddresses(resolvedAddresses);
1✔
105
    if (weightUpdateTimer != null && weightUpdateTimer.isPending()) {
1✔
106
      weightUpdateTimer.cancel();
1✔
107
    }
108
    updateWeightTask.run();
1✔
109
    afterAcceptAddresses();
1✔
110
    return accepted;
1✔
111
  }
112

113
  @Override
114
  public RoundRobinPicker createReadyPicker(List<Subchannel> activeList) {
115
    return new WeightedRoundRobinPicker(activeList, config.enableOobLoadReport,
1✔
116
        config.errorUtilizationPenalty);
117
  }
118

119
  private final class UpdateWeightTask implements Runnable {
1✔
120
    @Override
121
    public void run() {
122
      if (currentPicker != null && currentPicker instanceof WeightedRoundRobinPicker) {
1✔
123
        ((WeightedRoundRobinPicker) currentPicker).updateWeight();
1✔
124
      }
125
      weightUpdateTimer = syncContext.schedule(this, config.weightUpdatePeriodNanos,
1✔
126
          TimeUnit.NANOSECONDS, timeService);
1✔
127
    }
1✔
128
  }
129

130
  private void afterAcceptAddresses() {
131
    for (Subchannel subchannel : getSubchannels()) {
1✔
132
      WrrSubchannel weightedSubchannel = (WrrSubchannel) subchannel;
1✔
133
      if (config.enableOobLoadReport) {
1✔
134
        OrcaOobUtil.setListener(weightedSubchannel,
1✔
135
            weightedSubchannel.new OrcaReportListener(config.errorUtilizationPenalty),
136
                OrcaOobUtil.OrcaReportingConfig.newBuilder()
1✔
137
                        .setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS)
1✔
138
                        .build());
1✔
139
      } else {
140
        OrcaOobUtil.setListener(weightedSubchannel, null, null);
1✔
141
      }
142
    }
1✔
143
  }
1✔
144

145
  @Override
146
  public void shutdown() {
147
    if (weightUpdateTimer != null) {
1✔
148
      weightUpdateTimer.cancel();
1✔
149
    }
150
    super.shutdown();
1✔
151
  }
1✔
152

153
  private static final class WrrHelper extends ForwardingLoadBalancerHelper {
154
    private final Helper delegate;
155
    private WeightedRoundRobinLoadBalancer wrr;
156

157
    WrrHelper(Helper helper) {
1✔
158
      this.delegate = helper;
1✔
159
    }
1✔
160

161
    void setLoadBalancer(WeightedRoundRobinLoadBalancer lb) {
162
      this.wrr = lb;
1✔
163
    }
1✔
164

165
    @Override
166
    protected Helper delegate() {
167
      return delegate;
1✔
168
    }
169

170
    @Override
171
    public Subchannel createSubchannel(CreateSubchannelArgs args) {
172
      return wrr.new WrrSubchannel(delegate().createSubchannel(args));
1✔
173
    }
174
  }
175

176
  @VisibleForTesting
177
  final class WrrSubchannel extends ForwardingSubchannel {
178
    private final Subchannel delegate;
179
    private volatile long lastUpdated;
180
    private volatile long nonEmptySince;
181
    private volatile double weight;
182

183
    WrrSubchannel(Subchannel delegate) {
1✔
184
      this.delegate = checkNotNull(delegate, "delegate");
1✔
185
    }
1✔
186

187
    @Override
188
    public void start(SubchannelStateListener listener) {
189
      delegate().start(new SubchannelStateListener() {
1✔
190
        @Override
191
        public void onSubchannelState(ConnectivityStateInfo newState) {
192
          if (newState.getState().equals(ConnectivityState.READY)) {
1✔
193
            nonEmptySince = infTime;
1✔
194
          }
195
          listener.onSubchannelState(newState);
1✔
196
        }
1✔
197
      });
198
    }
1✔
199

200
    private double getWeight() {
201
      if (config == null) {
1✔
202
        return 0;
×
203
      }
204
      long now = ticker.nanoTime();
1✔
205
      if (now - lastUpdated >= config.weightExpirationPeriodNanos) {
1✔
206
        nonEmptySince = infTime;
1✔
207
        return 0;
1✔
208
      } else if (now - nonEmptySince < config.blackoutPeriodNanos
1✔
209
          && config.blackoutPeriodNanos > 0) {
1✔
210
        return 0;
1✔
211
      } else {
212
        return weight;
1✔
213
      }
214
    }
215

216
    @Override
217
    protected Subchannel delegate() {
218
      return delegate;
1✔
219
    }
220

221
    final class OrcaReportListener implements OrcaPerRequestReportListener, OrcaOobReportListener {
222
      private final float errorUtilizationPenalty;
223

224
      OrcaReportListener(float errorUtilizationPenalty) {
1✔
225
        this.errorUtilizationPenalty = errorUtilizationPenalty;
1✔
226
      }
1✔
227

228
      @Override
229
      public void onLoadReport(MetricReport report) {
230
        double newWeight = 0;
1✔
231
        // Prefer application utilization and fallback to CPU utilization if unset.
232
        double utilization =
233
            report.getApplicationUtilization() > 0 ? report.getApplicationUtilization()
1✔
234
                : report.getCpuUtilization();
1✔
235
        if (utilization > 0 && report.getQps() > 0) {
1✔
236
          double penalty = 0;
1✔
237
          if (report.getEps() > 0 && errorUtilizationPenalty > 0) {
1✔
238
            penalty = report.getEps() / report.getQps() * errorUtilizationPenalty;
1✔
239
          }
240
          newWeight = report.getQps() / (utilization + penalty);
1✔
241
        }
242
        if (newWeight == 0) {
1✔
243
          return;
1✔
244
        }
245
        if (nonEmptySince == infTime) {
1✔
246
          nonEmptySince = ticker.nanoTime();
1✔
247
        }
248
        lastUpdated = ticker.nanoTime();
1✔
249
        weight = newWeight;
1✔
250
      }
1✔
251
    }
252
  }
253

254
  @VisibleForTesting
255
  final class WeightedRoundRobinPicker extends RoundRobinPicker {
256
    private final List<Subchannel> list;
257
    private final Map<Subchannel, OrcaPerRequestReportListener> subchannelToReportListenerMap =
1✔
258
        new HashMap<>();
259
    private final boolean enableOobLoadReport;
260
    private final float errorUtilizationPenalty;
261
    private volatile StaticStrideScheduler scheduler;
262

263
    WeightedRoundRobinPicker(List<Subchannel> list, boolean enableOobLoadReport,
264
        float errorUtilizationPenalty) {
1✔
265
      checkNotNull(list, "list");
1✔
266
      Preconditions.checkArgument(!list.isEmpty(), "empty list");
1✔
267
      this.list = list;
1✔
268
      for (Subchannel subchannel : list) {
1✔
269
        this.subchannelToReportListenerMap.put(subchannel,
1✔
270
            ((WrrSubchannel) subchannel).new OrcaReportListener(errorUtilizationPenalty));
271
      }
1✔
272
      this.enableOobLoadReport = enableOobLoadReport;
1✔
273
      this.errorUtilizationPenalty = errorUtilizationPenalty;
1✔
274
      updateWeight();
1✔
275
    }
1✔
276

277
    @Override
278
    public PickResult pickSubchannel(PickSubchannelArgs args) {
279
      Subchannel subchannel = list.get(scheduler.pick());
1✔
280
      if (!enableOobLoadReport) {
1✔
281
        return PickResult.withSubchannel(subchannel,
1✔
282
                OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(
1✔
283
                subchannelToReportListenerMap.getOrDefault(subchannel,
1✔
284
                    ((WrrSubchannel) subchannel).new OrcaReportListener(errorUtilizationPenalty))));
285
      } else {
286
        return PickResult.withSubchannel(subchannel);
1✔
287
      }
288
    }
289

290
    private void updateWeight() {
291
      float[] newWeights = new float[list.size()];
1✔
292
      for (int i = 0; i < list.size(); i++) {
1✔
293
        WrrSubchannel subchannel = (WrrSubchannel) list.get(i);
1✔
294
        double newWeight = subchannel.getWeight();
1✔
295
        newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f;
1✔
296
      }
297

298
      StaticStrideScheduler scheduler = new StaticStrideScheduler(newWeights, random);
1✔
299
      this.scheduler = scheduler;
1✔
300
    }
1✔
301

302
    @Override
303
    public String toString() {
304
      return MoreObjects.toStringHelper(WeightedRoundRobinPicker.class)
1✔
305
          .add("enableOobLoadReport", enableOobLoadReport)
1✔
306
          .add("errorUtilizationPenalty", errorUtilizationPenalty)
1✔
307
          .add("list", list).toString();
1✔
308
    }
309

310
    @VisibleForTesting
311
    List<Subchannel> getList() {
312
      return list;
1✔
313
    }
314

315
    @Override
316
    public boolean isEquivalentTo(RoundRobinPicker picker) {
317
      if (!(picker instanceof WeightedRoundRobinPicker)) {
1✔
318
        return false;
×
319
      }
320
      WeightedRoundRobinPicker other = (WeightedRoundRobinPicker) picker;
1✔
321
      if (other == this) {
1✔
322
        return true;
×
323
      }
324
      // the lists cannot contain duplicate subchannels
325
      return enableOobLoadReport == other.enableOobLoadReport
1✔
326
          && Float.compare(errorUtilizationPenalty, other.errorUtilizationPenalty) == 0
1✔
327
          && list.size() == other.list.size() && new HashSet<>(list).containsAll(other.list);
1✔
328
    }
329
  }
330

331
  /*
332
   * The Static Stride Scheduler is an implementation of an earliest deadline first (EDF) scheduler
333
   * in which each object's deadline is the multiplicative inverse of the object's weight.
334
   * <p>
335
   * The way in which this is implemented is through a static stride scheduler. 
336
   * The Static Stride Scheduler works by iterating through the list of subchannel weights
337
   * and using modular arithmetic to proportionally distribute picks, favoring entries 
338
   * with higher weights. It is based on the observation that the intended sequence generated 
339
   * from an EDF scheduler is a periodic one that can be achieved through modular arithmetic. 
340
   * The Static Stride Scheduler is more performant than other implementations of the EDF
341
   * Scheduler, as it removes the need for a priority queue (and thus mutex locks).
342
   * <p>
343
   * go/static-stride-scheduler
344
   * <p>
345
   *
346
   * <ul>
347
   *  <li>nextSequence() - O(1)
348
   *  <li>pick() - O(n)
349
   */
350
  @VisibleForTesting
351
  static final class StaticStrideScheduler {
352
    private final short[] scaledWeights;
353
    private final int sizeDivisor;
354
    private final AtomicInteger sequence;
355
    private static final int K_MAX_WEIGHT = 0xFFFF;
356

357
    StaticStrideScheduler(float[] weights, Random random) {
1✔
358
      checkArgument(weights.length >= 1, "Couldn't build scheduler: requires at least one weight");
1✔
359
      int numChannels = weights.length;
1✔
360
      int numWeightedChannels = 0;
1✔
361
      double sumWeight = 0;
1✔
362
      float maxWeight = 0;
1✔
363
      short meanWeight = 0;
1✔
364
      for (float weight : weights) {
1✔
365
        if (weight > 0) {
1✔
366
          sumWeight += weight;
1✔
367
          maxWeight = Math.max(weight, maxWeight);
1✔
368
          numWeightedChannels++;
1✔
369
        }
370
      }
371

372
      double scalingFactor = K_MAX_WEIGHT / maxWeight;
1✔
373
      if (numWeightedChannels > 0) {
1✔
374
        meanWeight = (short) Math.round(scalingFactor * sumWeight / numWeightedChannels);
1✔
375
      } else {
376
        meanWeight = 1;
1✔
377
      }
378

379
      // scales weights s.t. max(weights) == K_MAX_WEIGHT, meanWeight is scaled accordingly
380
      short[] scaledWeights = new short[numChannels];
1✔
381
      for (int i = 0; i < numChannels; i++) {
1✔
382
        if (weights[i] <= 0) {
1✔
383
          scaledWeights[i] = meanWeight;
1✔
384
        } else {
385
          scaledWeights[i] = (short) Math.round(weights[i] * scalingFactor);
1✔
386
        }
387
      }
388

389
      this.scaledWeights = scaledWeights;
1✔
390
      this.sizeDivisor = numChannels;
1✔
391
      this.sequence = new AtomicInteger(random.nextInt());
1✔
392

393
    }
1✔
394

395
    /** Returns the next sequence number and atomically increases sequence with wraparound. */
396
    private long nextSequence() {
397
      return Integer.toUnsignedLong(sequence.getAndIncrement());
1✔
398
    }
399

400
    @VisibleForTesting
401
    long getSequence() {
402
      return Integer.toUnsignedLong(sequence.get());
1✔
403
    }
404

405
    /*
406
     * Selects index of next backend server.
407
     * <p>
408
     * A 2D array is compactly represented as a function of W(backend), where the row
409
     * represents the generation and the column represents the backend index:
410
     * X(backend,generation) | generation ∈ [0,kMaxWeight).
411
     * Each element in the conceptual array is a boolean indicating whether the backend at
412
     * this index should be picked now. If false, the counter is incremented again,
413
     * and the new element is checked. An atomically incremented counter keeps track of our
414
     * backend and generation through modular arithmetic within the pick() method.
415
     * <p>
416
     * Modular arithmetic allows us to evenly distribute picks and skips between
417
     * generations based on W(backend).
418
     * X(backend,generation) = (W(backend) * generation) % kMaxWeight >= kMaxWeight - W(backend)
419
     * If we have the same three backends with weights:
420
     * W(backend) = {2,3,6} scaled to max(W(backend)) = 6, then X(backend,generation) is:
421
     * <p>
422
     * B0    B1    B2
423
     * T     T     T
424
     * F     F     T
425
     * F     T     T
426
     * T     F     T
427
     * F     T     T
428
     * F     F     T
429
     * The sequence of picked backend indices is given by
430
     * walking across and down: {0,1,2,2,1,2,0,2,1,2,2}.
431
     * <p>
432
     * To reduce the variance and spread the wasted work among different picks,
433
     * an offset that varies per backend index is also included to the calculation.
434
     */
435
    int pick() {
436
      while (true) {
437
        long sequence = this.nextSequence();
1✔
438
        int backendIndex = (int) (sequence % this.sizeDivisor);
1✔
439
        long generation = sequence / this.sizeDivisor;
1✔
440
        int weight = Short.toUnsignedInt(this.scaledWeights[backendIndex]);
1✔
441
        long offset = (long) K_MAX_WEIGHT / 2 * backendIndex;
1✔
442
        if ((weight * generation + offset) % K_MAX_WEIGHT < K_MAX_WEIGHT - weight) {
1✔
443
          continue;
1✔
444
        }
445
        return backendIndex;
1✔
446
      }
447
    }
448
  }
449

450
  static final class WeightedRoundRobinLoadBalancerConfig {
451
    final long blackoutPeriodNanos;
452
    final long weightExpirationPeriodNanos;
453
    final boolean enableOobLoadReport;
454
    final long oobReportingPeriodNanos;
455
    final long weightUpdatePeriodNanos;
456
    final float errorUtilizationPenalty;
457

458
    public static Builder newBuilder() {
459
      return new Builder();
1✔
460
    }
461

462
    private WeightedRoundRobinLoadBalancerConfig(long blackoutPeriodNanos,
463
                                                 long weightExpirationPeriodNanos,
464
                                                 boolean enableOobLoadReport,
465
                                                 long oobReportingPeriodNanos,
466
                                                 long weightUpdatePeriodNanos,
467
                                                 float errorUtilizationPenalty) {
1✔
468
      this.blackoutPeriodNanos = blackoutPeriodNanos;
1✔
469
      this.weightExpirationPeriodNanos = weightExpirationPeriodNanos;
1✔
470
      this.enableOobLoadReport = enableOobLoadReport;
1✔
471
      this.oobReportingPeriodNanos = oobReportingPeriodNanos;
1✔
472
      this.weightUpdatePeriodNanos = weightUpdatePeriodNanos;
1✔
473
      this.errorUtilizationPenalty = errorUtilizationPenalty;
1✔
474
    }
1✔
475

476
    static final class Builder {
477
      long blackoutPeriodNanos = 10_000_000_000L; // 10s
1✔
478
      long weightExpirationPeriodNanos = 180_000_000_000L; //3min
1✔
479
      boolean enableOobLoadReport = false;
1✔
480
      long oobReportingPeriodNanos = 10_000_000_000L; // 10s
1✔
481
      long weightUpdatePeriodNanos = 1_000_000_000L; // 1s
1✔
482
      float errorUtilizationPenalty = 1.0F;
1✔
483

484
      private Builder() {
1✔
485

486
      }
1✔
487

488
      Builder setBlackoutPeriodNanos(long blackoutPeriodNanos) {
489
        this.blackoutPeriodNanos = blackoutPeriodNanos;
1✔
490
        return this;
1✔
491
      }
492

493
      Builder setWeightExpirationPeriodNanos(long weightExpirationPeriodNanos) {
494
        this.weightExpirationPeriodNanos = weightExpirationPeriodNanos;
1✔
495
        return this;
1✔
496
      }
497

498
      Builder setEnableOobLoadReport(boolean enableOobLoadReport) {
499
        this.enableOobLoadReport = enableOobLoadReport;
1✔
500
        return this;
1✔
501
      }
502

503
      Builder setOobReportingPeriodNanos(long oobReportingPeriodNanos) {
504
        this.oobReportingPeriodNanos = oobReportingPeriodNanos;
1✔
505
        return this;
1✔
506
      }
507

508
      Builder setWeightUpdatePeriodNanos(long weightUpdatePeriodNanos) {
509
        this.weightUpdatePeriodNanos = weightUpdatePeriodNanos;
1✔
510
        return this;
1✔
511
      }
512

513
      Builder setErrorUtilizationPenalty(float errorUtilizationPenalty) {
514
        this.errorUtilizationPenalty = errorUtilizationPenalty;
1✔
515
        return this;
1✔
516
      }
517

518
      WeightedRoundRobinLoadBalancerConfig build() {
519
        return new WeightedRoundRobinLoadBalancerConfig(blackoutPeriodNanos,
1✔
520
                weightExpirationPeriodNanos, enableOobLoadReport, oobReportingPeriodNanos,
521
                weightUpdatePeriodNanos, errorUtilizationPenalty);
522
      }
523
    }
524
  }
525
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc