#include <palacios/vm_guest_mem.h>
 #include <palacios/vm_guest.h>
 
+// Reference: AMD Software Developer Manual Vol.2 Ch.5 "Page Translation and Protection"
+
+static int check_large_page_ok() {
+
+    // Need to fix this....
+    return 0; 
+
+
+#if 0
+   struct v3_mem_region * base_reg = &(info->vm_info->mem_map.base_region);
+
+   /* If the guest has been configured for 2MiB pages, then we must check for hooked regions of
+     * memory which may overlap with the 2MiB page containing the faulting address (due to
+     * potentially differing access policies in place for e.g. i/o devices and APIC). A 2MiB page
+     * can be used if a) no region overlaps the page [or b) a region does overlap but fully contains
+     * the page]. The [bracketed] text pertains to the #if 0'd code below, state D. TODO modify this
+     * note if someone decides to enable this optimization. It can be tested with the SeaStar
+     * mapping.
+     *
+     * Examples: (CAPS regions are returned by v3_get_next_mem_region; state A returns the base reg)
+     *
+     *    |region| |region|                               2MiB mapped (state A)
+     *                   |reg|          |REG|             2MiB mapped (state B)
+     *   |region|     |reg|   |REG| |region|   |reg|      4KiB mapped (state C)
+     *        |reg|  |reg|   |--REGION---|                [2MiB mapped (state D)]
+     * |--------------------------------------------|     RAM
+     *                             ^                      fault addr
+     * |----|----|----|----|----|page|----|----|----|     2MB pages
+     *                           >>>>>>>>>>>>>>>>>>>>     search space
+     */
+    addr_t pg_start = 0UL, pg_end = 0UL; // 2MiB page containing the faulting address
+    struct v3_mem_region * pg_next_reg = NULL; // next immediate mem reg after page start addr
+    bool use_large_page = false;
 
+    if (region == NULL) {
+       PrintError("%s: invalid region, addr=%p\n", __FUNCTION__, (void *)fault_addr);
+       return -1;
+    }
 
-static inline int handle_passthrough_pagefault_64(struct guest_info * info, 
-                                                 addr_t fault_addr, 
-                                                 pf_error_t error_code) {
-    pml4e64_t * pml = NULL;
-    pdpe64_t * pdpe = NULL;
-    pde64_t * pde = NULL;
-    pte64_t * pte = NULL;
-    addr_t host_addr = 0;
+    // set use_large_page here
+    if (info->vm_info->paging_size == PAGING_2MB) {
+
+       // guest page maps to a host page + offset (so when we shift, it aligns with a host page)
+       pg_start = PAGE_ADDR_2MB(fault_addr);
+       pg_end = (pg_start + PAGE_SIZE_2MB);
+
+       PrintDebug("%s: page   [%p,%p) contains address\n", __FUNCTION__, (void *)pg_start, (void *)pg_end);
+
+       pg_next_reg = v3_get_next_mem_region(info->vm_info, info->cpu_id, pg_start);
 
-    int pml_index = PML4E64_INDEX(fault_addr);
+       if (pg_next_reg == NULL) {
+           PrintError("%s: Error: address not in base region, %p\n", __FUNCTION__, (void *)fault_addr);
+           return -1;
+       }
+
+       if ((pg_next_reg->base == 1) { // next region == base region
+           use_large_page = 1; // State A
+       } else {
+#if 0       // State B/C and D optimization
+           use_large_page = (pg_next_reg->guest_end >= pg_end) &&
+               ((pg_next_reg->guest_start >= pg_end) || (pg_next_reg->guest_start <= pg_start));
+           PrintDebug("%s: region [%p,%p) %s partial overlap with page\n", __FUNCTION__,
+                   (void *)pg_next_reg->guest_start, (void *)pg_next_reg->guest_end,
+                   (use_large_page ? "does not have" : "has"));
+#else       // State B/C
+           use_large_page = (pg_next_reg->guest_start >= pg_end);
+           PrintDebug("%s: region [%p,%p) %s overlap with page\n", __FUNCTION__,
+                   (void *)pg_next_reg->guest_start, (void *)pg_next_reg->guest_end,
+                   (use_large_page ? "does not have" : "has"));
+#endif
+       }
+    }
+
+    PrintDebug("%s: Address gets a 2MiB page? %s\n", __FUNCTION__, (use_large_page ? "yes" : "no"));
+#endif
+}
+
+
+static inline int handle_passthrough_pagefault_64(struct guest_info * core, addr_t fault_addr, pf_error_t error_code) {
+    pml4e64_t * pml      = NULL;
+    pdpe64_t * pdpe      = NULL;
+    pde64_t * pde        = NULL;
+    pde64_2MB_t * pde2mb = NULL;
+    pte64_t * pte        = NULL;
+    addr_t host_addr     = 0;
+
+    int pml_index  = PML4E64_INDEX(fault_addr);
     int pdpe_index = PDPE64_INDEX(fault_addr);
-    int pde_index = PDE64_INDEX(fault_addr);
-    int pte_index = PTE64_INDEX(fault_addr);
+    int pde_index  = PDE64_INDEX(fault_addr);
+    int pte_index  = PTE64_INDEX(fault_addr);
 
+    struct v3_mem_region * region =  v3_get_mem_region(core->vm_info, core->cpu_id, fault_addr);
+    int use_large_page = 0;
 
-    
 
-    struct v3_mem_region * region =  v3_get_mem_region(info->vm_info, info->cpu_id, fault_addr);
-  
-    if (region == NULL) {
-       PrintError("Invalid region in passthrough page fault 64, addr=%p\n", 
-                  (void *)fault_addr);
-       return -1;
+    /*  Check if:
+     *  1. the guest is configured to use large pages and 
+     *         2. the memory regions can be referenced by a large page
+     */
+    if ((core->use_large_pages == 1) && (check_large_page_ok() == 1)) {
+       use_large_page = 1;
     }
 
+ 
     // Lookup the correct PML address based on the PAGING MODE
-    if (info->shdw_pg_mode == SHADOW_PAGING) {
-       pml = CR3_TO_PML4E64_VA(info->ctrl_regs.cr3);
+    if (core->shdw_pg_mode == SHADOW_PAGING) {
+       pml = CR3_TO_PML4E64_VA(core->ctrl_regs.cr3);
     } else {
-       pml = CR3_TO_PML4E64_VA(info->direct_map_pt);
+       pml = CR3_TO_PML4E64_VA(core->direct_map_pt);
     }
 
     //Fix up the PML entry
         pml[pml_index].writable = 1;
         pml[pml_index].user_page = 1;
 
-       pml[pml_index].pdp_base_addr = PAGE_BASE_ADDR((addr_t)V3_PAddr(pdpe));    
+       pml[pml_index].pdp_base_addr = PAGE_BASE_ADDR_4KB((addr_t)V3_PAddr(pdpe));    
     } else {
-       pdpe = V3_VAddr((void*)BASE_TO_PAGE_ADDR(pml[pml_index].pdp_base_addr));
+       pdpe = V3_VAddr((void*)BASE_TO_PAGE_ADDR_4KB(pml[pml_index].pdp_base_addr));
     }
 
     // Fix up the PDPE entry
        pdpe[pdpe_index].writable = 1;
        pdpe[pdpe_index].user_page = 1;
 
-       pdpe[pdpe_index].pd_base_addr = PAGE_BASE_ADDR((addr_t)V3_PAddr(pde));    
+       pdpe[pdpe_index].pd_base_addr = PAGE_BASE_ADDR_4KB((addr_t)V3_PAddr(pde));    
     } else {
-       pde = V3_VAddr((void*)BASE_TO_PAGE_ADDR(pdpe[pdpe_index].pd_base_addr));
+       pde = V3_VAddr((void*)BASE_TO_PAGE_ADDR_4KB(pdpe[pdpe_index].pd_base_addr));
+    }
+
+    // Fix up the 2MiB PDE and exit here
+    if (use_large_page == 1) {
+       pde2mb = (pde64_2MB_t *)pde; // all but these two lines are the same for PTE
+       pde2mb[pde_index].large_page = 1;
+
+       if (pde2mb[pde_index].present == 0) {
+           pde2mb[pde_index].user_page = 1;
+
+           if ( (region->flags.alloced == 1) && 
+                (region->flags.read == 1)) {
+               // Full access
+               pde2mb[pde_index].present = 1;
+
+               if (region->flags.write == 1) {
+                   pde2mb[pde_index].writable = 1;
+               } else {
+                   pde2mb[pde_index].writable = 0;
+               }
+
+               if (v3_gpa_to_hpa(core, fault_addr, &host_addr) == -1) {
+                   PrintError("Error Could not translate fault addr (%p)\n", (void *)fault_addr);
+                   return -1;
+               }
+
+               pde2mb[pde_index].page_base_addr = PAGE_BASE_ADDR_2MB(host_addr);
+           } else {
+               return region->unhandled(core, fault_addr, fault_addr, region, error_code);
+           }
+       } else {
+           // We fix all permissions on the first pass, 
+           // so we only get here if its an unhandled exception
+
+           return region->unhandled(core, fault_addr, fault_addr, region, error_code);
+       }
     }
 
+    // Continue with the 4KiB page heirarchy
 
     // Fix up the PDE entry
     if (pde[pde_index].present == 0) {
        pde[pde_index].writable = 1;
        pde[pde_index].user_page = 1;
        
-       pde[pde_index].pt_base_addr = PAGE_BASE_ADDR((addr_t)V3_PAddr(pte));
+       pde[pde_index].pt_base_addr = PAGE_BASE_ADDR_4KB((addr_t)V3_PAddr(pte));
     } else {
-       pte = V3_VAddr((void*)BASE_TO_PAGE_ADDR(pde[pde_index].pt_base_addr));
+       pte = V3_VAddr((void*)BASE_TO_PAGE_ADDR_4KB(pde[pde_index].pt_base_addr));
     }
 
 
                pte[pte_index].writable = 0;
            }
 
-           if (v3_gpa_to_hpa(info, fault_addr, &host_addr) == -1) {
+           if (v3_gpa_to_hpa(core, fault_addr, &host_addr) == -1) {
                PrintError("Error Could not translate fault addr (%p)\n", (void *)fault_addr);
                return -1;
            }
 
-           pte[pte_index].page_base_addr = PAGE_BASE_ADDR(host_addr);
+           pte[pte_index].page_base_addr = PAGE_BASE_ADDR_4KB(host_addr);
        } else {
-           return region->unhandled(info, fault_addr, fault_addr, region, error_code);
+           return region->unhandled(core, fault_addr, fault_addr, region, error_code);
        }
     } else {
        // We fix all permissions on the first pass, 
        // so we only get here if its an unhandled exception
 
-       return region->unhandled(info, fault_addr, fault_addr, region, error_code);
+       return region->unhandled(core, fault_addr, fault_addr, region, error_code);
     }
 
     return 0;
 }
 
