/* LICENSE:
  =========================================================================
    CMPack'04 Source Code Release for OPEN-R SDK 1.1.5-r2 for ERS7
    Copyright (C) 2004 Multirobot Lab [Project Head: Manuela Veloso]
    School of Computer Science, Carnegie Mellon University
    All rights reserved.
  ========================================================================= */

#include "Constants.h"

#include <algorithm>
#ifdef PLATFORM_LINUX
#include <fstream>
#endif
#include <iostream>
#include <math.h>
#include <stdio.h>

using namespace std;

#include "../../headers/CircBufPacket.h"
#include "../../headers/field.h"
#include "../../headers/random.h"
#include "Environment.h"
#include "Functions.h"
#include "LocalizationEngine.h"
#include "Primitives.h"

#ifdef PLATFORM_APERIOS
#include "../../headers/DogTypes.h"
//#include "../headers/CircularBuffer.h"
#include "../../headers/Config.h"
//extern CircularBuffer *OutputBuf;
#endif

extern bool sample_from_sensors;

static const bool dump_primitive_effects=false;

#ifdef DEBUG_LOCALIZATION_LINUX
static const bool debug=true;
#else
static const bool debug=false;
#endif
static const bool debug_aperios=true;
static const bool dump_dist_files=false;
static const bool dump_marker_conf=false;

static const double AMCLFastFrac  = .2  ;
static const double AMCLSlowFrac  = .001;
static const double AMCLResetMult = 2.0;
static const int MaxSampleGen = 50; // maximum number of samples to generate in one frame

const int LocaleSampled::numSamples;

char *
makeFileName() {
  static char buf[256];
  static int num=0;
  sprintf(buf,"dist_data/d%04d.dat",num++);
  //printf("made file name '%s'\n",buf);
  return buf;
}

static const Sample LENearOldValueDev = {
  1.0,
  Sample::vector2e(200.0,200.0),
  .5,
  Sample::vector2e(0.0,0.0),
  0.0
};

LocalizationEngine::LocalizationEngine(Environment *environ_param) :
  environ(environ_param)
{
  locale.samples=new Sample[locale.numSamples];

  near_old_eval.dev = LENearOldValueDev;

  reset();
}

void LocalizationEngine::init()
{
  sampler.init();
}

void LocalizationEngine::reset()
{
  avg_prob_fast = 1.0;
  avg_prob_slow = 1.0;
  num_samples_generated_last_update = 0;

  pos_cache_valid = false;
  mzero(pos_cache_mean);
  mzero(pos_cache_std_dev);
  pos_cache_cov_xy = 0.0;
}

void
LocalizationEngine::setPos(double *pos, double *dev) {
  pos_cache_valid = false;

  for(int sample_idx=0; sample_idx<locale.numSamples; sample_idx++) {
    Sample *samp=&locale.samples[sample_idx];
    samp->weight = 1.0;
    samp->loc.x = pos[0]+dev[0]*GlobalRandom.gaussian32();
    samp->loc.y = pos[1]+dev[1]*GlobalRandom.gaussian32();
    samp->angle = pos[2]+dev[2]*GlobalRandom.gaussian32();
  }
  locale.totalWeight=locale.numSamples;
  locale.totalWeightValid=true;
  locale.samplesNormalized=true;
}

void
LocalizationEngine::setPosUniform() {
  pos_cache_valid = false;

  for(int sample_idx=0; sample_idx<locale.numSamples; sample_idx++) {
    Sample *sample=&locale.samples[sample_idx];
    sample->weight = 1.0;
    environ->genUniformSample(sample);
  }
  locale.totalWeight=locale.numSamples;
  locale.totalWeightValid=true;
  locale.samplesNormalized=true;
}

void
LocalizationEngine::updateForMovement(MovementPrimitive *move_updater) {
  pos_cache_valid = false;

#ifdef PLATFORM_LINUX
  if(dump_dist_files) {ofstream os(makeFileName()); dumpSamples(os); os.close();}
#endif

  if(!locale.samplesNormalized)
    normalizeSamples();

#ifdef PLATFORM_LINUX
  if(dump_dist_files) {ofstream os(makeFileName()); dumpSamples(os); os.close();}
#endif

  if(num_samples_generated_last_update > 0){
    //move_updater->data.vel_walk_data.x_dev             =50.0;
    //move_updater->data.vel_walk_data.y_dev             =50.0;
    //move_updater->data.vel_walk_data.heading_change_dev=.15;

    num_samples_generated_last_update = 0;
  }

  int num_samples=locale.numSamples;
  for(int i=0; i<num_samples; i++) {
    (move_updater->*(move_updater->sampleUpdater))(&locale.samples[i]);
  }
}

void
LocalizationEngine::updateForSensors(int num_primitives, ProbEvaluator **primitives_p) {
  pos_cache_valid = false;

#ifdef PLATFORM_LINUX
  if(dump_dist_files) {ofstream os(makeFileName()); dumpSamples(os); os.close();}
#endif

  if(!locale.totalWeightValid)
    calcTotalSampleWeight();
  double old_total_weight=locale.totalWeight;

  if(num_primitives>0) {
    locale.samplesNormalized=false;
    locale.totalWeightValid=false;
    sortPrimitives(num_primitives,primitives_p);
  }

  double total_weight=0.0;
  static double effects[MaxNumPrimitives];
  if(dump_primitive_effects)
    for(int i=0; i<num_primitives; i++)
      effects[i] = 0.0;
  for(int samp_idx=0; samp_idx<locale.numSamples; samp_idx++) {
    Sample *sample=&locale.samples[samp_idx];
    
    total_weight += weightSampleForPrimitives(sample,
                                              num_primitives,primitives_p,
                                              effects);
  }
  if(dump_primitive_effects){
    for(int i=0; i<num_primitives; i++)
      effects[i] /= locale.numSamples;
    for(int i=0; i<num_primitives; i++){
      pprintf(TextOutputStream,"primitive[%d]=%g\n",i,effects[i]);
    }
  }

  if(num_primitives>0) {
    locale.totalWeight=total_weight;
    locale.totalWeightValid=true;
  }

#ifdef PLATFORM_LINUX
  if(dump_dist_files) {ofstream os(makeFileName()); dumpSamples(os); os.close();}
#endif

  //if(debug) {cout << "doing normalization" << endl; cout.flush();}
  double good_samples_prob;
  good_samples_prob=locale.totalWeight;
  good_samples_prob/=old_total_weight;
  if(debug) cout << "avg relative weight=" << good_samples_prob << endl;

  double good_expectation_corrected;
  double expectation=1.0;
  int num_sensor_samples;

  for(int i=0; i<num_primitives; i++){
    double exp_prob=1.0;
    primitives_p[i]->calc_exp_prob(&exp_prob,false);
    expectation *= exp_prob;
  }
  good_expectation_corrected = good_samples_prob / expectation;

  if(debug)
    cout << "good_expectation_corrected=" << good_expectation_corrected 
         << " expectation=" << expectation 
         << " good_samples_prob=" << good_samples_prob << endl;
  avg_prob_fast = avg_prob_fast * (1.0-AMCLFastFrac) + AMCLFastFrac * good_expectation_corrected;
  avg_prob_slow = avg_prob_slow * (1.0-AMCLSlowFrac) + AMCLSlowFrac * good_expectation_corrected;
  if(debug)
    cout << "fast_avg=" << avg_prob_fast << " slow_avg=" << avg_prob_slow << " samples to replace=" <<
      locale.numSamples * bound(1 - AMCLResetMult*avg_prob_fast/avg_prob_slow, 0.0, 1.0) << endl;
  num_sensor_samples = (int)(locale.numSamples * bound(1 - AMCLResetMult*avg_prob_fast/avg_prob_slow, 0.0, 1.0));
  num_sensor_samples = bound(num_sensor_samples,0,MaxSampleGen);

  num_samples_generated_last_update = num_sensor_samples;

#ifdef PLATFORM_APERIOS
  //if(config.spoutConfig.dumpLocalization) {  
  //  if(debug_aperios) {
  //    if(OutputBuf!=NULL) {
  //      if(num_sensor_samples > 0) {
  //        uchar buf[64];
  //        sprintf((char *)buf,"L:r%d\n\xFF",num_sensor_samples);
  //        OutputBuf->write(buf);
  //      }
  //    }
  //  }
  //}
#endif

  // sensor sampling
  //if(debug) {cout << "doing sensor sampling" << endl; cout.flush();}
  if(num_sensor_samples>0) {
    if(!locale.samplesNormalized)
      normalizeSamples();
  }

#ifdef PLATFORM_LINUX
  if(dump_dist_files) {ofstream os(makeFileName()); dumpSamples(os); os.close();}
#endif

  generateSamples(num_sensor_samples,
                  locale.numSamples,&locale.samples[0],
                  num_primitives,primitives_p);

  if(debug) {cout << "done updating" << endl; cout.flush();}
}

double
LocalizationEngine::weightSampleForPrimitives(Sample *sample,
                                              int num_primitives, ProbEvaluator **primitives_p,
                                              double *effects)
{
  ProbEvaluator **updater_p             =&primitives_p[0];
  ProbEvaluator **end_related_updaters_p=&primitives_p[num_primitives];
  
  for(ProbEvaluator **updater_to_use_p=updater_p; updater_to_use_p!=end_related_updaters_p; updater_to_use_p++) {
    double prob=1.0;
    (*updater_to_use_p)->calc_prob(&prob,sample,false);
    if(dump_primitive_effects)      
      effects[updater_to_use_p - &primitives_p[0]] += prob;
    //printf("i=%d updater=%p effects[i]=%g prob=%g\n",
    //       updater_to_use_p - &primitives_p[0],*updater_to_use_p,effects[updater_to_use_p - &primitives_p[0]],prob);
    sample->weight *= prob;
  }

  return sample->weight;
}

void
LocalizationEngine::sortPrimitives(int num_primitives, ProbEvaluator **primitives_p) {
  if(num_primitives<=0)
    return;

  // divide into nonambiguous and ambiguous
  ProbEvaluator **non_ambig_end;
  ProbEvaluator **cur_primitive_p;
  ProbEvaluator **end_primitive_p;

  end_primitive_p = primitives_p+num_primitives;
  non_ambig_end   = end_primitive_p;

  // simple bubble sort for nonambiguous
  for(cur_primitive_p=primitives_p; cur_primitive_p!=non_ambig_end-1; cur_primitive_p++) {
    for(ProbEvaluator **other_primitive_p=cur_primitive_p+1; other_primitive_p!=non_ambig_end; other_primitive_p++) {
      if((*other_primitive_p)->observationId < (*cur_primitive_p)->observationId)
        swap(*other_primitive_p,*cur_primitive_p);
    }
  }
}

void
LocalizationEngine::generateSamples(int num_samples_to_generate,
                                    int num_samples,Sample *samples,
                                    int num_primitives,ProbEvaluator **primitives_p)
{
  if(num_samples_to_generate <= 0)
    return;

  //pprintf(TextOutputStream,"generating samples from %d primitives\n",num_primitives);

  // randomize which samples get replaced
  for(int i=0; i<num_samples_to_generate; i++){
    int samp_to_replace;

    samp_to_replace = GlobalRandom.uint32() % num_samples;

    swap(samples[i],samples[samp_to_replace]);
  }

  // set up evaluators
  static const int max_evals=6;

  int num_evals=0;
  int primitive_idx=0;
  while(primitive_idx < num_primitives && num_evals<max_evals-1){
    ProbEvaluator *primitive;

    primitive = primitives_p[primitive_idx++];
    sampler.evals[num_evals++] = primitive;
  }
  sampler.evals[num_evals++] = &on_field_eval;
  sampler.evals[num_evals++] = &near_old_eval;

  sampler.num_evals = num_evals;
  sampler.NumSamples = num_samples_to_generate;
  sampler.samples = samples;
  sampler.init_samples();
  //sampler.randomize_samples();
  sampler.step_samples(10);

  for(int i=0; i<num_samples_to_generate; i++)
    samples[i].weight = 1.0;
}

void
LocalizationEngine::getPosition(Sample *mean, Sample *std_dev, double *cov_xy) {
  if(pos_cache_valid){
    *mean    = pos_cache_mean   ;
    *std_dev = pos_cache_std_dev;
    *cov_xy  = pos_cache_cov_xy ;
    return;
  }

  double total_weight=0.0;

  // std_dev is used to hold variance until the last step

  mean->loc.set(0.0,0.0);
  mean->angle = 0.0;
  std_dev->loc.set(0.0,0.0);
  std_dev->angle = 0.0;
  *cov_xy = 0.0;
  
  for(int sample_idx=0; sample_idx<locale.numSamples; sample_idx++) {
    Sample *sample=&locale.samples[sample_idx];
    
    double weight=sample->weight;
    total_weight+=weight;
    
    // mean holds sum of x
    // std_dev holds sum of x squared
    mean   ->loc.x += weight*   sample->loc.x;
    std_dev->loc.x += weight*sq(sample->loc.x);
    mean   ->loc.y += weight*   sample->loc.y;
    std_dev->loc.y += weight*sq(sample->loc.y);

    // mean holds sum of cos x
    // std_dev holds sum of sin x
    mean   ->angle += weight*cosf(sample->angle);
    std_dev->angle += weight*sinf(sample->angle);

    *cov_xy += weight * sample->loc.x * sample->loc.y;
  }

  if(total_weight == 0.0){
    mean->loc.set(0.0,0.0);
    mean->angle = 0.0;
    std_dev->loc.set(halfLength,halfWidth);
    std_dev->angle = 2*M_PI;
    *cov_xy = 0.0;

    pos_cache_mean    = *mean   ;
    pos_cache_std_dev = *std_dev;
    pos_cache_cov_xy  = *cov_xy ;
    pos_cache_valid = true;
    return;
  }

  locale.totalWeight=total_weight;
  locale.totalWeightValid=true;

  { // calculate summary for x
    double sum=mean->loc.x;
    mean->loc.x = sum/total_weight;
    std_dev->loc.x /= total_weight;
    std_dev->loc.x -= mean->loc.x*mean->loc.x;
    if(std_dev->loc.x<0.0) // round off error protection
      std_dev->loc.x=0.0;
    std_dev->loc.x=sqrt(std_dev->loc.x);
  }

  { // calculate summary for y
    double sum=mean->loc.y;
    mean->loc.y = sum/total_weight;
    std_dev->loc.y /= total_weight;
    std_dev->loc.y -= mean->loc.y*mean->loc.y;
    if(std_dev->loc.y<0.0) // round off error protection
      std_dev->loc.y=0.0;
    std_dev->loc.y=sqrt(std_dev->loc.y);
  }

  { // calculate summary for theta
    double cos_mean = mean   ->angle/total_weight;
    double sin_mean = std_dev->angle/total_weight;
    mean->angle = atan2(sin_mean,cos_mean);
    std_dev->angle = 1 - hypot(cos_mean,sin_mean);
    if(std_dev->angle < 0.0) // round off error protection
      std_dev->angle = 0.0;
    std_dev->angle = sqrt(std_dev->angle);
  }

  { // calculate covariance estimate
    *cov_xy = *cov_xy/total_weight - mean->loc.x * mean->loc.y;
  }

  pos_cache_mean    = *mean   ;
  pos_cache_std_dev = *std_dev;
  pos_cache_cov_xy  = *cov_xy ;
  pos_cache_valid = true;
}

double
LocalizationEngine::calcTotalSampleWeight() {
  if(locale.totalWeightValid)
    return locale.totalWeight;

  double total_weight=0.0;
  
  for(int sample_idx=0; sample_idx<locale.numSamples; sample_idx++) {
    Sample *sample=&locale.samples[sample_idx];
    
    double weight = sample->weight;
    total_weight += weight;
  }
  
  locale.totalWeight = total_weight;
  locale.totalWeightValid = true;

  return locale.totalWeight;
}

void
LocalizationEngine::normalizeSamples() {
  if(locale.samplesNormalized)
    return;

  // does not support dynamic resizing of number of samples
  static double *cum_weights=NULL;
  if(cum_weights==NULL) {
    cum_weights=new double[locale.numSamples+1];
  }
  
  static LocaleSampled *new_locale=NULL;
  if(new_locale==NULL) {
    new_locale=new LocaleSampled;
    new_locale->samples=new Sample[new_locale->numSamples];
  }
  
  double cum_weight=0.0;
  for(int i=0; i<locale.numSamples; i++) {
    cum_weights[i] = cum_weight;
    cum_weight += locale.samples[i].weight;
  }
  cum_weights[locale.numSamples]=cum_weight;
  
  Sample *new_locale_iter;
  Sample *end_locale_iter;
  end_locale_iter=(&new_locale->samples[new_locale->numSamples-1])+1;
  for(new_locale_iter = new_locale->samples;
      new_locale_iter < end_locale_iter;
      new_locale_iter++) {
    double cum_weight_to_find=GlobalRandom.uniform(0.0,cum_weight);
    
    // binary search for right cumulative weight
    int low_idx,high_idx,mid_idx;
    low_idx=0;
    high_idx=locale.numSamples-1;
    
    while(low_idx<high_idx-1) {
      mid_idx = (low_idx+high_idx) / 2;
      
      if(cum_weight_to_find < cum_weights[mid_idx]) {
        high_idx=mid_idx;
      }
      else {
        low_idx=mid_idx;
      }
    }
    
    int sampled_idx=low_idx;
    
    //new_locales[new_sample_idx].first=locales[sampled_idx].first;
    //new_locales[new_sample_idx].second=1.0;
    
    Sample *old_sample = &locale.samples[sampled_idx];
    Sample *new_sample = new_locale_iter;
    *new_sample = *old_sample;
    new_sample->weight = 1.0;
  }
  
  Sample *tmp=locale.samples;
  locale.samples=new_locale->samples;
  new_locale->samples=tmp;
  
  locale.samplesNormalized=true;
  locale.totalWeight=locale.numSamples;
  locale.totalWeightValid=true;
}

#ifdef PLATFORM_LINUX
void
LocalizationEngine::dumpSamples(ostream &os) const {
  os.precision(20);
  os.setf(ios::scientific);
  for(int i=0; i<locale.numSamples; i++) {
    Sample *one_sample=&locale.samples[i];
    os << one_sample->weight << " " << one_sample->loc.x << " " << one_sample->loc.y << " " << one_sample->angle;
    os << endl;
  }
}
#endif

void
LocalizationEngine::copySamples(Sample *samples) const {
  for(int i=0; i<locale.numSamples; i++) {
    Sample *one_sample      = &locale.samples[i];
    Sample *one_sample_copy = &samples[i];
    *one_sample_copy = *one_sample;
  }
}
