You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
138 lines
4.7 KiB
138 lines
4.7 KiB
# Copyright 2019-present MongoDB, Inc. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); you |
|
# may not use this file except in compliance with the License. You |
|
# may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or |
|
# implied. See the License for the specific language governing |
|
# permissions and limitations under the License. |
|
|
|
"""Support for resolving hosts and options from mongodb+srv:// URIs.""" |
|
from __future__ import annotations |
|
|
|
import ipaddress |
|
import random |
|
from typing import Any, List, Optional, Tuple, Union |
|
|
|
try: |
|
from dns import resolver |
|
|
|
_HAVE_DNSPYTHON = True |
|
except ImportError: |
|
_HAVE_DNSPYTHON = False |
|
|
|
from pymongo.common import CONNECT_TIMEOUT |
|
from pymongo.errors import ConfigurationError |
|
|
|
|
|
# dnspython can return bytes or str from various parts |
|
# of its API depending on version. We always want str. |
|
def maybe_decode(text: Union[str, bytes]) -> str: |
|
if isinstance(text, bytes): |
|
return text.decode() |
|
return text |
|
|
|
|
|
# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet. |
|
def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer: |
|
if hasattr(resolver, "resolve"): |
|
# dnspython >= 2 |
|
return resolver.resolve(*args, **kwargs) |
|
# dnspython 1.X |
|
return resolver.query(*args, **kwargs) |
|
|
|
|
|
_INVALID_HOST_MSG = ( |
|
"Invalid URI host: %s is not a valid hostname for 'mongodb+srv://'. " |
|
"Did you mean to use 'mongodb://'?" |
|
) |
|
|
|
|
|
class _SrvResolver: |
|
def __init__( |
|
self, |
|
fqdn: str, |
|
connect_timeout: Optional[float], |
|
srv_service_name: str, |
|
srv_max_hosts: int = 0, |
|
): |
|
self.__fqdn = fqdn |
|
self.__srv = srv_service_name |
|
self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT |
|
self.__srv_max_hosts = srv_max_hosts or 0 |
|
# Validate the fully qualified domain name. |
|
try: |
|
ipaddress.ip_address(fqdn) |
|
raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",)) |
|
except ValueError: |
|
pass |
|
|
|
try: |
|
self.__plist = self.__fqdn.split(".")[1:] |
|
except Exception: |
|
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) |
|
self.__slen = len(self.__plist) |
|
if self.__slen < 2: |
|
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) |
|
|
|
def get_options(self) -> Optional[str]: |
|
try: |
|
results = _resolve(self.__fqdn, "TXT", lifetime=self.__connect_timeout) |
|
except (resolver.NoAnswer, resolver.NXDOMAIN): |
|
# No TXT records |
|
return None |
|
except Exception as exc: |
|
raise ConfigurationError(str(exc)) |
|
if len(results) > 1: |
|
raise ConfigurationError("Only one TXT record is supported") |
|
return (b"&".join([b"".join(res.strings) for res in results])).decode("utf-8") |
|
|
|
def _resolve_uri(self, encapsulate_errors: bool) -> resolver.Answer: |
|
try: |
|
results = _resolve( |
|
"_" + self.__srv + "._tcp." + self.__fqdn, "SRV", lifetime=self.__connect_timeout |
|
) |
|
except Exception as exc: |
|
if not encapsulate_errors: |
|
# Raise the original error. |
|
raise |
|
# Else, raise all errors as ConfigurationError. |
|
raise ConfigurationError(str(exc)) |
|
return results |
|
|
|
def _get_srv_response_and_hosts( |
|
self, encapsulate_errors: bool |
|
) -> Tuple[resolver.Answer, List[Tuple[str, Any]]]: |
|
results = self._resolve_uri(encapsulate_errors) |
|
|
|
# Construct address tuples |
|
nodes = [ |
|
(maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) for res in results |
|
] |
|
|
|
# Validate hosts |
|
for node in nodes: |
|
try: |
|
nlist = node[0].lower().split(".")[1:][-self.__slen :] |
|
except Exception: |
|
raise ConfigurationError(f"Invalid SRV host: {node[0]}") |
|
if self.__plist != nlist: |
|
raise ConfigurationError(f"Invalid SRV host: {node[0]}") |
|
if self.__srv_max_hosts: |
|
nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes))) |
|
return results, nodes |
|
|
|
def get_hosts(self) -> List[Tuple[str, Any]]: |
|
_, nodes = self._get_srv_response_and_hosts(True) |
|
return nodes |
|
|
|
def get_hosts_and_min_ttl(self) -> Tuple[List[Tuple[str, Any]], int]: |
|
results, nodes = self._get_srv_response_and_hosts(False) |
|
rrset = results.rrset |
|
ttl = rrset.ttl if rrset else 0 |
|
return nodes, ttl
|
|
|