/***************************************************************************
 *   Copyright (C) 2005 by Andreas Pokorny                                 *
 *   andreas.pokorny@biozentrum.uni-wuerzburg.de                           *
 *                                                                         *
 *   This file is part of profdist and cbcanalyzer                         *
 *                                                                         *
 *   Both profdist and cbcanalyzer are free software; you can redistribute *
 *   it and/or modify it under the terms of the GNU General Public License *
 *   as published by the Free Software Foundation; either version 2 of the *
 *   License, or (at your option) any later version.                       *
 *                                                                         *
 *   Profdist and cbcanalyzer are distributed in the hope that it will be  *
 *   useful, but WITHOUT ANY WARRANTY; without even the implied warranty   *
 *   of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the      *
 *   GNU General Public License for more details.                          *
 *                                                                         *
 *   You should have received a copy of the GNU General Public License     *
 *   along with this program; if not, write to the                         *
 *   Free Software Foundation, Inc.,                                       *
 *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
 ***************************************************************************/

#include <stdexcept>
#include <algorithm>
#include <iterator>
#include <limits>
#include <utility>
#include <boost/lexical_cast.hpp>
#include <iostream>
#include "tree.h"
#include "newickize.h"

// #define PROFDIST_DEBUG
#if 0
#define ADD_LEAF  std::cout << "Adding leaf: " << reference_position << " to the tree" << std::endl;
#define ADD_INNER std::cout << "Adding internal node: " << n_index << " with -:"; \
  std::copy( split_set.begin(), split_set.end(), std::ostream_iterator<size_t>( std::cout, "|" ) ); \
  std::cout << "- to the tree" << std::endl; 
#define DEBUG_LINE(a) cout << __LINE__ << " does ";a
#else
#define ADD_LEAF  
#define ADD_INNER 
#define DEBUG_LINE(a) a

#endif

Node::Node()
  : is_profile(false), reason(None), bootstrap_value(0), reference_position(0), node_index(0), sequence_index(0)
{
}

Node::Node( set_type const& set, size_t bs, size_t node ) 
      : is_profile(false), reason(None), bootstrap_value(bs), reference_position(0), node_index(node), sequence_index(0), split_set(set) {}

Node::Node( size_t seq ) 
  : is_profile(false), reason(Leaf), bootstrap_value(0), reference_position(seq), node_index(0), sequence_index(seq) {}


std::size_t Node::get_reference_position() const
{
  return reference_position;
}
std::size_t Node::get_sequence_index() const
{
  return sequence_index;
}
Node::set_type const&  Node::get_split_set() const
{
  return split_set;
}


void Node::propagate_pos() {
  for( const_iterator it = children.begin(), e = children.end(); it!=e;++it) 
    (*it)->propagate_pos();
  reason = WasAProfile;
}

std::size_t Node::get_node_index() const
{
  return node_index;
}
std::string Node::get_reason_str() const
{
  switch (reason) {
    case None: return "N";
    case Leaf: return "L";
    case Bootstrap: return "Pb";
    case Identity: return "Pi";
    case WasAProfile: return "Po";
    default: return "EEK";
  };
}

void Node::init_leaf_seq( std::size_t n_index, std::size_t seq_index )
{
  is_profile = false;
  node_index = n_index;
  reference_position = seq_index;
  sequence_index = seq_index;

  
  ADD_LEAF;
 //  split_set.clear();
 //  bootstrap_value = 0;
}
  
void Node::init_inner_seq( std::size_t n_index, std::size_t bs, set_type const& s_set )
{
  is_profile = false;
  node_index = n_index;
  split_set = s_set;
  bootstrap_value = bs;

  ADD_INNER;
    // reference_position = 0;
}


std::ostream& Node::post_order( std::ostream& out, std::vector<std::string> const& names, std::size_t num_bootstraps, bool print_with_profiles ) const
{
  if( children.empty() ) {
    assert( names.size() > sequence_index ); 
    return out << profdist::newickize( names[sequence_index] );
  }
  else {
    if( ! print_with_profiles  && is_profile )
      print_with_profiles = true;

    out << '(';
    for( const_iterator it = children.begin(), e = children.end(); it!=e;++it) {
      (*it)->post_order( out, names, num_bootstraps, print_with_profiles );
      if(  (it + 1 ) != e ) 
        out << ',';
    }
    out  << ')';
   
    if( print_with_profiles && ( !(reason == None || reason == Leaf ) || is_profile ) )
      out << get_reason_str();
    else 
      out << node_index; 
    if( num_bootstraps )
      out << "."  << size_t(100.0f*float(bootstrap_value)/float(num_bootstraps)); 
  }
  return out;
}

Tree::Tree() {}
Tree::Tree( CountedSets const& sets, std::size_t num_leaves, std::size_t num_bootstraps )
{
  create_tree_bottom_up_without_consense( sets, num_leaves, num_bootstraps);
}

void Tree::create_tree_bottom_up_without_consense( CountedSets const& sets, std::size_t num_leaves, std::size_t num_bootstraps ) {
#ifdef PROFDIST_DEBUG
  try {
#endif

  enum AlgMode { Majority, Minority,Multi };
  std::vector<char> visited( sets.size() , '\0' );
  std::size_t node_index = 0; 
  std::size_t remaining_seq = num_leaves;
  std::size_t majority = num_bootstraps/2 + num_bootstraps%2;

  typedef std::map<CountedSets::set_type, Node::ptr> nc_type;
  nc_type remaining_nodes;
  for( std::size_t i = 0; i < num_leaves; ++i ) {
    CountedSets::set_type s;
    s.insert(i);
    remaining_nodes[s] = Node::ptr( new Node( i ) ); 
  }

 
  AlgMode mode = Majority;
  bool done = false;
  while( remaining_nodes.size() > 3 && !done)  {
    bool improvement = false;
    size_t index = 0;
    switch(mode) {
      case Majority: 
        {
          for( CountedSets::const_iterator it = sets.begin(), e = sets.end();remaining_nodes.size() > 3 && it != e ; ++it, ++index ) {
            if( ! visited[index] && it->second > majority) {
              CountedSets::set_type::const_iterator cons_set_begin = it->first.begin(), cons_set_end = it->first.end();

              for( nc_type::iterator node_it = remaining_nodes.begin(); node_it != remaining_nodes.end(); ++node_it ) {

                CountedSets::set_type::const_iterator node_begin = node_it->first.begin(), node_end = node_it->first.end();

                if( std::includes( cons_set_begin, cons_set_end, node_begin, node_end ) ) {
                  CountedSets::set_type temp;
                  set_difference( 
                      cons_set_begin, cons_set_end,
                      node_begin, node_end, 
                      std::insert_iterator<CountedSets::set_type>( temp, temp.begin() )
                      );

                  nc_type::iterator other_node = remaining_nodes.find( temp );

                  if( other_node != remaining_nodes.end() ) {
                    visited[index]=1;
                    Node::ptr cn( new Node( it->first, it->second, node_index++ ) );
                    cn->add_child(node_it->second);
                    cn->add_child(other_node->second);
                    remaining_nodes[it->first] = cn;
                    remaining_nodes.erase( node_it );
                    remaining_nodes.erase( other_node );
                    improvement = true;
                    break;
                  }
                }
              }
            }
          }

          if(!improvement )
            mode = Minority;
          break;
        }
      case Minority:
        {
          std::size_t num_relevant = 0
            , biggest_bs = 0
            , biggest_index = 0;
          std::pair<nc_type::iterator, nc_type::iterator> biggest_relevant(remaining_nodes.end(), remaining_nodes.end());
          CountedSets::set_type set_with_biggest;
          for( CountedSets::const_iterator it = sets.begin(), e = sets.end();remaining_nodes.size() > 3 && it != e ; ++it, ++index ) {
            if( ! visited[index] && it->second <= majority) {
              CountedSets::set_type::const_iterator cons_set_begin = it->first.begin(), cons_set_end = it->first.end();

              for( nc_type::iterator node_it = remaining_nodes.begin(); node_it != remaining_nodes.end(); ++node_it ) {

                CountedSets::set_type::const_iterator node_begin = node_it->first.begin(), node_end = node_it->first.end();

                if( std::includes( cons_set_begin, cons_set_end, node_begin, node_end ) ) {
                  CountedSets::set_type temp;
                  set_difference( 
                      cons_set_begin, cons_set_end,
                      node_begin, node_end, 
                      std::insert_iterator<CountedSets::set_type>( temp, temp.begin() )
                      );

                  nc_type::iterator other_node = remaining_nodes.find( temp );

                  if( other_node != remaining_nodes.end() && it->second > biggest_bs ) {
                    improvement = true;
                    biggest_bs = it->second;
                    biggest_index = index;
                    biggest_relevant = std::make_pair( node_it, other_node );
                    set_with_biggest = it->first;
                    ++num_relevant;
                    break;
                  }
                }
              }
            }
          }

          if(!improvement ) {
            mode = Multi;
          }
          else {
            visited[biggest_index]=1;
            Node::ptr cn( new Node( set_with_biggest, biggest_bs, node_index++ ) );
            cn->add_child( biggest_relevant.first->second ); 
            cn->add_child( biggest_relevant.second->second );
            remaining_nodes[set_with_biggest] = cn;
            remaining_nodes.erase( biggest_relevant.first );
            remaining_nodes.erase( biggest_relevant.second );

            mode = Majority;
          }
          break;
        }
      case Multi:
        {
          std::size_t best_index = 0
            , best_bs = 0
            , best_split_size = std::numeric_limits<std::size_t>::max();
          CountedSets::set_type best_set;
          std::list<nc_type::iterator> smallest_furcation;
          for( CountedSets::const_iterator it = sets.begin(), e = sets.end();remaining_nodes.size() > 3 && it != e ; ++it, ++index ) {
            if( ! visited[index] ) {

              CountedSets::set_type temp = it->first;
              std::list<nc_type::iterator> temp_furcation;

              for( nc_type::iterator node_it = remaining_nodes.begin(); temp_furcation.size() <= best_split_size 
                  && temp.size() != 0 
                  && node_it != remaining_nodes.end(); ++node_it ) {
                if( std::includes( temp.begin(), temp.end(), node_it->first.begin(), node_it->first.end() ) ) {
                  temp_furcation.push_back(node_it);
                  CountedSets::set_type temp2;
                  set_difference( 
                      temp.begin(), temp.end(), node_it->first.begin(), node_it->first.end(),
                      std::insert_iterator<CountedSets::set_type>( temp2, temp2.begin() )
                      );
                  temp = temp2;
                }
              }

              if( temp.empty() && ( best_split_size > temp_furcation.size() 
                    ||  (best_split_size == temp_furcation.size() && best_bs < it->second ) )  ){
                best_set = it->first;
                best_index = index;
                best_bs = it->second;
                smallest_furcation = temp_furcation;
                improvement = true;
              }
            }
          }
          if(!improvement ) {
            done = true;
          }
          else {
            visited[best_index] = 1;
            Node::ptr cn (new Node( best_set, best_bs, node_index++ ) );
            remaining_nodes[best_set] = cn;
            for( std::list<nc_type::iterator>::const_iterator it = smallest_furcation.begin(), e = smallest_furcation.end();
                it != e; ++it ) {
              cn->add_child( (*it)->second );
              remaining_nodes.erase( *it );
            }

            mode = Majority;
          }
          break;
        }
    }
  }
  
#if 0
  if(remaining_nodes.size() > 3 ) {
 #ifdef PROFDIST_DEBUG
    ofstream deb("debug_data", std::ios::app); 
    if(deb) {
      deb << "Remaining Nodes : \n";
      for( nc_type::iterator node_it = remaining_nodes.begin(); node_it != remaining_nodes.end(); ++node_it ) {
        print_split_set( deb, node_it->first );
        deb << "\n";
      }
      size_t index = 0;
      deb << "Majority Splits not used:\n";
      for( CountedSets::const_iterator it = sets.begin(), e = sets.end();remaining_nodes.size() > 3 && it != e ; ++it, ++index ) 
        if( ! visited[index] && it->second >= majority) {
          print_split_set( deb, it->first );
          deb << " BS = " << it->second << '\n';
          CountedSets::set_type temp = it->first;
          for( nc_type::iterator node_it = remaining_nodes.begin(); temp.size() != 0 && node_it != remaining_nodes.end(); ++node_it ) {
            if( std::includes( temp.begin(), temp.end(), node_it->first.begin(), node_it->first.end() ) ) {
              deb << "    found ";
              print_split_set( deb, node_it->first );
              CountedSets::set_type temp2;
              set_difference( 
                  temp.begin(), temp.end(), node_it->first.begin(), node_it->first.end(),
                  std::insert_iterator<CountedSets::set_type>( temp2, temp2.begin() )
                  );
              temp = temp2;
              deb << '\n';
            }
          }
          deb << "    remains: required: ";
          print_split_set( deb, temp );
          deb << '\n';

        }
      index = 0;
      deb << "Minority Splits not used:\n";
      for( CountedSets::const_iterator it = sets.begin(), e = sets.end();remaining_nodes.size() > 3 && it != e ; ++it, ++index ) 
        if( ! visited[index] && it->second < majority) {
          print_split_set( deb, it->first );
          deb << " BS = " << it->second << '\n';
          CountedSets::set_type temp = it->first;
          for( nc_type::iterator node_it = remaining_nodes.begin(); temp.size() != 0 && node_it != remaining_nodes.end(); ++node_it ) {
            if( std::includes( temp.begin(), temp.end(), node_it->first.begin(), node_it->first.end() ) ) {
              deb << "  found ";
              print_split_set( deb, node_it->first );
              CountedSets::set_type temp2;
              set_difference( 
                  temp.begin(), temp.end(), node_it->first.begin(), node_it->first.end(),
                  std::insert_iterator<CountedSets::set_type>( temp2, temp2.begin() )
                  );
              temp = temp2;
            }
          }
          deb << " remains: required: ";
          print_split_set( deb, temp );
          deb << '\n';
        }
    }
#endif
    throw runtime_error("More than three sub nodes remaining!\nYou just detected a bug in profdist, please send the input data to profdist@biozentrum.uni-wuerzburg.de.");
  }
  else {
#endif
    for( nc_type::iterator node_it = remaining_nodes.begin(); node_it != remaining_nodes.end(); ++node_it ) {
      children.push_back( node_it->second);
    }
#if 0
  }
#endif
#ifdef PROFDIST_DEBUG
  }catch(runtime_error &e ) {
    ofstream deb("debug_data", std::ios::app); 
    if(deb) deb << "CountedSetsobjekt:" << sets << std::endl;
    throw e;
  }
#endif
}

#if 0
#define CREATE_PROFILE(a) std::cout << "======" << a->node_index << ":" << a->reference_position << ":" << a->sequence_index << " yields a profile at: " << #a << std::endl 
#define PROFILE_REASON(a) std::cout << "++++++" << node_index << ":" << reference_position << ":" << sequence_index << " is a profile because : " << #a << std::endl 
#define NO_PROFILE_REASON(a) std::cout << "------"<< node_index << ":" << reference_position << ":" << sequence_index << " is no profile because : " << #a << std::endl 
#else
#define CREATE_PROFILE(a) 
#define PROFILE_REASON(a) 
#define NO_PROFILE_REASON(a) 
#endif

Node::profile_reason Node::find_profile( std::size_t& prof_count, profile_map & profiles, profile_set const& known_profiles, identical_seq_set const& identicals, std::size_t threshold, bool pnj_method  )
{
  if( is_profile ) {
    split_set.clear(); // no longer in this step
    PROFILE_REASON(was a profile earlier);
    propagate_pos();
    return reason;
  }
  if( known_profiles.find(node_index) != known_profiles.end()  ) {
    PROFILE_REASON(is a known profile);
    propagate_pos();
    return reason;
  }
  if( !children.empty() )
  {
    if( pnj_method ) { // identity method is enabled
      bool all_are_identical = children[children.size()-1]->reason == WasAProfile 
        ||  children[children.size()-1]->reason == Leaf; // last child is a leaf
      for( std::size_t index = 0, e = children.size(); all_are_identical && index != e - 1; ++index ) 
      {
        all_are_identical =  all_are_identical && ( children[index]->reason == WasAProfile  || children[index]->reason == Leaf  ); 
        // all children are leaves or were profiles in previous steps
        for( std::size_t index_2 = index+1 ; all_are_identical && index_2 != e ; ++index_2 )  {
          all_are_identical = all_are_identical
            && ( 
                identicals.end() != identicals.find( std::make_pair( 
                    children[index]->reference_position
                    , children[index_2]->reference_position
                    ) ) 
               );
        }
      }
      if( all_are_identical ) {
        PROFILE_REASON(leaves found are identical);
        return reason=Identity;
      }
    }



    bool creating_profile = true;
    std::vector<Node::profile_reason> reasons( children.size(), None );
    for( std::size_t index = 0, e = children.size(); index!=e;++index ) {
      reasons[index] = children[index]->find_profile( prof_count, profiles, known_profiles, identicals, threshold, pnj_method );
      creating_profile = creating_profile && reasons[index];
    }
    if( bootstrap_value >= threshold && pnj_method && creating_profile ) {
      PROFILE_REASON(all children are profiles and bootstrap beyond thresold);
      return reason=Bootstrap;
    }

    // create profile, on each node that yield a profile reason left or right side only if we are not near leaves, 
    // leaves always return true! - if there was no other reason for a profile, then we cannot
    // create one here ->  !left->is_leaf()
    for( std::size_t index = 0 ; index!=children.size();++index ) {
      if( reasons[index] )  {
        children[index]->turn_into_profile( prof_count, reasons[index] );
        profiles.insert( std::make_pair( prof_count++, children[index] ) );
        children.erase( children.begin() + index );
        reasons.erase( reasons.begin() + index );
        --index;
      }
    }
    return reason=None;
  }
  else 
    return reason=Leaf;
}

Node::profile_reason Node::find_profile_first( std::size_t& prof_count, profile_map & profiles, profile_set const& known_profiles, identical_seq_set const& identicals, std::size_t threshold, bool pnj_method  )
{
  if( is_profile ) {
    split_set.clear(); // no longer in this step
    PROFILE_REASON(was a profile earlier);
    propagate_pos();
    return reason;
  }
  if( known_profiles.find(node_index) != known_profiles.end()  ) {
    PROFILE_REASON(is a known profile);
    propagate_pos();
    return reason;
  }
  if( ! children.empty()  )
  {
    if( pnj_method ) { // identity method is enabled
      bool all_are_identical = children[children.size()-1]->reason == WasAProfile 
        ||  children[children.size()-1]->reason == Leaf; // last child is a leaf
      for( std::size_t index = 0, e = children.size(); all_are_identical && index != e - 1; ++index ) 
      {
        all_are_identical =  all_are_identical && ( children[index]->reason == WasAProfile  || children[index]->reason == Leaf  ); 
          // all children are leaves or were profiles in previous steps
        for( std::size_t index_2 = index+1 ; all_are_identical && index_2 != e ; ++index_2 )  {
          all_are_identical = all_are_identical
            && ( 
                identicals.end() != identicals.find( std::make_pair( 
                    children[index]->reference_position
                    , children[index_2]->reference_position
                    ) ) 
               );
        }
      }

      if( all_are_identical ) {
        PROFILE_REASON(leaves found are identical);
        return reason=Identity;
      }
    }

    bool creating_profile = true;
    std::vector<Node::profile_reason> reasons( children.size(), None );
    for( std::size_t index = 0, e = children.size(); index!=e;++index )  {
      reasons[index] =  children[index]->find_profile_first( prof_count, profiles, known_profiles, identicals, threshold, pnj_method );
      creating_profile =  creating_profile && ( reasons[index] != None );
    }

    if( bootstrap_value >= threshold && pnj_method && creating_profile  ) {
      PROFILE_REASON(all children are profiles and bootstrap beyond thresold);
      return reason=Bootstrap;
    }

    for( std::size_t index = 0; index!=children.size();++index ) {
      if( reasons[index] != None )  {
        children[index]->turn_into_first_profile( prof_count, reasons[index] );
        profiles.insert( std::make_pair( prof_count++, children[index] ) );
        children.erase( children.begin() + index );
        reasons.erase( reasons.begin() + index );
        --index;
      }
    }

    return reason=None;
  }
  else 
    return reason=Leaf;
}


void Node::turn_into_profile( std::size_t prof_index, Node::profile_reason r )
{
  is_profile = true;
  reference_position = node_index;
  node_index = prof_index;
  reason = r;
}

void Node::turn_into_first_profile( std::size_t prof_index, Node::profile_reason r )
{
  is_profile = true;
//  reference_position = prof_index;
  node_index = prof_index;
  reason = r;
}



void Tree::find_profile( profile_map & profiles, profile_set const& known_profiles, identical_seq_set const& identicals, std::size_t threshold, bool pnj_method  )
{
  std::size_t prof_count = 0;
  for( std::size_t index = 0; index!=children.size();++index ) {
    Node::profile_reason r =  children[index]->find_profile( prof_count, profiles, known_profiles, identicals, threshold, pnj_method );
    if( Node::None != r )  {
      children[index]->turn_into_profile( prof_count, r );
      profiles.insert( std::make_pair( prof_count++, children[index] ) );
      children.erase( children.begin() + index );
      --index;
    }
  }


}

void Tree::find_profile_first( profile_map & profiles, profile_set const& known_profiles, identical_seq_set const& identicals, std::size_t threshold, bool pnj_method  )
{
  std::size_t prof_count = 0;
  for( std::size_t index = 0; index!=children.size();++index ) {
    Node::profile_reason r =  children[index]->find_profile_first( prof_count, profiles, known_profiles, identicals, threshold, pnj_method );
    if( Node::None != r )  {
      children[index]->turn_into_first_profile( prof_count, r  );
      profiles.insert( std::make_pair( prof_count++, children[index] ) );
      children.erase( children.begin() + index );
      --index;
    }
  }
  
}


bool Node::union_tree( profile_map const & profiles )
{
  if( children.empty() ) 
    return true;

  for( std::size_t index = 0, e = children.size(); index!=e;++index ) 
    if( children[index]->union_tree(profiles) )
      union_node( children[index], profiles);
  return false;
}

void Node::union_node( boost::shared_ptr<Node> &node, profile_map const& profiles )
{
  profile_map::const_iterator it = profiles.find( node->reference_position ); 
  if( it == profiles.end() )
    throw std::runtime_error("node not found while createing union");

  node = it->second;
}

void Tree::union_tree( profile_map const& profs ) 
{
  for( std::size_t index = 0, e = children.size(); index!=e;++index ) 
    if( children[index] && children[index]->union_tree( profs)  ) 
      Node::union_node( children[index], profs );
}

Node::const_iterator  Tree::begin() const { return children.begin(); }
Node::const_iterator  Tree::end() const { return children.end(); }


std::ostream& Tree::print( std::ostream& out, std::vector<std::string> const& names, std::size_t num_bootstraps ) const
{
  out << '(';
  for( Node::const_iterator it = children.begin(), e = children.end(); it!=e;++it) {
    (*it)->post_order( out, names, num_bootstraps, false );
    if(  (it + 1 ) != e ) 
      out << ',';
  }
  return out << ");";
}

std::ostream& Tree::print_graphviz_debug( std::ostream & out,  std::vector<std::string> const& names ) const
{
  out << "digraph { \n node [shape=Mrecord]; \n virtual_root ;\n";
  for( std::size_t index = 0, e = children.size(); index!=e;++index )  {
    children[index]->print_graphviz_debug( out, names );
    out << " virtual_root -> node_" << get_pointer(children[index]) << ";\n";
  }
  return out << "}\n" << std::endl; 
}

std::ostream& Node::print_graphviz_debug( std::ostream & out,  std::vector<std::string> const& names ) const
{
  out << " node_" << this << "[ label=\"";
  if( children.empty() )
    out << names[sequence_index] << "|";
  //if( is_profile )
  if( !(reason == None || reason == Leaf ) || is_profile )
    out << get_reason_str() << " |";
  out << "{" << "node_index " << node_index << " | ";
  out << "ref_pos " << reference_position << " | ";
  out << "seq_index " << sequence_index << " | ";
  out << "bs_v " << bootstrap_value << " }\"];\n";

  for( std::size_t index = 0, e = children.size(); index!=e;++index )  {
    children[index]->print_graphviz_debug( out, names );
    out << " node_" << this << " -> node_" << get_pointer(children[index]) << ";\n";
  }
  return out;
}


