# kate: syntax Python; space-indent off; tab-width 4; indent-width 4;
import shutil
import textwrap
import json
from sqt.dna import reverse_complement
from sqt import FastaReader, SequenceReader
import igdiscover
from igdiscover.utils import relative_symlink
from igdiscover.config import Config, GlobalConfig


try:
	config = Config.from_default_path()
except FileNotFoundError as e:
	sys.exit("Pipeline configuration file {!r} not found. Please create it!".format(e.filename))

# Use pigz (parallel gzip) if available
GZIP = 'pigz' if shutil.which('pigz') is not None else 'gzip'

PREPROCESSED_READS = 'reads/sequences.fasta.gz'

if config.debug:
	# Do not delete intermediate files when debugging
	temp = lambda x: x

# Targets for each iteration
ITERATION_TARGETS = [
	'clusterplots/done',
	'errorhistograms.pdf',
	'v-shm-distributions.pdf',
] + expand(['expressed_{gene}.tab', 'expressed_{gene}.pdf', 'dendrogram_{gene}.pdf'], gene=['V', 'D', 'J'])

# Targets for non-final iterations
DISCOVERY_TARGETS = [
	'candidates.tab',
	'new_V_germline.fasta',
	'new_V_pregermline.fasta',
]
TARGETS = expand('iteration-{nr:02d}/{path}', nr=range(1, config.iterations+1), path=ITERATION_TARGETS + DISCOVERY_TARGETS)
TARGETS += [
	'stats/readlengths.pdf',
	'stats/merging-successful',
	'stats/trimming-successful',
	'stats/stats_nofinal.json'
]
if config.iterations >= 1:
	TARGETS += ['iteration-01/new_J.fasta']

FINAL_TARGETS = expand('final/{path}', path=ITERATION_TARGETS) + ['stats/stats.json']


rule all:
	input:
		TARGETS + FINAL_TARGETS
	message: "IgDiscover finished."


rule nofinal:
	input:
		TARGETS


if config.limit:
	rule limit_reads_gz:
		output: 'reads/1-limited.{nr,([12]\\.|)}{ext,(fasta|fastq)}.gz'
		input: 'reads.{nr}{ext}.gz'
		shell:
			'sqt fastxmod -w 0 --limit {config.limit} {input} | {GZIP} > {output}'

	rule limit_reads:
		output: 'reads/1-limited.{nr,([12]\\.|)}{ext,(fasta|fastq)}.gz'
		input: 'reads.{nr}{ext}'
		shell:
			'sqt fastxmod -w 0 --limit {config.limit} {input} | {GZIP} > {output}'

else:
	rule symlink_limited:
		output: fastaq='reads/1-limited.{nr,([12]\\.|)}{ext,(fasta|fastq)}.gz'
		input: fastaq='reads.{nr}{ext}.gz'
		resources: time=1
		run:
			relative_symlink(input.fastaq, output.fastaq, force=True)

	# TODO compressing the input file is an unnecessary step
	rule gzip_limited:
		output: fastaq='reads/1-limited.{nr,([12]\\.|)}{ext,(fasta|fastq)}.gz'
		input: fastaq='reads.{nr}{ext}'
		shell:
			'{GZIP} < {input} > {output}'


# After the rules above, we either end up with
#
# 'reads/1-limited.1.fastq.gz' and 'reads/1-limited.2.fastq.gz'
# or
# 'reads/1-limited.fasta.gz'
# or
# 'reads/1-limited.fastq.gz'


if config.merge_program == 'flash':
	rule flash_merge:
		"""Use FLASH to merge paired-end reads"""
		output: fastqgz='reads/2-merged.fastq.gz', log='reads/2-flash.log'
		input: 'reads/1-limited.1.fastq.gz', 'reads/1-limited.2.fastq.gz'
		resources: time=60
		threads: 8
		shell:
			# -M: maximal overlap (2x300, 420-450bp expected fragment size)
			"time flash -t {threads} -c -M {config.flash_maximum_overlap} {input} 2> "
			">(tee {output.log} >&2) | {GZIP} > {output.fastqgz}"

	rule parse_flash_stats:
		input: log='reads/2-flash.log'
		output:
			json='stats/reads.json'
		run:
			total_ex = re.compile(r'\[FLASH\]\s*Total reads:\s*([0-9]+)')
			merged_ex = re.compile(r'\[FLASH\]\s*Combined reads:\s*([0-9]+)')
			with open(input.log) as f:
				for line in f:
					match = total_ex.search(line)
					if match:
						total = int(match.group(1))
						continue
					match = merged_ex.search(line)
					if match:
						merged = int(match.group(1))
						break
				else:
					sys.exit('Could not parse the FLASH log file')
			d = OrderedDict({'total': total})
			d['merged'] = merged
			d['merging_was_done'] = True
			with open(output.json, 'w') as f:
				json.dump(d, f)


elif config.merge_program == 'pear':

	rule pear_merge:
		"""Use pear to merge paired-end reads"""
		output:
			fastq='reads/2-merged.fastq.gz',
			log='reads/2-pear.log'
		input:
			fastq1='reads/1-limited.1.fastq.gz',
			fastq2='reads/1-limited.2.fastq.gz'
		log: 'reads/2-pear.log'

		resources: time=60
		threads: 20
		shell:
			"igdiscover merge -j {threads} {input.fastq1} {input.fastq2} {output.fastq} | tee {log}"

	rule parse_pear_stats:
		input: log='reads/2-pear.log'
		output:
			json='stats/reads.json'
		run:
			expression = re.compile(r"Assembled reads \.*: (?P<merged>[0-9,]*) / (?P<total>[0-9,]*)")
			with open(input.log) as f:
				for line in f:
					match = expression.search(line)
					if match:
						merged = int(match.group('merged').replace(',', ''))
						total = int(match.group('total').replace(',', ''))
						break
				else:
					sys.exit('Could not parse the PEAR log file')
			d = OrderedDict({'total': total})
			d['merged'] = merged
			d['merging_was_done'] = True
			with open(output.json, 'w') as f:
				json.dump(d, f)
else:
	sys.exit("merge_program {config.merge_program!r} given in configuration file not recognized".format(config=config))


# This rule applies only when the input is single-end or already merged
rule symlink_merged:
	output:
		fastaq='reads/2-merged.{ext,(fasta|fastq)}.gz'
	input: fastaq='reads/1-limited.{ext}.gz'
	run:
		relative_symlink(input.fastaq, output.fastaq, force=True)


# After the rules above, we end up with
#
# 'reads/2-merged.fasta.gz'
# or
# 'reads/2-merged.fastq.gz'


rule read_stats_single_fasta:
	"""Compute statistics if no merging was done (FASTA input)"""
	output: json='stats/reads.json',
	input: fastagz='reads/1-limited.fasta.gz'
	run:
		total = count_sequences(input.fastagz)
		d = OrderedDict({'total': total})
		d['merged'] = total
		d['merging_was_done'] = False
		with open(output.json, 'w') as f:
			json.dump(d, f)


rule read_stats_single_fastq:
	"""Compute statistics if no merging was done (FASTQ input)"""
	output: json='stats/reads.json',
	input: fastagz='reads/1-limited.fastq.gz'
	run:
		total = count_sequences(input.fastagz)
		d = OrderedDict({'total': total})
		d['merged'] = total
		d['merging_was_done'] = False
		with open(output.json, 'w') as f:
			json.dump(d, f)


rule check_merging:
	"""Ensure the merging succeeded"""
	output: success='stats/merging-successful'
	input:
		json='stats/reads.json'
	run:
		with open(input.json) as f:
			d = json.load(f)
		total = d['total']
		merged = d['merged']
		if total == 0 or merged / total >= 0.3:
			with open(output.success, 'w') as f:
				print('This marker file exists if at least 30% of the input '
					'reads could be merged.', file=f)
		else:
			sys.exit('Less than 30% of the input reads could be merged. Please '
				'check whether there is an issue with your input data. To skip '
				'this check, create the file "stats/merging-successful" and '
				're-start "igdiscover run".')


rule merged_read_length_histogram:
	output:
		txt="stats/merged.readlengths.txt",
		pdf="stats/merged.readlengths.pdf"
	input:
		fastq='reads/2-merged.fastq.gz'
	shell:
		"""sqt readlenhisto --bins 100 --left {config.minimum_merged_read_length} --title "Lengths of merged reads" --plot {output.pdf} {input}  > {output.txt}"""


rule read_length_histogram:
	output:
		txt="stats/readlengths.txt",
		pdf="stats/readlengths.pdf"
	input:
		fastq=PREPROCESSED_READS
	shell:
		"""sqt readlenhisto --bins 100 --left {config.minimum_merged_read_length} --title "Lengths of pre-processed reads" --plot {output.pdf} {input}  > {output.txt}"""


rule reads_stats_fasta:
	"""
	TODO implement this
	"""
	output: txt="stats/reads.txt"
	input:
		merged='reads/1-limited.fasta.gz'
	shell: "touch {output}"


# Remove primer sequences

if config.forward_primers:
	# At least one forward primer is to be removed
	rule trim_forward_primers:
		output: fastaq=temp('reads/3-forward-primer-trimmed.{ext,(fasta|fastq)}.gz')
		input: fastaq='reads/2-merged.{ext}.gz', mergesuccess='stats/merging-successful'
		resources: time=120
		log: 'reads/3-forward-primer-trimmed.{ext}.log'
		params:
			fwd_primers=''.join(' -g ^{}'.format(seq) for seq in config.forward_primers),
			rev_primers=''.join(' -a {}$'.format(reverse_complement(seq)) for seq in config.forward_primers) if not config.stranded else '',
		shell:
			"cutadapt --discard-untrimmed"
			"{params.fwd_primers}"
			"{params.rev_primers}"
			" -o {output.fastaq} {input.fastaq} | tee {log}"
else:
	# No trimming, just symlink the file
	rule dont_trim_forward_primers:
		output: fastaq='reads/3-forward-primer-trimmed.{ext,(fasta|fastq)}.gz'
		input: fastaq='reads/2-merged.{ext}.gz', mergesuccess='stats/merging-successful'
		resources: time=1
		run:
			relative_symlink(input.fastaq, output.fastaq, force=True)


if config.reverse_primers:
	# At least one reverse primer is to be removed
	rule trim_reverse_primers:
		output: fastaq='reads/4-trimmed.{ext,(fasta|fastq)}.gz'
		input: fastaq='reads/3-forward-primer-trimmed.{ext}.gz'
		resources: time=120
		log: 'reads/4-trimmed.{ext}.log'
		params:
			# Reverse primers should appear reverse-complemented at the 3' end
			# of the merged read.
			fwd_primers=''.join(' -a {}$'.format(reverse_complement(seq)) for seq in config.reverse_primers),
			rev_primers=''.join(' -g ^{}'.format(seq) for seq in config.reverse_primers) if not config.stranded else ''
		shell:
			"cutadapt --discard-untrimmed"
			"{params.fwd_primers}"
			"{params.rev_primers}"
			" -o {output.fastaq} {input.fastaq} | tee {log}"

else:
	# No trimming, just symlink the file
	rule dont_trim_reverse_primers:
		output: fastaq='reads/4-trimmed.{ext,(fasta|fastq)}.gz'
		input: fastaq='reads/3-forward-primer-trimmed.{ext}.gz'
		resources: time=1
		run:
			relative_symlink(input.fastaq, output.fastaq, force=True)


rule trimmed_fasta_stats:
	output: json='stats/trimmed.json',
	input: fastagz='reads/4-trimmed.fasta.gz'
	run:
		with open(output.json, 'w') as f:
			json.dump({'trimmed': count_sequences(input.fastagz)}, f)


rule trimmed_fastq_stats:
	output: json='stats/trimmed.json',
	input: fastqgz='reads/4-trimmed.fastq.gz'
	run:
		with open(output.json, 'w') as f:
			json.dump({'trimmed': count_sequences(input.fastqgz)}, f)


rule check_trimming:
	"""Ensure that some reads are left after trimming"""
	output: success='stats/trimming-successful'
	input:
		reads_json='stats/reads.json',
		trimmed_json='stats/trimmed.json'
	run:
		with open(input.reads_json) as f:
			total = json.load(f)['total']
		with open(input.trimmed_json) as f:
			trimmed = json.load(f)['trimmed']
		if total == 0 or trimmed / total >= 0.1:
			with open(output.success, 'w') as f:
				print('This marker file exists if at least 10% of input '
					'reads contain the required primer sequences.', file=f)
		else:
			print(*textwrap.wrap(
				'Less than 10% of the input reads contain the required primer '
				'sequences. Please check whether you have specified the '
				'correct primer sequences in the configuration file. To skip '
				'this check, create the file "stats/trimming-successful" and '
				're-start "igdiscover run".'), sep='\n')
			sys.exit(1)


def group_cdr3_arg():
	if not config.cdr3_location:
		cdr3_arg = ''
	elif config.cdr3_location == 'detect':
		cdr3_arg = ' --real-cdr3'
	else:
		cdr3_arg = ' --pseudo-cdr3={}:{}'.format(*config.cdr3_location)
	return cdr3_arg


for ext in ('fasta', 'fastq'):
	if config.barcode_length and config.barcode_consensus:
		rule:
			"""Group by barcode and CDR3 (also implicitly removes duplicates)"""
			output:
				fastagz=PREPROCESSED_READS,
				pdf="stats/groupsizes.pdf",
				groups="reads/4-groups.tab.gz",
				json="stats/groups.json"
			input:
				fastaq='reads/4-trimmed.{ext}.gz'.format(ext=ext), success='stats/trimming-successful'
			log: 'reads/4-sequences.fasta.gz.log'
			params:
				race_arg=' --trim-g' if config.race_g else '',
				cdr3_arg=group_cdr3_arg(),
			shell:
				"igdiscover group"
				"{params.cdr3_arg}{params.race_arg}"
				" --json={output.json}"
				" --minimum-length={config.minimum_merged_read_length}"
				" --groups-output={output.groups}"
				" --barcode-length={config.barcode_length}"
				" --plot-sizes={output.pdf}"
				" {input.fastaq} 2> {log} | {GZIP} > {output.fastagz}"

	else:
		rule:
			"""Collapse identical sequences, remove barcodes"""
			output:
				fastagz=PREPROCESSED_READS,
				json="stats/groups.json"
			input:
				fastaq='reads/4-trimmed.{ext}.gz'.format(ext=ext), success='stats/trimming-successful',
			params:
				barcode_length=' --barcode-length={}'.format(config.barcode_length) if config.barcode_length else '',
				race_arg = ' --trim-g' if config.race_g else '',
			shell:
				"igdiscover dereplicate"
				"{params.barcode_length}{params.race_arg}"
				" --json={output.json}"
				" --minimum-length={config.minimum_merged_read_length}"
				" {input.fastaq} | {GZIP} > {output.fastagz}"


rule copy_d_database:
	"""Copy D gene database into the iteration folder"""
	output:
		fasta="{base}/database/D.fasta"
	input:
		fasta="database/D.fasta"
	shell:
		"cp -p {input} {output}"


rule vj_database_iteration_1:
	"""Copy original V or J gene database into the iteration 1 folder"""
	output:
		fasta="iteration-01/database/{gene,[VJ]}.fasta"
	input:
		fasta="database/{gene}.fasta"
	shell:
		"cp -p {input} {output}"


def ensure_fasta_not_empty(path, gene):
	with FastaReader(path) as fr:
		for _ in fr:
			has_records = True
			break
		else:
			has_records = False
	if not has_records:
		print(
			'ERROR: No {gene} genes were discovered in this iteration (file '
			'{path!r} is empty)! Cannot continue.\n'
			'Check whether the starting database is of the correct chain type '
			'(heavy, light lambda, light kappa). It needs to match the type '
			'of sequences you analyze.'.format(gene=gene, path=path), file=sys.stderr)
		sys.exit(1)


for i in range(2, config.iterations + 1):
	rule:
		output:
			fasta='iteration-{nr:02d}/database/V.fasta'.format(nr=i)
		input:
			fasta='iteration-{nr:02d}/new_V_pregermline.fasta'.format(nr=i-1)
		run:
			ensure_fasta_not_empty(input.fasta, 'V')
			shell("cp -p {input.fasta} {output.fasta}")

	rule:
		# Even with multiple iterations, J genes are discovered only once
		output:
			fasta='iteration-{nr:02d}/database/J.fasta'.format(nr=i)
		input:
			fasta='iteration-01/new_J.fasta' if config.j_discovery['propagate'] else 'database/J.fasta'
		run:
			ensure_fasta_not_empty(input.fasta, 'J')
			shell("cp -p {input.fasta} {output.fasta}")


# Rules for last iteration

if config.iterations == 0:
	# Copy over the input database (would be nice to avoid this)
	rule copy_database:
		output:
			fasta='final/database/{gene,[VJ]}.fasta'
		input:
			fasta='database/{gene}.fasta'
		shell:
			"cp -p {input.fasta} {output.fasta}"
else:
	rule copy_final_v_database:
		output:
			fasta='final/database/V.fasta'
		input:
			fasta='iteration-{nr:02d}/new_V_germline.fasta'.format(nr=config.iterations)
		run:
			ensure_fasta_not_empty(input.fasta, 'V')
			shell("cp -p {input.fasta} {output.fasta}")

	rule copy_final_j_database:
		output:
			fasta='final/database/J.fasta'
		input:
			fasta=('iteration-01/new_J.fasta'
					if config.j_discovery['propagate'] else 'database/J.fasta')
		run:
			ensure_fasta_not_empty(input.fasta, 'J')
			shell("cp -p {input.fasta} {output.fasta}")


rule igdiscover_igblast:
	output:
		tabgz="{dir}/assigned.tab.gz",
		json="{dir}/stats/assigned.json"
	input:
		fastagz=PREPROCESSED_READS,
		db_v="{dir}/database/V.fasta",
		db_d="{dir}/database/D.fasta",
		db_j="{dir}/database/J.fasta"
	params:
		penalty=' --penalty {}'.format(config.mismatch_penalty) if config.mismatch_penalty is not None else '',
		database='{dir}/database',
		species=' --species={}'.format(config.species) if config.species else '',
		sequence_type=' --sequence-type={}'.format(config.sequence_type),
		rename=' --rename {path!r}_'.format(path=os.path.basename(os.getcwd())) if config.rename else ''
	log: '{dir}/igblast.log'
	threads: 16
	shell:
		"time igdiscover igblast{params.sequence_type}{params.penalty}{params.rename} --threads={threads}"
		"{params.species} --stats={output.json} {params.database} {input.fastagz} | "
		"{GZIP} 2>&1 > {output.tabgz} | tee {log} >&2"


rule check_parsing:
	output:
		success="{dir}/stats/parsing-successful"
	input:
		json="{dir}/stats/assigned.json"
	run:
		with open(input.json) as f:
			d = json.load(f)
		n = d['total']
		detected_cdr3s = d['detected_cdr3s']
		if n == 0:
			print('No IgBLAST assignments found, something is wrong.')
			sys.exit(1)
		elif detected_cdr3s / n >= 0.1:
			with open(output.success, 'w') as f:
				print('This marker file exists if a CDR3 sequence could be '
					'detected for at least 10% of IgBLAST-assigned sequences.',
					file=f)
		else:
			print(*textwrap.wrap(
				'A CDR3 sequence could be detected in less than 10% of the '
				'IgBLAST-assigned sequences. Possibly there is a problem with '
				'the starting database. To skip this check and continue anyway, '
				'create the file "{}" and re-start "igdiscover run".'.format(
					output.success)), sep='\n')
			sys.exit(1)



rule igdiscover_filter:
	output:
		filtered="{dir}/filtered.tab.gz",
		json="{dir}/stats/filtered.json"
	input:
		assigned="{dir}/assigned.tab.gz",
		success="{dir}/stats/parsing-successful"
	run:
		conf = config.preprocessing_filter
		criteria = ['--v-coverage={}'.format(conf['v_coverage'])]
		criteria += ['--j-coverage={}'.format(conf['j_coverage'])]
		criteria += ['--v-evalue={}'.format(conf['v_evalue'])]
		criteria = ' '.join(criteria)
		shell("igdiscover filter --json={output.json} {criteria} {input.assigned} | {GZIP} > {output.filtered}")


rule igdiscover_exact:
	output:
		exact="{dir}/exact.tab"
	input:
		filtered="{dir}/filtered.tab.gz"
	shell:
		# extract rows where V_errors == 0
		"""zcat {input.filtered} |"""
		""" awk 'NR==1 {{ for(i=1;i<=NF;i++) if ($i == "V_errors") col=i}};NR==1 || $col == 0' > {output}"""


rule igdiscover_count:
	output:
		plot="{dir}/expressed_{gene,[VDJ]}.pdf",
		counts="{dir}/expressed_{gene}.tab"
	input:
		tab="{dir}/filtered.tab.gz"
	shell:
		"igdiscover count --gene={wildcards.gene} "
		"--allele-ratio=0.2 "
		"--plot={output.plot} {input.tab} > {output.counts}"


rule igdiscover_clusterplot:
	output:
		done="{dir}/clusterplots/done"
	input:
		tab="{dir}/filtered.tab.gz"
	params:
		clusterplots="{dir}/clusterplots/",
		ignore_j=' --ignore-J' if config.ignore_j else ''
	shell:
		"igdiscover clusterplot{params.ignore_j} {input.tab} {params.clusterplots} && touch {output.done}"


rule igdiscover_discover:
	"""Discover potential new V gene sequences"""
	output:
		tab="{dir}/candidates.tab",
		read_names="{dir}/read_names_map.tab",
	input:
		v_reference="{dir}/database/V.fasta",
		tab="{dir}/filtered.tab.gz"
	params:
		ignore_j=' --ignore-J' if config.ignore_j else '',
		seed=' --seed={}'.format(config.seed) if config.seed is not None else '',
		exact_copies=' --exact-copies={}'.format(config.exact_copies) if config.exact_copies is not None else ''
	threads: 128
	shell:
		"time igdiscover discover -j {threads}{params.seed}{params.ignore_j}{params.exact_copies}"
		" --d-coverage={config.d_coverage}"
		" --read-names={output.read_names}"
		" --subsample={config.subsample} --database={input.v_reference}"
		" {input.tab} > {output.tab}"


def db_whitelist_or_not(wildcards):
	filterconf = config.pre_germline_filter if wildcards.pre == 'pre' else config.germline_filter
	if filterconf['whitelist']:
		# Use original (non-iteration-specific) database as whitelist
		return 'database/V.fasta'
	else:
		return []


def germlinefilter_criteria(wildcards, input):
	nr = int(wildcards.nr, base=10)
	conf = config.pre_germline_filter if wildcards.pre == 'pre' else config.germline_filter
	criteria = []
	for path in [input.db_whitelist, input.whitelist]:
		if path:
			criteria += ['--whitelist=' + path]
	if conf['allow_stop']:
		criteria += ['--allow-stop']
	# if conf['allow_chimeras']:
	#	criteria += ['--allow-chimeras']
	criteria += ['--unique-CDR3={}'.format(conf['unique_cdr3s'])]
	criteria += ['--cluster-size={}'.format(conf['cluster_size'])]
	criteria += ['--unique-J={}'.format(conf['unique_js'])]
	criteria += ['--cross-mapping-ratio={}'.format(conf['cross_mapping_ratio'])]
	criteria += ['--clonotype-ratio={}'.format(conf['clonotype_ratio'])]
	criteria += ['--exact-ratio={}'.format(conf['exact_ratio'])]
	criteria += ['--cdr3-shared-ratio={}'.format(conf['cdr3_shared_ratio'])]
	criteria += ['--unique-D-ratio={}'.format(conf['unique_d_ratio'])]
	criteria += ['--unique-D-threshold={}'.format(conf['unique_d_threshold'])]
	return ' '.join(criteria)


rule igdiscover_germlinefilter:
	"""Construct a new database out of the discovered sequences"""
	output:
		tab='iteration-{nr}/new_V_{pre,(pre|)}germline.tab',
		fasta='iteration-{nr}/new_V_{pre,(pre|)}germline.fasta',
		annotated_tab='iteration-{nr}/annotated_V_{pre,(pre|)}germline.tab',
	input:
		tab='iteration-{nr}/candidates.tab',
		db_whitelist=db_whitelist_or_not,
		whitelist='whitelist.fasta' if os.path.exists('whitelist.fasta') else [],
	params:
		criteria=germlinefilter_criteria
	log:
		'iteration-{nr}/new_V_{pre,(pre|)}germline.log'
	shell:
		"igdiscover germlinefilter {params.criteria}"
		" --annotate={output.annotated_tab}"
		" --fasta={output.fasta} {input.tab} "
		" 2> >(tee {log} >&2) "
		" > {output.tab}"


rule igdiscover_discover_j:
	"""Discover potential new J gene sequences"""
	output:
		tab="iteration-01/new_J.tab",
		fasta="iteration-01/new_J.fasta",
	input:
		j_reference="iteration-01/database/J.fasta",
		tab="iteration-01/filtered.tab.gz",
	params:
		allele_ratio='--allele-ratio={}'.format(config.j_discovery['allele_ratio']),
		cross_mapping_ratio=' --cross-mapping-ratio={}'.format(config.j_discovery['cross_mapping_ratio'])
	shell:
		"time igdiscover discoverj {params.allele_ratio}{params.cross_mapping_ratio} "
		"--database={input.j_reference} "
		"--fasta={output.fasta} "
		"{input.tab} > {output.tab}"


rule stats_correlation_V_J:
	output:
		pdf="{dir}/correlationVJ.pdf"
	input:
		table="{dir}/assigned.tab.gz"
	run:
		import matplotlib
		matplotlib.use('pdf')
		# sns.heatmap will not work properly with the object-oriented interface,
		# so use pyplot
		import matplotlib.pyplot as plt
		import seaborn as sns
		import numpy as np
		import pandas as pd
		from collections import Counter
		table = igdiscover.read_table(input.table)
		fig = plt.figure(figsize=(29.7/2.54, 21/2.54))
		counts = np.zeros((21, 11), dtype=np.int64)
		counter = Counter(zip(table['V_errors'], table['J_errors']))
		for (v,j), count in counter.items():
			if v is not None and v < counts.shape[0] and j is not None and j < counts.shape[1]:
				counts[v,j] = count
		df = pd.DataFrame(counts.T)[::-1]
		df.index.name = 'J errors'
		df.columns.name = 'V errors'
		sns.heatmap(df, annot=True, fmt=',d', cbar=False)
		fig.suptitle('V errors vs. J errors in unfiltered sequences')
		fig.set_tight_layout(True)
		fig.savefig(output.pdf)


rule plot_errorhistograms:
	output:
		multi_pdf='{dir}/errorhistograms.pdf',
		boxplot_pdf='{dir}/v-shm-distributions.pdf'
	input:
		table='{dir}/filtered.tab.gz'
	params:
		ignore_j=' --max-j-shm=0' if config.ignore_j else ''
	shell:
		'igdiscover errorplot{params.ignore_j} --multi={output.multi_pdf} --boxplot={output.boxplot_pdf} {input.table}'


rule dendrogram:
	output:
		pdf='{dir}/dendrogram_{gene}.pdf'
	input:
		fasta='{dir}/database/{gene}.fasta'
	shell:
		'igdiscover dendrogram --mark database/{wildcards.gene}.fasta {input.fasta} {output.pdf}'


rule version:
	output: txt='stats/version.txt'
	run:
		with open(output.txt, 'w') as f:
			print('IgDiscover version', igdiscover.__version__, file=f)


def get_sequences(path):
	with SequenceReader(path) as fr:
		sequences = [ record.sequence.upper() for record in fr ]
	return sequences


def count_sequences(path):
	with SequenceReader(path) as fr:
		n = 0
		for _ in fr:
			n += 1
	return n


rule json_stats_nofinal:
	output: json='stats/stats_nofinal.json'
	input:
		original_db='database/V.fasta',
		v_pre_germline=['iteration-{:02d}/new_V_pregermline.fasta'.format(i+1) for i in range(config.iterations)],
		v_germline=['iteration-{:02d}/new_V_germline.fasta'.format(i+1) for i in range(config.iterations)],
		filtered_stats=['iteration-{:02d}/stats/filtered.json'.format(i+1) for i in range(config.iterations)],
		group_stats='stats/groups.json',
		reads='stats/reads.json',
		trimmed='stats/trimmed.json'
	run:
		d = OrderedDict()
		d['version'] = igdiscover.__version__

		with open(input.reads) as f:
			rp = json.load(f)
		rp['raw_reads'] = rp['total']
		rp['merged'] = rp['merged']
		rp['merging_was_done'] = rp['merging_was_done']
		with open(input.trimmed) as f:
			rp['after_primer_trimming'] = json.load(f)['trimmed']

		with open(input.group_stats) as f:
			rp['grouping'] = json.load(f)
		d['read_preprocessing'] = rp

		prev_sequences = set(get_sequences(input.original_db))
		size = len(prev_sequences)
		iterations = [{'database': {'size': size}}]
		for i, (pre_germline_path, germline_path, filtered_json_path) in enumerate(
				zip(input.v_pre_germline, input.v_germline, input.filtered_stats)):

			pre_germline_sequences = set(get_sequences(pre_germline_path))
			germline_sequences = set(get_sequences(germline_path))

			gained = len(germline_sequences - prev_sequences)
			lost = len(prev_sequences - germline_sequences)
			gained_pre = len(pre_germline_sequences - prev_sequences)
			lost_pre = len(prev_sequences - pre_germline_sequences)

			iteration_info = OrderedDict()
			with open(filtered_json_path) as f:
				iteration_info['assignment_filtering'] = json.load(f)
			db_info = OrderedDict()
			db_info['size'] = len(germline_sequences)
			db_info['gained'] = gained
			db_info['lost'] = lost
			db_info['size_pre'] = len(pre_germline_sequences)
			db_info['gained_pre'] = gained_pre
			db_info['lost_pre'] = lost_pre
			iteration_info['database'] = db_info
			iterations.append(iteration_info)
			prev_sequences = pre_germline_sequences
		d['iterations'] = iterations

		with open(output.json, 'w') as f:
			json.dump(d, f, indent=2)
			print(file=f)


rule json_stats:
	output: json='stats/stats.json'
	input:
		stats_nofinal='stats/stats_nofinal.json',
		final_stats='final/stats/filtered.json',
	run:
		with open(input.stats_nofinal) as f:
			d = json.load(f)

		with open(input.final_stats) as f:
			d['assignment_filtering'] = json.load(f)

		with open(output.json, 'w') as f:
			json.dump(d, f, indent=2)
			print(file=f)
