You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
318 lines
12 KiB
318 lines
12 KiB
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license |
|
|
|
"""DNS Versioned Zones.""" |
|
|
|
import collections |
|
import threading |
|
from typing import Callable, Deque, Optional, Set, Union |
|
|
|
import dns.exception |
|
import dns.immutable |
|
import dns.name |
|
import dns.node |
|
import dns.rdataclass |
|
import dns.rdataset |
|
import dns.rdatatype |
|
import dns.rdtypes.ANY.SOA |
|
import dns.zone |
|
|
|
|
|
class UseTransaction(dns.exception.DNSException): |
|
"""To alter a versioned zone, use a transaction.""" |
|
|
|
|
|
# Backwards compatibility |
|
Node = dns.zone.VersionedNode |
|
ImmutableNode = dns.zone.ImmutableVersionedNode |
|
Version = dns.zone.Version |
|
WritableVersion = dns.zone.WritableVersion |
|
ImmutableVersion = dns.zone.ImmutableVersion |
|
Transaction = dns.zone.Transaction |
|
|
|
|
|
class Zone(dns.zone.Zone): # lgtm[py/missing-equals] |
|
__slots__ = [ |
|
"_versions", |
|
"_versions_lock", |
|
"_write_txn", |
|
"_write_waiters", |
|
"_write_event", |
|
"_pruning_policy", |
|
"_readers", |
|
] |
|
|
|
node_factory = Node |
|
|
|
def __init__( |
|
self, |
|
origin: Optional[Union[dns.name.Name, str]], |
|
rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN, |
|
relativize: bool = True, |
|
pruning_policy: Optional[Callable[["Zone", Version], Optional[bool]]] = None, |
|
): |
|
"""Initialize a versioned zone object. |
|
|
|
*origin* is the origin of the zone. It may be a ``dns.name.Name``, |
|
a ``str``, or ``None``. If ``None``, then the zone's origin will |
|
be set by the first ``$ORIGIN`` line in a zone file. |
|
|
|
*rdclass*, an ``int``, the zone's rdata class; the default is class IN. |
|
|
|
*relativize*, a ``bool``, determine's whether domain names are |
|
relativized to the zone's origin. The default is ``True``. |
|
|
|
*pruning policy*, a function taking a ``Zone`` and a ``Version`` and returning |
|
a ``bool``, or ``None``. Should the version be pruned? If ``None``, |
|
the default policy, which retains one version is used. |
|
""" |
|
super().__init__(origin, rdclass, relativize) |
|
self._versions: Deque[Version] = collections.deque() |
|
self._version_lock = threading.Lock() |
|
if pruning_policy is None: |
|
self._pruning_policy = self._default_pruning_policy |
|
else: |
|
self._pruning_policy = pruning_policy |
|
self._write_txn: Optional[Transaction] = None |
|
self._write_event: Optional[threading.Event] = None |
|
self._write_waiters: Deque[threading.Event] = collections.deque() |
|
self._readers: Set[Transaction] = set() |
|
self._commit_version_unlocked( |
|
None, WritableVersion(self, replacement=True), origin |
|
) |
|
|
|
def reader( |
|
self, id: Optional[int] = None, serial: Optional[int] = None |
|
) -> Transaction: # pylint: disable=arguments-differ |
|
if id is not None and serial is not None: |
|
raise ValueError("cannot specify both id and serial") |
|
with self._version_lock: |
|
if id is not None: |
|
version = None |
|
for v in reversed(self._versions): |
|
if v.id == id: |
|
version = v |
|
break |
|
if version is None: |
|
raise KeyError("version not found") |
|
elif serial is not None: |
|
if self.relativize: |
|
oname = dns.name.empty |
|
else: |
|
assert self.origin is not None |
|
oname = self.origin |
|
version = None |
|
for v in reversed(self._versions): |
|
n = v.nodes.get(oname) |
|
if n: |
|
rds = n.get_rdataset(self.rdclass, dns.rdatatype.SOA) |
|
if rds and rds[0].serial == serial: |
|
version = v |
|
break |
|
if version is None: |
|
raise KeyError("serial not found") |
|
else: |
|
version = self._versions[-1] |
|
txn = Transaction(self, False, version) |
|
self._readers.add(txn) |
|
return txn |
|
|
|
def writer(self, replacement: bool = False) -> Transaction: |
|
event = None |
|
while True: |
|
with self._version_lock: |
|
# Checking event == self._write_event ensures that either |
|
# no one was waiting before we got lucky and found no write |
|
# txn, or we were the one who was waiting and got woken up. |
|
# This prevents "taking cuts" when creating a write txn. |
|
if self._write_txn is None and event == self._write_event: |
|
# Creating the transaction defers version setup |
|
# (i.e. copying the nodes dictionary) until we |
|
# give up the lock, so that we hold the lock as |
|
# short a time as possible. This is why we call |
|
# _setup_version() below. |
|
self._write_txn = Transaction( |
|
self, replacement, make_immutable=True |
|
) |
|
# give up our exclusive right to make a Transaction |
|
self._write_event = None |
|
break |
|
# Someone else is writing already, so we will have to |
|
# wait, but we want to do the actual wait outside the |
|
# lock. |
|
event = threading.Event() |
|
self._write_waiters.append(event) |
|
# wait (note we gave up the lock!) |
|
# |
|
# We only wake one sleeper at a time, so it's important |
|
# that no event waiter can exit this method (e.g. via |
|
# cancellation) without returning a transaction or waking |
|
# someone else up. |
|
# |
|
# This is not a problem with Threading module threads as |
|
# they cannot be canceled, but could be an issue with trio |
|
# tasks when we do the async version of writer(). |
|
# I.e. we'd need to do something like: |
|
# |
|
# try: |
|
# event.wait() |
|
# except trio.Cancelled: |
|
# with self._version_lock: |
|
# self._maybe_wakeup_one_waiter_unlocked() |
|
# raise |
|
# |
|
event.wait() |
|
# Do the deferred version setup. |
|
self._write_txn._setup_version() |
|
return self._write_txn |
|
|
|
def _maybe_wakeup_one_waiter_unlocked(self): |
|
if len(self._write_waiters) > 0: |
|
self._write_event = self._write_waiters.popleft() |
|
self._write_event.set() |
|
|
|
# pylint: disable=unused-argument |
|
def _default_pruning_policy(self, zone, version): |
|
return True |
|
|
|
# pylint: enable=unused-argument |
|
|
|
def _prune_versions_unlocked(self): |
|
assert len(self._versions) > 0 |
|
# Don't ever prune a version greater than or equal to one that |
|
# a reader has open. This pins versions in memory while the |
|
# reader is open, and importantly lets the reader open a txn on |
|
# a successor version (e.g. if generating an IXFR). |
|
# |
|
# Note our definition of least_kept also ensures we do not try to |
|
# delete the greatest version. |
|
if len(self._readers) > 0: |
|
least_kept = min(txn.version.id for txn in self._readers) |
|
else: |
|
least_kept = self._versions[-1].id |
|
while self._versions[0].id < least_kept and self._pruning_policy( |
|
self, self._versions[0] |
|
): |
|
self._versions.popleft() |
|
|
|
def set_max_versions(self, max_versions: Optional[int]) -> None: |
|
"""Set a pruning policy that retains up to the specified number |
|
of versions |
|
""" |
|
if max_versions is not None and max_versions < 1: |
|
raise ValueError("max versions must be at least 1") |
|
if max_versions is None: |
|
|
|
def policy(zone, _): # pylint: disable=unused-argument |
|
return False |
|
|
|
else: |
|
|
|
def policy(zone, _): |
|
return len(zone._versions) > max_versions |
|
|
|
self.set_pruning_policy(policy) |
|
|
|
def set_pruning_policy( |
|
self, policy: Optional[Callable[["Zone", Version], Optional[bool]]] |
|
) -> None: |
|
"""Set the pruning policy for the zone. |
|
|
|
The *policy* function takes a `Version` and returns `True` if |
|
the version should be pruned, and `False` otherwise. `None` |
|
may also be specified for policy, in which case the default policy |
|
is used. |
|
|
|
Pruning checking proceeds from the least version and the first |
|
time the function returns `False`, the checking stops. I.e. the |
|
retained versions are always a consecutive sequence. |
|
""" |
|
if policy is None: |
|
policy = self._default_pruning_policy |
|
with self._version_lock: |
|
self._pruning_policy = policy |
|
self._prune_versions_unlocked() |
|
|
|
def _end_read(self, txn): |
|
with self._version_lock: |
|
self._readers.remove(txn) |
|
self._prune_versions_unlocked() |
|
|
|
def _end_write_unlocked(self, txn): |
|
assert self._write_txn == txn |
|
self._write_txn = None |
|
self._maybe_wakeup_one_waiter_unlocked() |
|
|
|
def _end_write(self, txn): |
|
with self._version_lock: |
|
self._end_write_unlocked(txn) |
|
|
|
def _commit_version_unlocked(self, txn, version, origin): |
|
self._versions.append(version) |
|
self._prune_versions_unlocked() |
|
self.nodes = version.nodes |
|
if self.origin is None: |
|
self.origin = origin |
|
# txn can be None in __init__ when we make the empty version. |
|
if txn is not None: |
|
self._end_write_unlocked(txn) |
|
|
|
def _commit_version(self, txn, version, origin): |
|
with self._version_lock: |
|
self._commit_version_unlocked(txn, version, origin) |
|
|
|
def _get_next_version_id(self): |
|
if len(self._versions) > 0: |
|
id = self._versions[-1].id + 1 |
|
else: |
|
id = 1 |
|
return id |
|
|
|
def find_node( |
|
self, name: Union[dns.name.Name, str], create: bool = False |
|
) -> dns.node.Node: |
|
if create: |
|
raise UseTransaction |
|
return super().find_node(name) |
|
|
|
def delete_node(self, name: Union[dns.name.Name, str]) -> None: |
|
raise UseTransaction |
|
|
|
def find_rdataset( |
|
self, |
|
name: Union[dns.name.Name, str], |
|
rdtype: Union[dns.rdatatype.RdataType, str], |
|
covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, |
|
create: bool = False, |
|
) -> dns.rdataset.Rdataset: |
|
if create: |
|
raise UseTransaction |
|
rdataset = super().find_rdataset(name, rdtype, covers) |
|
return dns.rdataset.ImmutableRdataset(rdataset) |
|
|
|
def get_rdataset( |
|
self, |
|
name: Union[dns.name.Name, str], |
|
rdtype: Union[dns.rdatatype.RdataType, str], |
|
covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, |
|
create: bool = False, |
|
) -> Optional[dns.rdataset.Rdataset]: |
|
if create: |
|
raise UseTransaction |
|
rdataset = super().get_rdataset(name, rdtype, covers) |
|
if rdataset is not None: |
|
return dns.rdataset.ImmutableRdataset(rdataset) |
|
else: |
|
return None |
|
|
|
def delete_rdataset( |
|
self, |
|
name: Union[dns.name.Name, str], |
|
rdtype: Union[dns.rdatatype.RdataType, str], |
|
covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, |
|
) -> None: |
|
raise UseTransaction |
|
|
|
def replace_rdataset( |
|
self, name: Union[dns.name.Name, str], replacement: dns.rdataset.Rdataset |
|
) -> None: |
|
raise UseTransaction
|
|
|