• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

grpc / grpc-java / #19416

12 Aug 2024 06:23PM UTC coverage: 84.469% (-0.002%) from 84.471%
#19416

push

github

web-flow
xds: WRRPicker must not access unsynchronized data in ChildLbState

There was no point to using subchannels as keys to
subchannelToReportListenerMap, as the listener is per-child. That meant
the keys would be guaranteed to be known ahead-of-time and the
unsynchronized getOrCreateOrcaListener() during picking was unnecessary.

The picker still stores ChildLbStates to make sure that updating weights
uses the correct children, but the picker itself no longer references
ChildLbStates except in the constructor. That means weight calculation
is moved into the LB policy, as child.getWeight() is unsynchronized, and
the picker no longer needs a reference to helper.

33389 of 39528 relevant lines covered (84.47%)

0.84 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.38
/../xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java
1
/*
2
 * Copyright 2023 The gRPC Authors
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *     http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package io.grpc.xds;
18

19
import static com.google.common.base.Preconditions.checkArgument;
20
import static com.google.common.base.Preconditions.checkNotNull;
21

22
import com.google.common.annotations.VisibleForTesting;
23
import com.google.common.base.MoreObjects;
24
import com.google.common.base.Preconditions;
25
import com.google.common.collect.ImmutableList;
26
import com.google.common.collect.Lists;
27
import io.grpc.ConnectivityState;
28
import io.grpc.ConnectivityStateInfo;
29
import io.grpc.Deadline.Ticker;
30
import io.grpc.DoubleHistogramMetricInstrument;
31
import io.grpc.EquivalentAddressGroup;
32
import io.grpc.LoadBalancer;
33
import io.grpc.LoadBalancerProvider;
34
import io.grpc.LongCounterMetricInstrument;
35
import io.grpc.MetricInstrumentRegistry;
36
import io.grpc.NameResolver;
37
import io.grpc.Status;
38
import io.grpc.SynchronizationContext;
39
import io.grpc.SynchronizationContext.ScheduledHandle;
40
import io.grpc.services.MetricReport;
41
import io.grpc.util.ForwardingSubchannel;
42
import io.grpc.util.MultiChildLoadBalancer;
43
import io.grpc.xds.orca.OrcaOobUtil;
44
import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener;
45
import io.grpc.xds.orca.OrcaPerRequestUtil;
46
import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener;
47
import java.util.ArrayList;
48
import java.util.Collection;
49
import java.util.HashSet;
50
import java.util.List;
51
import java.util.Random;
52
import java.util.Set;
53
import java.util.concurrent.ScheduledExecutorService;
54
import java.util.concurrent.TimeUnit;
55
import java.util.concurrent.atomic.AtomicInteger;
56
import java.util.logging.Level;
57
import java.util.logging.Logger;
58

59
/**
60
 * A {@link LoadBalancer} that provides weighted-round-robin load-balancing over the
61
 * {@link EquivalentAddressGroup}s from the {@link NameResolver}. The subchannel weights are
62
 * determined by backend metrics using ORCA.
63
 * To use WRR, users may configure through channel serviceConfig. Example config:
64
 * <pre> {@code
65
 *       String wrrConfig = "{\"loadBalancingConfig\":" +
66
 *           "[{\"weighted_round_robin\":{\"enableOobLoadReport\":true, " +
67
 *           "\"blackoutPeriod\":\"10s\"," +
68
 *           "\"oobReportingPeriod\":\"10s\"," +
69
 *           "\"weightExpirationPeriod\":\"180s\"," +
70
 *           "\"errorUtilizationPenalty\":\"1.0\"," +
71
 *           "\"weightUpdatePeriod\":\"1s\"}}]}";
72
 *        serviceConfig = (Map<String, ?>) JsonParser.parse(wrrConfig);
73
 *        channel = ManagedChannelBuilder.forTarget("test:///lb.test.grpc.io")
74
 *            .defaultServiceConfig(serviceConfig)
75
 *            .build();
76
 *  }
77
 *  </pre>
78
 *  Users may also configure through xDS control plane via custom lb policy. But that is much more
79
 *  complex to set up. Example config:
80
 *  <pre>
81
 *  localityLbPolicies:
82
 *   - customPolicy:
83
 *       name: weighted_round_robin
84
 *       data: '{ "enableOobLoadReport": true }'
85
 *  </pre>
86
 *  See related documentation: https://cloud.google.com/service-mesh/legacy/load-balancing-apis/proxyless-configure-advanced-traffic-management#custom-lb-config
87
 */
88
final class WeightedRoundRobinLoadBalancer extends MultiChildLoadBalancer {
89

90
  private static final LongCounterMetricInstrument RR_FALLBACK_COUNTER;
91
  private static final LongCounterMetricInstrument ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER;
92
  private static final LongCounterMetricInstrument ENDPOINT_WEIGHT_STALE_COUNTER;
93
  private static final DoubleHistogramMetricInstrument ENDPOINT_WEIGHTS_HISTOGRAM;
94
  private static final Logger log = Logger.getLogger(
1✔
95
      WeightedRoundRobinLoadBalancer.class.getName());
1✔
96
  private WeightedRoundRobinLoadBalancerConfig config;
97
  private final SynchronizationContext syncContext;
98
  private final ScheduledExecutorService timeService;
99
  private ScheduledHandle weightUpdateTimer;
100
  private final Runnable updateWeightTask;
101
  private final AtomicInteger sequence;
102
  private final long infTime;
103
  private final Ticker ticker;
104
  private String locality = "";
1✔
105
  private SubchannelPicker currentPicker = new FixedResultPicker(PickResult.withNoResult());
1✔
106

107
  // The metric instruments are only registered once and shared by all instances of this LB.
108
  static {
109
    MetricInstrumentRegistry metricInstrumentRegistry
110
        = MetricInstrumentRegistry.getDefaultRegistry();
1✔
111
    RR_FALLBACK_COUNTER = metricInstrumentRegistry.registerLongCounter("grpc.lb.wrr.rr_fallback",
1✔
112
        "EXPERIMENTAL. Number of scheduler updates in which there were not enough endpoints "
113
            + "with valid weight, which caused the WRR policy to fall back to RR behavior",
114
        "{update}", Lists.newArrayList("grpc.target"), Lists.newArrayList("grpc.lb.locality"),
1✔
115
        false);
116
    ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER = metricInstrumentRegistry.registerLongCounter(
1✔
117
        "grpc.lb.wrr.endpoint_weight_not_yet_usable", "EXPERIMENTAL. Number of endpoints "
118
            + "from each scheduler update that don't yet have usable weight information",
119
        "{endpoint}", Lists.newArrayList("grpc.target"), Lists.newArrayList("grpc.lb.locality"),
1✔
120
        false);
121
    ENDPOINT_WEIGHT_STALE_COUNTER = metricInstrumentRegistry.registerLongCounter(
1✔
122
        "grpc.lb.wrr.endpoint_weight_stale",
123
        "EXPERIMENTAL. Number of endpoints from each scheduler update whose latest weight is "
124
            + "older than the expiration period", "{endpoint}", Lists.newArrayList("grpc.target"),
1✔
125
        Lists.newArrayList("grpc.lb.locality"), false);
1✔
126
    ENDPOINT_WEIGHTS_HISTOGRAM = metricInstrumentRegistry.registerDoubleHistogram(
1✔
127
        "grpc.lb.wrr.endpoint_weights",
128
        "EXPERIMENTAL. The histogram buckets will be endpoint weight ranges.",
129
        "{weight}", Lists.newArrayList(), Lists.newArrayList("grpc.target"),
1✔
130
        Lists.newArrayList("grpc.lb.locality"),
1✔
131
        false);
132
  }
1✔
133

134
  public WeightedRoundRobinLoadBalancer(Helper helper, Ticker ticker) {
135
    this(helper, ticker, new Random());
1✔
136
  }
1✔
137

138
  @VisibleForTesting
139
  WeightedRoundRobinLoadBalancer(Helper helper, Ticker ticker, Random random) {
140
    super(OrcaOobUtil.newOrcaReportingHelper(helper));
1✔
141
    this.ticker = checkNotNull(ticker, "ticker");
1✔
142
    this.infTime = ticker.nanoTime() + Long.MAX_VALUE;
1✔
143
    this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext");
1✔
144
    this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService");
1✔
145
    this.updateWeightTask = new UpdateWeightTask();
1✔
146
    this.sequence = new AtomicInteger(random.nextInt());
1✔
147
    log.log(Level.FINE, "weighted_round_robin LB created");
1✔
148
  }
1✔
149

150
  @Override
151
  protected ChildLbState createChildLbState(Object key, Object policyConfig,
152
      SubchannelPicker initialPicker, ResolvedAddresses unused) {
153
    ChildLbState childLbState = new WeightedChildLbState(key, pickFirstLbProvider, policyConfig,
1✔
154
        initialPicker);
155
    return childLbState;
1✔
156
  }
157

158
  @Override
159
  public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
160
    if (resolvedAddresses.getLoadBalancingPolicyConfig() == null) {
1✔
161
      Status unavailableStatus = Status.UNAVAILABLE.withDescription(
1✔
162
              "NameResolver returned no WeightedRoundRobinLoadBalancerConfig. addrs="
163
                      + resolvedAddresses.getAddresses()
1✔
164
                      + ", attrs=" + resolvedAddresses.getAttributes());
1✔
165
      handleNameResolutionError(unavailableStatus);
1✔
166
      return unavailableStatus;
1✔
167
    }
168
    String locality = resolvedAddresses.getAttributes().get(WeightedTargetLoadBalancer.CHILD_NAME);
1✔
169
    if (locality != null) {
1✔
170
      this.locality = locality;
1✔
171
    } else {
172
      this.locality = "";
1✔
173
    }
174
    config =
1✔
175
            (WeightedRoundRobinLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig();
1✔
176
    AcceptResolvedAddrRetVal acceptRetVal;
177
    try {
178
      resolvingAddresses = true;
1✔
179
      acceptRetVal = acceptResolvedAddressesInternal(resolvedAddresses);
1✔
180
      if (!acceptRetVal.status.isOk()) {
1✔
181
        return acceptRetVal.status;
×
182
      }
183

184
      if (weightUpdateTimer != null && weightUpdateTimer.isPending()) {
1✔
185
        weightUpdateTimer.cancel();
1✔
186
      }
187
      updateWeightTask.run();
1✔
188

189
      createAndApplyOrcaListeners();
1✔
190

191
      // Must update channel picker before return so that new RPCs will not be routed to deleted
192
      // clusters and resolver can remove them in service config.
193
      updateOverallBalancingState();
1✔
194

195
      shutdownRemoved(acceptRetVal.removedChildren);
1✔
196
    } finally {
197
      resolvingAddresses = false;
1✔
198
    }
199

200
    return acceptRetVal.status;
1✔
201
  }
202

203
  /**
204
   * Updates picker with the list of active subchannels (state == READY).
205
   */
206
  @Override
207
  protected void updateOverallBalancingState() {
208
    List<ChildLbState> activeList = getReadyChildren();
1✔
209
    if (activeList.isEmpty()) {
1✔
210
      // No READY subchannels
211

212
      // MultiChildLB will request connection immediately on subchannel IDLE.
213
      boolean isConnecting = false;
1✔
214
      for (ChildLbState childLbState : getChildLbStates()) {
1✔
215
        ConnectivityState state = childLbState.getCurrentState();
1✔
216
        if (state == ConnectivityState.CONNECTING || state == ConnectivityState.IDLE) {
1✔
217
          isConnecting = true;
1✔
218
          break;
1✔
219
        }
220
      }
1✔
221

222
      if (isConnecting) {
1✔
223
        updateBalancingState(
1✔
224
            ConnectivityState.CONNECTING, new FixedResultPicker(PickResult.withNoResult()));
1✔
225
      } else {
226
        updateBalancingState(
1✔
227
            ConnectivityState.TRANSIENT_FAILURE, createReadyPicker(getChildLbStates()));
1✔
228
      }
229
    } else {
1✔
230
      updateBalancingState(ConnectivityState.READY, createReadyPicker(activeList));
1✔
231
    }
232
  }
1✔
233

234
  private SubchannelPicker createReadyPicker(Collection<ChildLbState> activeList) {
235
    WeightedRoundRobinPicker picker = new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList),
1✔
236
        config.enableOobLoadReport, config.errorUtilizationPenalty, sequence);
237
    updateWeight(picker);
1✔
238
    return picker;
1✔
239
  }
240

241
  private void updateWeight(WeightedRoundRobinPicker picker) {
242
    Helper helper = getHelper();
1✔
243
    float[] newWeights = new float[picker.children.size()];
1✔
244
    AtomicInteger staleEndpoints = new AtomicInteger();
1✔
245
    AtomicInteger notYetUsableEndpoints = new AtomicInteger();
1✔
246
    for (int i = 0; i < picker.children.size(); i++) {
1✔
247
      double newWeight = ((WeightedChildLbState) picker.children.get(i)).getWeight(staleEndpoints,
1✔
248
          notYetUsableEndpoints);
249
      helper.getMetricRecorder()
1✔
250
          .recordDoubleHistogram(ENDPOINT_WEIGHTS_HISTOGRAM, newWeight,
1✔
251
              ImmutableList.of(helper.getChannelTarget()),
1✔
252
              ImmutableList.of(locality));
1✔
253
      newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f;
1✔
254
    }
255

256
    if (staleEndpoints.get() > 0) {
1✔
257
      helper.getMetricRecorder()
1✔
258
          .addLongCounter(ENDPOINT_WEIGHT_STALE_COUNTER, staleEndpoints.get(),
1✔
259
              ImmutableList.of(helper.getChannelTarget()),
1✔
260
              ImmutableList.of(locality));
1✔
261
    }
262
    if (notYetUsableEndpoints.get() > 0) {
1✔
263
      helper.getMetricRecorder()
1✔
264
          .addLongCounter(ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER, notYetUsableEndpoints.get(),
1✔
265
              ImmutableList.of(helper.getChannelTarget()), ImmutableList.of(locality));
1✔
266
    }
267
    boolean weightsEffective = picker.updateWeight(newWeights);
1✔
268
    if (!weightsEffective) {
1✔
269
      helper.getMetricRecorder()
1✔
270
          .addLongCounter(RR_FALLBACK_COUNTER, 1, ImmutableList.of(helper.getChannelTarget()),
1✔
271
              ImmutableList.of(locality));
1✔
272
    }
273
  }
1✔
274

275
  private void updateBalancingState(ConnectivityState state, SubchannelPicker picker) {
276
    if (state != currentConnectivityState || !picker.equals(currentPicker)) {
1✔
277
      getHelper().updateBalancingState(state, picker);
1✔
278
      currentConnectivityState = state;
1✔
279
      currentPicker = picker;
1✔
280
    }
281
  }
1✔
282

283
  @VisibleForTesting
284
  final class WeightedChildLbState extends ChildLbState {
285

286
    private final Set<WrrSubchannel> subchannels = new HashSet<>();
1✔
287
    private volatile long lastUpdated;
288
    private volatile long nonEmptySince;
289
    private volatile double weight = 0;
1✔
290

291
    private OrcaReportListener orcaReportListener;
292

293
    public WeightedChildLbState(Object key, LoadBalancerProvider policyProvider, Object childConfig,
294
        SubchannelPicker initialPicker) {
1✔
295
      super(key, policyProvider, childConfig, initialPicker);
1✔
296
    }
1✔
297

298
    @Override
299
    protected ChildLbStateHelper createChildHelper() {
300
      return new WrrChildLbStateHelper();
1✔
301
    }
302

303
    private double getWeight(AtomicInteger staleEndpoints, AtomicInteger notYetUsableEndpoints) {
304
      if (config == null) {
1✔
305
        return 0;
×
306
      }
307
      long now = ticker.nanoTime();
1✔
308
      if (now - lastUpdated >= config.weightExpirationPeriodNanos) {
1✔
309
        nonEmptySince = infTime;
1✔
310
        staleEndpoints.incrementAndGet();
1✔
311
        return 0;
1✔
312
      } else if (now - nonEmptySince < config.blackoutPeriodNanos
1✔
313
          && config.blackoutPeriodNanos > 0) {
1✔
314
        notYetUsableEndpoints.incrementAndGet();
1✔
315
        return 0;
1✔
316
      } else {
317
        return weight;
1✔
318
      }
319
    }
320

321
    public void addSubchannel(WrrSubchannel wrrSubchannel) {
322
      subchannels.add(wrrSubchannel);
1✔
323
    }
1✔
324

325
    public OrcaReportListener getOrCreateOrcaListener(float errorUtilizationPenalty) {
326
      if (orcaReportListener != null
1✔
327
          && orcaReportListener.errorUtilizationPenalty == errorUtilizationPenalty) {
1✔
328
        return orcaReportListener;
1✔
329
      }
330
      orcaReportListener = new OrcaReportListener(errorUtilizationPenalty);
1✔
331
      return orcaReportListener;
1✔
332
    }
333

334
    public void removeSubchannel(WrrSubchannel wrrSubchannel) {
335
      subchannels.remove(wrrSubchannel);
1✔
336
    }
1✔
337

338
    final class WrrChildLbStateHelper extends ChildLbStateHelper {
1✔
339
      @Override
340
      public Subchannel createSubchannel(CreateSubchannelArgs args) {
341
        return new WrrSubchannel(super.createSubchannel(args), WeightedChildLbState.this);
1✔
342
      }
343
    }
344

345
    final class OrcaReportListener implements OrcaPerRequestReportListener, OrcaOobReportListener {
346
      private final float errorUtilizationPenalty;
347

348
      OrcaReportListener(float errorUtilizationPenalty) {
1✔
349
        this.errorUtilizationPenalty = errorUtilizationPenalty;
1✔
350
      }
1✔
351

352
      @Override
353
      public void onLoadReport(MetricReport report) {
354
        double newWeight = 0;
1✔
355
        // Prefer application utilization and fallback to CPU utilization if unset.
356
        double utilization =
357
            report.getApplicationUtilization() > 0 ? report.getApplicationUtilization()
1✔
358
                : report.getCpuUtilization();
1✔
359
        if (utilization > 0 && report.getQps() > 0) {
1✔
360
          double penalty = 0;
1✔
361
          if (report.getEps() > 0 && errorUtilizationPenalty > 0) {
1✔
362
            penalty = report.getEps() / report.getQps() * errorUtilizationPenalty;
1✔
363
          }
364
          newWeight = report.getQps() / (utilization + penalty);
1✔
365
        }
366
        if (newWeight == 0) {
1✔
367
          return;
1✔
368
        }
369
        if (nonEmptySince == infTime) {
1✔
370
          nonEmptySince = ticker.nanoTime();
1✔
371
        }
372
        lastUpdated = ticker.nanoTime();
1✔
373
        weight = newWeight;
1✔
374
      }
1✔
375
    }
376
  }
377

378
  private final class UpdateWeightTask implements Runnable {
1✔
379
    @Override
380
    public void run() {
381
      if (currentPicker != null && currentPicker instanceof WeightedRoundRobinPicker) {
1✔
382
        updateWeight((WeightedRoundRobinPicker) currentPicker);
1✔
383
      }
384
      weightUpdateTimer = syncContext.schedule(this, config.weightUpdatePeriodNanos,
1✔
385
          TimeUnit.NANOSECONDS, timeService);
1✔
386
    }
1✔
387
  }
388

389
  private void createAndApplyOrcaListeners() {
390
    for (ChildLbState child : getChildLbStates()) {
1✔
391
      WeightedChildLbState wChild = (WeightedChildLbState) child;
1✔
392
      for (WrrSubchannel weightedSubchannel : wChild.subchannels) {
1✔
393
        if (config.enableOobLoadReport) {
1✔
394
          OrcaOobUtil.setListener(weightedSubchannel,
1✔
395
              wChild.getOrCreateOrcaListener(config.errorUtilizationPenalty),
1✔
396
              OrcaOobUtil.OrcaReportingConfig.newBuilder()
1✔
397
                  .setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS)
1✔
398
                  .build());
1✔
399
        } else {
400
          OrcaOobUtil.setListener(weightedSubchannel, null, null);
1✔
401
        }
402
      }
1✔
403
    }
1✔
404
  }
1✔
405

406
  @Override
407
  public void shutdown() {
408
    if (weightUpdateTimer != null) {
1✔
409
      weightUpdateTimer.cancel();
1✔
410
    }
411
    super.shutdown();
1✔
412
  }
1✔
413

414
  @VisibleForTesting
415
  final class WrrSubchannel extends ForwardingSubchannel {
416
    private final Subchannel delegate;
417
    private final WeightedChildLbState owner;
418

419
    WrrSubchannel(Subchannel delegate, WeightedChildLbState owner) {
1✔
420
      this.delegate = checkNotNull(delegate, "delegate");
1✔
421
      this.owner = checkNotNull(owner, "owner");
1✔
422
    }
1✔
423

424
    @Override
425
    public void start(SubchannelStateListener listener) {
426
      owner.addSubchannel(this);
1✔
427
      delegate().start(new SubchannelStateListener() {
1✔
428
        @Override
429
        public void onSubchannelState(ConnectivityStateInfo newState) {
430
          if (newState.getState().equals(ConnectivityState.READY)) {
1✔
431
            owner.nonEmptySince = infTime;
1✔
432
          }
433
          listener.onSubchannelState(newState);
1✔
434
        }
1✔
435
      });
436
    }
1✔
437

438
    @Override
439
    protected Subchannel delegate() {
440
      return delegate;
1✔
441
    }
442

443
    @Override
444
    public void shutdown() {
445
      super.shutdown();
1✔
446
      owner.removeSubchannel(this);
1✔
447
    }
1✔
448
  }
449

450
  @VisibleForTesting
451
  static final class WeightedRoundRobinPicker extends SubchannelPicker {
452
    // Parallel lists (column-based storage instead of normal row-based storage of List<Struct>).
453
    // The ith element of children corresponds to the ith element of pickers, listeners, and even
454
    // updateWeight(float[]).
455
    private final List<ChildLbState> children; // May only be accessed from sync context
456
    private final List<SubchannelPicker> pickers;
457
    private final List<OrcaPerRequestReportListener> reportListeners;
458
    private final boolean enableOobLoadReport;
459
    private final float errorUtilizationPenalty;
460
    private final AtomicInteger sequence;
461
    private final int hashCode;
462
    private volatile StaticStrideScheduler scheduler;
463

464
    WeightedRoundRobinPicker(List<ChildLbState> children, boolean enableOobLoadReport,
465
        float errorUtilizationPenalty, AtomicInteger sequence) {
1✔
466
      checkNotNull(children, "children");
1✔
467
      Preconditions.checkArgument(!children.isEmpty(), "empty child list");
1✔
468
      this.children = children;
1✔
469
      List<SubchannelPicker> pickers = new ArrayList<>(children.size());
1✔
470
      List<OrcaPerRequestReportListener> reportListeners = new ArrayList<>(children.size());
1✔
471
      for (ChildLbState child : children) {
1✔
472
        WeightedChildLbState wChild = (WeightedChildLbState) child;
1✔
473
        pickers.add(wChild.getCurrentPicker());
1✔
474
        reportListeners.add(wChild.getOrCreateOrcaListener(errorUtilizationPenalty));
1✔
475
      }
1✔
476
      this.pickers = pickers;
1✔
477
      this.reportListeners = reportListeners;
1✔
478
      this.enableOobLoadReport = enableOobLoadReport;
1✔
479
      this.errorUtilizationPenalty = errorUtilizationPenalty;
1✔
480
      this.sequence = checkNotNull(sequence, "sequence");
1✔
481

482
      // For equality we treat pickers as a set; use hash code as defined by Set
483
      int sum = 0;
1✔
484
      for (SubchannelPicker picker : pickers) {
1✔
485
        sum += picker.hashCode();
1✔
486
      }
1✔
487
      this.hashCode = sum
1✔
488
          ^ Boolean.hashCode(enableOobLoadReport)
1✔
489
          ^ Float.hashCode(errorUtilizationPenalty);
1✔
490
    }
1✔
491

492
    @Override
493
    public PickResult pickSubchannel(PickSubchannelArgs args) {
494
      int pick = scheduler.pick();
1✔
495
      PickResult pickResult = pickers.get(pick).pickSubchannel(args);
1✔
496
      Subchannel subchannel = pickResult.getSubchannel();
1✔
497
      if (subchannel == null) {
1✔
498
        return pickResult;
1✔
499
      }
500
      if (!enableOobLoadReport) {
1✔
501
        return PickResult.withSubchannel(subchannel,
1✔
502
            OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(
1✔
503
                reportListeners.get(pick)));
1✔
504
      } else {
505
        return PickResult.withSubchannel(subchannel);
1✔
506
      }
507
    }
508

509
    /** Returns {@code true} if weights are different than round_robin. */
510
    private boolean updateWeight(float[] newWeights) {
511
      this.scheduler = new StaticStrideScheduler(newWeights, sequence);
1✔
512
      return !this.scheduler.usesRoundRobin();
1✔
513
    }
514

515
    @Override
516
    public String toString() {
517
      return MoreObjects.toStringHelper(WeightedRoundRobinPicker.class)
1✔
518
          .add("enableOobLoadReport", enableOobLoadReport)
1✔
519
          .add("errorUtilizationPenalty", errorUtilizationPenalty)
1✔
520
          .add("pickers", pickers)
1✔
521
          .toString();
1✔
522
    }
523

524
    @VisibleForTesting
525
    List<ChildLbState> getChildren() {
526
      return children;
1✔
527
    }
528

529
    @Override
530
    public int hashCode() {
531
      return hashCode;
×
532
    }
533

534
    @Override
535
    public boolean equals(Object o) {
536
      if (!(o instanceof WeightedRoundRobinPicker)) {
1✔
537
        return false;
×
538
      }
539
      WeightedRoundRobinPicker other = (WeightedRoundRobinPicker) o;
1✔
540
      if (other == this) {
1✔
541
        return true;
×
542
      }
543
      // the lists cannot contain duplicate subchannels
544
      return hashCode == other.hashCode
1✔
545
          && sequence == other.sequence
546
          && enableOobLoadReport == other.enableOobLoadReport
547
          && Float.compare(errorUtilizationPenalty, other.errorUtilizationPenalty) == 0
1✔
548
          && pickers.size() == other.pickers.size()
1✔
549
          && new HashSet<>(pickers).containsAll(other.pickers);
1✔
550
    }
551
  }
552

553
  /*
554
   * The Static Stride Scheduler is an implementation of an earliest deadline first (EDF) scheduler
555
   * in which each object's deadline is the multiplicative inverse of the object's weight.
556
   * <p>
557
   * The way in which this is implemented is through a static stride scheduler. 
558
   * The Static Stride Scheduler works by iterating through the list of subchannel weights
559
   * and using modular arithmetic to proportionally distribute picks, favoring entries 
560
   * with higher weights. It is based on the observation that the intended sequence generated 
561
   * from an EDF scheduler is a periodic one that can be achieved through modular arithmetic. 
562
   * The Static Stride Scheduler is more performant than other implementations of the EDF
563
   * Scheduler, as it removes the need for a priority queue (and thus mutex locks).
564
   * <p>
565
   * go/static-stride-scheduler
566
   * <p>
567
   *
568
   * <ul>
569
   *  <li>nextSequence() - O(1)
570
   *  <li>pick() - O(n)
571
   */
572
  @VisibleForTesting
573
  static final class StaticStrideScheduler {
574
    private final short[] scaledWeights;
575
    private final AtomicInteger sequence;
576
    private final boolean usesRoundRobin;
577
    private static final int K_MAX_WEIGHT = 0xFFFF;
578

579
    // Assuming the mean of all known weights is M, StaticStrideScheduler will clamp
580
    // weights bigger than M*kMaxRatio and weights smaller than M*kMinRatio.
581
    //
582
    // This is done as a performance optimization by limiting the number of rounds for picks
583
    // for edge cases where channels have large differences in subchannel weights.
584
    // In this case, without these clips, it would potentially require the scheduler to
585
    // frequently traverse through the entire subchannel list within the pick method.
586
    //
587
    // The current values of 10 and 0.1 were chosen without any experimenting. It should
588
    // decrease the amount of sequences that the scheduler must traverse through in order
589
    // to pick a high weight subchannel in such corner cases.
590
    // But, it also makes WeightedRoundRobin to send slightly more requests to
591
    // potentially very bad tasks (that would have near-zero weights) than zero.
592
    // This is not necessarily a downside, though. Perhaps this is not a problem at
593
    // all, and we can increase this value if needed to save CPU cycles.
594
    private static final double K_MAX_RATIO = 10;
595
    private static final double K_MIN_RATIO = 0.1;
596

597
    StaticStrideScheduler(float[] weights, AtomicInteger sequence) {
1✔
598
      checkArgument(weights.length >= 1, "Couldn't build scheduler: requires at least one weight");
1✔
599
      int numChannels = weights.length;
1✔
600
      int numWeightedChannels = 0;
1✔
601
      double sumWeight = 0;
1✔
602
      double unscaledMeanWeight;
603
      float unscaledMaxWeight = 0;
1✔
604
      for (float weight : weights) {
1✔
605
        if (weight > 0) {
1✔
606
          sumWeight += weight;
1✔
607
          unscaledMaxWeight = Math.max(weight, unscaledMaxWeight);
1✔
608
          numWeightedChannels++;
1✔
609
        }
610
      }
611

612
      // Adjust max value s.t. ratio does not exceed K_MAX_RATIO. This should
613
      // ensure that we on average do at most K_MAX_RATIO rounds for picks.
614
      if (numWeightedChannels > 0) {
1✔
615
        unscaledMeanWeight = sumWeight / numWeightedChannels;
1✔
616
        unscaledMaxWeight = Math.min(unscaledMaxWeight, (float) (K_MAX_RATIO * unscaledMeanWeight));
1✔
617
      } else {
618
        // Fall back to round robin if all values are non-positives. Note that
619
        // numWeightedChannels == 1 also behaves like RR because the weights are all the same, but
620
        // the weights aren't 1, so it doesn't go through this path.
621
        unscaledMeanWeight = 1;
1✔
622
        unscaledMaxWeight = 1;
1✔
623
      }
624
      // We need at least two weights for WRR to be distinguishable from round_robin.
625
      usesRoundRobin = numWeightedChannels < 2;
1✔
626

627
      // Scales weights s.t. max(weights) == K_MAX_WEIGHT, meanWeight is scaled accordingly.
628
      // Note that, since we cap the weights to stay within K_MAX_RATIO, meanWeight might not
629
      // match the actual mean of the values that end up in the scheduler.
630
      double scalingFactor = K_MAX_WEIGHT / unscaledMaxWeight;
1✔
631
      // We compute weightLowerBound and clamp it to 1 from below so that in the
632
      // worst case, we represent tiny weights as 1.
633
      int weightLowerBound = (int) Math.ceil(scalingFactor * unscaledMeanWeight * K_MIN_RATIO);
1✔
634
      short[] scaledWeights = new short[numChannels];
1✔
635
      for (int i = 0; i < numChannels; i++) {
1✔
636
        if (weights[i] <= 0) {
1✔
637
          scaledWeights[i] = (short) Math.round(scalingFactor * unscaledMeanWeight);
1✔
638
        } else {
639
          int weight = (int) Math.round(scalingFactor * Math.min(weights[i], unscaledMaxWeight));
1✔
640
          scaledWeights[i] = (short) Math.max(weight, weightLowerBound);
1✔
641
        }
642
      }
643

644
      this.scaledWeights = scaledWeights;
1✔
645
      this.sequence = sequence;
1✔
646
    }
1✔
647

648
    // Without properly weighted channels, we do plain vanilla round_robin.
649
    boolean usesRoundRobin() {
650
      return usesRoundRobin;
1✔
651
    }
652

653
    /**
654
     * Returns the next sequence number and atomically increases sequence with wraparound.
655
     */
656
    private long nextSequence() {
657
      return Integer.toUnsignedLong(sequence.getAndIncrement());
1✔
658
    }
659

660
    /*
661
     * Selects index of next backend server.
662
     * <p>
663
     * A 2D array is compactly represented as a function of W(backend), where the row
664
     * represents the generation and the column represents the backend index:
665
     * X(backend,generation) | generation ∈ [0,kMaxWeight).
666
     * Each element in the conceptual array is a boolean indicating whether the backend at
667
     * this index should be picked now. If false, the counter is incremented again,
668
     * and the new element is checked. An atomically incremented counter keeps track of our
669
     * backend and generation through modular arithmetic within the pick() method.
670
     * <p>
671
     * Modular arithmetic allows us to evenly distribute picks and skips between
672
     * generations based on W(backend).
673
     * X(backend,generation) = (W(backend) * generation) % kMaxWeight >= kMaxWeight - W(backend)
674
     * If we have the same three backends with weights:
675
     * W(backend) = {2,3,6} scaled to max(W(backend)) = 6, then X(backend,generation) is:
676
     * <p>
677
     * B0    B1    B2
678
     * T     T     T
679
     * F     F     T
680
     * F     T     T
681
     * T     F     T
682
     * F     T     T
683
     * F     F     T
684
     * The sequence of picked backend indices is given by
685
     * walking across and down: {0,1,2,2,1,2,0,2,1,2,2}.
686
     * <p>
687
     * To reduce the variance and spread the wasted work among different picks,
688
     * an offset that varies per backend index is also included to the calculation.
689
     */
690
    int pick() {
691
      while (true) {
692
        long sequence = this.nextSequence();
1✔
693
        int backendIndex = (int) (sequence % scaledWeights.length);
1✔
694
        long generation = sequence / scaledWeights.length;
1✔
695
        int weight = Short.toUnsignedInt(scaledWeights[backendIndex]);
1✔
696
        long offset = (long) K_MAX_WEIGHT / 2 * backendIndex;
1✔
697
        if ((weight * generation + offset) % K_MAX_WEIGHT < K_MAX_WEIGHT - weight) {
1✔
698
          continue;
1✔
699
        }
700
        return backendIndex;
1✔
701
      }
702
    }
703
  }
704

705
  static final class WeightedRoundRobinLoadBalancerConfig {
706
    final long blackoutPeriodNanos;
707
    final long weightExpirationPeriodNanos;
708
    final boolean enableOobLoadReport;
709
    final long oobReportingPeriodNanos;
710
    final long weightUpdatePeriodNanos;
711
    final float errorUtilizationPenalty;
712

713
    public static Builder newBuilder() {
714
      return new Builder();
1✔
715
    }
716

717
    private WeightedRoundRobinLoadBalancerConfig(long blackoutPeriodNanos,
718
                                                 long weightExpirationPeriodNanos,
719
                                                 boolean enableOobLoadReport,
720
                                                 long oobReportingPeriodNanos,
721
                                                 long weightUpdatePeriodNanos,
722
                                                 float errorUtilizationPenalty) {
1✔
723
      this.blackoutPeriodNanos = blackoutPeriodNanos;
1✔
724
      this.weightExpirationPeriodNanos = weightExpirationPeriodNanos;
1✔
725
      this.enableOobLoadReport = enableOobLoadReport;
1✔
726
      this.oobReportingPeriodNanos = oobReportingPeriodNanos;
1✔
727
      this.weightUpdatePeriodNanos = weightUpdatePeriodNanos;
1✔
728
      this.errorUtilizationPenalty = errorUtilizationPenalty;
1✔
729
    }
1✔
730

731
    static final class Builder {
732
      long blackoutPeriodNanos = 10_000_000_000L; // 10s
1✔
733
      long weightExpirationPeriodNanos = 180_000_000_000L; //3min
1✔
734
      boolean enableOobLoadReport = false;
1✔
735
      long oobReportingPeriodNanos = 10_000_000_000L; // 10s
1✔
736
      long weightUpdatePeriodNanos = 1_000_000_000L; // 1s
1✔
737
      float errorUtilizationPenalty = 1.0F;
1✔
738

739
      private Builder() {
1✔
740

741
      }
1✔
742

743
      @SuppressWarnings("UnusedReturnValue")
744
      Builder setBlackoutPeriodNanos(long blackoutPeriodNanos) {
745
        this.blackoutPeriodNanos = blackoutPeriodNanos;
1✔
746
        return this;
1✔
747
      }
748

749
      @SuppressWarnings("UnusedReturnValue")
750
      Builder setWeightExpirationPeriodNanos(long weightExpirationPeriodNanos) {
751
        this.weightExpirationPeriodNanos = weightExpirationPeriodNanos;
1✔
752
        return this;
1✔
753
      }
754

755
      Builder setEnableOobLoadReport(boolean enableOobLoadReport) {
756
        this.enableOobLoadReport = enableOobLoadReport;
1✔
757
        return this;
1✔
758
      }
759

760
      Builder setOobReportingPeriodNanos(long oobReportingPeriodNanos) {
761
        this.oobReportingPeriodNanos = oobReportingPeriodNanos;
1✔
762
        return this;
1✔
763
      }
764

765
      Builder setWeightUpdatePeriodNanos(long weightUpdatePeriodNanos) {
766
        this.weightUpdatePeriodNanos = weightUpdatePeriodNanos;
1✔
767
        return this;
1✔
768
      }
769

770
      Builder setErrorUtilizationPenalty(float errorUtilizationPenalty) {
771
        this.errorUtilizationPenalty = errorUtilizationPenalty;
1✔
772
        return this;
1✔
773
      }
774

775
      WeightedRoundRobinLoadBalancerConfig build() {
776
        return new WeightedRoundRobinLoadBalancerConfig(blackoutPeriodNanos,
1✔
777
                weightExpirationPeriodNanos, enableOobLoadReport, oobReportingPeriodNanos,
778
                weightUpdatePeriodNanos, errorUtilizationPenalty);
779
      }
780
    }
781
  }
782
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc