• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

grpc / grpc-java / #20271

07 May 2026 09:27AM UTC coverage: 88.828% (+0.01%) from 88.816%
#20271

push

github

web-flow
xds: pre-parse custom metric names in WRR load balancer (#12773)

Introduce ParsedMetricName in MetricReportUtils to pre-parse configured
custom metric names into Enums and key Strings on config initialization
in WeightedRoundRobinLoadBalancerConfig, avoiding String parsing
operations in the data path.

This has been done by a combination of a few things

- Streams -> loop
- OptionalDouble -> double : We decided to take a hit here because it
provides semantic correctness over using sentinels.
- Pre parsing instead of hot path substring

OrcaReportListener now utilizes pre-parsed ParsedMetricName objects
during getCustomMetricUtilization to prevent OptionalDouble heap
allocations on the hot path.

Updated test coverage in MetricReportUtilsTest and
WeightedRoundRobinLoadBalancerTest.

# JMH Benchmark Report: MetricReportUtils Optimization

We performed a benchmark comparison of four different custom metric
resolution implementations in the Weighted Round Robin (WRR) load
balancer.

## Benchmark Results

| Benchmark Variant | Average Latency | Normalized Heap Allocations |
Speedup |
| :------------------------------------ | :-------------- |
:-------------------------- | :-------- |
| **Baseline (`String` + Streams)** | 174.46 ns/op | 704.00 B/op | 1x |
| **`ParsedMetricName` + Streams** | 148.95 ns/op | 608.00 B/op | ~1.1x
|
| **`String` + Loop** | 81.61 ns/op | 240.00 B/op | ~2.1x |
| **`ParsedMetricName` + Loop** | 52.92 ns/op | 144.00 B/op | ~3.2x |
| **`ParsedMetricName` + Unboxed Loop** | **43.76 ns/op** | **≈ 0.00
B/op** | **~4.0x** |

---

36247 of 40806 relevant lines covered (88.83%)

0.89 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.6
/../xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java
1
/*
2
 * Copyright 2023 The gRPC Authors
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *     http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing, software
11
 * distributed under the License is distributed on an "AS IS" BASIS,
12
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
 * See the License for the specific language governing permissions and
14
 * limitations under the License.
15
 */
16

17
package io.grpc.xds;
18

19
import static com.google.common.base.Preconditions.checkArgument;
20
import static com.google.common.base.Preconditions.checkNotNull;
21

22
import com.google.common.annotations.VisibleForTesting;
23
import com.google.common.base.MoreObjects;
24
import com.google.common.base.Preconditions;
25
import com.google.common.collect.ImmutableList;
26
import com.google.common.collect.Lists;
27
import io.grpc.ConnectivityState;
28
import io.grpc.ConnectivityStateInfo;
29
import io.grpc.Deadline.Ticker;
30
import io.grpc.DoubleHistogramMetricInstrument;
31
import io.grpc.EquivalentAddressGroup;
32
import io.grpc.LoadBalancer;
33
import io.grpc.LoadBalancerProvider;
34
import io.grpc.LongCounterMetricInstrument;
35
import io.grpc.MetricInstrumentRegistry;
36
import io.grpc.NameResolver;
37
import io.grpc.Status;
38
import io.grpc.SynchronizationContext;
39
import io.grpc.SynchronizationContext.ScheduledHandle;
40
import io.grpc.services.MetricReport;
41
import io.grpc.util.ForwardingSubchannel;
42
import io.grpc.util.MultiChildLoadBalancer;
43
import io.grpc.xds.internal.MetricReportUtils;
44
import io.grpc.xds.internal.MetricReportUtils.ParsedMetricName;
45
import io.grpc.xds.orca.OrcaOobUtil;
46
import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener;
47
import io.grpc.xds.orca.OrcaPerRequestUtil;
48
import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener;
49
import java.util.ArrayList;
50
import java.util.Collection;
51
import java.util.HashSet;
52
import java.util.List;
53
import java.util.Objects;
54
import java.util.OptionalDouble;
55
import java.util.Random;
56
import java.util.Set;
57
import java.util.concurrent.ScheduledExecutorService;
58
import java.util.concurrent.TimeUnit;
59
import java.util.concurrent.atomic.AtomicInteger;
60
import java.util.logging.Level;
61
import java.util.logging.Logger;
62

63
/**
64
 * A {@link LoadBalancer} that provides weighted-round-robin load-balancing over the
65
 * {@link EquivalentAddressGroup}s from the {@link NameResolver}. The subchannel weights are
66
 * determined by backend metrics using ORCA.
67
 * To use WRR, users may configure through channel serviceConfig. Example config:
68
 * <pre> {@code
69
 *       String wrrConfig = "{\"loadBalancingConfig\":" +
70
 *           "[{\"weighted_round_robin\":{\"enableOobLoadReport\":true, " +
71
 *           "\"blackoutPeriod\":\"10s\"," +
72
 *           "\"oobReportingPeriod\":\"10s\"," +
73
 *           "\"weightExpirationPeriod\":\"180s\"," +
74
 *           "\"errorUtilizationPenalty\":\"1.0\"," +
75
 *           "\"weightUpdatePeriod\":\"1s\"}}]}";
76
 *        serviceConfig = (Map<String, ?>) JsonParser.parse(wrrConfig);
77
 *        channel = ManagedChannelBuilder.forTarget("test:///lb.test.grpc.io")
78
 *            .defaultServiceConfig(serviceConfig)
79
 *            .build();
80
 *  }
81
 *  </pre>
82
 *  Users may also configure through xDS control plane via custom lb policy. But that is much more
83
 *  complex to set up. Example config:
84
 *  <pre>
85
 *  localityLbPolicies:
86
 *   - customPolicy:
87
 *       name: weighted_round_robin
88
 *       data: '{ "enableOobLoadReport": true }'
89
 *  </pre>
90
 *  See related documentation: https://cloud.google.com/service-mesh/legacy/load-balancing-apis/proxyless-configure-advanced-traffic-management#custom-lb-config
91
 */
92
final class WeightedRoundRobinLoadBalancer extends MultiChildLoadBalancer {
93

94
  private static final LongCounterMetricInstrument RR_FALLBACK_COUNTER;
95
  private static final LongCounterMetricInstrument ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER;
96
  private static final LongCounterMetricInstrument ENDPOINT_WEIGHT_STALE_COUNTER;
97
  private static final DoubleHistogramMetricInstrument ENDPOINT_WEIGHTS_HISTOGRAM;
98
  private static final Logger log = Logger.getLogger(
1✔
99
      WeightedRoundRobinLoadBalancer.class.getName());
1✔
100
  private WeightedRoundRobinLoadBalancerConfig config;
101
  private final SynchronizationContext syncContext;
102
  private final ScheduledExecutorService timeService;
103
  private ScheduledHandle weightUpdateTimer;
104
  private final Runnable updateWeightTask;
105
  private final AtomicInteger sequence;
106
  private final long infTime;
107
  private final Ticker ticker;
108
  private String locality = "";
1✔
109
  private String backendService = "";
1✔
110
  private SubchannelPicker currentPicker = new FixedResultPicker(PickResult.withNoResult());
1✔
111

112
  // The metric instruments are only registered once and shared by all instances of this LB.
113
  static {
114
    MetricInstrumentRegistry metricInstrumentRegistry
115
        = MetricInstrumentRegistry.getDefaultRegistry();
1✔
116
    RR_FALLBACK_COUNTER = metricInstrumentRegistry.registerLongCounter(
1✔
117
        "grpc.lb.wrr.rr_fallback",
118
        "EXPERIMENTAL. Number of scheduler updates in which there were not enough endpoints "
119
            + "with valid weight, which caused the WRR policy to fall back to RR behavior",
120
        "{update}",
121
        Lists.newArrayList("grpc.target"),
1✔
122
        Lists.newArrayList("grpc.lb.locality", "grpc.lb.backend_service"),
1✔
123
        false);
124
    ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER = metricInstrumentRegistry.registerLongCounter(
1✔
125
        "grpc.lb.wrr.endpoint_weight_not_yet_usable",
126
        "EXPERIMENTAL. Number of endpoints from each scheduler update that don't yet have usable "
127
            + "weight information",
128
        "{endpoint}",
129
        Lists.newArrayList("grpc.target"),
1✔
130
        Lists.newArrayList("grpc.lb.locality", "grpc.lb.backend_service"),
1✔
131
        false);
132
    ENDPOINT_WEIGHT_STALE_COUNTER = metricInstrumentRegistry.registerLongCounter(
1✔
133
        "grpc.lb.wrr.endpoint_weight_stale",
134
        "EXPERIMENTAL. Number of endpoints from each scheduler update whose latest weight is "
135
            + "older than the expiration period",
136
        "{endpoint}",
137
        Lists.newArrayList("grpc.target"),
1✔
138
        Lists.newArrayList("grpc.lb.locality", "grpc.lb.backend_service"),
1✔
139
        false);
140
    ENDPOINT_WEIGHTS_HISTOGRAM = metricInstrumentRegistry.registerDoubleHistogram(
1✔
141
        "grpc.lb.wrr.endpoint_weights",
142
        "EXPERIMENTAL. The histogram buckets will be endpoint weight ranges.",
143
        "{weight}",
144
        Lists.newArrayList(),
1✔
145
        Lists.newArrayList("grpc.target"),
1✔
146
        Lists.newArrayList("grpc.lb.locality", "grpc.lb.backend_service"),
1✔
147
        false);
148
  }
1✔
149

150
  public WeightedRoundRobinLoadBalancer(Helper helper, Ticker ticker) {
151
    this(helper, ticker, new Random());
1✔
152
  }
1✔
153

154
  @VisibleForTesting
155
  WeightedRoundRobinLoadBalancer(Helper helper, Ticker ticker, Random random) {
156
    super(OrcaOobUtil.newOrcaReportingHelper(helper));
1✔
157
    this.ticker = checkNotNull(ticker, "ticker");
1✔
158
    this.infTime = ticker.nanoTime() + Long.MAX_VALUE;
1✔
159
    this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext");
1✔
160
    this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService");
1✔
161
    this.updateWeightTask = new UpdateWeightTask();
1✔
162
    this.sequence = new AtomicInteger(random.nextInt());
1✔
163
    log.log(Level.FINE, "weighted_round_robin LB created");
1✔
164
  }
1✔
165

166
  @Override
167
  protected ChildLbState createChildLbState(Object key) {
168
    return new WeightedChildLbState(key, pickFirstLbProvider);
1✔
169
  }
170

171
  @Override
172
  public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
173
    if (resolvedAddresses.getLoadBalancingPolicyConfig() == null) {
1✔
174
      Status unavailableStatus = Status.UNAVAILABLE.withDescription(
1✔
175
              "NameResolver returned no WeightedRoundRobinLoadBalancerConfig. addrs="
176
                      + resolvedAddresses.getAddresses()
1✔
177
                      + ", attrs=" + resolvedAddresses.getAttributes());
1✔
178
      handleNameResolutionError(unavailableStatus);
1✔
179
      return unavailableStatus;
1✔
180
    }
181
    String locality = resolvedAddresses.getAttributes().get(WeightedTargetLoadBalancer.CHILD_NAME);
1✔
182
    if (locality != null) {
1✔
183
      this.locality = locality;
1✔
184
    } else {
185
      this.locality = "";
1✔
186
    }
187
    String backendService
1✔
188
        = resolvedAddresses.getAttributes().get(NameResolver.ATTR_BACKEND_SERVICE);
1✔
189
    if (backendService != null) {
1✔
190
      this.backendService = backendService;
1✔
191
    } else {
192
      this.backendService = "";
1✔
193
    }
194
    config =
1✔
195
        (WeightedRoundRobinLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig();
1✔
196

197
    if (weightUpdateTimer != null && weightUpdateTimer.isPending()) {
1✔
198
      weightUpdateTimer.cancel();
1✔
199
    }
200
    updateWeightTask.run();
1✔
201

202
    Status status = super.acceptResolvedAddresses(resolvedAddresses);
1✔
203

204
    createAndApplyOrcaListeners();
1✔
205

206
    return status;
1✔
207
  }
208

209
  /**
210
   * Updates picker with the list of active subchannels (state == READY).
211
   */
212
  @Override
213
  protected void updateOverallBalancingState() {
214
    List<ChildLbState> activeList = getReadyChildren();
1✔
215
    if (activeList.isEmpty()) {
1✔
216
      // No READY subchannels
217

218
      // MultiChildLB will request connection immediately on subchannel IDLE.
219
      boolean isConnecting = false;
1✔
220
      for (ChildLbState childLbState : getChildLbStates()) {
1✔
221
        ConnectivityState state = childLbState.getCurrentState();
1✔
222
        if (state == ConnectivityState.CONNECTING || state == ConnectivityState.IDLE) {
1✔
223
          isConnecting = true;
1✔
224
          break;
1✔
225
        }
226
      }
1✔
227

228
      if (isConnecting) {
1✔
229
        updateBalancingState(
1✔
230
            ConnectivityState.CONNECTING, new FixedResultPicker(PickResult.withNoResult()));
1✔
231
      } else {
232
        updateBalancingState(
1✔
233
            ConnectivityState.TRANSIENT_FAILURE, createReadyPicker(getChildLbStates()));
1✔
234
      }
235
    } else {
1✔
236
      updateBalancingState(ConnectivityState.READY, createReadyPicker(activeList));
1✔
237
    }
238
  }
1✔
239

240
  private SubchannelPicker createReadyPicker(Collection<ChildLbState> activeList) {
241
    WeightedRoundRobinPicker picker = new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList),
1✔
242
        config.enableOobLoadReport, config.errorUtilizationPenalty, sequence,
243
        config.parsedMetricNamesForComputingUtilization);
244
    updateWeight(picker);
1✔
245
    return picker;
1✔
246
  }
247

248
  private void updateWeight(WeightedRoundRobinPicker picker) {
249
    Helper helper = getHelper();
1✔
250
    float[] newWeights = new float[picker.children.size()];
1✔
251
    AtomicInteger staleEndpoints = new AtomicInteger();
1✔
252
    AtomicInteger notYetUsableEndpoints = new AtomicInteger();
1✔
253
    for (int i = 0; i < picker.children.size(); i++) {
1✔
254
      double newWeight = ((WeightedChildLbState) picker.children.get(i)).getWeight(staleEndpoints,
1✔
255
          notYetUsableEndpoints);
256
      helper.getMetricRecorder()
1✔
257
          .recordDoubleHistogram(ENDPOINT_WEIGHTS_HISTOGRAM, newWeight,
1✔
258
              ImmutableList.of(helper.getChannelTarget()),
1✔
259
              ImmutableList.of(locality, backendService));
1✔
260
      newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f;
1✔
261
    }
262

263
    if (staleEndpoints.get() > 0) {
1✔
264
      helper.getMetricRecorder()
1✔
265
          .addLongCounter(ENDPOINT_WEIGHT_STALE_COUNTER, staleEndpoints.get(),
1✔
266
              ImmutableList.of(helper.getChannelTarget()),
1✔
267
              ImmutableList.of(locality, backendService));
1✔
268
    }
269
    if (notYetUsableEndpoints.get() > 0) {
1✔
270
      helper.getMetricRecorder()
1✔
271
          .addLongCounter(ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER, notYetUsableEndpoints.get(),
1✔
272
              ImmutableList.of(helper.getChannelTarget()),
1✔
273
              ImmutableList.of(locality, backendService));
1✔
274
    }
275
    boolean weightsEffective = picker.updateWeight(newWeights);
1✔
276
    if (!weightsEffective) {
1✔
277
      helper.getMetricRecorder()
1✔
278
          .addLongCounter(RR_FALLBACK_COUNTER, 1, ImmutableList.of(helper.getChannelTarget()),
1✔
279
              ImmutableList.of(locality, backendService));
1✔
280
    }
281
  }
1✔
282

283
  private void updateBalancingState(ConnectivityState state, SubchannelPicker picker) {
284
    if (state != currentConnectivityState || !picker.equals(currentPicker)) {
1✔
285
      getHelper().updateBalancingState(state, picker);
1✔
286
      currentConnectivityState = state;
1✔
287
      currentPicker = picker;
1✔
288
    }
289
  }
1✔
290

291
  @VisibleForTesting
292
  final class WeightedChildLbState extends ChildLbState {
293

294
    private final Set<WrrSubchannel> subchannels = new HashSet<>();
1✔
295
    private volatile long lastUpdated;
296
    private volatile long nonEmptySince;
297
    private volatile double weight = 0;
1✔
298

299
    private OrcaReportListener orcaReportListener;
300

301
    public WeightedChildLbState(Object key, LoadBalancerProvider policyProvider) {
1✔
302
      super(key, policyProvider);
1✔
303
    }
1✔
304

305
    @Override
306
    protected ChildLbStateHelper createChildHelper() {
307
      return new WrrChildLbStateHelper();
1✔
308
    }
309

310
    private double getWeight(AtomicInteger staleEndpoints, AtomicInteger notYetUsableEndpoints) {
311
      if (config == null) {
1✔
312
        return 0;
×
313
      }
314
      long now = ticker.nanoTime();
1✔
315
      if (now - lastUpdated >= config.weightExpirationPeriodNanos) {
1✔
316
        nonEmptySince = infTime;
1✔
317
        staleEndpoints.incrementAndGet();
1✔
318
        return 0;
1✔
319
      } else if (now - nonEmptySince < config.blackoutPeriodNanos
1✔
320
          && config.blackoutPeriodNanos > 0) {
1✔
321
        notYetUsableEndpoints.incrementAndGet();
1✔
322
        return 0;
1✔
323
      } else {
324
        return weight;
1✔
325
      }
326
    }
327

328
    public void addSubchannel(WrrSubchannel wrrSubchannel) {
329
      subchannels.add(wrrSubchannel);
1✔
330
    }
1✔
331

332
    public OrcaReportListener getOrCreateOrcaListener(float errorUtilizationPenalty,
333
        ImmutableList<ParsedMetricName> parsedMetricNamesForComputingUtilization) {
334
      if (orcaReportListener != null
1✔
335
          && orcaReportListener.errorUtilizationPenalty == errorUtilizationPenalty
1✔
336
          && orcaReportListener.parsedMetricNamesForComputingUtilization
1✔
337
              .equals(parsedMetricNamesForComputingUtilization)) {
1✔
338
        return orcaReportListener;
1✔
339
      }
340
      orcaReportListener =
1✔
341
          new OrcaReportListener(errorUtilizationPenalty, parsedMetricNamesForComputingUtilization);
342
      return orcaReportListener;
1✔
343
    }
344

345
    public void removeSubchannel(WrrSubchannel wrrSubchannel) {
346
      subchannels.remove(wrrSubchannel);
1✔
347
    }
1✔
348

349
    final class WrrChildLbStateHelper extends ChildLbStateHelper {
1✔
350
      @Override
351
      public Subchannel createSubchannel(CreateSubchannelArgs args) {
352
        return new WrrSubchannel(super.createSubchannel(args), WeightedChildLbState.this);
1✔
353
      }
354

355
      @Override
356
      public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) {
357
        super.updateBalancingState(newState, newPicker);
1✔
358
        if (!resolvingAddresses && newState == ConnectivityState.IDLE) {
1✔
359
          getLb().requestConnection();
×
360
        }
361
      }
1✔
362
    }
363

364
    final class OrcaReportListener implements OrcaPerRequestReportListener, OrcaOobReportListener {
365
      private final float errorUtilizationPenalty;
366
      private final ImmutableList<ParsedMetricName> parsedMetricNamesForComputingUtilization;
367

368
      OrcaReportListener(float errorUtilizationPenalty,
369
          ImmutableList<ParsedMetricName> parsedMetricNamesForComputingUtilization) {
1✔
370
        this.errorUtilizationPenalty = errorUtilizationPenalty;
1✔
371
        this.parsedMetricNamesForComputingUtilization = parsedMetricNamesForComputingUtilization;
1✔
372
      }
1✔
373

374
      @Override
375
      public void onLoadReport(MetricReport report) {
376
        double utilization = getUtilization(report);
1✔
377

378
        double newWeight = 0;
1✔
379
        if (utilization > 0 && report.getQps() > 0) {
1✔
380
          double penalty = 0;
1✔
381
          if (report.getEps() > 0 && errorUtilizationPenalty > 0) {
1✔
382
            penalty = report.getEps() / report.getQps() * errorUtilizationPenalty;
1✔
383
          }
384
          newWeight = report.getQps() / (utilization + penalty);
1✔
385
        }
386
        if (newWeight == 0) {
1✔
387
          return;
1✔
388
        }
389
        if (nonEmptySince == infTime) {
1✔
390
          nonEmptySince = ticker.nanoTime();
1✔
391
        }
392
        lastUpdated = ticker.nanoTime();
1✔
393
        weight = newWeight;
1✔
394
      }
1✔
395

396
      /**
397
       * Returns the utilization value computed from the specified metric names. If the custom
398
       * metrics are present and valid, the maximum of the custom metrics is returned. Otherwise,
399
       * if application utilization is > 0, it is returned. If neither are present, the CPU
400
       * utilization is returned.
401
       */
402
      private double getUtilization(MetricReport report) {
403
        OptionalDouble customUtil = getCustomMetricUtilization(report);
1✔
404
        if (customUtil.isPresent()) {
1✔
405
          return customUtil.getAsDouble();
1✔
406
        }
407
        double appUtil = report.getApplicationUtilization();
1✔
408
        if (appUtil > 0) {
1✔
409
          return appUtil;
1✔
410
        }
411
        return report.getCpuUtilization();
1✔
412
      }
413

414
      /**
415
       * Returns the maximum utilization value among the parsed metric names.
416
       * Returns OptionalDouble.empty() if NONE of the specified metrics are present in the report,
417
       * or if all present metrics are NaN or non positive.
418
       */
419
      private OptionalDouble getCustomMetricUtilization(MetricReport report) {
420
        OptionalDouble max = OptionalDouble.empty();
1✔
421
        for (int i = 0; i < parsedMetricNamesForComputingUtilization.size(); i++) {
1✔
422
          OptionalDouble opt = MetricReportUtils.getMetricValue(report,
1✔
423
              parsedMetricNamesForComputingUtilization.get(i));
1✔
424
          if (opt.isPresent()) {
1✔
425
            double d = opt.getAsDouble();
1✔
426
            if (!Double.isNaN(d) && d > 0 && (!max.isPresent() || d > max.getAsDouble())) {
1✔
427
              max = opt;
1✔
428
            }
429
          }
430
        }
431
        return max;
1✔
432
      }
433
    }
434
  }
435

436
  private final class UpdateWeightTask implements Runnable {
1✔
437
    @Override
438
    public void run() {
439
      if (currentPicker != null && currentPicker instanceof WeightedRoundRobinPicker) {
1✔
440
        updateWeight((WeightedRoundRobinPicker) currentPicker);
1✔
441
      }
442
      weightUpdateTimer = syncContext.schedule(this, config.weightUpdatePeriodNanos,
1✔
443
          TimeUnit.NANOSECONDS, timeService);
1✔
444
    }
1✔
445
  }
446

447
  private void createAndApplyOrcaListeners() {
448
    for (ChildLbState child : getChildLbStates()) {
1✔
449
      WeightedChildLbState wChild = (WeightedChildLbState) child;
1✔
450
      for (WrrSubchannel weightedSubchannel : wChild.subchannels) {
1✔
451
        if (config.enableOobLoadReport) {
1✔
452
          OrcaOobUtil.setListener(weightedSubchannel,
1✔
453
              wChild.getOrCreateOrcaListener(config.errorUtilizationPenalty,
1✔
454
                      config.parsedMetricNamesForComputingUtilization),
455
              OrcaOobUtil.OrcaReportingConfig.newBuilder()
1✔
456
                  .setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS).build());
1✔
457
        } else {
458
          OrcaOobUtil.setListener(weightedSubchannel, null, null);
1✔
459
        }
460
      }
1✔
461
    }
1✔
462
  }
1✔
463

464
  @Override
465
  public void shutdown() {
466
    if (weightUpdateTimer != null) {
1✔
467
      weightUpdateTimer.cancel();
1✔
468
    }
469
    super.shutdown();
1✔
470
  }
1✔
471

472
  @VisibleForTesting
473
  final class WrrSubchannel extends ForwardingSubchannel {
474
    private final Subchannel delegate;
475
    private final WeightedChildLbState owner;
476

477
    WrrSubchannel(Subchannel delegate, WeightedChildLbState owner) {
1✔
478
      this.delegate = checkNotNull(delegate, "delegate");
1✔
479
      this.owner = checkNotNull(owner, "owner");
1✔
480
    }
1✔
481

482
    @Override
483
    public void start(SubchannelStateListener listener) {
484
      owner.addSubchannel(this);
1✔
485
      delegate().start(new SubchannelStateListener() {
1✔
486
        @Override
487
        public void onSubchannelState(ConnectivityStateInfo newState) {
488
          if (newState.getState().equals(ConnectivityState.READY)) {
1✔
489
            owner.nonEmptySince = infTime;
1✔
490
          }
491
          listener.onSubchannelState(newState);
1✔
492
        }
1✔
493
      });
494
    }
1✔
495

496
    @Override
497
    protected Subchannel delegate() {
498
      return delegate;
1✔
499
    }
500

501
    @Override
502
    public void shutdown() {
503
      super.shutdown();
1✔
504
      owner.removeSubchannel(this);
1✔
505
    }
1✔
506
  }
507

508
  @VisibleForTesting
509
  static final class WeightedRoundRobinPicker extends SubchannelPicker {
510
    // Parallel lists (column-based storage instead of normal row-based storage of List<Struct>).
511
    // The ith element of children corresponds to the ith element of pickers, listeners, and even
512
    // updateWeight(float[]).
513
    private final List<ChildLbState> children; // May only be accessed from sync context
514
    private final List<SubchannelPicker> pickers;
515
    private final List<OrcaPerRequestReportListener> reportListeners;
516
    private final boolean enableOobLoadReport;
517
    private final float errorUtilizationPenalty;
518
    private final AtomicInteger sequence;
519
    private final int hashCode;
520
    private volatile StaticStrideScheduler scheduler;
521

522
    WeightedRoundRobinPicker(List<ChildLbState> children, boolean enableOobLoadReport,
523
        float errorUtilizationPenalty, AtomicInteger sequence,
524
        ImmutableList<ParsedMetricName> parsedMetricNamesForComputingUtilization) {
1✔
525
      checkNotNull(children, "children");
1✔
526
      Preconditions.checkArgument(!children.isEmpty(), "empty child list");
1✔
527
      this.children = children;
1✔
528
      List<SubchannelPicker> pickers = new ArrayList<>(children.size());
1✔
529
      List<OrcaPerRequestReportListener> reportListeners = new ArrayList<>(children.size());
1✔
530
      for (ChildLbState child : children) {
1✔
531
        WeightedChildLbState wChild = (WeightedChildLbState) child;
1✔
532
        pickers.add(wChild.getCurrentPicker());
1✔
533
        reportListeners.add(wChild.getOrCreateOrcaListener(errorUtilizationPenalty,
1✔
534
            parsedMetricNamesForComputingUtilization));
535
      }
1✔
536
      this.pickers = pickers;
1✔
537
      this.reportListeners = reportListeners;
1✔
538
      this.enableOobLoadReport = enableOobLoadReport;
1✔
539
      this.errorUtilizationPenalty = errorUtilizationPenalty;
1✔
540
      this.sequence = checkNotNull(sequence, "sequence");
1✔
541

542
      // For equality we treat pickers as a set; use hash code as defined by Set
543
      int sum = 0;
1✔
544
      for (SubchannelPicker picker : pickers) {
1✔
545
        sum += picker.hashCode();
1✔
546
      }
1✔
547
      this.hashCode = sum
1✔
548
          ^ Boolean.hashCode(enableOobLoadReport)
1✔
549
          ^ Float.hashCode(errorUtilizationPenalty);
1✔
550
    }
1✔
551

552
    @Override
553
    public PickResult pickSubchannel(PickSubchannelArgs args) {
554
      int pick = scheduler.pick();
1✔
555
      PickResult pickResult = pickers.get(pick).pickSubchannel(args);
1✔
556
      Subchannel subchannel = pickResult.getSubchannel();
1✔
557
      if (subchannel == null) {
1✔
558
        return pickResult;
1✔
559
      }
560
      
561
      subchannel = ((WrrSubchannel) subchannel).delegate();
1✔
562
      if (!enableOobLoadReport) {
1✔
563
        return pickResult.copyWithSubchannel(subchannel)
1✔
564
            .copyWithStreamTracerFactory(
1✔
565
                OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(
1✔
566
                    reportListeners.get(pick)));
1✔
567
      } else {
568
        return pickResult.copyWithSubchannel(subchannel);
1✔
569
      }
570
    }
571

572
    /** Returns {@code true} if weights are different than round_robin. */
573
    private boolean updateWeight(float[] newWeights) {
574
      this.scheduler = new StaticStrideScheduler(newWeights, sequence);
1✔
575
      return !this.scheduler.usesRoundRobin();
1✔
576
    }
577

578
    @Override
579
    public String toString() {
580
      return MoreObjects.toStringHelper(WeightedRoundRobinPicker.class)
1✔
581
          .add("enableOobLoadReport", enableOobLoadReport)
1✔
582
          .add("errorUtilizationPenalty", errorUtilizationPenalty)
1✔
583
          .add("pickers", pickers)
1✔
584
          .toString();
1✔
585
    }
586

587
    @VisibleForTesting
588
    List<ChildLbState> getChildren() {
589
      return children;
1✔
590
    }
591

592
    @Override
593
    public int hashCode() {
594
      return hashCode;
×
595
    }
596

597
    @Override
598
    public boolean equals(Object o) {
599
      if (!(o instanceof WeightedRoundRobinPicker)) {
1✔
600
        return false;
×
601
      }
602
      WeightedRoundRobinPicker other = (WeightedRoundRobinPicker) o;
1✔
603
      if (other == this) {
1✔
604
        return true;
×
605
      }
606
      // the lists cannot contain duplicate subchannels
607
      return hashCode == other.hashCode
1✔
608
          && sequence == other.sequence
609
          && enableOobLoadReport == other.enableOobLoadReport
610
          && Float.compare(errorUtilizationPenalty, other.errorUtilizationPenalty) == 0
1✔
611
          && pickers.size() == other.pickers.size()
1✔
612
          && new HashSet<>(pickers).containsAll(other.pickers);
1✔
613
    }
614
  }
615

616
  /*
617
   * The Static Stride Scheduler is an implementation of an earliest deadline first (EDF) scheduler
618
   * in which each object's deadline is the multiplicative inverse of the object's weight.
619
   * <p>
620
   * The way in which this is implemented is through a static stride scheduler. 
621
   * The Static Stride Scheduler works by iterating through the list of subchannel weights
622
   * and using modular arithmetic to proportionally distribute picks, favoring entries 
623
   * with higher weights. It is based on the observation that the intended sequence generated 
624
   * from an EDF scheduler is a periodic one that can be achieved through modular arithmetic. 
625
   * The Static Stride Scheduler is more performant than other implementations of the EDF
626
   * Scheduler, as it removes the need for a priority queue (and thus mutex locks).
627
   * <p>
628
   * go/static-stride-scheduler
629
   * <p>
630
   *
631
   * <ul>
632
   *  <li>nextSequence() - O(1)
633
   *  <li>pick() - O(n)
634
   */
635
  @VisibleForTesting
636
  static final class StaticStrideScheduler {
637
    private final short[] scaledWeights;
638
    private final AtomicInteger sequence;
639
    private final boolean usesRoundRobin;
640
    private static final int K_MAX_WEIGHT = 0xFFFF;
641

642
    // Assuming the mean of all known weights is M, StaticStrideScheduler will clamp
643
    // weights bigger than M*kMaxRatio and weights smaller than M*kMinRatio.
644
    //
645
    // This is done as a performance optimization by limiting the number of rounds for picks
646
    // for edge cases where channels have large differences in subchannel weights.
647
    // In this case, without these clips, it would potentially require the scheduler to
648
    // frequently traverse through the entire subchannel list within the pick method.
649
    //
650
    // The current values of 10 and 0.1 were chosen without any experimenting. It should
651
    // decrease the amount of sequences that the scheduler must traverse through in order
652
    // to pick a high weight subchannel in such corner cases.
653
    // But, it also makes WeightedRoundRobin to send slightly more requests to
654
    // potentially very bad tasks (that would have near-zero weights) than zero.
655
    // This is not necessarily a downside, though. Perhaps this is not a problem at
656
    // all, and we can increase this value if needed to save CPU cycles.
657
    private static final double K_MAX_RATIO = 10;
658
    private static final double K_MIN_RATIO = 0.1;
659

660
    StaticStrideScheduler(float[] weights, AtomicInteger sequence) {
1✔
661
      checkArgument(weights.length >= 1, "Couldn't build scheduler: requires at least one weight");
1✔
662
      int numChannels = weights.length;
1✔
663
      int numWeightedChannels = 0;
1✔
664
      double sumWeight = 0;
1✔
665
      double unscaledMeanWeight;
666
      float unscaledMaxWeight = 0;
1✔
667
      for (float weight : weights) {
1✔
668
        if (weight > 0) {
1✔
669
          sumWeight += weight;
1✔
670
          unscaledMaxWeight = Math.max(weight, unscaledMaxWeight);
1✔
671
          numWeightedChannels++;
1✔
672
        }
673
      }
674

675
      // Adjust max value s.t. ratio does not exceed K_MAX_RATIO. This should
676
      // ensure that we on average do at most K_MAX_RATIO rounds for picks.
677
      if (numWeightedChannels > 0) {
1✔
678
        unscaledMeanWeight = sumWeight / numWeightedChannels;
1✔
679
        unscaledMaxWeight = Math.min(unscaledMaxWeight, (float) (K_MAX_RATIO * unscaledMeanWeight));
1✔
680
      } else {
681
        // Fall back to round robin if all values are non-positives. Note that
682
        // numWeightedChannels == 1 also behaves like RR because the weights are all the same, but
683
        // the weights aren't 1, so it doesn't go through this path.
684
        unscaledMeanWeight = 1;
1✔
685
        unscaledMaxWeight = 1;
1✔
686
      }
687
      // We need at least two weights for WRR to be distinguishable from round_robin.
688
      usesRoundRobin = numWeightedChannels < 2;
1✔
689

690
      // Scales weights s.t. max(weights) == K_MAX_WEIGHT, meanWeight is scaled accordingly.
691
      // Note that, since we cap the weights to stay within K_MAX_RATIO, meanWeight might not
692
      // match the actual mean of the values that end up in the scheduler.
693
      double scalingFactor = K_MAX_WEIGHT / unscaledMaxWeight;
1✔
694
      // We compute weightLowerBound and clamp it to 1 from below so that in the
695
      // worst case, we represent tiny weights as 1.
696
      int weightLowerBound = (int) Math.ceil(scalingFactor * unscaledMeanWeight * K_MIN_RATIO);
1✔
697
      short[] scaledWeights = new short[numChannels];
1✔
698
      for (int i = 0; i < numChannels; i++) {
1✔
699
        if (weights[i] <= 0) {
1✔
700
          scaledWeights[i] = (short) Math.round(scalingFactor * unscaledMeanWeight);
1✔
701
        } else {
702
          int weight = (int) Math.round(scalingFactor * Math.min(weights[i], unscaledMaxWeight));
1✔
703
          scaledWeights[i] = (short) Math.max(weight, weightLowerBound);
1✔
704
        }
705
      }
706

707
      this.scaledWeights = scaledWeights;
1✔
708
      this.sequence = sequence;
1✔
709
    }
1✔
710

711
    // Without properly weighted channels, we do plain vanilla round_robin.
712
    boolean usesRoundRobin() {
713
      return usesRoundRobin;
1✔
714
    }
715

716
    /**
717
     * Returns the next sequence number and atomically increases sequence with wraparound.
718
     */
719
    private long nextSequence() {
720
      return Integer.toUnsignedLong(sequence.getAndIncrement());
1✔
721
    }
722

723
    /*
724
     * Selects index of next backend server.
725
     * <p>
726
     * A 2D array is compactly represented as a function of W(backend), where the row
727
     * represents the generation and the column represents the backend index:
728
     * X(backend,generation) | generation ∈ [0,kMaxWeight).
729
     * Each element in the conceptual array is a boolean indicating whether the backend at
730
     * this index should be picked now. If false, the counter is incremented again,
731
     * and the new element is checked. An atomically incremented counter keeps track of our
732
     * backend and generation through modular arithmetic within the pick() method.
733
     * <p>
734
     * Modular arithmetic allows us to evenly distribute picks and skips between
735
     * generations based on W(backend).
736
     * X(backend,generation) = (W(backend) * generation) % kMaxWeight >= kMaxWeight - W(backend)
737
     * If we have the same three backends with weights:
738
     * W(backend) = {2,3,6} scaled to max(W(backend)) = 6, then X(backend,generation) is:
739
     * <p>
740
     * B0    B1    B2
741
     * T     T     T
742
     * F     F     T
743
     * F     T     T
744
     * T     F     T
745
     * F     T     T
746
     * F     F     T
747
     * The sequence of picked backend indices is given by
748
     * walking across and down: {0,1,2,2,1,2,0,2,1,2,2}.
749
     * <p>
750
     * To reduce the variance and spread the wasted work among different picks,
751
     * an offset that varies per backend index is also included to the calculation.
752
     */
753
    int pick() {
754
      while (true) {
755
        long sequence = this.nextSequence();
1✔
756
        int backendIndex = (int) (sequence % scaledWeights.length);
1✔
757
        long generation = sequence / scaledWeights.length;
1✔
758
        int weight = Short.toUnsignedInt(scaledWeights[backendIndex]);
1✔
759
        long offset = (long) K_MAX_WEIGHT / 2 * backendIndex;
1✔
760
        if ((weight * generation + offset) % K_MAX_WEIGHT < K_MAX_WEIGHT - weight) {
1✔
761
          continue;
1✔
762
        }
763
        return backendIndex;
1✔
764
      }
765
    }
766
  }
767

768
  static final class WeightedRoundRobinLoadBalancerConfig {
769
    final long blackoutPeriodNanos;
770
    final long weightExpirationPeriodNanos;
771
    final boolean enableOobLoadReport;
772
    final long oobReportingPeriodNanos;
773
    final long weightUpdatePeriodNanos;
774
    final float errorUtilizationPenalty;
775
    final ImmutableList<ParsedMetricName> parsedMetricNamesForComputingUtilization;
776

777
    public static Builder newBuilder() {
778
      return new Builder();
1✔
779
    }
780

781
    private WeightedRoundRobinLoadBalancerConfig(long blackoutPeriodNanos,
782
        long weightExpirationPeriodNanos, boolean enableOobLoadReport, long oobReportingPeriodNanos,
783
        long weightUpdatePeriodNanos, float errorUtilizationPenalty,
784
        ImmutableList<String> metricNamesForComputingUtilization) {
1✔
785
      this.blackoutPeriodNanos = blackoutPeriodNanos;
1✔
786
      this.weightExpirationPeriodNanos = weightExpirationPeriodNanos;
1✔
787
      this.enableOobLoadReport = enableOobLoadReport;
1✔
788
      this.oobReportingPeriodNanos = oobReportingPeriodNanos;
1✔
789
      this.weightUpdatePeriodNanos = weightUpdatePeriodNanos;
1✔
790
      this.errorUtilizationPenalty = errorUtilizationPenalty;
1✔
791

792
      ImmutableList.Builder<ParsedMetricName> builder = ImmutableList.builder();
1✔
793
      if (metricNamesForComputingUtilization != null) {
1✔
794
        for (int i = 0; i < metricNamesForComputingUtilization.size(); i++) {
1✔
795
          String metricName = metricNamesForComputingUtilization.get(i);
1✔
796
          ParsedMetricName parsed = MetricReportUtils.ParsedMetricName.parse(metricName);
1✔
797
          if (parsed.getMetricType() != MetricReportUtils.MetricType.INVALID) {
1✔
798
            builder.add(parsed);
1✔
799
          } else {
800
            log.log(Level.FINE, "Invalid custom metric name configured and ignored: " + metricName);
1✔
801
          }
802
        }
803
      }
804
      this.parsedMetricNamesForComputingUtilization = builder.build();
1✔
805
    }
1✔
806

807
    @Override
808
    public boolean equals(Object o) {
809
      if (!(o instanceof WeightedRoundRobinLoadBalancerConfig)) {
1✔
810
        return false;
1✔
811
      }
812
      WeightedRoundRobinLoadBalancerConfig that = (WeightedRoundRobinLoadBalancerConfig) o;
1✔
813
      return this.blackoutPeriodNanos == that.blackoutPeriodNanos
1✔
814
          && this.weightExpirationPeriodNanos == that.weightExpirationPeriodNanos
815
          && this.enableOobLoadReport == that.enableOobLoadReport
816
          && this.oobReportingPeriodNanos == that.oobReportingPeriodNanos
817
          && this.weightUpdatePeriodNanos == that.weightUpdatePeriodNanos
818
          // Float.compare considers NaNs equal
819
          && Float.compare(this.errorUtilizationPenalty, that.errorUtilizationPenalty) == 0
1✔
820
          && Objects.equals(this.parsedMetricNamesForComputingUtilization,
1✔
821
              that.parsedMetricNamesForComputingUtilization);
822
    }
823

824
    @Override
825
    public int hashCode() {
826
      return Objects.hash(blackoutPeriodNanos, weightExpirationPeriodNanos, enableOobLoadReport,
1✔
827
          oobReportingPeriodNanos, weightUpdatePeriodNanos, errorUtilizationPenalty,
1✔
828
          parsedMetricNamesForComputingUtilization);
829
    }
830

831
    static final class Builder {
832
      long blackoutPeriodNanos = 10_000_000_000L; // 10s
1✔
833
      long weightExpirationPeriodNanos = 180_000_000_000L; // 3min
1✔
834
      boolean enableOobLoadReport = false;
1✔
835
      long oobReportingPeriodNanos = 10_000_000_000L; // 10s
1✔
836
      long weightUpdatePeriodNanos = 1_000_000_000L; // 1s
1✔
837
      float errorUtilizationPenalty = 1.0F;
1✔
838
      ImmutableList<String> metricNamesForComputingUtilization = ImmutableList.of();
1✔
839

840
      private Builder() {
1✔
841

842
      }
1✔
843

844
      @SuppressWarnings("UnusedReturnValue")
845
      Builder setBlackoutPeriodNanos(long blackoutPeriodNanos) {
846
        this.blackoutPeriodNanos = blackoutPeriodNanos;
1✔
847
        return this;
1✔
848
      }
849

850
      @SuppressWarnings("UnusedReturnValue")
851
      Builder setWeightExpirationPeriodNanos(long weightExpirationPeriodNanos) {
852
        this.weightExpirationPeriodNanos = weightExpirationPeriodNanos;
1✔
853
        return this;
1✔
854
      }
855

856
      Builder setEnableOobLoadReport(boolean enableOobLoadReport) {
857
        this.enableOobLoadReport = enableOobLoadReport;
1✔
858
        return this;
1✔
859
      }
860

861
      Builder setOobReportingPeriodNanos(long oobReportingPeriodNanos) {
862
        this.oobReportingPeriodNanos = oobReportingPeriodNanos;
1✔
863
        return this;
1✔
864
      }
865

866
      Builder setWeightUpdatePeriodNanos(long weightUpdatePeriodNanos) {
867
        this.weightUpdatePeriodNanos = weightUpdatePeriodNanos;
1✔
868
        return this;
1✔
869
      }
870

871
      Builder setErrorUtilizationPenalty(float errorUtilizationPenalty) {
872
        this.errorUtilizationPenalty = errorUtilizationPenalty;
1✔
873
        return this;
1✔
874
      }
875

876
      Builder setMetricNamesForComputingUtilization(
877
          List<String> metricNamesForComputingUtilization) {
878
        this.metricNamesForComputingUtilization =
1✔
879
            ImmutableList.copyOf(metricNamesForComputingUtilization);
1✔
880
        return this;
1✔
881
      }
882

883
      WeightedRoundRobinLoadBalancerConfig build() {
884
        return new WeightedRoundRobinLoadBalancerConfig(blackoutPeriodNanos,
1✔
885
            weightExpirationPeriodNanos, enableOobLoadReport, oobReportingPeriodNanos,
886
            weightUpdatePeriodNanos, errorUtilizationPenalty, metricNamesForComputingUtilization);
887
      }
888
    }
889
  }
890
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc