#!/usr/bin/env python
#
# This file is a part of DNSViz, a tool suite for DNS/DNSSEC monitoring,
# analysis, and visualization.
# Created by Casey Deccio (casey@deccio.net)
#
# Copyright 2014-2016 VeriSign, Inc.
#
# Copyright 2016-2017 Casey Deccio.
#
# DNSViz is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# DNSViz is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with DNSViz.  If not, see <http://www.gnu.org/licenses/>.
#

from __future__ import unicode_literals

import datetime
import errno
import socket
import sys

import dns.flags, dns.exception, dns.name, dns.opcode, dns.rdataclass, dns.rdatatype

from dnsviz.ipaddr import IPAddr
from dnsviz import query as Q
from dnsviz import resolver as Resolver
from dnsviz import transport
from dnsviz.util import get_trusted_keys

tm = transport.DNSQueryTransportManager()

class CommandLineException(Exception):
    pass

class SemanticException(Exception):
    pass

def _get_nameservers_for_name(addr):
    nameservers = []
    try:
        addrinfo = socket.getaddrinfo(addr, 53, 0, 0, socket.IPPROTO_TCP)
    except socket.gaierror:
        raise SemanticException('Unable to resolve "%s"' % addr)
    else:
        for item in addrinfo:
            nameservers.append(IPAddr(item[4][0]))
    return nameservers

class DigCommandLineQuery:
    def __init__(self, qname, rdtype, rdclass):
        self.qname = qname
        self.rdtype = rdtype
        self.rdclass = rdclass

        self.nameservers = []
        self.query_options = []

        # default query options
        self.handlers = [Q.UseTCPOnTCFlagHandler()]
        self.flags = dns.flags.RD | dns.flags.AD
        self.edns = 0
        self.edns_max_udp_payload = 4096
        self.edns_flags = 0
        self.edns_options = []
        self.tcp = False
        self.query_timeout = 5.0
        self.max_attempts = 3
        self.lifetime = None

        self.show_additional = True
        self.show_answer = True
        self.show_authority = True
        self.show_cmd = True
        self.show_comments = True
        self.show_question = True
        self.show_stats = True
        self.show_class = True
        self.show_ip_port = False
        self.multiline = False
        self.show_rr_comments = False
        self.short = False
        self.trusted_keys = ()
        self.show_ttl = True
        self.lg_url = None

    def process_query_options(self, global_options):
        for arg in global_options + self.query_options:
            if arg in ('+aaflag', '+aaonly', '+aa',):
                self.flags |= dns.flags.AA
            elif arg in ('+noaaflag', '+noaaonly', '+noaa'):
                self.flags &= ~dns.flags.AA
            elif arg == '+additional':
                self.show_additional = True
            elif arg == '+noadditional':
                self.show_additional = False
            elif arg in ('+adflag', '+ad'):
                self.flags |= dns.flags.AD
            elif arg in ('+noadflag', '+noad'):
                self.flags &= ~dns.flags.AD
            elif arg == '+all':
                self.show_additional = True
                self.show_answer = True
                self.show_authority = True
                self.show_cmd = True
                self.show_comments = True
                self.show_question = True
                self.show_stats = True
            elif arg == '+noall':
                self.show_additional = False
                self.show_answer = False
                self.show_authority = False
                self.show_cmd = False
                self.show_comments = False
                self.show_question = False
                self.show_stats = False
            elif arg == '+answer':
                self.show_answer = True
            elif arg == '+noanswer':
                self.show_answer = False
            elif arg == '+authority':
                self.show_authority = True
            elif arg == '+noauthority':
                self.show_authority = False
            #TODO +[no]besteffort
            elif arg.startswith('+bufsize') and \
                    (len(arg) <= 8 or arg[8] == '='):
                if self.edns < 0:
                    self.edns = 0
                try:
                    opt, arg = arg.split('=')
                    self.edns_max_udp_payload = int(arg)
                    if self.edns_max_udp_payload < 0 or self.edns_max_udp_payload > 65535:
                        raise ValueError()
                except ValueError:
                    raise CommandLineException('+bufsize requires an integer argument between 0 and 65535')
            elif arg in ('+cdflag', '+cd'):
                self.flags |= dns.flags.CD
            elif arg in ('+nocdflag', '+nocd'):
                self.flags &= ~dns.flags.CD
            elif arg == '+cl':
                self.show_class = True
            elif arg == '+nocl':
                self.show_class = False
            elif arg == '+cmd':
                self.show_cmd = True
            elif arg == '+nocmd':
                self.show_cmd = False
            elif arg == '+comments':
                self.show_comments = True
            elif arg == '+nocomments':
                self.show_comments = False
            #TODO +[no]crypto
            #TODO +[no]defname
            elif arg == '+dnssec':
                if self.edns < 0:
                    self.edns = 0
                self.edns_flags |= dns.flags.DO
            elif arg == '+nodnssec':
                self.edns_flags &= ~dns.flags.DO
            #TODO +domain=somename
            elif arg.startswith('+edns') and \
                    (len(arg) <= 5 or arg[5] == '='):
                try:
                    opt, arg = arg.split('=')
                    self.edns = int(arg)
                except ValueError:
                    raise CommandLineException('+edns requires an integer argument greater than or equal to 0')
            elif arg == '+noedns':
                self.edns = -1
            #TODO +[no]expire
            #TODO +[no]fail
            elif arg == '+identify':
                self.show_ip_port = True
            elif arg == '+noidentify':
                self.show_ip_port = False
            elif arg == '+ignore':
                self.handlers = []
            elif arg == '+noignore':
                self.handlers = [Q.UseTCPOnTCFlagHandler()]
            #TODO +[no]keepopen
            elif arg == '+multiline':
                self.multiline = True
            elif arg == '+nomultiline':
                self.multiline = False
            #TODO +ndots=D
            #TODO +[no]nsid
            #TODO +[no]nssearch
            #TODO +[no]onesoa
            #TODO +[no]qr
            elif arg == '+question':
                self.show_question = True
            elif arg == '+noquestion':
                self.show_question = False
            elif arg in ('+recurse', '+rec'):
                self.flags |= dns.flags.RD
            elif arg in ('+norecurse', '+norec'):
                self.flags &= ~dns.flags.RD
            elif arg.startswith('+retry') and \
                    (len(arg) <= 6 or arg[6] == '='):
                try:
                    opt, arg = arg.split('=')
                    self.max_attempts = int(arg) + 1
                    if self.max_attempts < 1:
                        self.max_attempts = 1
                except ValueError:
                    raise CommandLineException('+retry requires an integer argument')
            elif arg == '+rrcomments':
                self.show_rr_comments = True
            elif arg == '+norrcomments':
                self.show_rr_comments = False
            #TODO +[no]search
            elif arg == '+short':
                self.short = True
            elif arg == '+noshort':
                self.short = False
            #TODO +[no]showsearch
            #TODO +[no]sigchase
            #TODO +[no]sit[=####]
            #TODO +split=W
            elif arg == '+stats':
                self.show_stats = True
            elif arg == '+nostats':
                self.show_stats = False
            #TODO +[no]subnet=addr/prefix
            elif arg in ('+tcp', '+vc'):
                self.tcp = True
            elif arg in ('+notcp', '+novc'):
                self.tcp = False
            elif arg.startswith('+timeout') and \
                    (len(arg) <= 8 or arg[8] == '='):
                try:
                    opt, arg = arg.split('=')
                    self.query_timeout = float(arg)
                    if self.query_timeout < 1.0:
                        self.query_timeout = 1.0
                except ValueError:
                    raise CommandLineException('+timeout requires a numerical argument')
            #TODO +[no]topdown
            #TODO +[no]trace
            #TODO +[no]tries
            elif arg.startswith('+tries') and \
                    (len(arg) <= 6 or arg[6] == '='):
                try:
                    opt, arg = arg.split('=')
                    self.max_attempts = int(arg)
                    if self.max_attempts < 1:
                        self.max_attempts = 1
                except ValueError:
                    raise CommandLineException('+tries requires an integer argument')
            elif arg.startswith('+trusted-key') and \
                    (len(arg) <= 12 or arg[12] == '='):
                try:
                    opt, arg = arg.split('=')
                    if not arg:
                        raise ValueError()
                except ValueError:
                    raise CommandLineException('+trusted-key requires a filename argument.')
                else:
                    try:
                        tk_str = open(arg).read()
                    except IOError as e:
                        raise CommandLineException('%s: "%s"' % (e.strerror, arg))
                    try:
                        self.trusted_keys = get_trusted_keys(tk_str)
                    except dns.exception.DNSException:
                        raise SemanticException('There was an error parsing the trusted keys file: "%s"' % arg)
            elif arg in ('+ttlid', '+ttl'):
                self.show_ttl = True
            elif arg in ('+nottlid', '+nottl'):
                self.show_ttl = False
            elif arg.startswith('+lg') and \
                    (len(arg) <= 3 or arg[3] == '='):
                try:
                    opt, arg = arg.split('=')
                    if not arg:
                        raise ValueError()
                except ValueError:
                    raise CommandLineException('+lg requires a URL argument.')
                else:
                    self.lg_url = arg
            else:
                raise CommandLineException('Option "%s" not recognized.' % arg)

    def process_nameservers(self, nameservers, use_ipv4, use_ipv6):
        processed_nameservers = []
        for addr in self.nameservers:
            processed_nameservers.extend(_get_nameservers_for_name(addr))

        if not use_ipv4:
            processed_nameservers = [x for x in processed_nameservers if x.version != 4]
        if not use_ipv6:
            processed_nameservers = [x for x in processed_nameservers if x.version != 6]

        self.nameservers = nameservers + processed_nameservers

    def _get_resolver(self, options):
        class CustomQuery(Q.DNSQueryFactory):
            flags = self.flags
            edns = self.edns
            edns_max_udp_payload = self.edns_max_udp_payload
            edns_flags = self.edns_flags
            edns_options = self.edns_options
            tcp = self.tcp
            response_handlers = self.handlers

        if self.lg_url is not None:
            th_factories = (transport.DNSQueryTransportHandlerHTTPFactory(self.lg_url),)
        else:
            th_factories = None

        return Resolver.Resolver(self.nameservers, CustomQuery,
                timeout=self.query_timeout, max_attempts=self.max_attempts,
                lifetime=self.lifetime, shuffle=False, client_ipv4=options['client_ipv4'],
                client_ipv6=options['client_ipv6'], port=options['port'], transport_manager=tm,
                th_factories=th_factories)

    def _get_name(self):
        #TODO qualify name, if necessary
        #TODO check name syntax, etc.
        return dns.name.from_text(self.qname)

    def _get_rdtype(self, options):
        if self.rdtype is None:
            return options['rdtype']
        else:
            return self.rdtype

    def _get_rdclass(self, options):
        if self.rdclass is None:
            return options['rdclass']
        else:
            return self.rdclass

    def query(self, options):
        res = self._get_resolver(options)
        qname = self._get_name()
        rdtype = self._get_rdtype(options)
        rdclass = self._get_rdclass(options)
        return res.query(qname, rdtype, rdclass)

    def display(self, response, server, options):
        if response is None:
            return ';; no servers were queried\n'

        elif response.message is not None:
            if self.short:
                s = ''
                if self.show_ip_port:
                    identity = ' from server %s in %d ms.' % (server, int(response.response_time*1000))
                else:
                    identity = ''
                for rrset in response.message.answer:
                    for rr in rrset:
                        s += '%s%s\n' % (rr.to_text(), identity)
                return s

            # get counts
            if response.message.question:
                question_ct = 1
            else:
                question_ct = 0
            answer_ct = 0
            for i in response.message.answer: answer_ct += len(i)
            authority_ct = 0
            for i in response.message.authority: authority_ct += len(i)
            additional_ct = 0
            for i in response.message.additional: additional_ct += len(i)
            if response.message.edns >= 0:
                additional_ct += 1

            #TODO show_cmd, multiline, show_rr_comments

            s = ''
            if self.show_comments:
                s += ';; Got answer:\n'
                s += ';; ->>HEADER<<- opcode: %s, status: %s, id: %d\n' % (dns.opcode.to_text(response.message.opcode()), dns.rcode.to_text(response.message.rcode()), response.message.id)
                s += ';; flags: %s; QUERY: %d, ANSWER: %d, AUTHORITY: %d, ADDITIONAL: %d\n' % (dns.flags.to_text(response.message.flags).lower(), question_ct, answer_ct, authority_ct, additional_ct)
                if (self.flags & dns.flags.RD) and not (response.message.flags & dns.flags.RA):
                    s += ';; WARNING: recursion requested but not available\n'
                s += '\n'
                if response.message.edns >= 0:
                    s += ';; OPT PSEUDOSECTION:\n'
                    s += '; EDNS: version: %d, flags: %s; udp: %d\n' % (response.message.edns, dns.flags.edns_to_text(response.message.ednsflags).lower(), response.message.payload)

            if response.message.question and self.show_question:
                if self.show_comments:
                    s += ';; QUESTION SECTION:\n'
                s += ';%s          %s %s\n' % (response.message.question[0].name, dns.rdataclass.to_text(response.message.question[0].rdclass), dns.rdatatype.to_text(response.message.question[0].rdtype))
                if self.show_comments:
                    s += '\n'

            for section, title in ((response.message.answer, 'ANSWER'), (response.message.authority, 'AUTHORITY'), (response.message.additional, 'ADDITIONAL')):
                if section and getattr(self, 'show_%s' % title.lower()):
                    if self.show_comments:
                        s += ';; %s SECTION:\n' % title
                    for rrset in section:
                        for rr in rrset:
                            if self.show_ttl:
                                ttl = '\t%d' % rrset.ttl
                            else:
                                ttl = ''
                            if self.show_class:
                                cls = '\t%s' % dns.rdataclass.to_text(rrset.rdclass)
                            else:
                                cls = ''
                            s += '%s\t%s%s\t%s\t%s\n' % (rrset.name, ttl, cls, dns.rdatatype.to_text(rrset.rdtype), rr.to_text())
                    if self.show_comments:
                        s += '\n'

            if self.show_stats:
                s += ';; Query time: %d msec\n' % int(response.response_time*1000)
                s += ';; SERVER: %s#%d\n' % (server, options['port'])
                s += ';; WHEN: %s\n' % datetime.datetime.now().strftime('%a %b %d %H:%M:%S %Y UTC')
                s += ';; MSG SIZE  rcvd: %d\n' % response.msg_size

            return s

        elif response.error in (Q.RESPONSE_ERROR_TIMEOUT, Q.RESPONSE_ERROR_NETWORK_ERROR):
            return ';; connection timed out; no servers could be reached'

        else:
            return ';; the response from %s was malformed' % server

    def query_and_display(self, options, filehandle):
        try:
            server, response = self.query(options)
        except transport.RemoteQueryTransportError as e:
            sys.stderr.write('%s\n' % e)
        else:
            output = self.display(response, server, options)
            filehandle.write(output)
            filehandle.flush()

class DigCommandLine:
    def __init__(self, args):
        self.args = args
        self.arg_index = 0

        self.options = {
            'rdtype': None,
            'rdclass': None,
            'use_ipv4': None,
            'use_ipv6': None,
            'client_ipv4': None,
            'client_ipv6': None,
            'port': 53,
        }

        self.nameservers = []
        self.global_query_options = ['+cmd']

        self.queries = []

        self._process_args()
        self._process_network()
        self._process_nameservers()

        if not self.queries:
            self.queries.append(DigCommandLineQuery('.', dns.rdatatype.NS, dns.rdataclass.IN))

        for q in self.queries:
            q.process_nameservers(self.nameservers, self.options['use_ipv4'], self.options['use_ipv6'])
            q.process_query_options(self.global_query_options)

            if not q.nameservers:
                raise SemanticException('No nameservers to query')

        if self.options['rdtype'] is None:
            self.options['rdtype'] = dns.rdatatype.A
        if self.options['rdclass'] is None:
            self.options['rdclass'] = dns.rdataclass.IN

    def query_and_display(self):
        for q in self.queries:
            q.query_and_display(self.options, sys.stdout)

    def _get_arg(self, has_arg):
        try:
            if len(self.args[self.arg_index]) > 2:
                if not has_arg:
                    raise CommandLineException('"%s" option does not take arguments' % self.args[self.arg_index][:2])
                return self.args[self.arg_index][2:]
            else:
                if not has_arg:
                    return None
                else:
                    self.arg_index += 1
                    if self.arg_index >= len(self.args):
                        raise CommandLineException('"%s" option requires an argument' % self.args[self.arg_index - 1])
                    return self.args[self.arg_index]
        finally:
            self.arg_index += 1

    def _add_server_to_options(self, query):
        addr = self.args[self.arg_index][1:]
        self.arg_index += 1
        if query is None:
            self.nameservers.append(addr)
        else:
            query.nameservers.append(addr)

    def _add_reverse_query(self):
        arg = self._get_arg(True)
        try:
            addr = IPAddr(arg)
        except ValueError:
            raise SemanticException('Invalid IP address: "%s"' % arg)
        else:
            qname = addr.arpa_name()

        return DigCommandLineQuery(qname, dns.rdatatype.PTR, dns.rdataclass.IN)

    def _add_qname_from_opt(self):
        qname = self._get_arg(True)
        return DigCommandLineQuery(qname, None, None)

    def _add_default_option(self):
        if self.options['rdclass'] is None:
            try:
                self.options['rdclass'] = dns.rdataclass.from_text(self.args[self.arg_index])
            except dns.rdataclass.UnknownRdataclass:
                pass
            else:
                self.arg_index += 1
                return True

        if self.options['rdtype'] is None:
            try:
                self.options['rdtype'] = dns.rdatatype.from_text(self.args[self.arg_index])
            except dns.rdatatype.UnknownRdatatype:
                pass
            else:
                self.arg_index += 1
                return True

        return False

    def _add_qname(self):
        qname = self.args[self.arg_index]
        self.arg_index += 1

        # check for optional type
        try:
            rdtype = dns.rdatatype.from_text(self.args[self.arg_index])
        except (IndexError, dns.rdatatype.UnknownRdatatype):
            # no type detected; use default rdtype/rdclass
            rdtype = None
            rdclass = None
        else:
            self.arg_index += 1

        # now check for optional class
        try:
            rdclass = dns.rdataclass.from_text(self.args[self.arg_index])
        except (IndexError, dns.rdataclass.UnknownRdataclass):
            # no class detected; use default rdclass
            rdclass = None
        else:
            self.arg_index += 1

        return DigCommandLineQuery(qname, rdtype, rdclass)

    def _add_option(self):
        if self.args[self.arg_index].startswith('-b'):
            arg = self._get_arg(True)
            try:
                addr = IPAddr(arg)
            except ValueError:
                raise SemanticException('Invalid IP address: "%s"' % arg)

            if addr.version == 6:
                family = socket.AF_INET6
            else:
                family = socket.AF_INET

            try:
                s = socket.socket(family)
                s.bind((addr, 0))
            except socket.error as e:
                if e.errno == errno.EADDRNOTAVAIL:
                    raise SemanticException('Cannot bind to specified IP address: "%s"' % addr)
            else:
                del s
                if addr.version == 6:
                    self.options['client_ipv6'] = addr
                else:
                    self.options['client_ipv4'] = addr
        elif self.args[self.arg_index].startswith('-c'):
            arg = self._get_arg(True)
            try:
                self.options['rdclass'] = dns.rdataclass.from_text(arg)
            except dns.rdataclass.UnknownRdataclass:
                raise SemanticException('Unknown class: "%s".' % arg)
        #TODO -f
        #TODO -k
        #TODO -m
        elif self.args[self.arg_index].startswith('-p'):
            arg = self._get_arg(True)
            try:
                self.options['port'] = int(arg)
                if self.options['port'] < 0 or self.options['port'] > 65535:
                    raise ValueError()
            except ValueError:
                raise CommandLineException('-p requires an integer argument between 0 and 65535')
        #TODO -v
        elif self.args[self.arg_index].startswith('-t'):
            arg = self._get_arg(True)
            try:
                self.options['rdtype'] = dns.rdatatype.from_text(arg)
            except dns.rdatatype.UnknownRdatatype:
                raise SemanticException('Unknown type: "%s".' % arg)
        #TODO -y
        elif self.args[self.arg_index].startswith('-6'):
            self._get_arg(False)
            self.options['use_ipv6'] = True
        elif self.args[self.arg_index].startswith('-4'):
            self._get_arg(False)
            self.options['use_ipv4'] = True
        else:
            raise CommandLineException('Option "%s" not recognized.' % self.args[self.arg_index][:2])

    def _add_query_option(self, query):
        if query is None:
            self.global_query_options.append(self.args[self.arg_index])
        else:
            query.query_options.append(self.args[self.arg_index])
        self.arg_index += 1

    def _process_args(self):
        query = None
        while self.arg_index < len(self.args):
            # server address
            if self.args[self.arg_index][0] == '@':
                self._add_server_to_options(query)

            # reverse lookup
            elif self.args[self.arg_index].startswith('-x'):
                query = self._add_reverse_query()
                self.queries.append(query)

            # forward lookup (with -q)
            elif self.args[self.arg_index].startswith('-q'):
                query = self._add_qname_from_opt()
                self.queries.append(query)

            # options
            elif self.args[self.arg_index][0] == '-':
                self._add_option()

            # query options
            elif self.args[self.arg_index][0] == '+':
                self._add_query_option(query)

            # global query class/type
            elif query is None and self._add_default_option():
                pass

            # name to be queried
            else:
                query = self._add_qname()
                self.queries.append(query)

    def _process_network(self):
        if self.options['use_ipv4'] is None and self.options['use_ipv6'] is None:
            self.options['use_ipv4'] = True
            self.options['use_ipv6'] = True
        if not self.options['use_ipv4']:
            self.options['use_ipv4'] = False
        if not self.options['use_ipv6']:
            self.options['use_ipv6'] = False

    def _process_nameservers(self):
        if not self.nameservers:
            processed_nameservers = Resolver.get_standard_resolver()._servers
        else:
            processed_nameservers = []
            for addr in self.nameservers:
                processed_nameservers.extend(_get_nameservers_for_name(addr))

        if not self.options['use_ipv4']:
            processed_nameservers = [x for x in processed_nameservers if x.version != 4]
        if not self.options['use_ipv6']:
            processed_nameservers = [x for x in processed_nameservers if x.version != 6]

        self.nameservers = processed_nameservers

def main():
    try:
        q = DigCommandLine(sys.argv[1:])
        q.query_and_display()
    except (CommandLineException, SemanticException) as e:
        sys.stderr.write('%s\n' % e)
        sys.exit(1)
    except KeyboardInterrupt:
        pass
    # explicitly close tm here
    finally:
        tm.close()

if __name__ == "__main__":
    main()
