#!/usr/bin/env python
"""
This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

Written (W) 2007,2011 Gunnar Raetsch
Written (W) 2006-2009 Soeren Sonnenburg
Copyright (C) 2006-2010 Fraunhofer Institute FIRST and Max-Planck-Society
"""

try:
	import os
	import os.path
	import sys
	import pickle
	import bz2
	import numpy
	import optparse
	import array
	import math

	import genomic
	import model
	import seqdict
	import shogun.Kernel

	d=shogun.Kernel.WeightedDegreeStringKernel(1)
	if (d.version.get_version_revision() < 2997):
		print
		print "ERROR: SHOGUN VERSION 0.6.2 or later required"
		print
		sys.exit(1)
	from signal_detectors import signal_detectors
except ImportError:
	print
	print "ERROR IMPORTING MODULES, MAKE SURE YOU HAVE SHOGUN INSTALLED"
	print
	sys.exit(1)

asp_version='v0.3'

class asp:
	def __init__(self):
		self.model = None
		self.signal = None
		self.model_name = None

	def sigmoid_transform(self, x):
		return 1/(1+math.exp(-(x+2))) ;

	def load_model(self, filename):
		self.model_name = filename
		f=None
		picklefile=filename+'.pickle'
		if os.path.isfile(picklefile):
			self.model=pickle.load(file(picklefile))
		else:
			if filename.endswith('.bz2'):
				f=bz2.BZ2File(filename);
			else:
				f=file(filename);

			self.model=model.parse_file(f)
			f.close()

			f=file(picklefile,'w')
			pickle.dump(self.model, f)
			f.close()

		self.signal=signal_detectors(self.model)

	def write_gff(self, outfile, preds, name, score_type, skipheader, strand):
		genomic.write_gff_header(outfile, ('asp',asp_version + ' ' + self.model_name),
				('DNA', name))

		for i in xrange(len(preds[0])):
			d=dict()
			d['seqname']=name
			d['source']='asp'
			d['feature']=preds[0][i]
			d['start']=preds[1][i]
			d['end']=preds[1][i]+1
			if score_type=='output':
				d['score']=preds[2][i]
			else:
				d['score']=self.sigmoid_transform(preds[2][i])
			d['strand']=strand
			d['frame']=0
			genomic.write_gff_line(outfile, d)

	def write_spf(self, outfile, preds, name, score_type, skipheader, strand):
		genomic.write_spf_header(outfile, ('asp', asp_version + ' ' + self.model_name),
				('DNA', name))

		for i in xrange(len(preds[0])):
			d=dict()
			d['seqname']=name
			d['source']=score_type
			if preds[0][i]=='AG':
				d['feature']='acc'
				if strand=='+':
					d['position']=preds[1][i]+2
				else:
					d['position']=preds[1][i]-1
			else:
				d['feature']='don'
				if strand=='+':
					d['position']=preds[1][i]
				else:
					d['position']=preds[1][i]+1
			if score_type=='output':
				d['score']=preds[2][i]
			else:
				d['score']=self.sigmoid_transform(preds[2][i])
			d['strand']=strand
			genomic.write_spf_line(outfile, d)

	def write_binary(self, preds, site, strand, score_type, binary_out, binary_pos):
		out=array.array('f')
		if score_type=='output':
			out.fromlist(preds[2])
		else:
			outputs=[self.sigmoid_transform(o) for o in preds[2]] ;
			out.fromlist(outputs)

		# move positions consistent with spf output
		if site=='acc':
			if strand=='+':
				p=[i+2 for i in preds[1]]
			else:
				p=[i-1 for i in preds[1]]
		else:
			if strand=='+':
				p=[i for i in preds[1]]
			else:
				p=[i+1 for i in preds[1]]

		pos=array.array('i')
		pos.fromlist(p) ;
		out.tofile(binary_out)
		pos.tofile(binary_pos)


	def predict_file(self, fname, (start,end), output_format, score_type, strand='+'):
		skipheader=False
		fasta_dict = genomic.read_fasta(file(fname))

		if strand=='-':
			for k, kseq in fasta_dict.ordered_items():
				fasta_dict[k]=genomic.reverse_complement(kseq)

		sys.stdout.write('found fasta file with ' + `len(fasta_dict)` + ' sequence(s) (strand=%s)\n' % strand)
		seqs= seqdict.seqdict(fasta_dict, (start,end))

		#get donor/acceptor signal predictions for all sequences
		self.signal.predict_acceptor_sites_from_seqdict(seqs)
		self.signal.predict_donor_sites_from_seqdict(seqs)

		contig_no = 0 ;
		for seq in seqs:
			contig_no = contig_no + 1

			l=len(seq.preds['donor'].get_positions())
			p=[i+1 for i in seq.preds['donor'].get_positions()]
			s=seq.preds['donor'].get_scores()
			f=[]
			for pos in p:
				if seq.seq[pos-1:pos+1]=='GT':
					f.append(('GT'))
				else:
					f.append(('GC'))
					assert(seq.seq[pos-1:pos+1]=='GC')

			if strand=='-':
				p=p[len(p)::-1]
				p=[len(seq.seq)-i for i in p]
				s=s[len(s)::-1]
				f=f[len(f)::-1]

			don_preds=(f,p,s)

			l=len(seq.preds['acceptor'].get_positions())
			p=[i-1 for i in seq.preds['acceptor'].get_positions()]
			s=seq.preds['acceptor'].get_scores()
			f=l*['AG']

			if strand=='-':
				p=p[len(p)::-1]
				p=[len(seq.seq)-i for i in p]
				s=s[len(s)::-1]
				f=f[len(f)::-1]

			acc_preds=(f,p,s)

			if output_format == 'binary':
				assert(len(binary_basename)>0)
				binary_out=file(binary_basename+'/acc/contig_%i%c.%s' % (contig_no, strand, score_type), 'w')
				binary_pos=file(binary_basename+'/acc/contig_%i%c.pos' % (contig_no, strand), 'w')
				self.write_binary(acc_preds, 'acc', strand, score_type, binary_out, binary_pos)
				binary_out.close()
				binary_pos.close()
				binary_out=file(binary_basename+'/don/contig_%i%c.%s' % (contig_no, strand, score_type), 'w')
				binary_pos=file(binary_basename+'/don/contig_%i%c.pos' % (contig_no, strand), 'w')
				self.write_binary(don_preds, 'don', strand, score_type, binary_out, binary_pos)
				binary_out.close()
				binary_pos.close()
			else:
				if output_format == 'gff':
					self.write_gff(outfile, acc_preds, seq.name, score_type, skipheader, strand)
					self.write_gff(outfile, don_preds, seq.name, score_type, skipheader, strand)
				else:
					if output_format == 'spf':
						self.write_spf(outfile, acc_preds, seq.name, score_type, skipheader, strand)
						self.write_spf(outfile, don_preds, seq.name, score_type, skipheader, strand)


def print_version():
	sys.stdout.write('asp '+asp_version+'\n')

def parse_options():
	parser = optparse.OptionParser(usage="usage: %prog [options] seq.fa")

	parser.add_option("-g", "--gff-file", type="str",
							  help="File to write the results in GFF format to the given file")
	parser.add_option("-s", "--spf-file", type="str", default='stdout',
							  help="File to write the results in SPF format to the given file")
	parser.add_option("-b", "--binary-basename", type="str",
							  help="Write results in binary format to file starting with this basename")
	parser.add_option("-v", "--version", dest='version', default=False, action='store_true',
							  help="Show some more information")
	parser.add_option("-t", "--transform", dest='transform', default=False, action='store_true',
							  help="Apply sigmoid transform to scale predictions between 0 and 1")
	parser.add_option("--start", type="int", default=499,
							  help="coding start (zero based, relative to sequence start)")
	parser.add_option("--stop", type="int", default=-499,
							  help="""coding stop (zero based, if positive relative to
							  sequence start, if negative relative to sequence end)""")
	parser.add_option("--organism", type="str", default='Worm',
							  help="""use asp model for organism when predicting
							  (one of Cress, Fish, Fly, Human, Worm)""")

	(options, args) = parser.parse_args()
	if options.version:
		print_version()
		sys.exit(0)

	score_type = 'output'
	if options.transform!=False:
		score_type = 'Conf_cum' ;

	if len(args) != 1:
		parser.error("incorrect number of arguments")

	fafname=args[0]
	if not os.path.isfile(fafname):
		parser.error("fasta file does not exist")

	modelfname = 'data/%s.dat.bz2' % options.organism
	print "loading model file " + modelfname,

	if not os.path.isfile(modelfname):
		print "...not found!\n"
		parser.error("""model should be one of:

Cress, Fish, Fly, Human, Worm
""")

	if (options.gff_file and (options.spf_file!='stdout' or options.binary_basename)) or (options.spf_file!='stdout' and (options.gff_file or options.binary_basename)):
		parser.error("Only one of the options --binary-basename, --spf-file, or --gff-file may be given")

	if (options.spf_file!='stdout' or (not options.binary_basename and not options.gff_file)):
		output_format='spf'
		outfile_fname = options.spf_file
	if (options.gff_file):
		output_format='gff'
		outfile_fname = options.gff_file
	if (options.binary_basename):
		output_format='binary'

	if output_format!='binary':
		if outfile_fname == 'stdout':
			outfile=sys.stdout
		else:
			try:
				outfile=file(outfile_fname,'w')
			except IOError:
				parser.error("could not open %s for writing" % outfile_fname)

	if output_format=='binary':
		outfile = None
		if os.system('mkdir -p %s/acc' % options.binary_basename) != 0:
			parser.error("could not create directory %s/acc" % options.binary_basename)
		if os.system('mkdir -p %s/don' % options.binary_basename) != 0:
			parser.error("could not create directory %s/don" % options.binary_basename)

	if options.start<80:
		parser.error("--start value must be >=80")

	if options.stop > 0 and options.start >= options.stop - 80:
		parser.error("--stop value must be > start + 80")

	if options.stop < 0 and options.stop > -80:
		parser.error("--stop value must be <= - 80")

	# shift the start and stop a bit
	options.start -= 1 ;
	options.stop -= 1 ;

	return ((options.start,options.stop), fafname, modelfname, output_format, score_type, outfile, options.binary_basename)


if __name__ == '__main__':
	(startstop, fafname, modelfname, output_format, score_type, outfile, binary_basename ) = parse_options()
	p=asp()
	p.load_model(modelfname);
	p.predict_file(fafname, startstop, output_format, score_type, '+')
	p.predict_file(fafname, startstop, output_format, score_type, '-')
