Skip to content

Instantly share code, notes, and snippets.

@samuelcolvin
Last active November 4, 2024 02:17
  • Select an option

Select an option

Revisions

  1. samuelcolvin revised this gist Feb 26, 2017. No changes.
  2. samuelcolvin created this gist Feb 26, 2017.
    129 changes: 129 additions & 0 deletions dns_server.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,129 @@
    from datetime import datetime
    from time import sleep

    from dnslib import DNSLabel, QTYPE, RD, RR
    from dnslib import A, AAAA, CNAME, MX, NS, SOA, TXT
    from dnslib.server import DNSServer

    EPOCH = datetime(1970, 1, 1)
    SERIAL = int((datetime.utcnow() - EPOCH).total_seconds())

    TYPE_LOOKUP = {
    A: QTYPE.A,
    AAAA: QTYPE.AAAA,
    CNAME: QTYPE.CNAME,
    MX: QTYPE.MX,
    NS: QTYPE.NS,
    SOA: QTYPE.SOA,
    TXT: QTYPE.TXT,
    }


    class Record:
    def __init__(self, rdata_type, *args, rtype=None, rname=None, ttl=None, **kwargs):
    if isinstance(rdata_type, RD):
    # actually an instance, not a type
    self._rtype = TYPE_LOOKUP[rdata_type.__class__]
    rdata = rdata_type
    else:
    self._rtype = TYPE_LOOKUP[rdata_type]
    if rdata_type == SOA and len(args) == 2:
    # add sensible times to SOA
    args += ((
    SERIAL, # serial number
    60 * 60 * 1, # refresh
    60 * 60 * 3, # retry
    60 * 60 * 24, # expire
    60 * 60 * 1, # minimum
    ),)
    rdata = rdata_type(*args)

    if rtype:
    self._rtype = rtype
    self._rname = rname
    self.kwargs = dict(
    rdata=rdata,
    ttl=self.sensible_ttl() if ttl is None else ttl,
    **kwargs,
    )

    def try_rr(self, q):
    if q.qtype == QTYPE.ANY or q.qtype == self._rtype:
    return self.as_rr(q.qname)

    def as_rr(self, alt_rname):
    return RR(rname=self._rname or alt_rname, rtype=self._rtype, **self.kwargs)

    def sensible_ttl(self):
    if self._rtype in (QTYPE.NS, QTYPE.SOA):
    return 60 * 60 * 24
    else:
    return 300

    @property
    def is_soa(self):
    return self._rtype == QTYPE.SOA

    def __str__(self):
    return '{} {}'.format(QTYPE[self._rtype], self.kwargs)


    ZONES = {
    'example.com': [
    Record(A, '1.2.3.4'),
    Record(CNAME, 'whever.com'),
    Record(MX, 'whatever.com.', 5),
    Record(MX, 'mx2.whatever.com.', 10),
    Record(MX, 'mx3.whatever.com.', 20),
    Record(NS, 'mx2.whatever.com.'),
    Record(NS, 'mx3.whatever.com.'),
    Record(TXT, 'hello this is some text'),
    Record(SOA, 'ns1.example.com', 'dns.example.com'),
    ]
    }


    class Resolver:
    def __init__(self):
    self.zones = {DNSLabel(k): v for k, v in ZONES.items()}

    def resolve(self, request, handler):
    reply = request.reply()
    zone = self.zones.get(request.q.qname)
    if zone is not None:
    for zone_records in zone:
    rr = zone_records.try_rr(request.q)
    rr and reply.add_answer(rr)
    else:
    # no direct zone so look for an SOA record for a higher level zone
    for zone_label, zone_records in self.zones.items():
    if request.q.qname.matchSuffix(zone_label):
    try:
    soa_record = next(r for r in zone_records if r.is_soa)
    except StopIteration:
    continue
    else:
    reply.add_answer(soa_record.as_rr(zone_label))
    break

    return reply


    resolver = Resolver()
    servers = [
    DNSServer(resolver, port=5053, address='localhost', tcp=True),
    DNSServer(resolver, port=5053, address='localhost', tcp=False),
    ]

    if __name__ == '__main__':
    for s in servers:
    s.start_thread()

    try:
    while 1:
    sleep(0.1)
    except KeyboardInterrupt:
    pass
    finally:
    for s in servers:
    s.stop()