#include "jitty.h"

static void ErrorHandler(const char *format, va_list args)
     {
     fprintf(stderr,"%s ", "Rewriter demo:");
     ATvfprintf(stderr, format, args);
     fprintf(stderr,"\n");
     exit(0);
     }

void read_trs(FILE* trs,ATermList *funs, ATermList *eqs, ATermList *strat, char* def) {
  ATerm nextterm;
  ATermList sorts, rules, strategies;
  
  nextterm = ATreadFromFile(trs);
  
  if (nextterm != ATmake("signature"))
    ATerror("Keyword 'signature' expected\n");
  
  sorts=ATmakeList0();
  nextterm = ATreadFromFile(trs);
  if (!nextterm) exit(0);
  while (ATgetArity(ATgetSymbol(nextterm))==1 &&
	 ATmatch(ATgetArgument(nextterm,0),"<int>",NULL)) {
    sorts = ATinsert(sorts,nextterm);
    nextterm = ATreadFromFile(trs);
    if (!nextterm) exit(0);
  }
  
  if (nextterm!=ATmake("rules"))
    ATerror("Keyword 'rules' expected (%t)\n",nextterm);
  
  rules=ATmakeList0();
  nextterm=ATreadFromFile(trs);
  if (!nextterm) exit(0);
  while (ATgetArity(ATgetSymbol(nextterm))==3 &&
	 ATmatch(ATgetArgument(nextterm,0),"[<list>]",NULL)) {
    rules = ATinsert(rules,nextterm);
    nextterm = ATreadFromFile(trs);
    if (!nextterm) exit(0);
  }

  if (!ATmatch(nextterm,"default"))
    ATerror("Keyword 'default' expected\n");

  nextterm=ATreadFromFile(trs);
  
  if (nextterm == ATmake("innermost"))
    *def=1;
  else if (nextterm == ATmake("justintime"))
    *def=2;
  else
    ATerror("default strategy: 'innermost' or 'justintime'");
  
  nextterm = ATreadFromFile(trs);
  if (nextterm != ATmake("strategies"))
    ATerror("Keyword 'strategies' expected\n");
  
  strategies = ATmakeList0();
  nextterm=ATreadFromFile(trs);
  while (nextterm &&
	 ATgetArity(ATgetSymbol(nextterm))==1 &&
	 ATmatch(ATgetArgument(nextterm,0),"[<list>]",NULL)) {
    strategies = ATinsert(strategies,nextterm);
    nextterm = ATreadFromFile(trs);
  }

  if (nextterm != ATmake("end"))
    ATerror("Keyword 'end' expected\n");
  
  *funs = sorts;
  *eqs  = rules;
  *strat = strategies;
}

int main(int argc, char** argv) {
  ATermList funs,eqs,strat;
  char def_strat;
  ATinit(argc,argv,(ATerm*)&argc);
  ATsetErrorHandler(ErrorHandler);

  read_trs(stdin,&funs,&eqs,&strat,&def_strat);

  JIT_init(funs,eqs,strat,def_strat,0); /* or 1 for hashing mode */
  while (1) {
    ATerm command,var,term;
    command = ATreadFromFile(stdin);
    if (!command) exit(0);
    if (ATmatch(command,"rewrite(<term>)",&term))
      ATfprintf(stderr,"normal form: %t\n",JIT_normalize(term));
    else if (ATmatch(command,"assign(<term>,<term>)",&var,&term)) {
      JIT_assign(ATgetSymbol(var),term);
      ATfprintf(stderr,"variable %t set to %t in level %d\n",var,term,JIT_level());
    }
    else if (ATmatch(command,"leave")) {
      JIT_leave();
      ATfprintf(stderr,"Return back to level %d\n",JIT_level());
    }
    else if (ATmatch(command,"enter")) {
      JIT_enter();
      ATfprintf(stderr,"Enter level %d\n",JIT_level());
    }
    else if (ATmatch(command,"clear")) {
      JIT_clear();
      ATfprintf(stderr,"Clearing bindings at level %d\n",JIT_level());
    }

    else if (ATmatch(command,"stop")) break;
    else
      ATfprintf(stderr,"command %t not understood\n",command);
  }
  exit(1);
}
