#include <MAFSearchTables.h>
#include <NucleotideTools.h>
#include <debug.h>
#include <limits.h>
#include <algorithm>
#include <iterator>
#include <StringTools.h>
#include <vector>

// given a list of hash tables, find the one with 
// attributes "char1", "char2", "assembly1", "assembly2" set to c1, c2, assembly1, assembly2
// careful: does not find other way round!
string
MAFSearchTables::createHashTableHash(const string& assembly1, const string& assembly2, char c1, char c2) {
   ASSERT(assembly1.size() > 0);
   ASSERT(assembly2.size() > 0);
   ASSERT(assembly1 != assembly2);
   string result;
   if (assembly2 < assembly1) {
     result = createHashTableHash(assembly2, assembly1, c2, c1);
   } else {
     string sep = "_";
     c1 = toupper(c1);
     c2 = toupper(c2);
     result = assembly1 + sep + assembly2 + sep + c1 + sep + c2;
   };
   return result;
}

/** Parses hash table name of form hg18_mm8_A_C to assembly1=hg18 and assembly2=mm8, c1=A, c2=C */
void
MAFSearchTables::parseHashTableHash(const string& hash, string& assembly1, string& assembly2, char& c1, char& c2) {
  vector<string> words = getTokens(hash, "_");
  ASSERT(words.size() == 4);
  assembly1 = words[0];
  assembly2 = words[1];
  ASSERT(words[2].size() == 1);
  ASSERT(words[3].size() == 1);
  c1 = words[2][0];
  c2 = words[3][0];
}

//  string
//  MAFSearchTables::createHashTableHash(const string& assembly1, char c1) {
//    ASSERT(assembly1.size() > 0);
//    string sep = "_";
//    c1 = toupper(c1);
//    string result = assembly1 + sep + c1;
//    return result;
//  }

void
MAFSearchTables::testCreateHashTableHash() {
  string assembly1 = "hg18";
  string assembly2 = "mm8";
  char c1 = 'A';
  char c2 = 'c';
  ASSERT(createHashTableHash(assembly1,assembly2, c1, c2) == "hg18_mm8_A_C");
}

/** Generates search tables that return for keys like "hg18_mm8_A_C" all column ids, that have a certain two-character motif
 */
void
MAFSearchTables::createSearchHashTable(MAFAlignment *maf, const set<string>& assemblies) {
  PRECOND(positionHashes.size() == 0);
  PRECOND(assemblies.size() > 1);
  if (verbose > 0) {
    cout << "Starting createSearchHashTable for " << assemblies.size() << " assemblies: " << endl;
    for (set<string>::const_iterator it = assemblies.begin(); it != assemblies.end(); it++) {
      cout << (*it) << " ";
    }
    cout << endl;
  }
  const string& residues = maf->getResidues();
  if (verbose > 1) {
    REMARK << "Generating column index vectors..." << endl;
  };
  Vec<RankedSolution5<string, string> > queue;
  queue.reserve((assemblies.size() * (assemblies.size() - 1)) / 2);
  for (set<string>::const_iterator i = assemblies.begin(); i != assemblies.end(); i++) {
    ASSERT((*i).size() > 0);
    // loop over assembly 2
    for (set<string>::const_iterator j = i; j != assemblies.end(); j++) {
      ASSERT((*j).size() > 0);
      if (i == j) {
	continue; // do not allow same assemblies
      }
      double score = estimateAssemblyPairHashSize(*i,*j); // the smaller estimated size the better
      RankedSolution5<string, string> item(score, *i, *j);
      queue.push_back(item);
    }
  }
  sort(queue.begin(), queue.end());
  if (verbose > 2) {
    cout << "Successfully sorted " << queue.size() << " assembly pairs!" << endl;
    for (size_type i = 0; i < queue.size(); ++i) {
      cout << (i+1) << " " << queue[i] << endl;
      if (i > 10) {
	cout << "..." << endl;
	break;
      }
    }
  }
  size_type numAssemUsed = static_cast<size_type>(assemblyPairFraction * queue.size());
  ASSERT(numAssemUsed > 0);
  map<string, string> assemSeqs;
  set_type dummySet;
  Vec<Vec<set_type> > setReferences(residues.size(), Vec<set_type>(residues.size(), dummySet)); // , &dummySet));
  for (size_type i = 0; i < numAssemUsed; ++i) { // loop over pairs
    string assem1 = queue[i].getSecond();
    string assem2 = queue[i].getThird();
    ASSERT(assem1.size() > 0);
    ASSERT(assem2.size() > 0);
    if (assemSeqs.find(assem1) == assemSeqs.end()) {
      assemSeqs[assem1] = maf->generateAssemblySequence(assem1);
    }
    if (assemSeqs.find(assem2) == assemSeqs.end()) {
      assemSeqs[assem2] = maf->generateAssemblySequence(assem2);
    }
    ASSERT(assemSeqs.find(assem1) != assemSeqs.end());
    ASSERT(assemSeqs.find(assem2) != assemSeqs.end());
    ASSERT(static_cast<length_type>(assemSeqs[assem1].size()) == maf->getTotalLength());
    ASSERT(static_cast<length_type>(assemSeqs[assem2].size()) == maf->getTotalLength());
    ASSERT(assem1 != assem2);
    if (verbose > 1) {
      cout << "Working on assembly pair " << assem1 << " " << assem2 << " score: " << queue[i].getFirst();
    }
    // loop over character 1
    const string& seq1 = assemSeqs[assem1];
    const string& seq2 = assemSeqs[assem2];
    //  string::size_type n = seq1.size();
    // TRICKY Section!
    // The idea is, to avoid looking up the correct set using "createHashTableHash" for each residue. 
    // Instead we generate a 4x4 array of pointers to the correct sets.
    // the array is filled according to the "residues" variable order.
    // Example: setReferences[2][3] contains a pointer to the set corresponding to the current assembly pair and the residues 'G' and 'T' (assuming A,C,G,T alphabet)
    
    length_type countDiff = 0;
    for (length_type j = searchRangeMin; j < searchRangeMax; ++j) { // loop over whole searched region
      if (seq1[j] != seq2[j]) {
	++countDiff;
      }
    }
    countDiff /= (residues.size() * residues.size()); // estimate number of hits for each set
    for (size_type j = 0; j < residues.size(); ++j) {
      for (size_type k = 0; k < residues.size(); ++k) {
	if (countDiff > 0) {
	  setReferences[j][k].reserve(countDiff);
	}
	setReferences[j][k].clear();
      }
    }
    Vec<int> charHash(256, -1); // again, avoid looking up the mapping from character to the numbers 0...3 in a costly fashion
    for (string::size_type j = 0; j < residues.size(); ++j) {
      charHash[static_cast<int>(residues[j])] = j;
    }
    ASSERT(charHash[static_cast<int>('A')] == 0);
    ASSERT(charHash[static_cast<int>('C')] == 1);
    ASSERT(charHash[static_cast<int>('G')] == 2);
    ASSERT(charHash[static_cast<int>('T')] == 3);
    int resId1, resId2;
    for (length_type j = searchRangeMin; j < searchRangeMax; ++j) { // loop over whole searched region
      resId1 = charHash[static_cast<int>(seq1[j])];
      resId2 = charHash[static_cast<int>(seq2[j])];
      if ((resId1 >= 0) && (resId2 >= 0) && (resId1 != resId2)) {
	ASSERT(resId1 < static_cast<int>(setReferences.size()));
	ASSERT(resId2 < static_cast<int>(setReferences[resId2].size()));
	ASSERT(resId1 != resId2); 
	// setReferences[resId1][resId2]->insert(j);
	setReferences[resId1][resId2].push_back(j);
      } // else : not found
    }
    if (verbose > 2) {
      cout << "Computed hash tables for this assembly pair. Storing compressed version..." << endl; 
    }
    for (size_type j = 0; j < residues.size(); ++j) {
      for (size_type k = 0; k < residues.size(); ++k) {
	if (j != k) {
	  ASSERT(assem1.size() > 0);
	  ASSERT(assem2.size() > 0);
	  positionHashes[createHashTableHash(assem1,assem2,residues[j],residues[k])] = 
	    compressed_type(setReferences[j][k]);
	  // compressSet(setReferences[j][k]); // store COMPRESSED set!
	  // setReferences[j][k].clear(); // not needed anymore
	}
      }
    }
    if (verbose > 2) {
      cout << "Stored compressed hash tables for this assembly pair!" << endl; 
    }
    if (verbose > 1) {
      size_type sz = 0;
      for (size_type j = 0; j < residues.size(); ++j) {
	for (size_type k = 0; k < residues.size(); ++k) {
	  if (j != k) {
	    ASSERT(assem1.size() > 0);
	    ASSERT(assem2.size() > 0);
	    string hashhash = createHashTableHash(assem1,assem2,residues[j],residues[k]);
	    ASSERT(positionHashes.find(hashhash) != positionHashes.end());
	    sz += positionHashes[hashhash].size();
	    if (verbose > 4) {
	      cout << "Uncompressing " << positionHashes[hashhash] << endl;
	      set_type outset = uncompressSet(positionHashes[hashhash]);
	      
	      cout << "Content of search table set " << hashhash << " : " << endl;
	      for (set_type::iterator it = outset.begin(); it != outset.end(); ++it) {
		cout << (*it) << " ";
	      }
	      cout << endl;
	    }
	  }
	}
      }
      cout << " Actual size of " << ((residues.size()*(residues.size()-1))) << " sets: " << sz << endl;
    }
  }
  // resetPositionHashStarts();
  if (verbose > 0) {
    cout << "Generated " << positionHashes.size() << " hash tables." << endl;
  }
  // return result;
}

/** Estimates the potential number of hash table entries of two assemblies
 */
double
MAFSearchTables::estimateAssemblyPairHashSize(const string& assem1, const string& assem2) const {
  set<string> assems;
  assems.insert(assem1);
  assems.insert(assem2);
  length_type numDiv = 1000;
  length_type stride = maf->getTotalLength() / numDiv;
  if (stride == 0) {
    stride = 1;
  }
  length_type countStored = 0;
  length_type countTotal = 0;
  for (length_type colid = 0; colid < maf->getTotalLength(); colid += stride, ++countTotal) {
    string slice = maf->getSlice(colid, assems);
    ASSERT(slice.size() == 2);
    if (slice[0] != slice[1]) {
      if (!(NucleotideTools::isGap(slice[0]) || NucleotideTools::isGap(slice[1]) ) ) {
	++countStored;
      }
    }
  }
  // use pseudocounts for fraction:
  double result = (static_cast<double>(countStored+1)/static_cast<double>(countTotal+2)) * maf->getTotalLength();
  ASSERT((result >= 0) && (result <= maf->getTotalLength()));
  return result;
}

/** Returns size of intersection set */
MAFSearchTables::size_type
MAFSearchTables::intersectionSize(const string& hash1, const string& hash2) const {
  ASSERT(hash1 != hash2);
  string doubleHash = createDoubleHash(hash1, hash2); 
  map<string, size_type>::iterator it = intersectionSizeMap.find(doubleHash);
  if (it != intersectionSizeMap.end()) {
    return it -> second;
  }
  const compressed_type& set1 = getSet(hash1);
  const compressed_type& set2 = getSet(hash2);
  Vec<length_type> resultSet;
  set_intersection(set1.begin(), set1.end(), set2.begin(), set2.end(), back_inserter(resultSet));
  size_type newSize = resultSet.size();
  intersectionSizeMap[doubleHash] = newSize;
  return newSize;
}
