Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.
 
 
 
 

856 Zeilen
23 KiB

  1. import collections.abc
  2. import itertools
  3. import reprlib
  4. import sys
  5. __all__ = ('Map',)
  6. # Thread-safe counter.
  7. _mut_id = itertools.count(1).__next__
  8. # Python version of _map.c. The topmost comment there explains
  9. # all datastructures and algorithms.
  10. # The code here follows C code closely on purpose to make
  11. # debugging and testing easier.
  12. def map_hash(o):
  13. x = hash(o)
  14. if sys.hash_info.width > 32:
  15. return (x & 0xffffffff) ^ ((x >> 32) & 0xffffffff)
  16. else:
  17. return x
  18. def map_mask(hash, shift):
  19. return (hash >> shift) & 0x01f
  20. def map_bitpos(hash, shift):
  21. return 1 << map_mask(hash, shift)
  22. def map_bitcount(v):
  23. v = v - ((v >> 1) & 0x55555555)
  24. v = (v & 0x33333333) + ((v >> 2) & 0x33333333)
  25. v = (v & 0x0F0F0F0F) + ((v >> 4) & 0x0F0F0F0F)
  26. v = v + (v >> 8)
  27. v = (v + (v >> 16)) & 0x3F
  28. return v
  29. def map_bitindex(bitmap, bit):
  30. return map_bitcount(bitmap & (bit - 1))
  31. W_EMPTY, W_NEWNODE, W_NOT_FOUND = range(3)
  32. void = object()
  33. class _Unhashable:
  34. __slots__ = ()
  35. __hash__ = None
  36. _NULL = _Unhashable()
  37. del _Unhashable
  38. class BitmapNode:
  39. def __init__(self, size, bitmap, array, mutid):
  40. self.size = size
  41. self.bitmap = bitmap
  42. assert isinstance(array, list) and len(array) == size
  43. self.array = array
  44. self.mutid = mutid
  45. def clone(self, mutid):
  46. return BitmapNode(self.size, self.bitmap, self.array.copy(), mutid)
  47. def assoc(self, shift, hash, key, val, mutid):
  48. bit = map_bitpos(hash, shift)
  49. idx = map_bitindex(self.bitmap, bit)
  50. if self.bitmap & bit:
  51. key_idx = 2 * idx
  52. val_idx = key_idx + 1
  53. key_or_null = self.array[key_idx]
  54. val_or_node = self.array[val_idx]
  55. if key_or_null is _NULL:
  56. sub_node, added = val_or_node.assoc(
  57. shift + 5, hash, key, val, mutid)
  58. if val_or_node is sub_node:
  59. return self, added
  60. if mutid and mutid == self.mutid:
  61. self.array[val_idx] = sub_node
  62. return self, added
  63. else:
  64. ret = self.clone(mutid)
  65. ret.array[val_idx] = sub_node
  66. return ret, added
  67. if key == key_or_null:
  68. if val is val_or_node:
  69. return self, False
  70. if mutid and mutid == self.mutid:
  71. self.array[val_idx] = val
  72. return self, False
  73. else:
  74. ret = self.clone(mutid)
  75. ret.array[val_idx] = val
  76. return ret, False
  77. existing_key_hash = map_hash(key_or_null)
  78. if existing_key_hash == hash:
  79. sub_node = CollisionNode(
  80. 4, hash, [key_or_null, val_or_node, key, val], mutid)
  81. else:
  82. sub_node = BitmapNode(0, 0, [], mutid)
  83. sub_node, _ = sub_node.assoc(
  84. shift + 5, existing_key_hash,
  85. key_or_null, val_or_node,
  86. mutid)
  87. sub_node, _ = sub_node.assoc(
  88. shift + 5, hash, key, val,
  89. mutid)
  90. if mutid and mutid == self.mutid:
  91. self.array[key_idx] = _NULL
  92. self.array[val_idx] = sub_node
  93. return self, True
  94. else:
  95. ret = self.clone(mutid)
  96. ret.array[key_idx] = _NULL
  97. ret.array[val_idx] = sub_node
  98. return ret, True
  99. else:
  100. key_idx = 2 * idx
  101. val_idx = key_idx + 1
  102. n = map_bitcount(self.bitmap)
  103. new_array = self.array[:key_idx]
  104. new_array.append(key)
  105. new_array.append(val)
  106. new_array.extend(self.array[key_idx:])
  107. if mutid and mutid == self.mutid:
  108. self.size = 2 * (n + 1)
  109. self.bitmap |= bit
  110. self.array = new_array
  111. return self, True
  112. else:
  113. return BitmapNode(
  114. 2 * (n + 1), self.bitmap | bit, new_array, mutid), True
  115. def find(self, shift, hash, key):
  116. bit = map_bitpos(hash, shift)
  117. if not (self.bitmap & bit):
  118. raise KeyError
  119. idx = map_bitindex(self.bitmap, bit)
  120. key_idx = idx * 2
  121. val_idx = key_idx + 1
  122. key_or_null = self.array[key_idx]
  123. val_or_node = self.array[val_idx]
  124. if key_or_null is _NULL:
  125. return val_or_node.find(shift + 5, hash, key)
  126. if key == key_or_null:
  127. return val_or_node
  128. raise KeyError(key)
  129. def without(self, shift, hash, key, mutid):
  130. bit = map_bitpos(hash, shift)
  131. if not (self.bitmap & bit):
  132. return W_NOT_FOUND, None
  133. idx = map_bitindex(self.bitmap, bit)
  134. key_idx = 2 * idx
  135. val_idx = key_idx + 1
  136. key_or_null = self.array[key_idx]
  137. val_or_node = self.array[val_idx]
  138. if key_or_null is _NULL:
  139. res, sub_node = val_or_node.without(shift + 5, hash, key, mutid)
  140. if res is W_EMPTY:
  141. raise RuntimeError('unreachable code') # pragma: no cover
  142. elif res is W_NEWNODE:
  143. if (type(sub_node) is BitmapNode and
  144. sub_node.size == 2 and
  145. sub_node.array[0] is not _NULL):
  146. if mutid and mutid == self.mutid:
  147. self.array[key_idx] = sub_node.array[0]
  148. self.array[val_idx] = sub_node.array[1]
  149. return W_NEWNODE, self
  150. else:
  151. clone = self.clone(mutid)
  152. clone.array[key_idx] = sub_node.array[0]
  153. clone.array[val_idx] = sub_node.array[1]
  154. return W_NEWNODE, clone
  155. if mutid and mutid == self.mutid:
  156. self.array[val_idx] = sub_node
  157. return W_NEWNODE, self
  158. else:
  159. clone = self.clone(mutid)
  160. clone.array[val_idx] = sub_node
  161. return W_NEWNODE, clone
  162. else:
  163. assert sub_node is None
  164. return res, None
  165. else:
  166. if key == key_or_null:
  167. if self.size == 2:
  168. return W_EMPTY, None
  169. new_array = self.array[:key_idx]
  170. new_array.extend(self.array[val_idx + 1:])
  171. if mutid and mutid == self.mutid:
  172. self.size -= 2
  173. self.bitmap &= ~bit
  174. self.array = new_array
  175. return W_NEWNODE, self
  176. else:
  177. new_node = BitmapNode(
  178. self.size - 2, self.bitmap & ~bit, new_array, mutid)
  179. return W_NEWNODE, new_node
  180. else:
  181. return W_NOT_FOUND, None
  182. def keys(self):
  183. for i in range(0, self.size, 2):
  184. key_or_null = self.array[i]
  185. if key_or_null is _NULL:
  186. val_or_node = self.array[i + 1]
  187. yield from val_or_node.keys()
  188. else:
  189. yield key_or_null
  190. def values(self):
  191. for i in range(0, self.size, 2):
  192. key_or_null = self.array[i]
  193. val_or_node = self.array[i + 1]
  194. if key_or_null is _NULL:
  195. yield from val_or_node.values()
  196. else:
  197. yield val_or_node
  198. def items(self):
  199. for i in range(0, self.size, 2):
  200. key_or_null = self.array[i]
  201. val_or_node = self.array[i + 1]
  202. if key_or_null is _NULL:
  203. yield from val_or_node.items()
  204. else:
  205. yield key_or_null, val_or_node
  206. def dump(self, buf, level): # pragma: no cover
  207. buf.append(
  208. ' ' * (level + 1) +
  209. 'BitmapNode(size={} count={} bitmap={} id={:0x}):'.format(
  210. self.size, self.size / 2, bin(self.bitmap), id(self)))
  211. for i in range(0, self.size, 2):
  212. key_or_null = self.array[i]
  213. val_or_node = self.array[i + 1]
  214. pad = ' ' * (level + 2)
  215. if key_or_null is _NULL:
  216. buf.append(pad + 'NULL:')
  217. val_or_node.dump(buf, level + 2)
  218. else:
  219. buf.append(pad + '{!r}: {!r}'.format(key_or_null, val_or_node))
  220. class CollisionNode:
  221. def __init__(self, size, hash, array, mutid):
  222. self.size = size
  223. self.hash = hash
  224. self.array = array
  225. self.mutid = mutid
  226. def find_index(self, key):
  227. for i in range(0, self.size, 2):
  228. if self.array[i] == key:
  229. return i
  230. return -1
  231. def find(self, shift, hash, key):
  232. for i in range(0, self.size, 2):
  233. if self.array[i] == key:
  234. return self.array[i + 1]
  235. raise KeyError(key)
  236. def assoc(self, shift, hash, key, val, mutid):
  237. if hash == self.hash:
  238. key_idx = self.find_index(key)
  239. if key_idx == -1:
  240. new_array = self.array.copy()
  241. new_array.append(key)
  242. new_array.append(val)
  243. if mutid and mutid == self.mutid:
  244. self.size += 2
  245. self.array = new_array
  246. return self, True
  247. else:
  248. new_node = CollisionNode(
  249. self.size + 2, hash, new_array, mutid)
  250. return new_node, True
  251. val_idx = key_idx + 1
  252. if self.array[val_idx] is val:
  253. return self, False
  254. if mutid and mutid == self.mutid:
  255. self.array[val_idx] = val
  256. return self, False
  257. else:
  258. new_array = self.array.copy()
  259. new_array[val_idx] = val
  260. return CollisionNode(self.size, hash, new_array, mutid), False
  261. else:
  262. new_node = BitmapNode(
  263. 2, map_bitpos(self.hash, shift), [_NULL, self], mutid)
  264. return new_node.assoc(shift, hash, key, val, mutid)
  265. def without(self, shift, hash, key, mutid):
  266. if hash != self.hash:
  267. return W_NOT_FOUND, None
  268. key_idx = self.find_index(key)
  269. if key_idx == -1:
  270. return W_NOT_FOUND, None
  271. new_size = self.size - 2
  272. if new_size == 0:
  273. # Shouldn't be ever reachable
  274. return W_EMPTY, None # pragma: no cover
  275. if new_size == 2:
  276. if key_idx == 0:
  277. new_array = [self.array[2], self.array[3]]
  278. else:
  279. assert key_idx == 2
  280. new_array = [self.array[0], self.array[1]]
  281. new_node = BitmapNode(
  282. 2, map_bitpos(hash, shift), new_array, mutid)
  283. return W_NEWNODE, new_node
  284. new_array = self.array[:key_idx]
  285. new_array.extend(self.array[key_idx + 2:])
  286. if mutid and mutid == self.mutid:
  287. self.array = new_array
  288. self.size -= 2
  289. return W_NEWNODE, self
  290. else:
  291. new_node = CollisionNode(
  292. self.size - 2, self.hash, new_array, mutid)
  293. return W_NEWNODE, new_node
  294. def keys(self):
  295. for i in range(0, self.size, 2):
  296. yield self.array[i]
  297. def values(self):
  298. for i in range(1, self.size, 2):
  299. yield self.array[i]
  300. def items(self):
  301. for i in range(0, self.size, 2):
  302. yield self.array[i], self.array[i + 1]
  303. def dump(self, buf, level): # pragma: no cover
  304. pad = ' ' * (level + 1)
  305. buf.append(
  306. pad + 'CollisionNode(size={} id={:0x}):'.format(
  307. self.size, id(self)))
  308. pad = ' ' * (level + 2)
  309. for i in range(0, self.size, 2):
  310. key = self.array[i]
  311. val = self.array[i + 1]
  312. buf.append('{}{!r}: {!r}'.format(pad, key, val))
  313. class MapKeys:
  314. def __init__(self, c, m):
  315. self.__count = c
  316. self.__root = m
  317. def __len__(self):
  318. return self.__count
  319. def __iter__(self):
  320. return iter(self.__root.keys())
  321. class MapValues:
  322. def __init__(self, c, m):
  323. self.__count = c
  324. self.__root = m
  325. def __len__(self):
  326. return self.__count
  327. def __iter__(self):
  328. return iter(self.__root.values())
  329. class MapItems:
  330. def __init__(self, c, m):
  331. self.__count = c
  332. self.__root = m
  333. def __len__(self):
  334. return self.__count
  335. def __iter__(self):
  336. return iter(self.__root.items())
  337. class Map:
  338. def __init__(self, *args, **kw):
  339. if not args:
  340. col = None
  341. elif len(args) == 1:
  342. col = args[0]
  343. else:
  344. raise TypeError(
  345. "immutables.Map expected at most 1 arguments, "
  346. "got {}".format(len(args))
  347. )
  348. self.__count = 0
  349. self.__root = BitmapNode(0, 0, [], 0)
  350. self.__hash = -1
  351. if isinstance(col, Map):
  352. self.__count = col.__count
  353. self.__root = col.__root
  354. self.__hash = col.__hash
  355. col = None
  356. elif isinstance(col, MapMutation):
  357. raise TypeError('cannot create Maps from MapMutations')
  358. if col or kw:
  359. init = self.update(col, **kw)
  360. self.__count = init.__count
  361. self.__root = init.__root
  362. @classmethod
  363. def _new(cls, count, root):
  364. m = Map.__new__(Map)
  365. m.__count = count
  366. m.__root = root
  367. m.__hash = -1
  368. return m
  369. def __reduce__(self):
  370. return (type(self), (dict(self.items()),))
  371. def __len__(self):
  372. return self.__count
  373. def __eq__(self, other):
  374. if not isinstance(other, Map):
  375. return NotImplemented
  376. if len(self) != len(other):
  377. return False
  378. for key, val in self.__root.items():
  379. try:
  380. oval = other.__root.find(0, map_hash(key), key)
  381. except KeyError:
  382. return False
  383. else:
  384. if oval != val:
  385. return False
  386. return True
  387. def update(self, *args, **kw):
  388. if not args:
  389. col = None
  390. elif len(args) == 1:
  391. col = args[0]
  392. else:
  393. raise TypeError(
  394. "update expected at most 1 arguments, got {}".format(len(args))
  395. )
  396. it = None
  397. if col is not None:
  398. if hasattr(col, 'items'):
  399. it = iter(col.items())
  400. else:
  401. it = iter(col)
  402. if it is not None:
  403. if kw:
  404. it = iter(itertools.chain(it, kw.items()))
  405. else:
  406. if kw:
  407. it = iter(kw.items())
  408. if it is None:
  409. return self
  410. mutid = _mut_id()
  411. root = self.__root
  412. count = self.__count
  413. i = 0
  414. while True:
  415. try:
  416. tup = next(it)
  417. except StopIteration:
  418. break
  419. try:
  420. tup = tuple(tup)
  421. except TypeError:
  422. raise TypeError(
  423. 'cannot convert map update '
  424. 'sequence element #{} to a sequence'.format(i)) from None
  425. key, val, *r = tup
  426. if r:
  427. raise ValueError(
  428. 'map update sequence element #{} has length '
  429. '{}; 2 is required'.format(i, len(r) + 2))
  430. root, added = root.assoc(0, map_hash(key), key, val, mutid)
  431. if added:
  432. count += 1
  433. i += 1
  434. return Map._new(count, root)
  435. def mutate(self):
  436. return MapMutation(self.__count, self.__root)
  437. def set(self, key, val):
  438. new_count = self.__count
  439. new_root, added = self.__root.assoc(0, map_hash(key), key, val, 0)
  440. if new_root is self.__root:
  441. assert not added
  442. return self
  443. if added:
  444. new_count += 1
  445. return Map._new(new_count, new_root)
  446. def delete(self, key):
  447. res, node = self.__root.without(0, map_hash(key), key, 0)
  448. if res is W_EMPTY:
  449. return Map()
  450. elif res is W_NOT_FOUND:
  451. raise KeyError(key)
  452. else:
  453. return Map._new(self.__count - 1, node)
  454. def get(self, key, default=None):
  455. try:
  456. return self.__root.find(0, map_hash(key), key)
  457. except KeyError:
  458. return default
  459. def __getitem__(self, key):
  460. return self.__root.find(0, map_hash(key), key)
  461. def __contains__(self, key):
  462. try:
  463. self.__root.find(0, map_hash(key), key)
  464. except KeyError:
  465. return False
  466. else:
  467. return True
  468. def __iter__(self):
  469. yield from self.__root.keys()
  470. def keys(self):
  471. return MapKeys(self.__count, self.__root)
  472. def values(self):
  473. return MapValues(self.__count, self.__root)
  474. def items(self):
  475. return MapItems(self.__count, self.__root)
  476. def __hash__(self):
  477. if self.__hash != -1:
  478. return self.__hash
  479. MAX = sys.maxsize
  480. MASK = 2 * MAX + 1
  481. h = 1927868237 * (self.__count * 2 + 1)
  482. h &= MASK
  483. for key, value in self.__root.items():
  484. hx = hash(key)
  485. h ^= (hx ^ (hx << 16) ^ 89869747) * 3644798167
  486. h &= MASK
  487. hx = hash(value)
  488. h ^= (hx ^ (hx << 16) ^ 89869747) * 3644798167
  489. h &= MASK
  490. h = h * 69069 + 907133923
  491. h &= MASK
  492. if h > MAX:
  493. h -= MASK + 1 # pragma: no cover
  494. if h == -1:
  495. h = 590923713 # pragma: no cover
  496. self.__hash = h
  497. return h
  498. @reprlib.recursive_repr("{...}")
  499. def __repr__(self):
  500. items = []
  501. for key, val in self.items():
  502. items.append("{!r}: {!r}".format(key, val))
  503. return 'immutables.Map({{{}}})'.format(', '.join(items))
  504. def __dump__(self): # pragma: no cover
  505. buf = []
  506. self.__root.dump(buf, 0)
  507. return '\n'.join(buf)
  508. def __class_getitem__(cls, item):
  509. return cls
  510. class MapMutation:
  511. def __init__(self, count, root):
  512. self.__count = count
  513. self.__root = root
  514. self.__mutid = _mut_id()
  515. def set(self, key, val):
  516. self[key] = val
  517. def __enter__(self):
  518. return self
  519. def __exit__(self, *exc):
  520. self.finish()
  521. return False
  522. def __iter__(self):
  523. raise TypeError('{} is not iterable'.format(type(self)))
  524. def __delitem__(self, key):
  525. if self.__mutid == 0:
  526. raise ValueError('mutation {!r} has been finished'.format(self))
  527. res, new_root = self.__root.without(
  528. 0, map_hash(key), key, self.__mutid)
  529. if res is W_EMPTY:
  530. self.__count = 0
  531. self.__root = BitmapNode(0, 0, [], self.__mutid)
  532. elif res is W_NOT_FOUND:
  533. raise KeyError(key)
  534. else:
  535. self.__root = new_root
  536. self.__count -= 1
  537. def __setitem__(self, key, val):
  538. if self.__mutid == 0:
  539. raise ValueError('mutation {!r} has been finished'.format(self))
  540. self.__root, added = self.__root.assoc(
  541. 0, map_hash(key), key, val, self.__mutid)
  542. if added:
  543. self.__count += 1
  544. def pop(self, key, *args):
  545. if self.__mutid == 0:
  546. raise ValueError('mutation {!r} has been finished'.format(self))
  547. if len(args) > 1:
  548. raise TypeError(
  549. 'pop() accepts 1 to 2 positional arguments, '
  550. 'got {}'.format(len(args) + 1))
  551. elif len(args) == 1:
  552. default = args[0]
  553. else:
  554. default = void
  555. val = self.get(key, default)
  556. try:
  557. del self[key]
  558. except KeyError:
  559. if val is void:
  560. raise
  561. return val
  562. else:
  563. assert val is not void
  564. return val
  565. def get(self, key, default=None):
  566. try:
  567. return self.__root.find(0, map_hash(key), key)
  568. except KeyError:
  569. return default
  570. def __getitem__(self, key):
  571. return self.__root.find(0, map_hash(key), key)
  572. def __contains__(self, key):
  573. try:
  574. self.__root.find(0, map_hash(key), key)
  575. except KeyError:
  576. return False
  577. else:
  578. return True
  579. def update(self, *args, **kw):
  580. if not args:
  581. col = None
  582. elif len(args) == 1:
  583. col = args[0]
  584. else:
  585. raise TypeError(
  586. "update expected at most 1 arguments, got {}".format(len(args))
  587. )
  588. if self.__mutid == 0:
  589. raise ValueError('mutation {!r} has been finished'.format(self))
  590. it = None
  591. if col is not None:
  592. if hasattr(col, 'items'):
  593. it = iter(col.items())
  594. else:
  595. it = iter(col)
  596. if it is not None:
  597. if kw:
  598. it = iter(itertools.chain(it, kw.items()))
  599. else:
  600. if kw:
  601. it = iter(kw.items())
  602. if it is None:
  603. return
  604. root = self.__root
  605. count = self.__count
  606. i = 0
  607. while True:
  608. try:
  609. tup = next(it)
  610. except StopIteration:
  611. break
  612. try:
  613. tup = tuple(tup)
  614. except TypeError:
  615. raise TypeError(
  616. 'cannot convert map update '
  617. 'sequence element #{} to a sequence'.format(i)) from None
  618. key, val, *r = tup
  619. if r:
  620. raise ValueError(
  621. 'map update sequence element #{} has length '
  622. '{}; 2 is required'.format(i, len(r) + 2))
  623. root, added = root.assoc(0, map_hash(key), key, val, self.__mutid)
  624. if added:
  625. count += 1
  626. i += 1
  627. self.__root = root
  628. self.__count = count
  629. def finish(self):
  630. self.__mutid = 0
  631. return Map._new(self.__count, self.__root)
  632. @reprlib.recursive_repr("{...}")
  633. def __repr__(self):
  634. items = []
  635. for key, val in self.__root.items():
  636. items.append("{!r}: {!r}".format(key, val))
  637. return 'immutables.MapMutation({{{}}})'.format(', '.join(items))
  638. def __len__(self):
  639. return self.__count
  640. def __reduce__(self):
  641. raise TypeError("can't pickle {} objects".format(type(self).__name__))
  642. def __hash__(self):
  643. raise TypeError('unhashable type: {}'.format(type(self).__name__))
  644. def __eq__(self, other):
  645. if not isinstance(other, MapMutation):
  646. return NotImplemented
  647. if len(self) != len(other):
  648. return False
  649. for key, val in self.__root.items():
  650. try:
  651. oval = other.__root.find(0, map_hash(key), key)
  652. except KeyError:
  653. return False
  654. else:
  655. if oval != val:
  656. return False
  657. return True
  658. collections.abc.Mapping.register(Map)