From 7952e311d47f6b34f5dc050681001fd34ab60d01 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Fri, 16 Feb 2024 07:14:49 -0800 Subject: [PATCH] test IgnoreErrors (cherry picked from commit ac6763f1018458835201b38cae848e4d261f3e5c) --- tests/test_query.py | 140 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 140 insertions(+) diff --git a/tests/test_query.py b/tests/test_query.py index 1116b2d12..a47daa459 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -15,6 +15,7 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +import contextlib import socket import sys import time @@ -32,6 +33,7 @@ import dns.message import dns.name import dns.query +import dns.rcode import dns.rdataclass import dns.rdatatype import dns.tsigkeyring @@ -659,3 +661,141 @@ def test_matches_destination(self): dns.query._matches_destination( socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.1", 1235), False ) + + +@contextlib.contextmanager +def mock_udp_recv(wire1, from1, wire2, from2): + saved = dns.query._udp_recv + first_time = True + + def mock(sock, max_size, expiration): + nonlocal first_time + if first_time: + first_time = False + return wire1, from1 + else: + return wire2, from2 + + try: + dns.query._udp_recv = mock + yield None + finally: + dns.query._udp_recv = saved + + +class IgnoreErrors(unittest.TestCase): + def setUp(self): + self.q = dns.message.make_query("example.", "A") + self.good_r = dns.message.make_response(self.q) + self.good_r.set_rcode(dns.rcode.NXDOMAIN) + self.good_r_wire = self.good_r.to_wire() + + def mock_receive( + self, + wire1, + from1, + wire2, + from2, + ignore_unexpected=True, + ignore_errors=True, + ): + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + with mock_udp_recv(wire1, from1, wire2, from2): + (r, when) = dns.query.receive_udp( + s, + ("127.0.0.1", 53), + time.time() + 2, + ignore_unexpected=ignore_unexpected, + ignore_errors=ignore_errors, + query=self.q, + ) + self.assertEqual(r, self.good_r) + finally: + s.close() + + def test_good_mock(self): + self.mock_receive(self.good_r_wire, ("127.0.0.1", 53), None, None) + + def test_bad_address(self): + self.mock_receive( + self.good_r_wire, ("127.0.0.2", 53), self.good_r_wire, ("127.0.0.1", 53) + ) + + def test_bad_address_not_ignored(self): + def bad(): + self.mock_receive( + self.good_r_wire, + ("127.0.0.2", 53), + self.good_r_wire, + ("127.0.0.1", 53), + ignore_unexpected=False, + ) + + self.assertRaises(dns.query.UnexpectedSource, bad) + + def test_bad_id(self): + bad_r = dns.message.make_response(self.q) + bad_r.id += 1 + bad_r_wire = bad_r.to_wire() + self.mock_receive( + bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53) + ) + + def test_bad_id_not_ignored(self): + bad_r = dns.message.make_response(self.q) + bad_r.id += 1 + bad_r_wire = bad_r.to_wire() + + def bad(): + (r, wire) = self.mock_receive( + bad_r_wire, + ("127.0.0.1", 53), + self.good_r_wire, + ("127.0.0.1", 53), + ignore_errors=False, + ) + + self.assertRaises(AssertionError, bad) + + def test_bad_wire(self): + bad_r = dns.message.make_response(self.q) + bad_r.id += 1 + bad_r_wire = bad_r.to_wire() + self.mock_receive( + bad_r_wire[:10], ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53) + ) + + def test_bad_wire_not_ignored(self): + bad_r = dns.message.make_response(self.q) + bad_r.id += 1 + bad_r_wire = bad_r.to_wire() + + def bad(): + self.mock_receive( + bad_r_wire[:10], + ("127.0.0.1", 53), + self.good_r_wire, + ("127.0.0.1", 53), + ignore_errors=False, + ) + + self.assertRaises(dns.message.ShortHeader, bad) + + def test_trailing_wire(self): + wire = self.good_r_wire + b"abcd" + self.mock_receive(wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)) + + def test_trailing_wire_not_ignored(self): + wire = self.good_r_wire + b"abcd" + + def bad(): + self.mock_receive( + wire, + ("127.0.0.1", 53), + self.good_r_wire, + ("127.0.0.1", 53), + ignore_errors=False, + ) + + self.assertRaises(dns.message.TrailingJunk, bad)