/* LICENSE:
  =========================================================================
    CMPack'04 Source Code Release for OPEN-R SDK 1.1.5-r2 for ERS7
    Copyright (C) 2004 Multirobot Lab [Project Head: Manuela Veloso]
    School of Computer Science, Carnegie Mellon University
    All rights reserved.
  ========================================================================= */

#ifndef __STATE_MACHINE_H__
#define __STATE_MACHINE_H__

#include <stdlib.h>

#include "../headers/CircBufPacket.h"
#include "../Motion/MotionInterface.h"

using namespace Motion;

#define FSM_TEM \
  template <class state_t,class timestamp_t>
#define FSM_FUN \
  FiniteStateMachine<state_t,timestamp_t>

#define PRINTF(format,args...) \
  pprintf(TextOutputStream,format,##args)

FSM_TEM
class FiniteStateMachine{
public:
  struct TransitionInfo{
    int trans_id; // transition id (needs to be unique within parent state)
    state_t new_state; // state that this transition changes to
    int count; // number of times this transition has fired
    timestamp_t timestamp; // the last time this transition happened
    const char *trans_name; // string name of transition
    TransitionInfo *left,*right; // child pointers for id parity tree
  };

  struct StateInfo{
    int enter_count; // number of times we've transitioned to this state
    int run_count; // number of times we've been in this state
    int cmd_count; // number of times this command has generated an action
    // note: run_count >= cmd_count, because this state may choose to
    // transition to another state rather than generating a command
    timestamp_t total_time; // total time spent in this state
    TransitionInfo *root; // root of transition info tree
  };

  struct TransitionEvent{
    unsigned char old_state; // the state you are transition FROM
    unsigned char new_state; // the state you are transition TO
    unsigned char trans_id; // the transitions' unique id label
    unsigned char reserved;
    timestamp_t timestamp; // when the transition happened
  };

  enum ErrorCode{
    ErrInfiniteLoop         = 1 << 0,
    ErrNestedStartLoop      = 1 << 1,
    ErrNestedEndLoop        = 1 << 2,
    ErrOutOfTransitions     = 1 << 3,
    ErrTransOutsideOfLoop   = 1 << 4,
    ErrSetStateInsideOfLoop = 1 << 5
  };

private:
  StateInfo *state_info;
  TransitionInfo *trans_info;
  TransitionEvent *event_cache;
  state_t last_state;
  state_t state;
  int num_states;
  int num_trans,max_trans;
  int num_events,event_cache_size;

  timestamp_t state_start_time; // timestamp when current state was entered
  timestamp_t loop_start_time; // timestamp of decision loop start
  bool in_loop; // true if we are inside a startLoop/endLoop pair
  bool sleeping; // true if state machine's command is not being used
  bool is_new_state; // true if we switched into state this timestamp

  const char *const *state_name; // array of state name strings
public:
  int error; // mask of error codes

private:
  TransitionInfo *lookup(state_t state,int trans_id);

public:
  FiniteStateMachine();
  ~FiniteStateMachine() {reset();}

  bool init(const char *const *state_names,
            int _num_states,state_t initial_state,
            int max_transitions,int _event_cache_size);
  void reset();

  // the following function wrap the decision loop
  void startLoop(timestamp_t timestamp);
  void endLoop();

  // these can ONLY be called from within a decision loop
  void transition(state_t new_state,int trans_id,const char *trans_name);
  timestamp_t timeInState() {return(loop_start_time - state_start_time);}
  bool isNewState() {return(is_new_state);}

  // these can be called outside of the decision loop
  void setState(state_t new_state,int trans_id,const char *trans_name,
                timestamp_t timestamp);
  void sleep() {sleeping = true;}

  // these can be called anywhere
  state_t getState() {return(state);}
  state_t getPrevState() {return(last_state);}
  timestamp_t getStateStartTime() {return(state_start_time);}
  void clearEvents() {num_events = 0;}

  void dumpLoopEvents();
  void dumpStats();

  void handleErr(MotionCommand *command);
  
  // debug querying interface
  const StateInfo *lookupStateInfo(state_t state) const;
private:
  int countTransitions(const TransitionInfo *trans_info) const;
public:
  int countStateTransitions(state_t state) const;
  const TransitionInfo *lookupTransitionInfo(state_t state,int trans_id) const;
  int maxTransitions() const;
  int numEvents() const;
  const TransitionEvent *lookupEvent(int event_idx) const;
  state_t lookupState() {return(state);}
  bool asleep() {return(sleeping);}
};

#define TransReturn(state_machine, new_state, trans_id, trans_name) \
  {state_machine.transition(new_state,trans_id,trans_name); return(new_state);}

#define TRANS_CONT(state_machine, new_state, trans_id, trans_name) \
  {state_machine.transition(new_state,trans_id,trans_name); continue;}


//==== Implementation ================================================//

FSM_TEM
typename FSM_FUN::TransitionInfo *FSM_FUN::lookup(state_t state,int trans_id)
{
  TransitionInfo *p = state_info[state].root;
  int level = 0;

  while(p && p->trans_id!=trans_id){
    p = ((trans_id >> level) & 1)? p->right : p->left;
    level++;
  }

  return(p);
}

FSM_TEM
FSM_FUN::FiniteStateMachine()
{
  state_info = NULL;
  trans_info = NULL;
  event_cache = NULL;

  reset();
}

FSM_TEM
bool FSM_FUN::init(const char *const *state_names,
                   int _num_states,state_t initial_state,
                   int max_transitions,int _event_cache_size)
{
  delete(state_info);
  delete(trans_info);
  delete(event_cache);

  state_info = new StateInfo[_num_states];
  trans_info = new TransitionInfo[max_transitions];
  event_cache = new TransitionEvent[_event_cache_size];

  if(!state_info || !trans_info || !event_cache){
    reset();
    return(false);
  }

  state = initial_state;
  num_states = _num_states;
  num_trans = 0;
  max_trans = max_transitions;
  num_events = 0;
  event_cache_size = _event_cache_size;

  state_start_time = 0;
  loop_start_time = 0;
  in_loop = false;
  sleeping = true;
  is_new_state = true;

  state_name = state_names;
  error = 0;

  mzero(state_info,num_states);
  mzero(trans_info,max_trans);
  mzero(event_cache,event_cache_size);
  state_info[initial_state].enter_count = 1;

  return(true);
}

FSM_TEM
void FSM_FUN::reset()
{
  delete(state_info);
  delete(trans_info);
  delete(event_cache);

  state_info = NULL;
  trans_info = NULL;
  event_cache = NULL;

  state = (state_t)0;
  num_states = 0;
  num_trans = max_trans = 0;
  num_events = event_cache_size = 0;

  state_start_time = loop_start_time = 0;
  in_loop = false;

  error = 0;
}

FSM_TEM
void FSM_FUN::startLoop(timestamp_t timestamp)
{
  if(!in_loop){
    if(!sleeping){
      // this state's command has been running since the last loop
      state_info[state].total_time += timestamp - loop_start_time;
    }
    sleeping = false;
    in_loop = true;
    loop_start_time = timestamp;
    num_events = 0;
  }else{
    error |= ErrNestedStartLoop;
  }
}

FSM_TEM
void FSM_FUN::endLoop()
{
  if(in_loop){
    in_loop = false;
    state_info[state].run_count++;
    state_info[state].cmd_count++;
    is_new_state = false;
  }else{
    error |= ErrNestedEndLoop;
  }
}

FSM_TEM
void FSM_FUN::transition(state_t new_state,int trans_id,const char *trans_name)
{
  TransitionInfo *p,**q;
  int level = 0;

  if(in_loop){
    // search parity tree
    p = state_info[state].root;
    q = &state_info[state].root;
    while(p && p->trans_id!=trans_id){
      if((trans_id >> level) & 1){
        q = &p->right;
        p = p->right;        
      }else{
        q = &p->left;
        p = p->left;        
      }
      level++;
    }

    // update transition stats
    if(p){
      if(p->timestamp == loop_start_time){
        error |= ErrInfiniteLoop;
      }
      p->count++;
      p->timestamp = loop_start_time;
    }else{
      // allocate new node if not found
      if(num_trans < max_trans){
        p = &trans_info[num_trans];
        p->trans_id = trans_id;
        p->new_state = new_state;
        p->count = 1;
        p->timestamp = loop_start_time;
        p->trans_name = trans_name;
        p->left = p->right = NULL;
        *q = p; // insert into tree
        num_trans++;
      }else{
        error |= ErrOutOfTransitions;
      }
    }

    // log transition in event log
    if(num_events < event_cache_size){
      event_cache[num_events].old_state = state;
      event_cache[num_events].new_state = new_state;
      event_cache[num_events].trans_id  = trans_id;
      event_cache[num_events].reserved  = 0;
      event_cache[num_events].timestamp = loop_start_time;
      num_events++;
    }

    // since we're transitioning out, state code must have run
    state_info[state].run_count++;
    state_info[new_state].enter_count++;

    // actually change the state
    last_state = state;
    state = new_state;
    is_new_state = true;
    state_start_time = loop_start_time;
  }else{
    error |= ErrTransOutsideOfLoop;
  }
}

FSM_TEM
void FSM_FUN::setState(state_t new_state,int trans_id,const char *trans_name,
              timestamp_t timestamp)
{
  if(!in_loop){
    if(state != new_state){ // disallow self transitions
      if(!sleeping){
        // this state's command has been running since the last loop
        state_info[state].total_time += timestamp - loop_start_time;
      }

      in_loop = true;
      loop_start_time = timestamp;
      transition(new_state,trans_id,trans_name);
      in_loop = false;
    }
  }else{
    error |= ErrSetStateInsideOfLoop;
  }
}

FSM_TEM
void FSM_FUN::dumpLoopEvents()
{
  TransitionInfo *p;
  const char *trans_name;
  int i;

  PRINTF("Events: num=%d t=%f\n",num_events,loop_start_time);
  for(i=0; i<num_events; i++){
    p = lookup((state_t)event_cache[i].old_state,event_cache[i].trans_id);
    trans_name = (p != NULL)? p->trans_name : "Unknown";

    PRINTF("  %g %s::%s -> %s\n",
           (double)event_cache[i].timestamp,
           state_name[event_cache[i].old_state],
           trans_name,
           state_name[event_cache[i].new_state]);
  }
}

FSM_TEM
void FSM_FUN::dumpStats()
{
  TransitionInfo *p;
  int s,t,n;

  PRINTF("States: num=%d\n",num_states);
  for(s=0; s<num_states; s++){
    n = state_info[s].enter_count;
    PRINTF("%s%s: n=%d run=%d cmd=%d t=%g tavg=%g\n",
           (s == state)? " *" : "  ",
           state_name[s],n,
           state_info[s].run_count,
           state_info[s].cmd_count,
           (double)state_info[s].total_time,
           (double)state_info[s].total_time / n);

    n -= (s == state);
    for(t=0; t<max_trans; t++){
      p = lookup((state_t)s,t);
      if(p){
        PRINTF("    %s -> %s : %d (%1.2f%%)\n",
               p->trans_name,state_name[p->new_state],
               p->count,100.0*p->count/n);
      }
    }
  }
}


FSM_TEM
void FSM_FUN::handleErr(MotionCommand *command){
  
  if(error & (1 << 0)){
    PRINTF("Infinite Loop Error\n");
  }
  if(error & (1 << 1)){
    PRINTF("Nested Start Loop Error\n");
  }
  if(error & (1 << 2)){
    PRINTF("Nested End Loop Error\n");
  }
  if(error & (1 << 3)){
    PRINTF("Out of Transitions Error\n");
  }
  if(error & (1 << 4)){
    PRINTF("Transition Outside of Loop Error\n");
  }
  if(error & (1 << 5)){
    PRINTF("Set State Inside of Loop Error\n");
  }

  dumpLoopEvents();
  //command->motion_cmd = MOTION_GOAL_HAPPY1;
  command->sound_cmd       = SOUND_NOTE;
  command->sound_frequency = 500;
  command->sound_duration  = 500000;

  error = 0;
}


FSM_TEM
const typename FSM_FUN::StateInfo *FSM_FUN::lookupStateInfo(state_t state) const
{
  return &state_info[state];
}

FSM_TEM
int FSM_FUN::countTransitions(const TransitionInfo *trans_info) const
{
  if(trans_info == NULL){
    return 0;
  }

  return 1 + countTransitions(trans_info->left) + countTransitions(trans_info->right);
}

FSM_TEM
int FSM_FUN::countStateTransitions(state_t state) const
{
  const TransitionInfo *trans_info;
  const StateInfo *state_info;
  
  state_info = lookupStateInfo(state);
  trans_info = state_info->root;
  return countTransitions(trans_info);
}

FSM_TEM
const typename FSM_FUN::TransitionInfo *FSM_FUN::lookupTransitionInfo(state_t state,int trans_id) const
{
  const StateInfo *state_info;
  
  state_info = lookupStateInfo(state);
  return (const_cast<FSM_FUN *>(this))->lookup(state,trans_id);
}

FSM_TEM
int FSM_FUN::maxTransitions() const
{
  return max_trans;
}

FSM_TEM
int FSM_FUN::numEvents() const
{
  return num_events;
}

FSM_TEM
const typename FSM_FUN::TransitionEvent *FSM_FUN::lookupEvent(int event_idx) const
{
  return &event_cache[event_idx];
}

#endif
