#ifndef _STEM_TOOLS_H_
#define _STEM_TOOLS_H_

#include <iostream>
#include <string>
#include <Vec.h>
#include <Stem.h>
#include <stemhelp.h>
#include <StringTools.h>
#include <vector>
#include <math.h>
#include <RnaSecondaryStructure.h>
#include <vectornumerics.h> // only needed for minimum

class StemTools {

 public:

  typedef Stem::index_type index_type;
  typedef unsigned int size_type;
  typedef Stem stem_type;
  typedef Vec<Stem> stem_set;
  typedef Vec<Vec<double> > matrix_type;
  typedef Vec<double> matrix_row_type;
  typedef pair<double, double> point_type;

  /** Converts a 2 stems into a distance. Each stems is treated like two 2D points. This leads to 4 possible distances, the smallest ist used  */
  static double stemEndDist(const stem_type& stem1, const stem_type& stem2) {
    point_type s1b(static_cast<double>(stem1.getStart()),
		   static_cast<double>(stem1.getStop()));  // point corresponding to first base pair of stem1
    point_type s1e(static_cast<double>(stem1.getStart() + stem1.getLength() - 1),
		   static_cast<double>(stem1.getStop()  - stem1.getLength() + 1));  // point corresponding to first base pair of stem1
    point_type s2b(static_cast<double>(stem2.getStart()),
		   static_cast<double>(stem2.getStop()));  // point corresponding to first base pair of stem1
    point_type s2e(static_cast<double>(stem2.getStart() + stem2.getLength() - 1),
		   static_cast<double>(stem2.getStop()  - stem2.getLength() + 1));  // point corresponding to first base pair of stem1
    double d1b2b = sqrt((s1b.first-s2b.first)*(s1b.first-s2b.first) + (s1b.second-s2b.second)*(s1b.second-s2b.second));
    double d1b2e = sqrt((s1b.first-s2e.first)*(s1b.first-s2e.first) + (s1b.second-s2e.second)*(s1b.second-s2e.second));
    double d1e2b = sqrt((s1e.first-s2b.first)*(s1e.first-s2b.first) + (s1e.second-s2b.second)*(s1e.second-s2b.second));    
    double d1e2e = sqrt((s1e.first-s2e.first)*(s1e.first-s2e.first) + (s1e.second-s2e.second)*(s1e.second-s2e.second));    
    return minimum(minimum(d1b2b,d1b2e), minimum(d1e2b,d1e2e)); // return overall minimum
  }

  /** Converts a 2 stems into a distance. Each stems is treated like a 2D point */
  static double stemStartDist(const stem_type& stem1, const stem_type& stem2) {
    double d1 = stem2.getStart() - stem1.getStart();
    double d2 = stem2.getStop() - stem1.getStop();
    return sqrt(d1*d1 + d2*d2);
  }

  /** Converts a 2 stems into a distance. Each stems is treated like two 2D points. This leads to 4 possible distances, the smallest ist used  */
  static bool stemDistTest() {
    Stem stem1(1000,1000,99);
    Stem stem2(1100,900,50);
    double de = stemEndDist(stem1, stem2);
    double ds = stemStartDist(stem1, stem2);
    cout << "Distance between stem 1: " << stem1 << " and stem 2: " << stem2 << " : " << de << " start distance: " << ds << endl;
    return (de < 3.0) && (ds > 90);
  }

  /** Converts a list of stems into a distance matrix. Each stems is treated like a 2D point */
  static matrix_type convertStemStartsToDistanceMatrix(const stem_set& stems) {
    ASSERT(stems.size() > 0);
    matrix_type result(stems.size(), matrix_row_type(stems.size(), 0));
    ASSERT(result.size() == stems.size());
    ASSERT(result[0].size() == stems.size());
    for (stem_set::size_type i = 0; i < stems.size(); ++i) {
      for (stem_set::size_type j = i+1; j < stems.size(); ++j) {
	result[i][j] = stemStartDist(stems[i], stems[j]);
	result[j][i] = result[i][j];
      }
    }
    return result;
  }

  /** Converts a list of stems into a distance matrix. Each stems is treated like a 2D point */
  static matrix_type convertStemEndsToDistanceMatrix(const stem_set& stems) {
    ASSERT(stems.size() > 0);
    matrix_type result(stems.size(), matrix_row_type(stems.size(), 0));
    ASSERT(result.size() == stems.size());
    ASSERT(result[0].size() == stems.size());
    for (stem_set::size_type i = 0; i < stems.size(); ++i) {
      for (stem_set::size_type j = i+1; j < stems.size(); ++j) {
	result[i][j] = stemEndDist(stems[i], stems[j]);
	result[j][i] = result[i][j];
      }
    }
    return result;
  }

  /** Expand stems of length n into n stems of length 1. Careful: currently only supports reverse-complementary stems. */
  static Vec<Stem> expandStem(const Stem& stem, index_type stemLength, bool reverseMode) {
    Vec<Stem> result;
    index_type totLen = 0;
    int numStems = stem.getLength() / stemLength;
    ASSERT(numStems > 0);
    // cout << "Expanding stem " << stem << " of length " << stem.getLength() << " using minimum stem length " << stemLength << " into " << numStems << " substems " << endl;
    for (int i = 0; i < (numStems-1); ++i) {
      index_type offset = i * stemLength;
      if (reverseMode) { // reguar reverse-complementary stems
	Stem newStem(stem.getStart() + offset, stem.getStop() - offset, stemLength);
	ASSERT(newStem.getLength() >= stemLength);
	result.push_back(newStem);
	totLen += newStem.getLength();
      } else { // case of "forward-matching" stems as used by covarna in case of opposing strand directionalities
	Stem newStem(stem.getStart() + offset, stem.getStop() + offset, stemLength);
	ASSERT(newStem.getLength() >= stemLength);
	result.push_back(newStem);
	totLen += newStem.getLength();
      }
      // cout << "Adding substem: " << newStem << endl;
    }
    index_type missing = stem.getLength() - totLen;
    // cout << "Still missing: " << missing << endl;
    ASSERT(missing >= stemLength);
    index_type offset = (numStems - 1) * stemLength;
    if (reverseMode) { // regular case of reverse compl stem
      Stem newStem(stem.getStart() + offset, stem.getStop() - offset, missing);
      result.push_back(newStem);
      totLen += newStem.getLength();
    } else {
      Stem newStem(stem.getStart() + offset, stem.getStop() + offset, missing);
      result.push_back(newStem);
      totLen += newStem.getLength();
    }
    ASSERT(totLen = stem.getLength());
    return result;
  }

  static bool expandStemTest() {
    Stem stem(100, 200, 5);
    Vec<Stem> stems = expandStem(stem, 2, true);
    cout << "Expanded stem " << stem << " to " << stems << " of stem lengths 2 to 3" << endl;
    ASSERT(stems.size() == 2);
    ASSERT(stems[0].getLength() == 2);
    ASSERT(stems[1].getLength() == 3);

    Vec<Stem> stems2 = expandStem(stem, 2, false);
    cout << "Expanded stem " << stem << " to " << stems2 << " of stem lengths 2 to 3" << endl;
    ASSERT(stems2.size() == 2);
    ASSERT(stems2[0].getLength() == 2);
    ASSERT(stems2[1].getLength() == 3);
    return true;
  }

  static stem_set readRNAfold(istream& is, string& sequence) {
    string line = getLine(is);
    ERROR_IF(line.size() == 0, "readRNAfoldUnexcected empty first line encountered.");
    if (line[0] == '>') {
      line = getLine(is); // skip header line
    }
    sequence.clear();
    sequence.append(line);
    string bracket = getLine(is);
    unsigned int pkCounter = 0;
    Vec<Vec<double> > matrix = secToMatrix(bracket, 1.0, pkCounter);
    Vec<Stem> stems = generateStemsFromMatrix(matrix, 1, 0.5, sequence);
    return stems;
  }

/*   static RnaSecondaryStructure readRNAfold2(istream& is) { */
/* /\*     string line = getLine(is); *\/ */
/* /\*     ERROR_IF(line.size() == 0, "readRNAfoldUnexcected empty first line encountered."); *\/ */
/* /\*     if (line[0] == '>') { *\/ */
/* /\*       line = getLine(is); // skip header line *\/ */
/* /\*     } *\/ */
/*     string sequence; //  = line; */
/*     string bracket; // = getLine(is); */
/*     is >> sequence; */
/*     is >> bracket; */
/*     unsigned int pkCounter = 0; */
/*     Vec<Vec<double> > matrix = secToMatrix(bracket, 1.0, pkCounter); */
/* /\*     cout << "Sequence: " << sequence << endl; *\/ */
/* /\*     cout << "Bracket: " << bracket << endl; *\/ */
/* /\*     cout << "Matrix size: " << matrix.size() << " Sequnece size: " << sequence.size() << " " *\/ */
/* /\* 	 << " Bracket size: " << bracket.size() << endl; *\/ */
/*     ASSERT(matrix.size() == sequence.size()); */

/*     Vec<Stem> stems = generateStemsFromMatrix(matrix, 1, 0.5, sequence); */
/*     return RnaSecondaryStructure(stems, sequence); */
/*   } */

  static RnaSecondaryStructure readRNAfold2(istream& is) {
     string line = getLine(is);
     ERROR_IF(line.size() == 0, "readRNAfoldUnexcected empty first line encountered.");
     if (line[0] == '>') {
       line = getLine(is); // skip header line */
     } 
    string sequence = line;
    vector<string> tokens = getTokens(getLine(is));
    cerr << "Warning: Bracket notation and energy expected in last line of RNAfold format." << endl;
    string bracket = tokens[0];
    unsigned int pkCounter = 0;
    Vec<Vec<double> > matrix = secToMatrix(bracket, 1.0, pkCounter);
/*     cout << "Sequence: " << sequence << endl; */
/*     cout << "Bracket: " << bracket << endl; */
/*     cout << "Matrix size: " << matrix.size() << " Sequnece size: " << sequence.size() << " " */
/* 	 << " Bracket size: " << bracket.size() << endl; */
    ASSERT(matrix.size() == sequence.size());

    Vec<Stem> stems = generateStemsFromMatrix(matrix, 1, 0.5, sequence);
    return RnaSecondaryStructure(stems, sequence);
  }

  /** reads data of the format used by RADAR as well as RNAevel
   * >sequence-name
   * ACUGUUAACCUUUUCCGAU
   * (((--)))----((--))-
   */
  static RnaSecondaryStructure readRNAeval(istream& is) {
     string line = getLine(is);
     ERROR_IF(line.size() == 0, "readRNAfoldUnexcected empty first line encountered.");
     if (line[0] == '>') {
       line = getLine(is); // skip header line */
     } 
    string sequence = line;
    string bracket = getLine(is);
    unsigned int pkCounter = 0;
    Vec<Vec<double> > matrix = secToMatrix(bracket, 1.0, pkCounter);
/*     cout << "Sequence: " << sequence << endl; */
/*     cout << "Bracket: " << bracket << endl; */
/*     cout << "Matrix size: " << matrix.size() << " Sequnece size: " << sequence.size() << " " */
/* 	 << " Bracket size: " << bracket.size() << endl; */
    ASSERT(matrix.size() == sequence.size());
    Vec<Stem> stems = generateStemsFromMatrix(matrix, 1, 0.5, sequence);
    return RnaSecondaryStructure(stems, sequence);
  }
  
  static matrix_type readWMatch(istream& is, size_type n) {
    matrix_type matrix(n, matrix_row_type(n, 0.0));
    size_type id1, id2;
    for (size_type i = 0; i < n; ++i) {
      is >> id1 >> id2;
      ERROR_IF(id1 < 1, "First id has to be greater zero!");
      --id1;
      if (id2 > 0) {
	--id2;
	ERROR_IF(id1 >= n, "First index too large!");
	ERROR_IF(id2 >= n, "Second index too large!");
	matrix[id1][id2] = 1.0;
	matrix[id2][id1] = 1.0;
      }
    }
    return matrix;
  }

  
};

#endif
