#!/usr/bin/env perl

use strict ;
use warnings ;

use File::Basename;
use Cwd;
use Cwd 'cwd' ;
use Cwd 'abs_path' ;


die "Usage: centrifuge-promote centrifuge_index_name centrifuge_output level > output\n\n".
	"Promote the taxonomy id to specified level in Centrifuge output.\n".
	"\tIf level equals \"lca\", this will merge the multiassignment to their lowest common ancestor.\n" if ( @ARGV == 0 ) ;

my $CWD = dirname( abs_path( $0 ) ) ;
# Go through the index to obtain the taxonomy tree
my %taxParent ; 
my %taxIdToSeqId ;
my %taxLevel ;

my $centrifuge_index = $ARGV[0] ;
open FP1, "-|", "$CWD/centrifuge-inspect --taxonomy-tree $centrifuge_index" or die "can't open $!\n" ;
while ( <FP1> )
{
	chomp ;
	my @cols = split /\t\|\t/;
	$taxParent{ $cols[0] } = $cols[1] ;
	$taxLevel{ $cols[0] } = $cols[2] ;
}
close FP1 ;
open FP1, "-|", "$CWD/centrifuge-inspect --conversion-table $centrifuge_index" or die "can't open $!\n" ;
while ( <FP1> )
{
	chomp ;
	my @cols = split /\t/ ;
	$taxIdToSeqId{ $cols[1] } = $cols[0] ;
}
close FP1 ;

# Go through the output of centrifuge
my $level = $ARGV[2] ;
sub PromoteTaxId
{
	my $tid = $_[0] ;
	return 0 if ( $tid <= 0 || !defined( $taxLevel{ $tid } ) ) ;

	if ( $taxLevel{ $tid } eq $level )
	{
		return $tid ;
	}
	else
	{
		return 0 if ( $tid <= 1 ) ;
		return PromoteTaxId( $taxParent{ $tid } ) ;
	}
}

sub lca 
{
	my ($a, $b) = @_;
	return $b if $a eq 0;
	return $a if $b eq 0;
	return $a if $a eq $b;
	my %a_path;
	while ($a ge 1) 
	{
		$a_path{$a} = 1;
		if (!defined $taxParent{$a}) {
			print STDERR "Couldn't find parent of taxID $a - directly assigned to root.\n";
			last;
		}
		last if $a eq $taxParent{$a};
		$a = $taxParent{$a};
	}

	while ($b > 1) 
	{
		return $b if (defined $a_path{$b});
		if (!defined $taxParent{$b}) {
			print STDERR "Couldn't find parent of taxID $b - directly assigned to root.\n";
			last;
		}
		last if $b eq $taxParent{$b};
		$b = $taxParent{$b};
	}
	return 1;
}

sub OutputPromotedLines
{
	my @lines = @{ $_[0] } ;
	return if ( scalar( @lines ) <= 0 ) ;

	my @newLines ;
	my $i ;
	my $numMatches = 0 ;
	my %showedUpTaxId ;
	my $tab = sprintf( "\t" ) ;

	if ( $level ne "lca" )
	{
		for ( $i = 0 ; $i < scalar( @lines ) ; ++$i )
		{
			my @cols = split /\t+/, $lines[ $i ] ;
			my $newTid = PromoteTaxId( $cols[2] ) ;
			if ( $newTid <= 1 )
			{
				$newTid = $cols[2] ;
			}
			my $newLevel = $cols[1] ;
			$newLevel = $taxLevel{ $newTid } if ( $newTid >= 1 && defined $taxLevel{ $newTid } ) ;

			next if ( defined $showedUpTaxId{ $newTid } ) ;

			$showedUpTaxId{ $newTid } = 1 ;	
			++$numMatches ;

			$cols[2] = $newTid ;
			$cols[1] = $newLevel ;
			push @newLines, join( $tab, @cols ) ;
		}
	}
	else
	{
		$numMatches = 1 ;
		my @cols = split /\t+/, $lines[0] ;
		my $l = $cols[2] ;
		for ( $i = 1 ; $i < scalar( @lines ) ; ++$i )
		{
			@cols = split /\t+/, $lines[ $i ] ;
			
			$l = lca( $l, $cols[2] ) ;
		}
		
		@cols = split /\t+/, $lines[0] ;
		$cols[1] = $taxLevel{ $l } if ( $l ne $cols[2] ) ;
		$cols[2] = $l ;
		push @newLines, join( $tab, @cols ) ;
	}

	for ( $i = 0 ; $i < scalar( @newLines ) ; ++$i )
	{
		my @cols = split /\t+/, $newLines[$i] ;
		$cols[-1] = $numMatches ;
		print join( $tab, @cols ), "\n" ;
	}
}

open FP1, $ARGV[1] ;
my $header = <FP1> ;
my $prevReadId = "" ;
my @lines ;

print $header ;
while ( <FP1> )
{
	chomp ;
	my @cols = split /\t/ ;
	if ( $cols[0] eq $prevReadId )
	{
		push @lines, $_ ;
	}
	else
	{
		$prevReadId = $cols[0] ;
				
		OutputPromotedLines( \@lines ) ;

		undef @lines ;
		push @lines, $_ ;
	}
}
OutputPromotedLines( \@lines )  ;
close FP1 ;
