You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
414 lines
16 KiB
414 lines
16 KiB
# Copyright 2019-present MongoDB, Inc. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); you |
|
# may not use this file except in compliance with the License. You |
|
# may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or |
|
# implied. See the License for the specific language governing |
|
# permissions and limitations under the License. |
|
|
|
"""A CPython compatible SSLContext implementation wrapping PyOpenSSL's |
|
context. |
|
""" |
|
from __future__ import annotations |
|
|
|
import socket as _socket |
|
import ssl as _stdlibssl |
|
import sys as _sys |
|
import time as _time |
|
from errno import EINTR as _EINTR |
|
from ipaddress import ip_address as _ip_address |
|
from typing import TYPE_CHECKING, Any, Callable, List, Optional, TypeVar, Union |
|
|
|
from cryptography.x509 import load_der_x509_certificate as _load_der_x509_certificate |
|
from OpenSSL import SSL as _SSL |
|
from OpenSSL import crypto as _crypto |
|
from service_identity import CertificateError as _SICertificateError |
|
from service_identity import VerificationError as _SIVerificationError |
|
from service_identity.pyopenssl import verify_hostname as _verify_hostname |
|
from service_identity.pyopenssl import verify_ip_address as _verify_ip_address |
|
|
|
from pymongo.errors import ConfigurationError as _ConfigurationError |
|
from pymongo.errors import _CertificateError |
|
from pymongo.ocsp_cache import _OCSPCache |
|
from pymongo.ocsp_support import _load_trusted_ca_certs, _ocsp_callback |
|
from pymongo.socket_checker import SocketChecker as _SocketChecker |
|
from pymongo.socket_checker import _errno_from_exception |
|
from pymongo.write_concern import validate_boolean |
|
|
|
if TYPE_CHECKING: |
|
from ssl import VerifyMode |
|
|
|
from cryptography.x509 import Certificate |
|
|
|
_T = TypeVar("_T") |
|
|
|
try: |
|
import certifi |
|
|
|
_HAVE_CERTIFI = True |
|
except ImportError: |
|
_HAVE_CERTIFI = False |
|
|
|
PROTOCOL_SSLv23 = _SSL.SSLv23_METHOD |
|
# Always available |
|
OP_NO_SSLv2 = _SSL.OP_NO_SSLv2 |
|
OP_NO_SSLv3 = _SSL.OP_NO_SSLv3 |
|
OP_NO_COMPRESSION = _SSL.OP_NO_COMPRESSION |
|
# This isn't currently documented for PyOpenSSL |
|
OP_NO_RENEGOTIATION = getattr(_SSL, "OP_NO_RENEGOTIATION", 0) |
|
|
|
# Always available |
|
HAS_SNI = True |
|
IS_PYOPENSSL = True |
|
|
|
# Base Exception class |
|
SSLError = _SSL.Error |
|
|
|
# https://github.com/python/cpython/blob/v3.8.0/Modules/_ssl.c#L2995-L3002 |
|
_VERIFY_MAP = { |
|
_stdlibssl.CERT_NONE: _SSL.VERIFY_NONE, |
|
_stdlibssl.CERT_OPTIONAL: _SSL.VERIFY_PEER, |
|
_stdlibssl.CERT_REQUIRED: _SSL.VERIFY_PEER | _SSL.VERIFY_FAIL_IF_NO_PEER_CERT, |
|
} |
|
|
|
_REVERSE_VERIFY_MAP = {value: key for key, value in _VERIFY_MAP.items()} |
|
|
|
|
|
# For SNI support. According to RFC6066, section 3, IPv4 and IPv6 literals are |
|
# not permitted for SNI hostname. |
|
def _is_ip_address(address: Any) -> bool: |
|
try: |
|
_ip_address(address) |
|
return True |
|
except (ValueError, UnicodeError): # noqa: B014 |
|
return False |
|
|
|
|
|
# According to the docs for socket.send it can raise |
|
# WantX509LookupError and should be retried. |
|
BLOCKING_IO_ERRORS = (_SSL.WantReadError, _SSL.WantWriteError, _SSL.WantX509LookupError) |
|
|
|
|
|
def _ragged_eof(exc: BaseException) -> bool: |
|
"""Return True if the OpenSSL.SSL.SysCallError is a ragged EOF.""" |
|
return exc.args == (-1, "Unexpected EOF") |
|
|
|
|
|
# https://github.com/pyca/pyopenssl/issues/168 |
|
# https://github.com/pyca/pyopenssl/issues/176 |
|
# https://docs.python.org/3/library/ssl.html#notes-on-non-blocking-sockets |
|
class _sslConn(_SSL.Connection): |
|
def __init__( |
|
self, ctx: _SSL.Context, sock: Optional[_socket.socket], suppress_ragged_eofs: bool |
|
): |
|
self.socket_checker = _SocketChecker() |
|
self.suppress_ragged_eofs = suppress_ragged_eofs |
|
super().__init__(ctx, sock) |
|
|
|
def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: |
|
timeout = self.gettimeout() |
|
if timeout: |
|
start = _time.monotonic() |
|
while True: |
|
try: |
|
return call(*args, **kwargs) |
|
except BLOCKING_IO_ERRORS as exc: |
|
# Check for closed socket. |
|
if self.fileno() == -1: |
|
if timeout and _time.monotonic() - start > timeout: |
|
raise _socket.timeout("timed out") |
|
raise SSLError("Underlying socket has been closed") |
|
if isinstance(exc, _SSL.WantReadError): |
|
want_read = True |
|
want_write = False |
|
elif isinstance(exc, _SSL.WantWriteError): |
|
want_read = False |
|
want_write = True |
|
else: |
|
want_read = True |
|
want_write = True |
|
self.socket_checker.select(self, want_read, want_write, timeout) |
|
if timeout and _time.monotonic() - start > timeout: |
|
raise _socket.timeout("timed out") |
|
continue |
|
|
|
def do_handshake(self, *args: Any, **kwargs: Any) -> None: |
|
return self._call(super().do_handshake, *args, **kwargs) |
|
|
|
def recv(self, *args: Any, **kwargs: Any) -> bytes: |
|
try: |
|
return self._call(super().recv, *args, **kwargs) |
|
except _SSL.SysCallError as exc: |
|
# Suppress ragged EOFs to match the stdlib. |
|
if self.suppress_ragged_eofs and _ragged_eof(exc): |
|
return b"" |
|
raise |
|
|
|
def recv_into(self, *args: Any, **kwargs: Any) -> int: |
|
try: |
|
return self._call(super().recv_into, *args, **kwargs) |
|
except _SSL.SysCallError as exc: |
|
# Suppress ragged EOFs to match the stdlib. |
|
if self.suppress_ragged_eofs and _ragged_eof(exc): |
|
return 0 |
|
raise |
|
|
|
def sendall(self, buf: bytes, flags: int = 0) -> None: # type: ignore[override] |
|
view = memoryview(buf) |
|
total_length = len(buf) |
|
total_sent = 0 |
|
while total_sent < total_length: |
|
try: |
|
sent = self._call(super().send, view[total_sent:], flags) |
|
# XXX: It's not clear if this can actually happen. PyOpenSSL |
|
# doesn't appear to have any interrupt handling, nor any interrupt |
|
# errors for OpenSSL connections. |
|
except OSError as exc: # noqa: B014 |
|
if _errno_from_exception(exc) == _EINTR: |
|
continue |
|
raise |
|
# https://github.com/pyca/pyopenssl/blob/19.1.0/src/OpenSSL/SSL.py#L1756 |
|
# https://www.openssl.org/docs/man1.0.2/man3/SSL_write.html |
|
if sent <= 0: |
|
raise OSError("connection closed") |
|
total_sent += sent |
|
|
|
|
|
class _CallbackData: |
|
"""Data class which is passed to the OCSP callback.""" |
|
|
|
def __init__(self) -> None: |
|
self.trusted_ca_certs: Optional[List[Certificate]] = None |
|
self.check_ocsp_endpoint: Optional[bool] = None |
|
self.ocsp_response_cache = _OCSPCache() |
|
|
|
|
|
class SSLContext: |
|
"""A CPython compatible SSLContext implementation wrapping PyOpenSSL's |
|
context. |
|
""" |
|
|
|
__slots__ = ("_protocol", "_ctx", "_callback_data", "_check_hostname") |
|
|
|
def __init__(self, protocol: int): |
|
self._protocol = protocol |
|
self._ctx = _SSL.Context(self._protocol) |
|
self._callback_data = _CallbackData() |
|
self._check_hostname = True |
|
# OCSP |
|
# XXX: Find a better place to do this someday, since this is client |
|
# side configuration and wrap_socket tries to support both client and |
|
# server side sockets. |
|
self._callback_data.check_ocsp_endpoint = True |
|
self._ctx.set_ocsp_client_callback(callback=_ocsp_callback, data=self._callback_data) |
|
|
|
@property |
|
def protocol(self) -> int: |
|
"""The protocol version chosen when constructing the context. |
|
This attribute is read-only. |
|
""" |
|
return self._protocol |
|
|
|
def __get_verify_mode(self) -> VerifyMode: |
|
"""Whether to try to verify other peers' certificates and how to |
|
behave if verification fails. This attribute must be one of |
|
ssl.CERT_NONE, ssl.CERT_OPTIONAL or ssl.CERT_REQUIRED. |
|
""" |
|
return _REVERSE_VERIFY_MAP[self._ctx.get_verify_mode()] |
|
|
|
def __set_verify_mode(self, value: VerifyMode) -> None: |
|
"""Setter for verify_mode.""" |
|
|
|
def _cb( |
|
connobj: _SSL.Connection, |
|
x509obj: _crypto.X509, |
|
errnum: int, |
|
errdepth: int, |
|
retcode: int, |
|
) -> bool: |
|
# It seems we don't need to do anything here. Twisted doesn't, |
|
# and OpenSSL's SSL_CTX_set_verify let's you pass NULL |
|
# for the callback option. It's weird that PyOpenSSL requires |
|
# this. |
|
# This is optional in pyopenssl >= 20 and can be removed once minimum |
|
# supported version is bumped |
|
# See: pyopenssl.org/en/latest/changelog.html#id47 |
|
return bool(retcode) |
|
|
|
self._ctx.set_verify(_VERIFY_MAP[value], _cb) |
|
|
|
verify_mode = property(__get_verify_mode, __set_verify_mode) |
|
|
|
def __get_check_hostname(self) -> bool: |
|
return self._check_hostname |
|
|
|
def __set_check_hostname(self, value: Any) -> None: |
|
validate_boolean("check_hostname", value) |
|
self._check_hostname = value |
|
|
|
check_hostname = property(__get_check_hostname, __set_check_hostname) |
|
|
|
def __get_check_ocsp_endpoint(self) -> Optional[bool]: |
|
return self._callback_data.check_ocsp_endpoint |
|
|
|
def __set_check_ocsp_endpoint(self, value: bool) -> None: |
|
validate_boolean("check_ocsp", value) |
|
self._callback_data.check_ocsp_endpoint = value |
|
|
|
check_ocsp_endpoint = property(__get_check_ocsp_endpoint, __set_check_ocsp_endpoint) |
|
|
|
def __get_options(self) -> None: |
|
# Calling set_options adds the option to the existing bitmask and |
|
# returns the new bitmask. |
|
# https://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Context.set_options |
|
return self._ctx.set_options(0) |
|
|
|
def __set_options(self, value: int) -> None: |
|
# Explcitly convert to int, since newer CPython versions |
|
# use enum.IntFlag for options. The values are the same |
|
# regardless of implementation. |
|
self._ctx.set_options(int(value)) |
|
|
|
options = property(__get_options, __set_options) |
|
|
|
def load_cert_chain( |
|
self, |
|
certfile: Union[str, bytes], |
|
keyfile: Union[str, bytes, None] = None, |
|
password: Optional[str] = None, |
|
) -> None: |
|
"""Load a private key and the corresponding certificate. The certfile |
|
string must be the path to a single file in PEM format containing the |
|
certificate as well as any number of CA certificates needed to |
|
establish the certificate's authenticity. The keyfile string, if |
|
present, must point to a file containing the private key. Otherwise |
|
the private key will be taken from certfile as well. |
|
""" |
|
# Match CPython behavior |
|
# https://github.com/python/cpython/blob/v3.8.0/Modules/_ssl.c#L3930-L3971 |
|
# Password callback MUST be set first or it will be ignored. |
|
if password: |
|
|
|
def _pwcb(max_length: int, prompt_twice: bool, user_data: bytes) -> bytes: |
|
# XXX:We could check the password length against what OpenSSL |
|
# tells us is the max, but we can't raise an exception, so... |
|
# warn? |
|
assert password is not None |
|
return password.encode("utf-8") |
|
|
|
self._ctx.set_passwd_cb(_pwcb) |
|
self._ctx.use_certificate_chain_file(certfile) |
|
self._ctx.use_privatekey_file(keyfile or certfile) |
|
self._ctx.check_privatekey() |
|
|
|
def load_verify_locations( |
|
self, cafile: Optional[str] = None, capath: Optional[str] = None |
|
) -> None: |
|
"""Load a set of "certification authority"(CA) certificates used to |
|
validate other peers' certificates when `~verify_mode` is other than |
|
ssl.CERT_NONE. |
|
""" |
|
self._ctx.load_verify_locations(cafile, capath) |
|
# Manually load the CA certs when get_verified_chain is not available (pyopenssl<20). |
|
if not hasattr(_SSL.Connection, "get_verified_chain"): |
|
assert cafile is not None |
|
self._callback_data.trusted_ca_certs = _load_trusted_ca_certs(cafile) |
|
|
|
def _load_certifi(self) -> None: |
|
"""Attempt to load CA certs from certifi.""" |
|
if _HAVE_CERTIFI: |
|
self.load_verify_locations(certifi.where()) |
|
else: |
|
raise _ConfigurationError( |
|
"tlsAllowInvalidCertificates is False but no system " |
|
"CA certificates could be loaded. Please install the " |
|
"certifi package, or provide a path to a CA file using " |
|
"the tlsCAFile option" |
|
) |
|
|
|
def _load_wincerts(self, store: str) -> None: |
|
"""Attempt to load CA certs from Windows trust store.""" |
|
cert_store = self._ctx.get_cert_store() |
|
oid = _stdlibssl.Purpose.SERVER_AUTH.oid |
|
for cert, encoding, trust in _stdlibssl.enum_certificates(store): # type: ignore |
|
if encoding == "x509_asn": |
|
if trust is True or oid in trust: |
|
cert_store.add_cert( |
|
_crypto.X509.from_cryptography(_load_der_x509_certificate(cert)) |
|
) |
|
|
|
def load_default_certs(self) -> None: |
|
"""A PyOpenSSL version of load_default_certs from CPython.""" |
|
# PyOpenSSL is incapable of loading CA certs from Windows, and mostly |
|
# incapable on macOS. |
|
# https://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Context.set_default_verify_paths |
|
if _sys.platform == "win32": |
|
try: |
|
for storename in ("CA", "ROOT"): |
|
self._load_wincerts(storename) |
|
except PermissionError: |
|
# Fall back to certifi |
|
self._load_certifi() |
|
elif _sys.platform == "darwin": |
|
self._load_certifi() |
|
self._ctx.set_default_verify_paths() |
|
|
|
def set_default_verify_paths(self) -> None: |
|
"""Specify that the platform provided CA certificates are to be used |
|
for verification purposes. |
|
""" |
|
# Note: See PyOpenSSL's docs for limitations, which are similar |
|
# but not that same as CPython's. |
|
self._ctx.set_default_verify_paths() |
|
|
|
def wrap_socket( |
|
self, |
|
sock: _socket.socket, |
|
server_side: bool = False, |
|
do_handshake_on_connect: bool = True, |
|
suppress_ragged_eofs: bool = True, |
|
server_hostname: Optional[str] = None, |
|
session: Optional[_SSL.Session] = None, |
|
) -> _sslConn: |
|
"""Wrap an existing Python socket connection and return a TLS socket |
|
object. |
|
""" |
|
ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs) |
|
if session: |
|
ssl_conn.set_session(session) |
|
if server_side is True: |
|
ssl_conn.set_accept_state() |
|
else: |
|
# SNI |
|
if server_hostname and not _is_ip_address(server_hostname): |
|
# XXX: Do this in a callback registered with |
|
# SSLContext.set_info_callback? See Twisted for an example. |
|
ssl_conn.set_tlsext_host_name(server_hostname.encode("idna")) |
|
if self.verify_mode != _stdlibssl.CERT_NONE: |
|
# Request a stapled OCSP response. |
|
ssl_conn.request_ocsp() |
|
ssl_conn.set_connect_state() |
|
# If this wasn't true the caller of wrap_socket would call |
|
# do_handshake() |
|
if do_handshake_on_connect: |
|
# XXX: If we do hostname checking in a callback we can get rid |
|
# of this call to do_handshake() since the handshake |
|
# will happen automatically later. |
|
ssl_conn.do_handshake() |
|
# XXX: Do this in a callback registered with |
|
# SSLContext.set_info_callback? See Twisted for an example. |
|
if self.check_hostname and server_hostname is not None: |
|
try: |
|
if _is_ip_address(server_hostname): |
|
_verify_ip_address(ssl_conn, server_hostname) |
|
else: |
|
_verify_hostname(ssl_conn, server_hostname) |
|
except (_SICertificateError, _SIVerificationError) as exc: |
|
raise _CertificateError(str(exc)) |
|
return ssl_conn
|
|
|