/* -*- mode: c++; c-basic-offset: 3; -*- */
#include "wtile.hh"
#include "matlab_fcs.hh"
#include <constant.hh>
#include "DVecType.hh"
#include "Bits.hh"
#include "fSeries/DFT.hh"
#include <cmath>
#include <iostream>
#include <sstream>

using namespace std;
using namespace wpipe;
using namespace containers;

//
//    Notes on variables used here:
//    timeRange: Length of time series that will be ffted
//    minimumFrequencyStep: frequency step in dft of time series
//

// conversion factor from Q prime to true Q
const double qPrimeToQ = sqrt(11);
const double dInf = 1.0 / 0.0;

//======================================  Empty wtile constructor.
wtile::wtile(void) 
   : _debugLevel(0)
{
}

//======================================  wtile constructor.
wtile::wtile(double timeRange, const dble_vect& qRange, 
	     const dble_vect& frequencyRange, double sampleFrequency, 
	     double maximumMismatch, double highPassCutoff, 
	     double lowPassCutoff, double whiteningDuration, 
	     double transientFactor, int debug) {
   init(timeRange, qRange, frequencyRange, sampleFrequency, maximumMismatch, 
	highPassCutoff, lowPassCutoff, whiteningDuration, transientFactor,
	debug);
}

//======================================  Initialize w-transform tiling 
void
wtile::init(double timeRange, const dble_vect& qRange, 
	    const dble_vect& frequencyRange, double sampleFrequency, 
	    double maximumMismatch, double highPassCutoff, 
	    double lowPassCutoff, double whiteningDuration, 
	    double transientFactor, int debug) {
   
   // apply default arguments
   _debugLevel = debug;


   // extract minimum and maximum Q from Q range
   double minimumQ = qRange[0];
   double maximumQ = qRange[1];

   // extract minimum and maximum frequency from frequency range
   double minimumFrequency = 0;
   double maximumFrequency = dInf;
   if (!frequencyRange.empty()) {
      minimumFrequency = frequencyRange[0];
      if (frequencyRange.size() == 1) maximumFrequency = frequencyRange[0];
      else                            maximumFrequency = frequencyRange[1];
   }

   ////////////////////////////////////////////////////////////////////////////
   //                          compute derived parameters                    //
   ////////////////////////////////////////////////////////////////////////////

   // nyquist frequency
   double nyquistFrequency = sampleFrequency / 2;

   // maximum mismatch between neighboring tiles
   double mismatchStep = 2 * sqrt(maximumMismatch / 3);

   // maximum possible time resolution
   //double minimumTimeStep = 1 / sampleFrequency;

   // total number of samples in input data
   double numberOfSamples = timeRange * sampleFrequency;

   /////////////////////////////////////////////////////////////////////////////
   //                       determine parameter constraints                   //
   /////////////////////////////////////////////////////////////////////////////

   // minimum allowable Q prime to prevent window aliasing at zero frequency
   const double minimumAllowableQPrime = 1.0;

   // minimum allowable Q to avoid window aliasing at zero frequency
   double minimumAllowableQ = minimumAllowableQPrime * qPrimeToQ;

   // reasonable number of statistically independent tiles in a frequency row
   double minimumAllowableIndependents = 50;

   // maximum allowable mismatch parameter for reasonable performance
   double maximumAllowableMismatch = 0.5;

   /////////////////////////////////////////////////////////////////////////////
   //                             validate parameters                         //
   /////////////////////////////////////////////////////////////////////////////

   // check for valid time range
   if (timeRange < 0) {
      error("negative time range");
   }

   // check for valid Q range
   if (minimumQ > maximumQ) {
      error("minimum Q exceeds maximum Q");
   }

   // check for valid frequency range
   if (minimumFrequency > maximumFrequency) {
      error("minimum frequency exceeds maximum frequency");
   }

   // check for valid minimum Q
   if (minimumQ < minimumAllowableQ) {
      ostringstream msg;
      msg << "minimum Q (" << minimumQ << ") less than minimum allowable: " 
	  << minimumAllowableQ;
      error(msg.str());
   }

   // check for reasonable maximum mismatch parameter
   if ( maximumMismatch > maximumAllowableMismatch) {
      ostringstream msg;
      msg << "maximum mismatch (" << maximumMismatch 
	  << ") exceeds maximum allowable: " << maximumAllowableMismatch;
      error(msg.str());
   }

   // check for integer power of two data length
   //if (fmod(log(timeRange * sampleFrequency) / log(2), 1) != 0) {
   if (!is_power_of_2(timeRange * sampleFrequency)) {
      error("data length is not an integer power of two");
   }

   ////////////////////////////////////////////////////////////////////////////
   //                              determine Q planes                        //
   ////////////////////////////////////////////////////////////////////////////

   // cumulative mismatch across Q range
   double qCumulativeMismatch = log(maximumQ / minimumQ) / sqrt(2);

   // number of Q planes
   int numberOfPlanes = int(ceil(qCumulativeMismatch / mismatchStep));

   // insure at least one plane
   if (numberOfPlanes == 0) {
      numberOfPlanes = 1;
   }

   // mismatch between neighboring planes
   double qMismatchStep = qCumulativeMismatch / numberOfPlanes;

   // index of Q planes
   // double qIndices = 0.5 : numberOfPlanes - 0.5;

   // vector of Qs
   //double qs = minimumQ * exp(sqrt(2) * qIndices * qMismatchStep);
   double minQ = minimumQ * exp(sqrt(2) * 0.5 * qMismatchStep);
   double maxQ = minimumQ * exp(sqrt(2) * (numberOfPlanes-0.5) * qMismatchStep);

   ////////////////////////////////////////////////////////////////////////////
   //                             validate frequencies                       //
   ////////////////////////////////////////////////////////////////////////////

   // minimum allowable frequency to provide sufficient statistics
   double minimumAllowableFrequency = minimumAllowableIndependents * maxQ / 
      (2 * pi * timeRange);

   // maximum allowable frequency to avoid window aliasing
   double maximumAllowableFrequency = nyquistFrequency / (1 + qPrimeToQ / minQ);

   // check for valid minimum frequency
   if (maximumFrequency < 0.0) maximumFrequency = dInf;
   if ((minimumFrequency!=0 && minimumFrequency<minimumAllowableFrequency) ||
       (maximumFrequency!=dInf && maximumFrequency>maximumAllowableFrequency)) {
      cerr << "Requested frequency range (" << minimumFrequency << "-" 
	   << maximumFrequency <<") is not within allowable range (" 
	   << minimumAllowableFrequency << "-" << maximumAllowableFrequency
	   << ")" << endl;
      error("Requested frequency range is not within allowable range");
   }

   ////////////////////////////////////////////////////////////////////////////
   //                     create Q transform tiling structure                //
   ////////////////////////////////////////////////////////////////////////////

   // structure type identifier
   _id = "Discrete Q-transform tile structure";

   // insert duration into tiling structure
   _duration = timeRange;

   // insert minimum Q into tiling structure
   _minimumQ = minimumQ;

   // insert maximum Q into tiling structure
   _maximumQ = maximumQ;

   // insert minimum frequency into tiling structure
   _minimumFrequency = minimumFrequency;

   // insert maximum frequency into tiling structure
   _maximumFrequency = maximumFrequency;

   // insert sample frequency into tiling structure
   _sampleFrequency = sampleFrequency;

   // insert maximum loss due to mismatch into tiling structure
   _maximumMismatch = maximumMismatch;

   // insert Q vector into tiling structure
   // _qs = qs;

   // insert number of Q planes into tiling structure
   _numberOfPlanes = numberOfPlanes;

   // initialize cell array of Q plans in tiling structure
   //_planes = cell(1, numberOfPlanes);
   _planes.resize(numberOfPlanes);

   // initialize total number of tiles counter
   _numberOfTiles = 0;

   // initialize total number of independent tiles counter
   _numberOfIndependents = 0;

   // initialize total number of flops counter
   _numberOfFlops = double(numberOfSamples) * log(double(numberOfSamples));

   ////////////////////////////////////////////////////////////////////////////
   //                           begin loop over Q planes                     //
   ////////////////////////////////////////////////////////////////////////////

   // begin loop over Q planes
   for (int plane=0; plane < numberOfPlanes; plane++) {

      // calculate Q of plane
      double q = minimumQ * exp(sqrt(2) * (plane + 0.5) * qMismatchStep);
      _planes[plane].init(q, timeRange, minimumAllowableIndependents,
			  nyquistFrequency, minimumFrequency, maximumFrequency,
			  mismatchStep);

      // increment total number of tiles counter
      _numberOfTiles += _planes[plane].numberOfTiles;

      // increment total number of independent tiles counter
      _numberOfIndependents += int(_planes[plane].numberOfIndependents * 
				   (1 + qCumulativeMismatch) / numberOfPlanes);

      // increment total number of flops counter
      _numberOfFlops += _planes[plane].numberOfFlops;

   } //                             end loop over Q planes

   /////////////////////////////////////////////////////////////////////////////
   //                         determine filter properties                     //
   /////////////////////////////////////////////////////////////////////////////
   //
   // high pass filter cutoff frequency [lowest minimum frequency]
   if (highPassCutoff < 0) _highPassCutoff = defaultHighPassCutoff();
   else                    _highPassCutoff = highPassCutoff;

   // low pass filter cutoff frequency [highest maximum frequency]
   if (lowPassCutoff < 0) _lowPassCutoff = defaultLowPassCutoff();
   else                   _lowPassCutoff = lowPassCutoff;
 
   // whitening filter duration [maximum q / (2 * fMin)
   if (whiteningDuration <= 0) _whiteningDuration = defaultWhiteningDuration();
   else                        _whiteningDuration = whiteningDuration;

   // estimated duration of filter transients to supress
   _transientDuration = transientFactor * _whiteningDuration;

   // test for insufficient data
   if ((2 * _transientDuration) >= _duration) {
      error("duration of filter transients equals or exceeds data duration");
   }

   /////////////////////////////////////////////////////////////////////////////
   //                          return Q transform tiling                      //
   /////////////////////////////////////////////////////////////////////////////
}

//======================================  Filter properties...
double 
wtile::defaultHighPassCutoff(void) const {
   double dMin = 0.0;
   for (const_qplane_iter p=_planes.begin(); p != _planes.end(); p++) {
      if (p == _planes.begin()) dMin = p->minimumFrequency;
      else                      dMin = min(dMin, p->minimumFrequency );
   }
   return dMin;
}

double 
wtile::defaultLowPassCutoff(void) const {
   double dMax = 0.0;
   for (const_qplane_iter p=_planes.begin(); p != _planes.end(); p++) {
      if (p == _planes.begin()) dMax = p->maximumFrequency;
      else                      dMax = max(dMax, p->maximumFrequency );
   }
   return dMax;
}

double 
wtile::defaultWhiteningDuration(void) const {
   double dMax = 0.0;
   for (const_qplane_iter p=_planes.begin(); p != _planes.end(); p++) {
      dMax = max(dMax, p->defaultWhiteningDt() );
   }
   return dMax;
}

//======================================  Destructor.
wtile::~wtile(void) {
}

//======================================  Find nearest q-plane
size_t
wtile::nearest_plane(double qTest) const {
   size_t pBest = 0;
   if (_numberOfPlanes != 0 && qTest > 0) {
      double dq = fabs(log(_planes[0].q/qTest));
      for (int i=1; i < _numberOfPlanes; i++) {
	 double dqi = fabs(log(_planes[i].q/qTest));
	 if (dqi < dq) {
	    pBest = i;
	    dq = dqi;
	 }
      }
   }
   return pBest;
}

//======================================  Calculate threshold
double
wtile:: threshold_from_rate(double rate) const {
   double ind_rate = independentsRate();
   if (rate <= 0.0 || ind_rate == 0) {
      cout << "rate = " << rate << " ind_rate = " << ind_rate << endl;
      error("Insufficient information to calculate event threshold");
   }
   return -log(rate / ind_rate);
}

//======================================  Construct a row structure.
qrow::qrow(void)
{}

//======================================  Destroy a row structure.
qrow::~qrow(void) {
}

//======================================  Initialize a row structure
void 
qrow::init(double q, double f, double timeRange, double nyquistFrequency,
	   double fStep, double mismatchStep) {
   ///////////////////////////////////////////////////////////////////////
   //                      determine tile properties                    //
   ///////////////////////////////////////////////////////////////////////
   frequency = f;
   double qPrime = q / qPrimeToQ;
   double minimumFrequencyStep = 1.0 / timeRange;
   long numberOfSamples = long(2.0 * nyquistFrequency * timeRange + 0.5);

   // bandwidth for coincidence testing
   bandwidth = 2 * sqrt(pi) * frequency / q;

   // duration for coincidence testing
   duration = 1 / bandwidth;

   // frequency step for integration
   frequencyStep = fStep;

   ///////////////////////////////////////////////////////////////////////
   //                         determine tile times                      //
   ///////////////////////////////////////////////////////////////////////

   // cumulative mismatch across time range
   double timeCumulativeMismatch = timeRange * 2 * pi * frequency / q;

   // number of time tiles
   long logNumberOfTiles = nextpow2(timeCumulativeMismatch/mismatchStep);
   numberOfTiles = 1 << logNumberOfTiles;

   // mismatch between neighboring time tiles
   // double timeMismatchStep = timeCumulativeMismatch / numberOfTiles;

   // time step for integration
   // double timeStep = q * timeMismatchStep / (2 * pi * frequency);
   timeStep = timeRange / numberOfTiles;

   // number of flops to compute row
   numberOfFlops = numberOfTiles * logNumberOfTiles * log(2.0);

   // number of independent tiles in row
   numberOfIndependents = 1 + timeCumulativeMismatch;

   //////////////////////////////////////////////////////////////////////
   //                           generate window                        //
   //////////////////////////////////////////////////////////////////////

   // half length of window in samples
   int halfWindowLength = int((frequency / qPrime) / minimumFrequencyStep);

   // full length of window in samples
   int windowLength = 2 * halfWindowLength + 1;

   // sample index vector for window construction
   // windowIndices = -halfWindowLength : halfWindowLength;

   DVectD dvd(windowLength);
   dvd.replace_with_zeros(0, windowLength, windowLength);

   // dimensionless frequency vector for window construction
   // double windowArgument = windowFrequencies * qPrime / frequency;

   // bi square window function
   // window = (1 - windowArgument.^2).^2;
   dvd[halfWindowLength] = 1.0;
   for (int i=1; i<=halfWindowLength; ++i) {
      double wFreq = minimumFrequencyStep * double(i);
      double wArg  = wFreq * qPrime / frequency;
      double wVal  = pow(1.0 - wArg * wArg, 2);
      dvd[halfWindowLength+i] = wVal;
      dvd[halfWindowLength-i] = wVal;
   }

   // row normalization factor
   double rowNormalization = sqrt((315 * qPrime) / (128 * frequency));

   // inverse fft normalization factor
   double ifftNormalization = double(numberOfTiles) / numberOfSamples;

   // normalize window
   // window = window * ifftNormalization * rowNormalization * 
   //          planeNormalization;
   dvd *= ifftNormalization * rowNormalization;

   //  Build the fSeries
   
   window = winptr_type(new DFT);
   static_cast<fSeries&>(*window) =
      fSeries(-halfWindowLength * minimumFrequencyStep, minimumFrequencyStep,
	      Time(0), timeRange, dvd);

   // number of zeros to append to windowed data
   zeroPadLength = numberOfTiles - windowLength;
}

//======================================  display a row structure.
ostream& 
qrow::display(std::ostream& out) const {
   out << "frequency: "     << frequency << endl;
   out << "duration: "      << duration << endl;
   out << "bandwidth: "     << bandwidth << endl;
   out << "timeStep: "      << timeStep << endl;
   out << "frequencyStep: " << frequencyStep << endl;
   //out << "window:        ";   window->Dump(out) << endl;
   out << "zeroPadLength: " << zeroPadLength << endl;
   out << "numberOfTiles: " << numberOfTiles << endl;
   out << "numberOfIndependents: " << int(numberOfIndependents) << endl;
   out << "numberOfFlops: " << numberOfFlops << endl;
   return out;
}

//=======================================  Calculate a q-transform for one row
TSeries
qrow::tileCoeffs(const DFT& data) const {

   //-----------------------------------  extract and window in-band data
   double f0Win  = window->getLowFreq();
   double fDf    = window->getFStep();
   long inxLow   = data.getBin(frequency + f0Win);
   DFT windowedData;
   static_cast<fSeries&>(windowedData) = 
      fSeries(f0Win, fDf, data.getStartTime(), data.getDt(),
              data.refDVect().Extract(inxLow, window->size()));
   windowedData.setSampleTime(data.getSampleTime());
   windowedData.refDVect().reserve(numberOfTiles);
   windowedData *= *window;

   double fMax = 0.5 * numberOfTiles * fDf;
   windowedData.extend(fMax);
   windowedData.extend(-fMax);

   // complex valued tile coefficients
   TSeries tileCoefficients = windowedData.iFFT();

   //------------------------------------  fix for matlab unnormalized DFTs
   double tileNorm = 0.5 / fMax;
#ifndef UNNORMALIZED_DFTS
   tileNorm /= double(data.getSampleTime());
#endif
   tileCoefficients *= tileNorm;
   return tileCoefficients;
}

//======================================  Tile plane constructor
qplane::qplane(void) 
   : numberOfTiles(0), numberOfIndependents(0), numberOfFlops(0)
{
}

//======================================  Tile plane destructor
qplane::~qplane(void) {
}

//======================================  Set plane properties
void
qplane::init(double q0, double timeRange, long minimumAllowableIndependents,
	     double nyquistFrequency, double reqFmin, double reqFmax, 
	     double mismatchStep) {

   // best possible frequency resolution
   double minimumFrequencyStep = 1 / timeRange;

   //-----------------------------------  Store the Q value, calculate Q'.
   q = q0;
   double qPrime = q / qPrimeToQ;

   // for large qPrime use asymptotic value of planeNormalization
   if (qPrime > 10) {
      normalization = 1;
   } 

   // otherwise
   else {
      // otherwise, use polynomial coefficients for plane normalization factor
      double coefficients[9];
      coefficients[0] = + 1. * log((qPrime + 1) / (qPrime - 1));
      coefficients[1] = - 2.; 
      coefficients[2] = - 4. * log((qPrime + 1) / (qPrime - 1));
      coefficients[3] =  22.0 / 3.0; 
      coefficients[4] = + 6. * log((qPrime + 1) / (qPrime - 1));
      coefficients[5] = - 146.0 / 15.0; 
      coefficients[6] = - 4. * log((qPrime + 1) / (qPrime - 1));
      coefficients[7] = + 186.0 / 35.0;
      coefficients[8] = + 1. * log((qPrime + 1) / (qPrime - 1));

      // plane normalization factor
      double pval = polyval(coefficients, 9, qPrime);
      normalization = sqrt(256.0 / (315.0 * qPrime * pval));
   }

   /////////////////////////////////////////////////////////////////////////
   //                        determine frequency rows                     //
   /////////////////////////////////////////////////////////////////////////
   
   // minimum allowable frequency to provide sufficient statistics
   double minimumAllowableFrequency = minimumAllowableIndependents*q / 
      (2*pi * timeRange);

   // plane specific maximum allowable frequency to avoid window aliasing
   double maximumAllowableFrequency = nyquistFrequency / (1 + qPrimeToQ / q);

   // use plane specific minimum allowable frequency if (requested
   minimumFrequency = reqFmin;
   if (reqFmin == 0) {
      minimumFrequency = minimumAllowableFrequency;
   }

   // use plane specific maximum allowable frequency if (requested
   maximumFrequency = reqFmax;
   if (reqFmax < 0 || reqFmax == dInf) {
      maximumFrequency = maximumAllowableFrequency;
   }

   // cumulative mismatch across frequency range
   double frequencyCumulativeMismatch = 
      log(maximumFrequency / minimumFrequency) * sqrt(2 + q*q) / 2;

   // number of frequency rows
   numberOfRows = int(ceil(frequencyCumulativeMismatch / mismatchStep));

   // insure at least one row
   if (numberOfRows == 0) {
      numberOfRows = 1;
   }

   // mismatch between neighboring frequency rows
   double frequencyMismatchStep = frequencyCumulativeMismatch / numberOfRows;

   // ratio between successive frequencies
   double logfRatio = (2 / sqrt(2 + q*q)) * frequencyMismatchStep;
   double fRatio = exp(logfRatio);

   /////////////////////////////////////////////////////////////////////////
   //                   create Q transform plane structure                //
   /////////////////////////////////////////////////////////////////////////

   // initialize cell array of frequency rows into Q plane structure
   rows.resize(numberOfRows);

   /////////////////////////////////////////////////////////////////////////
   //                      begin loop over frequency rows                 //
   /////////////////////////////////////////////////////////////////////////

   // begin loop over frequency rows
   for (int row=0; row < numberOfRows; row++) {

      // extract frequency of row from frequency vector
      //frequency = frequencies(row);
      double frequency = minimumFrequency * exp(logfRatio * (row + 0.5));
      frequency=long(frequency/minimumFrequencyStep+0.5)*minimumFrequencyStep;

      double fStep = frequency * (fRatio - 1) / sqrt(fRatio);
      rows[row].init(q, frequency, timeRange, nyquistFrequency, fStep, mismatchStep);

      // increment number of tiles in plane counter
      numberOfTiles += rows[row].numberOfTiles;

      // increment number of independent tiles in plane counter
      numberOfIndependents += rows[row].numberOfIndependents * 
	 (1 + frequencyCumulativeMismatch) / numberOfRows;

      // increment number of flops in plane counter
      numberOfFlops += rows[row].numberOfFlops;

      //////////////////////////////////////////////////////////////////////
      //                       end loop over frequency rows               //
      //////////////////////////////////////////////////////////////////////
   }
}

//======================================  display a row structure.
ostream& 
qplane::display(std::ostream& out) const {
   out << "q:             " << q << endl;
   out << "minimumFrequency:     " << minimumFrequency << endl;
   out << "maximumFrequency:     " << maximumFrequency << endl;
   out << "normalization:        " << normalization << endl;
   out << "numberOfRows:         " << numberOfRows << endl;
   for (int i=0; i<numberOfRows; i++) {
      out << "rows[" << i << "]:" << endl;
      rows[i].display(out);
   }
   out << "numberOfTiles:        " << numberOfTiles << endl;
   out << "numberOfIndependents: " << numberOfIndependents << endl;
   out << "numberOfFlops:        " << numberOfFlops << endl;
   return out;
}


//======================================  display a tiling structure.
ostream& 
wtile::display(std::ostream& out) const {
   out << "id:                   " << _id << endl;
   out << "duration:             " << _duration << endl;
   out << "minimumQ:             " << _minimumQ << endl;
   out << "maximumQ:             " << _maximumQ << endl;
   out << "minimumFrequency:     " << _minimumFrequency << endl;
   out << "maximumFrequency:     " << _maximumFrequency << endl;
   out << "sampleFrequency:      " << _sampleFrequency << endl;
   out << "maximumMismatch:      " << _maximumMismatch << endl;
   out << "numberOfPlanes:       " << _numberOfPlanes << endl;
   for (int i=0; i<_numberOfPlanes; i++) {
      out << "planes[" << i << "]:" << endl;
      _planes[i].display(out);
   }
   out << "numberOfTiles:        " << _numberOfTiles << endl;
   out << "numberOfIndependents: " << _numberOfIndependents << endl;
   out << "numberOfFlops:        " << _numberOfFlops << endl;
   out << "highPassCutoff:       " << _highPassCutoff << endl;
   out << "lowPassCutoff:        " << _lowPassCutoff << endl;
   out << "whiteningDuration:    " << _whiteningDuration << endl;
   out << "transientDuration:    " << _transientDuration << endl;
   return out;
}
