#!/usr/bin/env python

# This file is part of Window-Switch.
# Copyright (c) 2009-2013 Antoine Martin <antoine@nagafix.co.uk>
# Window-Switch is released under the terms of the GNU GPL v3

import socket
import random
import sys
import re
import subprocess

from twisted.internet import protocol
import twisted.protocols.basic as twisted_protocols_basic
import twisted.internet.error as twisted_error

from winswitch.consts import LOCALHOST, PORT_START, DISPLAY_TUNNEL_PORT_BASE, X_PORT_BASE, XNEST_OFFSET, \
			IPP_TUNNEL_PORT_BASE, PULSE_TUNNEL_PORT_BASE, COMMAND_PORT_BASE, SAMBA_TUNNEL_PORT_BASE
from winswitch.globals import HOSTNAME, SUBPROCESS_CREATION_FLAGS, WIN32, OSX
from winswitch.util.common import no_newlines, csv_list, is_valid_file
from winswitch.util.main_loop import callLater, connectTCP, listenTCP
from winswitch.util.simple_logger import Logger, msig

logger=Logger("net_util", log_colour=Logger.MAGENTA)



def win32_netstat_parse(udp_or_tcp=None):
	"""
	Parses netstat output to find the list of listening ports.
	An alternative can be found here: http://code.activestate.com/recipes/392572/
	(but is actually longer)
	Returns an array of (command, host, port, pid), ie: [("System", "0.0.0.0", 139, 4), ..]
	"""
	netstat = []
	try:
		args = ["netstat.exe", "-a", "-b", "-n"]
		if udp_or_tcp:
			args += ["-p", udp_or_tcp]
		proc = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, creationflags=SUBPROCESS_CREATION_FLAGS)
		(out,err) = proc.communicate()
		if proc.returncode==0 and not err:
			prev_line = None
			for line in out.split("\n"):
				line = line.strip()
				if prev_line and line.startswith("[") and line.endswith("]"):
					try:
						command = line[1:len(line)-1]
						logger.sdebug("%s: %s" % (command, prev_line), udp_or_tcp)
						netstat_data = prev_line.split()
						if len(netstat_data)==5 and netstat_data[3]=="LISTENING":
							host_port = netstat_data[1].split(":")
							host = host_port[0]
							port = int(host_port[1])
							pid = int(netstat_data[4])
							netstat.append((command, host, port, pid))
					except Exception, e:
						logger.serr("error on line=%s, prev_line=%s" % (line, prev_line), e)
				prev_line = line
	except Exception, e:
		logger.serr("args=%s" % csv_list(args), e)
	return	netstat



has_netifaces = True
try:
	import netifaces
	logger.slog("netifaces loaded sucessfully")
except Exception, e:
	has_netifaces = False
	logger.serror("netifaces package is missing!")
iface_ipmasks = {}
bind_IPs = None

def get_interfaces():
	if not has_netifaces:
		return	[]
	return	netifaces.interfaces()

def get_bind_IPs():
	global bind_IPs
	if not bind_IPs:
		if has_netifaces:
			bind_IPs = do_get_bind_IPs()
		else:
			bind_IPs = ["127.0.0.1"]
	return	bind_IPs

def do_get_bind_IPs():
	global iface_ipmasks
	ips = []
	ifaces = netifaces.interfaces()
	logger.sdebug("ifaces=%s" % str(ifaces))
	for iface in ifaces:
		if_ipmasks = []
		try:
			ipmasks = do_get_bind_ifacemask(iface)
			for ipmask in ipmasks:
				(ip,_) = ipmask
				if ip not in ips:
					ips.append(ip)
				if ipmask not in if_ipmasks:
					if_ipmasks.append(ipmask)
		except Exception, e:
			logger.serr("error on %s" % iface, e)
		iface_ipmasks[iface] = if_ipmasks
	logger.slog("=%s" % str(ips))
	return ips

def get_iface_ipmasks():
	if has_netifaces:
		get_bind_IPs()	#side effect is to set iface_ipmasks
	return	iface_ipmasks

def do_get_bind_ifacemask(iface):
	ipmasks = []
	sig = msig(iface)
	address_types = netifaces.ifaddresses(iface)
	for addresses in address_types.values():
		for address in addresses:
			if 'netmask' in address and 'addr' in address:
				addr = address['addr']
				mask = address['netmask']
				if addr!= '::1' and addr != '0.0.0.0' and addr.find("%")<0:
					try:
						socket.inet_aton(addr)
						ipmasks.append((addr,mask))
					except Exception, e:
						logger.error(sig+" error on %s" % addr, e)
	logger.debug(sig+"=%s" % str(ipmasks))
	return ipmasks


def get_iface(ip):
	if not ip:
		return	None
	ip_parts = ip.split(".")
	if len(ip_parts)!=4:
		logger.serror("invalid ip! (%d parts)" % len(ip_parts), ip)
		return	None

	best_match = None
	for (iface, ipmasks) in iface_ipmasks.items():
		for (test_ip,mask) in ipmasks:
			if test_ip == ip:
				#exact match
				logger.sdebug("=%s" % iface, ip)
				return	iface
			test_ip_parts = test_ip.split(".")
			mask_parts = mask.split(".")
			if len(test_ip_parts)!=4 or len(mask_parts)!=4:
				logger.serror("incorrect ip or mask: %s/%s" % (test_ip, mask), ip)
			match = True
			try:
				for i in [0,1,2,3]:
					mask_part = int(mask_parts[i])
					ip_part = int(ip_parts[i]) & mask_part
					test_ip_part = int(test_ip_parts[i]) & mask_part
					if ip_part!=test_ip_part:
						match = False
						break
				if match:
					best_match = iface
			except Exception, e:
				logger.serr("error parsing ip (%s) or its mask (%s)" % (test_ip, mask), e, ip)
	logger.sdebug("=%s" % best_match, ip)
	return	best_match



# Found this recipe here:
# http://code.activestate.com/recipes/442490/
if_nametoindex = None
if_indextoname = None
if if_nametoindex is None and not WIN32:
	library = "libc.so.6"
	if OSX:
		library = "/usr/lib/libc.dylib"
	elif sys.platform.startswith("sunos"):
		library = "libsocket.so.1"
	elif sys.platform.startswith("freebsd"):
		library = "/usr/lib/libc.so"
	elif sys.platform.startswith("openbsd"):
		library = "libc.so"
	try:
		from ctypes import cdll, CDLL, c_char_p, c_uint, create_string_buffer
		cdll.LoadLibrary(library)
		#<CDLL 'libc.so.6', handle 7fcac419b000 at 7fcac1ab0c10>
		_libc = CDLL(library)
		logger.slog("successfully loaded C library from "+library)
	except ImportError, e:
		logger.error("loading "+library, e)
	except OSError, e:
		logger.error("loading "+library, e)
	else:
		_libc.if_indextoname.restype = c_char_p
		def if_nametoindex(interfaceName):
			return _libc.if_nametoindex(interfaceName)
		def if_indextoname(index):
			s = create_string_buffer('\000' * 256)
			return _libc.if_indextoname(c_uint(index), s)

def is_localhost(hostname):
	if not hostname:
		return False
	if hostname == "127.0.0.1" or hostname == "localhost":
		return True
	if hostname == HOSTNAME or hostname == "%s.local" % HOSTNAME or hostname == "%s.local." % HOSTNAME:
		return True
	for ip in get_bind_IPs():
		if ip == hostname:
			return True
	return False


def get_interface_speed(iface_index, iface, default_value):
	""" Returns the speed of the interface given, or the default_value if we cant find out what it is
		This is only implemented for win32 via wmi so far...
		Must be called from the main thread
	"""
	logger.sdebug(None, iface_index, iface, default_value)
	value = default_value
	if WIN32:
		#try to get perf info from wmi
		#the interface index does not seem to match what we find, so dont try to match that
		try:
			import wmi		#@UnresolvedImport
			c = wmi.WMI()
			bw = []
			for nic in c.Win32_NetworkAdapterConfiguration(IPEnabled=True):
				logger.sdebug("nic=%s" % nic, iface_index, iface, default_value)
				for perf in c.Win32_PerfRawData_Tcpip_NetworkInterface(Name=nic.Description):
					logger.sdebug("perf=%s" % perf, iface_index, iface, default_value)
					logger.sdebug("bandwidth(%s)=%s" % (nic.Description, perf.CurrentBandwidth), iface_index, iface, default_value)
					bw.append(perf.CurrentBandwidth)
			if len(bw)==1 and bw[0]>1000:
				#if there's only one value, return that
				value = bw[0]
		except Exception, e:
			logger.serr(None, e, iface_index, iface, default_value)
	logger.slog("=%s" % value, iface_index, iface, default_value)
	return	value






def tcp_listen_on(factory, listen_on_str):
	# listen_on_str examples: "127.0.0.1:12321, 192.168.0.2:0", "0.0.0.0:", "*:", "*:45443", "eth0:,eth1:", "eth*", "wlan0:1234,eth0:"
	listen_on = []
	logger.sdebug(None, factory, listen_on_str)
	for spec in listen_on_str.split(","):
		spec = spec.replace(" ", "")
		pos = spec.find(":")
		if pos<0:
			logger.serror("Invalid listen_on spec, ':' separator not found in '%s', ignoring it." % spec, factory, listen_on_str)
			continue
		host = spec[:pos]
		if host=="*" or host=="":
			host = "0.0.0.0"
		if host[0].isalpha():
			hosts = []
			#must be an interface name(s)
			_re = re.compile(host.replace("*", ".*"))
			iface_ipmasks = get_iface_ipmasks()
			for iface,ipmasks in iface_ipmasks.items():
				if _re.match(iface):
					logger.sdebug("found interface %s matching spec %s, will listen on %s" % (iface, host, csv_list(ipmasks)), factory, listen_on_str)
					for ip,_ in ipmasks:
						hosts.append(ip)
			if len(hosts)==0:
				logger.serror("not found any matching IPs for %s" % host, factory, listen_on_str)
		else:
			#just the one IP specified
			hosts = [host]

		#parse port (if set)
		port_str = spec[pos+1:]
		port = 0
		if len(port_str)>0:
			port = int(port_str)

		for host in hosts:
			p = listenTCP(port, factory, interface=host)
			if port <= 0:
				port = p.getHost().port
				#logger.sdebug("using assigned port=%d for %s" % (port, spec), factory, listen_on_str)
			get_port_mapper().add_taken_port(port)
			listen_on.append((host, port))
			logger.sdebug("listening on '%s' port '%d'" % (host, port))
	return listen_on



class ConnectTestClient(protocol.Protocol):
	def connectionMade(self):
		logger.sdebug("ConnectTestClient closing connection immediately")
		self.transport.loseConnection()

class ConnectTestFactory(protocol.ClientFactory):
	protocol = ConnectTestClient

	def __init__(self, max_attempts, success_callback, error_callback, abort_test, host, port):
		Logger(self, log_colour=Logger.MAGENTA)
		self.max_attempts = max_attempts
		self.success_callback = success_callback
		self.error_callback = error_callback
		self.abort_test = abort_test
		self.host = host
		self.port = port
		self.attempt = 0
		self.wait = 1

	def clientConnectionFailed(self, connector, reason):
		if self.abort_test and self.abort_test():
			self.sdebug("aborting %s:%s test with %s" % (self.host, self.port, self.abort_test), connector, reason)
			return
		if self.attempt<self.max_attempts:
			self.attempt += 1
			self.sdebug("retrying %s:%s (attempt=%s)" % (self.host, self.port, self.attempt), connector, reason)
			callLater(self.wait, connector.connect)
		else:
			self.slog("connection to %s:%s failed (%s attempts), calling %s " % (self.host, self.port, self.attempt, self.error_callback), connector, reason)
			if self.error_callback:
				self.error_callback()

	def clientConnectionLost(self, connector, reason):
		self.slog("successfully tested connection to %s:%s calling=%s" % (self.host, self.port, self.success_callback), connector, reason)
		if self.success_callback:
			self.success_callback()

def wait_for_socket(host, port, max_wait=10, success_callback=None, error_callback=None, abort_test=None, timeout=1):
	""" retries up to max_wait times to connect with a timeout then fires the appropriate callback """
	logger.slog(None, host, port, max_wait, success_callback, error_callback, abort_test)
	if host == "0.0.0.0":
		host = LOCALHOST
	factory = ConnectTestFactory(max_wait, success_callback, error_callback, abort_test, host, port)
	connectTCP(host, port, factory, timeout=timeout)




port_mapper = None
def get_port_mapper():
	global port_mapper
	if not port_mapper:
		port_mapper = PortMapper()
	return port_mapper

class PortMapper:
	def __init__(self):
		Logger("PortMapper.", log_colour=Logger.MAGENTA).add_methods(self)
		self.index = PORT_START
		self.free_ports = []
		self.taken_ports = []
		self.free_threshold = 500		#start freeing ports after 500 are used
		self.locked_ports = {}
		self.blacklisted_ports = []
		self.load_blacklisted_ports()

	def load_blacklisted_ports(self):
		"""
		Loads a list of blacklisted ports that should never be used.
		This list can be generated on machines with SELinux enabled with (as root):
		semanage port -l  | tail -n +2 | awk '{$1=$2=""}1' | sed 's+^  ++g' >> /etc/winswitch/ports.conf
		Or you can just create it by hand.
		"""
		def add_to_blacklist(port):
			if not port:
				return
			iport = int(port)
			if iport not in self.blacklisted_ports:
				self.blacklisted_ports.append(iport)

		def parse_port_spec(line):
			try:
				if line.find(",")>0:
					values = line.split(",")
					for v in values:
						parse_port_spec(v)
				elif line.find("-")>0:
					#range:
					start = int(line[:line.find("-")].strip())
					end = int(line[line.find("-")+1:].strip())
					while start <= end:
						add_to_blacklist(start)
						start += 1
				else:
					add_to_blacklist(line.strip())
			except Exception, e:
				self.serr("invalid line", e, no_newlines(line))

		self.slog("this may take a few seconds... please be patient")
		try:
			from winswitch.util.file_io import get_server_blocked_ports
			filename = get_server_blocked_ports()
			if is_valid_file(filename):
				inp = open(filename, "r")
				self.sdebug("loading from file %s" % filename)
				lines = inp.readlines()
				self.sdebug("loaded %s lines" % len(lines))
				for line in lines:
					if line.startswith("#"):
						continue
					line = no_newlines(line).strip()
					if not line:
						continue
					parse_port_spec(line)
			else:
				self.slog("blacklisted ports file '%s' does not exist" % filename)
		except Exception, e:
			self.serr(None, e)
		self.slog("blacklisted_ports above 1024: "+csv_list([x for x in self.blacklisted_ports if x>1023]))

	def add_taken_port(self, *ports):
		for port in ports:
			self.do_add_taken_port(port)
	def do_add_taken_port(self, port):
		iport = int(port)
		if iport not in self.taken_ports:
			self.taken_ports.append(iport)

	def free_X_display(self, display):
		d_no = int(display[1:])
		self.free_port(X_PORT_BASE + d_no)

	def free_port(self, port):
		iport = int(port)
		self.sdebug(None, port)
		if iport<1024:
			raise Exception("cannot free ports below 1024: we should not use them as client ports!")
		if iport in self.taken_ports and iport not in self.blacklisted_ports:
			self.free_ports.append(iport)

	def get_free_port(self, base, offset=0, host="0.0.0.0", step_start=0):
		if len(self.taken_ports)>self.free_threshold:
			self.sdebug("threshold %s reached, freeing some ports" % self.free_threshold, base, offset, host)
			if len(self.free_ports)>0:
				for port in self.free_ports:
					if port in self.taken_ports:
						self.taken_ports.remove(port)
				self.free_ports = []
			else:
				#just remove a few and hope they have become free... (they will be re-tested)
				for _ in range(10):
					index = int( (random.random()*len(self.taken_ports)) % len(self.taken_ports) )
					if index not in self.blacklisted_ports:
						self.taken_ports.pop(index)
		attempts = 0			#total attempts
		take_attempts = 0		#count attempts to bind to port already taken
		step = step_start		#increment for index (step_start=0: dont increment until we encounter an already taken port)
		while attempts<500:
			port = int(base + offset + self.index)
			try:
				while port >=  65536:
					port = (port % 65536) + 1024
				if port not in self.taken_ports and port not in self.blacklisted_ports:
					self.do_add_taken_port(port)
					if self.is_really_free(host, port):
						return port
					else:
						take_attempts += 1
			finally:
				self.index += step
			if take_attempts>5:		#too many taken ports, let's skip a few
				step = 1+int(10*random.random())
			else:
				step = 1
			attempts += 1
		raise Exception("Failed to find a free port!")

	def is_really_free(self, host, port):
		try:
			sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
			sock.bind((host, int(port)))
			ret = True
		except Exception, e:
			self.slog("%s" % e, host, port)
			ret = False
		if sock:
			try:
				sock.close()
			except Exception, e:
				self.serr("error closing socket %s" % sock, e, host, port)
		return	ret

	def get_free_command_port(self):
		return	self.get_free_port(COMMAND_PORT_BASE)

	def get_free_display_tunnel_port(self):
		return self.get_free_port(DISPLAY_TUNNEL_PORT_BASE)

	def get_free_X_display(self):
		return self.get_free_port(X_PORT_BASE, step_start=1) - X_PORT_BASE

	def get_free_Xnest_display(self):
		return self.get_free_port(X_PORT_BASE, XNEST_OFFSET) - X_PORT_BASE

	def get_free_ipp_tunnel_port(self):
		return self.get_free_port(IPP_TUNNEL_PORT_BASE)

	def get_free_sound_tunnel_port(self):
		return self.get_free_port(PULSE_TUNNEL_PORT_BASE)

	def get_free_samba_tunnel_port(self):
		return self.get_free_port(SAMBA_TUNNEL_PORT_BASE)


	def get_sound_tunnel_port(self):
		return	self.get_free_port(PULSE_TUNNEL_PORT_BASE)
	def get_locked_sound_tunnel_port(self):
		return	self.get_locked_port(PULSE_TUNNEL_PORT_BASE)

	def get_locked_port(self, base, offset=0, host="0.0.0.0", attempt=0):
		"""
		Returns a port which has an associated ReservePortFactory/ReservePortChannel attached.
		This guarantees that the operating system cannot assign it to some other process/task.
		"""
		port = self.get_free_port(base, offset, host)
		factory = ReservePortFactory(port)
		try:
			ListeningPort = listenTCP(port, factory, interface=host)
			self.locked_ports[port] = ListeningPort
			return	port
		except twisted_error.CannotListenError, e:
			self.serror("failure on port %s: %s" % (port, e), base, offset, host)
			if attempt<5:
				return	self.get_locked_port(base, offset, host, attempt+1)
			else:
				raise Exception("failed to get locked port (base=%s, offset=%s, host=%s, attempt=%s" % (base, offset, host, attempt))

	def unlock_port(self, port):
		"""
		Unlock a port reserved via get_locked_port()
		"""
		ListeningPort = self.locked_ports[port]
		ListeningPort.stopListening()


class ReservePortChannel(twisted_protocols_basic.LineReceiver):

	def __init__ (self):
		Logger(self)

	def __str__(self):
		return	"ReservePortChannel(%s)" % self.factory.port

	def connectionMade(self):
		self.serror("unexpected connection - not a real service on port=%s, connection from: %s" % (self.factory.port, self.transport.getPeer()))
		self.transport.loseConnection()

	def connectionLost(self, reason):
		self.serror()

class ReservePortFactory(protocol.ClientFactory):
	# the class of the protocol to build when new connection is made
	protocol = ReservePortChannel

	def __init__ (self, port):
		Logger(self)
		self.port = port
		self.sdebug(None, port)

	def __str__(self):
		return	"ReservePortFactory(%s)" % self.port

	def clientConnectionLost(self, connector, reason):
		self.serror(None, connector, reason)

	def clientConnectionFailed(self, connector, reason):
		self.error(None, connector, reason)
