• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

microsoft / botbuilder-dotnet / 385666

04 Mar 2024 04:19PM UTC coverage: 78.385% (-0.006%) from 78.391%
385666

push

CI-PR build

web-flow
Microsoft.IdentityModel.Protocols.OpenIdConnect bump (#6756)

Co-authored-by: Tracy Boehrer <trboehre@microsoft.com>

26179 of 33398 relevant lines covered (78.38%)

0.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.09
/libraries/Microsoft.Bot.Connector.Streaming/Session/StreamingSession.cs
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
using System;
5
using System.Buffers;
6
using System.Collections.Concurrent;
7
using System.Collections.Generic;
8
using System.IO;
9
using System.Linq;
10
using System.Net;
11
using System.Runtime.InteropServices;
12
using System.Threading;
13
using System.Threading.Tasks;
14
using Microsoft.Bot.Connector.Streaming.Payloads;
15
using Microsoft.Bot.Connector.Streaming.Transport;
16
using Microsoft.Bot.Streaming;
17
using Microsoft.Bot.Streaming.Payloads;
18
using Microsoft.Extensions.Logging;
19
using Microsoft.Extensions.Logging.Abstractions;
20
using Microsoft.Net.Http.Headers;
21
using Newtonsoft.Json;
22
using Newtonsoft.Json.Serialization;
23

24
namespace Microsoft.Bot.Connector.Streaming.Session
25
{
26
    internal class StreamingSession
27
    {
28
        // Utf byte order mark constant as defined
29
        // Dotnet runtime: https://github.com/dotnet/runtime/blob/main/src/libraries/System.Text.Json/src/System/Text/Json/JsonConstants.cs#L35
30
        // Unicode.org spec: https://www.unicode.org/faq/utf_bom.html#bom5
31
        private static byte[] _utf8Bom = { 0xEF, 0xBB, 0xBF };
1✔
32

33
        private readonly Dictionary<Guid, StreamDefinition> _streamDefinitions = new Dictionary<Guid, StreamDefinition>();
1✔
34
        private readonly Dictionary<Guid, ReceiveRequest> _requests = new Dictionary<Guid, ReceiveRequest>();
1✔
35
        private readonly Dictionary<Guid, ReceiveResponse> _responses = new Dictionary<Guid, ReceiveResponse>();
1✔
36
        private readonly ConcurrentDictionary<Guid, TaskCompletionSource<ReceiveResponse>> _pendingResponses = new ConcurrentDictionary<Guid, TaskCompletionSource<ReceiveResponse>>();
1✔
37

38
        private readonly RequestHandler _receiver;
39
        private readonly TransportHandler _sender;
40

41
        private readonly ILogger _logger;
42
        private readonly CancellationToken _connectionCancellationToken;
43

44
        private readonly object _receiveSync = new object();
1✔
45

46
        static StreamingSession()
47
        {
48
            var tmpSetting = SerializationSettings.DefaultSerializationSettings;
1✔
49
            tmpSetting.NullValueHandling = NullValueHandling.Ignore;
1✔
50
            Serializer = JsonSerializer.Create(tmpSetting);            
1✔
51
        }
1✔
52

53
        public StreamingSession(RequestHandler receiver, TransportHandler sender, ILogger logger, CancellationToken connectionCancellationToken = default)
1✔
54
        {
55
            _receiver = receiver ?? throw new ArgumentNullException(nameof(receiver));
1✔
56
            _sender = sender ?? throw new ArgumentNullException(nameof(sender));
1✔
57
            _sender.Subscribe(new ProtocolDispatcher(this));
1✔
58

59
            _logger = logger ?? NullLogger.Instance;
×
60
            _connectionCancellationToken = connectionCancellationToken;
1✔
61
        }
1✔
62

63
        private static JsonSerializer Serializer { get; }
1✔
64

65
        public async Task<ReceiveResponse> SendRequestAsync(StreamingRequest request, CancellationToken cancellationToken)
66
        {
67
            if (request == null)
1✔
68
            {
69
                throw new ArgumentNullException(nameof(request));
1✔
70
            }
71

72
            var payload = new RequestPayload()
1✔
73
            {
1✔
74
                Verb = request.Verb,
1✔
75
                Path = request.Path,
1✔
76
            };
1✔
77

78
            if (request.Streams != null)
1✔
79
            {
80
                payload.Streams = new List<StreamDescription>();
1✔
81
                foreach (var contentStream in request.Streams)
1✔
82
                {
83
                    var description = GetStreamDescription(contentStream);
1✔
84

85
                    payload.Streams.Add(description);
1✔
86
                }
87
            }
88

89
            var requestId = Guid.NewGuid();
1✔
90

91
            var responseCompletionSource = new TaskCompletionSource<ReceiveResponse>();
1✔
92
            _pendingResponses.TryAdd(requestId, responseCompletionSource);
1✔
93

94
            // Send request
95
            await _sender.SendRequestAsync(requestId, payload, cancellationToken).ConfigureAwait(false);
1✔
96

97
            if (request.Streams != null)
1✔
98
            {
99
                foreach (var stream in request.Streams)
1✔
100
                {
101
                    await _sender.SendStreamAsync(stream.Id, await stream.Content.ReadAsStreamAsync().ConfigureAwait(false), cancellationToken).ConfigureAwait(false);
1✔
102
                }
103
            }
104

105
            return await responseCompletionSource.Task.DefaultTimeOutAsync().ConfigureAwait(false);
1✔
106
        }
1✔
107

108
        public async Task SendResponseAsync(Header header, StreamingResponse response, CancellationToken cancellationToken)
109
        {
110
            if (header == null)
1✔
111
            {
112
                throw new ArgumentNullException(nameof(header));
1✔
113
            }
114

115
            if (header.Type != PayloadTypes.Response)
1✔
116
            {
117
                throw new InvalidOperationException($"StreamingSession SendResponseAsync expected Response payload, but instead received a payload of type {header.Type}");
1✔
118
            }
119

120
            if (response == null)
1✔
121
            {
122
                throw new ArgumentNullException(nameof(response));
1✔
123
            }
124

125
            var payload = new ResponsePayload()
1✔
126
            {
1✔
127
                StatusCode = response.StatusCode,
1✔
128
            };
1✔
129

130
            if (response.Streams != null)
1✔
131
            {
132
                payload.Streams = new List<StreamDescription>();
×
133
                foreach (var contentStream in response.Streams)
×
134
                {
135
                    var description = GetStreamDescription(contentStream);
×
136

137
                    payload.Streams.Add(description);
×
138
                }
139
            }
140

141
            await _sender.SendResponseAsync(header.Id, payload, cancellationToken).ConfigureAwait(false);
1✔
142

143
            if (response.Streams != null)
1✔
144
            {
145
                foreach (var stream in response.Streams)
×
146
                {
147
                    await _sender.SendStreamAsync(stream.Id, await stream.Content.ReadAsStreamAsync().ConfigureAwait(false), cancellationToken).ConfigureAwait(false);
×
148
                }
149
            }
150
        }
1✔
151

152
        public virtual void ReceiveRequest(Header header, ReceiveRequest request)
153
        {
154
            if (header == null)
1✔
155
            {
156
                throw new ArgumentNullException(nameof(header));
1✔
157
            }
158

159
            if (header.Type != PayloadTypes.Request)
1✔
160
            {
161
                throw new InvalidOperationException($"StreamingSession cannot receive payload of type {header.Type} as request.");
1✔
162
            }
163

164
            if (request == null)
1✔
165
            {
166
                throw new ArgumentNullException(nameof(request));
1✔
167
            }
168

169
            Log.PayloadReceived(_logger, header);
1✔
170

171
            lock (_receiveSync)
1✔
172
            {
173
                _requests.Add(header.Id, request);
1✔
174

175
                if (request.Streams.Any())
1✔
176
                {
177
                    foreach (var streamDefinition in request.Streams)
1✔
178
                    {
179
                        _streamDefinitions.Add(streamDefinition.Id, streamDefinition as StreamDefinition);
1✔
180
                    }
181
                }
182
                else
183
                {
184
                    ProcessRequest(header.Id, request);
1✔
185
                }
186
            }
1✔
187
        }
1✔
188

189
        public virtual void ReceiveResponse(Header header, ReceiveResponse response)
190
        {
191
            if (header == null)
1✔
192
            {
193
                throw new ArgumentNullException(nameof(header));
1✔
194
            }
195

196
            if (header.Type != PayloadTypes.Response)
1✔
197
            {
198
                throw new InvalidOperationException($"StreamingSession cannot receive payload of type {header.Type} as response");
1✔
199
            }
200

201
            if (response == null)
1✔
202
            {
203
                throw new ArgumentNullException(nameof(response));
1✔
204
            }
205

206
            Log.PayloadReceived(_logger, header);
1✔
207

208
            if (response.StatusCode == (int)HttpStatusCode.Accepted)
1✔
209
            {
210
                return;
1✔
211
            }
212

213
            lock (_receiveSync)
1✔
214
            {
215
                if (!response.Streams.Any())
1✔
216
                {
217
                    if (_pendingResponses.TryGetValue(header.Id, out TaskCompletionSource<ReceiveResponse> responseTask))
1✔
218
                    {
219
                        responseTask.SetResult(response);
1✔
220
                        _pendingResponses.TryRemove(header.Id, out TaskCompletionSource<ReceiveResponse> removedResponse);
1✔
221
                    }
222
                }
223
                else
224
                {
225
                    _responses.Add(header.Id, response);
1✔
226

227
                    foreach (var streamDefinition in response.Streams)
1✔
228
                    {
229
                        _streamDefinitions.Add(streamDefinition.Id, streamDefinition as StreamDefinition);
1✔
230
                    }
231
                }
232
            }
×
233
        }
1✔
234

235
        public virtual void ReceiveStream(Header header, ArraySegment<byte> payload)
236
        {
237
            if (header == null)
1✔
238
            {
239
                throw new ArgumentNullException(nameof(header));
1✔
240
            }
241

242
            if (header.Type != PayloadTypes.Stream)
1✔
243
            {
244
                throw new InvalidOperationException($"StreamingSession cannot receive payload of type {header.Type} as stream");
1✔
245
            }
246

247
            if (payload == null)
248
            {
249
                throw new ArgumentNullException(nameof(payload));
250
            }
251

252
            Log.PayloadReceived(_logger, header);
1✔
253

254
            // Find request for incoming stream header
255
            if (_streamDefinitions.TryGetValue(header.Id, out StreamDefinition streamDefinition))
1✔
256
            {
257
                streamDefinition.Stream.Write(payload.Array, payload.Offset, payload.Count);
1✔
258

259
                // Is this the end of this stream?
260
                if (header.End)
1✔
261
                {
262
                    // Mark this stream as completed
263
                    if (streamDefinition is StreamDefinition streamDef)
1✔
264
                    {
265
                        streamDef.Complete = true;
1✔
266
                        streamDef.Stream.Seek(0, SeekOrigin.Begin);
1✔
267

268
                        List<IContentStream> streams = null;
1✔
269

270
                        // Find the request / response
271
                        if (streamDef.PayloadType == PayloadTypes.Request)
1✔
272
                        {
273
                            if (_requests.TryGetValue(streamDef.PayloadId, out ReceiveRequest req))
1✔
274
                            {
275
                                streams = req.Streams;
1✔
276
                            }
277
                        }
278
                        else if (streamDef.PayloadType == PayloadTypes.Response)
1✔
279
                        {
280
                            if (_responses.TryGetValue(streamDef.PayloadId, out ReceiveResponse res))
1✔
281
                            {
282
                                streams = res.Streams;
1✔
283
                            }
284
                        }
285

286
                        if (streams != null)
1✔
287
                        {
288
                            lock (_receiveSync)
1✔
289
                            {
290
                                // Have we completed all the streams we expect for this request?
291
                                bool allStreamsDone = streams.All(s => s is StreamDefinition streamDef && streamDef.Complete);
×
292

293
                                // If we received all the streams, then it's time to pass this request to the request handler!
294
                                // For example, if this request is a send activity, the request handler will deserialize the first stream
295
                                // into an activity and pass to the adapter.
296
                                if (allStreamsDone)
1✔
297
                                {
298
                                    if (streamDef.PayloadType == PayloadTypes.Request)
1✔
299
                                    {
300
                                        if (_requests.TryGetValue(streamDef.PayloadId, out ReceiveRequest request))
1✔
301
                                        {
302
                                            ProcessRequest(streamDef.PayloadId, request);
1✔
303
                                            _requests.Remove(streamDef.PayloadId);
1✔
304
                                        }
305
                                    }
306
                                    else if (streamDef.PayloadType == PayloadTypes.Response)
1✔
307
                                    {
308
                                        if (_responses.TryGetValue(streamDef.PayloadId, out ReceiveResponse response))
1✔
309
                                        {
310
                                            if (_pendingResponses.TryGetValue(streamDef.PayloadId, out TaskCompletionSource<ReceiveResponse> responseTask))
1✔
311
                                            {
312
                                                responseTask.SetResult(response);
1✔
313
                                                _responses.Remove(streamDef.PayloadId);
1✔
314
                                                _pendingResponses.TryRemove(streamDef.PayloadId, out TaskCompletionSource<ReceiveResponse> removedResponse);
1✔
315
                                            }
316
                                        }
317
                                    }
318
                                }
319
                            }
1✔
320
                        }
321
                    }
322
                }
323
            }
324
            else
325
            {
326
                Log.OrphanedStream(_logger, header);
×
327
            }
328
        }
1✔
329

330
        private static StreamDescription GetStreamDescription(ResponseMessageStream stream)
331
        {
332
            var description = new StreamDescription()
1✔
333
            {
1✔
334
                Id = stream.Id.ToString("D"),
1✔
335
            };
1✔
336

337
            if (stream.Content.Headers.TryGetValues(HeaderNames.ContentType, out IEnumerable<string> contentType))
1✔
338
            {
339
                description.ContentType = contentType?.FirstOrDefault();
×
340
            }
341

342
            if (stream.Content.Headers.TryGetValues(HeaderNames.ContentLength, out IEnumerable<string> contentLength))
1✔
343
            {
344
                var value = contentLength?.FirstOrDefault();
×
345
                if (value != null && int.TryParse(value, out int length))
×
346
                {
347
                    description.Length = length;
×
348
                }
349
            }
350
            else
351
            {
352
                description.Length = (int?)stream.Content.Headers.ContentLength;
×
353
            }
354

355
            return description;
1✔
356
        }
357

358
        private static ArraySegment<byte> GetArraySegment(ReadOnlySequence<byte> sequence)
359
        {
360
            if (sequence.IsSingleSegment)
1✔
361
            {
362
                if (MemoryMarshal.TryGetArray(sequence.First, out ArraySegment<byte> segment))
1✔
363
                {
364
                    return segment;
1✔
365
                }
366
            }
367

368
            // Can be optimized by not copying but should be uncommon. If perf data shows that we are hitting this
369
            // code branch, then we can optimize and avoid copies and heap allocations.
370
            return new ArraySegment<byte>(sequence.ToArray());
1✔
371
        }
372

373
        private void ProcessRequest(Guid id, ReceiveRequest request)
374
        {
375
            _ = Task.Run(async () =>
1✔
376
            {
1✔
377
                // Send an HTTP 202 (Accepted) response right away, otherwise, while under high streaming load, the conversation times out due to not having a response in the request/response time frame.
1✔
378
                await SendResponseAsync(new Header { Id = id, Type = PayloadTypes.Response }, new StreamingResponse { StatusCode = (int)HttpStatusCode.Accepted }, _connectionCancellationToken).ConfigureAwait(false);
1✔
379
                var streamingResponse = await _receiver.ProcessRequestAsync(request, null).ConfigureAwait(false);
1✔
380
                await SendResponseAsync(new Header() { Id = id, Type = PayloadTypes.Response }, streamingResponse, _connectionCancellationToken).ConfigureAwait(false);
1✔
381

1✔
382
                request.Streams.ForEach(s => _streamDefinitions.Remove(s.Id));
1✔
383
            });
1✔
384
        }
1✔
385

386
        internal class ProtocolDispatcher : IObserver<(Header Header, ReadOnlySequence<byte> Payload)>
387
        {
388
            private readonly StreamingSession _streamingSession;
389

390
            public ProtocolDispatcher(StreamingSession streamingSession)
1✔
391
            {
392
                _streamingSession = streamingSession ?? throw new ArgumentNullException(nameof(streamingSession));
1✔
393
            }
1✔
394

395
            public void OnCompleted()
396
            {
397
                throw new NotImplementedException();
×
398
            }
399

400
            public void OnError(Exception error)
401
            {
402
                throw new NotImplementedException();
×
403
            }
404

405
            public void OnNext((Header Header, ReadOnlySequence<byte> Payload) frame)
406
            {
407
                var header = frame.Header;
1✔
408
                var payload = frame.Payload;
1✔
409

410
                switch (header.Type)
1✔
411
                {
412
                    case PayloadTypes.Stream:
413
                        _streamingSession.ReceiveStream(header, GetArraySegment(payload));
1✔
414

415
                        break;
1✔
416
                    case PayloadTypes.Request:
417

418
                        var requestPayload = DeserializeTo<RequestPayload>(payload);
1✔
419
                        var request = new ReceiveRequest()
1✔
420
                        {
1✔
421
                            Verb = requestPayload.Verb,
1✔
422
                            Path = requestPayload.Path,
1✔
423
                            Streams = new List<IContentStream>(),
1✔
424
                        };
1✔
425

426
                        CreatePlaceholderStreams(header, request.Streams, requestPayload.Streams);
1✔
427
                        _streamingSession.ReceiveRequest(header, request);
1✔
428

429
                        break;
1✔
430

431
                    case PayloadTypes.Response:
432

433
                        var responsePayload = DeserializeTo<ResponsePayload>(payload);
1✔
434
                        var response = new ReceiveResponse()
1✔
435
                        {
1✔
436
                            StatusCode = responsePayload.StatusCode,
1✔
437
                            Streams = new List<IContentStream>(),
1✔
438
                        };
1✔
439

440
                        CreatePlaceholderStreams(header, response.Streams, responsePayload.Streams);
1✔
441
                        _streamingSession.ReceiveResponse(header, response);
1✔
442

443
                        break;
444

445
                    case PayloadTypes.CancelAll:
446
                        break;
447

448
                    case PayloadTypes.CancelStream:
449
                        break;
450
                }
451
            }
1✔
452

453
            private static T DeserializeTo<T>(ReadOnlySequence<byte> payload)
454
            {
455
                // The payload here will likely have a UTF-8 byte-order-mark (BOM). 
456
                // The JsonSerializer and UtfJsonReader explicitly expect no BOM in this overload that takes a ReadOnlySequence<byte>.
457
                // With that in mind, we check for a UTF-8 BOM and remove it if present. The main reason to call this specific flow instead of
458
                // the stream version or using Json.Net is that the ReadOnlySequence<T> API allows us to do a no-copy deserialization.
459
                // The ReadOnlySequence was allocated from the memory pool by the transport layer and gets sent all the way here without copies.
460

461
                // Check for UTF-8 BOM and remove if present: https://docs.microsoft.com/en-us/dotnet/standard/serialization/system-text-json-use-dom-utf8jsonreader-utf8jsonwriter?pivots=dotnet-5-0#filter-data-using-utf8jsonreader 
462
                var potentialBomSequence = payload.Slice(payload.Start, _utf8Bom.Length);
1✔
463
                var potentialBomSpan = potentialBomSequence.IsSingleSegment
×
464
                    ? potentialBomSequence.First.Span
×
465
                    : potentialBomSequence.ToArray();
×
466

467
                ReadOnlySequence<byte> mainPayload = payload;
1✔
468

469
                if (potentialBomSpan.StartsWith(_utf8Bom))
1✔
470
                {
471
                    mainPayload = payload.Slice(_utf8Bom.Length);
1✔
472
                }
473

474
                using (var ms = new MemoryStream(mainPayload.ToArray()))
1✔
475
                {
476
                    using (var sr = new StreamReader(ms))
1✔
477
                    {
478
                        using (var jsonReader = new JsonTextReader(sr))
1✔
479
                        {
480
                            return Serializer.Deserialize<T>(jsonReader);
1✔
481
                        }                        
482
                    }                    
483
                }
484
            }
1✔
485

486
            private static void CreatePlaceholderStreams(Header header, List<IContentStream> placeholders, List<StreamDescription> streamInfo)
487
            {
488
                if (streamInfo != null)
1✔
489
                {
490
                    foreach (var streamDescription in streamInfo)
1✔
491
                    {
492
                        if (!Guid.TryParse(streamDescription.Id, out Guid id))
1✔
493
                        {
494
                            throw new InvalidDataException($"Stream description id '{streamDescription.Id}' is not a Guid");
×
495
                        }
496

497
                        placeholders.Add(new StreamDefinition()
1✔
498
                        {
1✔
499
                            ContentType = streamDescription.ContentType,
1✔
500
                            Length = streamDescription.Length,
1✔
501
                            Id = Guid.Parse(streamDescription.Id),
1✔
502
                            Stream = new MemoryStream(),
1✔
503
                            PayloadType = header.Type,
1✔
504
                            PayloadId = header.Id
1✔
505
                        });
1✔
506
                    }
507
                }
508
            }
1✔
509
        }
510

511
        internal class StreamDefinition : IContentStream
512
        {
513
            public Guid Id { get; set; }
1✔
514

515
            public string ContentType { get; set; }
1✔
516

517
            public int? Length { get; set; }
1✔
518

519
            public Stream Stream { get; set; }
1✔
520

521
            public bool Complete { get; set; }
1✔
522

523
            public char PayloadType { get; set; }
1✔
524

525
            public Guid PayloadId { get; set; }
1✔
526
        }
527

528
        private class Log
529
        {
530
            private static readonly Action<ILogger, Guid, char, int, bool, Exception> _orphanedStream =
1✔
531
                LoggerMessage.Define<Guid, char, int, bool>(LogLevel.Error, new EventId(1, nameof(OrphanedStream)), "Stream has no associated payload. Header: ID {Guid} Type: {char} Payload length: {int} End: {bool}");
1✔
532

533
            private static readonly Action<ILogger, Guid, char, int, bool, Exception> _payloadReceived =
1✔
534
                LoggerMessage.Define<Guid, char, int, bool>(LogLevel.Debug, new EventId(2, nameof(PayloadReceived)), "Payload received in session. Header: ID {Guid} Type: {char} Payload length: {int} End: {bool}");
1✔
535

536
            public static void OrphanedStream(ILogger logger, Header header) => _orphanedStream(logger, header.Id, header.Type, header.PayloadLength, header.End, null);
×
537

538
            public static void PayloadReceived(ILogger logger, Header header) => _payloadReceived(logger, header.Id, header.Type, header.PayloadLength, header.End, null);
1✔
539
        }
540
    }
541
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc