#include <devices/8254.h>
 #include <palacios/vmm.h>
 #include <palacios/vmm_time.h>
+#include <palacios/vmm_util.h>
+
+
+
+
+
+
 
 // constants
 #define OSC_HZ 1193182
 /* The order of these typedefs is important because the numerical values correspond to the 
  * values coming from the io ports
  */
-typedef enum {NOT_RUNNING, WAITING_LOBYTE, WAITING_HIBYTE, RUNNING} channel_access_state_t;
+typedef enum {NOT_RUNNING, PENDING, RUNNING} channel_run_state_t;
+typedef enum {NOT_WAITING, WAITING_LOBYTE, WAITING_HIBYTE} channel_access_state_t;
 typedef enum {LATCH_COUNT, LOBYTE_ONLY, HIBYTE_ONLY, LOBYTE_HIBYTE} channel_access_mode_t;
 typedef enum {IRQ_ON_TERM_CNT, ONE_SHOT, RATE_GEN, SQR_WAVE, SW_STROBE, HW_STROBE} channel_op_mode_t;
 
 struct channel {
   channel_access_mode_t access_mode;
   channel_access_state_t access_state;
+  channel_run_state_t run_state;
 
   channel_op_mode_t op_mode;
 
 
-
   // Time til interrupt trigger 
 
   ushort_t counter;
 
 
 
+/* 
+ * This should call out to handle_SQR_WAVE_tics, etc... 
+ */
+// Returns true if the the output signal changed state
+static int handle_crystal_tics(struct vm_device * dev, struct channel * ch, uint_t oscillations) {
+  uint_t channel_cycles = 0;
+  uint_t output_changed = 0;
+  
+  PrintDebug("8254 PIT: %d crystal tics\n", oscillations);
+  if (ch->run_state == PENDING) {
+    oscillations--;
+    ch->counter = ch->reload_value;
+    ch->run_state = RUNNING;
+  } else if (ch->run_state != RUNNING) {
+    return output_changed;
+  }
+
+
+  PrintDebug("8254 PIT: Channel Run State = %d, counter=", ch->run_state);
+  PrintTraceLL(ch->counter);
+  PrintDebug("\n");
+
+
+  if (ch->counter > oscillations) {
+    ch->counter -= oscillations;
+    return output_changed;
+  } else {
+    oscillations -= ch->counter;
+    ch->counter = 0;
+    channel_cycles = 1;
+
+    
+    channel_cycles += oscillations / ch->reload_value;
+    oscillations = oscillations % ch->reload_value;
+
+    ch->counter = ch->reload_value - oscillations;
+  }
+
+  PrintDebug("8254 PIT: Channel Cycles: %d\n", channel_cycles);
+  
+
+
+  switch (ch->op_mode) {
+  case IRQ_ON_TERM_CNT:
+    if ((channel_cycles > 0) && (ch->output_pin == 0)) {
+      ch->output_pin = 1; 
+      output_changed = 1;
+    }
+    break;
+  case ONE_SHOT:
+    if ((channel_cycles > 0) && (ch->output_pin == 0)) {
+      ch->output_pin = 1; 
+      output_changed = 1;
+    }
+    break;
+  case RATE_GEN:
+    // See the data sheet: we ignore the output pin cycle...
+    if (channel_cycles > 0) {
+      output_changed = 1;
+    }
+    break;
+  case SQR_WAVE:
+    break;
+  case SW_STROBE:
+    break;
+  case HW_STROBE:
+    break;
+  default:
+    break;
+  }
+
+  return output_changed;
+}
+                               
 
 
 static void pit_update_time(ullong_t cpu_cycles, ullong_t cpu_freq, void * private_data) {
-  PrintDebug("Adding %d cycles\n", cpu_cycles);
+  struct vm_device * dev = (struct vm_device *)private_data;
+  struct pit * state = (struct pit *)dev->private_data;
+  //  ullong_t tmp_ctr = state->pit_counter;
+  ullong_t tmp_cycles;
+  uint_t oscillations = 0;
+
+
+  /*
+    PrintDebug("updating cpu_cycles=");
+    PrintTraceLL(cpu_cycles);
+    PrintDebug("\n");
+    
+    PrintDebug("pit_counter=");
+    PrintTraceLL(state->pit_counter);
+    PrintDebug("\n");
+    
+    PrintDebug("pit_reload=");
+    PrintTraceLL(state->pit_reload);
+    PrintDebug("\n");
+  */
+
+  if (state->pit_counter > cpu_cycles) {
+    // Easy...
+    state->pit_counter -= cpu_cycles;
+  } else {
+    
+    // Take off the first part
+    cpu_cycles -= state->pit_counter;
+    state->pit_counter = 0;
+    oscillations = 1;
+    
+    if (cpu_cycles > state->pit_reload) {
+      // how many full oscillations
+      tmp_cycles = cpu_cycles;
+
+      cpu_cycles = do_divll(tmp_cycles, state->pit_reload);
+
+      oscillations += tmp_cycles;
+    }
+
+    // update counter with remainder (mod reload)
+    state->pit_counter = state->pit_reload - cpu_cycles;    
+
+    //PrintDebug("8254 PIT: Handling %d crystal tics\n", oscillations);
+    if (handle_crystal_tics(dev, &(state->ch_0), oscillations) == 1) {
+      // raise interrupt
+      PrintDebug("8254 PIT: Injecting Timer interrupt to guest\n");
+      dev->vm->vm_ops.raise_irq(dev->vm, 0);
+    }
+
+    //handle_crystal_tics(dev, &(state->ch_1), oscillations);
+    //handle_crystal_tics(dev, &(state->ch_2), oscillations);
+  }
   
+
+
+ 
   return;
 }
 
 
+
+/* This should call out to handle_SQR_WAVE_write, etc...
+ */
 static int handle_channel_write(struct channel * ch, char val) {
-  //  switch (ch->access_mode) {
+
+    switch (ch->access_state) {      
+    case WAITING_HIBYTE:
+      {
+       ushort_t tmp_val = ((ushort_t)val) << 8;
+       ch->reload_value &= 0x00ff;
+       ch->reload_value |= tmp_val;
+       
+
+       if ((ch->op_mode != RATE_GEN) || (ch->run_state != RUNNING)){
+         ch->run_state = PENDING;  
+       }
+       
+       if (ch->access_mode == LOBYTE_HIBYTE) {
+         ch->access_state = WAITING_LOBYTE;
+       }
+
+       PrintDebug("8254 PIT: updated channel counter: %d\n", ch->reload_value);        
+       PrintDebug("8254 PIT: Channel Run State=%d\n", ch->run_state);
+       break;
+      }
+    case WAITING_LOBYTE:
+      ch->reload_value &= 0xff00;
+      ch->reload_value |= val;
+
+      if (ch->access_mode == LOBYTE_HIBYTE) {
+       ch->access_state = WAITING_HIBYTE;
+      } else if ((ch->op_mode != RATE_GEN) || (ch->run_state != RUNNING)) {
+       ch->run_state = PENDING;
+      }
+
+      PrintDebug("8254 PIT: updated channel counter: %d\n", ch->reload_value);
+      PrintDebug("8254 PIT: Channel Run State=%d\n", ch->run_state);
+      break;
+    default:
+      return -1;
+  }
 
 
-  //}
+    switch (ch->op_mode) {
+    case IRQ_ON_TERM_CNT:
+      ch->output_pin = 0;
+      break;
+    case ONE_SHOT:
+      ch->output_pin = 1;
+      break;
+    case RATE_GEN:
+      ch->output_pin = 1;
+      break;
 
 
-  return -1;
+    default:
+      return -1;
+      break;
+    }
+
+
+  return 0;
 }
 
 
   ch->op_mode = cmd.op_mode;
   ch->access_mode = cmd.access_mode;
 
+
+
+
   switch (cmd.access_mode) {
   case LATCH_COUNT:
     return -1;
   case RATE_GEN: 
     ch->output_pin = 1;
     break;
+  case SQR_WAVE:
+    ch->output_pin = 1;
+    break;
   default:
     return -1;
   }
     return -1;
   }
 
-  PrintDebug("8254 PIT: Write to PIT Channel %d\n", port - CHANNEL0_PORT);
+  PrintDebug("8254 PIT: Write to PIT Channel %d (%x)\n", port - CHANNEL0_PORT, *(char*)src);
 
 
   switch (port) {
   struct pit_cmd_word * cmd = (struct pit_cmd_word *)src;
 
   PrintDebug("8254 PIT: Write to PIT Command port\n");
-
+  PrintDebug("8254 PIT: Writing to channel %d (access_mode = %d, op_mode = %d)\n", cmd->channel, cmd->access_mode, cmd->op_mode);
   if (length != 1) {
     PrintDebug("8254 PIT: Write of Invalid length to command port\n");
     return -1;
 };
 
 
+static void init_channel(struct channel * ch) {
+  ch->run_state = NOT_RUNNING;
+  ch->access_state = NOT_WAITING;
+  ch->access_mode = 0;
+  ch->op_mode = 0;
+
+  ch->counter = 0;
+  ch->reload_value = 0;
+  ch->output_pin = 0;
+  ch->gate_input_pin = 0;
+
+  return;
+}
+
+
 static int pit_init(struct vm_device * dev) {
+  struct pit * state = (struct pit *)dev->private_data;
+  uint_t cpu_khz = V3_CPU_KHZ();
+  ullong_t reload_val = (ullong_t)cpu_khz * 1000;
+
   dev_hook_io(dev, CHANNEL0_PORT, &pit_read_channel, &pit_write_channel);
   dev_hook_io(dev, CHANNEL1_PORT, &pit_read_channel, &pit_write_channel);
   dev_hook_io(dev, CHANNEL2_PORT, &pit_read_channel, &pit_write_channel);
   dev_hook_io(dev, COMMAND_PORT, NULL, &pit_write_command);
 
+  PrintDebug("8254 PIT: OSC_HZ=%d, reload_val=", OSC_HZ);
+  PrintTraceLL(reload_val);
+  PrintDebug("\n");
+
 
   v3_add_timer(dev->vm, &timer_ops, dev);
 
   // Get cpu frequency and calculate the global pit oscilattor counter/cycle
 
+  do_divll(reload_val, OSC_HZ);
+  state->pit_counter = reload_val;
+  state->pit_reload = reload_val;
+
+
+
+  init_channel(&(state->ch_0));
+  init_channel(&(state->ch_1));
+  init_channel(&(state->ch_2));
+
+  PrintDebug("8254 PIT: CPU MHZ=%d -- pit count=", cpu_khz / 1000);
+  PrintTraceLL(state->pit_counter);
+  PrintDebug("\n");
 
   return 0;
 }