• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

tari-project / tari / 17297453740

28 Aug 2025 01:33PM UTC coverage: 61.046% (+0.9%) from 60.14%
17297453740

push

github

web-flow
chore(ci): add a wasm build step in ci (#7448)

Description
Add a wasm build step in ci

Motivation and Context
Test the wasm builds

How Has This Been Tested?
Builds in local fork


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Chores**
* Added CI workflow to build WebAssembly targets with optimized caching
on both hosted and self-hosted runners, improving build consistency and
speed.
* **Tests**
* Expanded automated checks to include WebAssembly build verification
for multiple modules, increasing coverage and early detection of build
issues.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

72582 of 118897 relevant lines covered (61.05%)

301536.67 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.53
/comms/core/src/noise/socket.rs
1
// Copyright 2019, The Tari Project
2
//
3
// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the
4
// following conditions are met:
5
//
6
// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following
7
// disclaimer.
8
//
9
// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the
10
// following disclaimer in the documentation and/or other materials provided with the distribution.
11
//
12
// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote
13
// products derived from this software without specific prior written permission.
14
//
15
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
16
// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
18
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
19
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
20
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
21
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
22

23
// This file is a slightly modified version of the Libra NoiseSocket implementation.
24
// Copyright (c) The Libra Core Contributors
25
// SPDX-License-Identifier: Apache-2.0
26

27
//! Noise Socket
28

29
use std::{
30
    cmp,
31
    convert::TryInto,
32
    io,
33
    pin::Pin,
34
    task::{Context, Poll},
35
    time::Duration,
36
};
37

38
use futures::ready;
39
use log::*;
40
use snow::{error::StateProblem, HandshakeState, TransportState};
41
use tari_utilities::ByteArray;
42
use tokio::{
43
    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf},
44
    time,
45
};
46

47
use crate::types::CommsPublicKey;
48

49
const LOG_TARGET: &str = "comms::noise::socket";
50

51
const MAX_PAYLOAD_LENGTH: usize = u16::MAX as usize; // 65535
52

53
// The maximum number of bytes that we can buffer is 16 bytes less than u16::max_value() because
54
// encrypted messages include a tag along with the payload.
55
const MAX_WRITE_BUFFER_LENGTH: usize = u16::MAX as usize - 16; // 65519
56

57
/// Collection of buffers used for buffering data during the various read/write states of a
58
/// NoiseSocket
59
struct NoiseBuffers {
60
    /// Encrypted frame read from the wire
61
    read_encrypted: [u8; MAX_PAYLOAD_LENGTH],
62
    /// Decrypted data read from the wire (produced by having snow decrypt the `read_encrypted`
63
    /// buffer)
64
    read_decrypted: [u8; MAX_PAYLOAD_LENGTH],
65
    /// Unencrypted data intended to be written to the wire
66
    write_decrypted: [u8; MAX_WRITE_BUFFER_LENGTH],
67
    /// Encrypted data to write to the wire (produced by having snow encrypt the `write_decrypted`
68
    /// buffer)
69
    write_encrypted: [u8; MAX_PAYLOAD_LENGTH],
70
}
71

72
impl NoiseBuffers {
73
    fn new() -> Self {
128✔
74
        Self {
128✔
75
            read_encrypted: [0; MAX_PAYLOAD_LENGTH],
128✔
76
            read_decrypted: [0; MAX_PAYLOAD_LENGTH],
128✔
77
            write_decrypted: [0; MAX_WRITE_BUFFER_LENGTH],
128✔
78
            write_encrypted: [0; MAX_PAYLOAD_LENGTH],
128✔
79
        }
128✔
80
    }
128✔
81
}
82

83
/// Hand written Debug implementation in order to omit the printing of huge buffers of data
84
impl ::std::fmt::Debug for NoiseBuffers {
85
    fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
×
86
        f.debug_struct("NoiseBuffers").finish()
×
87
    }
×
88
}
89

90
/// Possible read states for a [NoiseSocket]
91
#[derive(Debug)]
92
enum ReadState {
93
    /// Initial State
94
    Init,
95
    /// Read frame length
96
    ReadFrameLen { buf: [u8; 2], offset: usize },
97
    /// Read encrypted frame
98
    ReadFrame { frame_len: u16, offset: usize },
99
    /// Copy decrypted frame to provided buffer
100
    CopyDecryptedFrame { decrypted_len: usize, offset: usize },
101
    /// End of file reached, result indicated if EOF was expected or not
102
    Eof(Result<(), ()>),
103
    /// Decryption Error
104
    DecryptionError(snow::Error),
105
}
106

107
/// Possible write states for a [NoiseSocket]
108
#[derive(Debug)]
109
enum WriteState {
110
    /// Initial State
111
    Init,
112
    /// Buffer provided data
113
    BufferData { offset: usize },
114
    /// Write frame length to the wire
115
    WriteFrameLen {
116
        frame_len: u16,
117
        buf: [u8; 2],
118
        offset: usize,
119
    },
120
    /// Write encrypted frame to the wire
121
    WriteEncryptedFrame { frame_len: u16, offset: usize },
122
    /// Flush the underlying socket
123
    Flush,
124
    /// End of file reached
125
    Eof,
126
    /// Encryption Error
127
    EncryptionError(snow::Error),
128
}
129

130
/// A Noise session with a remote
131
///
132
/// Encrypts data to be written to and decrypts data that is read from the underlying socket using
133
/// the noise protocol. This is done by wrapping noise payloads in u16 (big endian) length prefix
134
/// frames.
135
#[derive(Debug)]
136
pub struct NoiseSocket<TSocket> {
137
    socket: TSocket,
138
    state: NoiseState,
139
    buffers: Box<NoiseBuffers>,
140
    read_state: ReadState,
141
    write_state: WriteState,
142
}
143

144
impl<TSocket> NoiseSocket<TSocket> {
145
    fn new(socket: TSocket, session: NoiseState) -> Self {
128✔
146
        Self {
128✔
147
            socket,
128✔
148
            state: session,
128✔
149
            buffers: Box::new(NoiseBuffers::new()),
128✔
150
            read_state: ReadState::Init,
128✔
151
            write_state: WriteState::Init,
128✔
152
        }
128✔
153
    }
128✔
154

155
    /// Get the raw remote static key
156
    pub fn get_remote_static(&self) -> Option<&[u8]> {
118✔
157
        self.state.get_remote_static()
118✔
158
    }
118✔
159

160
    /// Get the remote static key as a CommsPublicKey
161
    pub fn get_remote_public_key(&self) -> Option<CommsPublicKey> {
116✔
162
        self.get_remote_static()
116✔
163
            .and_then(|s| CommsPublicKey::from_canonical_bytes(s).ok())
116✔
164
    }
116✔
165
}
166

167
fn poll_write_all<TSocket>(
297,060✔
168
    context: &mut Context,
297,060✔
169
    mut socket: Pin<&mut TSocket>,
297,060✔
170
    buf: &[u8],
297,060✔
171
    offset: &mut usize,
297,060✔
172
) -> Poll<io::Result<()>>
297,060✔
173
where
297,060✔
174
    TSocket: AsyncWrite,
297,060✔
175
{
297,060✔
176
    loop {
177
        let bytes = match buf.get(*offset..) {
297,060✔
178
            Some(bytes) => bytes,
297,060✔
179
            None => {
180
                return Poll::Ready(Err(io::Error::new(
×
181
                    io::ErrorKind::InvalidInput,
×
182
                    "Offset exceeds buffer length",
×
183
                )));
×
184
            },
185
        };
186
        let n = ready!(socket.as_mut().poll_write(context, bytes))?;
297,060✔
187
        trace!(
287,654✔
188
            target: LOG_TARGET,
×
189
            "poll_write_all: wrote {}/{} bytes",
×
190
            *offset + n,
×
191
            buf.len()
×
192
        );
193
        if n == 0 {
287,654✔
194
            return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
×
195
        }
287,654✔
196
        *offset += n;
287,654✔
197
        assert!(*offset <= buf.len());
287,654✔
198

199
        if *offset == buf.len() {
287,654✔
200
            return Poll::Ready(Ok(()));
287,654✔
201
        }
×
202
    }
203
}
297,060✔
204

205
/// Read a u16 frame length from `socket`.
206
///
207
/// Can result in the following output:
208
/// 1) Ok(None) => EOF; remote graceful shutdown
209
/// 2) Err(UnexpectedEOF) => read 1 byte then hit EOF; remote died
210
/// 3) Ok(Some(n)) => new frame of length n
211
fn poll_read_u16frame_len<TSocket>(
150,758✔
212
    context: &mut Context,
150,758✔
213
    socket: Pin<&mut TSocket>,
150,758✔
214
    buf: &mut [u8; 2],
150,758✔
215
    offset: &mut usize,
150,758✔
216
) -> Poll<io::Result<Option<u16>>>
150,758✔
217
where
150,758✔
218
    TSocket: AsyncRead,
150,758✔
219
{
150,758✔
220
    match ready!(poll_read_exact(context, socket, buf, offset)) {
150,758✔
221
        Ok(()) => Poll::Ready(Ok(Some(u16::from_be_bytes(*buf)))),
143,749✔
222
        Err(e) => {
30✔
223
            if *offset == 0 && e.kind() == io::ErrorKind::UnexpectedEof {
30✔
224
                return Poll::Ready(Ok(None));
30✔
225
            }
×
226
            Poll::Ready(Err(e))
×
227
        },
228
    }
229
}
150,758✔
230

231
fn poll_read_exact<TSocket>(
294,906✔
232
    context: &mut Context,
294,906✔
233
    mut socket: Pin<&mut TSocket>,
294,906✔
234
    buf: &mut [u8],
294,906✔
235
    offset: &mut usize,
294,906✔
236
) -> Poll<io::Result<()>>
294,906✔
237
where
294,906✔
238
    TSocket: AsyncRead,
294,906✔
239
{
294,906✔
240
    loop {
241
        let bytes = match buf.get_mut(*offset..) {
294,911✔
242
            Some(bytes) => bytes,
294,911✔
243
            None => {
244
                return Poll::Ready(Err(io::Error::new(
×
245
                    io::ErrorKind::InvalidInput,
×
246
                    "Offset exceeds buffer length",
×
247
                )));
×
248
            },
249
        };
250
        let mut read_buf = ReadBuf::new(bytes);
294,911✔
251
        let prev_rem = read_buf.remaining();
294,911✔
252
        ready!(socket.as_mut().poll_read(context, &mut read_buf))?;
294,911✔
253
        let n = prev_rem
287,533✔
254
            .checked_sub(read_buf.remaining())
287,533✔
255
            .ok_or_else(|| io::Error::other("buffer underflow: prev_rem < read_buf.remaining()"))?;
287,533✔
256
        trace!(
287,533✔
257
            target: LOG_TARGET,
×
258
            "poll_read_exact: read {}/{} bytes",
×
259
            *offset + n,
×
260
            buf.len()
×
261
        );
262
        if n == 0 {
287,533✔
263
            return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into()));
30✔
264
        }
287,503✔
265
        *offset += n;
287,503✔
266
        assert!(*offset <= buf.len());
287,503✔
267

268
        if *offset == buf.len() {
287,503✔
269
            return Poll::Ready(Ok(()));
287,498✔
270
        }
5✔
271
    }
272
}
294,906✔
273
impl<TSocket> NoiseSocket<TSocket>
274
where TSocket: AsyncRead + Unpin
275
{
276
    #[allow(clippy::too_many_lines)]
277
    fn poll_read(&mut self, context: &mut Context, buf: &mut [u8]) -> Poll<io::Result<usize>> {
296,731✔
278
        loop {
279
            trace!(target: LOG_TARGET, "NoiseSocket ReadState::{:?}", self.read_state);
728,122✔
280
            match self.read_state {
30✔
281
                ReadState::Init => {
143,863✔
282
                    self.read_state = ReadState::ReadFrameLen { buf: [0, 0], offset: 0 };
143,863✔
283
                },
143,863✔
284
                ReadState::ReadFrameLen {
285
                    ref mut buf,
150,758✔
286
                    ref mut offset,
150,758✔
287
                } => {
288
                    match ready!(poll_read_u16frame_len(context, Pin::new(&mut self.socket), buf, offset)) {
150,758✔
289
                        Ok(Some(frame_len)) => {
143,749✔
290
                            // Empty Frame
143,749✔
291
                            if frame_len == 0 {
143,749✔
292
                                self.read_state = ReadState::Init;
×
293
                            } else {
143,749✔
294
                                self.read_state = ReadState::ReadFrame { frame_len, offset: 0 };
143,749✔
295
                            }
143,749✔
296
                        },
297
                        Ok(None) => {
30✔
298
                            self.read_state = ReadState::Eof(Ok(()));
30✔
299
                        },
30✔
300
                        Err(e) => {
×
301
                            if e.kind() == io::ErrorKind::UnexpectedEof {
×
302
                                self.read_state = ReadState::Eof(Err(()));
×
303
                            }
×
304
                            return Poll::Ready(Err(e));
×
305
                        },
306
                    }
307
                },
308
                ReadState::ReadFrame {
309
                    frame_len,
144,148✔
310
                    ref mut offset,
144,148✔
311
                } => {
312
                    let bytes = match self.buffers.read_encrypted.get_mut(..(frame_len as usize)) {
144,148✔
313
                        Some(bytes) => bytes,
144,148✔
314
                        None => {
315
                            return Poll::Ready(Err(io::Error::new(
×
316
                                io::ErrorKind::InvalidInput,
×
317
                                "frame length exceeds buffer length",
×
318
                            )));
×
319
                        },
320
                    };
321
                    match ready!(poll_read_exact(context, Pin::new(&mut self.socket), bytes, offset)) {
144,148✔
322
                        Ok(()) => match self.state.read_message(bytes, &mut self.buffers.read_decrypted) {
143,749✔
323
                            Ok(decrypted_len) => {
143,749✔
324
                                self.read_state = ReadState::CopyDecryptedFrame {
143,749✔
325
                                    decrypted_len,
143,749✔
326
                                    offset: 0,
143,749✔
327
                                };
143,749✔
328
                            },
143,749✔
329
                            Err(e) => {
×
330
                                warn!(target: LOG_TARGET, "Decryption Error: {e}");
×
331
                                self.read_state = ReadState::DecryptionError(e);
×
332
                            },
333
                        },
334
                        Err(e) => {
×
335
                            if e.kind() == io::ErrorKind::UnexpectedEof {
×
336
                                self.read_state = ReadState::Eof(Err(()));
×
337
                            }
×
338
                            return Poll::Ready(Err(e));
×
339
                        },
340
                    }
341
                },
342
                ReadState::CopyDecryptedFrame {
343
                    decrypted_len,
289,323✔
344
                    ref mut offset,
289,323✔
345
                } => {
289,323✔
346
                    let num_bytes_to_copy = cmp::min(decrypted_len - *offset, buf.len());
289,323✔
347
                    let bytes_to_copy = match self.buffers.read_decrypted.get(*offset..(*offset + num_bytes_to_copy)) {
289,323✔
348
                        Some(bytes) => bytes,
289,323✔
349
                        None => {
350
                            return Poll::Ready(Err(io::Error::new(
×
351
                                io::ErrorKind::InvalidInput,
×
352
                                "Offset exceeds buffer length",
×
353
                            )));
×
354
                        },
355
                    };
356
                    buf.get_mut(..num_bytes_to_copy)
289,323✔
357
                        .expect("this is checked")
289,323✔
358
                        .copy_from_slice(bytes_to_copy);
289,323✔
359
                    trace!(
289,323✔
360
                        target: LOG_TARGET,
×
361
                        "CopyDecryptedFrame: copied {}/{} bytes",
×
362
                        *offset + num_bytes_to_copy,
×
363
                        decrypted_len
364
                    );
365
                    *offset += num_bytes_to_copy;
289,323✔
366
                    if *offset == decrypted_len {
289,323✔
367
                        self.read_state = ReadState::Init;
143,749✔
368
                    }
145,574✔
369
                    return Poll::Ready(Ok(num_bytes_to_copy));
289,323✔
370
                },
371
                ReadState::Eof(Ok(())) => return Poll::Ready(Ok(0)),
30✔
372
                ReadState::Eof(Err(())) => return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())),
×
373
                ReadState::DecryptionError(ref e) => {
×
374
                    return Poll::Ready(Err(io::Error::new(
×
375
                        io::ErrorKind::InvalidData,
×
376
                        format!("DecryptionError: {e}"),
×
377
                    )))
×
378
                },
379
            }
380
        }
381
    }
296,731✔
382
}
383

384
impl<TSocket> AsyncRead for NoiseSocket<TSocket>
385
where TSocket: AsyncRead + Unpin
386
{
387
    fn poll_read(self: Pin<&mut Self>, context: &mut Context, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
296,731✔
388
        let slice = buf.initialize_unfilled();
296,731✔
389
        let n = futures::ready!(self.get_mut().poll_read(context, slice))?;
296,731✔
390
        buf.advance(n);
289,353✔
391
        Poll::Ready(Ok(()))
289,353✔
392
    }
296,731✔
393
}
394

395
impl<TSocket> NoiseSocket<TSocket>
396
where TSocket: AsyncWrite + Unpin
397
{
398
    #[allow(clippy::too_many_lines)]
399
    fn poll_write_or_flush(&mut self, context: &mut Context, buf: Option<&[u8]>) -> Poll<io::Result<Option<usize>>> {
591,380✔
400
        loop {
401
            trace!(
1,310,573✔
402
                target: LOG_TARGET,
×
403
                "NoiseSocket {} WriteState::{:?}",
×
404
                if buf.is_some() { "poll_write" } else { "poll_flush" },
×
405
                self.write_state,
406
            );
407
            match self.write_state {
1,310,573✔
408
                WriteState::Init => {
409
                    if buf.is_some() {
436,227✔
410
                        self.write_state = WriteState::BufferData { offset: 0 };
143,885✔
411
                    } else {
143,885✔
412
                        return Poll::Ready(Ok(None));
292,342✔
413
                    }
414
                },
415
                WriteState::BufferData { ref mut offset } => {
433,459✔
416
                    let bytes_buffered = if let Some(buf) = buf {
433,459✔
417
                        let num_bytes_to_copy = ::std::cmp::min(MAX_WRITE_BUFFER_LENGTH - *offset, buf.len());
289,632✔
418
                        let bytes = match buf.get(..num_bytes_to_copy) {
289,632✔
419
                            Some(bytes) => bytes,
289,632✔
420
                            None => {
421
                                return Poll::Ready(Err(io::Error::new(
×
422
                                    io::ErrorKind::InvalidInput,
×
423
                                    "frame length exceeds buffer length",
×
424
                                )));
×
425
                            },
426
                        };
427
                        self.buffers
289,632✔
428
                            .write_decrypted
289,632✔
429
                            .get_mut(*offset..(*offset + num_bytes_to_copy))
289,632✔
430
                            .expect("this is checked")
289,632✔
431
                            .copy_from_slice(bytes);
289,632✔
432
                        trace!(
289,632✔
433
                            target: LOG_TARGET,
×
434
                            "BufferData: buffered {}/{} bytes",
×
435
                            num_bytes_to_copy,
×
436
                            buf.len()
×
437
                        );
438
                        *offset += num_bytes_to_copy;
289,632✔
439
                        Some(num_bytes_to_copy)
289,632✔
440
                    } else {
441
                        None
143,827✔
442
                    };
443

444
                    if buf.is_none() || *offset == MAX_WRITE_BUFFER_LENGTH {
433,459✔
445
                        let bytes = match self.buffers.write_decrypted.get(..*offset) {
143,831✔
446
                            Some(bytes) => bytes,
143,831✔
447
                            None => {
448
                                return Poll::Ready(Err(io::Error::new(
×
449
                                    io::ErrorKind::InvalidInput,
×
450
                                    "frame length exceeds buffer length",
×
451
                                )));
×
452
                            },
453
                        };
454
                        match self.state.write_message(bytes, &mut self.buffers.write_encrypted) {
143,831✔
455
                            Ok(encrypted_len) => {
143,831✔
456
                                let frame_len = encrypted_len
143,831✔
457
                                    .try_into()
143,831✔
458
                                    .map_err(|_| io::Error::other("offset should be able to fit in u16"))?;
143,831✔
459
                                self.write_state = WriteState::WriteFrameLen {
143,831✔
460
                                    frame_len,
143,831✔
461
                                    buf: u16::to_be_bytes(frame_len),
143,831✔
462
                                    offset: 0,
143,831✔
463
                                };
143,831✔
464
                            },
465
                            Err(e) => {
×
466
                                warn!(target: LOG_TARGET, "Encryption Error: {e}");
×
467
                                let err = io::Error::new(io::ErrorKind::InvalidData, format!("EncryptionError: {e}"));
×
468
                                self.write_state = WriteState::EncryptionError(e);
×
469
                                return Poll::Ready(Err(err));
×
470
                            },
471
                        }
472
                    }
289,628✔
473

474
                    if let Some(bytes_buffered) = bytes_buffered {
433,459✔
475
                        return Poll::Ready(Ok(Some(bytes_buffered)));
289,632✔
476
                    }
143,827✔
477
                },
478
                WriteState::WriteFrameLen {
479
                    frame_len,
153,233✔
480
                    ref buf,
153,233✔
481
                    ref mut offset,
153,233✔
482
                } => match ready!(poll_write_all(context, Pin::new(&mut self.socket), buf, offset)) {
153,233✔
483
                    Ok(()) => {
143,827✔
484
                        self.write_state = WriteState::WriteEncryptedFrame { frame_len, offset: 0 };
143,827✔
485
                    },
143,827✔
486
                    Err(e) => {
4✔
487
                        if e.kind() == io::ErrorKind::WriteZero {
4✔
488
                            self.write_state = WriteState::Eof;
×
489
                        }
4✔
490
                        return Poll::Ready(Err(e));
4✔
491
                    },
492
                },
493
                WriteState::WriteEncryptedFrame {
494
                    frame_len,
143,827✔
495
                    ref mut offset,
143,827✔
496
                } => {
497
                    let bytes = match self.buffers.write_encrypted.get(..(frame_len as usize)) {
143,827✔
498
                        Some(bytes) => bytes,
143,827✔
499
                        None => {
500
                            return Poll::Ready(Err(io::Error::new(
×
501
                                io::ErrorKind::InvalidInput,
×
502
                                "frame length exceeds buffer length",
×
503
                            )));
×
504
                        },
505
                    };
506
                    match ready!(poll_write_all(context, Pin::new(&mut self.socket), bytes, offset)) {
143,827✔
507
                        Ok(()) => {
143,827✔
508
                            self.write_state = WriteState::Flush;
143,827✔
509
                        },
143,827✔
510
                        Err(e) => {
×
511
                            if e.kind() == io::ErrorKind::WriteZero {
×
512
                                self.write_state = WriteState::Eof;
×
513
                            }
×
514
                            return Poll::Ready(Err(e));
×
515
                        },
516
                    }
517
                },
518
                WriteState::Flush => {
519
                    ready!(Pin::new(&mut self.socket).poll_flush(context))?;
143,827✔
520
                    self.write_state = WriteState::Init;
143,827✔
521
                },
522
                WriteState::Eof => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())),
×
523
                WriteState::EncryptionError(ref e) => {
×
524
                    return Poll::Ready(Err(io::Error::new(
×
525
                        io::ErrorKind::InvalidData,
×
526
                        format!("EncryptionError: {e}"),
×
527
                    )))
×
528
                },
529
            }
530
        }
531
    }
591,380✔
532

533
    fn poll_write(&mut self, context: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
297,138✔
534
        if let Some(bytes_written) = ready!(self.poll_write_or_flush(context, Some(buf)))? {
297,138✔
535
            Poll::Ready(Ok(bytes_written))
289,632✔
536
        } else {
537
            unreachable!();
×
538
        }
539
    }
297,138✔
540

541
    fn poll_flush(&mut self, context: &mut Context) -> Poll<io::Result<()>> {
294,242✔
542
        if ready!(self.poll_write_or_flush(context, None))?.is_none() {
294,242✔
543
            Poll::Ready(Ok(()))
292,342✔
544
        } else {
545
            unreachable!();
×
546
        }
547
    }
294,242✔
548
}
549

550
impl<TSocket> AsyncWrite for NoiseSocket<TSocket>
551
where TSocket: AsyncWrite + Unpin
552
{
553
    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
297,138✔
554
        self.get_mut().poll_write(cx, buf)
297,138✔
555
    }
297,138✔
556

557
    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
294,242✔
558
        self.get_mut().poll_flush(cx)
294,242✔
559
    }
294,242✔
560

561
    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
56✔
562
        Pin::new(&mut self.socket).poll_shutdown(cx)
56✔
563
    }
56✔
564
}
565

566
pub struct Handshake<TSocket> {
567
    socket: NoiseSocket<TSocket>,
568
    recv_timeout: Duration,
569
}
570

571
impl<TSocket> Handshake<TSocket> {
572
    pub fn new(socket: TSocket, state: HandshakeState, recv_timeout: Duration) -> Self {
116✔
573
        Self {
116✔
574
            socket: NoiseSocket::new(socket, state.into()),
116✔
575
            recv_timeout,
116✔
576
        }
116✔
577
    }
116✔
578
}
579

580
impl<TSocket> Handshake<TSocket>
581
where TSocket: AsyncRead + AsyncWrite + Unpin
582
{
583
    /// Perform a Single Round-Trip noise IX handshake returning the underlying [NoiseSocket]
584
    /// (switched to transport mode) upon success.
585
    pub async fn perform_handshake(mut self) -> io::Result<NoiseSocket<TSocket>> {
128✔
586
        match self.handshake_1_5rtt().await {
128✔
587
            Ok(_) => self.build(),
128✔
588
            Err(err) => {
×
589
                info!(
×
590
                    target: LOG_TARGET,
×
591
                    "Noise handshake failed because '{err:?}'. Closing socket."
×
592
                );
593
                self.socket.shutdown().await?;
×
594
                Err(err)
×
595
            },
596
        }
597
    }
128✔
598

599
    /// Performs a 1.5 RTT handshake. For example, the noise XX handshake.
600
    async fn handshake_1_5rtt(&mut self) -> io::Result<()> {
128✔
601
        if self.socket.state.is_initiator() {
128✔
602
            //   -> e
603
            self.send().await?;
64✔
604
            self.flush().await?;
64✔
605

606
            // <- e, ee, s, es
607
            self.receive().await?;
64✔
608

609
            //   -> s, se
610
            self.send().await?;
64✔
611
            self.flush().await?;
64✔
612
        } else {
613
            //   -> e
614
            self.receive().await?;
64✔
615

616
            // <- e, ee, s, es
617
            self.send().await?;
64✔
618
            self.flush().await?;
64✔
619

620
            //   -> s, se
621
            self.receive().await?;
64✔
622
        }
623

624
        Ok(())
128✔
625
    }
128✔
626

627
    async fn send(&mut self) -> io::Result<usize> {
192✔
628
        self.socket.write(&[]).await
192✔
629
    }
192✔
630

631
    async fn flush(&mut self) -> io::Result<()> {
192✔
632
        self.socket.flush().await
192✔
633
    }
192✔
634

635
    async fn receive(&mut self) -> io::Result<usize> {
192✔
636
        time::timeout(self.recv_timeout, self.socket.read(&mut []))
192✔
637
            .await
192✔
638
            .map_err(|_| io::Error::from(io::ErrorKind::TimedOut))?
192✔
639
    }
192✔
640

641
    fn build(self) -> io::Result<NoiseSocket<TSocket>> {
128✔
642
        let transport_state = self
128✔
643
            .socket
128✔
644
            .state
128✔
645
            .into_transport_mode()
128✔
646
            .map_err(|err| io::Error::other(format!("Invalid snow state: {err}")))?;
128✔
647

648
        Ok(NoiseSocket {
128✔
649
            state: transport_state,
128✔
650
            ..self.socket
128✔
651
        })
128✔
652
    }
128✔
653
}
654

655
#[derive(Debug)]
656
enum NoiseState {
657
    HandshakeState(Box<HandshakeState>),
658
    TransportState(Box<TransportState>),
659
}
660

661
macro_rules! proxy_state_method {
662
    (pub fn $name:ident(&mut self$(,)? $($arg_name:ident : $arg_type:ty),*) -> $ret:ty) => {
663
        pub fn $name(&mut self, $($arg_name:$arg_type),*) -> $ret {
287,580✔
664
            match self {
287,580✔
665
                NoiseState::HandshakeState(state) => state.$name($($arg_name),*),
384✔
666
                NoiseState::TransportState(state) => state.$name($($arg_name),*),
287,196✔
667
            }
668
        }
287,580✔
669
    };
670
     (pub fn $name:ident(&self$(,)? $($arg_name:ident : $arg_type:ty),*) -> $ret:ty) => {
671
        pub fn $name(&self, $($arg_name:$arg_type),*) -> $ret {
246✔
672
            match self {
246✔
673
                NoiseState::HandshakeState(state) => state.$name($($arg_name),*),
128✔
674
                NoiseState::TransportState(state) => state.$name($($arg_name),*),
118✔
675
            }
676
        }
246✔
677
    }
678
}
679

680
impl NoiseState {
681
    proxy_state_method!(pub fn write_message(&mut self, message: &[u8], payload: &mut [u8]) -> Result<usize, snow::Error>);
682

683
    proxy_state_method!(pub fn is_initiator(&self) -> bool);
684

685
    proxy_state_method!(pub fn read_message(&mut self, message: &[u8], payload: &mut [u8]) -> Result<usize, snow::Error>);
686

687
    proxy_state_method!(pub fn get_remote_static(&self) -> Option<&[u8]>);
688

689
    pub fn into_transport_mode(self) -> Result<Self, snow::Error> {
128✔
690
        match self {
128✔
691
            NoiseState::HandshakeState(state) => Ok(NoiseState::TransportState(Box::new(state.into_transport_mode()?))),
128✔
692
            _ => Err(snow::Error::State(StateProblem::HandshakeAlreadyFinished)),
×
693
        }
694
    }
128✔
695
}
696

697
impl From<HandshakeState> for NoiseState {
698
    fn from(state: HandshakeState) -> Self {
128✔
699
        NoiseState::HandshakeState(Box::new(state))
128✔
700
    }
128✔
701
}
702

703
impl From<TransportState> for NoiseState {
704
    fn from(state: TransportState) -> Self {
×
705
        NoiseState::TransportState(Box::new(state))
×
706
    }
×
707
}
708

709
#[cfg(test)]
710
mod test {
711
    use futures::future::join;
712
    use snow::{params::NoiseParams, Builder, Error, Keypair};
713

714
    use super::*;
715
    use crate::{memsocket::MemorySocket, noise::config::NOISE_PARAMETERS};
716

717
    async fn build_test_connection(
6✔
718
    ) -> Result<((Keypair, Handshake<MemorySocket>), (Keypair, Handshake<MemorySocket>)), Error> {
6✔
719
        let parameters: NoiseParams = NOISE_PARAMETERS.parse().expect("Invalid protocol name");
6✔
720

721
        let dialer_keypair = Builder::new(parameters.clone()).generate_keypair()?;
6✔
722
        let listener_keypair = Builder::new(parameters.clone()).generate_keypair()?;
6✔
723

724
        let dialer_session = Builder::new(parameters.clone())
6✔
725
            .local_private_key(&dialer_keypair.private)
6✔
726
            .build_initiator()?;
6✔
727
        let listener_session = Builder::new(parameters)
6✔
728
            .local_private_key(&listener_keypair.private)
6✔
729
            .build_responder()?;
6✔
730

731
        let (dialer_socket, listener_socket) = MemorySocket::new_pair();
6✔
732
        let (dialer, listener) = (
6✔
733
            NoiseSocket::new(dialer_socket, dialer_session.into()),
6✔
734
            NoiseSocket::new(listener_socket, listener_session.into()),
6✔
735
        );
6✔
736

6✔
737
        Ok((
6✔
738
            (dialer_keypair, Handshake {
6✔
739
                socket: dialer,
6✔
740
                recv_timeout: Duration::from_secs(1),
6✔
741
            }),
6✔
742
            (listener_keypair, Handshake {
6✔
743
                socket: listener,
6✔
744
                recv_timeout: Duration::from_secs(1),
6✔
745
            }),
6✔
746
        ))
6✔
747
    }
6✔
748

749
    async fn perform_handshake(
6✔
750
        dialer: Handshake<MemorySocket>,
6✔
751
        listener: Handshake<MemorySocket>,
6✔
752
    ) -> io::Result<(NoiseSocket<MemorySocket>, NoiseSocket<MemorySocket>)> {
6✔
753
        let (dialer_result, listener_result) = join(dialer.perform_handshake(), listener.perform_handshake()).await;
6✔
754

755
        Ok((dialer_result?, listener_result?))
6✔
756
    }
6✔
757

758
    #[tokio::test]
759
    async fn test_handshake() {
1✔
760
        let ((dialer_keypair, dialer), (listener_keypair, listener)) = build_test_connection().await.unwrap();
1✔
761

1✔
762
        let (dialer_socket, listener_socket) = perform_handshake(dialer, listener).await.unwrap();
1✔
763

1✔
764
        assert_eq!(
1✔
765
            dialer_socket.get_remote_static(),
1✔
766
            Some(listener_keypair.public.as_ref())
1✔
767
        );
1✔
768
        assert_eq!(
1✔
769
            listener_socket.get_remote_static(),
1✔
770
            Some(dialer_keypair.public.as_ref())
1✔
771
        );
1✔
772
    }
1✔
773

774
    #[tokio::test]
775
    async fn simple_test() -> io::Result<()> {
1✔
776
        let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap();
1✔
777

1✔
778
        let (mut dialer_socket, mut listener_socket) = perform_handshake(dialer, listener).await?;
1✔
779

1✔
780
        dialer_socket.write_all(b"stormlight").await?;
1✔
781
        dialer_socket.write_all(b" ").await?;
1✔
782
        dialer_socket.write_all(b"archive").await?;
1✔
783
        dialer_socket.flush().await?;
1✔
784
        dialer_socket.shutdown().await?;
1✔
785

1✔
786
        let mut buf = Vec::new();
1✔
787
        listener_socket.read_to_end(&mut buf).await?;
1✔
788

1✔
789
        assert_eq!(buf, b"stormlight archive");
1✔
790

1✔
791
        Ok(())
1✔
792
    }
1✔
793

794
    #[tokio::test]
795
    async fn interleaved_writes() -> io::Result<()> {
1✔
796
        let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap();
1✔
797

1✔
798
        let (mut a, mut b) = perform_handshake(dialer, listener).await?;
1✔
799

1✔
800
        a.write_all(b"The Name of the Wind").await?;
1✔
801
        a.flush().await?;
1✔
802
        a.write_all(b"The Wise Man's Fear").await?;
1✔
803
        a.flush().await?;
1✔
804

1✔
805
        b.write_all(b"The Doors of Stone").await?;
1✔
806
        b.flush().await?;
1✔
807

1✔
808
        let mut buf = [0; 20];
1✔
809
        b.read_exact(&mut buf).await?;
1✔
810
        assert_eq!(&buf, b"The Name of the Wind");
1✔
811
        let mut buf = [0; 19];
1✔
812
        b.read_exact(&mut buf).await?;
1✔
813
        assert_eq!(&buf, b"The Wise Man's Fear");
1✔
814

1✔
815
        let mut buf = [0; 18];
1✔
816
        a.read_exact(&mut buf).await?;
1✔
817
        assert_eq!(&buf, b"The Doors of Stone");
1✔
818

1✔
819
        Ok(())
1✔
820
    }
1✔
821

822
    #[tokio::test]
823
    async fn u16_max_writes() -> io::Result<()> {
1✔
824
        let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap();
1✔
825

1✔
826
        let (mut a, mut b) = perform_handshake(dialer, listener).await?;
1✔
827

1✔
828
        let buf_send = &[1; MAX_PAYLOAD_LENGTH + 1];
1✔
829
        a.write_all(buf_send).await?;
1✔
830
        a.flush().await?;
1✔
831

1✔
832
        let mut buf_receive = vec![0; MAX_PAYLOAD_LENGTH + 1];
1✔
833
        b.read_exact(&mut buf_receive).await?;
1✔
834
        assert_eq!(&buf_receive[..], &buf_send[..]);
1✔
835

1✔
836
        Ok(())
1✔
837
    }
1✔
838

839
    #[tokio::test]
840
    async fn larger_writes() -> io::Result<()> {
1✔
841
        let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap();
1✔
842

1✔
843
        let (mut a, mut b) = perform_handshake(dialer, listener).await?;
1✔
844

1✔
845
        let buf_send = &[1; MAX_PAYLOAD_LENGTH * 2 + 1024];
1✔
846
        a.write_all(buf_send).await?;
1✔
847
        a.flush().await?;
1✔
848

1✔
849
        let mut buf_receive = vec![0; MAX_PAYLOAD_LENGTH * 2 + 1024];
1✔
850
        b.read_exact(&mut buf_receive).await?;
1✔
851
        assert_eq!(&buf_receive[..], &buf_send[..]);
1✔
852

1✔
853
        Ok(())
1✔
854
    }
1✔
855

856
    #[tokio::test]
857
    async fn unexpected_eof() -> io::Result<()> {
1✔
858
        let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = build_test_connection().await.unwrap();
1✔
859

1✔
860
        let (mut a, mut b) = perform_handshake(dialer, listener).await?;
1✔
861

1✔
862
        let buf_send = &[1; MAX_PAYLOAD_LENGTH];
1✔
863
        a.write_all(buf_send).await?;
1✔
864
        a.flush().await?;
1✔
865

1✔
866
        a.socket.shutdown().await.unwrap();
1✔
867
        drop(a);
1✔
868

1✔
869
        let mut buf_receive = vec![0; MAX_PAYLOAD_LENGTH];
1✔
870
        b.read_exact(&mut buf_receive).await.unwrap();
1✔
871
        assert_eq!(&buf_receive[..], &buf_send[..]);
1✔
872

1✔
873
        let err = b.read_exact(&mut buf_receive).await.unwrap_err();
1✔
874
        assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
1✔
875

1✔
876
        Ok(())
1✔
877
    }
1✔
878
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc