• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

alandtse / auth_capture_proxy / 21794459388

08 Feb 2026 07:36AM UTC coverage: 65.436% (+14.7%) from 50.687%
21794459388

push

github

web-flow
build: fix tox failures (#40)

16 of 23 new or added lines in 4 files covered. (69.57%)

155 existing lines in 1 file now uncovered.

638 of 975 relevant lines covered (65.44%)

0.65 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

55.77
/authcaptureproxy/auth_capture_proxy.py
1
#  SPDX-License-Identifier: Apache-2.0
2
"""Python Package for auth capture proxy."""
3
import asyncio
1✔
4
import logging
1✔
5
import posixpath
1✔
6
import re
1✔
7
from functools import partial
1✔
8
from json import JSONDecodeError
1✔
9
from ssl import SSLContext, create_default_context
1✔
10
from typing import Any, Callable, Dict, List, Optional, Set, Text, Tuple, Union
1✔
11

12
import httpx
1✔
13
from aiohttp import MultipartReader, MultipartWriter, hdrs, web
1✔
14
from multidict import CIMultiDict
1✔
15
from yarl import URL
1✔
16

17
from authcaptureproxy.const import SKIP_AUTO_HEADERS
1✔
18
from authcaptureproxy.examples.modifiers import (
1✔
19
    prepend_relative_urls,
20
    replace_empty_action_urls,
21
    replace_matching_urls,
22
)
23
from authcaptureproxy.helper import (
1✔
24
    convert_multidict_to_dict,
25
    get_content_type,
26
    get_nested_dict_keys,
27
    print_resp,
28
    run_func,
29
    swap_url,
30
)
31
from authcaptureproxy.interceptor import BaseInterceptor, InterceptContext
1✔
32
from authcaptureproxy.stackoverflow import get_open_port
1✔
33

34
# Pre-configure SSL context
35
ssl_context = create_default_context()
1✔
36

37
_LOGGER = logging.getLogger(__name__)
1✔
38

39

40
class AuthCaptureProxy:
1✔
41
    """Class to handle proxy login connections.
42

43
    This class relies on tests to be provided to indicate the proxy has completed. At proxy completion all data can be found in self.session, self.data, and self.query.
44
    """
45

46
    def __init__(
1✔
47
        self,
48
        proxy_url: URL,
49
        host_url: URL,
50
        session: Optional[httpx.AsyncClient] = None,
51
        session_factory: Optional[Callable[[], httpx.AsyncClient]] = None,
52
        preserve_headers: bool = False,
53
    ) -> None:
54
        """Initialize proxy object.
55

56
        Args:
57
            proxy_url (URL): url for proxy location. e.g., http://192.168.1.1/. If there is any path, the path is considered part of the base url. If no explicit port is specified, a random port will be generated. If https is passed in, ssl_context must be provided at start_proxy() or the url will be downgraded to http.
58
            host_url (URL): original url for login, e.g., http://example.com
59
            session (httpx.AsyncClient): httpx client to make queries. Optional
60
            session_factory (lambda: httpx.AsyncClient): factory to create the aforementioned httpx client if having one fixed session is insufficient.
61
            preserve_headers (bool): Whether to preserve headers from the backend. Useful in circumventing CSRF protection. Defaults to False.
62
        """
63
        self._preserve_headers = preserve_headers
1✔
64
        self.session_factory: Callable[[], httpx.AsyncClient] = session_factory or (
1✔
65
            lambda: httpx.AsyncClient(verify=ssl_context)
66
        )
67
        self.session: httpx.AsyncClient = session if session else self.session_factory()
1✔
68
        self._proxy_url: URL = proxy_url
1✔
69
        self._host_url: URL = host_url
1✔
70
        self._port: int = proxy_url.explicit_port if proxy_url.explicit_port else 0  # type: ignore
1✔
71
        self.runner: Optional[web.AppRunner] = None
1✔
72
        self.last_resp: Optional[httpx.Response] = None
1✔
73
        self.init_query: Dict[Text, Any] = {}
1✔
74
        self.query: Dict[Text, Any] = {}
1✔
75
        self.data: Dict[Text, Any] = {}
1✔
76
        # tests and modifiers should be initialized after port is actually assigned and not during init.
77
        # however, to ensure defaults go first, they should have a dummy key set
78
        self._tests: Dict[Text, Callable] = {}
1✔
79
        self._modifiers: Dict[Text, Union[Callable, Dict[Text, Callable]]] = {
1✔
80
            "text/html": {
81
                "prepend_relative_urls": lambda x: x,
82
                "change_host_to_proxy": lambda x: x,
83
            }
84
        }
85
        self._old_tests: Dict[Text, Callable] = {}
1✔
86
        self._old_modifiers: Dict[Text, Union[Callable, Dict[Text, Callable]]] = {}
1✔
87
        self._active = False
1✔
88
        self._all_handler_active = True
1✔
89
        self.headers: Dict[Text, Text] = {}
1✔
90
        self.redirect_filters: Dict[Text, List[Text]] = {
1✔
91
            "url": []
92
        }  # dictionary of lists of regex strings to filter against
93
        self._background_tasks: Set[asyncio.Task] = set()
1✔
94
        self._interceptors: List[BaseInterceptor] = []
1✔
95

96
    @property
1✔
97
    def interceptors(self) -> List[BaseInterceptor]:
1✔
98
        """Return interceptors list.
99

100
        :setter: value (List[BaseInterceptor]): A list of interceptors to run during request processing. See :mod:`authcaptureproxy.examples.amazon_waf` for an example.
101
        """
UNCOV
102
        return self._interceptors
×
103

104
    @interceptors.setter
1✔
105
    def interceptors(self, value: List[BaseInterceptor]) -> None:
1✔
106
        """Set interceptors.
107

108
        Args:
109
            value (List[BaseInterceptor]): A list of interceptors.
110
        """
111
        self._interceptors = value
1✔
112

113
    @property
1✔
114
    def active(self) -> bool:
1✔
115
        """Return whether proxy is started."""
116
        return self._active
1✔
117

118
    @property
1✔
119
    def all_handler_active(self) -> bool:
1✔
120
        """Return whether all handler is active."""
121
        return self._all_handler_active
1✔
122

123
    @all_handler_active.setter
1✔
124
    def all_handler_active(self, value: bool) -> None:
1✔
125
        """Set all handler to value."""
UNCOV
126
        self._all_handler_active = value
×
127

128
    @property
1✔
129
    def port(self) -> int:
1✔
130
        """Return port setting."""
131
        return self._port
1✔
132

133
    @property
1✔
134
    def tests(self) -> Dict[Text, Callable]:
1✔
135
        """Return tests setting.
136

137
        :setter: value (Dict[Text, Any]): A dictionary of tests. The key should be the name of the test and the value should be a function or coroutine that takes a httpx.Response, a dictionary of post variables, and a dictioary of query variables and returns a URL or string. See :mod:`authcaptureproxy.examples.testers` for examples.
138
        """
139
        return self._tests
1✔
140

141
    @tests.setter
1✔
142
    def tests(self, value: Dict[Text, Callable]) -> None:
1✔
143
        """Set tests.
144

145
        Args:
146
            value (Dict[Text, Any]): A dictionary of tests.
147
        """
UNCOV
148
        self.refresh_tests()  # refresh in case of pending change
×
UNCOV
149
        self._old_tests = self._tests.copy()
×
UNCOV
150
        self._tests = value
×
151

152
    @property
1✔
153
    def modifiers(self) -> Dict[Text, Union[Callable, Dict[Text, Callable]]]:
1✔
154
        """Return modifiers setting.
155

156
        :setter: value (Dict[Text, Dict[Text, Callable]): A nested dictionary of modifiers. The key should be a MIME type and the value should be a dictionary of modifiers for that MIME type where the key should be the name of the modifier and the value should be a function or coroutine that takes a string and returns a modified string. If parameters are necessary, functools.partial should be used. See :mod:`authcaptureproxy.examples.modifiers` for examples.
157
        """
158
        return self._modifiers
1✔
159

160
    @modifiers.setter
1✔
161
    def modifiers(self, value: Dict[Text, Union[Callable, Dict[Text, Callable]]]) -> None:
1✔
162
        """Set tests.
163

164
        Args:
165
            value (Dict[Text, Any]): A dictionary of tests.
166
        """
UNCOV
167
        self.refresh_modifiers()  # refresh in case of pending change
×
UNCOV
168
        self._old_modifiers = self._modifiers
×
UNCOV
169
        self._modifiers = value
×
170

171
    def access_url(self) -> URL:
1✔
172
        """Return access url for proxy with port."""
173
        return self._proxy_url.with_port(self.port) if self.port != 0 else self._proxy_url
1✔
174

175
    async def change_host_url(self, new_url: URL) -> None:
1✔
176
        """Change the host url of the proxy.
177

178
        This will also reset all stored data.
179

180
        Args:
181
            new_url (URL): original url for login, e.g., http://example.com
182
        """
183
        if not isinstance(new_url, URL):
×
184
            raise ValueError("URL required")
×
185
        self._host_url = new_url
×
186
        await self.reset_data()
×
187

188
    async def reset_data(self) -> None:
1✔
189
        """Reset all stored data.
190

191
        A proxy may need to service multiple login requests if the route is not torn down. This function will reset all data between logins.
192
        """
UNCOV
193
        if self.session:
×
UNCOV
194
            await self.session.aclose()
×
UNCOV
195
        self.session = self.session_factory()
×
UNCOV
196
        self.last_resp = None
×
197
        self.init_query = {}
×
198
        self.query = {}
×
199
        self.data = {}
×
UNCOV
200
        self._active = False
×
UNCOV
201
        self._all_handler_active = True
×
UNCOV
202
        _LOGGER.debug("Proxy data reset.")
×
203

204
    def refresh_tests(self) -> None:
1✔
205
        """Refresh tests.
206

207
        Because tests may use partials, they will freeze their parameters which is a problem with self.access() if the port hasn't been assigned.
208
        """
209
        if self._tests != self._old_tests:
1✔
210
            self.tests.update({})
1✔
211
            self.old_tests = self.tests.copy()
1✔
212
            _LOGGER.debug("Refreshed %s tests: %s", len(self.tests), list(self.tests.keys()))
1✔
213

214
    def refresh_modifiers(self, site: Optional[URL] = None) -> None:
1✔
215
        """Refresh modifiers.
216

217
        Because modifiers may use partials, they will freeze their parameters which is a problem with self.access() if the port hasn't been assigned.
218

219
        Args:
220
            site (Optional[URL], optional): The current site. Defaults to None.
221
        """
222
        DEFAULT_MODIFIERS = {  # noqa: N806
1✔
223
            "prepend_relative_urls": partial(prepend_relative_urls, self.access_url()),
224
            "change_host_to_proxy": partial(
225
                replace_matching_urls,
226
                self._host_url.with_query({}).with_path("/"),
227
                self.access_url(),
228
            ),
229
        }
230
        if self._modifiers != self._old_modifiers:
1✔
231
            if self.modifiers.get("text/html") is None:
1✔
UNCOV
232
                self.modifiers["text/html"] = DEFAULT_MODIFIERS  # type: ignore
×
233
            elif self.modifiers.get("text/html") and isinstance(self.modifiers["text/html"], dict):
1✔
234
                self.modifiers["text/html"].update(DEFAULT_MODIFIERS)
1✔
235
            if site and isinstance(self.modifiers["text/html"], dict):
1✔
236
                self.modifiers["text/html"].update(
1✔
237
                    {
238
                        "change_empty_to_proxy": partial(
239
                            replace_empty_action_urls,
240
                            swap_url(
241
                                old_url=self._host_url.with_query({}),
242
                                new_url=self.access_url().with_query({}),
243
                                url=site,
244
                            ),
245
                        ),
246
                    }
247
                )
248
            self._old_modifiers = self.modifiers.copy()
1✔
249
            refreshed_modifers = get_nested_dict_keys(self.modifiers)
1✔
250
            _LOGGER.debug("Refreshed %s modifiers: %s", len(refreshed_modifers), refreshed_modifers)
1✔
251

252
    @staticmethod
1✔
253
    def _filter_ajax_headers(resp: httpx.Response) -> dict:
1✔
254
        """Filter headers for AJAX responses, removing hop-by-hop and CSP headers."""
255
        _skip_headers = {
1✔
256
            "content-type",
257
            "content-length",
258
            "content-encoding",
259
            "transfer-encoding",
260
            "connection",
261
            "x-connection-hash",
262
            "content-security-policy",
263
            "content-security-policy-report-only",
264
        }
265
        filtered = {}
1✔
266
        for k, v in resp.headers.items():
1✔
267
            if k.lower() not in _skip_headers:
1✔
UNCOV
268
                filtered[k] = v
×
269
        filtered["Cache-Control"] = "no-cache, no-store, must-revalidate"
1✔
270
        return filtered
1✔
271

272
    async def _build_response(
1✔
273
        self, response: Optional[httpx.Response] = None, *args, **kwargs
274
    ) -> web.Response:
275
        """
276
        Build a response.
277
        """
278
        if "headers" not in kwargs and response is not None:
1✔
279
            kwargs["headers"] = response.headers.copy() if self._preserve_headers else CIMultiDict()
1✔
280

281
            if hdrs.CONTENT_TYPE in kwargs["headers"] and "content_type" in kwargs:
1✔
UNCOV
282
                del kwargs["headers"][hdrs.CONTENT_TYPE]
×
283

284
            if hdrs.CONTENT_LENGTH in kwargs["headers"]:
1✔
UNCOV
285
                del kwargs["headers"][hdrs.CONTENT_LENGTH]
×
286

287
            if hdrs.CONTENT_ENCODING in kwargs["headers"]:
1✔
UNCOV
288
                del kwargs["headers"][hdrs.CONTENT_ENCODING]
×
289

290
            if hdrs.CONTENT_TRANSFER_ENCODING in kwargs["headers"]:
1✔
UNCOV
291
                del kwargs["headers"][hdrs.CONTENT_TRANSFER_ENCODING]
×
292

293
            if hdrs.TRANSFER_ENCODING in kwargs["headers"]:
1✔
UNCOV
294
                del kwargs["headers"][hdrs.TRANSFER_ENCODING]
×
295

296
            if "x-connection-hash" in kwargs["headers"]:
1✔
UNCOV
297
                del kwargs["headers"]["x-connection-hash"]
×
298

299
            while hdrs.SET_COOKIE in kwargs["headers"]:
1✔
UNCOV
300
                del kwargs["headers"][hdrs.SET_COOKIE]
×
301

302
            # cache control
303

304
            if hdrs.CACHE_CONTROL in kwargs["headers"]:
1✔
UNCOV
305
                del kwargs["headers"][hdrs.CACHE_CONTROL]
×
306

307
            kwargs["headers"][hdrs.CACHE_CONTROL] = "no-cache, no-store, must-revalidate"
1✔
308

309
        return web.Response(*args, **kwargs)
1✔
310

311
    async def all_handler(self, request: web.Request, **kwargs) -> web.Response:
1✔
312
        """Handle all requests.
313

314
        This handler will exit on successful test found in self.tests or if a /stop url is seen. This handler can be used with any aiohttp webserver and disabled after registered using self.all_haandler_active.
315

316
        The handler supports an interceptor pipeline for extending behavior
317
        without modifying core proxy code. See :class:`BaseInterceptor`.
318

319
        Args
320
            request (web.Request): The request to process
321
            **kwargs: Additional keyword arguments
322
                access_url (URL): The access url for the proxy. Defaults to self.access_url()
323
                host_url (URL): The host url for the proxy. Defaults to self._host_url
324

325
        Returns
326
            web.Response: The webresponse to the browser
327

328
        Raises
329
            web.HTTPFound: Redirect URL upon success
330
            web.HTTPNotFound: Return 404 when all_handler is disabled
331

332
        """
333
        if "access_url" in kwargs:
1✔
UNCOV
334
            access_url = kwargs.pop("access_url")
×
335
        else:
336
            access_url = self.access_url()
1✔
337

338
        if "host_url" in kwargs:
1✔
339
            host_url = kwargs.pop("host_url")
×
340
        else:
341
            host_url = self._host_url
1✔
342

343
        async def _process_multipart(reader: MultipartReader, writer: MultipartWriter) -> None:
1✔
344
            """Process multipart.
345

346
            Args:
347
                reader (MultipartReader): Response multipart to process.
348
                writer (MultipartWriter): Multipart to write out.
349
            """
UNCOV
350
            while True:
×
351
                part = await reader.next()  # noqa: B305
×
352
                # https://github.com/PyCQA/flake8-bugbear/issues/59
353
                if part is None:
×
354
                    break
×
355
                if isinstance(part, MultipartReader):
×
356
                    await _process_multipart(part, writer)
×
357
                elif hdrs.CONTENT_TYPE in part.headers:
×
UNCOV
358
                    content_type = part.headers.get(hdrs.CONTENT_TYPE, "")
×
UNCOV
359
                    mime_type = content_type.split(";", 1)[0].strip()
×
360
                    if mime_type == "application/json":
×
361
                        try:
×
UNCOV
362
                            part_data: Optional[
×
363
                                Union[Text, Dict[Text, Any], List[Tuple[Text, Text]], bytes]
364
                            ] = await part.json()
UNCOV
365
                            writer.append_json(part_data)
×
UNCOV
366
                        except (JSONDecodeError, ValueError, TypeError):
×
367
                            # Best-effort fallback: text, then bytes
UNCOV
368
                            try:
×
UNCOV
369
                                part_text = await part.text()
×
UNCOV
370
                                writer.append(part_text)
×
NEW
371
                            except ValueError:
×
UNCOV
372
                                part_data = await part.read()
×
UNCOV
373
                                writer.append(part_data)
×
374
                    elif mime_type.startswith("text"):
×
375
                        part_data = await part.text()
×
UNCOV
376
                        writer.append(part_data)
×
UNCOV
377
                    elif mime_type == "application/x-www-form-urlencoded":
×
UNCOV
378
                        part_data = await part.form()
×
UNCOV
379
                        writer.append_form(part_data)
×
380
                    else:
UNCOV
381
                        part_data = await part.read()
×
UNCOV
382
                        writer.append(part_data)
×
383
                else:
UNCOV
384
                    part_data = await part.read()
×
UNCOV
385
                    if part.name:
×
UNCOV
386
                        self.data.update({part.name: part_data})
×
UNCOV
387
                    elif part.filename:
×
UNCOV
388
                        part_data = await part.read()
×
UNCOV
389
                        self.data.update({part.filename: part_data})
×
UNCOV
390
                    writer.append(part_data)
×
391

392
        if not self.all_handler_active:
1✔
UNCOV
393
            _LOGGER.debug("%s all_handler is disabled; returning 404.", self)
×
UNCOV
394
            raise web.HTTPNotFound()
×
395
        method = request.method.lower()
1✔
396
        _LOGGER.debug("Received %s: %s for %s", method, str(request.url), host_url)
1✔
397
        resp: Optional[httpx.Response] = None
1✔
398
        # Create interceptor context for the request pipeline
399
        ctx = InterceptContext(
1✔
400
            request=request,
401
            proxy=self,
402
            access_url=access_url,
403
            host_url=host_url,
404
            method=method,
405
        )
406
        # Run on_request interceptors (can set ctx.site for custom URL routing)
407
        for interceptor in self._interceptors:
1✔
408
            await interceptor.on_request(ctx)
1✔
409
            if ctx.short_circuit is not None:
1✔
410
                return ctx.short_circuit
1✔
411
        if ctx.site:
1✔
412
            # Interceptor set the target URL (e.g., multi-host routing)
413
            site = ctx.site
1✔
414
        else:
415
            # Generic URL resolution
416
            old_url: URL = (
1✔
417
                access_url.with_host(request.url.host)
418
                if request.url.host and request.url.host != access_url.host
419
                else access_url
420
            )
421
            if request.scheme == "http" and access_url.scheme == "https":
1✔
422
                _LOGGER.debug("Detected http while should be https; switching to https")
×
423
                site = str(
×
424
                    swap_url(
425
                        ignore_query=True,
426
                        old_url=old_url.with_scheme("https"),
427
                        new_url=host_url.with_path("/"),
428
                        url=URL(str(request.url)).with_scheme("https"),
429
                    ),
430
                )
431
            else:
432
                site = str(
1✔
433
                    swap_url(
434
                        ignore_query=True,
435
                        old_url=old_url,
436
                        new_url=host_url.with_path("/"),
437
                        url=URL(str(request.url)),
438
                    ),
439
                )
440
        self.query.update(request.query)
1✔
441
        data: Optional[Dict] = None
1✔
442
        raw_body: Optional[bytes] = None
1✔
443
        mpwriter = None
1✔
444
        if request.content_type == "multipart/form-data":
1✔
UNCOV
445
            mpwriter = MultipartWriter()
×
UNCOV
446
            await _process_multipart(await request.multipart(), mpwriter)
×
447
        elif (
1✔
448
            request.has_body
449
            and request.content_type
450
            and "x-www-form-urlencoded" not in request.content_type
451
            and "json" not in request.content_type
452
        ):
453
            # Raw body (text/plain, binary, etc.) - forward as-is.
UNCOV
454
            raw_body = await request.read()
×
UNCOV
455
            _LOGGER.debug(
×
456
                "Read raw body (%s bytes, type=%s) for %s",
457
                len(raw_body) if raw_body else 0,
458
                request.content_type,
459
                site,
460
            )
461
        else:
462
            data = convert_multidict_to_dict(await request.post())
1✔
463
        json_data = None
1✔
464
        # Only attempt JSON decoding for JSON requests; avoid raising for form posts.
465
        if request.has_body and (
1✔
466
            request.content_type == "application/json" or request.content_type.endswith("+json")
467
        ):
468
            try:
1✔
469
                json_data = await request.json()
1✔
UNCOV
470
            except (JSONDecodeError, ValueError):
×
UNCOV
471
                json_data = None
×
472
        if data:
1✔
473
            self.data.update(data)
1✔
474
            _LOGGER.debug("Storing data %s", data)
1✔
475
            # Run on_request_data interceptors (can modify data before HTTP request)
476
            ctx.site = site
1✔
477
            ctx.data = data
1✔
478
            ctx.json_data = json_data
1✔
479
            for interceptor in self._interceptors:
1✔
480
                await interceptor.on_request_data(ctx)
1✔
481
                if ctx.short_circuit is not None:
1✔
482
                    return ctx.short_circuit
×
483
            data = ctx.data
1✔
484
        elif json_data:
1✔
485
            self.data.update(json_data)
1✔
486
            _LOGGER.debug("Storing json %s", json_data)
1✔
487
        if URL(str(request.url)).path == re.sub(
1✔
488
            r"/+", "/", self._proxy_url.with_path(f"{self._proxy_url.path}/stop").path
489
        ):
490
            self.all_handler_active = False
×
491
            if self.active:
×
UNCOV
492
                task = asyncio.create_task(self.stop_proxy(3))
×
UNCOV
493
                self._background_tasks.add(task)
×
UNCOV
494
                task.add_done_callback(self._background_tasks.discard)
×
UNCOV
495
            return await self._build_response(text="Proxy stopped.")
×
496
        elif (
1✔
497
            URL(str(request.url)).path
498
            == re.sub(r"/+", "/", self._proxy_url.with_path(f"{self._proxy_url.path}/resume").path)
499
            and self.last_resp
500
            and isinstance(self.last_resp, httpx.Response)
501
        ):
UNCOV
502
            self.init_query = self.query.copy()
×
UNCOV
503
            _LOGGER.debug("Resuming request: %s", self.last_resp)
×
504
            resp = self.last_resp
×
505
        else:
506
            if URL(str(request.url)).path in [
1✔
507
                self._proxy_url.path,
508
                re.sub(
509
                    r"/+", "/", self._proxy_url.with_path(f"{self._proxy_url.path}/resume").path
510
                ),
511
            ]:
512
                # either base path or resume without anything to resume
513
                site = str(URL(host_url))
×
514
                if method == "get":
×
515
                    self.init_query = self.query.copy()
×
516
                    _LOGGER.debug(
×
517
                        "Starting auth capture proxy for %s",
518
                        host_url,
519
                    )
520
            headers = await self.modify_headers(URL(site), request)
1✔
521
            skip_auto_headers: List[str] = headers.get(SKIP_AUTO_HEADERS, [])
1✔
522
            if skip_auto_headers:
1✔
523
                _LOGGER.debug("Discovered skip_auto_headers %s", skip_auto_headers)
×
524
                headers.pop(SKIP_AUTO_HEADERS)
×
525
            # Avoid accidental header mutation across branches/calls
526
            req_headers: dict[str, Any] = dict(headers)
1✔
527
            _LOGGER.debug(
1✔
528
                "Attempting %s to %s\nheaders: %s \ncookies: %s",
529
                method,
530
                site,
531
                req_headers,
532
                self.session.cookies.jar,
533
            )
534
            try:
1✔
535
                if mpwriter:
1✔
536
                    resp = await getattr(self.session, method)(
×
537
                        site, data=mpwriter, headers=req_headers, follow_redirects=True
538
                    )
539
                elif data:
1✔
540
                    resp = await getattr(self.session, method)(
1✔
541
                        site, data=data, headers=req_headers, follow_redirects=True
542
                    )
543
                elif raw_body is not None:
1✔
544
                    _LOGGER.debug(
×
545
                        "Sending raw body (%s bytes, Content-Type: %s) to %s",
546
                        len(raw_body),
547
                        request.content_type,
548
                        site,
549
                    )
550
                    # Preserve the original Content-Type for raw body requests
551
                    if request.content_type and "Content-Type" not in req_headers:
×
552
                        req_headers["Content-Type"] = request.content_type
×
553
                    resp = await getattr(self.session, method)(
×
554
                        site, content=raw_body, headers=req_headers, follow_redirects=True
555
                    )
556
                elif json_data:
1✔
557
                    for item in ["Host", "Origin", "User-Agent", "dnt", "Accept-Encoding"]:
1✔
558
                        # remove proxy headers
559
                        if req_headers.get(item):
1✔
560
                            req_headers.pop(item)
1✔
561
                    resp = await getattr(self.session, method)(
1✔
562
                        site, json=json_data, headers=req_headers, follow_redirects=True
563
                    )
564
                else:
565
                    resp = await getattr(self.session, method)(
1✔
566
                        site, headers=req_headers, follow_redirects=True
567
                    )
UNCOV
568
            except httpx.ConnectError as ex:
×
UNCOV
569
                return await self._build_response(
×
570
                    text=f"Error connecting to {site}; please retry: {ex}"
571
                )
UNCOV
572
            except httpx.TooManyRedirects as ex:
×
UNCOV
573
                return await self._build_response(
×
574
                    text=f"Error connecting to {site}; too many redirects: {ex}"
575
                )
UNCOV
576
            except httpx.TimeoutException as ex:
×
NEW
577
                _LOGGER.warning("Timeout connecting to %s: %s", site, ex)
×
UNCOV
578
                return await self._build_response(
×
579
                    text=(
580
                        f"Timeout connecting to {site}: {ex}. "
581
                        "Please try again. If this persists, check your network "
582
                        "and that the service endpoint is reachable from this host."
583
                    )
584
                )
UNCOV
585
            except httpx.HTTPError as ex:
×
NEW
586
                return await self._build_response(text=f"Error connecting to {site}: {ex}")
×
587
        if resp is None:
1✔
UNCOV
588
            return await self._build_response(text=f"Error connecting to {site}; please retry")
×
589
        self.last_resp = resp
1✔
590
        print_resp(resp)
1✔
591
        # Run on_response interceptors (post-response, pre-tests)
592
        ctx.response = resp
1✔
593
        ctx.site = site
1✔
594
        for interceptor in self._interceptors:
1✔
595
            await interceptor.on_response(ctx)
1✔
596
            if ctx.short_circuit is not None:
1✔
597
                return ctx.short_circuit
1✔
598
        self.check_redirects()
1✔
599
        self.refresh_tests()
1✔
600
        if self.tests:
1✔
601
            for test_name, test in self.tests.items():
1✔
602
                result = None
1✔
603
                result = await run_func(test, test_name, resp, self.data, self.query)
1✔
604
                if result:
1✔
UNCOV
605
                    _LOGGER.debug("Test %s triggered", test_name)
×
UNCOV
606
                    if isinstance(result, URL):
×
UNCOV
607
                        _LOGGER.debug(
×
608
                            "Redirecting to callback: %s",
609
                            result,
610
                        )
611
                        raise web.HTTPFound(location=result)
×
612
                    elif isinstance(result, str):
×
613
                        _LOGGER.debug("Displaying page:\n%s", result)
×
614
                        return await self._build_response(
×
615
                            resp, text=result, content_type="text/html"
616
                        )
617
        else:
618
            _LOGGER.warning("Proxy has no tests; please set.")
1✔
619
        content_type = get_content_type(resp)
1✔
620
        # Detect AJAX requests using Fetch Metadata headers (W3C standard).
621
        # Sec-Fetch-Mode is set by the browser and cannot be spoofed by JS.
622
        # 'navigate' = top-level page navigation; anything else = AJAX/subresource.
623
        # Fall back to Upgrade-Insecure-Requests for older clients.
624
        _sec_fetch_mode = request.headers.get("Sec-Fetch-Mode")
1✔
625
        if _sec_fetch_mode is not None:
1✔
626
            _is_ajax = _sec_fetch_mode != "navigate"
1✔
627
        else:
628
            # Legacy fallback for clients without Sec-Fetch-Mode
629
            _is_ajax = request.headers.get("Upgrade-Insecure-Requests") != "1"
1✔
630
        if _is_ajax:
1✔
631
            _LOGGER.debug(
1✔
632
                "AJAX response for %s: status=%s, content_type=%s",
633
                URL(str(request.url)).path,
634
                resp.status_code,
635
                content_type,
636
            )
637
        if _is_ajax and content_type == "text/html":
1✔
638
            _ajax_body = resp.content
1✔
639
            # Run on_ajax_html interceptors (can modify AJAX HTML body)
640
            ctx.is_ajax = True
1✔
641
            ctx.content_type = content_type
1✔
642
            ctx.body = _ajax_body
1✔
643
            for interceptor in self._interceptors:
1✔
644
                await interceptor.on_ajax_html(ctx)
1✔
645
            _ajax_body = ctx.body if ctx.body is not None else _ajax_body
1✔
646
            _LOGGER.debug(
1✔
647
                "AJAX HTML response for %s - skipping modifiers",
648
                URL(str(request.url)).path,
649
            )
650
            # Forward original headers for AJAX responses.
651
            # Client-side JavaScript may check response headers (e.g., for CAPTCHA
652
            # initialization). Without them, it may fail silently.
653
            _ajax_headers = self._filter_ajax_headers(resp) if resp is not None else {}
1✔
654
            return await self._build_response(
1✔
655
                resp,
656
                body=_ajax_body,
657
                content_type=content_type,
658
                headers=_ajax_headers,
659
            )
660
        # Also skip modifiers for non-HTML AJAX responses (JSON, binary, etc.)
661
        if _is_ajax and content_type != "text/html":
1✔
662
            _LOGGER.debug(
1✔
663
                "AJAX non-HTML response (%s) for %s - skipping modifiers",
664
                content_type,
665
                URL(str(request.url)).path,
666
            )
667
            _resp_body = resp.content
1✔
668
            _ajax_headers_nh = self._filter_ajax_headers(resp) if resp is not None else {}
1✔
669
            return await self._build_response(
1✔
670
                resp,
671
                body=_resp_body,
672
                content_type=content_type,
673
                headers=_ajax_headers_nh,
674
            )
675
        self.refresh_modifiers(URL(str(resp.url)))
1✔
676
        if self.modifiers:
1✔
677
            modified: bool = False
1✔
678
            if content_type != "text/html" and content_type not in self.modifiers.keys():
1✔
UNCOV
679
                text: Text = ""
×
680
            elif content_type != "text/html" and content_type in self.modifiers.keys():
1✔
UNCOV
681
                text = resp.text
×
682
            else:
683
                text = resp.text
1✔
684
            if not isinstance(text, str):  # process aiohttp text
1✔
UNCOV
685
                text = await resp.text()
×
686
            # Resolve relative form actions BEFORE modifiers run.
687
            if text and content_type == "text/html" and resp and resp.url:
1✔
688
                _resp_url = URL(str(resp.url))
1✔
689
                _resp_dir = _resp_url.path.rsplit("/", 1)[0] + "/" if "/" in _resp_url.path else "/"
1✔
690

691
                def _resolve_form_action(form_match):
1✔
692
                    """Resolve relative action URLs only inside <form> tags."""
693
                    form_tag = form_match.group(0)
×
NEW
694
                    action_m = re.search(r'(\s+action=["\'])([^"\']*?)(["\'])', form_tag)
×
UNCOV
695
                    if not action_m:
×
696
                        return form_tag
×
UNCOV
697
                    action = action_m.group(2)
×
UNCOV
698
                    if action and not action.startswith(
×
699
                        ("http://", "https://", "//", "#", "javascript:", "/")
700
                    ):
UNCOV
701
                        resolved_path = posixpath.normpath(_resp_dir + action)
×
702
                        _proxy_base = self.access_url().path.rstrip("/")
×
703
                        abs_url = str(
×
704
                            self.access_url().with_path(_proxy_base + resolved_path).with_query({})
705
                        )
UNCOV
706
                        _LOGGER.debug(
×
707
                            "Resolved relative form action '%s' -> '%s' (page: %s)",
708
                            action,
709
                            abs_url,
710
                            _resp_url.path,
711
                        )
NEW
712
                        return form_tag[: action_m.start(2)] + abs_url + form_tag[action_m.end(2) :]
×
713
                    return form_tag
×
714

715
                text = re.sub(
1✔
716
                    r"<form\b[^>]*>",
717
                    _resolve_form_action,
718
                    text,
719
                    flags=re.IGNORECASE,
720
                )
721
            # Run on_page_html interceptors (can inject scripts before modifiers)
722
            if text and content_type == "text/html":
1✔
723
                ctx.text = text
1✔
724
                ctx.content_type = content_type
1✔
725
                ctx.is_ajax = False
1✔
726
                for interceptor in self._interceptors:
1✔
727
                    await interceptor.on_page_html(ctx)
1✔
728
                text = ctx.text if ctx.text is not None else text
1✔
729
            if text:
1✔
730
                for name, modifier in self.modifiers.items():
1✔
731
                    if isinstance(modifier, dict):
1✔
732
                        if name != content_type:
1✔
UNCOV
733
                            continue
×
734
                        for sub_name, sub_modifier in modifier.items():
1✔
735
                            try:
1✔
736
                                text = await run_func(sub_modifier, sub_name, text)
1✔
737
                                modified = True
1✔
738
                            except TypeError as ex:
×
UNCOV
739
                                _LOGGER.warning("Modifier %s is not callable: %s", sub_name, ex)
×
740
                    else:
741
                        # default run against text/html only
UNCOV
742
                        if content_type == "text/html":
×
UNCOV
743
                            try:
×
UNCOV
744
                                text = await run_func(modifier, name, text)
×
UNCOV
745
                                modified = True
×
746
                            except TypeError as ex:
×
UNCOV
747
                                _LOGGER.warning("Modifier %s is not callable: %s", name, ex)
×
748
            if modified:
1✔
749
                return await self._build_response(
1✔
750
                    resp,
751
                    text=text,
752
                    content_type=content_type,
753
                )
754
        # pass through non parsed content
UNCOV
755
        _LOGGER.debug(
×
756
            "Passing through %s as %s",
757
            URL(str(request.url)).name
758
            if URL(str(request.url)).name
759
            else URL(str(request.url)).path,
760
            content_type,
761
        )
762
        return await self._build_response(resp, body=resp.content, content_type=content_type)
×
763

764
    async def start_proxy(
1✔
765
        self, host: Optional[Text] = None, ssl_context: Optional[SSLContext] = None
766
    ) -> None:
767
        """Start proxy.
768

769
        Args:
770
            host (Optional[Text], optional): The host interface to bind to. Defaults to None which is "0.0.0.0" all interfaces.
771
            ssl_context (Optional[SSLContext], optional): SSL Context for the server. Defaults to None.
772
        """
UNCOV
773
        app = web.Application()
×
UNCOV
774
        app.add_routes(
×
775
            [
776
                web.route("*", "/{tail:.*}", self.all_handler),
777
            ]
778
        )
UNCOV
779
        self.runner = web.AppRunner(app)
×
UNCOV
780
        await self.runner.setup()
×
UNCOV
781
        if not self.port:
×
UNCOV
782
            self._port = get_open_port()
×
UNCOV
783
        if self._proxy_url.scheme == "https" and ssl_context is None:
×
UNCOV
784
            _LOGGER.debug("Proxy url is https but no SSL Context set, downgrading to http")
×
UNCOV
785
            self._proxy_url = self._proxy_url.with_scheme("http")
×
UNCOV
786
        site = web.TCPSite(runner=self.runner, host=host, port=self.port, ssl_context=ssl_context)
×
UNCOV
787
        await site.start()
×
UNCOV
788
        self._active = True
×
UNCOV
789
        _LOGGER.debug("Started proxy at %s", self.access_url())
×
790

791
    async def stop_proxy(self, delay: int = 0) -> None:
1✔
792
        """Stop proxy server.
793

794
        Args:
795
            delay (int, optional): How many seconds to delay. Defaults to 0.
796
        """
UNCOV
797
        if not self.active:
×
UNCOV
798
            _LOGGER.debug("Proxy is not started; ignoring stop command")
×
UNCOV
799
            return
×
UNCOV
800
        _LOGGER.debug("Stopping proxy at %s after %s seconds", self.access_url(), delay)
×
UNCOV
801
        await asyncio.sleep(delay)
×
UNCOV
802
        _LOGGER.debug("Closing site runner")
×
UNCOV
803
        if self.runner:
×
UNCOV
804
            await self.runner.cleanup()
×
UNCOV
805
            await self.runner.shutdown()
×
UNCOV
806
        _LOGGER.debug("Site runner closed")
×
807
        # close session
UNCOV
808
        if self.session:
×
UNCOV
809
            _LOGGER.debug("Closing session")
×
UNCOV
810
            await self.session.aclose()
×
UNCOV
811
            _LOGGER.debug("Session closed")
×
UNCOV
812
        self._active = False
×
UNCOV
813
        _LOGGER.debug("Proxy stopped")
×
814

815
    def _swap_proxy_and_host(self, text: Text, domain_only: bool = False) -> Text:
1✔
816
        """Replace host with proxy address or proxy with host address.
817

818
        Args
819
            text (Text): text to replace
820
            domain (bool): Whether only the domains should be swapped.
821

822
        Returns
823
            Text: Result of replacing
824

825
        """
UNCOV
826
        host_string: Text = str(self._host_url.with_path("/"))
×
UNCOV
827
        proxy_string: Text = str(
×
828
            self.access_url() if not domain_only else self.access_url().with_path("/")
829
        )
UNCOV
830
        if str(self.access_url().with_path("/")).replace("https", "http") in text:
×
UNCOV
831
            _LOGGER.debug(
×
832
                "Replacing %s with %s",
833
                str(self.access_url().with_path("/")).replace("https", "http"),
834
                str(self.access_url().with_path("/")),
835
            )
UNCOV
836
            text = text.replace(
×
837
                str(self.access_url().with_path("/")).replace("https", "http"),
838
                str(self.access_url().with_path("/")),
839
            )
UNCOV
840
        if proxy_string in text:
×
UNCOV
841
            if host_string[-1] == "/" and (
×
842
                not proxy_string or proxy_string == "/" or proxy_string[-1] != "/"
843
            ):
UNCOV
844
                proxy_string = f"{proxy_string}/"
×
UNCOV
845
            _LOGGER.debug("Replacing %s with %s in %s", proxy_string, host_string, text)
×
UNCOV
846
            return text.replace(proxy_string, host_string)
×
UNCOV
847
        elif host_string in text:
×
UNCOV
848
            if host_string[-1] == "/" and (
×
849
                not proxy_string or proxy_string == "/" or proxy_string[-1] != "/"
850
            ):
UNCOV
851
                proxy_string = f"{proxy_string}/"
×
UNCOV
852
            _LOGGER.debug("Replacing %s with %s", host_string, proxy_string)
×
UNCOV
853
            return text.replace(host_string, proxy_string)
×
854
        else:
UNCOV
855
            _LOGGER.debug("Unable to find %s and %s in %s", host_string, proxy_string, text)
×
UNCOV
856
            return text
×
857

858
    async def modify_headers(self, site: URL, request: web.Request) -> dict:
1✔
859
        """Modify headers.
860

861
        Return modified headers based on site and request. To disable auto header generation,
862
        pass in to the header a key const.SKIP_AUTO_HEADERS with a list of keys to not generate.
863

864
        For example, to prevent User-Agent generation: {SKIP_AUTO_HEADERS : ["User-Agent"]}
865

866
        Args:
867
            site (URL): URL of the next host request.
868
            request (web.Request): Proxy directed request. This will need to be changed for the actual host request.
869

870
        Returns:
871
            dict: Headers after modifications
872
        """
873
        result: Dict[str, Any] = {}
1✔
874
        result.update(request.headers)
1✔
875
        # _LOGGER.debug("Original headers %s", headers)
876
        if result.get("Host"):
1✔
UNCOV
877
            result.pop("Host")
×
878
        if result.get("Origin"):
1✔
879
            # Use the configured host URL as Origin for cross-origin requests.
880
            # Third-party services may validate Origin against the page that loaded them.
UNCOV
881
            result["Origin"] = f"{self._host_url.with_path('')}"
×
882
        # remove any cookies in header received from browser. If not removed, httpx will not send session cookies
883
        if result.get("Cookie"):
1✔
UNCOV
884
            result.pop("Cookie")
×
885
        if result.get("Referer") and (
1✔
886
            URL(result.get("Referer", "")).query == self.init_query
887
            or URL(result.get("Referer", "")).path
888
            == "/config/integrations"  # home-assistant referer
889
        ):
890
            # Change referer for starting request; this may have query items we shouldn't pass
UNCOV
891
            result["Referer"] = str(self._host_url)
×
892
        elif result.get("Referer"):
1✔
UNCOV
893
            result["Referer"] = self._swap_proxy_and_host(
×
894
                result.get("Referer", ""), domain_only=True
895
            )
896
        for item in [
1✔
897
            "Content-Length",
898
            "X-Forwarded-For",
899
            "X-Forwarded-Host",
900
            "X-Forwarded-Port",
901
            "X-Forwarded-Proto",
902
            "X-Forwarded-Scheme",
903
            "X-Forwarded-Server",
904
            "X-Real-IP",
905
        ]:
906
            # remove proxy headers
907
            if result.get(item):
1✔
UNCOV
908
                result.pop(item)
×
909
        result.update(self.headers if self.headers else {})
1✔
910
        _LOGGER.debug("Final headers %s", result)
1✔
911
        return result
1✔
912

913
    def check_redirects(self) -> None:
1✔
914
        """Change host if redirect detected and regex does not match self.redirect_filters.
915

916
        Self.redirect_filters is a dict with key as attr in resp and value as list of regex expressions to filter against.
917
        """
918
        if not self.last_resp:
1✔
UNCOV
919
            return
×
920
        resp: httpx.Response = self.last_resp
1✔
921
        if resp.history:
1✔
UNCOV
922
            for item in resp.history:
×
UNCOV
923
                if (
×
924
                    item.status_code in [301, 302, 303, 304, 305, 306, 307, 308]
925
                    and item.url
926
                    and resp.url
927
                    and resp.url.host != self._host_url.host
928
                ):
UNCOV
929
                    filtered = False
×
UNCOV
930
                    for attr, regex_list in self.redirect_filters.items():
×
UNCOV
931
                        if getattr(resp, attr) and list(
×
932
                            filter(
933
                                lambda regex_string: re.search(
934
                                    regex_string, str(getattr(resp, attr))
935
                                ),
936
                                regex_list,
937
                            )
938
                        ):
UNCOV
939
                            _LOGGER.debug(
×
940
                                "Check_redirects: Filtered out on %s in %s for resp attribute %s",
941
                                list(
942
                                    filter(
943
                                        lambda regex_string: re.search(
944
                                            regex_string, str(getattr(resp, attr))
945
                                        ),
946
                                        regex_list,
947
                                    )
948
                                ),
949
                                str(getattr(resp, attr)),
950
                                attr,
951
                            )
UNCOV
952
                            filtered = True
×
UNCOV
953
                    if filtered:
×
UNCOV
954
                        return
×
UNCOV
955
                    _LOGGER.debug(
×
956
                        "Detected %s redirect from %s to %s; changing proxy host",
957
                        item.status_code,
958
                        item.url.host,
959
                        resp.url.host,
960
                    )
UNCOV
961
                    self._host_url = self._host_url.with_host(resp.url.host)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc