• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

alandtse / auth_capture_proxy / 21736981058

06 Feb 2026 02:56AM UTC coverage: 50.687% (+9.4%) from 41.243%
21736981058

push

github

web-flow
fix: header duplication & JSON parsing (#37)

15 of 37 new or added lines in 1 file covered. (40.54%)

108 existing lines in 1 file now uncovered.

369 of 728 relevant lines covered (50.69%)

0.51 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

39.3
/authcaptureproxy/auth_capture_proxy.py
1
#  SPDX-License-Identifier: Apache-2.0
2
"""Python Package for auth capture proxy."""
3
import asyncio
1✔
4
import logging
1✔
5
import re
1✔
6
from json import JSONDecodeError
1✔
7
from functools import partial
1✔
8
from ssl import SSLContext, create_default_context
1✔
9
from typing import Any, Callable, Dict, List, Optional, Set, Text, Tuple, Union
1✔
10

11
import httpx
1✔
12
from aiohttp import (
1✔
13
    MultipartReader,
1✔
14
    MultipartWriter,
1✔
15
    hdrs,
16
    web,
1✔
17
)
1✔
18
from multidict import CIMultiDict
19
from yarl import URL
20

21
from authcaptureproxy.const import SKIP_AUTO_HEADERS
22
from authcaptureproxy.examples.modifiers import (
1✔
23
    prepend_relative_urls,
24
    replace_empty_action_urls,
25
    replace_matching_urls,
26
)
27
from authcaptureproxy.helper import (
28
    convert_multidict_to_dict,
29
    get_content_type,
30
    get_nested_dict_keys,
1✔
31
    print_resp,
32
    run_func,
33
    swap_url,
1✔
34
)
35
from authcaptureproxy.stackoverflow import get_open_port
1✔
36

37
# Pre-configure SSL context
38
ssl_context = create_default_context()
1✔
39

40
_LOGGER = logging.getLogger(__name__)
41

42

43
class AuthCaptureProxy:
44
    """Class to handle proxy login connections.
1✔
45

46
    This class relies on tests to be provided to indicate the proxy has completed. At proxy completion all data can be found in self.session, self.data, and self.query.
47
    """
48

49
    def __init__(self,
50
        proxy_url: URL,
51
        host_url: URL,
52
        session: Optional[httpx.AsyncClient] = None,
53
        session_factory: Optional[Callable[[], httpx.AsyncClient]] = None,
54
        preserve_headers: bool = False,
55
    ) -> None:
56
        """Initialize proxy object.
57

58
        Args:
59
            proxy_url (URL): url for proxy location. e.g., http://192.168.1.1/.
60
                If there is any path, the path is considered part of the base url.
61
                If no explicit port is specified, a random port will be generated.
62
                If https is passed in, ssl_context must be provided at start_proxy() or the url will be downgraded to http.
63
            host_url (URL): original url for login, e.g., http://amazon.com
64
            session (httpx.AsyncClient): httpx client to make queries. Optional
1✔
65
            session_factory (lambda: httpx.AsyncClient): factory to create the aforementioned httpx client if having one fixed session is insufficient.
1✔
66
            preserve_headers (bool): Whether to preserve headers from the backend. Useful in circumventing CSRF protection. Defaults to False.
67
        """
68
        self._preserve_headers = preserve_headers
69
        self.session_factory: Callable[[], httpx.AsyncClient] = session_factory or (
70
            lambda: httpx.AsyncClient(
71
                verify=ssl_context,
72
                timeout=httpx.Timeout(
73
                    connect=10.0,
74
                    read=30.0,
75
                    write=10.0,
76
                    pool=10.0,
1✔
77
                ),
1✔
78
            )
1✔
79
        )
1✔
80
        self.session: httpx.AsyncClient = session if session else self.session_factory()
1✔
81
        self._proxy_url: URL = proxy_url
1✔
82
        self._host_url: URL = host_url
1✔
83
        self._port: int = proxy_url.explicit_port if proxy_url.explicit_port else 0  # type: ignore
1✔
84
        self.runner: Optional[web.AppRunner] = None
1✔
85
        self.last_resp: Optional[httpx.Response] = None
86
        self.init_query: Dict[Text, Any] = {}
87
        self.query: Dict[Text, Any] = {}
1✔
88
        self.data: Dict[Text, Any] = {}
1✔
89
        # tests and modifiers should be initialized after port is actually assigned and not during init.
90
        # however, to ensure defaults go first, they should have a dummy key set
91
        self._tests: Dict[Text, Callable] = {}
92
        self._modifiers: Dict[Text, Union[Callable, Dict[Text, Callable]]] = {
93
            "text/html": {
94
                "prepend_relative_urls": lambda x: x,
1✔
95
                "change_host_to_proxy": lambda x: x,
1✔
96
            }
1✔
97
        }
1✔
98
        self._old_tests: Dict[Text, Callable] = {}
1✔
99
        self._old_modifiers: Dict[Text, Union[Callable, Dict[Text, Callable]]] = {}
1✔
100
        self._active = False
101
        self._all_handler_active = True
102
        self.headers: Dict[Text, Text] = {}
1✔
103
        self.redirect_filters: Dict[Text, List[Text]] = {
104
            "url": []
1✔
105
        }  # dictionary of lists of regex strings to filter against
1✔
106
        self._background_tasks: Set[asyncio.Task] = set()
107

1✔
108
    @property
109
    def active(self) -> bool:
1✔
110
        """Return whether proxy is started."""
1✔
111
        return self._active
112

1✔
113
    @property
114
    def all_handler_active(self) -> bool:
1✔
115
        """Return whether all handler is active."""
1✔
116
        return self._all_handler_active
UNCOV
117

×
118
    @all_handler_active.setter
119
    def all_handler_active(self, value: bool) -> None:
1✔
120
        """Set all handler to value."""
1✔
121
        self._all_handler_active = value
122

1✔
123
    @property
124
    def port(self) -> int:
1✔
125
        """Return port setting."""
1✔
126
        return self._port
127

128
    @property
129
    def tests(self) -> Dict[Text, Callable]:
130
        """Return tests setting.
1✔
131

132
        :setter: value (Dict[Text, Any]): A dictionary of tests. The key should be the name of the test and the value should be a function or coroutine that takes a httpx.Response, a dictionary of post variables, and a dictioary of query variables and returns a URL or string. See :mod:`authcaptureproxy.examples.testers` for examples.
1✔
133
        """
1✔
134
        return self._tests
135

136
    @tests.setter
137
    def tests(self, value: Dict[Text, Callable]) -> None:
138
        """Set tests.
UNCOV
139

×
UNCOV
140
        Args:
×
UNCOV
141
            value (Dict[Text, Any]): A dictionary of tests.
×
142
        """
143
        self.refresh_tests()  # refresh in case of pending change
1✔
144
        self._old_tests = self._tests.copy()
1✔
145
        self._tests = value
146

147
    @property
148
    def modifiers(self) -> Dict[Text, Union[Callable, Dict[Text, Callable]]]:
149
        """Return modifiers setting.
1✔
150

151
        :setter: value (Dict[Text, Dict[Text, Callable]): A nested dictionary of modifiers. The key should be a MIME type and the value should be a dictionary of modifiers for that MIME type where the key should be the name of the modifier and the value should be a function or coroutine that takes a string and returns a modified string. If parameters are necessary, functools.partial should be used. See :mod:`authcaptureproxy.examples.modifiers` for examples.
1✔
152
        """
1✔
153
        return self._modifiers
154

155
    @modifiers.setter
156
    def modifiers(self, value: Dict[Text, Union[Callable, Dict[Text, Callable]]]) -> None:
157
        """Set tests.
UNCOV
158

×
UNCOV
159
        Args:
×
UNCOV
160
            value (Dict[Text, Any]): A dictionary of tests.
×
161
        """
162
        self.refresh_modifiers()  # refresh in case of pending change
1✔
163
        self._old_modifiers = self._modifiers
164
        self._modifiers = value
1✔
165

166
    def access_url(self) -> URL:
1✔
167
        """Return access url for proxy with port."""
168
        return self._proxy_url.with_port(self.port) if self.port != 0 else self._proxy_url
169

170
    async def change_host_url(self, new_url: URL) -> None:
171
        """Change the host url of the proxy.
172

173
        This will also reset all stored data.
UNCOV
174

×
UNCOV
175
        Args:
×
UNCOV
176
            new_url (URL): original url for login, e.g., http://amazon.com
×
UNCOV
177
        """
×
178
        if not isinstance(new_url, URL):
179
            raise ValueError("URL required")
1✔
180
        self._host_url = new_url
181
        await self.reset_data()
182

183
    async def reset_data(self) -> None:
UNCOV
184
        """Reset all stored data.
×
UNCOV
185

×
UNCOV
186
        A proxy may need to service multiple login requests if the route is not torn down. This function will reset all data between logins.
×
UNCOV
187
        """
×
188
        if self.session:
×
189
            await self.session.aclose()
×
190
        self.session = self.session_factory()
×
191
        self.last_resp = None
×
192
        self.init_query = {}
×
193
        self.query = {}
×
194
        self.data = {}
195
        self._active = False
1✔
196
        self._all_handler_active = True
197
        _LOGGER.debug("Proxy data reset.")
198

199
    def refresh_tests(self) -> None:
200
        """Refresh tests.
1✔
UNCOV
201

×
UNCOV
202
        Because tests may use partials, they will freeze their parameters which is a problem with self.access() if the port hasn't been assigned.
×
UNCOV
203
        """
×
204
        if self._tests != self._old_tests:
205
            self.tests.update({})
1✔
206
            self.old_tests = self.tests.copy()
207
            _LOGGER.debug("Refreshed %s tests: %s", len(self.tests), list(self.tests.keys()))
208

209
    def refresh_modifiers(self, site: Optional[URL] = None) -> None:
210
        """Refresh modifiers.
211

212
        Because modifiers may use partials, they will freeze their parameters which is a problem with self.access() if the port hasn't been assigned.
213

1✔
214
        Args:
215
            site (Optional[URL], optional): The current site. Defaults to None.
216
        """
217
        DEFAULT_MODIFIERS = {  # noqa: N806
218
            "prepend_relative_urls": partial(prepend_relative_urls, self.access_url()),
219
            "change_host_to_proxy": partial(
220
                replace_matching_urls,
221
                self._host_url.with_query({}).with_path("/"),
1✔
UNCOV
222
                self.access_url(),
×
UNCOV
223
            ),
×
UNCOV
224
        }
×
225
        if self._modifiers != self._old_modifiers:
×
226
            if self.modifiers.get("text/html") is None:
×
227
                self.modifiers["text/html"] = DEFAULT_MODIFIERS  # type: ignore
×
228
            elif self.modifiers.get("text/html") and isinstance(self.modifiers["text/html"], dict):
229
                self.modifiers["text/html"].update(DEFAULT_MODIFIERS)
230
            if site and isinstance(self.modifiers["text/html"], dict):
231
                self.modifiers["text/html"].update(
232
                    {
233
                        "change_empty_to_proxy": partial(
234
                            replace_empty_action_urls,
235
                            swap_url(
236
                                old_url=self._host_url.with_query({}),
237
                                new_url=self.access_url().with_query({}),
238
                                url=site,
UNCOV
239
                            ),
×
UNCOV
240
                        ),
×
UNCOV
241
                    }
×
242
                )
243
            self._old_modifiers = self.modifiers.copy()
1✔
244
            refreshed_modifers = get_nested_dict_keys(self.modifiers)
245
            _LOGGER.debug("Refreshed %s modifiers: %s", len(refreshed_modifers), refreshed_modifers)
246

247
    async def _build_response(
248
        self, response: Optional[httpx.Response] = None, *args, **kwargs
249
    ) -> web.Response:
1✔
250
        """
1✔
251
        Build a response.
252
        """
1✔
253
        if "headers" not in kwargs and response is not None:
×
254
            kwargs["headers"] = response.headers.copy() if self._preserve_headers else CIMultiDict()
255

1✔
256
            if hdrs.CONTENT_TYPE in kwargs["headers"] and "content_type" in kwargs:
×
257
                del kwargs["headers"][hdrs.CONTENT_TYPE]
258

1✔
259
            if hdrs.CONTENT_LENGTH in kwargs["headers"]:
×
260
                del kwargs["headers"][hdrs.CONTENT_LENGTH]
261

1✔
262
            if hdrs.CONTENT_ENCODING in kwargs["headers"]:
×
263
                del kwargs["headers"][hdrs.CONTENT_ENCODING]
264

1✔
265
            if hdrs.CONTENT_TRANSFER_ENCODING in kwargs["headers"]:
×
266
                del kwargs["headers"][hdrs.CONTENT_TRANSFER_ENCODING]
267

1✔
268
            if hdrs.TRANSFER_ENCODING in kwargs["headers"]:
×
269
                del kwargs["headers"][hdrs.TRANSFER_ENCODING]
270

1✔
271
            if "x-connection-hash" in kwargs["headers"]:
×
272
                del kwargs["headers"]["x-connection-hash"]
273

274
            while hdrs.SET_COOKIE in kwargs["headers"]:
275
                del kwargs["headers"][hdrs.SET_COOKIE]
1✔
UNCOV
276

×
277
            # cache control
278

1✔
279
            if hdrs.CACHE_CONTROL in kwargs["headers"]:
280
                del kwargs["headers"][hdrs.CACHE_CONTROL]
1✔
281

282
            kwargs["headers"][hdrs.CACHE_CONTROL] = "no-cache, no-store, must-revalidate"
1✔
283

284
        return web.Response(*args, **kwargs)
285

286
    async def all_handler(self, request: web.Request, **kwargs) -> web.Response:
287
        """Handle all requests.
288

289
        This handler will exit on successful test found in self.tests or if a /stop url is seen. This handler can be used with any aiohttp webserver and disabled after registered using self.all_haandler_active.
290

291
        Args
292
            request (web.Request): The request to process
293
            **kwargs: Additional keyword arguments
294
                access_url (URL): The access url for the proxy. Defaults to self.access_url()
295
                host_url (URL): The host url for the proxy. Defaults to self._host_url
296

297
        Returns
298
            web.Response: The webresponse to the browser
299

300
        Raises
301
            web.HTTPFound: Redirect URL upon success
1✔
UNCOV
302
            web.HTTPNotFound: Return 404 when all_handler is disabled
×
303

304
        """
1✔
305
        if "access_url" in kwargs:
306
            access_url = kwargs.pop("access_url")
1✔
UNCOV
307
        else:
×
308
            access_url = self.access_url()
309

1✔
310
        if "host_url" in kwargs:
311
            host_url = kwargs.pop("host_url")
1✔
312
        else:
313
            host_url = self._host_url
314

315
        async def _process_multipart(reader: MultipartReader, writer: MultipartWriter) -> None:
316
            """Process multipart.
317

UNCOV
318
            Args:
×
UNCOV
319
                reader (MultipartReader): Response multipart to process.
×
320
                writer (MultipartWriter): Multipart to write out.
UNCOV
321
            """
×
322
            while True:
×
323
                part = await reader.next()  # noqa: B305
×
UNCOV
324
                # https://github.com/PyCQA/flake8-bugbear/issues/59
×
325
                if part is None:
×
326
                    break
×
327
                if isinstance(part, MultipartReader):
×
328
                    await _process_multipart(part, writer)
×
NEW
329
                elif hdrs.CONTENT_TYPE in part.headers:
×
NEW
330
                    content_type = part.headers.get(hdrs.CONTENT_TYPE, "")
×
331
                    mime_type = content_type.split(";", 1)[0].strip()
332
                    if mime_type == "application/json":
NEW
333
                        try:
×
NEW
334
                            part_data: Optional[
×
335
                                Union[Text, Dict[Text, Any], List[Tuple[Text, Text]], bytes]
NEW
336
                            ] = await part.json()
×
NEW
337
                            writer.append_json(part_data)
×
NEW
338
                        except Exception:
×
NEW
339
                            # Best-effort fallback: text, then bytes
×
NEW
340
                            try:
×
NEW
341
                                part_text = await part.text()
×
NEW
342
                                writer.append(part_text)
×
NEW
343
                            except Exception:
×
NEW
344
                                part_data = await part.read()
×
NEW
345
                                writer.append(part_data)
×
NEW
346
                    elif mime_type.startswith("text"):
×
347
                        part_data = await part.text()
×
348
                        writer.append(part_data)
NEW
349
                    elif mime_type == "application/x-www-form-urlencoded":
×
350
                        part_data = await part.form()
×
351
                        writer.append_form(part_data)
UNCOV
352
                    else:
×
353
                        part_data = await part.read()
×
354
                        writer.append(part_data)
×
UNCOV
355
                else:
×
356
                    part_data = await part.read()
×
357
                    if part.name:
×
358
                        self.data.update({part.name: part_data})
×
359
                    elif part.filename:
360
                        part_data = await part.read()
1✔
361
                        self.data.update({part.filename: part_data})
×
362
                    writer.append(part_data)
×
363

364
        if not self.all_handler_active:
365
            _LOGGER.debug("%s all_handler is disabled; returning 404.", self)
1✔
366
            raise web.HTTPNotFound()
1✔
367
        # if not self.session:
1✔
368
        #     self.session = httpx.AsyncClient()
1✔
369
        method = request.method.lower()
370
        _LOGGER.debug("Received %s: %s for %s", method, str(request.url), host_url)
371
        resp: Optional[httpx.Response] = None
372
        old_url: URL = (
373
            access_url.with_host(request.url.host)
1✔
374
            if request.url.host and request.url.host != access_url.host
UNCOV
375
            else access_url
×
UNCOV
376
        )
×
377
        if request.scheme == "http" and access_url.scheme == "https":
378
            # detect reverse proxy downgrade
379
            _LOGGER.debug("Detected http while should be https; switching to https")
380
            site: str = str(
381
                swap_url(
382
                    ignore_query=True,
383
                    old_url=old_url.with_scheme("https"),
384
                    new_url=host_url.with_path("/"),
385
                    url=URL(str(request.url)).with_scheme("https"),
1✔
386
                ),
387
            )
388
        else:
389
            site = str(
390
                swap_url(
391
                    ignore_query=True,
392
                    old_url=old_url,
393
                    new_url=host_url.with_path("/"),
1✔
394
                    url=URL(str(request.url)),
1✔
395
                ),
1✔
396
            )
1✔
397
        self.query.update(request.query)
×
398
        data: Optional[Dict] = None
×
399
        mpwriter = None
400
        if request.content_type == "multipart/form-data":
1✔
401
            mpwriter = MultipartWriter()
1✔
402
            await _process_multipart(await request.multipart(), mpwriter)
403
        else:
1✔
404
            data = convert_multidict_to_dict(await request.post())
405
        json_data = None
406
        # Only attempt JSON decoding for JSON requests; avoid raising for form posts.
1✔
407
        if request.has_body and (
1✔
NEW
408
            request.content_type == "application/json"
×
NEW
409
            or request.content_type.endswith("+json")
×
410
        ):
1✔
411
            try:
1✔
412
                json_data = await request.json()
1✔
413
            except (JSONDecodeError, ValueError):
1✔
414
                json_data = None
1✔
415
        if data:
1✔
416
            self.data.update(data)
1✔
417
            _LOGGER.debug("Storing data %s", data)
418
        elif json_data:
419
            self.data.update(json_data)
×
420
            _LOGGER.debug("Storing json %s", json_data)
×
421
        if URL(str(request.url)).path == re.sub(
×
UNCOV
422
            r"/+", "/", self._proxy_url.with_path(f"{self._proxy_url.path}/stop").path
×
UNCOV
423
        ):
×
424
            self.all_handler_active = False
×
425
            if self.active:
1✔
426
                task = asyncio.create_task(self.stop_proxy(3))
427
                self._background_tasks.add(task)
428
                task.add_done_callback(self._background_tasks.discard)
429
            return await self._build_response(text="Proxy stopped.")
430
        elif (
UNCOV
431
            URL(str(request.url)).path
×
UNCOV
432
            == re.sub(r"/+", "/", self._proxy_url.with_path(f"{self._proxy_url.path}/resume").path)
×
UNCOV
433
            and self.last_resp
×
434
            and isinstance(self.last_resp, httpx.Response)
435
        ):
1✔
436
            self.init_query = self.query.copy()
437
            _LOGGER.debug("Resuming request: %s", self.last_resp)
438
            resp = self.last_resp
439
        else:
440
            if URL(str(request.url)).path in [
441
                self._proxy_url.path,
UNCOV
442
                re.sub(
×
UNCOV
443
                    r"/+", "/", self._proxy_url.with_path(f"{self._proxy_url.path}/resume").path
×
UNCOV
444
                ),
×
UNCOV
445
            ]:
×
446
                # either base path or resume without anything to resume
447
                site = str(URL(host_url))
448
                if method == "get":
449
                    self.init_query = self.query.copy()
1✔
450
                    _LOGGER.debug(
1✔
451
                        "Starting auth capture proxy for %s",
1✔
UNCOV
452
                        host_url,
×
UNCOV
453
                    )
×
454
            headers = await self.modify_headers(URL(site), request)
455
            skip_auto_headers: List[str] = headers.get(SKIP_AUTO_HEADERS, [])
1✔
456
            if skip_auto_headers:
1✔
457
                _LOGGER.debug("Discovered skip_auto_headers %s", skip_auto_headers)
458
                headers.pop(SKIP_AUTO_HEADERS)
459
            # Avoid accidental header mutation across branches/calls
460
            req_headers: dict[str, Any] = dict(headers)
461
            _LOGGER.debug(
462
                "Attempting %s to %s\nheaders: %s \ncookies: %s",
463
                method,
1✔
464
                site,
1✔
NEW
465
                req_headers,
×
466
                self.session.cookies.jar,
467
            )
468
            try:
1✔
469
                if mpwriter:
1✔
470
                    resp = await getattr(self.session, method)(
471
                        site, data=mpwriter, headers=req_headers, follow_redirects=True
472
                    )
1✔
473
                elif data:
1✔
474
                    resp = await getattr(self.session, method)(
475
                        site, data=data, headers=req_headers, follow_redirects=True
1✔
476
                    )
1✔
477
                elif json_data:
1✔
478
                    for item in ["Host", "Origin", "User-Agent", "dnt", "Accept-Encoding"]:
479
                        # remove proxy headers
480
                        if req_headers.get(item):
NEW
481
                            req_headers.pop(item)
×
482
                    resp = await getattr(self.session, method)(
483
                        site, json=json_data, headers=req_headers, follow_redirects=True
UNCOV
484
                    )
×
UNCOV
485
                else:
×
486
                    resp = await getattr(self.session, method)(
487
                        site, headers=req_headers, follow_redirects=True
UNCOV
488
                    )
×
NEW
489
            except httpx.ConnectError as ex:
×
490
                return await self._build_response(
491
                    text=f"Error connecting to {site}; please retry: {ex}"
UNCOV
492
                )
×
NEW
493
            except httpx.TooManyRedirects as ex:
×
494
                return await self._build_response(
495
                    text=f"Error connecting to {site}; too many redirects: {ex}"
496
                )
497
            except httpx.TimeoutException as ex:
498
                _LOGGER.warning(
×
499
                    "Timeout during proxy request to %s: %s",
500
                    site,
501
                    ex.__class__.__name__,
502
                )
503
                return await self._build_response(
504
                    text=(
505
                        "Timed out while contacting the service during login.\n\n"
UNCOV
506
                        "This is usually caused by slow or blocked network access. "
×
UNCOV
507
                        "Please retry, or check DNS resolution, firewall rules, proxy/VPN settings, "
×
508
                        "and that the service endpoint is reachable from this host."
1✔
UNCOV
509
                    )
×
510
                )
1✔
511
            except httpx.HTTPError as ex:
1✔
512
                return await self._build_response(
1✔
513
                    text=f"Error connecting to {site}: {ex}"
1✔
514
                )
1✔
515
        if resp is None:
×
516
            return await self._build_response(text=f"Error connecting to {site}; please retry")
×
517
        self.last_resp = resp
×
518
        print_resp(resp)
×
519
        self.check_redirects()
×
520
        self.refresh_tests()
×
521
        if self.tests:
×
522
            for test_name, test in self.tests.items():
523
                result = None
524
                result = await run_func(test, test_name, resp, self.data, self.query)
525
                if result:
×
526
                    _LOGGER.debug("Test %s triggered", test_name)
×
527
                    if isinstance(result, URL):
×
528
                        _LOGGER.debug(
×
529
                            "Redirecting to callback: %s",
530
                            result,
531
                        )
532
                        raise web.HTTPFound(location=result)
1✔
533
                    elif isinstance(result, str):
1✔
534
                        _LOGGER.debug("Displaying page:\n%s", result)
1✔
535
                        return await self._build_response(
1✔
UNCOV
536
                            resp, text=result, content_type="text/html"
×
UNCOV
537
                        )
×
UNCOV
538
        else:
×
539
            _LOGGER.warning("Proxy has no tests; please set.")
×
540
        content_type = get_content_type(resp)
×
541
        self.refresh_modifiers(URL(str(resp.url)))
542
        if self.modifiers:
×
543
            modified: bool = False
×
544
            if content_type != "text/html" and content_type not in self.modifiers.keys():
×
545
                text: Text = ""
×
546
            elif content_type != "text/html" and content_type in self.modifiers.keys():
×
547
                text = resp.text
×
UNCOV
548
            else:
×
549
                text = resp.text
×
550
            if not isinstance(text, str):  # process aiohttp text
×
551
                text = await resp.text()
×
552
            if text:
×
553
                for name, modifier in self.modifiers.items():
×
554
                    if isinstance(modifier, dict):
×
555
                        if name != content_type:
×
556
                            continue
557
                        for sub_name, sub_modifier in modifier.items():
558
                            try:
×
559
                                text = await run_func(sub_modifier, sub_name, text)
×
560
                                modified = True
×
561
                            except TypeError as ex:
×
562
                                _LOGGER.warning("Modifier %s is not callable: %s", sub_name, ex)
×
UNCOV
563
                    else:
×
564
                        # default run against text/html only
565
                        if content_type == "text/html":
×
566
                            try:
×
567
                                text = await run_func(modifier, name, text)
568
                                modified = True
569
                            except TypeError as ex:
570
                                _LOGGER.warning("Modifier %s is not callable: %s", name, ex)
571
                # _LOGGER.debug("Returning modified text:\n%s", text)
572
            if modified:
1✔
573
                return await self._build_response(
574
                    resp,
575
                    text=text,
576
                    content_type=content_type,
577
                )
578
        # pass through non parsed content
579
        _LOGGER.debug(
1✔
580
            "Passing through %s as %s",
581
            URL(str(request.url)).name
1✔
582
            if URL(str(request.url)).name
583
            else URL(str(request.url)).path,
584
            content_type,
585
        )
586
        return await self._build_response(resp, body=resp.content, content_type=content_type)
587

588
    async def start_proxy(
589
        self, host: Optional[Text] = None, ssl_context: Optional[SSLContext] = None
UNCOV
590
    ) -> None:
×
UNCOV
591
        """Start proxy.
×
592

593
        Args:
594
            host (Optional[Text], optional): The host interface to bind to. Defaults to None which is "0.0.0.0" all interfaces.
595
            ssl_context (Optional[SSLContext], optional): SSL Context for the server. Defaults to None.
UNCOV
596
        """
×
597
        app = web.Application()
×
598
        app.add_routes(
×
UNCOV
599
            [
×
UNCOV
600
                web.route("*", "/{tail:.*}", self.all_handler),
×
UNCOV
601
            ]
×
UNCOV
602
        )
×
603
        self.runner = web.AppRunner(app)
×
604
        await self.runner.setup()
×
605
        if not self.port:
×
606
            self._port = get_open_port()
×
607
        if self._proxy_url.scheme == "https" and ssl_context is None:
608
            _LOGGER.debug("Proxy url is https but no SSL Context set, downgrading to http")
1✔
609
            self._proxy_url = self._proxy_url.with_scheme("http")
610
        site = web.TCPSite(runner=self.runner, host=host, port=self.port, ssl_context=ssl_context)
611
        await site.start()
612
        self._active = True
613
        _LOGGER.debug("Started proxy at %s", self.access_url())
UNCOV
614

×
UNCOV
615
    async def stop_proxy(self, delay: int = 0) -> None:
×
UNCOV
616
        """Stop proxy server.
×
UNCOV
617

×
UNCOV
618
        Args:
×
UNCOV
619
            delay (int, optional): How many seconds to delay. Defaults to 0.
×
UNCOV
620
        """
×
621
        if not self.active:
×
622
            _LOGGER.debug("Proxy is not started; ignoring stop command")
×
623
            return
×
624
        _LOGGER.debug("Stopping proxy at %s after %s seconds", self.access_url(), delay)
625
        await asyncio.sleep(delay)
×
626
        _LOGGER.debug("Closing site runner")
×
627
        if self.runner:
×
628
            await self.runner.cleanup()
×
629
            await self.runner.shutdown()
×
630
        _LOGGER.debug("Site runner closed")
×
631
        # close session
632
        if self.session:
1✔
633
            _LOGGER.debug("Closing session")
634
            await self.session.aclose()
635
            _LOGGER.debug("Session closed")
636
        self._active = False
637
        _LOGGER.debug("Proxy stopped")
638

639
    def _swap_proxy_and_host(self, text: Text, domain_only: bool = False) -> Text:
640
        """Replace host with proxy address or proxy with host address.
641

642
        Args
UNCOV
643
            text (Text): text to replace
×
UNCOV
644
            domain (bool): Whether only the domains should be swapped.
×
645

646
        Returns
UNCOV
647
            Text: Result of replacing
×
UNCOV
648

×
649
        """
650
        host_string: Text = str(self._host_url.with_path("/"))
651
        proxy_string: Text = str(
652
            self.access_url() if not domain_only else self.access_url().with_path("/"))
UNCOV
653
        if str(self.access_url().with_path("/")).replace("https", "http") in text:
×
654
            _LOGGER.debug(
655
                "Replacing %s with %s",
656
                str(self.access_url().with_path("/")).replace("https", "http"),
UNCOV
657
                str(self.access_url().with_path("/")),
×
UNCOV
658
            )
×
659
            text = text.replace(
660
                str(self.access_url().with_path("/")).replace("https", "http"),
UNCOV
661
                str(self.access_url().with_path("/")),
×
UNCOV
662
            )
×
663
        if proxy_string in text:
×
664
            if host_string[-1] == "/" and (
×
UNCOV
665
                not proxy_string or proxy_string == "/" or proxy_string[-1] != "/"
×
666
            ):
667
                proxy_string = f"{proxy_string}/"
668
            _LOGGER.debug("Replacing %s with %s in %s", proxy_string, host_string, text)
×
669
            return text.replace(proxy_string, host_string)
×
670
        elif host_string in text:
×
671
            if host_string[-1] == "/" and (
UNCOV
672
                not proxy_string or proxy_string == "/" or proxy_string[-1] != "/"
×
UNCOV
673
            ):
×
674
                proxy_string = f"{proxy_string}/"
675
            _LOGGER.debug("Replacing %s with %s", host_string, proxy_string)
1✔
676
            return text.replace(host_string, proxy_string)
677
        else:
678
            _LOGGER.debug("Unable to find %s and %s in %s", host_string, proxy_string, text)
679
            return text
680

681
    async def modify_headers(self, site: URL, request: web.Request) -> dict:
682
        """Modify headers.
683

684
        Return modified headers based on site and request. To disable auto header generation,
685
        pass in to the header a key const.SKIP_AUTO_HEADERS with a list of keys to not generate.
686

687
        For example, to prevent User-Agent generation: {SKIP_AUTO_HEADERS : ["User-Agent"]}
688

689
        Args:
UNCOV
690
            site (URL): URL of the next host request.
×
UNCOV
691
            request (web.Request): Proxy directed request. This will need to be changed for the actual host request.
×
692

UNCOV
693
        Returns:
×
UNCOV
694
            dict: Headers after modifications
×
UNCOV
695
        """
×
696
        result: Dict[str, Any] = {}
×
697
        result.update(request.headers)
UNCOV
698
        # _LOGGER.debug("Original headers %s", headers)
×
699
        if result.get("Host"):
×
700
            result.pop("Host")
×
701
        if result.get("Origin"):
702
            result["Origin"] = f"{site.with_path('')}"
703
        # remove any cookies in header received from browser. If not removed, httpx will not send session cookies
704
        if result.get("Cookie"):
705
            result.pop("Cookie")
706
        if result.get("Referer") and (
×
UNCOV
707
            URL(result.get("Referer", "")).query == self.init_query
×
UNCOV
708
            or URL(result.get("Referer", "")).path
×
709
            == "/config/integrations"  # home-assistant referer
710
        ):
UNCOV
711
            # Change referer for starting request; this may have query items we shouldn't pass
×
712
            result["Referer"] = str(self._host_url)
713
        elif result.get("Referer"):
714
            result["Referer"] = self._swap_proxy_and_host(
715
                result.get("Referer", ""), domain_only=True
716
            )
717
        for item in [
718
            "Content-Length",
719
            "X-Forwarded-For",
720
            "X-Forwarded-Host",
721
            "X-Forwarded-Port",
UNCOV
722
            "X-Forwarded-Proto",
×
UNCOV
723
            "X-Forwarded-Scheme",
×
UNCOV
724
            "X-Forwarded-Server",
×
UNCOV
725
            "X-Real-IP",
×
UNCOV
726
        ]:
×
727
            # remove proxy headers
728
            if result.get(item):
1✔
729
                result.pop(item)
730
        result.update(self.headers if self.headers else {})
731
        _LOGGER.debug("Final headers %s", result)
732
        return result
733

1✔
UNCOV
734
    def check_redirects(self) -> None:
×
735
        """Change host if redirect detected and regex does not match self.redirect_filters.
1✔
736

1✔
UNCOV
737
        Self.redirect_filters is a dict with key as attr in resp and value as list of regex expressions to filter against.
×
UNCOV
738
        """
×
739
        if not self.last_resp:
740
            return
741
        resp: httpx.Response = self.last_resp
742
        if resp.history:
743
            for item in resp.history:
744
                if (
×
UNCOV
745
                    item.status_code in [301, 302, 303, 304, 305, 306, 307, 308]
×
UNCOV
746
                    and item.url
×
747
                    and resp.url
748
                    and resp.url.host != self._host_url.host
749
                ):
750
                    filtered = False
751
                    for attr, regex_list in self.redirect_filters.items():
752
                        if getattr(resp, attr) and list(
753
                            filter(
UNCOV
754
                                lambda regex_string: re.search(
×
755
                                    regex_string, str(getattr(resp, attr))
756
                                ),
757
                                regex_list,
758
                            )
759
                        ):
760
                            _LOGGER.debug(
761
                                "Check_redirects: Filtered out on %s in %s for resp attribute %s",
762
                                list(
763
                                    filter(
764
                                        lambda regex_string: re.search(
765
                                            regex_string, str(getattr(resp, attr))
766
                                        ),
UNCOV
767
                                        regex_list,
×
UNCOV
768
                                    )
×
UNCOV
769
                                ),
×
UNCOV
770
                                str(getattr(resp, attr)),
×
771
                                attr,
772
                            )
773
                            filtered = True
774
                    if filtered:
775
                        return
776
                    _LOGGER.debug(
×
777
                        "Detected %s redirect from %s to %s; changing proxy host",
778
                        item.status_code,
779
                        item.url.host,
780
                        resp.url.host,
781
                    )
782
                    self._host_url = self._host_url.with_host(resp.url.host)
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc