• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

grpc / grpc-java / #19230

13 May 2024 04:04PM UTC coverage: 88.403% (-0.004%) from 88.407%
#19230

push

github

web-flow
xds, rls: Experimental metrics are disabled by default (#11196) (#11197)

Experimental metrics (i.e WRR and RLS metrics) are disabled by default. Users are expected to explicitly enable while configuring metrics.

31606 of 35752 relevant lines covered (88.4%)

0.88 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.31
/../xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java
1
/*
2
 * Copyright 2023 The gRPC Authors
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *     http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package io.grpc.xds;
18

19
import static com.google.common.base.Preconditions.checkArgument;
20
import static com.google.common.base.Preconditions.checkElementIndex;
21
import static com.google.common.base.Preconditions.checkNotNull;
22

23
import com.google.common.annotations.VisibleForTesting;
24
import com.google.common.base.MoreObjects;
25
import com.google.common.base.Preconditions;
26
import com.google.common.collect.ImmutableList;
27
import com.google.common.collect.Lists;
28
import io.grpc.ConnectivityState;
29
import io.grpc.ConnectivityStateInfo;
30
import io.grpc.Deadline.Ticker;
31
import io.grpc.DoubleHistogramMetricInstrument;
32
import io.grpc.EquivalentAddressGroup;
33
import io.grpc.ExperimentalApi;
34
import io.grpc.LoadBalancer;
35
import io.grpc.LoadBalancerProvider;
36
import io.grpc.LongCounterMetricInstrument;
37
import io.grpc.MetricInstrumentRegistry;
38
import io.grpc.NameResolver;
39
import io.grpc.Status;
40
import io.grpc.SynchronizationContext;
41
import io.grpc.SynchronizationContext.ScheduledHandle;
42
import io.grpc.services.MetricReport;
43
import io.grpc.util.ForwardingLoadBalancerHelper;
44
import io.grpc.util.ForwardingSubchannel;
45
import io.grpc.util.RoundRobinLoadBalancer;
46
import io.grpc.xds.orca.OrcaOobUtil;
47
import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener;
48
import io.grpc.xds.orca.OrcaPerRequestUtil;
49
import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener;
50
import java.util.Collection;
51
import java.util.HashMap;
52
import java.util.HashSet;
53
import java.util.List;
54
import java.util.Map;
55
import java.util.Random;
56
import java.util.Set;
57
import java.util.concurrent.ScheduledExecutorService;
58
import java.util.concurrent.TimeUnit;
59
import java.util.concurrent.atomic.AtomicInteger;
60
import java.util.logging.Level;
61
import java.util.logging.Logger;
62

63
/**
64
 * A {@link LoadBalancer} that provides weighted-round-robin load-balancing over the
65
 * {@link EquivalentAddressGroup}s from the {@link NameResolver}. The subchannel weights are
66
 * determined by backend metrics using ORCA.
67
 */
68
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/9885")
69
final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
70

71
  private static final LongCounterMetricInstrument RR_FALLBACK_COUNTER;
72
  private static final LongCounterMetricInstrument ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER;
73
  private static final LongCounterMetricInstrument ENDPOINT_WEIGHT_STALE_COUNTER;
74
  private static final DoubleHistogramMetricInstrument ENDPOINT_WEIGHTS_HISTOGRAM;
75
  private static final Logger log = Logger.getLogger(
1✔
76
      WeightedRoundRobinLoadBalancer.class.getName());
1✔
77
  private WeightedRoundRobinLoadBalancerConfig config;
78
  private final SynchronizationContext syncContext;
79
  private final ScheduledExecutorService timeService;
80
  private ScheduledHandle weightUpdateTimer;
81
  private final Runnable updateWeightTask;
82
  private final AtomicInteger sequence;
83
  private final long infTime;
84
  private final Ticker ticker;
85
  private String locality = "";
1✔
86

87
  // The metric instruments are only registered once and shared by all instances of this LB.
88
  static {
89
    MetricInstrumentRegistry metricInstrumentRegistry
90
        = MetricInstrumentRegistry.getDefaultRegistry();
1✔
91
    RR_FALLBACK_COUNTER = metricInstrumentRegistry.registerLongCounter("grpc.lb.wrr.rr_fallback",
1✔
92
        "EXPERIMENTAL. Number of scheduler updates in which there were not enough endpoints "
93
            + "with valid weight, which caused the WRR policy to fall back to RR behavior",
94
        "{update}", Lists.newArrayList("grpc.target"), Lists.newArrayList("grpc.lb.locality"),
1✔
95
        false);
96
    ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER = metricInstrumentRegistry.registerLongCounter(
1✔
97
        "grpc.lb.wrr.endpoint_weight_not_yet_usable", "EXPERIMENTAL. Number of endpoints "
98
            + "from each scheduler update that don't yet have usable weight information",
99
        "{endpoint}", Lists.newArrayList("grpc.target"), Lists.newArrayList("grpc.lb.locality"),
1✔
100
        false);
101
    ENDPOINT_WEIGHT_STALE_COUNTER = metricInstrumentRegistry.registerLongCounter(
1✔
102
        "grpc.lb.wrr.endpoint_weight_stale",
103
        "EXPERIMENTAL. Number of endpoints from each scheduler update whose latest weight is "
104
            + "older than the expiration period", "{endpoint}", Lists.newArrayList("grpc.target"),
1✔
105
        Lists.newArrayList("grpc.lb.locality"), false);
1✔
106
    ENDPOINT_WEIGHTS_HISTOGRAM = metricInstrumentRegistry.registerDoubleHistogram(
1✔
107
        "grpc.lb.wrr.endpoint_weights",
108
        "EXPERIMENTAL. The histogram buckets will be endpoint weight ranges.",
109
        "{weight}", Lists.newArrayList(), Lists.newArrayList("grpc.target"),
1✔
110
        Lists.newArrayList("grpc.lb.locality"),
1✔
111
        false);
112
  }
1✔
113

114
  public WeightedRoundRobinLoadBalancer(Helper helper, Ticker ticker) {
115
    this(new WrrHelper(OrcaOobUtil.newOrcaReportingHelper(helper)), ticker, new Random());
1✔
116
  }
1✔
117

118
  public WeightedRoundRobinLoadBalancer(WrrHelper helper, Ticker ticker, Random random) {
119
    super(helper);
1✔
120
    helper.setLoadBalancer(this);
1✔
121
    this.ticker = checkNotNull(ticker, "ticker");
1✔
122
    this.infTime = ticker.nanoTime() + Long.MAX_VALUE;
1✔
123
    this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext");
1✔
124
    this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService");
1✔
125
    this.updateWeightTask = new UpdateWeightTask();
1✔
126
    this.sequence = new AtomicInteger(random.nextInt());
1✔
127
    log.log(Level.FINE, "weighted_round_robin LB created");
1✔
128
  }
1✔
129

130
  @VisibleForTesting
131
  WeightedRoundRobinLoadBalancer(Helper helper, Ticker ticker, Random random) {
132
    this(new WrrHelper(OrcaOobUtil.newOrcaReportingHelper(helper)), ticker, random);
1✔
133
  }
1✔
134

135
  @Override
136
  protected ChildLbState createChildLbState(Object key, Object policyConfig,
137
      SubchannelPicker initialPicker, ResolvedAddresses unused) {
138
    ChildLbState childLbState = new WeightedChildLbState(key, pickFirstLbProvider, policyConfig,
1✔
139
        initialPicker);
140
    return childLbState;
1✔
141
  }
142

143
  @Override
144
  public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
145
    if (resolvedAddresses.getLoadBalancingPolicyConfig() == null) {
1✔
146
      Status unavailableStatus = Status.UNAVAILABLE.withDescription(
1✔
147
              "NameResolver returned no WeightedRoundRobinLoadBalancerConfig. addrs="
148
                      + resolvedAddresses.getAddresses()
1✔
149
                      + ", attrs=" + resolvedAddresses.getAttributes());
1✔
150
      handleNameResolutionError(unavailableStatus);
1✔
151
      return unavailableStatus;
1✔
152
    }
153
    String locality = resolvedAddresses.getAttributes().get(WeightedTargetLoadBalancer.CHILD_NAME);
1✔
154
    if (locality != null) {
1✔
155
      this.locality = locality;
1✔
156
    } else {
157
      this.locality = "";
1✔
158
    }
159
    config =
1✔
160
            (WeightedRoundRobinLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig();
1✔
161
    AcceptResolvedAddrRetVal acceptRetVal;
162
    try {
163
      resolvingAddresses = true;
1✔
164
      acceptRetVal = acceptResolvedAddressesInternal(resolvedAddresses);
1✔
165
      if (!acceptRetVal.status.isOk()) {
1✔
166
        return acceptRetVal.status;
×
167
      }
168

169
      if (weightUpdateTimer != null && weightUpdateTimer.isPending()) {
1✔
170
        weightUpdateTimer.cancel();
1✔
171
      }
172
      updateWeightTask.run();
1✔
173

174
      createAndApplyOrcaListeners();
1✔
175

176
      // Must update channel picker before return so that new RPCs will not be routed to deleted
177
      // clusters and resolver can remove them in service config.
178
      updateOverallBalancingState();
1✔
179

180
      shutdownRemoved(acceptRetVal.removedChildren);
1✔
181
    } finally {
182
      resolvingAddresses = false;
1✔
183
    }
184

185
    return acceptRetVal.status;
1✔
186
  }
187

188
  @Override
189
  public SubchannelPicker createReadyPicker(Collection<ChildLbState> activeList) {
190
    return new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList),
1✔
191
        config.enableOobLoadReport, config.errorUtilizationPenalty, sequence, getHelper(),
1✔
192
        locality);
193
  }
194

195
  @VisibleForTesting
196
  final class WeightedChildLbState extends ChildLbState {
197

198
    private final Set<WrrSubchannel> subchannels = new HashSet<>();
1✔
199
    private volatile long lastUpdated;
200
    private volatile long nonEmptySince;
201
    private volatile double weight = 0;
1✔
202

203
    private OrcaReportListener orcaReportListener;
204

205
    public WeightedChildLbState(Object key, LoadBalancerProvider policyProvider, Object childConfig,
206
        SubchannelPicker initialPicker) {
1✔
207
      super(key, policyProvider, childConfig, initialPicker);
1✔
208
    }
1✔
209

210
    private double getWeight(AtomicInteger staleEndpoints, AtomicInteger notYetUsableEndpoints) {
211
      if (config == null) {
1✔
212
        return 0;
×
213
      }
214
      long now = ticker.nanoTime();
1✔
215
      if (now - lastUpdated >= config.weightExpirationPeriodNanos) {
1✔
216
        nonEmptySince = infTime;
1✔
217
        staleEndpoints.incrementAndGet();
1✔
218
        return 0;
1✔
219
      } else if (now - nonEmptySince < config.blackoutPeriodNanos
1✔
220
          && config.blackoutPeriodNanos > 0) {
1✔
221
        notYetUsableEndpoints.incrementAndGet();
1✔
222
        return 0;
1✔
223
      } else {
224
        return weight;
1✔
225
      }
226
    }
227

228
    public void addSubchannel(WrrSubchannel wrrSubchannel) {
229
      subchannels.add(wrrSubchannel);
1✔
230
    }
1✔
231

232
    public OrcaReportListener getOrCreateOrcaListener(float errorUtilizationPenalty) {
233
      if (orcaReportListener != null
1✔
234
          && orcaReportListener.errorUtilizationPenalty == errorUtilizationPenalty) {
1✔
235
        return orcaReportListener;
1✔
236
      }
237
      orcaReportListener = new OrcaReportListener(errorUtilizationPenalty);
1✔
238
      return orcaReportListener;
1✔
239
    }
240

241
    public void removeSubchannel(WrrSubchannel wrrSubchannel) {
242
      subchannels.remove(wrrSubchannel);
1✔
243
    }
1✔
244

245
    final class OrcaReportListener implements OrcaPerRequestReportListener, OrcaOobReportListener {
246
      private final float errorUtilizationPenalty;
247

248
      OrcaReportListener(float errorUtilizationPenalty) {
1✔
249
        this.errorUtilizationPenalty = errorUtilizationPenalty;
1✔
250
      }
1✔
251

252
      @Override
253
      public void onLoadReport(MetricReport report) {
254
        double newWeight = 0;
1✔
255
        // Prefer application utilization and fallback to CPU utilization if unset.
256
        double utilization =
257
            report.getApplicationUtilization() > 0 ? report.getApplicationUtilization()
1✔
258
                : report.getCpuUtilization();
1✔
259
        if (utilization > 0 && report.getQps() > 0) {
1✔
260
          double penalty = 0;
1✔
261
          if (report.getEps() > 0 && errorUtilizationPenalty > 0) {
1✔
262
            penalty = report.getEps() / report.getQps() * errorUtilizationPenalty;
1✔
263
          }
264
          newWeight = report.getQps() / (utilization + penalty);
1✔
265
        }
266
        if (newWeight == 0) {
1✔
267
          return;
1✔
268
        }
269
        if (nonEmptySince == infTime) {
1✔
270
          nonEmptySince = ticker.nanoTime();
1✔
271
        }
272
        lastUpdated = ticker.nanoTime();
1✔
273
        weight = newWeight;
1✔
274
      }
1✔
275
    }
276
  }
277

278
  private final class UpdateWeightTask implements Runnable {
1✔
279
    @Override
280
    public void run() {
281
      if (currentPicker != null && currentPicker instanceof WeightedRoundRobinPicker) {
1✔
282
        ((WeightedRoundRobinPicker) currentPicker).updateWeight();
1✔
283
      }
284
      weightUpdateTimer = syncContext.schedule(this, config.weightUpdatePeriodNanos,
1✔
285
          TimeUnit.NANOSECONDS, timeService);
1✔
286
    }
1✔
287
  }
288

289
  private void createAndApplyOrcaListeners() {
290
    for (ChildLbState child : getChildLbStates()) {
1✔
291
      WeightedChildLbState wChild = (WeightedChildLbState) child;
1✔
292
      for (WrrSubchannel weightedSubchannel : wChild.subchannels) {
1✔
293
        if (config.enableOobLoadReport) {
1✔
294
          OrcaOobUtil.setListener(weightedSubchannel,
1✔
295
              wChild.getOrCreateOrcaListener(config.errorUtilizationPenalty),
1✔
296
              OrcaOobUtil.OrcaReportingConfig.newBuilder()
1✔
297
                  .setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS)
1✔
298
                  .build());
1✔
299
        } else {
300
          OrcaOobUtil.setListener(weightedSubchannel, null, null);
1✔
301
        }
302
      }
1✔
303
    }
1✔
304
  }
1✔
305

306
  @Override
307
  public void shutdown() {
308
    if (weightUpdateTimer != null) {
1✔
309
      weightUpdateTimer.cancel();
1✔
310
    }
311
    super.shutdown();
1✔
312
  }
1✔
313

314
  private static final class WrrHelper extends ForwardingLoadBalancerHelper {
315
    private final Helper delegate;
316
    private WeightedRoundRobinLoadBalancer wrr;
317

318
    WrrHelper(Helper helper) {
1✔
319
      this.delegate = helper;
1✔
320
    }
1✔
321

322
    void setLoadBalancer(WeightedRoundRobinLoadBalancer lb) {
323
      this.wrr = lb;
1✔
324
    }
1✔
325

326
    @Override
327
    protected Helper delegate() {
328
      return delegate;
1✔
329
    }
330

331
    @Override
332
    public Subchannel createSubchannel(CreateSubchannelArgs args) {
333
      checkElementIndex(0, args.getAddresses().size(), "Empty address group");
1✔
334
      WeightedChildLbState childLbState =
1✔
335
          (WeightedChildLbState) wrr.getChildLbStateEag(args.getAddresses().get(0));
1✔
336
      return wrr.new WrrSubchannel(delegate().createSubchannel(args), childLbState);
1✔
337
    }
338
  }
339

340
  @VisibleForTesting
341
  final class WrrSubchannel extends ForwardingSubchannel {
342
    private final Subchannel delegate;
343
    private final WeightedChildLbState owner;
344

345
    WrrSubchannel(Subchannel delegate, WeightedChildLbState owner) {
1✔
346
      this.delegate = checkNotNull(delegate, "delegate");
1✔
347
      this.owner = checkNotNull(owner, "owner");
1✔
348
    }
1✔
349

350
    @Override
351
    public void start(SubchannelStateListener listener) {
352
      owner.addSubchannel(this);
1✔
353
      delegate().start(new SubchannelStateListener() {
1✔
354
        @Override
355
        public void onSubchannelState(ConnectivityStateInfo newState) {
356
          if (newState.getState().equals(ConnectivityState.READY)) {
1✔
357
            owner.nonEmptySince = infTime;
1✔
358
          }
359
          listener.onSubchannelState(newState);
1✔
360
        }
1✔
361
      });
362
    }
1✔
363

364
    @Override
365
    protected Subchannel delegate() {
366
      return delegate;
1✔
367
    }
368

369
    @Override
370
    public void shutdown() {
371
      super.shutdown();
1✔
372
      owner.removeSubchannel(this);
1✔
373
    }
1✔
374
  }
375

376
  @VisibleForTesting
377
  static final class WeightedRoundRobinPicker extends SubchannelPicker {
378
    private final List<ChildLbState> children;
379
    private final Map<Subchannel, OrcaPerRequestReportListener> subchannelToReportListenerMap =
1✔
380
        new HashMap<>();
381
    private final boolean enableOobLoadReport;
382
    private final float errorUtilizationPenalty;
383
    private final AtomicInteger sequence;
384
    private final int hashCode;
385
    private final LoadBalancer.Helper helper;
386
    private final String locality;
387
    private volatile StaticStrideScheduler scheduler;
388

389
    WeightedRoundRobinPicker(List<ChildLbState> children, boolean enableOobLoadReport,
390
        float errorUtilizationPenalty, AtomicInteger sequence, LoadBalancer.Helper helper,
391
        String locality) {
1✔
392
      checkNotNull(children, "children");
1✔
393
      Preconditions.checkArgument(!children.isEmpty(), "empty child list");
1✔
394
      this.children = children;
1✔
395
      for (ChildLbState child : children) {
1✔
396
        WeightedChildLbState wChild = (WeightedChildLbState) child;
1✔
397
        for (WrrSubchannel subchannel : wChild.subchannels) {
1✔
398
          this.subchannelToReportListenerMap
1✔
399
              .put(subchannel, wChild.getOrCreateOrcaListener(errorUtilizationPenalty));
1✔
400
        }
1✔
401
      }
1✔
402
      this.enableOobLoadReport = enableOobLoadReport;
1✔
403
      this.errorUtilizationPenalty = errorUtilizationPenalty;
1✔
404
      this.sequence = checkNotNull(sequence, "sequence");
1✔
405
      this.helper = helper;
1✔
406
      this.locality = checkNotNull(locality, "locality");
1✔
407

408
      // For equality we treat children as a set; use hash code as defined by Set
409
      int sum = 0;
1✔
410
      for (ChildLbState child : children) {
1✔
411
        sum += child.hashCode();
1✔
412
      }
1✔
413
      this.hashCode = sum
1✔
414
          ^ Boolean.hashCode(enableOobLoadReport)
1✔
415
          ^ Float.hashCode(errorUtilizationPenalty);
1✔
416

417
      updateWeight();
1✔
418
    }
1✔
419

420
    @Override
421
    public PickResult pickSubchannel(PickSubchannelArgs args) {
422
      ChildLbState childLbState = children.get(scheduler.pick());
1✔
423
      WeightedChildLbState wChild = (WeightedChildLbState) childLbState;
1✔
424
      PickResult pickResult = childLbState.getCurrentPicker().pickSubchannel(args);
1✔
425
      Subchannel subchannel = pickResult.getSubchannel();
1✔
426
      if (subchannel == null) {
1✔
427
        return pickResult;
1✔
428
      }
429
      if (!enableOobLoadReport) {
1✔
430
        return PickResult.withSubchannel(subchannel,
1✔
431
            OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(
1✔
432
                subchannelToReportListenerMap.getOrDefault(subchannel,
1✔
433
                    wChild.getOrCreateOrcaListener(errorUtilizationPenalty))));
1✔
434
      } else {
435
        return PickResult.withSubchannel(subchannel);
1✔
436
      }
437
    }
438

439
    private void updateWeight() {
440
      float[] newWeights = new float[children.size()];
1✔
441
      AtomicInteger staleEndpoints = new AtomicInteger();
1✔
442
      AtomicInteger notYetUsableEndpoints = new AtomicInteger();
1✔
443
      for (int i = 0; i < children.size(); i++) {
1✔
444
        double newWeight = ((WeightedChildLbState) children.get(i)).getWeight(staleEndpoints,
1✔
445
            notYetUsableEndpoints);
446
        // TODO: add locality label once available
447
        helper.getMetricRecorder()
1✔
448
            .recordDoubleHistogram(ENDPOINT_WEIGHTS_HISTOGRAM, newWeight,
1✔
449
                ImmutableList.of(helper.getChannelTarget()),
1✔
450
                ImmutableList.of(locality));
1✔
451
        newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f;
1✔
452
      }
453
      if (staleEndpoints.get() > 0) {
1✔
454
        // TODO: add locality label once available
455
        helper.getMetricRecorder()
1✔
456
            .addLongCounter(ENDPOINT_WEIGHT_STALE_COUNTER, staleEndpoints.get(),
1✔
457
                ImmutableList.of(helper.getChannelTarget()),
1✔
458
                ImmutableList.of(locality));
1✔
459
      }
460
      if (notYetUsableEndpoints.get() > 0) {
1✔
461
        // TODO: add locality label once available
462
        helper.getMetricRecorder()
1✔
463
            .addLongCounter(ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER, notYetUsableEndpoints.get(),
1✔
464
                ImmutableList.of(helper.getChannelTarget()), ImmutableList.of(locality));
1✔
465
      }
466

467
      this.scheduler = new StaticStrideScheduler(newWeights, sequence);
1✔
468
      if (this.scheduler.usesRoundRobin()) {
1✔
469
        // TODO: locality label once available
470
        helper.getMetricRecorder()
1✔
471
            .addLongCounter(RR_FALLBACK_COUNTER, 1, ImmutableList.of(helper.getChannelTarget()),
1✔
472
                ImmutableList.of(locality));
1✔
473
      }
474
    }
1✔
475

476
    @Override
477
    public String toString() {
478
      return MoreObjects.toStringHelper(WeightedRoundRobinPicker.class)
1✔
479
          .add("enableOobLoadReport", enableOobLoadReport)
1✔
480
          .add("errorUtilizationPenalty", errorUtilizationPenalty)
1✔
481
          .add("list", children).toString();
1✔
482
    }
483

484
    @VisibleForTesting
485
    List<ChildLbState> getChildren() {
486
      return children;
1✔
487
    }
488

489
    @Override
490
    public int hashCode() {
491
      return hashCode;
×
492
    }
493

494
    @Override
495
    public boolean equals(Object o) {
496
      if (!(o instanceof WeightedRoundRobinPicker)) {
1✔
497
        return false;
×
498
      }
499
      WeightedRoundRobinPicker other = (WeightedRoundRobinPicker) o;
1✔
500
      if (other == this) {
1✔
501
        return true;
×
502
      }
503
      // the lists cannot contain duplicate subchannels
504
      return hashCode == other.hashCode
1✔
505
          && sequence == other.sequence
506
          && enableOobLoadReport == other.enableOobLoadReport
507
          && Float.compare(errorUtilizationPenalty, other.errorUtilizationPenalty) == 0
1✔
508
          && children.size() == other.children.size()
1✔
509
          && new HashSet<>(children).containsAll(other.children);
1✔
510
    }
511
  }
512

513
  /*
514
   * The Static Stride Scheduler is an implementation of an earliest deadline first (EDF) scheduler
515
   * in which each object's deadline is the multiplicative inverse of the object's weight.
516
   * <p>
517
   * The way in which this is implemented is through a static stride scheduler. 
518
   * The Static Stride Scheduler works by iterating through the list of subchannel weights
519
   * and using modular arithmetic to proportionally distribute picks, favoring entries 
520
   * with higher weights. It is based on the observation that the intended sequence generated 
521
   * from an EDF scheduler is a periodic one that can be achieved through modular arithmetic. 
522
   * The Static Stride Scheduler is more performant than other implementations of the EDF
523
   * Scheduler, as it removes the need for a priority queue (and thus mutex locks).
524
   * <p>
525
   * go/static-stride-scheduler
526
   * <p>
527
   *
528
   * <ul>
529
   *  <li>nextSequence() - O(1)
530
   *  <li>pick() - O(n)
531
   */
532
  @VisibleForTesting
533
  static final class StaticStrideScheduler {
534
    private final short[] scaledWeights;
535
    private final AtomicInteger sequence;
536
    private final boolean usesRoundRobin;
537
    private static final int K_MAX_WEIGHT = 0xFFFF;
538

539
    // Assuming the mean of all known weights is M, StaticStrideScheduler will clamp
540
    // weights bigger than M*kMaxRatio and weights smaller than M*kMinRatio.
541
    //
542
    // This is done as a performance optimization by limiting the number of rounds for picks
543
    // for edge cases where channels have large differences in subchannel weights.
544
    // In this case, without these clips, it would potentially require the scheduler to
545
    // frequently traverse through the entire subchannel list within the pick method.
546
    //
547
    // The current values of 10 and 0.1 were chosen without any experimenting. It should
548
    // decrease the amount of sequences that the scheduler must traverse through in order
549
    // to pick a high weight subchannel in such corner cases.
550
    // But, it also makes WeightedRoundRobin to send slightly more requests to
551
    // potentially very bad tasks (that would have near-zero weights) than zero.
552
    // This is not necessarily a downside, though. Perhaps this is not a problem at
553
    // all, and we can increase this value if needed to save CPU cycles.
554
    private static final double K_MAX_RATIO = 10;
555
    private static final double K_MIN_RATIO = 0.1;
556

557
    StaticStrideScheduler(float[] weights, AtomicInteger sequence) {
1✔
558
      checkArgument(weights.length >= 1, "Couldn't build scheduler: requires at least one weight");
1✔
559
      int numChannels = weights.length;
1✔
560
      int numWeightedChannels = 0;
1✔
561
      double sumWeight = 0;
1✔
562
      double unscaledMeanWeight;
563
      float unscaledMaxWeight = 0;
1✔
564
      for (float weight : weights) {
1✔
565
        if (weight > 0) {
1✔
566
          sumWeight += weight;
1✔
567
          unscaledMaxWeight = Math.max(weight, unscaledMaxWeight);
1✔
568
          numWeightedChannels++;
1✔
569
        }
570
      }
571

572
      // Adjust max value s.t. ratio does not exceed K_MAX_RATIO. This should
573
      // ensure that we on average do at most K_MAX_RATIO rounds for picks.
574
      if (numWeightedChannels > 0) {
1✔
575
        unscaledMeanWeight = sumWeight / numWeightedChannels;
1✔
576
        unscaledMaxWeight = Math.min(unscaledMaxWeight, (float) (K_MAX_RATIO * unscaledMeanWeight));
1✔
577
        usesRoundRobin = false;
1✔
578
      } else {
579
        // Fall back to round robin if all values are non-positives
580
        usesRoundRobin = true;
1✔
581
        unscaledMeanWeight = 1;
1✔
582
        unscaledMaxWeight = 1;
1✔
583
      }
584

585
      // Scales weights s.t. max(weights) == K_MAX_WEIGHT, meanWeight is scaled accordingly.
586
      // Note that, since we cap the weights to stay within K_MAX_RATIO, meanWeight might not
587
      // match the actual mean of the values that end up in the scheduler.
588
      double scalingFactor = K_MAX_WEIGHT / unscaledMaxWeight;
1✔
589
      // We compute weightLowerBound and clamp it to 1 from below so that in the
590
      // worst case, we represent tiny weights as 1.
591
      int weightLowerBound = (int) Math.ceil(scalingFactor * unscaledMeanWeight * K_MIN_RATIO);
1✔
592
      short[] scaledWeights = new short[numChannels];
1✔
593
      for (int i = 0; i < numChannels; i++) {
1✔
594
        if (weights[i] <= 0) {
1✔
595
          scaledWeights[i] = (short) Math.round(scalingFactor * unscaledMeanWeight);
1✔
596
        } else {
597
          int weight = (int) Math.round(scalingFactor * Math.min(weights[i], unscaledMaxWeight));
1✔
598
          scaledWeights[i] = (short) Math.max(weight, weightLowerBound);
1✔
599
        }
600
      }
601

602
      this.scaledWeights = scaledWeights;
1✔
603
      this.sequence = sequence;
1✔
604
    }
1✔
605

606
    // Without properly weighted channels, we do plain vanilla round_robin.
607
    boolean usesRoundRobin() {
608
      return usesRoundRobin;
1✔
609
    }
610

611
    /**
612
     * Returns the next sequence number and atomically increases sequence with wraparound.
613
     */
614
    private long nextSequence() {
615
      return Integer.toUnsignedLong(sequence.getAndIncrement());
1✔
616
    }
617

618
    /*
619
     * Selects index of next backend server.
620
     * <p>
621
     * A 2D array is compactly represented as a function of W(backend), where the row
622
     * represents the generation and the column represents the backend index:
623
     * X(backend,generation) | generation ∈ [0,kMaxWeight).
624
     * Each element in the conceptual array is a boolean indicating whether the backend at
625
     * this index should be picked now. If false, the counter is incremented again,
626
     * and the new element is checked. An atomically incremented counter keeps track of our
627
     * backend and generation through modular arithmetic within the pick() method.
628
     * <p>
629
     * Modular arithmetic allows us to evenly distribute picks and skips between
630
     * generations based on W(backend).
631
     * X(backend,generation) = (W(backend) * generation) % kMaxWeight >= kMaxWeight - W(backend)
632
     * If we have the same three backends with weights:
633
     * W(backend) = {2,3,6} scaled to max(W(backend)) = 6, then X(backend,generation) is:
634
     * <p>
635
     * B0    B1    B2
636
     * T     T     T
637
     * F     F     T
638
     * F     T     T
639
     * T     F     T
640
     * F     T     T
641
     * F     F     T
642
     * The sequence of picked backend indices is given by
643
     * walking across and down: {0,1,2,2,1,2,0,2,1,2,2}.
644
     * <p>
645
     * To reduce the variance and spread the wasted work among different picks,
646
     * an offset that varies per backend index is also included to the calculation.
647
     */
648
    int pick() {
649
      while (true) {
650
        long sequence = this.nextSequence();
1✔
651
        int backendIndex = (int) (sequence % scaledWeights.length);
1✔
652
        long generation = sequence / scaledWeights.length;
1✔
653
        int weight = Short.toUnsignedInt(scaledWeights[backendIndex]);
1✔
654
        long offset = (long) K_MAX_WEIGHT / 2 * backendIndex;
1✔
655
        if ((weight * generation + offset) % K_MAX_WEIGHT < K_MAX_WEIGHT - weight) {
1✔
656
          continue;
1✔
657
        }
658
        return backendIndex;
1✔
659
      }
660
    }
661
  }
662

663
  static final class WeightedRoundRobinLoadBalancerConfig {
664
    final long blackoutPeriodNanos;
665
    final long weightExpirationPeriodNanos;
666
    final boolean enableOobLoadReport;
667
    final long oobReportingPeriodNanos;
668
    final long weightUpdatePeriodNanos;
669
    final float errorUtilizationPenalty;
670

671
    public static Builder newBuilder() {
672
      return new Builder();
1✔
673
    }
674

675
    private WeightedRoundRobinLoadBalancerConfig(long blackoutPeriodNanos,
676
                                                 long weightExpirationPeriodNanos,
677
                                                 boolean enableOobLoadReport,
678
                                                 long oobReportingPeriodNanos,
679
                                                 long weightUpdatePeriodNanos,
680
                                                 float errorUtilizationPenalty) {
1✔
681
      this.blackoutPeriodNanos = blackoutPeriodNanos;
1✔
682
      this.weightExpirationPeriodNanos = weightExpirationPeriodNanos;
1✔
683
      this.enableOobLoadReport = enableOobLoadReport;
1✔
684
      this.oobReportingPeriodNanos = oobReportingPeriodNanos;
1✔
685
      this.weightUpdatePeriodNanos = weightUpdatePeriodNanos;
1✔
686
      this.errorUtilizationPenalty = errorUtilizationPenalty;
1✔
687
    }
1✔
688

689
    static final class Builder {
690
      long blackoutPeriodNanos = 10_000_000_000L; // 10s
1✔
691
      long weightExpirationPeriodNanos = 180_000_000_000L; //3min
1✔
692
      boolean enableOobLoadReport = false;
1✔
693
      long oobReportingPeriodNanos = 10_000_000_000L; // 10s
1✔
694
      long weightUpdatePeriodNanos = 1_000_000_000L; // 1s
1✔
695
      float errorUtilizationPenalty = 1.0F;
1✔
696

697
      private Builder() {
1✔
698

699
      }
1✔
700

701
      @SuppressWarnings("UnusedReturnValue")
702
      Builder setBlackoutPeriodNanos(long blackoutPeriodNanos) {
703
        this.blackoutPeriodNanos = blackoutPeriodNanos;
1✔
704
        return this;
1✔
705
      }
706

707
      @SuppressWarnings("UnusedReturnValue")
708
      Builder setWeightExpirationPeriodNanos(long weightExpirationPeriodNanos) {
709
        this.weightExpirationPeriodNanos = weightExpirationPeriodNanos;
1✔
710
        return this;
1✔
711
      }
712

713
      Builder setEnableOobLoadReport(boolean enableOobLoadReport) {
714
        this.enableOobLoadReport = enableOobLoadReport;
1✔
715
        return this;
1✔
716
      }
717

718
      Builder setOobReportingPeriodNanos(long oobReportingPeriodNanos) {
719
        this.oobReportingPeriodNanos = oobReportingPeriodNanos;
1✔
720
        return this;
1✔
721
      }
722

723
      Builder setWeightUpdatePeriodNanos(long weightUpdatePeriodNanos) {
724
        this.weightUpdatePeriodNanos = weightUpdatePeriodNanos;
1✔
725
        return this;
1✔
726
      }
727

728
      Builder setErrorUtilizationPenalty(float errorUtilizationPenalty) {
729
        this.errorUtilizationPenalty = errorUtilizationPenalty;
1✔
730
        return this;
1✔
731
      }
732

733
      WeightedRoundRobinLoadBalancerConfig build() {
734
        return new WeightedRoundRobinLoadBalancerConfig(blackoutPeriodNanos,
1✔
735
                weightExpirationPeriodNanos, enableOobLoadReport, oobReportingPeriodNanos,
736
                weightUpdatePeriodNanos, errorUtilizationPenalty);
737
      }
738
    }
739
  }
740
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc