#include "gpuCacheVramQuery.h"
#include "gpuCacheGLFT.h"
#include <maya/MGlobal.h>
#include <maya/MHardwareRenderer.h>
#include <maya/MString.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <cassert>
#if defined(_WIN32)
    #define INITGUID
    #include <windows.h>
    #include <oleauto.h>
    #include <initguid.h>
    #include <wbemidl.h>
    #include <dxgi.h>
#elif defined(__APPLE__) || defined(__MACH__)
    #include <ApplicationServices/ApplicationServices.h>
#else
    #include <iostream>
    #include <fstream>
#endif
#include <maya/MGL.h>
using namespace GPUCache;
namespace {
#if defined(_WIN32)
    typedef BOOL (WINAPI *PfnCoSetProxyBlanket)(
        IUnknown* pProxy, DWORD dwAuthnSvc, DWORD dwAuthzSvc,
        OLECHAR* pServerPrincName, DWORD dwAuthnLevel, DWORD dwImpLevel,
        RPC_AUTH_IDENTITY_HANDLE pAuthInfo, DWORD dwCapabilities);
    class CoInitializeHelper
    {
    public:
        CoInitializeHelper() {
            fResult = CoInitialize(0);
        }
        ~CoInitializeHelper() {
            CoUninitialize();
        }
        operator bool() {
            return (SUCCEEDED(fResult));
        }
    private:
        HRESULT fResult;
    };
    class Win32LibraryHelper
    {
    public:
        Win32LibraryHelper(const wchar_t* library) {
            fModule = LoadLibraryW(library);
        }
        ~Win32LibraryHelper() {
            if (fModule) {
                FreeLibrary(fModule);
            }
        }
        operator HINSTANCE() const {
            return fModule;
        }
    private:
        HINSTANCE fModule;
    };
    template<class CoObjectType>
    class CoObjectCreator
    {
    public:
        virtual ~CoObjectCreator() {}
        virtual CoObjectType* operator() () const = NULL;
    };
    template<class CoObjectType>
    class CoObjectHelper
    {
    public:
        CoObjectHelper(const CoObjectCreator<CoObjectType>& creator) {
            fObject = creator();
        }
        CoObjectHelper(CoObjectType* object) {
            fObject = object;
        }
        ~CoObjectHelper() {
            if (fObject) {
                fObject->Release();
            }
        }
        CoObjectType* operator-> () const {
            return fObject;
        }
        operator bool() const {
            return (fObject != NULL);
        }
        operator CoObjectType*() const {
            return fObject;
        }
    private:
        CoObjectType* fObject;
    };
    class CoStringHelper
    {
    public:
        CoStringHelper(const wchar_t* str) {
            fString = SysAllocString(str);
        }
        ~CoStringHelper() {
            if (fString) {
                SysFreeString(fString);
            }
        }
        operator BSTR() const {
            return fString;
        }
    private:
        BSTR fString;
    };
    class WbemLocatorHelper : public CoObjectCreator<IWbemLocator>
    {
    public:
        virtual IWbemLocator* operator() () const {
            IWbemLocator* wbemLocator = NULL;
            HRESULT hres = CoCreateInstance(CLSID_WbemLocator, NULL, CLSCTX_INPROC_SERVER,
                IID_IWbemLocator, (LPVOID*)&wbemLocator);
            return SUCCEEDED(hres) ? wbemLocator : NULL;
        }
    };
    class WbemServicesHelper : public CoObjectCreator<IWbemServices>
    {
    public:
        WbemServicesHelper(IWbemLocator* wbemLocator) : fWbemLocator(wbemLocator) {}
        virtual IWbemServices* operator() () const {
            CoStringHelper ns(L"\\\\.\\root\\cimv2");
            IWbemServices* wbemServices = NULL;
            HRESULT hres = fWbemLocator->ConnectServer(ns, NULL, NULL, 0L, 0L, NULL, NULL, &wbemServices);
            return SUCCEEDED(hres) ? wbemServices : NULL;
        }
    private:
        IWbemLocator* fWbemLocator;
    };
    class EnumVideoCtrlHelper : public CoObjectCreator<IEnumWbemClassObject>
    {
    public:
        EnumVideoCtrlHelper(IWbemServices* wbemServices) : fWbemServices(wbemServices) {}
        virtual IEnumWbemClassObject* operator() () const {
            CoStringHelper className(L"Win32_VideoController");
            IEnumWbemClassObject* enumVideoControllers = NULL;
            HRESULT hres = fWbemServices->CreateInstanceEnum(className, 0, NULL, &enumVideoControllers);
            return SUCCEEDED(hres) ? enumVideoControllers : NULL;
        }
    private:
        IWbemServices* fWbemServices;
    };
    typedef HRESULT ( WINAPI* LPCREATEDXGIFACTORY )( REFIID, void** );
    class DXGIFactoryHelper : public CoObjectCreator<IDXGIFactory>
    {
    public:
        DXGIFactoryHelper(HINSTANCE dxgiModule) : fDXGIModule(dxgiModule) {}
        virtual IDXGIFactory* operator() () const {
            LPCREATEDXGIFACTORY createDXGIFactory = (LPCREATEDXGIFACTORY)
                GetProcAddress(fDXGIModule, "CreateDXGIFactory");
            if (!createDXGIFactory) {
                return NULL;
            }
            IDXGIFactory* dxgiFactory = NULL;
            createDXGIFactory(__uuidof(IDXGIFactory), (LPVOID*)&dxgiFactory);
            return dxgiFactory;
        }
    private:
        HINSTANCE fDXGIModule;
    };
    class DXGIAdapterHelper : public CoObjectCreator<IDXGIAdapter>
    {
    public:
        DXGIAdapterHelper(IDXGIFactory* dxgiFactory, UINT index)
            : fDXGIFactory(dxgiFactory), fIndex(index) {}
        virtual IDXGIAdapter* operator() () const {
            IDXGIAdapter* dxgiAdapter = NULL;
            HRESULT result = fDXGIFactory->EnumAdapters(fIndex, &dxgiAdapter);
            
            if (FAILED(result)) {
                return NULL;
            }
            return dxgiAdapter;
        }
    private:
        IDXGIFactory* fDXGIFactory;
        UINT          fIndex;
    };
#endif
}
namespace GPUCache {
MUint64 VramQuery::queryVram()
{
    
    return VramQuery::getInstance().fVram;
}
bool VramQuery::isGeforce()
{
    
    return VramQuery::getInstance().fIsGeforce;
}
bool VramQuery::isQuadro()
{
    
    return VramQuery::getInstance().fIsQuadro;
}
const MString& VramQuery::manufacturer()
 
{
    
    return VramQuery::getInstance().fManufacturer;
}
{
    
    return VramQuery::getInstance().fModel;
}
void VramQuery::driverVersion(int version[3])
{
    
    const VramQuery& query = VramQuery::getInstance();
    version[0] = query.fDriverVersion[0];
    version[1] = query.fDriverVersion[1];
    version[2] = query.fDriverVersion[2];
    
}
#if defined(_WIN32)
void VramQuery::queryVramAndDriverWMI(MUint64& vram, 
int driverVersion[3], 
MString& manufacturer, 
MString& model)
 
{
    vram = driverVersion[0] = driverVersion[1] = driverVersion[2] = 0;
    
    CoInitializeHelper coInit;
    if (!coInit) {
        return;
    }
    
    WbemLocatorHelper wbemLocatorCreator;
    CoObjectHelper<IWbemLocator> wbemLocator(wbemLocatorCreator);
    if (!wbemLocator) {
        return;
    }
    
    WbemServicesHelper wbemServiceCreator(wbemLocator);
    CoObjectHelper<IWbemServices> wbemServices(wbemServiceCreator);
    if (!wbemServices) {
        return;
    }
    
    Win32LibraryHelper ole32Library(L"ole32.dll");
    if (ole32Library) {
        PfnCoSetProxyBlanket pfnCoSetProxyBlanket = 
            (PfnCoSetProxyBlanket)GetProcAddress(ole32Library, "CoSetProxyBlanket");
        if (pfnCoSetProxyBlanket) {
            pfnCoSetProxyBlanket(wbemServices, RPC_C_AUTHN_WINNT, RPC_C_AUTHZ_NONE,
                NULL, RPC_C_AUTHN_LEVEL_CALL, RPC_C_IMP_LEVEL_IMPERSONATE, NULL, 0);
        }
    }
    
    EnumVideoCtrlHelper enumVideoCtrlCreator(wbemServices);
    CoObjectHelper<IEnumWbemClassObject> enumVideoCtrls(enumVideoCtrlCreator);
    if (!enumVideoCtrls) {
        return;
    }
    
    IWbemClassObject* videoCtrls[10] = {0};
    DWORD returned = 0;
    enumVideoCtrls->Reset();
    HRESULT hres = enumVideoCtrls->Next(5000, 10, videoCtrls, &returned);
    if (FAILED(hres) || returned == 0) {
        return;
    }
    
    VARIANT var;
    VariantClear(&var);
    CoStringHelper vramPropName(L"AdapterRAM");
    CoStringHelper compatPropName(L"AdapterCompatibility");
    CoStringHelper driverVersionPropName(L"DriverVersion");
    CoStringHelper modelPropName(L"Name");
    MUint64 maxVidMem = 0;
    for (UINT ctrlIndex = 0; ctrlIndex < returned; ctrlIndex++) {
        CoObjectHelper<IWbemClassObject> videoCtrl(videoCtrls[ctrlIndex]);
        hres = videoCtrl->Get(vramPropName, 0L, &var, NULL, NULL);
        if (SUCCEEDED(hres)) {
            if (var.ulVal > maxVidMem) {
                maxVidMem = var.ulVal;
                VariantClear(&var);
                videoCtrl->Get(compatPropName, 0L, &var, NULL, NULL);
                manufacturer = var.bstrVal;
                VariantClear(&var);
                videoCtrl->Get(driverVersionPropName, 0L, &var, NULL, NULL);
                driverVersionStr = var.bstrVal;
                VariantClear(&var);
                videoCtrl->Get(modelPropName, 0L, &var, NULL, NULL);
                model = var.bstrVal;
            }
        }
        VariantClear(&var);
    }
    vram = maxVidMem;
    if (manufacturer == "NVIDIA"
        || manufacturer == "NVIDIA " 
        ) {
        
        driverVersionStr.
split(
'.', versions);
            unsigned int numChars2 = versions[2].numChars();
            unsigned int numChars3 = versions[3].numChars();
            if (numChars2 >= 1 && numChars3 >= 2) {
                while (numChars3 < 4) {
                    
                    
                    versions[3] = 
MString(
"0") + versions[3];
                    numChars3++;
                }
                MString major1 = versions[2].substringW(numChars2-1, numChars2-1);
 
                MString major2 = versions[3].substringW(0, 1);
 
                MString minor = versions[3].substringW(2, numChars3-1);
 
                }
            }
        }
    }
    else if (manufacturer == "ATI Technologies Inc." || 
        manufacturer == "Advanced Micro Devices, Inc.") {
        
        int version[4], ret;
        ret = sscanf_s(driverVersionStr.
asChar(), 
"%d.%d.%d.%d",
                &version[0], &version[1], &version[2], &version[3]);
        if (ret == 4) {
            driverVersion[0] = version[0];
            driverVersion[1] = version[1];
        }
    }
}
MUint64 VramQuery::queryVramDXGI()
{
    
    CoInitializeHelper coInit;
    if (!coInit) {
        return 0;
    }
    
    Win32LibraryHelper dxgiLibrary(L"dxgi.dll");
    if (!dxgiLibrary) {
        return 0;
    }
    
    DXGIFactoryHelper dxgiFactoryCreator(dxgiLibrary);
    CoObjectHelper<IDXGIFactory> dxgiFactory(dxgiFactoryCreator);
    if (!dxgiFactory) {
        return 0;
    }
    
    MUint64 maxVidMem = 0;
    for (UINT index = 0; ; ++index) {
        DXGIAdapterHelper dxgiAdapterCreator(dxgiFactory, index);
        CoObjectHelper<IDXGIAdapter> dxgiAdapter(dxgiAdapterCreator);
        if (!dxgiAdapter) {
            break;
        }
        DXGI_ADAPTER_DESC dxgiAdapterDesc;
        ZeroMemory(&dxgiAdapterDesc, sizeof(DXGI_ADAPTER_DESC));
        
        HRESULT result = dxgiAdapter->GetDesc(&dxgiAdapterDesc);
        if (SUCCEEDED(result)) {
            SIZE_T vidMem = dxgiAdapterDesc.DedicatedVideoMemory;
            maxVidMem = (vidMem > maxVidMem) ? vidMem : maxVidMem;
        }
    }
    return maxVidMem;
}
#elif defined(__APPLE__) || defined(__MACH__)
void VramQuery::queryVramAndDriverMAC(MUint64& vram, 
int driverVersion[3], 
MString& manufacturer, 
MString& model)
 
{
    vram = 0;
    driverVersion[0] = driverVersion[1] = driverVersion[2] = 0;
    CGError res = CGDisplayNoErr;
    
    CGDisplayCount dspCount = 0;
    res = CGGetActiveDisplayList(0, NULL, &dspCount);
    if (res || dspCount == 0) {
        return;
    }
    
    CGDirectDisplayID* displays = (CGDirectDisplayID*)calloc((size_t)dspCount, sizeof(CGDirectDisplayID));
    res = CGGetActiveDisplayList(dspCount, displays, &dspCount);
    if (res || dspCount == 0) {
        return;
    }
    SInt64 maxVramTotal = 0;
    for (int i = 0; i < dspCount; i++) {
        
        io_service_t dspPort = CGDisplayIOServicePort(displays[i]);
        
        
        SInt64 vramScale = 1;
        CFTypeRef typeCode = IORegistryEntrySearchCFProperty(dspPort,
            kIOServicePlane,
            CFSTR("VRAM,totalsize"),
            kCFAllocatorDefault,
            kIORegistryIterateRecursively | kIORegistryIterateParents);
        if (!typeCode) {
            
            typeCode = IORegistryEntrySearchCFProperty(dspPort,
                kIOServicePlane,
                CFSTR("VRAM,totalMB"),
                kCFAllocatorDefault,
                kIORegistryIterateRecursively | kIORegistryIterateParents);
            if (typeCode) {
                vramScale = 1024 * 1024;
            }
        }
        
        if (typeCode) {
            SInt64 vramTotal = 0;
            if (CFGetTypeID(typeCode) == CFNumberGetTypeID()) {
                
                CFNumberGetValue((const __CFNumber*)typeCode, kCFNumberSInt64Type, &vramTotal);
            }
            else if (CFGetTypeID(typeCode) == CFDataGetTypeID()) {
                
                CFIndex      length = CFDataGetLength((const __CFData*)typeCode);
                const UInt8* data   = CFDataGetBytePtr((const __CFData*)typeCode);
                if (length == 4) {
                    vramTotal = *(const unsigned int*)data;
                }
                else if (length == 8) {
                    vramTotal = *(const SInt64*)data;
                }
            }
            vramTotal *= vramScale;
            CFRelease(typeCode);
            
            if (vramTotal > maxVramTotal) {
                maxVramTotal = vramTotal;
                typeCode = IORegistryEntrySearchCFProperty(dspPort,
                            kIOServicePlane,
                            CFSTR("NVDA,Features"),
                            kCFAllocatorDefault,
                            kIORegistryIterateRecursively | kIORegistryIterateParents);
                if (typeCode) {
                    CFRelease(typeCode);
                }
                typeCode = IORegistryEntrySearchCFProperty(dspPort,
                            kIOServicePlane,
                            CFSTR("ATY,Copyright"),
                            kCFAllocatorDefault,
                            kIORegistryIterateRecursively | kIORegistryIterateParents);
                if (typeCode) {
                    manufacturer = 
MString(
"Advanced Micro Devices, Inc.");
                    CFRelease(typeCode);
                }
                
                typeCode = IORegistryEntrySearchCFProperty(dspPort,
                            kIOServicePlane,
                            CFSTR("model"),
                            kCFAllocatorDefault,
                            kIORegistryIterateRecursively | kIORegistryIterateParents);
                if (typeCode) {
                    if (CFGetTypeID(typeCode) == CFDataGetTypeID()) {
                        model = 
MString((
const char*)CFDataGetBytePtr((
const __CFData*)typeCode));
                    }
                    CFRelease(typeCode);
                }
            }
        }
    }
    vram = (MUint64)maxVramTotal;
    
    
    
    const char* glVersion = (const char*)gGLFT->glGetString(MGL_VERSION);
    if (glVersion) {
        const char* implVersion = strstr(glVersion, "-");
        if (implVersion) {
            int version[3], ret;
            ret = sscanf(implVersion+1, "%d.%d.%d",
                &version[0], &version[1], &version[2]);
            if (ret == 3) {
                driverVersion[0] = version[0];
                driverVersion[1] = version[1];
                driverVersion[2] = version[2];
            }
        }
    }
}
#else
void VramQuery::queryVramAndDriverXORG(MUint64& vram, 
int driverVersion[3], 
MString& manufacturer, 
MString& model)
 
{
    vram = 0;
    driverVersion[0] = driverVersion[1] = driverVersion[2] = 0;
    
    std::string   line;
    std::ifstream xorgLog("/var/log/Xorg.0.log");
    if (!xorgLog.is_open()) {
        return;
    }
    int maxVidMemKb = 0;
    int version[3] = {0, 0, 0}, versionSize = 0;
    while (xorgLog.good()) {
        
        std::getline(xorgLog, line);
        
        
        size_t initPos = line.find("(--) PCI:");
            initPos != std::string::npos &&
            line.find("Mem @") != std::string::npos &&
            line.find("I/O @") != std::string::npos
            ) {
            size_t start = line.find("[", initPos);
            if (start != std::string::npos) {
                
                size_t end = line.find("]", start);
                model = 
MString(line.substr(start + 1, end - start - 1).c_str());
            }
            else {
                start = line.find("ATI Technologies Inc", initPos);
                if (start != std::string::npos) {
                    
                    size_t end0 = line.find(" (", start);
                    size_t end1 = line.find(", Mem @", start);
                    size_t end = (end0 != std::string::npos && end0 < end1) ? end0 : end1; 
                    model = 
MString(line.substr(start + 20, end - start - 20).c_str());
                    manufacturer = 
MString(
"Advanced Micro Devices, Inc.");
                }
            }
        }
        
        
            size_t start = line.find("NVIDIA GPU ");
            if (start != std::string::npos &&
                line.find("NVIDIA(") != std::string::npos) {
                size_t end0 = line.find(" (", start);
                size_t end1 = line.find(" at", start);
                size_t end = (end0 != std::string::npos && end0 < end1) ? end0 : end1; 
                model = 
MString(line.substr(start + 11, end - start - 11).c_str());
            }
        }
        
        size_t startOffset = std::string::npos;
        size_t endOffset   = std::string::npos;
        if (line.find("NVIDIA") != std::string::npos) {
            
            startOffset = line.find("Memory:") + 7;
            endOffset   = line.find("kBytes");
        }
        if (startOffset == std::string::npos ||
                endOffset == std::string::npos) {
            
            startOffset = line.find("Video RAM:") + 10;
            endOffset   = line.find("kByte");
        }
        if (startOffset != std::string::npos &&
                endOffset != std::string::npos) {
            
            std::string strVidMem = line.substr(startOffset, endOffset - startOffset);
            int vidMemKb = atoi(strVidMem.c_str());
            maxVidMemKb  = (vidMemKb > maxVidMemKb) ? vidMemKb : maxVidMemKb;
        }
        
        const char* driver = NULL;
        if ((driver = strstr(line.c_str(), "NVIDIA dlloader X Driver"))) {
            versionSize = sscanf(driver+24, "%d.%d.%d",
                &version[0], &version[1], &version[2]);
        }
        else if ((driver = strstr(line.c_str(), "ATI Proprietary Linux Driver Release Identifier:"))) {
            versionSize = sscanf(driver+48, "%d.%d",
                &version[0], &version[1]);
            manufacturer = 
MString(
"Advanced Micro Devices, Inc.");
        }
    }
    vram = MUint64(maxVidMemKb) * 1024;
    for (int i = 0; i < versionSize; i++) {
        driverVersion[i] = version[i];
    }
}
#endif
MUint64 VramQuery::queryVramOGL()
{
    
    
    if (gGLFT && gGLFT->extensionExists(kMGLext_NVX_gpu_memory_info)) {
        
        MGLint dedicatedVidMem = 0;
        gGLFT->glGetIntegerv(MGL_GPU_MEMORY_INFO_DEDICATED_VIDMEM_NVX, &dedicatedVidMem);
        return MUint64(dedicatedVidMem) * 1024;
    }
    else if (gGLFT && gGLFT->extensionExists(kMGLext_ATI_meminfo)) {
        
        MGLint freeVBOMem[4] = {0, 0, 0, 0};
        gGLFT->glGetIntegerv(MGL_VBO_FREE_MEMORY_ATI, freeVBOMem);
        return MUint64(freeVBOMem[0]) * 1024;
    }
    return 0;
}
bool VramQuery::isGeforceOGL()
{
    
    
    const char* renderer = (const char*)gGLFT->glGetString(MGL_RENDERER);
    return (renderer && strstr(renderer, "GeForce"));
}
bool VramQuery::isQuadroOGL()
{
    
    
    const char* renderer = (const char*)gGLFT->glGetString(MGL_RENDERER);
    return (renderer && strstr(renderer, "Quadro"));
}
const VramQuery& VramQuery::getInstance()
{
    
    static VramQuery query;
    return query;
}
VramQuery::VramQuery()
    : fVram(0),
      fIsGeforce(false),
      fIsQuadro(false)
{
    fDriverVersion[0] = fDriverVersion[1] = fDriverVersion[2] = 0;
        InitializeGLFT();
        MUint64 vram = 0;
        int     driverVersion[3] = {0, 0, 0};
#if defined(_WIN32)
        
        
        
        MUint64 vramDXGI = VramQuery::queryVramDXGI();
        VramQuery::queryVramAndDriverWMI(vram, driverVersion, manufacturer, model);
        if (vramDXGI != 0) {
            vram = vramDXGI;  
        }
#elif defined(__APPLE__) || defined(__MACH__)
        
        
        VramQuery::queryVramAndDriverMAC(vram, driverVersion, manufacturer, model);
#else
        
        VramQuery::queryVramAndDriverXORG(vram, driverVersion, manufacturer, model);
#endif
        
        
        
        if (vram == 0) {
            vram = VramQuery::queryVramOGL();
        }
    
        
        if (vram == 0) {
            vram = 1 << 30;
        }
        fVram      = vram;
        fDriverVersion[0] = driverVersion[0];
        fDriverVersion[1] = driverVersion[1];
        fDriverVersion[2] = driverVersion[2];
        fIsGeforce = VramQuery::isGeforceOGL();
        fIsQuadro  = VramQuery::isQuadroOGL();
        fManufacturer = manufacturer;
        fModel = model;
    }
}
}