-static inline int invalidate_addr_64(struct guest_info * info, addr_t inv_addr) {
+static inline int invalidate_addr_64(struct guest_info * core, addr_t inv_addr) {
     pml4e64_t * pml = NULL;
     pdpe64_t * pdpe = NULL;
     pde64_t * pde = NULL;
 
     
     // Lookup the correct PDE address based on the PAGING MODE
-    if (info->shdw_pg_mode == SHADOW_PAGING) {
-       pml = CR3_TO_PML4E64_VA(info->ctrl_regs.cr3);
+    if (core->shdw_pg_mode == SHADOW_PAGING) {
+       pml = CR3_TO_PML4E64_VA(core->ctrl_regs.cr3);
     } else {
-       pml = CR3_TO_PML4E64_VA(info->direct_map_pt);
+       pml = CR3_TO_PML4E64_VA(core->direct_map_pt);
     }
 
     if (pml[pml_index].present == 0) {
 
     if (pdpe[pdpe_index].present == 0) {
        return 0;
-    } else if (pdpe[pdpe_index].large_page == 1) {
+    } else if (pdpe[pdpe_index].large_page == 1) { // 1GiB
        pdpe[pdpe_index].present = 0;
        return 0;
     }
 
     if (pde[pde_index].present == 0) {
        return 0;
-    } else if (pde[pde_index].large_page == 1) {
+    } else if (pde[pde_index].large_page == 1) { // 2MiB
        pde[pde_index].present = 0;
        return 0;
     }
 
     pte = V3_VAddr((void*)BASE_TO_PAGE_ADDR(pde[pde_index].pt_base_addr));
 
-    pte[pte_index].present = 0;
+    pte[pte_index].present = 0; // 4KiB
 
     return 0;
 }