Skip to content

Commit

Permalink
Avoid fixed-length paths in the router
Browse files Browse the repository at this point in the history
  • Loading branch information
kcat committed May 2, 2024
1 parent c05fc02 commit 7e6074c
Showing 1 changed file with 87 additions and 67 deletions.
154 changes: 87 additions & 67 deletions router/router.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <string>
#include <string_view>

#include "AL/alc.h"
#include "AL/al.h"

#include "almalloc.h"
#include "alstring.h"
#include "strutils.h"

#include "version.h"
Expand All @@ -20,7 +23,9 @@
enum LogLevel LogLevel = LogLevel_Error;
FILE *LogFile;

static void LoadDriverList();
namespace {
void LoadDriverList();
} // namespace


BOOL APIENTRY DllMain(HINSTANCE, DWORD reason, void*)
Expand Down Expand Up @@ -70,8 +75,9 @@ BOOL APIENTRY DllMain(HINSTANCE, DWORD reason, void*)
return TRUE;
}

namespace {

static void AddModule(HMODULE module, const WCHAR *name)
void AddModule(HMODULE module, const std::wstring_view name)
{
for(auto &drv : DriverList)
{
Expand All @@ -83,7 +89,7 @@ static void AddModule(HMODULE module, const WCHAR *name)
}
if(drv->Name == name)
{
TRACE("Skipping similarly-named module %ls\n", name);
TRACE("Skipping similarly-named module %.*ls\n", al::sizei(name), name.data());
FreeLibrary(module);
return;
}
Expand All @@ -99,7 +105,7 @@ static void AddModule(HMODULE module, const WCHAR *name)
GetProcAddress(module, #x))); \
if(!newdrv.x) \
{ \
ERR("Failed to find entry point for %s in %ls\n", #x, name); \
ERR("Failed to find entry point for %s in %.*ls\n", #x, al::sizei(name), name.data()); \
err = 1; \
} \
} while(0)
Expand Down Expand Up @@ -194,7 +200,8 @@ static void AddModule(HMODULE module, const WCHAR *name)
newdrv.ALCVer = MAKE_ALC_VER(alc_ver[0], alc_ver[1]);
else
{
WARN("Failed to query ALC version for %ls, assuming 1.0\n", name);
WARN("Failed to query ALC version for %.*ls, assuming 1.0\n", al::sizei(name),
name.data());
newdrv.ALCVer = MAKE_ALC_VER(1, 0);
}

Expand All @@ -204,7 +211,8 @@ static void AddModule(HMODULE module, const WCHAR *name)
GetProcAddress(module, #x))); \
if(!newdrv.x) \
{ \
WARN("Failed to find optional entry point for %s in %ls\n", #x, name); \
WARN("Failed to find optional entry point for %s in %.*ls\n", #x, al::sizei(name), \
name.data()); \
} \
} while(0)
LOAD_PROC(alBufferf);
Expand All @@ -226,7 +234,7 @@ static void AddModule(HMODULE module, const WCHAR *name)
newdrv.alcGetProcAddress(nullptr, #x)); \
if(!newdrv.x) \
{ \
ERR("Failed to find entry point for %s in %ls\n", #x, name); \
ERR("Failed to find entry point for %s in %.*ls\n", #x, al::sizei(name), name.data()); \
err = 1; \
} \
} while(0)
Expand All @@ -242,20 +250,19 @@ static void AddModule(HMODULE module, const WCHAR *name)
DriverList.pop_back();
return;
}
TRACE("Loaded module %p, %ls, ALC %d.%d\n", decltype(std::declval<void*>()){module}, name,
newdrv.ALCVer>>8, newdrv.ALCVer&255);
TRACE("Loaded module %p, %.*ls, ALC %d.%d\n", decltype(std::declval<void*>()){module},
al::sizei(name), name.data(), newdrv.ALCVer>>8, newdrv.ALCVer&255);
#undef LOAD_PROC
}

static void SearchDrivers(WCHAR *path)
void SearchDrivers(const std::wstring_view path)
{
WIN32_FIND_DATAW fdata;

TRACE("Searching for drivers in %ls...\n", path);
std::wstring srchPath = path;
TRACE("Searching for drivers in %.*ls...\n", al::sizei(path), path.data());
std::wstring srchPath{path};
srchPath += L"\\*oal.dll";

HANDLE srchHdl = FindFirstFileW(srchPath.c_str(), &fdata);
WIN32_FIND_DATAW fdata{};
HANDLE srchHdl{FindFirstFileW(srchPath.c_str(), &fdata)};
if(srchHdl != INVALID_HANDLE_VALUE)
{
do {
Expand All @@ -276,87 +283,100 @@ static void SearchDrivers(WCHAR *path)
}
}

static WCHAR *strrchrW(WCHAR *str, WCHAR ch)
{
WCHAR *res = nullptr;
while(str && *str != '\0')
{
if(*str == ch)
res = str;
++str;
}
return res;
}

static int GetLoadedModuleDirectory(const WCHAR *name, WCHAR *moddir, DWORD length)
bool GetLoadedModuleDirectory(const WCHAR *name, std::wstring *moddir)
{
HMODULE module = nullptr;
WCHAR *sep0, *sep1;
HMODULE module{nullptr};

if(name)
{
module = GetModuleHandleW(name);
if(!module) return 0;
}

if(GetModuleFileNameW(module, moddir, length) == 0)
return 0;

sep0 = strrchrW(moddir, '/');
if(sep0) sep1 = strrchrW(sep0+1, '\\');
else sep1 = strrchrW(moddir, '\\');

if(sep1) *sep1 = '\0';
else if(sep0) *sep0 = '\0';
else *moddir = '\0';

return 1;
moddir->assign(256, '\0');
DWORD res{GetModuleFileNameW(module, moddir->data(), static_cast<DWORD>(moddir->size()))};
if(res >= moddir->size())
{
do {
moddir->append(256, '\0');
res = GetModuleFileNameW(module, moddir->data(), static_cast<DWORD>(moddir->size()));
} while(res >= moddir->size());
}
moddir->resize(res);

auto sep0 = moddir->rfind('/');
auto sep1 = moddir->rfind('\\');
if(sep0 < moddir->size() && sep1 < moddir->size())
moddir->resize(std::max(sep0, sep1));
else if(sep0 < moddir->size())
moddir->resize(sep0);
else if(sep1 < moddir->size())
moddir->resize(sep1);
else
moddir->resize(0);

return !moddir->empty();
}

void LoadDriverList()
{
WCHAR dll_path[MAX_PATH+1] = L"";
WCHAR cwd_path[MAX_PATH+1] = L"";
WCHAR proc_path[MAX_PATH+1] = L"";
WCHAR sys_path[MAX_PATH+1] = L"";

if(GetLoadedModuleDirectory(L"OpenAL32.dll", dll_path, MAX_PATH))
TRACE("Got DLL path %ls\n", dll_path);
std::wstring dll_path;
if(GetLoadedModuleDirectory(L"OpenAL32.dll", &dll_path))
TRACE("Got DLL path %ls\n", dll_path.c_str());

GetCurrentDirectoryW(MAX_PATH, cwd_path);
auto len = wcslen(cwd_path);
if(len > 0 && (cwd_path[len-1] == '\\' || cwd_path[len-1] == '/'))
cwd_path[len-1] = '\0';
TRACE("Got current working directory %ls\n", cwd_path);
std::wstring cwd_path;
if(DWORD pathlen{GetCurrentDirectoryW(0, nullptr)})
{
do {
cwd_path.resize(pathlen);
pathlen = GetCurrentDirectoryW(pathlen, cwd_path.data());
} while(pathlen >= cwd_path.size());
cwd_path.resize(pathlen);
}
if(!cwd_path.empty() && (cwd_path.back() == '\\' || cwd_path.back() == '/'))
cwd_path.pop_back();
if(!cwd_path.empty())
TRACE("Got current working directory %ls\n", cwd_path.c_str());

if(GetLoadedModuleDirectory(nullptr, proc_path, MAX_PATH))
TRACE("Got proc path %ls\n", proc_path);
std::wstring proc_path;
if(GetLoadedModuleDirectory(nullptr, &proc_path))
TRACE("Got proc path %ls\n", proc_path.c_str());

GetSystemDirectoryW(sys_path, MAX_PATH);
len = wcslen(sys_path);
if(len > 0 && (sys_path[len-1] == '\\' || sys_path[len-1] == '/'))
sys_path[len-1] = '\0';
TRACE("Got system path %ls\n", sys_path);
std::wstring sys_path;
if(UINT pathlen{GetSystemDirectoryW(nullptr, 0)})
{
do {
sys_path.resize(pathlen);
pathlen = GetSystemDirectoryW(sys_path.data(), pathlen);
} while(pathlen >= sys_path.size());
sys_path.resize(pathlen);
}
if(!sys_path.empty() && (sys_path.back() == '\\' || sys_path.back() == '/'))
sys_path.pop_back();
if(!sys_path.empty())
TRACE("Got system path %ls\n", sys_path.c_str());

/* Don't search the DLL's path if it is the same as the current working
* directory, app's path, or system path (don't want to do duplicate
* searches, or increase the priority of the app or system path).
*/
if(dll_path[0] &&
(!cwd_path[0] || wcscmp(dll_path, cwd_path) != 0) &&
(!proc_path[0] || wcscmp(dll_path, proc_path) != 0) &&
(!sys_path[0] || wcscmp(dll_path, sys_path) != 0))
(!cwd_path[0] || dll_path != cwd_path) &&
(!proc_path[0] || dll_path != proc_path) &&
(!sys_path[0] || dll_path != sys_path))
SearchDrivers(dll_path);
if(cwd_path[0] &&
(!proc_path[0] || wcscmp(cwd_path, proc_path) != 0) &&
(!sys_path[0] || wcscmp(cwd_path, sys_path) != 0))
(!proc_path[0] || cwd_path != proc_path) &&
(!sys_path[0] || cwd_path != sys_path))
SearchDrivers(cwd_path);
if(proc_path[0] && (!sys_path[0] || wcscmp(proc_path, sys_path) != 0))
if(proc_path[0] && (!sys_path[0] || proc_path != sys_path))
SearchDrivers(proc_path);
if(sys_path[0])
SearchDrivers(sys_path);
}

} // namespace


PtrIntMap::~PtrIntMap()
{
Expand Down

0 comments on commit 7e6074c

Please sign in to comment.