• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

payjoin / rust-payjoin / 16131346449

08 Jul 2025 01:01AM UTC coverage: 85.317% (-0.01%) from 85.329%
16131346449

Pull #852

github

web-flow
Merge 0eb74b9e5 into 25acd561b
Pull Request #852: Fragment parameter fixes

93 of 105 new or added lines in 2 files covered. (88.57%)

4 existing lines in 1 file now uncovered.

7455 of 8738 relevant lines covered (85.32%)

535.99 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

70.97
/payjoin/src/core/ohttp.rs
1
use std::ops::{Deref, DerefMut};
2
use std::{error, fmt};
3

4
use bitcoin::bech32::{self, EncodeError};
5
use bitcoin::key::constants::UNCOMPRESSED_PUBLIC_KEY_SIZE;
6
use hpke::rand_core::{OsRng, RngCore};
7

8
use crate::directory::ENCAPSULATED_MESSAGE_BYTES;
9

10
const N_ENC: usize = UNCOMPRESSED_PUBLIC_KEY_SIZE;
11
const N_T: usize = crate::hpke::POLY1305_TAG_SIZE;
12
const OHTTP_REQ_HEADER_BYTES: usize = 7;
13
pub const PADDED_BHTTP_REQ_BYTES: usize =
14
    ENCAPSULATED_MESSAGE_BYTES - (N_ENC + N_T + OHTTP_REQ_HEADER_BYTES);
15

16
pub fn ohttp_encapsulate(
46✔
17
    ohttp_keys: &mut ohttp::KeyConfig,
46✔
18
    method: &str,
46✔
19
    target_resource: &str,
46✔
20
    body: Option<&[u8]>,
46✔
21
) -> Result<([u8; ENCAPSULATED_MESSAGE_BYTES], ohttp::ClientResponse), OhttpEncapsulationError> {
46✔
22
    use std::fmt::Write;
23

24
    let ctx = ohttp::ClientRequest::from_config(ohttp_keys)?;
46✔
25
    let url = url::Url::parse(target_resource)?;
46✔
26
    let authority_bytes = url.host().map_or_else(Vec::new, |host| {
46✔
27
        let mut authority = host.to_string();
46✔
28
        if let Some(port) = url.port() {
46✔
29
            write!(authority, ":{port}").unwrap();
44✔
30
        }
44✔
31
        authority.into_bytes()
46✔
32
    });
46✔
33
    let mut bhttp_message = bhttp::Message::request(
46✔
34
        method.as_bytes().to_vec(),
46✔
35
        url.scheme().as_bytes().to_vec(),
46✔
36
        authority_bytes,
46✔
37
        url.path().as_bytes().to_vec(),
46✔
38
    );
39
    // None of our messages include headers, so we don't add them
40
    if let Some(body) = body {
46✔
41
        bhttp_message.write_content(body);
30✔
42
    }
30✔
43

44
    let mut bhttp_req = [0u8; PADDED_BHTTP_REQ_BYTES];
46✔
45
    OsRng.fill_bytes(&mut bhttp_req);
46✔
46
    bhttp_message.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_req.as_mut_slice())?;
46✔
47
    let (encapsulated, ohttp_ctx) = ctx.encapsulate(&bhttp_req)?;
46✔
48

49
    let mut buffer = [0u8; ENCAPSULATED_MESSAGE_BYTES];
46✔
50
    let len = encapsulated.len().min(ENCAPSULATED_MESSAGE_BYTES);
46✔
51
    buffer[..len].copy_from_slice(&encapsulated[..len]);
46✔
52
    Ok((buffer, ohttp_ctx))
46✔
53
}
46✔
54

55
#[derive(Debug)]
56
pub enum DirectoryResponseError {
57
    InvalidSize(usize),
58
    OhttpDecapsulation(OhttpEncapsulationError),
59
    UnexpectedStatusCode(http::StatusCode),
60
}
61

62
impl fmt::Display for DirectoryResponseError {
63
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
×
64
        use DirectoryResponseError::*;
65

66
        match self {
×
67
            OhttpDecapsulation(e) => write!(f, "OHTTP Decapsulation Error: {e}"),
×
68
            InvalidSize(size) => write!(
×
69
                f,
×
70
                "Unexpected response size {}, expected {} bytes",
×
71
                size,
72
                crate::directory::ENCAPSULATED_MESSAGE_BYTES
73
            ),
74
            UnexpectedStatusCode(status) => write!(f, "Unexpected status code: {status}"),
×
75
        }
76
    }
×
77
}
78

79
impl error::Error for DirectoryResponseError {
80
    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
×
81
        use DirectoryResponseError::*;
82

83
        match self {
×
84
            OhttpDecapsulation(e) => Some(e),
×
85
            InvalidSize(_) => None,
×
86
            UnexpectedStatusCode(_) => None,
×
87
        }
88
    }
×
89
}
90

91
pub fn process_get_res(
21✔
92
    res: &[u8],
21✔
93
    ohttp_context: ohttp::ClientResponse,
21✔
94
) -> Result<Option<Vec<u8>>, DirectoryResponseError> {
21✔
95
    let response = process_ohttp_res(res, ohttp_context)?;
21✔
96
    match response.status() {
21✔
97
        http::StatusCode::OK => Ok(Some(response.body().to_vec())),
18✔
98
        http::StatusCode::ACCEPTED => Ok(None),
3✔
99
        status_code => Err(DirectoryResponseError::UnexpectedStatusCode(status_code)),
×
100
    }
101
}
21✔
102

103
pub fn process_post_res(
18✔
104
    res: &[u8],
18✔
105
    ohttp_context: ohttp::ClientResponse,
18✔
106
) -> Result<(), DirectoryResponseError> {
18✔
107
    let response = process_ohttp_res(res, ohttp_context)?;
18✔
108
    match response.status() {
18✔
109
        http::StatusCode::OK => Ok(()),
18✔
110
        status_code => Err(DirectoryResponseError::UnexpectedStatusCode(status_code)),
×
111
    }
112
}
18✔
113

114
fn process_ohttp_res(
39✔
115
    res: &[u8],
39✔
116
    ohttp_context: ohttp::ClientResponse,
39✔
117
) -> Result<http::Response<Vec<u8>>, DirectoryResponseError> {
39✔
118
    let response_array: &[u8; crate::directory::ENCAPSULATED_MESSAGE_BYTES] =
39✔
119
        res.try_into().map_err(|_| DirectoryResponseError::InvalidSize(res.len()))?;
39✔
120
    log::trace!("decapsulating directory response");
39✔
121
    let res = ohttp_decapsulate(ohttp_context, response_array)
39✔
122
        .map_err(DirectoryResponseError::OhttpDecapsulation)?;
39✔
123
    Ok(res)
39✔
124
}
39✔
125

126
/// decapsulate ohttp, bhttp response and return http response body and status code
127
pub fn ohttp_decapsulate(
39✔
128
    res_ctx: ohttp::ClientResponse,
39✔
129
    ohttp_body: &[u8; ENCAPSULATED_MESSAGE_BYTES],
39✔
130
) -> Result<http::Response<Vec<u8>>, OhttpEncapsulationError> {
39✔
131
    let bhttp_body = res_ctx.decapsulate(ohttp_body)?;
39✔
132
    let mut r = std::io::Cursor::new(bhttp_body);
39✔
133
    let m: bhttp::Message = bhttp::Message::read_bhttp(&mut r)?;
39✔
134
    let mut builder = http::Response::builder();
39✔
135
    for field in m.header().iter() {
39✔
136
        builder = builder.header(field.name(), field.value());
×
137
    }
×
138
    builder
39✔
139
        .status(m.control().status().unwrap_or(http::StatusCode::INTERNAL_SERVER_ERROR.into()))
39✔
140
        .body(m.content().to_vec())
39✔
141
        .map_err(OhttpEncapsulationError::Http)
39✔
142
}
39✔
143

144
/// Error from de/encapsulating an Oblivious HTTP request or response.
145
#[derive(Debug)]
146
pub enum OhttpEncapsulationError {
147
    Http(http::Error),
148
    Ohttp(ohttp::Error),
149
    Bhttp(bhttp::Error),
150
    ParseUrl(url::ParseError),
151
}
152

153
impl From<http::Error> for OhttpEncapsulationError {
154
    fn from(value: http::Error) -> Self { Self::Http(value) }
×
155
}
156

157
impl From<ohttp::Error> for OhttpEncapsulationError {
158
    fn from(value: ohttp::Error) -> Self { Self::Ohttp(value) }
×
159
}
160

161
impl From<bhttp::Error> for OhttpEncapsulationError {
162
    fn from(value: bhttp::Error) -> Self { Self::Bhttp(value) }
×
163
}
164

165
impl From<url::ParseError> for OhttpEncapsulationError {
166
    fn from(value: url::ParseError) -> Self { Self::ParseUrl(value) }
×
167
}
168

169
impl fmt::Display for OhttpEncapsulationError {
170
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
×
171
        use OhttpEncapsulationError::*;
172

173
        match &self {
×
174
            Http(e) => e.fmt(f),
×
175
            Ohttp(e) => e.fmt(f),
×
176
            Bhttp(e) => e.fmt(f),
×
177
            ParseUrl(e) => e.fmt(f),
×
178
        }
179
    }
×
180
}
181

182
impl error::Error for OhttpEncapsulationError {
183
    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
×
184
        use OhttpEncapsulationError::*;
185

186
        match &self {
×
187
            Http(e) => Some(e),
×
188
            Ohttp(e) => Some(e),
×
189
            Bhttp(e) => Some(e),
×
190
            ParseUrl(e) => Some(e),
×
191
        }
192
    }
×
193
}
194

195
#[derive(Debug, Clone)]
196
pub struct OhttpKeys(pub ohttp::KeyConfig);
197

198
impl OhttpKeys {
199
    /// Decode an OHTTP KeyConfig
200
    pub fn decode(bytes: &[u8]) -> Result<Self, ohttp::Error> {
26✔
201
        ohttp::KeyConfig::decode(bytes).map(Self)
26✔
202
    }
26✔
203
}
204

205
const KEM_ID: &[u8] = b"\x00\x16"; // DHKEM(secp256k1, HKDF-SHA256)
206
const SYMMETRIC_LEN: &[u8] = b"\x00\x04"; // 4 bytes
207
const SYMMETRIC_KDF_AEAD: &[u8] = b"\x00\x01\x00\x03"; // KDF(HKDF-SHA256), AEAD(ChaCha20Poly1305)
208

209
impl fmt::Display for OhttpKeys {
210
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
23✔
211
        let bytes = self.encode().map_err(|_| fmt::Error)?;
23✔
212
        let key_id = bytes[0];
23✔
213
        let pubkey = &bytes[3..68];
23✔
214

215
        let compressed_pubkey =
23✔
216
            bitcoin::secp256k1::PublicKey::from_slice(pubkey).map_err(|_| fmt::Error)?.serialize();
23✔
217

218
        let mut buf = vec![key_id];
23✔
219
        buf.extend_from_slice(&compressed_pubkey);
23✔
220

221
        let oh_hrp: bech32::Hrp = bech32::Hrp::parse("OH").unwrap();
23✔
222

223
        crate::bech32::nochecksum::encode_to_fmt(f, oh_hrp, &buf).map_err(|e| match e {
23✔
224
            EncodeError::Fmt(e) => e,
×
225
            _ => fmt::Error,
×
226
        })
×
227
    }
23✔
228
}
229

230
impl TryFrom<&[u8]> for OhttpKeys {
231
    type Error = ParseOhttpKeysError;
232

233
    fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
30✔
234
        let key_id = *bytes.first().ok_or(ParseOhttpKeysError::InvalidFormat)?;
30✔
235
        let compressed_pk = bytes.get(1..34).ok_or(ParseOhttpKeysError::InvalidFormat)?;
30✔
236

237
        let pubkey = bitcoin::secp256k1::PublicKey::from_slice(compressed_pk)
30✔
238
            .map_err(|_| ParseOhttpKeysError::InvalidPublicKey)?;
30✔
239

240
        let mut buf = vec![key_id];
30✔
241
        buf.extend_from_slice(KEM_ID);
30✔
242
        buf.extend_from_slice(&pubkey.serialize_uncompressed());
30✔
243
        buf.extend_from_slice(SYMMETRIC_LEN);
30✔
244
        buf.extend_from_slice(SYMMETRIC_KDF_AEAD);
30✔
245

246
        ohttp::KeyConfig::decode(&buf).map(Self).map_err(ParseOhttpKeysError::DecodeKeyConfig)
30✔
247
    }
30✔
248
}
249

250
impl std::str::FromStr for OhttpKeys {
251
    type Err = ParseOhttpKeysError;
252

253
    /// Parses a base64URL-encoded string into OhttpKeys.
254
    /// The string format is: key_id || compressed_public_key
255
    fn from_str(s: &str) -> Result<Self, Self::Err> {
30✔
256
        // TODO extract to utility function
257
        let oh_hrp: bech32::Hrp = bech32::Hrp::parse("OH").unwrap();
30✔
258

259
        let (hrp, bytes) =
30✔
260
            crate::bech32::nochecksum::decode(s).map_err(ParseOhttpKeysError::DecodeBech32)?;
30✔
261

262
        if hrp != oh_hrp {
30✔
263
            return Err(ParseOhttpKeysError::InvalidFormat);
×
264
        }
30✔
265

266
        Self::try_from(&bytes[..])
30✔
267
    }
30✔
268
}
269

270
impl PartialEq for OhttpKeys {
271
    fn eq(&self, other: &Self) -> bool {
9✔
272
        match (self.encode(), other.encode()) {
9✔
273
            (Ok(self_encoded), Ok(other_encoded)) => self_encoded == other_encoded,
9✔
274
            // If OhttpKeys::encode(&self) is Err, return false
275
            _ => false,
×
276
        }
277
    }
9✔
278
}
279

280
impl Eq for OhttpKeys {}
281

282
impl Deref for OhttpKeys {
283
    type Target = ohttp::KeyConfig;
284

285
    fn deref(&self) -> &Self::Target { &self.0 }
56✔
286
}
287

288
impl DerefMut for OhttpKeys {
289
    fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 }
43✔
290
}
291

292
impl<'de> serde::Deserialize<'de> for OhttpKeys {
293
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
14✔
294
    where
14✔
295
        D: serde::Deserializer<'de>,
14✔
296
    {
297
        let bytes = Vec::<u8>::deserialize(deserializer)?;
14✔
298
        OhttpKeys::decode(&bytes).map_err(serde::de::Error::custom)
14✔
299
    }
14✔
300
}
301

302
impl serde::Serialize for OhttpKeys {
303
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
11✔
304
    where
11✔
305
        S: serde::Serializer,
11✔
306
    {
307
        let bytes = self.encode().map_err(serde::ser::Error::custom)?;
11✔
308
        bytes.serialize(serializer)
11✔
309
    }
11✔
310
}
311

312
#[derive(Debug)]
313
pub enum ParseOhttpKeysError {
314
    InvalidFormat,
315
    InvalidPublicKey,
316
    DecodeBech32(bech32::primitives::decode::CheckedHrpstringError),
317
    DecodeKeyConfig(ohttp::Error),
318
}
319

320
impl std::fmt::Display for ParseOhttpKeysError {
321
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
×
322
        match self {
×
323
            ParseOhttpKeysError::InvalidFormat => write!(f, "Invalid format"),
×
324
            ParseOhttpKeysError::InvalidPublicKey => write!(f, "Invalid public key"),
×
NEW
325
            ParseOhttpKeysError::DecodeBech32(e) => write!(f, "Failed to decode bech32: {e}"),
×
326
            ParseOhttpKeysError::DecodeKeyConfig(e) => write!(f, "Failed to decode KeyConfig: {e}"),
×
327
        }
328
    }
×
329
}
330

331
impl std::error::Error for ParseOhttpKeysError {
332
    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
×
333
        match self {
×
334
            ParseOhttpKeysError::DecodeBech32(e) => Some(e),
×
335
            ParseOhttpKeysError::DecodeKeyConfig(e) => Some(e),
×
336
            ParseOhttpKeysError::InvalidFormat | ParseOhttpKeysError::InvalidPublicKey => None,
×
337
        }
338
    }
×
339
}
340

341
#[cfg(test)]
342
mod test {
343
    use super::*;
344

345
    #[test]
346
    fn test_ohttp_keys_roundtrip() {
1✔
347
        use std::str::FromStr;
348

349
        use ohttp::hpke::{Aead, Kdf, Kem};
350
        use ohttp::{KeyId, SymmetricSuite};
351
        const KEY_ID: KeyId = 1;
352
        const KEM: Kem = Kem::K256Sha256;
353
        const SYMMETRIC: &[SymmetricSuite] =
354
            &[ohttp::SymmetricSuite::new(Kdf::HkdfSha256, Aead::ChaCha20Poly1305)];
355
        let keys = OhttpKeys(ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap());
1✔
356
        let serialized = &keys.to_string();
1✔
357
        let deserialized = OhttpKeys::from_str(serialized).unwrap();
1✔
358
        assert_eq!(keys.encode().unwrap(), deserialized.encode().unwrap());
1✔
359
    }
1✔
360
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc