663 lines
12 KiB
C++
663 lines
12 KiB
C++
#include "LuaNetwork.h"
|
||
|
||
// Using Win10 by default
|
||
#define _WIN32_WINNT 0x0A00
|
||
#include <winsock2.h>
|
||
#include <ws2tcpip.h>
|
||
#ifdef _MSC_VER
|
||
#pragma comment(lib,"ws2_32.lib")
|
||
#endif
|
||
|
||
#include <vector>
|
||
#include <string>
|
||
#include <set>
|
||
#include <memory>
|
||
#include <thread>
|
||
#include <mutex>
|
||
#include <condition_variable>
|
||
#include <functional>
|
||
#include "LuaCommon.h"
|
||
using namespace std;
|
||
|
||
int LuaWSError(lua_State* L, const string& prefix="")
|
||
{
|
||
char msgbuf[1024] = { 0 };
|
||
int err = WSAGetLastError();
|
||
FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, err, MAKELANGID(LANG_ENGLISH, SUBLANG_DEFAULT), msgbuf, 1024, NULL);
|
||
if (!prefix.empty())
|
||
{
|
||
return luaL_error(L, "%s: %s", prefix.c_str(), msgbuf);
|
||
}
|
||
else
|
||
{
|
||
return luaL_error(L, msgbuf);
|
||
}
|
||
}
|
||
|
||
// Bought from GSock.
|
||
int DNSResolve(const std::string& HostName, std::vector<std::string>& _out_IPStrVec)
|
||
{
|
||
std::vector<std::string> vec;
|
||
|
||
/// Use getaddrinfo instead
|
||
struct addrinfo hints;
|
||
memset(&hints, 0, sizeof(hints));
|
||
hints.ai_family = AF_UNSPEC;
|
||
hints.ai_socktype = SOCK_STREAM;
|
||
hints.ai_protocol = IPPROTO_TCP;
|
||
|
||
struct addrinfo* result = nullptr;
|
||
|
||
int ret = getaddrinfo(HostName.c_str(), NULL, &hints, &result);
|
||
if (ret != 0)
|
||
{
|
||
return -1;
|
||
}
|
||
|
||
int cnt = 0;
|
||
for (struct addrinfo* ptr = result; ptr != nullptr; ptr = ptr->ai_next)
|
||
{
|
||
cnt++;
|
||
switch (ptr->ai_family)
|
||
{
|
||
case AF_INET:
|
||
{
|
||
sockaddr_in* paddr = (struct sockaddr_in*) (ptr->ai_addr);
|
||
char ip_buff[64] = { 0 };
|
||
const char* ptr = inet_ntop(AF_INET, &(paddr->sin_addr), ip_buff, 64);
|
||
if (ptr != NULL)
|
||
{
|
||
vec.push_back(ptr);
|
||
}
|
||
break;
|
||
}
|
||
case AF_INET6:
|
||
{
|
||
sockaddr_in6* paddr = (struct sockaddr_in6*) (ptr->ai_addr);
|
||
char ip_buff[128] = { 0 };
|
||
const char* ptr = inet_ntop(AF_INET6, &(paddr->sin6_addr), ip_buff, 128);
|
||
if (ptr != NULL)
|
||
{
|
||
vec.push_back(ptr);
|
||
}
|
||
break;
|
||
}
|
||
}// End of switch
|
||
}
|
||
|
||
freeaddrinfo(result);
|
||
|
||
_out_IPStrVec = std::move(vec);
|
||
|
||
// if(cnt!=(int)_out_IPStrVec.size()),
|
||
// then (cnt-(int)_out_IPStrVec.size()) errors happend while calling inet_ntop().
|
||
return cnt;
|
||
}
|
||
|
||
int DNSResolve(const std::string& HostName, std::string& _out_IPStr)
|
||
{
|
||
std::vector<std::string> vec;
|
||
int ret = DNSResolve(HostName, vec);
|
||
if (ret < 0)
|
||
{
|
||
return -1;
|
||
}
|
||
if (vec.empty())
|
||
{
|
||
return -2;
|
||
}
|
||
_out_IPStr = vec[0];
|
||
return 0;
|
||
}
|
||
|
||
class ThreadPool
|
||
{
|
||
public:
|
||
int maxSize;
|
||
|
||
using TaskRunner = void (*)(void*);
|
||
|
||
class TaskData
|
||
{
|
||
public:
|
||
TaskRunner runner;
|
||
void* data;
|
||
mutex mLock;
|
||
condition_variable cond;
|
||
|
||
volatile int status; // -1 exit 0 new task 1 started 2 finished
|
||
|
||
TaskData()
|
||
{
|
||
runner = nullptr;
|
||
data = nullptr;
|
||
status = 2;
|
||
}
|
||
};
|
||
|
||
vector<TaskData> tasks;
|
||
vector<thread> workers;
|
||
|
||
static void task_runner(TaskData& td)
|
||
{
|
||
unique_lock<mutex> ulk(td.mLock);
|
||
while (1)
|
||
{
|
||
td.cond.wait(ulk);
|
||
if (td.status == -1)
|
||
{
|
||
break;
|
||
}
|
||
td.status = 1;
|
||
if (td.runner)
|
||
{
|
||
td.runner(td.data);
|
||
}
|
||
td.status = 2;
|
||
}
|
||
}
|
||
|
||
ThreadPool(int maxSize) : maxSize(maxSize), tasks(maxSize)
|
||
{
|
||
for (int i = 0; i < maxSize; i++)
|
||
{
|
||
workers.emplace_back(task_runner, tasks[i]);
|
||
}
|
||
}
|
||
|
||
~ThreadPool()
|
||
{
|
||
for (int i = 0; i < maxSize; i++)
|
||
{
|
||
TaskData& td = tasks[i];
|
||
{
|
||
unique_lock<mutex> ulk(td.mLock);
|
||
td.status = -1;
|
||
td.cond.notify_all();
|
||
}
|
||
workers[i].join();
|
||
}
|
||
}
|
||
|
||
int addTask(TaskRunner runner, void* data)
|
||
{
|
||
for (int i = 0; i < maxSize; i++)
|
||
{
|
||
TaskData& td = tasks[i];
|
||
if (td.status == 2)
|
||
{
|
||
unique_lock<mutex> ulk(td.mLock);
|
||
td.runner = runner;
|
||
td.data = data;
|
||
td.status = 0;
|
||
td.cond.notify_all();
|
||
return i;
|
||
}
|
||
}
|
||
|
||
return -1;
|
||
}
|
||
};
|
||
|
||
|
||
// need_copy: false <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>data<74>й<EFBFBD>ϽȨ. true <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>dataû<61>й<EFBFBD>ϽȨ, <20><>Ҫ<EFBFBD><D2AA><EFBFBD><EFBFBD>.
|
||
void PushReadEvent(int fd, const char* data, int read_ret, bool need_copy = false)
|
||
{
|
||
SDL_Event e;
|
||
e.type = SDL_USEREVENT;
|
||
e.user.code = 1005;
|
||
SocketEventData* p = new SocketEventData;
|
||
p->fd = fd;
|
||
p->ret = read_ret;
|
||
if (read_ret <= 0)
|
||
{
|
||
p->errcode = WSAGetLastError();
|
||
p->data = nullptr;
|
||
if (!need_copy) delete[] data;
|
||
}
|
||
else
|
||
{
|
||
p->errcode = 0;
|
||
if (need_copy)
|
||
{
|
||
char* s = new char[read_ret + 32];
|
||
memcpy(s, data, read_ret);
|
||
p->data = s;
|
||
}
|
||
else
|
||
{
|
||
p->data = data;
|
||
}
|
||
}
|
||
SDL_PushEvent(&e);
|
||
}
|
||
|
||
void PushWriteEvent(int fd)
|
||
{
|
||
SDL_Event e;
|
||
e.type = SDL_USEREVENT;
|
||
e.user.code = 1006;
|
||
SocketEventData* p = new SocketEventData;
|
||
p->data = nullptr;
|
||
p->fd = fd;
|
||
p->ret = 0;
|
||
p->errcode = 0;
|
||
SDL_PushEvent(&e);
|
||
}
|
||
|
||
|
||
class SocketSelector
|
||
{
|
||
public:
|
||
mutex mLock;
|
||
condition_variable cond;
|
||
volatile int status;
|
||
set<int> wait_read;
|
||
set<int> wait_write;
|
||
|
||
int control_usec; // 1ms
|
||
int control_usec_low_bound; // 0.5ms
|
||
int control_usec_high_bound; // 1.5ms
|
||
|
||
SocketSelector()
|
||
{
|
||
control_usec = 1000;
|
||
control_usec_low_bound = 500;
|
||
control_usec_high_bound = 1500;
|
||
status = 1; // 0 Require stop, 1 Running
|
||
}
|
||
|
||
int run_once()
|
||
{
|
||
fd_set* pRead;
|
||
fd_set* pWrite;
|
||
fd_set fs_read;
|
||
fd_set fs_write;
|
||
timeval tm;
|
||
tm.tv_sec = 0;
|
||
tm.tv_usec = control_usec;
|
||
|
||
{
|
||
lock_guard<mutex> ulk(mLock);
|
||
|
||
if (!wait_read.empty())
|
||
{
|
||
FD_ZERO(&fs_read);
|
||
for (const int& fd : wait_read)
|
||
{
|
||
FD_SET(fd, &fs_read);
|
||
}
|
||
pRead = &fs_read;
|
||
}
|
||
if (!wait_write.empty())
|
||
{
|
||
FD_ZERO(&fs_write);
|
||
for (const int& fd : wait_write)
|
||
{
|
||
FD_SET(fd, &fs_write);
|
||
}
|
||
pWrite = &fs_write;
|
||
}
|
||
}
|
||
|
||
int ret = select(0, pRead, pWrite, nullptr, &tm); // <20><>û<EFBFBD>оͿյ<CDBF>һ<EFBFBD><D2BB>tick.
|
||
if (ret <= 0)
|
||
{
|
||
return ret;
|
||
}
|
||
|
||
{
|
||
lock_guard<mutex> ulk(mLock);
|
||
|
||
for (const int& fd : wait_read)
|
||
{
|
||
if (FD_ISSET(fd, &fs_read))
|
||
{
|
||
char* buffer = new char[1024];
|
||
int ret = recv(fd, buffer, 1024, 0);
|
||
PushReadEvent(fd, buffer, ret);
|
||
wait_read.erase(fd);
|
||
}
|
||
}
|
||
|
||
for (const int& fd : wait_write)
|
||
{
|
||
if (FD_ISSET(fd, &fs_write))
|
||
{
|
||
PushWriteEvent(fd);
|
||
wait_write.erase(fd);
|
||
}
|
||
}
|
||
}
|
||
return 1;
|
||
}
|
||
|
||
void run()
|
||
{
|
||
while (status)
|
||
{
|
||
run_once();
|
||
}
|
||
}
|
||
|
||
void add_wait_read(int fd)
|
||
{
|
||
lock_guard<mutex> lk(mLock);
|
||
wait_read.insert(fd);
|
||
}
|
||
|
||
void add_wait_write(int fd)
|
||
{
|
||
lock_guard<mutex> lk(mLock);
|
||
wait_write.insert(fd);
|
||
}
|
||
};
|
||
|
||
#define setfn(func, name) lua_pushcfunction(L, func);lua_setfield(L, -2, name)
|
||
|
||
struct ClientSocket
|
||
{
|
||
int fd;
|
||
int family;
|
||
int socketType;
|
||
|
||
static int close(lua_State* L)
|
||
{
|
||
check(L, { LUA_TUSERDATA }, { "clientsocket" });
|
||
ClientSocket* p = (ClientSocket*)lua_touserdata(L, 1);
|
||
if (p->fd != -1)
|
||
{
|
||
closesocket(p->fd);
|
||
p->fd = -1;
|
||
}
|
||
return 0;
|
||
}
|
||
|
||
static int connect_ipv4(lua_State* L)
|
||
{
|
||
check(L, { LUA_TUSERDATA, LUA_TSTRING, LUA_TNUMBER }, { "clientsocket" });
|
||
ClientSocket* p = (ClientSocket*)lua_touserdata(L, 1);
|
||
const char* addr = lua_tostring(L, 2);
|
||
int port = lua_tointeger(L, 3);
|
||
|
||
string IPStr;
|
||
if (DNSResolve(addr, IPStr) < 0)
|
||
{
|
||
return LuaWSError(L, "DNS resolve failed ("s + addr + ")");
|
||
}
|
||
|
||
struct sockaddr_in saddr;
|
||
memset(&saddr, 0, sizeof(saddr));
|
||
if (inet_pton(AF_INET, IPStr.c_str(), &(saddr.sin_addr.s_addr)) != 1)
|
||
{
|
||
return 0;
|
||
}
|
||
saddr.sin_port = htons(port);
|
||
saddr.sin_family = AF_INET;
|
||
|
||
int ret = connect(p->fd, (const sockaddr*)& saddr, sizeof(saddr));
|
||
if (ret < 0)
|
||
{
|
||
return LuaWSError(L, "Connect failed. ("s + IPStr + ")");
|
||
}
|
||
|
||
return 0;
|
||
}
|
||
|
||
static int send(lua_State* L)
|
||
{
|
||
check(L, { LUA_TUSERDATA, LUA_TSTRING }, { "clientsocket" });
|
||
ClientSocket* p = (ClientSocket*)lua_touserdata(L, 1);
|
||
size_t dataLen = 0;
|
||
const char* data = lua_tolstring(L, 2, &dataLen);
|
||
size_t done = 0;
|
||
while (done < dataLen)
|
||
{
|
||
int ret = ::send(p->fd, data + done, dataLen - done, 0);
|
||
if (ret < 0)
|
||
{
|
||
return LuaWSError(L, "Send failed");
|
||
}
|
||
done += ret;
|
||
}
|
||
return 0;
|
||
}
|
||
|
||
static int recv(lua_State* L)
|
||
{
|
||
check(L, { LUA_TUSERDATA }, { "clientsocket" });
|
||
ClientSocket* p = (ClientSocket*)lua_touserdata(L, 1);
|
||
int size;
|
||
if (lua_gettop(L) >= 2)
|
||
{
|
||
printf("%s\n", lua_typename(L, lua_type(L, 2)));
|
||
size = lua_tointeger(L, 2);
|
||
}
|
||
else
|
||
{
|
||
size = 1024;
|
||
}
|
||
if (size > 0)
|
||
{
|
||
unique_ptr<char> data(new char[size]);
|
||
int ret = ::recv(p->fd, data.get(), size, 0);
|
||
if (ret < 0)
|
||
{
|
||
return LuaWSError(L, "Recv failed");
|
||
}
|
||
else if (ret == 0)
|
||
{
|
||
lua_pushnil(L);
|
||
}
|
||
else
|
||
{
|
||
lua_pushlstring(L, data.get(), ret);
|
||
}
|
||
return 1;
|
||
}
|
||
else
|
||
{
|
||
return luaL_error(L, "Bad argument #%d. integer above 0 expected, got %s", 2, lua_tostring(L, 2));
|
||
}
|
||
}
|
||
|
||
static int create(lua_State* L)
|
||
{
|
||
int socketType = SOCK_STREAM;
|
||
if (lua_gettop(L) >= 1)
|
||
{
|
||
if (lua_type(L, 1) == LUA_TSTRING)
|
||
{
|
||
if (strcmp(lua_tostring(L, 1), "tcp") == 0)
|
||
{
|
||
socketType = SOCK_STREAM;
|
||
}
|
||
else if (strcmp(lua_tostring(L, 1), "udp") == 0)
|
||
{
|
||
socketType = SOCK_DGRAM;
|
||
}
|
||
else
|
||
{
|
||
return luaL_error(L, "Invalid socket type: %s", lua_tostring(L, 1));
|
||
}
|
||
}
|
||
lua_settop(L, 0);
|
||
}
|
||
int fd = socket(AF_INET, socketType, 0);
|
||
if (fd == -1)
|
||
{
|
||
return LuaWSError(L);
|
||
}
|
||
ClientSocket* p = (ClientSocket*)lua_newuserdata(L, sizeof(ClientSocket));
|
||
p->fd = fd;
|
||
p->family = AF_INET;
|
||
p->socketType = socketType;
|
||
if (lua_getfield(L, LUA_REGISTRYINDEX, "__clientsocket_mt") != LUA_TTABLE)
|
||
{
|
||
lua_pop(L, 1);
|
||
lua_newtable(L);
|
||
// GC
|
||
setfn(close, "__gc");
|
||
|
||
// Fields
|
||
lua_newtable(L);
|
||
lua_pushstring(L, "clientsocket"); lua_setfield(L, -2, "type");
|
||
setfn(close, "close");
|
||
setfn(connect_ipv4, "connect");
|
||
setfn(send, "send");
|
||
setfn(recv, "recv");
|
||
|
||
// Set __index of metatable.
|
||
lua_setfield(L, -2, "__index");
|
||
|
||
lua_pushvalue(L, -1);
|
||
lua_setfield(L, LUA_REGISTRYINDEX, "__clientsocket_mt");
|
||
}
|
||
lua_setmetatable(L, -2);
|
||
|
||
return 1;
|
||
}
|
||
};
|
||
|
||
template<typename T>
|
||
class ClassWrapper
|
||
{
|
||
public:
|
||
bool __destroyed;
|
||
const char* type;
|
||
|
||
struct Fuck
|
||
{
|
||
using TMemFn = int (T::*)(lua_State*);
|
||
TMemFn f;
|
||
};
|
||
|
||
static int destroy(lua_State* L)
|
||
{
|
||
T* t = static_cast<T*>(lua_touserdata(L, 1));
|
||
if (!t->__destroyed)
|
||
{
|
||
t->__destroyed = true;
|
||
t->~T();
|
||
}
|
||
|
||
return 0;
|
||
}
|
||
|
||
static int create(lua_State* L)
|
||
{
|
||
T* t = new (lua_newuserdata(L, sizeof(T))) T(L);
|
||
t->__destroyed = false;
|
||
|
||
char mtname[64] = { 0 };
|
||
sprintf(mtname, "__clsw_%u_mt", typeid(T).hash_code);
|
||
|
||
if (lua_getfield(L, LUA_REGISTRYINDEX, mtname) != LUA_TTABLE)
|
||
{
|
||
lua_pop(L, 1);
|
||
lua_newtable(L);
|
||
lua_pushstring(L, t->type);
|
||
lua_setfield(L, -2, "type");
|
||
lua_pushcfunction(L, destroy);
|
||
lua_setfield(L, -2, "__gc");
|
||
|
||
lua_newtable(L);
|
||
int __chk_before = lua_gettop(L);
|
||
t->prepare(L); // T::prepare(lua_State*) must keep stack balance!
|
||
int __chk_after = lua_gettop(L);
|
||
if (__chk_before != __chk_after)
|
||
{
|
||
return luaL_error(L, "Stack unbalance detected while initializing: %s", type);
|
||
}
|
||
|
||
lua_setfield(L, -2, "__index");
|
||
|
||
lua_pushvalue(L, -1);
|
||
lua_setfield(L, LUA_REGISTRYINDEX, mtname);
|
||
}
|
||
|
||
lua_setmetatable(L, -1);
|
||
return 1;
|
||
}
|
||
|
||
static int LuaCallGate(lua_State* L)
|
||
{
|
||
if (lua_type(L, 1) != LUA_TUSERDATA)
|
||
{
|
||
return luaL_error(L, "Bad argument #1. userdata expected. Got %s.", lua_typename(L, lua_type(L, 1)));
|
||
}
|
||
if (!lua_getmetatable(L, 1))
|
||
{
|
||
return luaL_error(L, "Bad argument #1. cannot get metatable.");
|
||
}
|
||
char mtname[64] = { 0 };
|
||
sprintf(mtname, "__clsw_%u_mt", typeid(T).hash_code);
|
||
lua_getfield(L, LUA_REGISTRYINDEX, mtname);
|
||
if (!lua_rawequal(L, -1, -2))
|
||
{
|
||
lua_pop(L, 2);
|
||
return luaL_error(L, "Bad argument #1. metatable mismatch.");
|
||
}
|
||
lua_pop(L, 2);
|
||
|
||
T* p = (T*)(lua_touserdata(L, 1));
|
||
|
||
Fuck* f = (Fuck*)(lua_touserdata(L, lua_upvalueindex(1)));
|
||
return std::invoke(f->f, *p, L);
|
||
}
|
||
|
||
void AddFunction(lua_State* L, typename Fuck::TMemFn callable)
|
||
{
|
||
Fuck* f = new Fuck;
|
||
f->f = callable;
|
||
lua_pushlightuserdata(L, f);
|
||
lua_pushcclosure(L, LuaCallGate, 1);
|
||
}
|
||
};
|
||
|
||
class MyClass : public ClassWrapper<MyClass>
|
||
{
|
||
public:
|
||
string welcome;
|
||
|
||
MyClass(lua_State* L)
|
||
{
|
||
printf("MyClass ctor.\n");
|
||
type = "MyClass";
|
||
welcome = lua_tostring(L, 1);
|
||
}
|
||
|
||
int showit(lua_State* L)
|
||
{
|
||
printf("In MyClass::showit, %s", welcome.c_str());
|
||
return 0;
|
||
}
|
||
|
||
~MyClass()
|
||
{
|
||
printf("MyClass dtor.\n");
|
||
}
|
||
|
||
void prepare(lua_State* L)
|
||
{
|
||
AddFunction(L, &MyClass::showit);
|
||
}
|
||
};
|
||
|
||
|
||
int InitNetwork(lua_State* L)
|
||
{
|
||
WORD wd;
|
||
WSAData wdt;
|
||
wd = MAKEWORD(2, 2);
|
||
int ret = WSAStartup(wd, &wdt);
|
||
if (ret < 0)
|
||
{
|
||
return LuaWSError(L);
|
||
}
|
||
lua_pushcfunction(L, ClientSocket::create);
|
||
lua_setglobal(L, "ClientSocket");
|
||
|
||
return 0;
|
||
}
|