
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include <vector>
#include <string>
#include <algorithm>

#include <Windows.h>
#include <Shlwapi.h>

#define DIRECTINPUT_VERSION 0x0800
#include <dinput.h>

#include <boost/version.hpp>
#include <boost/static_assert.hpp>
#include <boost/filesystem.hpp>
#include <boost/foreach.hpp>

#include "dll_hack_lib.h"
#include "../DxLibKeyConfig/Consts.h"
#include "main.h"

const IMAGE_DOS_HEADER &getDosHeader(const HMODULE hmod = ::GetModuleHandle(NULL)) {
  const unsigned int base = reinterpret_cast<unsigned int>(hmod);
  return *reinterpret_cast<const IMAGE_DOS_HEADER *>(base);
}

const IMAGE_NT_HEADERS &getNtHeader(const HMODULE hmod = ::GetModuleHandle(NULL)) {
  const unsigned int base = reinterpret_cast<unsigned int>(hmod);
  const IMAGE_DOS_HEADER &mz = getDosHeader(hmod);
  return *reinterpret_cast<const IMAGE_NT_HEADERS *>(base + mz.e_lfanew);
}

void getSectionHeaderList(std::vector<const IMAGE_SECTION_HEADER *> &out, const HMODULE hmod = ::GetModuleHandle(NULL)) {
  const IMAGE_NT_HEADERS &pe = getNtHeader(hmod);
  const IMAGE_SECTION_HEADER *sectionPtr = reinterpret_cast<const IMAGE_SECTION_HEADER *>(reinterpret_cast<unsigned int>(&pe.OptionalHeader) + pe.FileHeader.SizeOfOptionalHeader);
  out.resize(pe.FileHeader.NumberOfSections);
  BOOST_FOREACH(const IMAGE_SECTION_HEADER * &section, out) {
    section = sectionPtr;
    ++sectionPtr;
  }
}

const IMAGE_SECTION_HEADER *getSection(const std::vector<const IMAGE_SECTION_HEADER *> &list, const std::string &targetName) {
  BOOST_FOREACH(const IMAGE_SECTION_HEADER *section, list) {
    const char * const sectionNamePtr = reinterpret_cast<const char *>(section->Name);
    const unsigned int length = (std::min)(_countof(section->Name), ::strlen(sectionNamePtr));
    const std::string sectionName = std::string(sectionNamePtr, &sectionNamePtr[length]);
    if (targetName == sectionName) {
      return section;
    }
  }
  return NULL;
}

const IMAGE_SECTION_HEADER *getSection(const std::string &name, const HMODULE hmod = ::GetModuleHandle(NULL)) {
  std::vector<const IMAGE_SECTION_HEADER *> sectionList;
  getSectionHeaderList(sectionList, hmod);
  return getSection(sectionList, name);
}

const unsigned char *getKeyConfigInitAddr(const HMODULE hmod = ::GetModuleHandle(NULL)) {
  const unsigned int base = reinterpret_cast<unsigned int>(hmod);
  const IMAGE_SECTION_HEADER *section = getSection(".text");
  if (section == NULL) {
    return NULL;
  }
  const unsigned char *ptr = reinterpret_cast<const unsigned char *>(base + section->VirtualAddress);
  const unsigned char * const end = ptr + section->SizeOfRawData;
  for (; ptr != end; ptr++) {
    if (*reinterpret_cast<const unsigned short *>(ptr) != 0x05C7) {
      continue;
    }
    static const unsigned char list[][6] = {
      {0x50, 0x00, 0x00, 0x00, 0xC7, 0x05},
      {0xD0, 0x00, 0x00, 0x00, 0xC7, 0x05},
      {0x4B, 0x00, 0x00, 0x00, 0xC7, 0x05},
      {0xCB, 0x00, 0x00, 0x00, 0xC7, 0x05},
      {0x4D, 0x00, 0x00, 0x00, 0xC7, 0x05},
      {0xCD, 0x00, 0x00, 0x00, 0xC7, 0x05},
      {0x48, 0x00, 0x00, 0x00, 0xC7, 0x05},
      {0xC8, 0x00, 0x00, 0x00, 0xC7, 0x05},
      {0x2C, 0x00, 0x00, 0x00, 0xC7, 0x05},
      {0x2D, 0x00, 0x00, 0x00, 0xC7, 0x05},
      {0x2E, 0x00, 0x00, 0x00, 0xC7, 0x05},
      {0x1E, 0x00, 0x00, 0x00, 0xC7, 0x05},
      {0x1F, 0x00, 0x00, 0x00, 0xC7, 0x05},
      {0x20, 0x00, 0x00, 0x00, 0xC7, 0x05},
      {0x10, 0x00, 0x00, 0x00, 0xC7, 0x05},
      {0x11, 0x00, 0x00, 0x00, 0xC7, 0x05},
      {0x01, 0x00, 0x00, 0x00, 0xC7, 0x05},
      {0x39, 0x00, 0x00, 0x00, 0xC7, 0x05},
    };
    for (unsigned int i = 0; i < _countof(list); i++) {
      if (::memcmp(&ptr[6 + i * 10], list[i], 6) != 0) {
        goto next;
      }
    }
    return ptr;
next:
    ;
  }
  return NULL;
}

int *getKeyMap(const HMODULE hmod = ::GetModuleHandle(NULL)) {
  const unsigned char *codePtr = getKeyConfigInitAddr();
  const unsigned int keyMapAddr = *reinterpret_cast<const unsigned int *>(&codePtr[2]);
  return reinterpret_cast<int *>(keyMapAddr);
}

void InitializeKeyMap(int * const keyMap) {
  for (unsigned int pad = 0; pad < 16; pad++) {
    for (unsigned int key = 0; key < 32; key++) {
      for (unsigned int i = 0; i < 4; i++) {
        keyMap[pad * 32 * 4 + key * 4 + i] = -1;
      }
    }
  }
}

bool ChangeIAT(const std::string &targetDllName, const std::string &targetFuncName, const FARPROC function) {
  std::string targetDllNameLow;
  targetDllNameLow.resize(targetDllName.size());
  std::transform(targetDllName.begin(), targetDllName.end(), targetDllNameLow.begin(), ::tolower);

  const HMODULE hmod = ::GetModuleHandle(NULL);
  const IMAGE_NT_HEADERS &nt = getNtHeader(hmod);
  const IMAGE_DATA_DIRECTORY &dir = nt.OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT];
  const IMAGE_IMPORT_DESCRIPTOR *desc = reinterpret_cast<const IMAGE_IMPORT_DESCRIPTOR *>(reinterpret_cast<unsigned int>(hmod) + dir.VirtualAddress);
  for (; desc->OriginalFirstThunk != 0; desc++) {
    const char * const dllName = reinterpret_cast<const char *>(reinterpret_cast<unsigned int>(hmod) + desc->Name);
    const unsigned int dllNameLength = ::strlen(dllName);
    std::string dllNameLow;
    dllNameLow.resize(dllNameLength);
    std::transform(&dllName[0], &dllName[dllNameLength], dllNameLow.begin(), ::tolower);
    if (dllNameLow != targetDllNameLow) {
      continue;
    }
    const IMAGE_THUNK_DATA *thunk = reinterpret_cast<const IMAGE_THUNK_DATA *>(reinterpret_cast<unsigned int>(hmod) + desc->OriginalFirstThunk);
    const IMAGE_THUNK_DATA *outThunk = reinterpret_cast<const IMAGE_THUNK_DATA *>(reinterpret_cast<unsigned int>(hmod) + desc->FirstThunk);
    for (;thunk->u1.AddressOfData != NULL; thunk++, outThunk++) {
      if (!IMAGE_SNAP_BY_ORDINAL32(thunk->u1.Ordinal)) {
        const char * const funcName = reinterpret_cast<const char *>(reinterpret_cast<unsigned int>(hmod) + thunk->u1.AddressOfData + 2);
        if (targetFuncName == funcName) {
          return org::click3::DllHackLib::WriteCode(NULL, reinterpret_cast<unsigned int>(outThunk), reinterpret_cast<const unsigned char *>(&function), sizeof(function));
        }
      }
    }
  }
  return false;
}

GUID MY_CLSID_DirectInput = {0x25E609E0,0xB259,0x11CF,0xBF,0xC7,0x44,0x45,0x53,0x54,0x00,0x00};
GUID MY_CLSID_DirectInput8 = {0x25E609E4,0xB259,0x11CF,0xBF,0xC7,0x44,0x45,0x53,0x54,0x00,0x00};

struct D_DIJOYSTATE {
	LONG lX;
	LONG lY;
	LONG lZ;
	LONG lRx;
	LONG lRy;
	LONG lRz;
	LONG rglSlider[2];
	DWORD rgdwPOV[4];
	BYTE rgbButtons[32];
};

IDirectInputDevice *joy[16];
unsigned int joyNum = 0;
unsigned int keyMap[4];

HRESULT (__stdcall * p_GetDeviceState)(IDirectInputDevice *, DWORD, LPVOID) = NULL;
HRESULT __stdcall d_GetDeviceState(IDirectInputDevice *self, DWORD cbData, LPVOID lpvData) {
  const HRESULT result = p_GetDeviceState(self, cbData, lpvData);
  if (result == DI_OK && cbData == sizeof(D_DIJOYSTATE)) {
    D_DIJOYSTATE &data = *reinterpret_cast<D_DIJOYSTATE *>(lpvData);
    if (joy[0] == self) {
      bool keyStateList[_countof(INI_KEY_LIST)];
      for (unsigned int i = 0; i < _countof(keyStateList); i++) {
        keyStateList[i] = (data.rgbButtons[keyMap[i]] == 0x80);
      }
      ::ZeroMemory(data.rgbButtons, sizeof(data.rgbButtons));
      for (unsigned int i = 0; i < _countof(keyStateList); i++) {
        data.rgbButtons[INI_DEFAULT_VALUE_LIST[i]] = (keyStateList[i] ? 0x80 : 0);
      }
    }
  }
  return result;
}

HRESULT (__stdcall * p_SetDataFormat)(IDirectInputDevice *, LPCDIDATAFORMAT) = NULL;
HRESULT __stdcall d_SetDataFormat(IDirectInputDevice *self, LPCDIDATAFORMAT lpdf) {
  if (lpdf->dwNumObjs == 44) {
    joy[joyNum] = self;
    joyNum++;
  }
  return p_SetDataFormat(self, lpdf);
}

HRESULT (__stdcall * p_CreateDeviceEx)(IDirectInput7 *, REFGUID, REFIID, LPVOID *, LPUNKNOWN) = NULL;
HRESULT __stdcall d_CreateDeviceEx(IDirectInput7 *self, REFGUID rguid, REFIID riid, LPVOID *pvOut, LPUNKNOWN pUnkOuter) {
  const HRESULT result = p_CreateDeviceEx(self, rguid, riid, pvOut, pUnkOuter);
  if (result == S_OK) {
    if (!org::click3::DllHackLib::ChangeVartualProcAddress(*pvOut,  reinterpret_cast<void **>(p_SetDataFormat == NULL ? &p_SetDataFormat : NULL), &IDirectInputDevice::SetDataFormat, d_SetDataFormat)) {
      return DIERR_NOTINITIALIZED;
    }
    if (!org::click3::DllHackLib::ChangeVartualProcAddress(*pvOut,  reinterpret_cast<void **>(p_GetDeviceState == NULL ? &p_GetDeviceState : NULL), &IDirectInputDevice::GetDeviceState, d_GetDeviceState)) {
      return DIERR_NOTINITIALIZED;
    }
  }
  return result;
}

HRESULT (__stdcall * p_CreateDevice)(IDirectInput8 *, REFGUID, LPDIRECTINPUTDEVICE *, LPUNKNOWN) = NULL;
HRESULT __stdcall d_CreateDevice(IDirectInput8 *self, REFGUID rguid, LPDIRECTINPUTDEVICE *lplpDirectInputDevice, LPUNKNOWN pUnkOuter) {
  const HRESULT result = p_CreateDevice(self, rguid, lplpDirectInputDevice, pUnkOuter);
  if (result == S_OK) {
    if (!org::click3::DllHackLib::ChangeVartualProcAddress(*lplpDirectInputDevice,  reinterpret_cast<void **>(p_SetDataFormat == NULL ? &p_SetDataFormat : NULL), &IDirectInputDevice::SetDataFormat, d_SetDataFormat)) {
      return DIERR_NOTINITIALIZED;
    }
  }
  return result;
}

HRESULT __stdcall d_CoCreateInstance(REFCLSID rclsid, LPUNKNOWN pUnkOuter, DWORD dwClsContext, REFIID riid, LPVOID *ppv) {
  const HRESULT result = ::CoCreateInstance(rclsid, pUnkOuter, dwClsContext, riid, ppv);
  if (result == S_OK) {
    if (rclsid == MY_CLSID_DirectInput8) {
      if (!org::click3::DllHackLib::ChangeVartualProcAddress(*ppv, reinterpret_cast<void **>(&p_CreateDevice), &IDirectInput8::CreateDevice, d_CreateDevice)) {
        return E_POINTER;
      }
    } else if (rclsid == MY_CLSID_DirectInput) {
      if (!org::click3::DllHackLib::ChangeVartualProcAddress(*ppv, reinterpret_cast<void **>(&p_CreateDeviceEx), &IDirectInput7::CreateDeviceEx, d_CreateDeviceEx)) {
        return E_POINTER;
      }
    }
  }
  return result;
}

FARPROC __stdcall d_GetProcAddress(const HMODULE hmod, const char * const name) {
  static HMODULE ole32 = NULL;
  static std::string coCreateInstanceName = "CoCreateInstance";
  if (ole32 == NULL) {
    ole32 = ::LoadLibraryW(L"ole32.dll");
  }
  const FARPROC result = ::GetProcAddress(hmod, name);
  if (ole32 == hmod && coCreateInstanceName == name) {
    return reinterpret_cast<FARPROC>(d_CoCreateInstance);
  }
  return result;
}

const char *GetIniFullPath() {
  static char result[MAX_PATH] = "";
  if (result[0] == '\0') {
    ::GetCurrentDirectoryA(_countof(result), result);
    ::PathAppendA(result, INI_FILENAME);
  }
  return result;
}

unsigned int GetIniInt(const char * const key, const unsigned int defaultValue) {
  char defaultString[32];
  ::sprintf(defaultString, "%d", defaultValue);
  char result[32];
  ::GetPrivateProfileString(INI_SECTION_NAME, key, defaultString, result, _countof(result), GetIniFullPath());
  return ::atoi(result);
}

void ReadKeyMap() {
  for (unsigned int i = 0; i < _countof(keyMap); i++) {
    keyMap[i] = (std::min)(GetIniInt(INI_KEY_LIST[i], INI_DEFAULT_VALUE_LIST[i]), static_cast<unsigned int>(31));
  }
}

void main() {
  ReadKeyMap();
  if (!ChangeIAT("Kernel32.dll", "GetProcAddress", reinterpret_cast<FARPROC>(d_GetProcAddress))) {
    return;
  }
}
