#ifndef __BSP_TREE_H__
#define __BSP_TREE_H__

#include <stdio.h>

#include <queue>

#include "agent/headers/Util.h"

#include "util.h"
#include "nvector.h"

#define BSP_TEM template <class state_t,class num_t,int dim>
#define BSP_FUN BSPTree<state_t,num_t,dim>

BSP_TEM
class BSPTree{
public:
  typedef Vec::nvector<num_t,dim> vec_t;

  struct Node{
    vec_t mean;
    num_t radius;
    vec_t split_dir;
    num_t split_threshold;

    state_t *states;
    int num_states;

    Node *front,*back;
  };

  struct QueryNode{
    float dist;
    Node *node;
    state_t *state;

    bool operator <(const QueryNode &qn) const
      {return(dist > qn.dist);}
  };

protected:
  Node *root;
  int leaf_size,max_depth;

  vec_t query_point;

  class PriorityQueue : public std::priority_queue<QueryNode>{
  public:
    void clear() {c.clear();}
  };
  PriorityQueue queue;

protected:
  double distFromQuery(Node *p);
  double distFromQuery(state_t *s);

  void freeTree(Node *p);
  void add(Node **q,state_t *s);
  void build(Node *p,int level);

  void addToSearchQueue(Node *p);
  void addToSearchQueue(state_t *s);

public:
  BSPTree() {root=NULL; leaf_size=16; max_depth=20;}
  ~BSPTree() {freeTree(root); root=NULL;}

  void add(state_t *s) {add(&root,s);}
  void build() {build(root,0);}
  void clear() {freeTree();}

  void startQuery(state_t &query_point_);
  state_t *getNextNearest(double &dist);
  void endQuery() {queue.clear();}
};

/*
BSP_TEM
inline bool operator <(const typename BSP_FUN::QueryNode &a,
                       const typename BSP_FUN::QueryNode &b)
{
  return(a.dist > b.dist);
}
*/

BSP_TEM
double BSP_FUN::distFromQuery(Node *p)
{
  double d = dist(query_point,p->mean);
  return(max(d - p->radius, 0.0));
}

BSP_TEM
double BSP_FUN::distFromQuery(state_t *s)
{
  int i;
  double d = 0.0;
  for(i=0; i<dim; i++) d += sq(query_point[i] - s->val(i));
  return(sqrt(d));
}

BSP_TEM
void BSP_FUN::freeTree(Node *p)
{
  if(p){
    freeTree(p->front);
    freeTree(p->back);
    p->front = p->back = NULL;
    delete(p);
  }
}

BSP_TEM
void BSP_FUN::add(Node **q,state_t *s)
{
  Node *p = *q;
  if(!p){
    *q = p = new Node;
    mzero(*p);
  }

  int i;

  // add point to mean
  for(i=0; i<dim; i++) p->mean[i] += s->val(i);

  // save point
  s->next = p->states;
  p->states = s;
  p->num_states++;
}

BSP_TEM
void BSP_FUN::build(Node *p,int level)
{
  int i;
  double d,md;

  if(!p || p->num_states==0) return;

  // normalize sum to get mean
  mul(p->mean, 1.0/p->num_states);

  zero(p->split_dir);
  md = 0.0;

  // choose splitting plane based on furthest point
  state_t *s,*sn;
  s = p->states;
  while(s){
    d = 0.0;
    for(i=0; i<dim; i++) d += sq(s->val(i) - p->mean[i]);

    if(d > md){
      // save point info
      for(i=0; i<dim; i++) p->split_dir[i] = s->val(i);
      md = d;
    }

    s = s->next;
  }

  // change point into normalized direction
  sub(p->split_dir,p->mean);
  norm(p->split_dir);

  // calculate radius and split threshold (dot product of mean and split dir)
  p->radius = sqrt(md);
  p->split_threshold = dot(p->split_dir, p->mean);

  if(false){
    for(i=0; i<level; i++) printf("    ");
    printf("mean=[%5.1f,%5.1f,%5.1f] dir=[%6.3f,%6.3f,%6.3f] thresh=%f radius=%f num=%d\n",
           p->mean[0],p->mean[1],p->mean[2],
           p->split_dir[0],p->split_dir[1],p->split_dir[2],
           p->split_threshold,p->radius,
           p->num_states);
  }

  // stop if we are small enough to be a leaf, or at the depth limit
  if(p->num_states<leaf_size || level>=max_depth) return;

  // split the points based on the plane
  s = p->states;
  while(s){
    sn = s->next;

    d = -p->split_threshold;
    for(i=0; i<dim; i++) d += p->split_dir[i] * s->val(i);

    if(d > 0){
      add(&(p->front),s);
    }else{
      add(&(p->back),s);
    }

    s = sn;
  }

  // this state no longer stores points
  p->states = 0;
  p->num_states = 0;

  // recursively build children
  build(p->front,level+1);
  build(p->back ,level+1);
}

BSP_TEM
void BSP_FUN::startQuery(state_t &query_point_)
{
  int i;
  endQuery();
  for(i=0; i<dim; i++) query_point[i] = query_point_.val(i);
  addToSearchQueue(root);
}

BSP_TEM
void BSP_FUN::addToSearchQueue(Node *p)
{
  QueryNode qn;
  qn.dist = distFromQuery(p);
  qn.node = p;
  qn.state = NULL;
  queue.push(qn);
}

BSP_TEM
void BSP_FUN::addToSearchQueue(state_t *s)
{
  QueryNode qn;
  qn.dist = distFromQuery(s);
  qn.node = NULL;
  qn.state = s;
  queue.push(qn);
}

BSP_TEM
state_t *BSP_FUN::getNextNearest(double &dist)
{
  QueryNode qn;
  state_t *s;

  while(queue.size() > 0){
    // get head of priority queue
    qn = queue.top();
    queue.pop();

    // if its a raw state, return it
    if(qn.state){
      dist = qn.dist;
      return(qn.state);
    }

    // otherwise it is a node, so we have to expand it

    // expand nodes
    if(qn.node->front) addToSearchQueue(qn.node->front);
    if(qn.node->back ) addToSearchQueue(qn.node->back);

    // expand states
    s = qn.node->states;
    while(s){
      addToSearchQueue(s);
      s = s->next;
    }
  }

  // ran out of states
  return(NULL);
}

/*
BSP_TEM
void BSP_FUN::endQuery()
{
  int i,n;
  n = queue.size();
  for(i=0; i<n; i++) queue.pop();
}
*/

#endif
