/* LICENSE:
  =========================================================================
    CMPack'04 Source Code Release for OPEN-R SDK 1.1.5-r2 for ERS7
    Copyright (C) 2004 Multirobot Lab [Project Head: Manuela Veloso]
    School of Computer Science, Carnegie Mellon University
    All rights reserved.
  ========================================================================= */

#ifndef INCLUDED_spatial_tree_h
#define INCLUDED_spatial_tree_h

#include "../headers/fast_alloc.h"
#include "../headers/Geometry.h"
#include "../headers/Util.h"

//struct loc_t {
//  vector2f pos;
//  struct loc_t *next;
//};

#define ST_TEMP template<class loc_t>
#define ST_FUN  SpatialTree<loc_t>
#define PROC_TEMP template<class Processor>

ST_TEMP
class SpatialTree{
  // tree assumes that if child[0]!=NULL then child[1]!=NULL

  struct node{
    vector2f minv,maxv; // bounding box of subtree
    loc_t *locs;        // list of locations stored at this node
    int num_locs;       // number of locations of this subtree

    union{
      node *child[2]; // children of this tree
      node *next;     // used by fast allocator for free nodes
    };
  };

  // interface for Processor class, these are the methods that need
  // to be present
  //class Processor {
  //public:
  //  void add_loc(loc_t *loc) {}
  //};

  node *root;
  int leaf_size,max_depth;
  int tree_depth;
  int tests;
  int root_split_dim;
  fast_allocator<node> node_allocator;

protected:
  inline bool inside(vector2f &minv,vector2f &maxv,loc_t &s);
  inline float box_distance(vector2f &minv,vector2f &maxv,vector2f &p);

  void split(node *t,int split_dim);
  loc_t *nearest(node *t,loc_t *best,float &best_dist,vector2f &x);
  void clear(node *t);

  struct QueryInfo {
    vector2f minv,maxv;
    vector2f basis;
    vector2f basis_minv,basis_maxv;
  };

  PROC_TEMP
  void query_process(node *t,QueryInfo &query_info,Processor &proc);

public:
  SpatialTree() {root=NULL; leaf_size=max_depth=0;}
  ~SpatialTree() {clear();}

  bool setdim(const vector2f &minv,const vector2f &maxv,int nleaf_size,int nmax_depth);
  bool add(loc_t *s);
  void clear();
  loc_t *nearest(float &dist,vector2f &x);
  
  //. ----------------------------------------------------------------
  // Runs proc.add_loc on all of the points within the specified query rectangle.
  // INPUT
  //   x_dir  - direction vector of the major axis of rectangle
  //   center - center of rectangle
  //   range  - size of major,minor axis of rectangle
  // -----------------------------------------------------------------
  PROC_TEMP
  void query(vector2f x_dir,vector2f center,vector2f range,Processor &proc);

protected:
  PROC_TEMP
  void process_locs(node *t,Processor &proc);
public:  
  PROC_TEMP
  void process_locs(Processor &proc);
};

ST_TEMP
inline bool ST_FUN::inside(vector2f &minv,vector2f &maxv,loc_t &s)
{
  return(s.pos.x>minv.x && s.pos.y>minv.y &&
         s.pos.x<maxv.x && s.pos.y<maxv.y);
}

ST_TEMP
inline float ST_FUN::box_distance(vector2f &minv,vector2f &maxv,vector2f &p)
{
  float dx,dy;

  dx = p.x - bound(p.x,minv.x,maxv.x);
  dy = p.y - bound(p.y,minv.y,maxv.y);

  return(sqrt(dx*dx + dy*dy));
}

ST_TEMP
void ST_FUN::split(node *t,int split_dim)
{
  node *a,*b;
  loc_t *p,*n;
  float split_val;

  // make new nodes
  a = node_allocator.alloc();
  b = node_allocator.alloc();
  if(!a || !b) return;
  a->child[0] = b->child[0] = NULL;
  a->child[1] = b->child[1] = NULL;
  a->locs = b->locs = NULL;
  a->num_locs = b->num_locs = 0;

  // determine split value
  a->minv = b->minv = t->minv;
  a->maxv = b->maxv = t->maxv;

  if(split_dim == 0){
    split_val = (t->minv.x + t->maxv.x) / 2;
    a->maxv.x = b->minv.x = split_val;
  }else{
    split_val = (t->minv.y + t->maxv.y) / 2;
    a->maxv.y = b->minv.y = split_val;
  }

  // separate children based on split
  n = t->locs;
  while(p = n){
    n = n->next;

    if(((split_dim == 0)?p->pos.x : p->pos.y) < split_val){
      p->next = a->locs;
      a->locs = p;
      a->num_locs++;
    }else{
      p->next = b->locs;
      b->locs = p;
      b->num_locs++;
    }
  }

  // insert into tree
  t->locs = NULL;
  t->child[0] = a;
  t->child[1] = b;
}

ST_TEMP
bool ST_FUN::setdim(const vector2f &minv,const vector2f &maxv,int nleaf_size,int nmax_depth)
{
  clear();
  if(!root) root = node_allocator.alloc();
  if(!root) return(false);
  mzero(*root);
  root->minv = minv;
  root->maxv = maxv;
  leaf_size = nleaf_size;
  max_depth = nmax_depth;
  tree_depth = 0;
  root_split_dim = 0;
  return(true);
}

ST_TEMP
bool ST_FUN::add(loc_t *loc)
{
  node *p;
  int c,level;

  level = 0;
  p = root;
  if(!p) return(false);

  if(inside(p->minv,p->maxv,*loc)){
    // go down tree to see where new loc should go
    while(p->child[0]){ // implies p->child[1] also
      p->num_locs++;
      c = !inside(p->child[0]->minv,p->child[0]->maxv,*loc);
      p = p->child[c];
      level++;
    }
    
    // add it to leaf; and split leaf if too many children
    loc->next = p->locs;
    p->locs = loc;
    p->num_locs++;
    
    // split leaf if not too deep and too many children for one node
    if(level<max_depth && p->num_locs>leaf_size){
      split(p,(level + root_split_dim) % 2);
      tree_depth = max(tree_depth,level+1);
    }
    return(true);
  }else{
    int new_split_dim;
    int new_levels;
    float outside_frac;
    float est_new_levels;
    vector2f new_minv,new_maxv;
    node *new_root,*new_child;

    new_split_dim = (root_split_dim + (2-1)) % 2;
    new_minv = p->minv;
    new_maxv = p->maxv;
    new_levels = 0;
    
    // estimated number of new levels that will be needed so we can abort if
    //   there will be too many levels required
    est_new_levels = 0;
    outside_frac = max3(loc->pos.x - new_maxv.x, new_minv.x - loc->pos.x, 0.0f);
    outside_frac /= (new_maxv.x - new_minv.x);
    while(outside_frac > 1.0){
      outside_frac /= 2.0;
      est_new_levels++;
    }
    if(outside_frac > 0.0)
      est_new_levels++;
    outside_frac = max3(loc->pos.y - new_maxv.y, new_minv.y - loc->pos.y, 0.0f);
    outside_frac /= (new_maxv.y - new_minv.y);
    while(outside_frac > 1.0){
      outside_frac /= 2.0;
      est_new_levels++;
    }
    if(outside_frac > 0.0)
      est_new_levels++;
    if(est_new_levels > max_depth)
      return(false);

    // determine range for new root node
    while(!inside(new_minv,new_maxv,*loc)){
      vector2f child_minv,child_maxv;
      int child_num=0; // 0=smaller, 1=larger

      child_minv = new_minv;
      child_maxv = new_maxv;

      // float range, expanding towards point
      // if point already in this range expand arbitrarily towards greater values
      if(new_split_dim==0){
        if(loc->pos.x > new_minv.x){
          child_minv.x = new_maxv.x;
          new_maxv.x = new_minv.x + (new_maxv.x - new_minv.x)*2;
          child_maxv.x = new_maxv.x;
          child_num = 1;
        }else{
          child_maxv.x = new_minv.x;
          new_minv.x = new_maxv.x + (new_minv.x - new_maxv.x)*2;
          child_minv.x = new_minv.x;
          child_num = 0;
        }
      }else{
        if(loc->pos.y > new_minv.y){
          child_minv.y = new_maxv.y;
          new_maxv.y = new_minv.y + (new_maxv.y - new_minv.y)*2;
          child_maxv.y = new_maxv.y;
          child_num = 1;
        }else{
          child_maxv.y = new_minv.y;
          new_minv.y = new_maxv.y + (new_minv.y - new_maxv.y)*2;
          child_maxv.y = new_minv.y;
          child_num = 0;
        }
      }

      // create new root node (and new child node for balance)
      new_root  = node_allocator.alloc();
      new_child = node_allocator.alloc();
      if(!new_root || !new_child){
        if(new_root)
          node_allocator.free(new_root);
        if(new_child)
          node_allocator.free(new_child);
        return(false);
      }

      new_child->child[0] = NULL;
      new_child->child[1] = NULL;
      new_child->locs = NULL;
      new_child->num_locs = 0;
      new_child->minv = child_minv;
      new_child->maxv = child_maxv;

      new_root->child[1 - child_num] = p;
      new_root->child[child_num    ] = new_child;
      new_root->locs = NULL;
      new_root->num_locs = root->num_locs;
      new_root->minv = new_minv;
      new_root->maxv = new_maxv;

      root = new_root;
      tree_depth++;
      p = root;

      // update
      new_split_dim = (new_split_dim + 1) % 2;
    }
    if(!inside(new_minv,new_maxv,*loc))
      return(false);

    return(add(loc));
  }
}

ST_TEMP
void ST_FUN::clear(node *t)
{
  if(!t) return;
  if(t->child[0]) clear(t->child[0]);
  if(t->child[1]) clear(t->child[1]);

  t->child[0] = t->child[1] = NULL;
  t->locs = NULL;
  t->num_locs = 0;

  node_allocator.free(t);
}

ST_TEMP
void ST_FUN::clear()
{
  if(!root) return;

  clear(root->child[0]);
  clear(root->child[1]);

  root->child[0] = root->child[1] = NULL;
  root->locs = NULL;
  root->num_locs = 0;
}

ST_TEMP
loc_t *ST_FUN::nearest(node *t,loc_t *best,float &best_dist,vector2f &x)
{
  float d,dc[2];
  loc_t *p;
  int c;

  // look at locs at current node
  p = t->locs;
  while(p){
    d = GVector::distance(p->pos,x);
    if(d < best_dist){
      best = p;
      best_dist = d;
    }
    tests++;
    p = p->next;
  }

  // recurse on children (nearest first to maximize pruning)
  if(t->child[0]){ // implies t->child[1]
    dc[0] = box_distance(t->child[0]->minv,t->child[0]->maxv,x);
    dc[1] = box_distance(t->child[1]->minv,t->child[1]->maxv,x);
    c = dc[1] < dc[0]; // c indicates nearest lower bound distance child

    if(dc[ c] < best_dist) best = nearest(t->child[ c],best,best_dist,x);
    if(dc[!c] < best_dist) best = nearest(t->child[!c],best,best_dist,x);
  }

  return(best);
}

ST_TEMP
loc_t *ST_FUN::nearest(float &dist,vector2f &x)
{
  loc_t *best;

  best = NULL;
  dist = 4000;

  tests = 0;
  best = nearest(root,best,dist,x);
  // printf("tests=%d dist=%f\n\n",tests,best_dist);

  return(best);
}

ST_TEMP
PROC_TEMP
void ST_FUN::query_process(node *t,QueryInfo &query_info,Processor &proc)
{
  //printf("querying node %p, (%g,%g)-(%g,%g)\n",t,V2COMP(t->minv),V2COMP(t->maxv));

  // prune trees that don't overlap query bounding box
  if(t->minv.x > query_info.maxv.x || t->maxv.x < query_info.minv.x ||
     t->minv.y > query_info.maxv.y || t->maxv.y < query_info.minv.y){
    //printf("pruned\n");
    return;
  }

  if(t->child[0]!=NULL){
    // implies t->child[1]!=NULL
    //printf("querying children %p and %p\n",t->child[0],t->child[1]);
    query_process(t->child[0],query_info,proc);
    query_process(t->child[1],query_info,proc);
  }else{
    //printf("processing leaf\n");
    loc_t *cur_loc;

    for(cur_loc=t->locs; cur_loc!=NULL; cur_loc=cur_loc->next){
      vector2f rebased_loc;

      rebased_loc = cur_loc->pos.rebase(query_info.basis);
      if(rebased_loc >= query_info.basis_minv &&
         rebased_loc <= query_info.basis_maxv){
        proc.add_loc(cur_loc);
      }
    }
  }
}

ST_TEMP
PROC_TEMP
void ST_FUN::query(vector2f x_dir,vector2f center,vector2f range,Processor &proc)
{
  if(root==NULL)
    return;

  QueryInfo query_info;
  vector2f basis_center;
  vector2f minv,maxv;
  vector2f test;

  basis_center = center.rebase(x_dir);

  query_info.basis      = x_dir;
  query_info.basis_minv = basis_center - range/2.0;
  query_info.basis_maxv = basis_center + range/2.0;

  minv = maxv = query_info.basis_minv.project(x_dir);
  test = query_info.basis_maxv.project(x_dir);
  test.add_to_bound(minv,maxv);
  test = vector2f(query_info.basis_minv.x,query_info.basis_maxv.y).project(x_dir);
  test.add_to_bound(minv,maxv);
  test = vector2f(query_info.basis_maxv.x,query_info.basis_minv.y).project(x_dir);
  test.add_to_bound(minv,maxv);

  query_info.minv = minv;
  query_info.maxv = maxv;

  //printf("query: x_dir=(%g,%g) center=(%g,%g) range=(%g,%g)\n",
  //       V2COMP(x_dir),V2COMP(center),V2COMP(range));
  //printf("info : basis=(%g,%g) basis_minv=(%g,%g) basis_maxv=(%g,%g) minv=(%g,%g) maxv=(%g,%g)\n",
  //       V2COMP(query_info.basis),V2COMP(query_info.basis_minv),V2COMP(query_info.basis_maxv),
  //       V2COMP(query_info.minv),V2COMP(query_info.maxv));

  query_process(root,query_info,proc);
}

ST_TEMP
PROC_TEMP
void ST_FUN::process_locs(node *n,Processor &proc)
{
  //printf("entering tree with bbox (%g,%g)-(%g,%g)\n",
  //       V2COMP(n->minv),V2COMP(n->maxv));

  if(n->child[0])
    process_locs(n->child[0],proc);
  if(n->child[1])
    process_locs(n->child[1],proc);
  if(!n->child[0] && !n->child[1]){
    loc_t *loc;
    loc = n->locs;
    while(loc){
      proc.add_loc(loc);
      loc = loc->next;
    }
  }
}

ST_TEMP
PROC_TEMP
void ST_FUN::process_locs(Processor &proc)
{
  if(root)
    process_locs(root,proc);
}

#undef ST_TEMP
#undef ST_FUN 
#undef PROC_TEMP

#endif
