You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

319 lines
12 KiB

  1. # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
  2. """DNS Versioned Zones."""
  3. import collections
  4. import threading
  5. from typing import Callable, Deque, Optional, Set, Union
  6. import dns.exception
  7. import dns.immutable
  8. import dns.name
  9. import dns.node
  10. import dns.rdataclass
  11. import dns.rdataset
  12. import dns.rdatatype
  13. import dns.rdtypes.ANY.SOA
  14. import dns.zone
  15. class UseTransaction(dns.exception.DNSException):
  16. """To alter a versioned zone, use a transaction."""
  17. # Backwards compatibility
  18. Node = dns.zone.VersionedNode
  19. ImmutableNode = dns.zone.ImmutableVersionedNode
  20. Version = dns.zone.Version
  21. WritableVersion = dns.zone.WritableVersion
  22. ImmutableVersion = dns.zone.ImmutableVersion
  23. Transaction = dns.zone.Transaction
  24. class Zone(dns.zone.Zone): # lgtm[py/missing-equals]
  25. __slots__ = [
  26. "_versions",
  27. "_versions_lock",
  28. "_write_txn",
  29. "_write_waiters",
  30. "_write_event",
  31. "_pruning_policy",
  32. "_readers",
  33. ]
  34. node_factory = Node
  35. def __init__(
  36. self,
  37. origin: Optional[Union[dns.name.Name, str]],
  38. rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
  39. relativize: bool = True,
  40. pruning_policy: Optional[Callable[["Zone", Version], Optional[bool]]] = None,
  41. ):
  42. """Initialize a versioned zone object.
  43. *origin* is the origin of the zone. It may be a ``dns.name.Name``,
  44. a ``str``, or ``None``. If ``None``, then the zone's origin will
  45. be set by the first ``$ORIGIN`` line in a zone file.
  46. *rdclass*, an ``int``, the zone's rdata class; the default is class IN.
  47. *relativize*, a ``bool``, determine's whether domain names are
  48. relativized to the zone's origin. The default is ``True``.
  49. *pruning policy*, a function taking a ``Zone`` and a ``Version`` and returning
  50. a ``bool``, or ``None``. Should the version be pruned? If ``None``,
  51. the default policy, which retains one version is used.
  52. """
  53. super().__init__(origin, rdclass, relativize)
  54. self._versions: Deque[Version] = collections.deque()
  55. self._version_lock = threading.Lock()
  56. if pruning_policy is None:
  57. self._pruning_policy = self._default_pruning_policy
  58. else:
  59. self._pruning_policy = pruning_policy
  60. self._write_txn: Optional[Transaction] = None
  61. self._write_event: Optional[threading.Event] = None
  62. self._write_waiters: Deque[threading.Event] = collections.deque()
  63. self._readers: Set[Transaction] = set()
  64. self._commit_version_unlocked(
  65. None, WritableVersion(self, replacement=True), origin
  66. )
  67. def reader(
  68. self, id: Optional[int] = None, serial: Optional[int] = None
  69. ) -> Transaction: # pylint: disable=arguments-differ
  70. if id is not None and serial is not None:
  71. raise ValueError("cannot specify both id and serial")
  72. with self._version_lock:
  73. if id is not None:
  74. version = None
  75. for v in reversed(self._versions):
  76. if v.id == id:
  77. version = v
  78. break
  79. if version is None:
  80. raise KeyError("version not found")
  81. elif serial is not None:
  82. if self.relativize:
  83. oname = dns.name.empty
  84. else:
  85. assert self.origin is not None
  86. oname = self.origin
  87. version = None
  88. for v in reversed(self._versions):
  89. n = v.nodes.get(oname)
  90. if n:
  91. rds = n.get_rdataset(self.rdclass, dns.rdatatype.SOA)
  92. if rds and rds[0].serial == serial:
  93. version = v
  94. break
  95. if version is None:
  96. raise KeyError("serial not found")
  97. else:
  98. version = self._versions[-1]
  99. txn = Transaction(self, False, version)
  100. self._readers.add(txn)
  101. return txn
  102. def writer(self, replacement: bool = False) -> Transaction:
  103. event = None
  104. while True:
  105. with self._version_lock:
  106. # Checking event == self._write_event ensures that either
  107. # no one was waiting before we got lucky and found no write
  108. # txn, or we were the one who was waiting and got woken up.
  109. # This prevents "taking cuts" when creating a write txn.
  110. if self._write_txn is None and event == self._write_event:
  111. # Creating the transaction defers version setup
  112. # (i.e. copying the nodes dictionary) until we
  113. # give up the lock, so that we hold the lock as
  114. # short a time as possible. This is why we call
  115. # _setup_version() below.
  116. self._write_txn = Transaction(
  117. self, replacement, make_immutable=True
  118. )
  119. # give up our exclusive right to make a Transaction
  120. self._write_event = None
  121. break
  122. # Someone else is writing already, so we will have to
  123. # wait, but we want to do the actual wait outside the
  124. # lock.
  125. event = threading.Event()
  126. self._write_waiters.append(event)
  127. # wait (note we gave up the lock!)
  128. #
  129. # We only wake one sleeper at a time, so it's important
  130. # that no event waiter can exit this method (e.g. via
  131. # cancellation) without returning a transaction or waking
  132. # someone else up.
  133. #
  134. # This is not a problem with Threading module threads as
  135. # they cannot be canceled, but could be an issue with trio
  136. # tasks when we do the async version of writer().
  137. # I.e. we'd need to do something like:
  138. #
  139. # try:
  140. # event.wait()
  141. # except trio.Cancelled:
  142. # with self._version_lock:
  143. # self._maybe_wakeup_one_waiter_unlocked()
  144. # raise
  145. #
  146. event.wait()
  147. # Do the deferred version setup.
  148. self._write_txn._setup_version()
  149. return self._write_txn
  150. def _maybe_wakeup_one_waiter_unlocked(self):
  151. if len(self._write_waiters) > 0:
  152. self._write_event = self._write_waiters.popleft()
  153. self._write_event.set()
  154. # pylint: disable=unused-argument
  155. def _default_pruning_policy(self, zone, version):
  156. return True
  157. # pylint: enable=unused-argument
  158. def _prune_versions_unlocked(self):
  159. assert len(self._versions) > 0
  160. # Don't ever prune a version greater than or equal to one that
  161. # a reader has open. This pins versions in memory while the
  162. # reader is open, and importantly lets the reader open a txn on
  163. # a successor version (e.g. if generating an IXFR).
  164. #
  165. # Note our definition of least_kept also ensures we do not try to
  166. # delete the greatest version.
  167. if len(self._readers) > 0:
  168. least_kept = min(txn.version.id for txn in self._readers)
  169. else:
  170. least_kept = self._versions[-1].id
  171. while self._versions[0].id < least_kept and self._pruning_policy(
  172. self, self._versions[0]
  173. ):
  174. self._versions.popleft()
  175. def set_max_versions(self, max_versions: Optional[int]) -> None:
  176. """Set a pruning policy that retains up to the specified number
  177. of versions
  178. """
  179. if max_versions is not None and max_versions < 1:
  180. raise ValueError("max versions must be at least 1")
  181. if max_versions is None:
  182. def policy(zone, _): # pylint: disable=unused-argument
  183. return False
  184. else:
  185. def policy(zone, _):
  186. return len(zone._versions) > max_versions
  187. self.set_pruning_policy(policy)
  188. def set_pruning_policy(
  189. self, policy: Optional[Callable[["Zone", Version], Optional[bool]]]
  190. ) -> None:
  191. """Set the pruning policy for the zone.
  192. The *policy* function takes a `Version` and returns `True` if
  193. the version should be pruned, and `False` otherwise. `None`
  194. may also be specified for policy, in which case the default policy
  195. is used.
  196. Pruning checking proceeds from the least version and the first
  197. time the function returns `False`, the checking stops. I.e. the
  198. retained versions are always a consecutive sequence.
  199. """
  200. if policy is None:
  201. policy = self._default_pruning_policy
  202. with self._version_lock:
  203. self._pruning_policy = policy
  204. self._prune_versions_unlocked()
  205. def _end_read(self, txn):
  206. with self._version_lock:
  207. self._readers.remove(txn)
  208. self._prune_versions_unlocked()
  209. def _end_write_unlocked(self, txn):
  210. assert self._write_txn == txn
  211. self._write_txn = None
  212. self._maybe_wakeup_one_waiter_unlocked()
  213. def _end_write(self, txn):
  214. with self._version_lock:
  215. self._end_write_unlocked(txn)
  216. def _commit_version_unlocked(self, txn, version, origin):
  217. self._versions.append(version)
  218. self._prune_versions_unlocked()
  219. self.nodes = version.nodes
  220. if self.origin is None:
  221. self.origin = origin
  222. # txn can be None in __init__ when we make the empty version.
  223. if txn is not None:
  224. self._end_write_unlocked(txn)
  225. def _commit_version(self, txn, version, origin):
  226. with self._version_lock:
  227. self._commit_version_unlocked(txn, version, origin)
  228. def _get_next_version_id(self):
  229. if len(self._versions) > 0:
  230. id = self._versions[-1].id + 1
  231. else:
  232. id = 1
  233. return id
  234. def find_node(
  235. self, name: Union[dns.name.Name, str], create: bool = False
  236. ) -> dns.node.Node:
  237. if create:
  238. raise UseTransaction
  239. return super().find_node(name)
  240. def delete_node(self, name: Union[dns.name.Name, str]) -> None:
  241. raise UseTransaction
  242. def find_rdataset(
  243. self,
  244. name: Union[dns.name.Name, str],
  245. rdtype: Union[dns.rdatatype.RdataType, str],
  246. covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE,
  247. create: bool = False,
  248. ) -> dns.rdataset.Rdataset:
  249. if create:
  250. raise UseTransaction
  251. rdataset = super().find_rdataset(name, rdtype, covers)
  252. return dns.rdataset.ImmutableRdataset(rdataset)
  253. def get_rdataset(
  254. self,
  255. name: Union[dns.name.Name, str],
  256. rdtype: Union[dns.rdatatype.RdataType, str],
  257. covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE,
  258. create: bool = False,
  259. ) -> Optional[dns.rdataset.Rdataset]:
  260. if create:
  261. raise UseTransaction
  262. rdataset = super().get_rdataset(name, rdtype, covers)
  263. if rdataset is not None:
  264. return dns.rdataset.ImmutableRdataset(rdataset)
  265. else:
  266. return None
  267. def delete_rdataset(
  268. self,
  269. name: Union[dns.name.Name, str],
  270. rdtype: Union[dns.rdatatype.RdataType, str],
  271. covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE,
  272. ) -> None:
  273. raise UseTransaction
  274. def replace_rdataset(
  275. self, name: Union[dns.name.Name, str], replacement: dns.rdataset.Rdataset
  276. ) -> None:
  277. raise UseTransaction