You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
477 lines
15 KiB
477 lines
15 KiB
import codecs |
|
import email.message |
|
import ipaddress |
|
import mimetypes |
|
import os |
|
import re |
|
import time |
|
import typing |
|
from pathlib import Path |
|
from urllib.request import getproxies |
|
|
|
import sniffio |
|
|
|
from ._types import PrimitiveData |
|
|
|
if typing.TYPE_CHECKING: # pragma: no cover |
|
from ._urls import URL |
|
|
|
|
|
_HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"} |
|
_HTML5_FORM_ENCODING_REPLACEMENTS.update( |
|
{chr(c): "%{:02X}".format(c) for c in range(0x1F + 1) if c != 0x1B} |
|
) |
|
_HTML5_FORM_ENCODING_RE = re.compile( |
|
r"|".join([re.escape(c) for c in _HTML5_FORM_ENCODING_REPLACEMENTS.keys()]) |
|
) |
|
|
|
|
|
def normalize_header_key( |
|
value: typing.Union[str, bytes], |
|
lower: bool, |
|
encoding: typing.Optional[str] = None, |
|
) -> bytes: |
|
""" |
|
Coerce str/bytes into a strictly byte-wise HTTP header key. |
|
""" |
|
if isinstance(value, bytes): |
|
bytes_value = value |
|
else: |
|
bytes_value = value.encode(encoding or "ascii") |
|
|
|
return bytes_value.lower() if lower else bytes_value |
|
|
|
|
|
def normalize_header_value( |
|
value: typing.Union[str, bytes], encoding: typing.Optional[str] = None |
|
) -> bytes: |
|
""" |
|
Coerce str/bytes into a strictly byte-wise HTTP header value. |
|
""" |
|
if isinstance(value, bytes): |
|
return value |
|
return value.encode(encoding or "ascii") |
|
|
|
|
|
def primitive_value_to_str(value: "PrimitiveData") -> str: |
|
""" |
|
Coerce a primitive data type into a string value. |
|
|
|
Note that we prefer JSON-style 'true'/'false' for boolean values here. |
|
""" |
|
if value is True: |
|
return "true" |
|
elif value is False: |
|
return "false" |
|
elif value is None: |
|
return "" |
|
return str(value) |
|
|
|
|
|
def is_known_encoding(encoding: str) -> bool: |
|
""" |
|
Return `True` if `encoding` is a known codec. |
|
""" |
|
try: |
|
codecs.lookup(encoding) |
|
except LookupError: |
|
return False |
|
return True |
|
|
|
|
|
def format_form_param(name: str, value: str) -> bytes: |
|
""" |
|
Encode a name/value pair within a multipart form. |
|
""" |
|
|
|
def replacer(match: typing.Match[str]) -> str: |
|
return _HTML5_FORM_ENCODING_REPLACEMENTS[match.group(0)] |
|
|
|
value = _HTML5_FORM_ENCODING_RE.sub(replacer, value) |
|
return f'{name}="{value}"'.encode() |
|
|
|
|
|
# Null bytes; no need to recreate these on each call to guess_json_utf |
|
_null = b"\x00" |
|
_null2 = _null * 2 |
|
_null3 = _null * 3 |
|
|
|
|
|
def guess_json_utf(data: bytes) -> typing.Optional[str]: |
|
# JSON always starts with two ASCII characters, so detection is as |
|
# easy as counting the nulls and from their location and count |
|
# determine the encoding. Also detect a BOM, if present. |
|
sample = data[:4] |
|
if sample in (codecs.BOM_UTF32_LE, codecs.BOM_UTF32_BE): |
|
return "utf-32" # BOM included |
|
if sample[:3] == codecs.BOM_UTF8: |
|
return "utf-8-sig" # BOM included, MS style (discouraged) |
|
if sample[:2] in (codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE): |
|
return "utf-16" # BOM included |
|
nullcount = sample.count(_null) |
|
if nullcount == 0: |
|
return "utf-8" |
|
if nullcount == 2: |
|
if sample[::2] == _null2: # 1st and 3rd are null |
|
return "utf-16-be" |
|
if sample[1::2] == _null2: # 2nd and 4th are null |
|
return "utf-16-le" |
|
# Did not detect 2 valid UTF-16 ascii-range characters |
|
if nullcount == 3: |
|
if sample[:3] == _null3: |
|
return "utf-32-be" |
|
if sample[1:] == _null3: |
|
return "utf-32-le" |
|
# Did not detect a valid UTF-32 ascii-range character |
|
return None |
|
|
|
|
|
def get_ca_bundle_from_env() -> typing.Optional[str]: |
|
if "SSL_CERT_FILE" in os.environ: |
|
ssl_file = Path(os.environ["SSL_CERT_FILE"]) |
|
if ssl_file.is_file(): |
|
return str(ssl_file) |
|
if "SSL_CERT_DIR" in os.environ: |
|
ssl_path = Path(os.environ["SSL_CERT_DIR"]) |
|
if ssl_path.is_dir(): |
|
return str(ssl_path) |
|
return None |
|
|
|
|
|
def parse_header_links(value: str) -> typing.List[typing.Dict[str, str]]: |
|
""" |
|
Returns a list of parsed link headers, for more info see: |
|
https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Link |
|
The generic syntax of those is: |
|
Link: < uri-reference >; param1=value1; param2="value2" |
|
So for instance: |
|
Link; '<http:/.../front.jpeg>; type="image/jpeg",<http://.../back.jpeg>;' |
|
would return |
|
[ |
|
{"url": "http:/.../front.jpeg", "type": "image/jpeg"}, |
|
{"url": "http://.../back.jpeg"}, |
|
] |
|
:param value: HTTP Link entity-header field |
|
:return: list of parsed link headers |
|
""" |
|
links: typing.List[typing.Dict[str, str]] = [] |
|
replace_chars = " '\"" |
|
value = value.strip(replace_chars) |
|
if not value: |
|
return links |
|
for val in re.split(", *<", value): |
|
try: |
|
url, params = val.split(";", 1) |
|
except ValueError: |
|
url, params = val, "" |
|
link = {"url": url.strip("<> '\"")} |
|
for param in params.split(";"): |
|
try: |
|
key, value = param.split("=") |
|
except ValueError: |
|
break |
|
link[key.strip(replace_chars)] = value.strip(replace_chars) |
|
links.append(link) |
|
return links |
|
|
|
|
|
def parse_content_type_charset(content_type: str) -> typing.Optional[str]: |
|
# We used to use `cgi.parse_header()` here, but `cgi` became a dead battery. |
|
# See: https://peps.python.org/pep-0594/#cgi |
|
msg = email.message.Message() |
|
msg["content-type"] = content_type |
|
return msg.get_content_charset(failobj=None) |
|
|
|
|
|
SENSITIVE_HEADERS = {"authorization", "proxy-authorization"} |
|
|
|
|
|
def obfuscate_sensitive_headers( |
|
items: typing.Iterable[typing.Tuple[typing.AnyStr, typing.AnyStr]] |
|
) -> typing.Iterator[typing.Tuple[typing.AnyStr, typing.AnyStr]]: |
|
for k, v in items: |
|
if to_str(k.lower()) in SENSITIVE_HEADERS: |
|
v = to_bytes_or_str("[secure]", match_type_of=v) |
|
yield k, v |
|
|
|
|
|
def port_or_default(url: "URL") -> typing.Optional[int]: |
|
if url.port is not None: |
|
return url.port |
|
return {"http": 80, "https": 443}.get(url.scheme) |
|
|
|
|
|
def same_origin(url: "URL", other: "URL") -> bool: |
|
""" |
|
Return 'True' if the given URLs share the same origin. |
|
""" |
|
return ( |
|
url.scheme == other.scheme |
|
and url.host == other.host |
|
and port_or_default(url) == port_or_default(other) |
|
) |
|
|
|
|
|
def is_https_redirect(url: "URL", location: "URL") -> bool: |
|
""" |
|
Return 'True' if 'location' is a HTTPS upgrade of 'url' |
|
""" |
|
if url.host != location.host: |
|
return False |
|
|
|
return ( |
|
url.scheme == "http" |
|
and port_or_default(url) == 80 |
|
and location.scheme == "https" |
|
and port_or_default(location) == 443 |
|
) |
|
|
|
|
|
def get_environment_proxies() -> typing.Dict[str, typing.Optional[str]]: |
|
"""Gets proxy information from the environment""" |
|
|
|
# urllib.request.getproxies() falls back on System |
|
# Registry and Config for proxies on Windows and macOS. |
|
# We don't want to propagate non-HTTP proxies into |
|
# our configuration such as 'TRAVIS_APT_PROXY'. |
|
proxy_info = getproxies() |
|
mounts: typing.Dict[str, typing.Optional[str]] = {} |
|
|
|
for scheme in ("http", "https", "all"): |
|
if proxy_info.get(scheme): |
|
hostname = proxy_info[scheme] |
|
mounts[f"{scheme}://"] = ( |
|
hostname if "://" in hostname else f"http://{hostname}" |
|
) |
|
|
|
no_proxy_hosts = [host.strip() for host in proxy_info.get("no", "").split(",")] |
|
for hostname in no_proxy_hosts: |
|
# See https://curl.haxx.se/libcurl/c/CURLOPT_NOPROXY.html for details |
|
# on how names in `NO_PROXY` are handled. |
|
if hostname == "*": |
|
# If NO_PROXY=* is used or if "*" occurs as any one of the comma |
|
# separated hostnames, then we should just bypass any information |
|
# from HTTP_PROXY, HTTPS_PROXY, ALL_PROXY, and always ignore |
|
# proxies. |
|
return {} |
|
elif hostname: |
|
# NO_PROXY=.google.com is marked as "all://*.google.com, |
|
# which disables "www.google.com" but not "google.com" |
|
# NO_PROXY=google.com is marked as "all://*google.com, |
|
# which disables "www.google.com" and "google.com". |
|
# (But not "wwwgoogle.com") |
|
# NO_PROXY can include domains, IPv6, IPv4 addresses and "localhost" |
|
# NO_PROXY=example.com,::1,localhost,192.168.0.0/16 |
|
if is_ipv4_hostname(hostname): |
|
mounts[f"all://{hostname}"] = None |
|
elif is_ipv6_hostname(hostname): |
|
mounts[f"all://[{hostname}]"] = None |
|
elif hostname.lower() == "localhost": |
|
mounts[f"all://{hostname}"] = None |
|
else: |
|
mounts[f"all://*{hostname}"] = None |
|
|
|
return mounts |
|
|
|
|
|
def to_bytes(value: typing.Union[str, bytes], encoding: str = "utf-8") -> bytes: |
|
return value.encode(encoding) if isinstance(value, str) else value |
|
|
|
|
|
def to_str(value: typing.Union[str, bytes], encoding: str = "utf-8") -> str: |
|
return value if isinstance(value, str) else value.decode(encoding) |
|
|
|
|
|
def to_bytes_or_str(value: str, match_type_of: typing.AnyStr) -> typing.AnyStr: |
|
return value if isinstance(match_type_of, str) else value.encode() |
|
|
|
|
|
def unquote(value: str) -> str: |
|
return value[1:-1] if value[0] == value[-1] == '"' else value |
|
|
|
|
|
def guess_content_type(filename: typing.Optional[str]) -> typing.Optional[str]: |
|
if filename: |
|
return mimetypes.guess_type(filename)[0] or "application/octet-stream" |
|
return None |
|
|
|
|
|
def peek_filelike_length(stream: typing.Any) -> typing.Optional[int]: |
|
""" |
|
Given a file-like stream object, return its length in number of bytes |
|
without reading it into memory. |
|
""" |
|
try: |
|
# Is it an actual file? |
|
fd = stream.fileno() |
|
# Yup, seems to be an actual file. |
|
length = os.fstat(fd).st_size |
|
except (AttributeError, OSError): |
|
# No... Maybe it's something that supports random access, like `io.BytesIO`? |
|
try: |
|
# Assuming so, go to end of stream to figure out its length, |
|
# then put it back in place. |
|
offset = stream.tell() |
|
length = stream.seek(0, os.SEEK_END) |
|
stream.seek(offset) |
|
except (AttributeError, OSError): |
|
# Not even that? Sorry, we're doomed... |
|
return None |
|
|
|
return length |
|
|
|
|
|
class Timer: |
|
async def _get_time(self) -> float: |
|
library = sniffio.current_async_library() |
|
if library == "trio": |
|
import trio |
|
|
|
return trio.current_time() |
|
elif library == "curio": # pragma: no cover |
|
import curio |
|
|
|
return typing.cast(float, await curio.clock()) |
|
|
|
import asyncio |
|
|
|
return asyncio.get_event_loop().time() |
|
|
|
def sync_start(self) -> None: |
|
self.started = time.perf_counter() |
|
|
|
async def async_start(self) -> None: |
|
self.started = await self._get_time() |
|
|
|
def sync_elapsed(self) -> float: |
|
now = time.perf_counter() |
|
return now - self.started |
|
|
|
async def async_elapsed(self) -> float: |
|
now = await self._get_time() |
|
return now - self.started |
|
|
|
|
|
class URLPattern: |
|
""" |
|
A utility class currently used for making lookups against proxy keys... |
|
|
|
# Wildcard matching... |
|
>>> pattern = URLPattern("all") |
|
>>> pattern.matches(httpx.URL("http://example.com")) |
|
True |
|
|
|
# Witch scheme matching... |
|
>>> pattern = URLPattern("https") |
|
>>> pattern.matches(httpx.URL("https://example.com")) |
|
True |
|
>>> pattern.matches(httpx.URL("http://example.com")) |
|
False |
|
|
|
# With domain matching... |
|
>>> pattern = URLPattern("https://example.com") |
|
>>> pattern.matches(httpx.URL("https://example.com")) |
|
True |
|
>>> pattern.matches(httpx.URL("http://example.com")) |
|
False |
|
>>> pattern.matches(httpx.URL("https://other.com")) |
|
False |
|
|
|
# Wildcard scheme, with domain matching... |
|
>>> pattern = URLPattern("all://example.com") |
|
>>> pattern.matches(httpx.URL("https://example.com")) |
|
True |
|
>>> pattern.matches(httpx.URL("http://example.com")) |
|
True |
|
>>> pattern.matches(httpx.URL("https://other.com")) |
|
False |
|
|
|
# With port matching... |
|
>>> pattern = URLPattern("https://example.com:1234") |
|
>>> pattern.matches(httpx.URL("https://example.com:1234")) |
|
True |
|
>>> pattern.matches(httpx.URL("https://example.com")) |
|
False |
|
""" |
|
|
|
def __init__(self, pattern: str) -> None: |
|
from ._urls import URL |
|
|
|
if pattern and ":" not in pattern: |
|
raise ValueError( |
|
f"Proxy keys should use proper URL forms rather " |
|
f"than plain scheme strings. " |
|
f'Instead of "{pattern}", use "{pattern}://"' |
|
) |
|
|
|
url = URL(pattern) |
|
self.pattern = pattern |
|
self.scheme = "" if url.scheme == "all" else url.scheme |
|
self.host = "" if url.host == "*" else url.host |
|
self.port = url.port |
|
if not url.host or url.host == "*": |
|
self.host_regex: typing.Optional[typing.Pattern[str]] = None |
|
elif url.host.startswith("*."): |
|
# *.example.com should match "www.example.com", but not "example.com" |
|
domain = re.escape(url.host[2:]) |
|
self.host_regex = re.compile(f"^.+\\.{domain}$") |
|
elif url.host.startswith("*"): |
|
# *example.com should match "www.example.com" and "example.com" |
|
domain = re.escape(url.host[1:]) |
|
self.host_regex = re.compile(f"^(.+\\.)?{domain}$") |
|
else: |
|
# example.com should match "example.com" but not "www.example.com" |
|
domain = re.escape(url.host) |
|
self.host_regex = re.compile(f"^{domain}$") |
|
|
|
def matches(self, other: "URL") -> bool: |
|
if self.scheme and self.scheme != other.scheme: |
|
return False |
|
if ( |
|
self.host |
|
and self.host_regex is not None |
|
and not self.host_regex.match(other.host) |
|
): |
|
return False |
|
if self.port is not None and self.port != other.port: |
|
return False |
|
return True |
|
|
|
@property |
|
def priority(self) -> typing.Tuple[int, int, int]: |
|
""" |
|
The priority allows URLPattern instances to be sortable, so that |
|
we can match from most specific to least specific. |
|
""" |
|
# URLs with a port should take priority over URLs without a port. |
|
port_priority = 0 if self.port is not None else 1 |
|
# Longer hostnames should match first. |
|
host_priority = -len(self.host) |
|
# Longer schemes should match first. |
|
scheme_priority = -len(self.scheme) |
|
return (port_priority, host_priority, scheme_priority) |
|
|
|
def __hash__(self) -> int: |
|
return hash(self.pattern) |
|
|
|
def __lt__(self, other: "URLPattern") -> bool: |
|
return self.priority < other.priority |
|
|
|
def __eq__(self, other: typing.Any) -> bool: |
|
return isinstance(other, URLPattern) and self.pattern == other.pattern |
|
|
|
|
|
def is_ipv4_hostname(hostname: str) -> bool: |
|
try: |
|
ipaddress.IPv4Address(hostname.split("/")[0]) |
|
except Exception: |
|
return False |
|
return True |
|
|
|
|
|
def is_ipv6_hostname(hostname: str) -> bool: |
|
try: |
|
ipaddress.IPv6Address(hostname.split("/")[0]) |
|
except Exception: |
|
return False |
|
return True
|
|
|