/* LICENSE:
  =========================================================================
    CMPack'04 Source Code Release for OPEN-R SDK 1.1.5-r2 for ERS7
    Copyright (C) 2004 Multirobot Lab [Project Head: Manuela Veloso]
    School of Computer Science, Carnegie Mellon University
    All rights reserved.
  ========================================================================= */

#include "../headers/Config.h"
#include "../headers/Reporting.h"
#include "../Main/MainObject.h"
#include "StuckDetector.h"

StuckDetector::StuckDetector() {
  tree = new DTree();
  
  data_index = 0;
  data = new vector3d[max_history];
  
  for(int i=0; i<max_history; i++) {
    data[i].set(0.0, 0.0, 0.0);
  }
  
  features = new double[num_features];
  for(int i=0; i<num_features; i++)
    features[i] = 0.0;

  last_class = StuckInfo::STANDING;
  result_hist_index = 0;
  num_stuck_in_hist = 0;  
  result_history = new int[max_result_history];
  for(int i=0; i<max_result_history; i++)
    result_history[i] = StuckInfo::STANDING;

  history_mean.set(0.0, 0.0, 0.0);
  history_sum.set(0.0, 0.0, 0.0);
  history_xx_sum = history_yy_sum = history_zz_sum = 0.0;
  history_xy_sum = history_xz_sum = history_yz_sum = 0.0;
}

StuckDetector::~StuckDetector() {
  if(tree!=NULL)
    delete tree;
  tree = NULL;
  
  if(features!=NULL)
    delete[] features;
  features = NULL;
  
  if(data!=NULL)
    delete[] data;
  data = NULL;
}

void StuckDetector::init() {
  if(strcmp(model,"ERS210")==0){
    tree->load("/MS/config/ers210/stree.prm");
    //tree->print();
  }else{
    tree->load("/MS/config/ers7/stree.prm");
  }
}

/* We track the mean, variance, and correlation between our three data sets
   in an incremental fashion over max_history data points.
*/

void StuckDetector::update(vector3d accel) {
#ifdef PLATFORM_APERIOS
  static EventTimeReporter reporter("StuckDetector::update",SecToTime(5.0),SecToTime(100.0),1000UL,&TextOutputStream);
  EventTimeReporter::EventTimer timer(&reporter,config.spoutConfig.dumpProfile);
#endif
  
  vector3d accel_var;
  double xy_cor;
  double xz_cor;
  double yz_cor;
  double denom;
  
  // Iteritively update various sums and means
  history_sum -= data[data_index];
  history_sum += accel;

  history_mean = history_sum / max_history;

  history_xx_sum -= data[data_index].x*data[data_index].x;
  history_yy_sum -= data[data_index].y*data[data_index].y;
  history_zz_sum -= data[data_index].z*data[data_index].z;
  history_xy_sum -= data[data_index].x*data[data_index].y;
  history_xz_sum -= data[data_index].x*data[data_index].z;
  history_yz_sum -= data[data_index].y*data[data_index].z;

  history_xx_sum += accel.x*accel.x;
  history_yy_sum += accel.y*accel.y;
  history_zz_sum += accel.z*accel.z;
  history_xy_sum += accel.x*accel.y;
  history_xz_sum += accel.x*accel.z;
  history_yz_sum += accel.y*accel.z;

  // Update the variance.
  // var = sum((x - mean_x)^2) / (n - 1)
  // = sum(x^2 - 2*x*mean_x + x_mean^2)
  // = (sum(x^2) - 2*mean_x*sum(x) + n*x_mean^2) / (n -1)

  // Add sum(x^2)
  accel_var.set(history_xx_sum,
		history_yy_sum,
		history_zz_sum);

  // - mean*sum(x)*2
  accel_var -= history_mean*history_sum*2;

  // + mean^2*n
  accel_var += history_mean*history_mean*max_history;

  // divide by n - 1
  accel_var /= max_history - 1;

  // Correlation:
  // = sum((x - mean_x)(y - mean_y)) /
  ///    (sqrt(var_x*var_y)(n - 1))
  // denominator is easy, expand numerator for online updates
  // = sum(x*y) - mean_y*sum(x) - mean_x*sum(y) + mean_x*mean_y*n

  // Do X-Y correlation
  denom = sqrt(accel_var.x * accel_var.y) * (max_history - 1);
  if(denom > 0.000001) {
    // sum(x*y)
    xy_cor = history_xy_sum;
    // - mean_y*sum(x)
    xy_cor -= history_mean.y*history_sum.x;
    // - mean_x*sum(y)
    xy_cor -= history_mean.x*history_sum.y;
    // + n*mean_x*mean_y
    xy_cor += max_history*history_mean.x*history_mean.y;

    // Now divide by the denominator.
    xy_cor /= denom;
  } else {
    xy_cor = 0;
  }

  
  // Do X-Z correlation
  denom = sqrt(accel_var.x * accel_var.z) * (max_history - 1);
  if(denom > 0.000001) {
    // sum(x*y)
    xz_cor = history_xz_sum;
    // - mean_y*sum(x)
    xz_cor -= history_mean.z*history_sum.x;
    // - mean_x*sum(y)
    xz_cor -= history_mean.x*history_sum.z;
    // + n*mean_x*mean_y
    xz_cor += max_history*history_mean.x*history_mean.z;

    // Now divide by the denominator.
    xz_cor /= denom;
  } else {
    xz_cor = 0;
  }
  
  // Do Y-Z correlation
  denom = sqrt(accel_var.y * accel_var.z) * (max_history - 1);
  if(denom > 0.000001) {
    // sum(x*y)
    yz_cor = history_yz_sum;
    // - mean_y*sum(x)
    yz_cor -= history_mean.z*history_sum.y;
    // - mean_x*sum(y)
    yz_cor -= history_mean.y*history_sum.z;
    // + n*mean_x*mean_y
    yz_cor += max_history*history_mean.y*history_mean.z;

    // Now divide by the denominator.
    yz_cor /= denom;
  } else {
    yz_cor = 0;
  }

  // Actually write to our data array
  data[data_index] = accel;
  
  data_index = (data_index + 1) % max_history;

  // This next bit tells us whether or not we should test using the
  // decision tree. (Doing it every sensor frame is pointless 'cuz they
  // happen more quickly than vision/behavior frames)
  test_timer++;

  if(test_timer % test_interval==0) {
    features[0] = accel_var.x;
    features[1] = accel_var.y;
    features[2] = accel_var.z;
    features[3] = xy_cor;
    features[4] = xz_cor;
    features[5] = yz_cor;

    if(tree!=NULL) {
      last_class = tree->getClass(features, num_features);
      // if there is an error, just don't use the tree and say we're
      // playing.
      if(last_class==-1)
	last_class = StuckInfo::PLAYING;
    } else {
      last_class = StuckInfo::PLAYING;
    }

    if(last_class==StuckInfo::WALL ||
       last_class==StuckInfo::HOOKED)
      num_stuck_in_hist++;

    if(result_history[result_hist_index]==StuckInfo::WALL ||
       result_history[result_hist_index]==StuckInfo::HOOKED)
      num_stuck_in_hist--;

    result_history[result_hist_index] = last_class;

    result_hist_index = (result_hist_index + 1) % max_result_history;
      
  }
}

StuckInfo StuckDetector::getStuckInfo() {
  StuckInfo retval;

  retval.last_state = last_class;
  retval.fraction_stuck = getStuckFraction();

  return retval;
}

// Return the fraction of the last max_result_history frames that
// we've been stuck for.
double StuckDetector::getStuckFraction() {
  return ((double)num_stuck_in_hist)/((double)max_result_history);
}
