// kalman.cc
//
// Generic extended kalman-bucy filter code.
//
// Created by:  Michael Bowling (mhb@cs.cmu.edu)
//
// Modified by:  Paul E. Rybski (prybski@cs.cmu.edu)
//
/* LICENSE:
  =========================================================================
    CMPack'04 Source Code Release for OPEN-R SDK 1.1.5-r2 for ERS7
    Copyright (C) 2004 Multirobot Lab [Project Head: Manuela Veloso]
    School of Computer Science, Carnegie Mellon University
    All rights reserved.
  ========================================================================= */

#include <stdio.h>
#include <math.h>

#include "../headers/kalman.h"

Kalman::Kalman(int _state_n, int _obs_n, double _stepsize) { 
  state_n = _state_n;
  obs_n = _obs_n;
  stepsize = _stepsize; 

  xs.clear(); xs.push_back(Matrix(state_n, 1));
  Ps.clear(); Ps.push_back(Matrix(state_n, state_n));
  Is.clear(); Is.push_back(Matrix());

  prediction_lookahead = 0.0;
  prediction_time = 0.0;
  errors = Matrix(_state_n, 1);
}

void Kalman::initial(double t, Matrix &x, Matrix &P)
{
  xs.clear(); xs.push_back(x);
  Ps.clear(); Ps.push_back(P);
  Is.clear(); Is.push_back(Matrix());
  stepped_time = time = t;
}

void Kalman::propagate()
{
  Matrix x = xs.back();
  Matrix P = Ps.back();
  Matrix &_A = A(x);
  Matrix &_W = W(x);
  Matrix &_Q = Q(x);
  Matrix I;
  
#ifdef KALMAN_DEBUG
  fprintf(stderr, "PROPAGATE:\n");
  fprintf(stderr, "x =\n");
  x.print();
  fprintf(stderr, "P =\n");
  P.print();
#endif
  
  x = f(x, I);
  P = _A * P * transpose(_A) + _W * _Q * transpose(_W);

  xs.push_back(x);
  Ps.push_back(P);
  Is.push_back(I);
  
#ifdef KALMAN_DEBUG
  fprintf(stderr, "=============>\nx =\n");
  x.print();
  fprintf(stderr, "P =\n");
  P.print();
  fprintf(stderr, "\n");
#endif

  stepped_time += stepsize;
}

void Kalman::update(const Matrix &z,int iterate)
{
  Matrix x = xs.front();
  Matrix P = Ps.front();
  Matrix I = Is.front();

  Matrix &_H = H(x);
  Matrix &_V = V(x); 
  Matrix &_R = R(x);

  Matrix K;

  // We clear the prediction list because we have a new observation.
  xs.clear(); Ps.clear(); Is.clear(); stepped_time = time;

  // Iterated EKF step
  for (int i=0;i<iterate;i++) {

    _H = H(x);
    _V = V(x); 
    _R = R(x);

    K = P * transpose(_H) * 
      inverse(_H * P * transpose(_H) + _V * _R * transpose(_V));

    Matrix error = K * (z - h(x));
    
#ifdef KALMAN_DEBUG
    fprintf(stderr, "UPDATE:\n");
    fprintf(stderr, "x =\n");
    x.print();
    fprintf(stderr, "z =\n");
    z.print();
    fprintf(stderr, "P =\n");
    P.print();
    fprintf(stderr, "K =\n");
    K.print();
#endif
    
    x = x + error;
  }
  P = (Matrix(P.nrows()) - K * _H) * P;

  // Add the current state back onto the prediction list.
  xs.push_back(x); Ps.push_back(P); Is.push_back(I);

  if (prediction_lookahead > 0.0) {
    if (time - prediction_time >= prediction_lookahead) {

      if (prediction_time > 0.0) {
	Matrix error = x - prediction_x;

	for(int i=0; i < error.nrows(); i++)
	  errors.e(i, 0) += fabs(error.e(i, 0));
	errors_n++;
      }

      prediction_x = predict(prediction_lookahead);
      prediction_time = time;
    }
  }

#ifdef KALMAN_DEBUG
  fprintf(stderr, "=============>\nx =\n");
  x.print();
  fprintf(stderr, "P =\n");
  P.print();
  fprintf(stderr, "\n");
#endif
}

void Kalman::tick(double dt) 
{
  uint nsteps = (int) rint(dt / stepsize);

  while(xs.size() - 1 < nsteps) { propagate(); }

  xs.erase(xs.begin(), xs.begin() + nsteps);
  Ps.erase(Ps.begin(), Ps.begin() + nsteps);
  Is.erase(Is.begin(), Is.begin() + nsteps);
  
  time += dt;
}

Matrix Kalman::predict(double dt)
{
  uint nsteps = (int) rint(dt / stepsize);

#ifdef PLATFORM_LINUX
  //  fprintf(stderr,"%f %f\n",dt,stepsize);
  //  fprintf(stderr,"Predicting ahead %d steps\n",nsteps);
#endif
  while(xs.size() - 1 < nsteps) { propagate(); }

  return xs[nsteps];
}

Matrix Kalman::predict_cov(double dt)
{
  uint nsteps = (int) rint(dt / stepsize);

  while(xs.size() - 1 < nsteps) { propagate(); }

  return Ps[nsteps];
}

Matrix Kalman::predict_info(double dt)
{
  uint nsteps = (int) rint(dt / stepsize);

  while(xs.size() - 1 < nsteps) { propagate(); }

  return Is[nsteps];
}

Matrix Kalman::predict_fast(double dt)
{
  uint nsteps = (int) rint(dt / stepsize);
  double orig_stepsize = stepsize;

  if (xs.size() - 1 >= nsteps) return xs[nsteps];

  stepsize = dt - (stepped_time - time);
  propagate();

  Matrix rv = xs.back();

  stepped_time -= stepsize;
  stepsize = orig_stepsize;
  xs.pop_back();
  Ps.pop_back();
  Is.pop_back();

  return rv;
}

double Kalman::obs_likelihood(double dt, Matrix &z)
{
  Matrix x = predict(dt);
  Matrix P = predict_cov(dt);
  Matrix _hx = h(x);
  Matrix &_H = H(x);

  Matrix C = _H * P * transpose(_H);

  Matrix D = z - _hx;
  
  double likelihood = 1.0;

  for(int i=0; i<D.nrows(); i++)
    likelihood *= exp( - (D.e(i, 0) * D.e(i, 0)) / (2 * C.e(i, i)) );

  return likelihood;
}

Matrix Kalman::error_mean()
{
  return errors.scale(1.0 / (double) errors_n);
}

void Kalman::error_reset()
{
  errors = errors.scale(0.0);
  errors_n = 0;
}

double Kalman::error_time_elapsed()
{
  return errors_n * prediction_lookahead;
}
