#!/usr/bin/env python3
# -*- coding: utf-8 -*-

""" Python code to start a RIPE Atlas UDM (User-Defined
Measurement). This one is for running DNS to resolve a name from many
places, in order to survey local cache poisonings, effect of
hijackings and other DNS rejuvenation effects.

You'll need an API key in ~/.atlas/auth.

After launching the measurement, it downloads the results and analyzes
them.

Stéphane Bortzmeyer <stephane+frama@bortzmeyer.org>
"""

import json
import sys
import time
import base64
import re
import copy
import collections

# DNS Python http://www.dnspython.org/
import dns.message

import Blaeu

config = Blaeu.Config()
# Default values
config.qtype = 'AAAA'
config.qclass = "IN"
config.display_resolvers = False
config.display_rtt = False
config.display_validation = False
config.edns_size = None
config.dnssec = False
config.dnssec_checking = True
config.nameserver = None
config.recursive = True
config.sort = False
config.nsid = False
config.only_one_per_probe = True
config.protocol = "UDP"
config.tls = False
config.probe_id = False
config.answer_section = True
config.authority_section = False
config.additional_section = False

# Local values
edns_size = None

# Constants
MAXLEN = 80 # Maximum length of a displayed resource record

class Set():
    def __init__(self):
        self.total = 0
        self.successes = 0
        self.rtt = 0

def usage(msg=None):
    print("Usage: %s domain-name" % sys.argv[0], file=sys.stderr)
    config.usage(msg)
    print("""Also:
    --displayresolvers or -l : display the resolvers IP addresses (WARNING: big lists)
    --norecursive or -Z : asks the resolver to NOT recurse (default is to recurse, note --norecursive works ONLY if asking a specific resolver, not with the default one)
    --dnssec or -D : asks the resolver the DNSSEC records
    --nsid : asks the resolver with NSID (name server identification)
    --ednssize=N or -B N : asks for EDNS with the "payload size" option (default is very old DNS, without EDNS)
    --tcp: uses TCP (default is UDP)
    --tls: uses TLS (implies TCP)
    --checkingdisabled or -k : asks the resolver to NOT perform DNSSEC validation
    --displayvalidation or -j : displays the DNSSEC validation status
    --displayrtt : displays the average RTT
    --authority : displays the Authority section of the answer
    --additional : displays the Additional section of the answer 
    --sort or -S : sort the result sets
    --type or -q : query type (default is %s)
    --class : query class (default is %s)
    --severalperprobe : count all the resolvers of each probe (default is to count only the first to reply)
    --nameserver=name_or_IPaddr[,...] or -x name_or_IPaddr : query this name server (default is to query the probe's resolver)
    --probe_id : prepend probe ID (and timestamp) to the domain name (default is to abstain)
    """ % (config.qtype, config.qclass), file=sys.stderr)

def specificParse(config, option, value):
        result = True
        if option == "--type" or option == "-q":
            config.qtype = value
        elif option == "--class": # For Chaos, use "CHAOS", not "CH"
            config.qclass = value
        elif option == "--norecursive" or option == "-Z":
            config.recursive = False
        elif option == "--dnssec" or option == "-D":
            config.dnssec = True
        elif option == "--nsid":
            config.nsid = True
        elif option == "--probe_id":
            config.probe_id = True
        elif option == "--ednssize" or option == "-B":
            config.edns_size = int(value)
        elif option == "--tcp":
            config.protocol = "TCP"
        elif option == "--tls":
            config.tls = True
        elif option == "--checkingdisabled" or option == "-k":
            config.dnssec_checking = False
        elif option == "--sort" or option == "-S":
            config.sort = True
        elif option == "--authority":
            config.answer_section = False
            config.authority_section = True
        elif option == "--additional":
            config.answer_section = False
            config.additional_section = True
        elif option == "--nameserver" or option == "-x":
            config.nameserver = value
            config.nameservers = config.nameserver.split(",")
        elif option == "--displayresolvers" or option == "-l":
            config.display_resolvers = True
        elif option == "--displayvalidation" or option == "-j":
            config.display_validation = True
        elif option == "--displayrtt":
            config.display_rtt = True
        elif option == "--severalperprobe":
            config.only_one_per_probe = False
        else:
            result = False
        return result
    
args, data = config.parse("q:ZDkSx:ljB:", ["type=", "class=", "ednssize=",
                                     "displayresolvers", "probe_id",
                                     "displayrtt", "displayvalidation",
                                     "dnssec", "nsid", "norecursive",
                                     "authority", "additional",
                                     "tcp", "tls", "checkingdisabled",
                                     "nameserver=", "sort",
                                     "severalperprobe"], specificParse,
                    usage)

if len(args) != 1:
    usage()
    sys.exit(1)

domainname = args[0]

if config.tls:
    config.protocol = "TCP"
    # We don't set the port (853) but Atlas does it for us
    
data["definitions"][0]["type"] = "dns"
del data["definitions"][0]["size"]
del data["definitions"][0]["port"]
data["definitions"][0]["query_argument"] = domainname
data["definitions"][0]["description"] = ("DNS resolution of %s/%s" % (domainname, config.qtype)) + data["definitions"][0]["description"]
data["definitions"][0]["query_class"] = config.qclass
data["definitions"][0]["query_type"] = config.qtype
if config.edns_size is not None and config.protocol == "UDP":
    data["definitions"][0]["udp_payload_size"] = config.edns_size
    edns_size = config.edns_size
if config.dnssec or config.display_validation: # https://atlas.ripe.net/docs/api/v2/reference/#!/measurements/Dns_Type_Measurement_List_POST
    data["definitions"][0]["set_do_bit"] = True
    if config.edns_size is None and config.protocol == "UDP":
        edns_size = 4096
if config.nsid: 
    data["definitions"][0]["set_nsid_bit"] = True
    if config.edns_size is None and config.protocol == "UDP":
        edns_size = 1024
if edns_size is not None and config.protocol == "UDP":
    data["definitions"][0]["udp_payload_size"] = edns_size
if not config.dnssec_checking:
    data["definitions"][0]["set_cd_bit"] = True
if config.recursive:
    data["definitions"][0]["set_rd_bit"] = True
else:
    data["definitions"][0]["set_rd_bit"] = False
if config.tls:
    data["definitions"][0]["tls"] = True
if config.probe_id:
    data["definitions"][0]["prepend_probe_id"] = True
data["definitions"][0]["protocol"] = config.protocol
if config.verbose and config.machine_readable:
    usage("Specify verbose *or* machine-readable output")
    sys.exit(1)
if (config.display_probes or config.display_resolvers or config.display_rtt) and config.machine_readable:
    usage("Display probes/resolvers/RTT *or* machine-readable output")
    sys.exit(1)

if config.nameserver is None:
    config.nameservers = [None,]

for nameserver in config.nameservers:
    if nameserver is None:
        data["definitions"][0]["use_probe_resolver"] = True
        # Exclude probes which do not have at least one working resolver
        data["probes"][0]["tags"]["include"].append("system-resolves-a-correctly")
        data["probes"][0]["tags"]["include"].append("system-resolves-aaaa-correctly")
    else:
        data["definitions"][0]["use_probe_resolver"] = False
        data["definitions"][0]["target"] = nameserver
        data["definitions"][0]["description"] += (" via nameserver %s" % nameserver) # TODO if several nameservers, they addresses are added after each other :-(
        if nameserver.find(':') > -1: # TODO: or use is_ip_address(str) from blaeu-reach?
            config.ipv4 = False
            data["definitions"][0]['af'] = 6 
            if config.include is not None:
                data["probes"][0]["tags"]["include"] = copy.copy(config.include)
                data["probes"][0]["tags"]["include"].append("system-ipv6-works")
            else:
                data["probes"][0]["tags"]["include"] = ["system-ipv6-works",]
        elif re.match("^[0-9.]+$", nameserver):
            config.ipv4 = True
            data["definitions"][0]['af'] = 4
            if config.include is not None:
                data["probes"][0]["tags"]["include"] = copy.copy(config.include)
                data["probes"][0]["tags"]["include"].append("system-ipv4-works")
            else:
                data["probes"][0]["tags"]["include"] = ["system-ipv4-works",]
        else: # Probably an host name
            pass
    if config.measurement_id is None:
        if config.verbose:
            print(data)
        measurement = Blaeu.Measurement(data,
                                        lambda delay: sys.stderr.write(
                "Sleeping %i seconds...\n" % delay))

        if not config.machine_readable and config.verbose:
            print("Measurement #%s for %s/%s uses %i probes" % \
            (measurement.id, domainname, config.qtype, measurement.num_probes))

        old_measurement = measurement.id
        results = measurement.results(wait=True)
    else:
        measurement = Blaeu.Measurement(data=None, id=config.measurement_id)
        results = measurement.results(wait=False)
        if config.verbose:
            print("%i results from already-done measurement %s" % (len(results), measurement.id))
    if len(results) == 0:
        print("Warning: zero results. Measurement not terminated? May be retry later with --measurement-ID=%s ?" % (measurement.id), file=sys.stderr)
    probes = 0
    successes = 0

    qtype_num = dns.rdatatype.from_text(config.qtype) # Raises dns.rdatatype.UnknownRdatatype if unknown
    sets = collections.defaultdict(Set)
    if config.display_probes:
        probes_sets = collections.defaultdict(Set)
    if config.display_resolvers:
        resolvers_sets = collections.defaultdict(Set)
    for result in results:
        probes += 1
        probe_id = result["prb_id"]
        first_error = ""
        probe_resolves = False
        resolver_responds = False
        all_timeout = True
        if "result" in result:
            result_set = [{'result': result['result']},]
        elif "resultset" in result:
            result_set = result['resultset']
        elif "error" in result:
            result_set = []
            myset = []
            if "timeout" in result['error']:
                myset.append("TIMEOUT")
            elif "socket" in result['error']:
                all_timeout = False
                myset.append("NETWORK PROBLEM WITH RESOLVER")
            elif "TUCONNECT" in result['error']:
                all_timeout = False
                myset.append("TUCONNECT (may be a TLS negotiation error or a TCP connection issue)")
            else:
                all_timeout = False
                myset.append("NO RESPONSE FOR UNKNOWN REASON at probe %s" % probe_id)
        else:
            raise Blaeu.WrongAssumption("Neither result not resultset member")
        if len(result_set) == 0:
            myset.sort()
            set_str = " ".join(myset)
            sets[set_str].total += 1
            if config.display_probes:
                if set_str in probes_sets:
                    probes_sets[set_str].append(probe_id)
                else:
                    probes_sets[set_str] = [probe_id,]
        for result_i in result_set:
            try:
                if "dst_addr" in result_i:
                    resolver = str(result_i['dst_addr'])
                elif "dst_name" in result_i: # Apparently, used when there was a problem
                    resolver = str(result_i['dst_name'])
                elif "dst_addr" in result: # Used when specifying a name server
                    resolver = str(result['dst_addr'])
                elif "dst_name" in result: # Apparently, used when there was a problem
                    resolver = str(result['dst_name'])
                else:
                    resolver = "UNKNOWN RESOLUTION ERROR"
                myset = []
                if "result" not in result_i:
                    if config.only_one_per_probe:
                        continue
                    else:
                        if "timeout" in result_i['error']:
                            myset.append("TIMEOUT")
                        elif "socket" in result_i['error']:
                            all_timeout = False
                            myset.append("NETWORK PROBLEM WITH RESOLVER")
                        else:
                            all_timeout = False
                            myset.append("NO RESPONSE FOR UNKNOWN REASON at probe %s" % probe_id)
                else:
                    all_timeout = False
                    resolver_responds = True
                    answer = result_i['result']['abuf'] + "=="
                    content = base64.b64decode(answer)
                    msg = dns.message.from_wire(content)
                    if config.nsid:
                        for opt in msg.options:
                            if opt.otype == dns.edns.NSID:
                                myset.append("NSID: %s;" % opt.data.decode())
                    successes += 1
                    if msg.rcode() == dns.rcode.NOERROR:
                        probe_resolves = True
                        if config.answer_section:
                            if result_i['result']['ANCOUNT'] == 0 and config.verbose:
                                # If we test an authoritative server, and it returns a delegation, we won't see anything...
                                print("Warning: reply at probe %s has no answers: may be the server returned a delegation, or does not have data of type %s? For the first case, you may want to use --authority." % (probe_id, config.qtype), file=sys.stderr)
                            interesting_section = msg.answer
                        elif config.authority_section:
                            interesting_section = msg.authority
                        elif config.additional_section:
                            interesting_section = msg.additional
                        for rrset in interesting_section:
                            for rdata in rrset:
                                if rdata.rdtype == qtype_num:
                                    myset.append(str(rdata)[0:MAXLEN].lower()) # We truncate because DNSKEY can be very long
                        if config.display_validation and (msg.flags & dns.flags.AD):
                            myset.append(" (Authentic Data flag) ")
                        if (msg.flags & dns.flags.TC):
                            if edns_size is not None:
                                myset.append(" (TRUNCATED - EDNS buffer size was %d ) " % edns_size)
                            else:
                                myset.append(" (TRUNCATED - May have to use --ednssize) ")
                    else:
                        if msg.rcode() == dns.rcode.REFUSED: # Not SERVFAIL since
                            # it can be legitimate (DNSSEC problem, for instance)
                            if config.only_one_per_probe and len(result_set) > 1: # It
                                # does not handle the case where there
                                # are several resolvers and all say
                                # REFUSED (probably a rare case).
                                if first_error == "":
                                    first_error = "ERROR: %s" % dns.rcode.to_text(msg.rcode())
                                continue # Try again
                        else:
                            probe_resolves = True # NXDOMAIN or SERVFAIL are legitimate
                        myset.append("ERROR: %s" % dns.rcode.to_text(msg.rcode()))
                myset.sort()
                set_str = " ".join(myset)
                sets[set_str].total += 1
                if "error" not in result_i:
                    sets[set_str].successes += 1
                if config.display_probes:
                    if set_str in probes_sets:
                        probes_sets[set_str].append(probe_id)
                    else:
                        probes_sets[set_str] = [probe_id,]
                if config.display_resolvers:
                    if set_str in resolvers_sets:
                        if not (resolver in resolvers_sets[set_str]):
                            resolvers_sets[set_str].append(resolver)
                    else:
                        resolvers_sets[set_str] = [resolver,]
                if config.display_rtt:
                    if "error" not in result_i:
                        if "result" not in result_i:
                            sets[set_str].rtt +=  result_i['rt']
                        else:
                            sets[set_str].rtt +=  result_i['result']['rt']
            except dns.name.BadLabelType:
                if not config.machine_readable:
                    print("Probe %s failed (bad label in name)" % probe_id, file=sys.stderr)
            except dns.message.TrailingJunk:
                if not config.machine_readable:
                    print("Probe %s failed (trailing junk)" % probe_id, file=sys.stderr)
            except dns.exception.FormError:
                if not config.machine_readable:
                    print("Probe %s failed (malformed DNS message)" % probe_id, file=sys.stderr)
            if config.only_one_per_probe:
                    break
        if not probe_resolves and first_error != "" and config.verbose:
            print("Warning, probe %s has no working resolver (first error is \"%s\")" % (probe_id, first_error), file=sys.stderr)
        if not resolver_responds:
            if all_timeout and not config.only_one_per_probe:
                if config.verbose:
                    print("Warning, probe %s never got reply from any resolver" % (probe_id), file=sys.stderr)
                set_str = "TIMEOUT(S) on all resolvers"
                sets[set_str].total += 1
            else:
                myset.sort()
                set_str = " ".join(myset)
    if config.sort:
        sets_data = sorted(sets, key=lambda s: sets[s].total, reverse=True)
    else:
        sets_data = sets
    details = []
    if not config.machine_readable and config.nameserver is not None:
            print("Nameserver %s" % config.nameserver)
    if not config.answer_section:
        if config.authority_section:
            print("Authority section of the DNS responses")
        elif config.additional_section:
            print("Additional section of the DNS responses")
        else:
            print("INTERNAL PROBLEM: no section to display?")
    for myset in sets_data:
        detail = ""
        if config.display_probes:
            detail = "(probes %s)" % probes_sets[myset]
        if config.display_resolvers:
            detail += "(resolvers %s)" % resolvers_sets[myset]
        if config.display_rtt and sets[myset].successes > 0:
            detail += "Average RTT %i ms" % (sets[myset].rtt/sets[myset].successes)
        if not config.machine_readable:
            print("[%s] : %i occurrences %s" % (myset, sets[myset].total, detail))
        else:
            details.append("[%s];%i" % (myset, sets[myset].total))

    if not config.machine_readable:
        print(("Test #%s done at %s" % (measurement.id, time.strftime("%Y-%m-%dT%H:%M:%SZ", measurement.time))))
        print("")
    else:
        # TODO: what if we analyzed an existing measurement?
        if config.nameserver is None:
            ns = "DEFAULT RESOLVER"
        else:
            ns = config.nameserver
        print(",".join([domainname, config.qtype, str(measurement.id), "%s/%s" % (len(results), measurement.num_probes), \
                        time.strftime("%Y-%m-%dT%H:%M:%SZ", measurement.time), ns] + details))
