You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

508 regels
18 KiB

  1. """
  2. lxml-based doctest output comparison.
  3. Note: normally, you should just import the `lxml.usedoctest` and
  4. `lxml.html.usedoctest` modules from within a doctest, instead of this
  5. one::
  6. >>> import lxml.usedoctest # for XML output
  7. >>> import lxml.html.usedoctest # for HTML output
  8. To use this module directly, you must call ``lxmldoctest.install()``,
  9. which will cause doctest to use this in all subsequent calls.
  10. This changes the way output is checked and comparisons are made for
  11. XML or HTML-like content.
  12. XML or HTML content is noticed because the example starts with ``<``
  13. (it's HTML if it starts with ``<html``). You can also use the
  14. ``PARSE_HTML`` and ``PARSE_XML`` flags to force parsing.
  15. Some rough wildcard-like things are allowed. Whitespace is generally
  16. ignored (except in attributes). In text (attributes and text in the
  17. body) you can use ``...`` as a wildcard. In an example it also
  18. matches any trailing tags in the element, though it does not match
  19. leading tags. You may create a tag ``<any>`` or include an ``any``
  20. attribute in the tag. An ``any`` tag matches any tag, while the
  21. attribute matches any and all attributes.
  22. When a match fails, the reformatted example and gotten text is
  23. displayed (indented), and a rough diff-like output is given. Anything
  24. marked with ``+`` is in the output but wasn't supposed to be, and
  25. similarly ``-`` means its in the example but wasn't in the output.
  26. You can disable parsing on one line with ``# doctest:+NOPARSE_MARKUP``
  27. """
  28. from lxml import etree
  29. import sys
  30. import re
  31. import doctest
  32. try:
  33. from html import escape as html_escape
  34. except ImportError:
  35. from cgi import escape as html_escape
  36. __all__ = ['PARSE_HTML', 'PARSE_XML', 'NOPARSE_MARKUP', 'LXMLOutputChecker',
  37. 'LHTMLOutputChecker', 'install', 'temp_install']
  38. try:
  39. _basestring = basestring
  40. except NameError:
  41. _basestring = (str, bytes)
  42. _IS_PYTHON_3 = sys.version_info[0] >= 3
  43. PARSE_HTML = doctest.register_optionflag('PARSE_HTML')
  44. PARSE_XML = doctest.register_optionflag('PARSE_XML')
  45. NOPARSE_MARKUP = doctest.register_optionflag('NOPARSE_MARKUP')
  46. OutputChecker = doctest.OutputChecker
  47. def strip(v):
  48. if v is None:
  49. return None
  50. else:
  51. return v.strip()
  52. def norm_whitespace(v):
  53. return _norm_whitespace_re.sub(' ', v)
  54. _html_parser = etree.HTMLParser(recover=False, remove_blank_text=True)
  55. def html_fromstring(html):
  56. return etree.fromstring(html, _html_parser)
  57. # We use this to distinguish repr()s from elements:
  58. _repr_re = re.compile(r'^<[^>]+ (at|object) ')
  59. _norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+')
  60. class LXMLOutputChecker(OutputChecker):
  61. empty_tags = (
  62. 'param', 'img', 'area', 'br', 'basefont', 'input',
  63. 'base', 'meta', 'link', 'col')
  64. def get_default_parser(self):
  65. return etree.XML
  66. def check_output(self, want, got, optionflags):
  67. alt_self = getattr(self, '_temp_override_self', None)
  68. if alt_self is not None:
  69. super_method = self._temp_call_super_check_output
  70. self = alt_self
  71. else:
  72. super_method = OutputChecker.check_output
  73. parser = self.get_parser(want, got, optionflags)
  74. if not parser:
  75. return super_method(
  76. self, want, got, optionflags)
  77. try:
  78. want_doc = parser(want)
  79. except etree.XMLSyntaxError:
  80. return False
  81. try:
  82. got_doc = parser(got)
  83. except etree.XMLSyntaxError:
  84. return False
  85. return self.compare_docs(want_doc, got_doc)
  86. def get_parser(self, want, got, optionflags):
  87. parser = None
  88. if NOPARSE_MARKUP & optionflags:
  89. return None
  90. if PARSE_HTML & optionflags:
  91. parser = html_fromstring
  92. elif PARSE_XML & optionflags:
  93. parser = etree.XML
  94. elif (want.strip().lower().startswith('<html')
  95. and got.strip().startswith('<html')):
  96. parser = html_fromstring
  97. elif (self._looks_like_markup(want)
  98. and self._looks_like_markup(got)):
  99. parser = self.get_default_parser()
  100. return parser
  101. def _looks_like_markup(self, s):
  102. s = s.strip()
  103. return (s.startswith('<')
  104. and not _repr_re.search(s))
  105. def compare_docs(self, want, got):
  106. if not self.tag_compare(want.tag, got.tag):
  107. return False
  108. if not self.text_compare(want.text, got.text, True):
  109. return False
  110. if not self.text_compare(want.tail, got.tail, True):
  111. return False
  112. if 'any' not in want.attrib:
  113. want_keys = sorted(want.attrib.keys())
  114. got_keys = sorted(got.attrib.keys())
  115. if want_keys != got_keys:
  116. return False
  117. for key in want_keys:
  118. if not self.text_compare(want.attrib[key], got.attrib[key], False):
  119. return False
  120. if want.text != '...' or len(want):
  121. want_children = list(want)
  122. got_children = list(got)
  123. while want_children or got_children:
  124. if not want_children or not got_children:
  125. return False
  126. want_first = want_children.pop(0)
  127. got_first = got_children.pop(0)
  128. if not self.compare_docs(want_first, got_first):
  129. return False
  130. if not got_children and want_first.tail == '...':
  131. break
  132. return True
  133. def text_compare(self, want, got, strip):
  134. want = want or ''
  135. got = got or ''
  136. if strip:
  137. want = norm_whitespace(want).strip()
  138. got = norm_whitespace(got).strip()
  139. want = '^%s$' % re.escape(want)
  140. want = want.replace(r'\.\.\.', '.*')
  141. if re.search(want, got):
  142. return True
  143. else:
  144. return False
  145. def tag_compare(self, want, got):
  146. if want == 'any':
  147. return True
  148. if (not isinstance(want, _basestring)
  149. or not isinstance(got, _basestring)):
  150. return want == got
  151. want = want or ''
  152. got = got or ''
  153. if want.startswith('{...}'):
  154. # Ellipsis on the namespace
  155. return want.split('}')[-1] == got.split('}')[-1]
  156. else:
  157. return want == got
  158. def output_difference(self, example, got, optionflags):
  159. want = example.want
  160. parser = self.get_parser(want, got, optionflags)
  161. errors = []
  162. if parser is not None:
  163. try:
  164. want_doc = parser(want)
  165. except etree.XMLSyntaxError:
  166. e = sys.exc_info()[1]
  167. errors.append('In example: %s' % e)
  168. try:
  169. got_doc = parser(got)
  170. except etree.XMLSyntaxError:
  171. e = sys.exc_info()[1]
  172. errors.append('In actual output: %s' % e)
  173. if parser is None or errors:
  174. value = OutputChecker.output_difference(
  175. self, example, got, optionflags)
  176. if errors:
  177. errors.append(value)
  178. return '\n'.join(errors)
  179. else:
  180. return value
  181. html = parser is html_fromstring
  182. diff_parts = ['Expected:',
  183. self.format_doc(want_doc, html, 2),
  184. 'Got:',
  185. self.format_doc(got_doc, html, 2),
  186. 'Diff:',
  187. self.collect_diff(want_doc, got_doc, html, 2)]
  188. return '\n'.join(diff_parts)
  189. def html_empty_tag(self, el, html=True):
  190. if not html:
  191. return False
  192. if el.tag not in self.empty_tags:
  193. return False
  194. if el.text or len(el):
  195. # This shouldn't happen (contents in an empty tag)
  196. return False
  197. return True
  198. def format_doc(self, doc, html, indent, prefix=''):
  199. parts = []
  200. if not len(doc):
  201. # No children...
  202. parts.append(' '*indent)
  203. parts.append(prefix)
  204. parts.append(self.format_tag(doc))
  205. if not self.html_empty_tag(doc, html):
  206. if strip(doc.text):
  207. parts.append(self.format_text(doc.text))
  208. parts.append(self.format_end_tag(doc))
  209. if strip(doc.tail):
  210. parts.append(self.format_text(doc.tail))
  211. parts.append('\n')
  212. return ''.join(parts)
  213. parts.append(' '*indent)
  214. parts.append(prefix)
  215. parts.append(self.format_tag(doc))
  216. if not self.html_empty_tag(doc, html):
  217. parts.append('\n')
  218. if strip(doc.text):
  219. parts.append(' '*indent)
  220. parts.append(self.format_text(doc.text))
  221. parts.append('\n')
  222. for el in doc:
  223. parts.append(self.format_doc(el, html, indent+2))
  224. parts.append(' '*indent)
  225. parts.append(self.format_end_tag(doc))
  226. parts.append('\n')
  227. if strip(doc.tail):
  228. parts.append(' '*indent)
  229. parts.append(self.format_text(doc.tail))
  230. parts.append('\n')
  231. return ''.join(parts)
  232. def format_text(self, text, strip=True):
  233. if text is None:
  234. return ''
  235. if strip:
  236. text = text.strip()
  237. return html_escape(text, 1)
  238. def format_tag(self, el):
  239. attrs = []
  240. if isinstance(el, etree.CommentBase):
  241. # FIXME: probably PIs should be handled specially too?
  242. return '<!--'
  243. for name, value in sorted(el.attrib.items()):
  244. attrs.append('%s="%s"' % (name, self.format_text(value, False)))
  245. if not attrs:
  246. return '<%s>' % el.tag
  247. return '<%s %s>' % (el.tag, ' '.join(attrs))
  248. def format_end_tag(self, el):
  249. if isinstance(el, etree.CommentBase):
  250. # FIXME: probably PIs should be handled specially too?
  251. return '-->'
  252. return '</%s>' % el.tag
  253. def collect_diff(self, want, got, html, indent):
  254. parts = []
  255. if not len(want) and not len(got):
  256. parts.append(' '*indent)
  257. parts.append(self.collect_diff_tag(want, got))
  258. if not self.html_empty_tag(got, html):
  259. parts.append(self.collect_diff_text(want.text, got.text))
  260. parts.append(self.collect_diff_end_tag(want, got))
  261. parts.append(self.collect_diff_text(want.tail, got.tail))
  262. parts.append('\n')
  263. return ''.join(parts)
  264. parts.append(' '*indent)
  265. parts.append(self.collect_diff_tag(want, got))
  266. parts.append('\n')
  267. if strip(want.text) or strip(got.text):
  268. parts.append(' '*indent)
  269. parts.append(self.collect_diff_text(want.text, got.text))
  270. parts.append('\n')
  271. want_children = list(want)
  272. got_children = list(got)
  273. while want_children or got_children:
  274. if not want_children:
  275. parts.append(self.format_doc(got_children.pop(0), html, indent+2, '+'))
  276. continue
  277. if not got_children:
  278. parts.append(self.format_doc(want_children.pop(0), html, indent+2, '-'))
  279. continue
  280. parts.append(self.collect_diff(
  281. want_children.pop(0), got_children.pop(0), html, indent+2))
  282. parts.append(' '*indent)
  283. parts.append(self.collect_diff_end_tag(want, got))
  284. parts.append('\n')
  285. if strip(want.tail) or strip(got.tail):
  286. parts.append(' '*indent)
  287. parts.append(self.collect_diff_text(want.tail, got.tail))
  288. parts.append('\n')
  289. return ''.join(parts)
  290. def collect_diff_tag(self, want, got):
  291. if not self.tag_compare(want.tag, got.tag):
  292. tag = '%s (got: %s)' % (want.tag, got.tag)
  293. else:
  294. tag = got.tag
  295. attrs = []
  296. any = want.tag == 'any' or 'any' in want.attrib
  297. for name, value in sorted(got.attrib.items()):
  298. if name not in want.attrib and not any:
  299. attrs.append('+%s="%s"' % (name, self.format_text(value, False)))
  300. else:
  301. if name in want.attrib:
  302. text = self.collect_diff_text(want.attrib[name], value, False)
  303. else:
  304. text = self.format_text(value, False)
  305. attrs.append('%s="%s"' % (name, text))
  306. if not any:
  307. for name, value in sorted(want.attrib.items()):
  308. if name in got.attrib:
  309. continue
  310. attrs.append('-%s="%s"' % (name, self.format_text(value, False)))
  311. if attrs:
  312. tag = '<%s %s>' % (tag, ' '.join(attrs))
  313. else:
  314. tag = '<%s>' % tag
  315. return tag
  316. def collect_diff_end_tag(self, want, got):
  317. if want.tag != got.tag:
  318. tag = '%s (got: %s)' % (want.tag, got.tag)
  319. else:
  320. tag = got.tag
  321. return '</%s>' % tag
  322. def collect_diff_text(self, want, got, strip=True):
  323. if self.text_compare(want, got, strip):
  324. if not got:
  325. return ''
  326. return self.format_text(got, strip)
  327. text = '%s (got: %s)' % (want, got)
  328. return self.format_text(text, strip)
  329. class LHTMLOutputChecker(LXMLOutputChecker):
  330. def get_default_parser(self):
  331. return html_fromstring
  332. def install(html=False):
  333. """
  334. Install doctestcompare for all future doctests.
  335. If html is true, then by default the HTML parser will be used;
  336. otherwise the XML parser is used.
  337. """
  338. if html:
  339. doctest.OutputChecker = LHTMLOutputChecker
  340. else:
  341. doctest.OutputChecker = LXMLOutputChecker
  342. def temp_install(html=False, del_module=None):
  343. """
  344. Use this *inside* a doctest to enable this checker for this
  345. doctest only.
  346. If html is true, then by default the HTML parser will be used;
  347. otherwise the XML parser is used.
  348. """
  349. if html:
  350. Checker = LHTMLOutputChecker
  351. else:
  352. Checker = LXMLOutputChecker
  353. frame = _find_doctest_frame()
  354. dt_self = frame.f_locals['self']
  355. checker = Checker()
  356. old_checker = dt_self._checker
  357. dt_self._checker = checker
  358. # The unfortunate thing is that there is a local variable 'check'
  359. # in the function that runs the doctests, that is a bound method
  360. # into the output checker. We have to update that. We can't
  361. # modify the frame, so we have to modify the object in place. The
  362. # only way to do this is to actually change the func_code
  363. # attribute of the method. We change it, and then wait for
  364. # __record_outcome to be run, which signals the end of the __run
  365. # method, at which point we restore the previous check_output
  366. # implementation.
  367. if _IS_PYTHON_3:
  368. check_func = frame.f_locals['check'].__func__
  369. checker_check_func = checker.check_output.__func__
  370. else:
  371. check_func = frame.f_locals['check'].im_func
  372. checker_check_func = checker.check_output.im_func
  373. # Because we can't patch up func_globals, this is the only global
  374. # in check_output that we care about:
  375. doctest.etree = etree
  376. _RestoreChecker(dt_self, old_checker, checker,
  377. check_func, checker_check_func,
  378. del_module)
  379. class _RestoreChecker(object):
  380. def __init__(self, dt_self, old_checker, new_checker, check_func, clone_func,
  381. del_module):
  382. self.dt_self = dt_self
  383. self.checker = old_checker
  384. self.checker._temp_call_super_check_output = self.call_super
  385. self.checker._temp_override_self = new_checker
  386. self.check_func = check_func
  387. self.clone_func = clone_func
  388. self.del_module = del_module
  389. self.install_clone()
  390. self.install_dt_self()
  391. def install_clone(self):
  392. if _IS_PYTHON_3:
  393. self.func_code = self.check_func.__code__
  394. self.func_globals = self.check_func.__globals__
  395. self.check_func.__code__ = self.clone_func.__code__
  396. else:
  397. self.func_code = self.check_func.func_code
  398. self.func_globals = self.check_func.func_globals
  399. self.check_func.func_code = self.clone_func.func_code
  400. def uninstall_clone(self):
  401. if _IS_PYTHON_3:
  402. self.check_func.__code__ = self.func_code
  403. else:
  404. self.check_func.func_code = self.func_code
  405. def install_dt_self(self):
  406. self.prev_func = self.dt_self._DocTestRunner__record_outcome
  407. self.dt_self._DocTestRunner__record_outcome = self
  408. def uninstall_dt_self(self):
  409. self.dt_self._DocTestRunner__record_outcome = self.prev_func
  410. def uninstall_module(self):
  411. if self.del_module:
  412. import sys
  413. del sys.modules[self.del_module]
  414. if '.' in self.del_module:
  415. package, module = self.del_module.rsplit('.', 1)
  416. package_mod = sys.modules[package]
  417. delattr(package_mod, module)
  418. def __call__(self, *args, **kw):
  419. self.uninstall_clone()
  420. self.uninstall_dt_self()
  421. del self.checker._temp_override_self
  422. del self.checker._temp_call_super_check_output
  423. result = self.prev_func(*args, **kw)
  424. self.uninstall_module()
  425. return result
  426. def call_super(self, *args, **kw):
  427. self.uninstall_clone()
  428. try:
  429. return self.check_func(*args, **kw)
  430. finally:
  431. self.install_clone()
  432. def _find_doctest_frame():
  433. import sys
  434. frame = sys._getframe(1)
  435. while frame:
  436. l = frame.f_locals
  437. if 'BOOM' in l:
  438. # Sign of doctest
  439. return frame
  440. frame = frame.f_back
  441. raise LookupError(
  442. "Could not find doctest (only use this function *inside* a doctest)")
  443. __test__ = {
  444. 'basic': '''
  445. >>> temp_install()
  446. >>> print """<xml a="1" b="2">stuff</xml>"""
  447. <xml b="2" a="1">...</xml>
  448. >>> print """<xml xmlns="http://example.com"><tag attr="bar" /></xml>"""
  449. <xml xmlns="...">
  450. <tag attr="..." />
  451. </xml>
  452. >>> print """<xml>blahblahblah<foo /></xml>""" # doctest: +NOPARSE_MARKUP, +ELLIPSIS
  453. <xml>...foo /></xml>
  454. '''}
  455. if __name__ == '__main__':
  456. import doctest
  457. doctest.testmod()