/* LICENSE:
  =========================================================================
    CMPack'04 Source Code Release for OPEN-R SDK 1.1.5-r2 for ERS7
    Copyright (C) 2004 Multirobot Lab [Project Head: Manuela Veloso]
    School of Computer Science, Carnegie Mellon University
    All rights reserved.
  ========================================================================= */

#include <assert.h>
#include <math.h>
#include <stdio.h>

#include <vector>

#include <gsl/gsl_multifit.h>

#include "../../agent/headers/Geometry.h"
#include "../../agent/Motion/Motion.h"
#include "../../agent/Motion/MotionInterface.h"

using std::vector;

// FIXME: set from walk parameters
double MaxDA=2.2;

static const double MovementTime = 7.0;

class MotCalData {
public:
  double dx,dy,da;
  vector2d pre_loc;
  double pre_angle;
  vector2d post_loc;
  double post_angle;
  double neck_offset_x;
};

class MotCalExample {
public:
  double ask_dx,ask_dy,ask_da;
  double real_dx,real_dy,real_da;
  vector2d pre_loc_body;
  double pre_angle;
  vector2d post_loc_body;
  double post_angle;
};

class MotCal {
public:
  double *dx_coeffs;
  double *dy_coeffs;
  double *da_coeffs;
};

static const int NumBases = 6;

double get_basis(const MotCalExample *ex,int basis_idx)
{
  double basis_val = 0.0;

  switch(basis_idx){
    case 0:
      basis_val = (ex->ask_dx >= 0.0 ? ex->ask_dx : 0.0);
      break;
    case 1:
      basis_val = (ex->ask_dx <= 0.0 ? ex->ask_dx : 0.0);
      break;
    case 2:
      basis_val = ex->ask_dy;
      break;
    case 3:
      basis_val = ex->ask_da;
      break;
    case 4:
      basis_val = ex->ask_dx * ex->ask_dy;
      break;
    case 5:
    default:
      basis_val = 1.0;
      break;
  }

  return basis_val;
}

void load_data(vector<MotCalData> &mot_cal_data,FILE *data_file)
{
  MotCalData mot_cal_data_item;

  mot_cal_data.reserve(100);

  int item_cnt;
  while(!feof(data_file)){
    item_cnt = fscanf(data_file," MT: motion: dx=%lf dy=%lf da=%lf pre (%lf,%lf) <%lf post (%lf,%lf) <%lf neck %lf",
                      &mot_cal_data_item.dx,&mot_cal_data_item.dy,&mot_cal_data_item.da,
                      &mot_cal_data_item.pre_loc.x,&mot_cal_data_item.pre_loc.y,
                      &mot_cal_data_item.pre_angle,
                      &mot_cal_data_item.post_loc.x,&mot_cal_data_item.post_loc.y,
                      &mot_cal_data_item.post_angle,
                      &mot_cal_data_item.neck_offset_x);
    if(item_cnt < 0){
      printf("error reading data\n");
    }else if(item_cnt == 0){
      printf("did not match line\n");
      char buf[1024];
      fgets(buf,1024,data_file);
      printf("line was '%s'\n",buf);
    }else if(item_cnt == 10){
      mot_cal_data.push_back(mot_cal_data_item);
    }else{
      printf("partial match\n");
    }
  }
}

void calc_delta(double &dx,double &dy,double &da,
                double exp_dx,double exp_dy,double exp_da,
                double time,
                vector2d pre_loc_body,double pre_angle,vector2d post_loc_body,double post_angle)
{
  double angle_diff;
  double exp_angle_diff;

  angle_diff = post_angle - pre_angle;
  // calculate most likely number of extra rotations
  exp_angle_diff = time * exp_da;
  while(angle_diff - exp_angle_diff > M_PI)
    angle_diff -= 2*M_PI;
  while(angle_diff - exp_angle_diff < -M_PI)
    angle_diff += 2*M_PI;
  printf("angle diff=%g revs=%g\n",angle_diff,angle_diff/(2*M_PI));
  // calculate da implied by angle_diff
  da = angle_diff / time;

  vector2d &p1=pre_loc_body;
  vector2d &p2=post_loc_body;
  vector2d p_m; // midpoint between start and end points
  vector2d p_delta; // delta from pre to post loc
  vector2d h_1,h_2; // unit vectors in direction of pre/post angles
  vector2d p_c; // the center of the circle traced out

  p_m = (p1 + p2)/2.0;
  p_delta = p2 - p1;
  h_1.set(cos(pre_angle), sin(pre_angle ));
  h_2.set(cos(post_angle),sin(post_angle));

  //p_c.y = (p_delta.x * (p2.dot(h_2) - p1.dot(h_1)) - (h_2.x - h_1.x)*p_delta.dot(p_m)) /
  //  ((h_2 - h_1).cross(p_delta));
  //p_c.x = (p_delta.dot(p_m) - p_delta.y * p_c.y) / p_delta.x;
  //printf("p_c = (%g,%g)\n",V2COMP(p_c));

  // The center of the circle traced out by the robot is found thru a 3 step process.
  // First, the offset of the robot's motion relative to a radius of the circle is found.
  // This offset is calculated by finding the offset between the observed avg heading (pre/post)
  //   and the direction implied by the bisection line between the pre/post locations.
  //   The result is stored in the basis vector proj_basis.
  // Second, the heading vectors pre/post movement are rotated to remove the offset from
  //   motion making them perpendicular to the circle (they go thru the circle center).
  // Third, the 2 points (pre/post) and 2 radial vectors are used to intersect the 2 lines
  //   and find the center of the circle traced out.
  vector2d h_avg;
  vector2d center_locus_dir;
  vector2d proj_basis;
  vector2d perp1,perp2;
  h_avg = (h_1 + h_2) / 2;
  h_avg = h_avg.norm();
  center_locus_dir = p_delta.perp().norm();
  proj_basis = center_locus_dir.rebase(h_avg);
  perp1 = h_1.project(proj_basis);
  perp2 = h_2.project(proj_basis);
  printf("h_avg=(%g,%g) center_locus_dir=(%g,%g) proj_basis=(%g,%g) perp1=(%g,%g) perp2=(%g,%g)\n",
         V2COMP(h_avg),V2COMP(center_locus_dir),V2COMP(proj_basis),V2COMP(perp1),V2COMP(perp2));
  //GVector::intersect_ray_plane(p1,-perp1,p_m,p_delta,p_c);
  //printf("p_c = (%g,%g)\n",V2COMP(p_c));
  //GVector::intersect_ray_plane(p2,-perp2,p_m,p_delta,p_c);
  //printf("p_c = (%g,%g)\n",V2COMP(p_c));
  GVector::intersect_ray_plane(p1,-perp1,p2,-perp2.perp(),p_c);
  printf("p_c = (%g,%g)\n",V2COMP(p_c));

  double radius;
  vector2d p_1_tan; // tangent vector to circle at p_1 in direction of travel
  double perimeter; // distance traced out
  vector2d walk_vel; // walk velocities
  // estimate of amount of angle swept out of circle, ideally equal to angle_diff
  double angle_sweep;

  radius = GVector::distance(p_c,p1);
  angle_sweep = norm_angle((p2-p_c).angle() - (p1-p_c).angle());
  angle_sweep += 2*M_PI * (rint(angle_diff/(2*M_PI)));
  perimeter = fabs(radius * angle_sweep);
  p_1_tan = (p1 - p_c).perp().norm() * sign(angle_sweep);

  printf("circle: center (%g,%g), radius %g, perimeter %g, h_1 (%g,%g), p_1_tan (%g,%g)\n",
         V2COMP(p_c),radius,perimeter,V2COMP(h_1),V2COMP(p_1_tan));

  walk_vel.set(perimeter/time,0.0);

  printf("walk_vel pre=(%g,%g)\n",V2COMP(walk_vel));

  walk_vel = walk_vel.project(p_1_tan);
  printf("walk_vel mid=(%g,%g)\n",V2COMP(walk_vel));
  walk_vel = walk_vel.rebase(h_1);

  //vector2d exp_walk_dir;
  //exp_walk_dir.set(exp_dx,exp_dy);
  //if(exp_walk_dir.sqlength() > 0)
  //  exp_walk_dir = exp_walk_dir.norm();
  //walk_vel = walk_vel.project(exp_walk_dir);

  printf("walk_vel post=(%g,%g)\n",V2COMP(walk_vel));

  dx = walk_vel.x;
  dy = walk_vel.y;
}

void calc_examples(vector<MotCalExample> &mot_cal_examples,const vector<MotCalData> &mot_cal_data)
{
  vector2d neck_offset(0.0,0.0);

  mot_cal_examples.reserve(mot_cal_data.size());

  for(unsigned int i=0; i<mot_cal_data.size(); i++){
    MotCalExample ex;
    const MotCalData *datum;

    datum = &mot_cal_data[i];
    ex.ask_dx        = datum->dx;
    ex.ask_dy        = datum->dy;
    ex.ask_da        = datum->da;
    ex.pre_loc_body  = datum->pre_loc;
    ex.pre_angle     = datum->pre_angle;
    ex.post_loc_body = datum->post_loc;
    ex.post_angle    = datum->post_angle;

    neck_offset.x = datum->neck_offset_x;
    
    ex.pre_loc_body  = ex.pre_loc_body  - neck_offset.rotate(ex.pre_angle );
    ex.post_loc_body = ex.post_loc_body - neck_offset.rotate(ex.post_angle);

    calc_delta(ex.real_dx,ex.real_dy,ex.real_da,
               ex.ask_dx,ex.ask_dy,ex.ask_da,
               MovementTime,
               ex.pre_loc_body,ex.pre_angle,ex.post_loc_body,ex.post_angle);

    printf("example:ask dx=%8.3f dy=%8.3f da=%8.4f, real dx=%8.3f dy=%8.3f da=%8.4f, pre_body (%8.3f,%8.3f) <%8.4f post_body (%9.3f,%9.3f) <%8.4f\n",
           ex.ask_dx,ex.ask_dy,ex.ask_da,
           ex.real_dx,ex.real_dy,ex.real_da,
           ex.pre_loc_body.x,ex.pre_loc_body.y,
           ex.pre_angle,
           ex.post_loc_body.x,ex.post_loc_body.y,
           ex.post_angle
           );

    mot_cal_examples.push_back(ex);
  }
}

void dump_data(vector<MotCalData> &mot_cal_data)
{
  MotCalData mot_cal_data_item;

  for(uint i=0; i<mot_cal_data.size(); i++){
    mot_cal_data_item = mot_cal_data[i];
    printf("MT: motion: dx=%f dy=%f da=%f pre (%f,%f) <%f post (%f,%f) <%f\n",
           mot_cal_data_item.dx,mot_cal_data_item.dy,mot_cal_data_item.da,
           mot_cal_data_item.pre_loc.x,mot_cal_data_item.pre_loc.y,
           mot_cal_data_item.pre_angle,
           mot_cal_data_item.post_loc.x,mot_cal_data_item.post_loc.y,
           mot_cal_data_item.post_angle);
  }
}

void dump_example_debug(vector<MotCalData> &mot_cal_data,vector<MotCalExample> &mot_cal_examples)
{
  assert(mot_cal_data.size() == mot_cal_examples.size());

  MotCalData *mot_cal_datum;
  MotCalExample *mot_cal_example;

  for(uint i=0; i<mot_cal_examples.size(); i++){
    mot_cal_datum   = &mot_cal_data[i];
    mot_cal_example = &mot_cal_examples[i];
    //printf("example:ask dx=%f dy=%f da=%f pre (%f,%f) <%f post (%f,%f) <%f pre_body (%f,%f) <%f post_body (%f,%f) <%f\n",
    //       mot_cal_example->ask_dx,mot_cal_example->ask_dy,mot_cal_example->ask_da,
    //       mot_cal_datum->pre_loc.x,mot_cal_datum->pre_loc.y,
    //       mot_cal_datum->pre_angle,
    //       mot_cal_datum->post_loc.x,mot_cal_datum->post_loc.y,
    //       mot_cal_datum->post_angle,
    //       mot_cal_example->pre_loc_body.x,mot_cal_example->pre_loc_body.y,
    //       mot_cal_example->pre_angle,
    //       mot_cal_example->post_loc_body.x,mot_cal_example->post_loc_body.y,
    //       mot_cal_example->post_angle
    //       );
    printf("example:ask dx=%8.3f dy=%8.3f da=%8.4f, real dx=%8.3f dy=%8.3f da=%8.4f, pre_body (%8.3f,%8.3f) <%8.4f post_body (%9.3f,%9.3f) <%8.4f\n",
           mot_cal_example->ask_dx,mot_cal_example->ask_dy,mot_cal_example->ask_da,
           mot_cal_example->real_dx,mot_cal_example->real_dy,mot_cal_example->real_da,
           mot_cal_example->pre_loc_body.x,mot_cal_example->pre_loc_body.y,
           mot_cal_example->pre_angle,
           mot_cal_example->post_loc_body.x,mot_cal_example->post_loc_body.y,
           mot_cal_example->post_angle
           );
  }
}

void calc_regression_params(MotCal &mot_cal,vector<MotCalExample> &mot_cal_examples)
{
  int num_examples = mot_cal_examples.size();

  gsl_multifit_linear_workspace *workspace;
  gsl_matrix *example_bases,*cov;
  gsl_vector *obs,*coeffs,*weights;

  workspace = gsl_multifit_linear_alloc(num_examples,NumBases);
  example_bases = gsl_matrix_alloc(num_examples,NumBases);
  weights = gsl_vector_alloc(num_examples);
  obs     = gsl_vector_alloc(num_examples);
  coeffs  = gsl_vector_alloc(NumBases);
  cov = gsl_matrix_alloc(NumBases,NumBases);

  for(int ex_idx=0; ex_idx<num_examples; ex_idx++){
    MotCalExample *ex;
    ex = &mot_cal_examples[ex_idx];

    for(int basis_idx=0; basis_idx<NumBases; basis_idx++){
      double basis_val=get_basis(ex,basis_idx);

      gsl_matrix_set(example_bases,ex_idx,basis_idx,basis_val);
    }
  }

  mot_cal.dx_coeffs = new double[NumBases];
  mot_cal.dy_coeffs = new double[NumBases];
  mot_cal.da_coeffs = new double[NumBases];

  double total_weight[3];
  for(int output_idx=0; output_idx<3; output_idx++){
    total_weight[output_idx] = 0.0;
  }

  for(int output_idx=0; output_idx<3; output_idx++){
    for(int ex_idx=0; ex_idx<num_examples; ex_idx++){
      MotCalExample *ex;
      ex = &mot_cal_examples[ex_idx];

      double ex_val=0.0;
      switch(output_idx){
        case 0: ex_val = ex->real_dx; break;
        case 1: ex_val = ex->real_dy; break;
        default:
        case 2: ex_val = ex->real_da; break;
      }
      gsl_vector_set(obs,ex_idx,ex_val);

      double weight=0.0;
      switch(output_idx){
        case 0: weight=1.0-fabs(ex->ask_da/(.25*MaxDA)); break;
        case 1: weight=1.0-fabs(ex->ask_da/(.25*MaxDA)); break;
        default:
        case 2: weight=1.0; break;
      }
      if(weight < 0.001)
        weight = 0.001;
      total_weight[output_idx] += weight;
      gsl_vector_set(weights,ex_idx,weight);
    }

    double chisq;
    gsl_multifit_wlinear(example_bases,weights,obs,coeffs,cov,&chisq,workspace);
    printf("chisq/total_weight=%g\n",chisq/total_weight[output_idx]);

    for(int basis_idx=0; basis_idx<NumBases; basis_idx++){
      printf("output %d: basis %d has coeff %g\n",output_idx,basis_idx,gsl_vector_get(coeffs,basis_idx));
      switch(output_idx){
        case 0: mot_cal.dx_coeffs[basis_idx] = gsl_vector_get(coeffs,basis_idx); break;
        case 1: mot_cal.dy_coeffs[basis_idx] = gsl_vector_get(coeffs,basis_idx); break;
        default:
        case 2: mot_cal.da_coeffs[basis_idx] = gsl_vector_get(coeffs,basis_idx); break;
      }
    }
  }

  gsl_matrix_free(cov);
  gsl_vector_free(coeffs);
  gsl_vector_free(obs);
  gsl_vector_free(weights);
  gsl_matrix_free(example_bases);
  gsl_multifit_linear_free(workspace);
}

void test_regression_params(MotCal &mot_cal,vector<MotCalExample> &mot_cal_examples)
{
  int num_examples = mot_cal_examples.size();

  for(int ex_idx=0; ex_idx<num_examples; ex_idx++){
    MotCalExample *ex;
    double pdx,pdy,pda; // predicted values

    ex = &mot_cal_examples[ex_idx];

    pdx = 0.0;
    pdy = 0.0;
    pda = 0.0;
    for(int basis_idx=0; basis_idx<NumBases; basis_idx++){
      double basis = get_basis(ex,basis_idx);
      pdx += basis * mot_cal.dx_coeffs[basis_idx];
      pdy += basis * mot_cal.dy_coeffs[basis_idx];
      pda += basis * mot_cal.da_coeffs[basis_idx];
    }

    double errdx,errdy,errda;
    errdx = ex->real_dx - pdx;
    errdy = ex->real_dy - pdy;
    errda = ex->real_da - pda;

    printf("example:ask dx=%8.3f dy=%8.3f da=%8.4f, real dx=%8.3f dy=%8.3f da=%8.4f, "
           "predicted dx=%8.3f dy=%8.3f da=%8.4f, sq err dx=%9.3f dy=%9.3f da=%9.4f\n",
           ex->ask_dx,ex->ask_dy,ex->ask_da,
           ex->real_dx,ex->real_dy,ex->real_da,
           pdx,pdy,pda,
           sq(errdx),sq(errdy),sq(errda)
           );
    
  }
}

void dump_regression_params(MotCal &mot_cal)
{
  for(int output_idx=0; output_idx<3; output_idx++){
    printf("%d={",output_idx);
    for(int basis_idx=0; basis_idx<NumBases; basis_idx++){
      double *data=NULL;
      switch(output_idx){
        case 0: data=mot_cal.dx_coeffs; break;
        case 1: data=mot_cal.dy_coeffs; break;
        case 2: data=mot_cal.da_coeffs; break;         
      }
      printf("%20.13e",data[basis_idx]);
      if(basis_idx!=NumBases-1)
        printf(",");
    }
    printf("};\n");
  }

  Motion::OdometryParam odom_param;
  for(int basis_idx=0; basis_idx<NumBases; basis_idx++){
    odom_param.dx_coeffs[basis_idx] = mot_cal.dx_coeffs[basis_idx];
    odom_param.dy_coeffs[basis_idx] = mot_cal.dy_coeffs[basis_idx];
    odom_param.da_coeffs[basis_idx] = mot_cal.da_coeffs[basis_idx];
  }
  FILE *file;
  file = fopen("odom.prm","wb");
  if(file == NULL){
    fprintf(stderr,"error writing odometry parameters\n");
    return;
  }
  fwrite(&odom_param,sizeof(odom_param),1,file);
  fclose(file);
}

void usage()
{
  printf("mot_cal <data_file>\n");
}

int main(int argc, char *argv[])
{
  char *data_filename=NULL;
  FILE *data_file=NULL;

  if(argc < 2){
    usage();
    exit(1);
  }

  data_filename = argv[1];
  data_file = fopen(data_filename,"r");
  if(data_file==NULL){
    fprintf(stderr,"unable to open file '%s' for reading\n",data_filename);
    exit(2);
  }

  vector<MotCalData> mot_cal_data;
  vector<MotCalExample> mot_cal_examples;
  MotCal mot_cal;

  load_data(mot_cal_data,data_file);
  //dump_data(mot_cal_data);
  calc_examples(mot_cal_examples,mot_cal_data);
  dump_example_debug(mot_cal_data,mot_cal_examples);
  calc_regression_params(mot_cal,mot_cal_examples);
  test_regression_params(mot_cal,mot_cal_examples);
  dump_regression_params(mot_cal);

  fclose(data_file);

  return 0;
}
