/* LICENSE:
  =========================================================================
    CMPack'04 Source Code Release for OPEN-R SDK 1.1.5-r2 for ERS7
    Copyright (C) 2004 Multirobot Lab [Project Head: Manuela Veloso]
    School of Computer Science, Carnegie Mellon University
    All rights reserved.
  ========================================================================= */

#include <math.h>
#include <stdio.h>
#include <stdlib.h>

#include "../../headers/system_config.h"

#include "../../headers/CircBufPacket.h"
#include "../../headers/field.h"
#include "../../headers/Util.h"

// to get definition of Sample
#include "LocalizationEngine.h"
#include "Sampler.h"

double eval_gaussian(double sigmas);

//#define CHECK_DERIVATIVES

const int NumMarkerDistReadings = 0;
double ExpDists[2] = {1000.0,1850.0};
vector2d DistMarkerLocs[2] = {//vector2d(2100.0, 500.0),
                              vector2d(0.0, 1450.0),
                              vector2d(2100.0,-1450.0)};

const int NumGoalDistReadings = 0;
double GoalExpDists[2] = {1350.0, 2500.0};
vector2d DistGoalLocs[2] = {vector2d(2100.0, 300.0),
                            vector2d(2100.0, 300.0)};
double GoalDistUnif = 450.0;

const int NumMarkerAngleReadings = 0;
double ExpAngles[2] = {1.57, 1.57};
vector2d AngleMarkerLocs[2] = {vector2d(0.0, 1450.0),
                               vector2d(0.0,-1450.0)};
//double ExpAngles[2] = {-0.4, 0.0};
//vector2d AngleMarkerLocs[2] = {vector2d(0.0, 1450.0),
//                               vector2d(2100.0, 300.0)};

const int NumLineReadings = 0;
vector2d LineStartLocs[4] = {
  vector2d(-2100.0,-1350.0),
  vector2d(-2100.0, 1350.0),
  vector2d( 2100.0,-1350.0),
  vector2d(-2100.0,-1350.0)
};
vector2d LineEndLocs  [4] = {
  vector2d(-2100.0, 1350.0),
  vector2d( 2100.0, 1350.0),
  vector2d( 2100.0, 1350.0),
  vector2d( 2100.0,-1350.0)
};
vector2d LineObsEgo[3] = {vector2d(   0.0, 600.0),
                          vector2d( 300.0,   0.0),
                          vector2d( 300.0, 300.0)};

const int NumCornerReadings=1;
double CornerDistObs[1] = {
  300.0
};
double CornerAngleObs[1] = {
  0.0
};
vector2d CornerLocs[10] = {
  vector2d(                halfLength, penaltyRegionHalfWidth),
  vector2d( penaltyRegionLengthOffset, penaltyRegionHalfWidth),
  vector2d(                       0.0,     centerCircleRadius),
  vector2d(-penaltyRegionLengthOffset, penaltyRegionHalfWidth),
  vector2d(               -halfLength, penaltyRegionHalfWidth),

  vector2d(                halfLength,-penaltyRegionHalfWidth),
  vector2d( penaltyRegionLengthOffset,-penaltyRegionHalfWidth),
  vector2d(                       0.0,    -centerCircleRadius),
  vector2d(-penaltyRegionLengthOffset,-penaltyRegionHalfWidth),
  vector2d(               -halfLength,-penaltyRegionHalfWidth)
};

bool UseOnFieldProbEvaluator=true;

bool UseNearOldValueProbEvaluator=true;
static const Sample NearOldValueDev = {
  1.0,
  Sample::vector2e(200.0,200.0),
  .5,
  Sample::vector2e(0.0,0.0),
  0.0
};

static const bool PrintOuts = false;
static const bool PrintAcceptRate = false;

static const int num_steps=5;
static const double NormalTimeStepMagnitude=29.0;
static const double MaxTimeStepDownAdjustFrac= .5;
static const double MaxTimeStepUpAdjustFrac  =2.0;
static const double MinOKAcceptFrac=.90;
// if accept rate goes over this, the time step will be increased
static const double MaxOKAcceptFrac=.98; // must be less than 1.0

void OnFieldProbEvaluator::calc_prob(double *prob,const Sample *samp,bool scaled)
{
  static const double max_field_x = 2400.0;
  static const double max_field_y = 1350.0;
  static const double stddev_dist = 100.0;

  double x_off,y_off;

  x_off = fabs(samp->loc.x) - max_field_x;
  if(x_off < 0.0)
    x_off = 0.0;

  y_off = fabs(samp->loc.y) - max_field_y;
  if(y_off < 0.0)
    y_off = 0.0;

  double sigmas;
  sigmas = x_off/stddev_dist + y_off/stddev_dist;

  *prob = eval_gaussian(sigmas);
}

void OnFieldProbEvaluator::calc_exp_prob(double *exp_prob,bool scaled)
{
  // this value is not actually used
  *exp_prob = 1.0;
}

void OnFieldProbEvaluator::calc_energy(double *energy,const Sample *samp)
{
  static const double max_field_x = 2100.0;
  static const double max_field_y = 1350.0;
  static const double stddev_dist = 100.0;

  double x_off,y_off;

  x_off = fabs(samp->loc.x) - max_field_x;
  if(x_off < 0.0)
    x_off = 0.0;

  y_off = fabs(samp->loc.y) - max_field_y;
  if(y_off < 0.0)
    y_off = 0.0;

  double sigmas;
  sigmas = x_off/stddev_dist + y_off/stddev_dist;

  *energy = (sigmas*sigmas)/2.0;
}

void OnFieldProbEvaluator::calc_energy_der(Sample *der,const Sample *samp)
{
  static const double max_field_x = 2100.0;
  static const double max_field_y = 1350.0;
  static const double stddev_dist = 100.0;

  double x_off,y_off;

  x_off = fabs(samp->loc.x) - max_field_x;
  if(x_off < 0.0)
    x_off = 0.0;

  y_off = fabs(samp->loc.y) - max_field_y;
  if(y_off < 0.0)
    y_off = 0.0;

  double sigmas;
  sigmas = x_off/stddev_dist + y_off/stddev_dist;

  double der_common;
  der_common = sigmas;

  der->loc.x = sigmas * sign(x_off) / stddev_dist;
  der->loc.y = sigmas * sign(y_off) / stddev_dist;
  der->angle = 0.0;
}

void NearOldValueProbEvaluator::calc_prob(double *prob,const Sample *samp,bool scaled)
{
  double x_off,y_off,ang_off;

  x_off   = samp->loc.x - samp->orig_loc.x;
  y_off   = samp->loc.y - samp->orig_loc.y;
  ang_off = norm_angle(samp->angle/Sample::AngleScale - samp->orig_angle);

  double x_sigmas,y_sigmas,ang_sigmas;
  x_sigmas   =   x_off / dev.loc.x;
  y_sigmas   =   y_off / dev.loc.y;
  ang_sigmas = ang_off / dev.angle;

  double x_prob,y_prob,ang_prob;
  x_prob   = eval_gaussian(  x_sigmas);
  y_prob   = eval_gaussian(  y_sigmas);
  ang_prob = eval_gaussian(ang_sigmas);

  *prob = x_prob * y_prob * ang_prob;
}

void NearOldValueProbEvaluator::calc_exp_prob(double *exp_prob,bool scaled)
{
  // this value is not actually used
  *exp_prob = 1.0;
}

void NearOldValueProbEvaluator::calc_energy(double *energy,const Sample *samp)
{
  double x_off,y_off,ang_off;

  x_off   = samp->loc.x - samp->orig_loc.x;
  y_off   = samp->loc.y - samp->orig_loc.y;
  ang_off = norm_angle(samp->angle/Sample::AngleScale - samp->orig_angle);

  double x_sigmas,y_sigmas,ang_sigmas;
  x_sigmas   =   x_off / dev.loc.x;
  y_sigmas   =   y_off / dev.loc.y;
  ang_sigmas = ang_off / dev.angle;

  double x_energy,y_energy,ang_energy;
  x_energy   = sq(  x_sigmas)/2.0;
  y_energy   = sq(  y_sigmas)/2.0;
  ang_energy = sq(ang_sigmas)/2.0;

  *energy = x_energy + y_energy + ang_energy;
}

void NearOldValueProbEvaluator::calc_energy_der(Sample *der,const Sample *samp)
{
  double x_off,y_off,ang_off;

  x_off   = samp->loc.x - samp->orig_loc.x;
  y_off   = samp->loc.y - samp->orig_loc.y;
  ang_off = norm_angle(samp->angle/Sample::AngleScale - samp->orig_angle);

  double x_sigmas,y_sigmas,ang_sigmas;
  x_sigmas   =   x_off / dev.loc.x;
  y_sigmas   =   y_off / dev.loc.y;
  ang_sigmas = ang_off / dev.angle;

  der->loc.x =   x_sigmas / dev.loc.x;
  der->loc.y =   y_sigmas / dev.loc.y;
  der->angle = ang_sigmas / (dev.angle*Sample::AngleScale);
}

void Sampler::init()
{
  samples = new Sample[NumSamples];
  for(int samp_idx=0; samp_idx<NumSamples; samp_idx++)
    samples[samp_idx].init();

  random.seed(1);
  sample_steps = 0;
  AcceptCnt=0;
  TotalCnt =0;

  num_evals = 0;

  for(int i=0; i<NumMarkerDistReadings; i++){
    PointDistProbEvaluator *pt_dist_eval;
    pt_dist_eval = new PointDistProbEvaluator;
    pt_dist_eval->loc           = DistMarkerLocs[i];
    pt_dist_eval->dist_obs      = ExpDists[i];
    pt_dist_eval->dist_dev_frac = .1;

    evals[num_evals++] = pt_dist_eval;
  }

  for(int i=0; i<NumGoalDistReadings; i++){
    PointDistWithUniformProbEvaluator *pt_dist_with_uniform_eval;
    pt_dist_with_uniform_eval = new PointDistWithUniformProbEvaluator;
    pt_dist_with_uniform_eval->loc           = DistGoalLocs[i];
    pt_dist_with_uniform_eval->dist_obs      = GoalExpDists[i];
    pt_dist_with_uniform_eval->dist_unif     = GoalDistUnif;
    pt_dist_with_uniform_eval->dist_dev_frac = .1;

    evals[num_evals++] = pt_dist_with_uniform_eval;
  }

  for(int i=0; i<NumMarkerAngleReadings; i++){
    PointAngleProbEvaluator *pt_angle_eval;
    pt_angle_eval = new PointAngleProbEvaluator;
    pt_angle_eval->loc       = AngleMarkerLocs[i];
    pt_angle_eval->angle_obs = ExpAngles[i];
    pt_angle_eval->angle_dev = .1;

    evals[num_evals++] = pt_angle_eval;
  }

  for(int i=0; i<NumLineReadings; i++){
    LinePointProbEvaluator *line_eval;
    line_eval = new LinePointProbEvaluator;
    line_eval->num_lines = 4;
    for(int j=0; j<4; j++){
      line_eval->p0[j]      = LineStartLocs[j];
      line_eval->p1[j]      = LineEndLocs  [j];
    }
    line_eval->obs_ego = LineObsEgo[i];
    line_eval->init_internals();

    evals[num_evals++] = line_eval;
  }

  for(int i=0; i<NumCornerReadings; i++){
    CornerProbEvaluator *corner_eval;
    corner_eval = new CornerProbEvaluator;

    corner_eval->dist_obs = CornerDistObs[i];
    corner_eval->dist_dev_frac = .1;
    corner_eval->angle_obs = CornerAngleObs[i];
    corner_eval->angle_dev = .1;

    corner_eval->num_corners = 10;
    corner_eval->corners = CornerLocs;

    evals[num_evals++] = corner_eval;
  }

  if(UseOnFieldProbEvaluator){
    OnFieldProbEvaluator *on_field;
    on_field = new OnFieldProbEvaluator;

    evals[num_evals++] = on_field;
  }

  if(UseNearOldValueProbEvaluator){
    NearOldValueProbEvaluator *near_old;
    near_old = new NearOldValueProbEvaluator;
    near_old->dev = NearOldValueDev;

    evals[num_evals++] = near_old;
  }
}

void Sampler::init_samples()
{
  for(int samp_idx=0; samp_idx<NumSamples; samp_idx++){
    Sample *samp;

    samp = &samples[samp_idx];
    samp->orig_loc   = samp->loc;
    samp->orig_angle = samp->angle;
  }
}

void Sampler::randomize_samples()
{
  for(int samp_idx=0; samp_idx<NumSamples; samp_idx++){
    Sample *samp;
    double angle;

    samp = &samples[samp_idx];
    angle = norm_angle(random.real32()*2*M_PI);
    samp->loc.set(random.real32()*4200 - 2100,random.real32()*2700 - 1350);
    samp->angle = angle;

    //samp->loc.set(1000.0,-1000.0);
    //samp->angle = 0.0;
  }
  sample_steps = 0;
}

void Sampler::eval_sample_prob(double *prob,const Sample *samp)
{
  *prob = 1.0;

  for(int eval_idx=0; eval_idx<num_evals; eval_idx++){
    ProbEvaluator *eval;
    double this_prob;

    eval = evals[eval_idx];

    eval->calc_prob(&this_prob,samp,true);
    *prob *= this_prob;
  }
}

void Sampler::eval_sample_energy(double *energy,const Sample *samp)
{
  *energy = 0.0;

  for(int eval_idx=0; eval_idx<num_evals; eval_idx++){
    ProbEvaluator *eval;
    double this_energy;

    eval = evals[eval_idx];

    eval->calc_energy(&this_energy,samp);
    *energy += this_energy;
  }
}

void Sampler::eval_sample_energy_der(Sample *der,const Sample *samp)
{
  der->set_der_zero();

  for(int eval_idx=0; eval_idx<num_evals; eval_idx++){
    ProbEvaluator *eval;
    Sample this_der;

    eval = evals[eval_idx];

    eval->calc_energy_der(&this_der,samp);
    der->add_der(&this_der);
  }
}

void Sampler::sim_sample(Sample *samp,Sample *momentum,double time_step_magnitude,bool forward)
{
  static const bool print_outs=true && PrintOuts;

  double time_step = (forward ? time_step_magnitude : -time_step_magnitude);

  Sample delta;
  Sample e_der;

  if(print_outs)
    pprintf(TextOutputStream,"loc=(%g,%g) angle=(%g)\n",V2COMP(samp->loc),samp->angle);

  if(print_outs)
    pprintf(TextOutputStream,"momentum=(%g,%g) (%g)\n",V2COMP(momentum->loc),momentum->angle);
  
  // start leap frog implementation

  // first half step of momentums
  eval_sample_energy_der(&e_der,samp);
  delta.copy_der(&e_der);
  delta.scale_der(-time_step/2.0);
  momentum->add_der(&delta);
  if(print_outs){
    pprintf(TextOutputStream,"hs mom der=(%10.4g,%10.4g) (%10.4g) delta=(%10.4g,%10.4g) (%10.4g)\n",
            V2COMP(e_der.loc),e_der.angle,V2COMP(delta.loc),delta.angle);
    pprintf(TextOutputStream,"momentum=(%g,%g) (%g)\n",V2COMP(momentum->loc),momentum->angle);
  }

  // full step of locations
  delta.copy_der(momentum);
  delta.scale_der(time_step);
  samp->add_der_to_sample(&delta);
  if(print_outs)
    pprintf(TextOutputStream,"fs loc delta=(%g,%g) (%g) loc=(%g,%g) (%g)\n",
            V2COMP(delta.loc),delta.angle,V2COMP(samp->loc),samp->angle);

  for(int step=1; step<num_steps; step++){
    // full step of momentums
    eval_sample_energy_der(&e_der,samp);
    delta.copy_der(&e_der);
    delta.scale_der(-time_step);
    momentum->add_der(&delta);
    if(print_outs){
      pprintf(TextOutputStream,"fs mom der=(%10.4g,%10.4g) (%10.4g) delta=(%10.4g,%10.4g) (%10.4g)\n",
              V2COMP(e_der.loc),e_der.angle,V2COMP(delta.loc),delta.angle);
      pprintf(TextOutputStream,"momentum=(%g,%g) (%g)\n",V2COMP(momentum->loc),momentum->angle);
    }

    // full step of locations
    delta.copy_der(momentum);
    delta.scale_der(time_step);
    samp->add_der_to_sample(&delta);
    if(print_outs)
      pprintf(TextOutputStream,"fs loc delta=(%g,%g) (%g) loc=(%g,%g) (%g)\n",
              V2COMP(delta.loc),delta.angle,V2COMP(samp->loc),samp->angle);
  }

  // last half step of momentums
  eval_sample_energy_der(&e_der,samp);
  delta.copy_der(&e_der);
  delta.scale_der(-time_step/2.0);
  momentum->add_der(&delta);
  if(print_outs){
    pprintf(TextOutputStream,"hs mom der=(%10.4g,%10.4g) (%10.4g) delta=(%10.4g,%10.4g) (%10.4g)\n",
            V2COMP(e_der.loc),e_der.angle,V2COMP(delta.loc),delta.angle);
    pprintf(TextOutputStream,"momentum=(%g,%g) (%g)\n",V2COMP(momentum->loc),momentum->angle);
  }
}

void Sampler::hybrid_mcmc_update_sample(Sample *samp,double time_step_magnitude)
{
  static const bool print_outs=true && PrintOuts;

  Sample old_samp;
  double old_energy=0.0;

  old_samp = *samp;

  samp->normalize_for_mcmc();

  eval_sample_energy(&old_energy,samp);

  if(print_outs)
    pprintf(TextOutputStream,"loc=(%g,%g) (%g) energy=%g\n",V2COMP(samp->loc),samp->angle,old_energy);

#ifdef CHECK_DERIVATIVES
  Sample delta_samp;
  double delta_energy=0.0,delta_energy2=0.0;
  static const double epsilon = 1e-3;
  delta_samp = *samp;
  delta_samp.loc.x += epsilon;
  eval_sample_energy(&delta_energy,&delta_samp);
  delta_samp = *samp;
  delta_samp.loc.x -= epsilon;
  eval_sample_energy(&delta_energy2,&delta_samp);
  pprintf(TextOutputStream,"delta_energy=%g delta_energy2=%g old_energy=%g diff=%g epsilon=%g\n",delta_energy,delta_energy2,old_energy,delta_energy-delta_energy2,epsilon);
  pprintf(TextOutputStream,"dE/dloc.x = %g\n",(delta_energy - old_energy) / epsilon);
  pprintf(TextOutputStream,"dE/dloc.x = %g\n",(delta_energy - delta_energy2) / (2*epsilon));
  delta_samp = *samp;
  delta_samp.loc.y += epsilon;
  eval_sample_energy(&delta_energy,&delta_samp);
  pprintf(TextOutputStream,"dE/dloc.y = %g\n",(delta_energy - old_energy) / epsilon);
  delta_samp = *samp;
  delta_samp.angle += epsilon;
  eval_sample_energy(&delta_energy,&delta_samp);
  pprintf(TextOutputStream,"dE/dangle = %g\n",(delta_energy - old_energy) / epsilon);

  Sample der;
  eval_sample_energy_der(&der,samp);
  pprintf(TextOutputStream,"dE/dloc.x = %g, dE/dloc.y = %g, dE/dangle = %g\n",der.loc.x,der.loc.y,der.angle);
#endif

  Sample momentum;
  momentum.random_momentum(random);
  old_energy += momentum.sqlength()/2.0;

  if(print_outs)
    pprintf(TextOutputStream,"momentum=(%g,%g) (%g) energy=%g\n",V2COMP(momentum.loc),momentum.angle,old_energy);

  bool forward;

  forward = (random.real32() < .5);
  if(print_outs)
    pprintf(TextOutputStream,"forward=%d\n",forward);
  sim_sample(samp,&momentum,time_step_magnitude,forward);

  double energy=0.0;
  eval_sample_energy(&energy,samp);
  if(print_outs)
    pprintf(TextOutputStream,"loc=(%g,%g) (%g) energy=%g\n",V2COMP(samp->loc),samp->angle,energy);
  energy += momentum.sqlength()/2.0;
  if(print_outs)
    pprintf(TextOutputStream,"momentum=(%g,%g) (%g) energy=%g\n",V2COMP(momentum.loc),momentum.angle,energy);

  bool accept;
  double accept_val;
  accept_val = exp(-energy + old_energy);
  accept = (random.real32() < accept_val);
  TotalCnt++;
  if(accept) AcceptCnt++;
  if(print_outs)
    pprintf(TextOutputStream,"###accept_val=%g accept=%d\n",accept_val,accept);
  if(!accept){
    *samp = old_samp;
  }

  if(accept){
    samp->normalize();
  }
}

void Sampler::step_sample(Sample *samp,double time_step_magnitude)
{
#if 1
  // Hybrid MCMC
  hybrid_mcmc_update_sample(samp,time_step_magnitude);
#endif

#if 0
  // Stochastic dynamics
  vector2d momentum;
  momentum.set(random.gaussian32(),random.gaussian32());
  sim_sample(samp,momentum,true);
#endif

#if 0
  Sample old_samp;
  double old_prob=0.0;

  // Metropolis
  old_samp = *samp;
  eval_sample_prob(&old_prob,samp);

  vector2d loc;
  vector2d delta;

  loc = samp->loc;
  delta.set(random.gaussian32()*200.0,random.gaussian32()*200.0);

  samp->loc += delta;

  double prob=0.0;

  eval_sample_prob(&prob,samp);

  //double der_e_x=0.0,der_e_y=0.0;
  //eval_sample_energy_der(&der_e_x,&der_e_y,samp);
  //pprintf(TextOutputStream,"loc=(%8g,%8g) dist=%8g energy der=(%10.4g,%10.4g)\n",
  //       V2COMP(samp->loc),
  //       GVector::distance(samp->loc,MarkerLocs[0]),
  //       der_e_x,der_e_y);

  bool accept;
  accept = (prob > old_prob) || (random.real32() <= prob/old_prob);
  if(!accept){
    *samp = old_samp;
  }
#endif
}

void Sampler::step_samples(int step_cnt)
{
  if(PrintOuts)
    pprintf(TextOutputStream,"using %d evals to step %d samples for %d steps\n",
            num_evals,NumSamples,step_cnt);

  double time_step = NormalTimeStepMagnitude;
  for(int step_num=0; step_num<step_cnt; step_num++){
    AcceptCnt = 0;
    TotalCnt  = 0;

    if(PrintAcceptRate)
      pprintf(TextOutputStream,"stepping samples with step size %g\n",time_step);

    for(int samp_idx=0; samp_idx<NumSamples; samp_idx++){
      Sample *samp;
      
      samp = &samples[samp_idx];
      step_sample(samp,time_step);
    }

    double accept_frac;
    accept_frac = ((double)AcceptCnt)/TotalCnt;
    if(accept_frac < MinOKAcceptFrac){
      time_step *=
        MaxTimeStepDownAdjustFrac * (MinOKAcceptFrac - accept_frac)/MinOKAcceptFrac + 
        accept_frac / MinOKAcceptFrac;
      if(PrintAcceptRate)
        pprintf(TextOutputStream,"adjusted step size to %g\n",time_step);
    }
    //else if(accept_frac > MaxOKAcceptFrac){
    //  time_step *=
    //    MaxTimeStepUpAdjustFrac * (accept_frac - MaxOKAcceptFrac)/(1.0 - MaxOKAcceptFrac) + 
    //    (1.0 - accept_frac) / (1.0 - MaxOKAcceptFrac);
    //}
    
    if(PrintAcceptRate)
      pprintf(TextOutputStream,"accepted %d/%d samples (%10.4g%%)\n",AcceptCnt,TotalCnt,
              100.0*accept_frac);
  }

  sample_steps += step_cnt;
}
