#include <iostream>
#include <boost/lexical_cast.hpp>
#include "types.h"
#include "parser.h"
#include "distance.h"
#include "traits.hpp"
#include "aligncode.h"

using namespace profdist;
using namespace boost;
template<class T, size_t Dim>
double logLike( profdist::fixed_matrix<T,Dim,Dim> const& N, profdist::fixed_matrix<double,Dim,Dim> const& Q, T const& t )
{
  double Lt = 0;

  profdist::fixed_matrix<double,Dim,Dim> P( exp( Q, t, Pade ) );

  for( std::size_t i = 0; i < Dim; ++i )
    for( std::size_t j = 0; j < Dim; ++j ) {
      T temp = log(P( i, j ) );
      if( !isfinite( temp ) )
        std::cout << "log(" << P(i,j) <<  ") -> " << temp << std::endl;
      Lt += N( i, j ) * temp;
    }

  return Lt;
}

template<class T, size_t Dim>
void plot_exp( std::ostream & file, profdist::fixed_matrix<T,Dim,Dim> const& Q ) {
  for( double t = 1; t < 300; t+=10.0 ) {
    try {
      fixed_matrix<T,Dim,Dim> em( exp(Q,t,Pade) );
    file << t << "\n" << em << std::endl;
    }catch( ... ) {}
  }
}

template<class T, size_t Dim>
void plot( std::ofstream & file, profdist::fixed_matrix<T,Dim,Dim> const& M, profdist::fixed_matrix<double,Dim,Dim> const& Q ) {
  fixed_matrix<double,Dim,Dim> N(  M  );
  N = N + transpose( N );
  for( double t = 0; t < 3; t+=0.1 ) {
    try {
    file << t << " " << logLike(N,Q,t) << std::endl;
    }catch( ... ) {}
  }
  for( double t = 3; t < 300; t+=1.0 ) {
    try {
    file << t << " " << logLike(N,Q,t) << std::endl;
    }catch( ... ) {}
  }
}

template<class T, size_t Dim>
void print_N( std::ofstream & file, profdist::fixed_matrix<T,Dim,Dim> const& M ) {
  fixed_matrix<double,Dim,Dim> N(  M  + transpose( M ) );
  file << N << std::endl;
}


  template<size_t N>
  inline void checked_inc( fixed_matrix<size_t,N,N> & mat, size_t i, size_t j ) {
    if( i < N && j < N )
      ++mat[i][j];
  }
  template<size_t N>
  inline void checked_dec( fixed_matrix<size_t,N,N> & mat, size_t i, size_t j ) {
    if( i < N && j < N )
      --mat[i][j];
  }


template<typename Traits>
void compute_distance( AlignCode<Traits> const& source,  typename Traits::rate_matrix const& Q)
{
  std::size_t num_seq = source.get_num_sequences()
    , max_steps = ( num_seq - 1 ) * ( num_seq - 2 ) / 2;
  double value = 0.0;

  ofstream Nm( "N_matrices" );
  for ( std::size_t i = 0; i < num_seq - 1; ++i )
  {

    typename AlignCode<Traits>::count_matrix const& A = source.get_matrix(i);
    ofstream plot_file(("plot_1_" + lexical_cast<std::string>(i+2)).c_str());
    plot(plot_file, A , Q );
    plot_file.close();
    print_N( Nm, A );

    for( std::size_t j = i + 1; j < num_seq - 1; ++j )
    {
      // Initialialize SubstMatrix N_i+2_j+2 with the diagonale from N_1_i+2
      typename AlignCode<Traits>::count_matrix B( 0U );
      for( int k = 0; k < Traits::num_relevant_elements; ++k )
        B[k][k]= A[k][k];

      // i+1 and j+1 since we start at the second  sequence
      typename AlignCode<Traits>::const_diff_iterator it_b = source.begin_difference( j )
        , it_a = source.begin_difference( i )
        , a_end = source.end_difference( i )
        , b_end = source.end_difference( j )
        ;

      while ( it_a != a_end && it_b != b_end  )
      {
        if ( it_a->first == it_b->first ) // both are different to first sequence
        {
          checked_inc(B, it_a->second,  it_b->second );
          ++it_a;
          ++it_b;
        }
        else if( it_a->first < it_b->first ) // Unterschied festgestellt und wird in Unterschiedsliste aufgenommen
        {
          checked_inc(B, it_a->second, source.get_reference_element( it_a->first ));
          ++it_a;
        }
        else
        {
          checked_inc( B, it_b->second, source.get_reference_element( it_b->first ));
          checked_dec( B, source.get_reference_element( it_b->first ), source.get_reference_element( it_b->first ));
          ++it_b;
        }
      } //end of while

      while ( it_a != a_end  )
      {
        checked_inc(B, it_a->second, source.get_reference_element( it_a->first ));
        ++it_a;
      }

      while ( it_b != b_end )
      {
        checked_inc( B, it_b->second, source.get_reference_element( it_b->first ));
        checked_dec( B, source.get_reference_element( it_b->first ), source.get_reference_element( it_b->first ));
        ++it_b;
      }

      ofstream plot_file(("plot_" + lexical_cast<std::string>(i+2) + "_" + lexical_cast<std::string>(j+2)).c_str());
      plot(plot_file, B, Q );
      plot_file.close();
      print_N( Nm, B );
    } // end for j
  } //end for  i
}

int main(int argc, char ** argv ) {
  try {
  alignment a;
  parse_fasta(argv[2], a );
  protein_traits::rate_matrix Q(0);
  ifstream in( argv[1] );
  read_rate_matrix<protein_traits>( in, Q );
  if( ! in ) { std::cout <<"Fail" << std::endl; return 1;}
//  Q*=100.0f;
  ofstream  plot_exp_f( "plot_exp_qt" );
  plot_exp( plot_exp_f, Q );
  AlignCode<protein_traits> code;
  code.read_sequences( a );
  compute_distance( code, Q );
  }catch( std::exception const&e ) {
    std::cout << e.what() << std::endl; 
    return 1;
  }catch( ... ) {
    std::cout << "Was eingefangen" << std::endl; 
    return 1;
 }
}

