LuaYard/LuaNetwork.cpp

663 lines
12 KiB
C++
Raw Permalink Normal View History

2019-08-18 02:00:57 +08:00
#include "LuaNetwork.h"
// Using Win10 by default
#define _WIN32_WINNT 0x0A00
#include <winsock2.h>
#include <ws2tcpip.h>
#ifdef _MSC_VER
#pragma comment(lib,"ws2_32.lib")
#endif
#include <vector>
#include <string>
#include <set>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <functional>
#include "LuaCommon.h"
using namespace std;
int LuaWSError(lua_State* L, const string& prefix="")
{
char msgbuf[1024] = { 0 };
int err = WSAGetLastError();
FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, err, MAKELANGID(LANG_ENGLISH, SUBLANG_DEFAULT), msgbuf, 1024, NULL);
if (!prefix.empty())
{
return luaL_error(L, "%s: %s", prefix.c_str(), msgbuf);
}
else
{
return luaL_error(L, msgbuf);
}
}
// Bought from GSock.
int DNSResolve(const std::string& HostName, std::vector<std::string>& _out_IPStrVec)
{
std::vector<std::string> vec;
/// Use getaddrinfo instead
struct addrinfo hints;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_protocol = IPPROTO_TCP;
struct addrinfo* result = nullptr;
int ret = getaddrinfo(HostName.c_str(), NULL, &hints, &result);
if (ret != 0)
{
return -1;
}
int cnt = 0;
for (struct addrinfo* ptr = result; ptr != nullptr; ptr = ptr->ai_next)
{
cnt++;
switch (ptr->ai_family)
{
case AF_INET:
{
sockaddr_in* paddr = (struct sockaddr_in*) (ptr->ai_addr);
char ip_buff[64] = { 0 };
const char* ptr = inet_ntop(AF_INET, &(paddr->sin_addr), ip_buff, 64);
if (ptr != NULL)
{
vec.push_back(ptr);
}
break;
}
case AF_INET6:
{
sockaddr_in6* paddr = (struct sockaddr_in6*) (ptr->ai_addr);
char ip_buff[128] = { 0 };
const char* ptr = inet_ntop(AF_INET6, &(paddr->sin6_addr), ip_buff, 128);
if (ptr != NULL)
{
vec.push_back(ptr);
}
break;
}
}// End of switch
}
freeaddrinfo(result);
_out_IPStrVec = std::move(vec);
// if(cnt!=(int)_out_IPStrVec.size()),
// then (cnt-(int)_out_IPStrVec.size()) errors happend while calling inet_ntop().
return cnt;
}
int DNSResolve(const std::string& HostName, std::string& _out_IPStr)
{
std::vector<std::string> vec;
int ret = DNSResolve(HostName, vec);
if (ret < 0)
{
return -1;
}
if (vec.empty())
{
return -2;
}
_out_IPStr = vec[0];
return 0;
}
class ThreadPool
{
public:
int maxSize;
using TaskRunner = void (*)(void*);
class TaskData
{
public:
TaskRunner runner;
void* data;
mutex mLock;
condition_variable cond;
volatile int status; // -1 exit 0 new task 1 started 2 finished
TaskData()
{
runner = nullptr;
data = nullptr;
status = 2;
}
};
vector<TaskData> tasks;
vector<thread> workers;
static void task_runner(TaskData& td)
{
unique_lock<mutex> ulk(td.mLock);
while (1)
{
td.cond.wait(ulk);
if (td.status == -1)
{
break;
}
td.status = 1;
if (td.runner)
{
td.runner(td.data);
}
td.status = 2;
}
}
ThreadPool(int maxSize) : maxSize(maxSize), tasks(maxSize)
{
for (int i = 0; i < maxSize; i++)
{
workers.emplace_back(task_runner, tasks[i]);
}
}
~ThreadPool()
{
for (int i = 0; i < maxSize; i++)
{
TaskData& td = tasks[i];
{
unique_lock<mutex> ulk(td.mLock);
td.status = -1;
td.cond.notify_all();
}
workers[i].join();
}
}
int addTask(TaskRunner runner, void* data)
{
for (int i = 0; i < maxSize; i++)
{
TaskData& td = tasks[i];
if (td.status == 2)
{
unique_lock<mutex> ulk(td.mLock);
td.runner = runner;
td.data = data;
td.status = 0;
td.cond.notify_all();
return i;
}
}
return -1;
}
};
// need_copy: false <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>data<74>й<EFBFBD>ϽȨ. true <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>dataû<61>й<EFBFBD>ϽȨ, <20><>Ҫ<EFBFBD><D2AA><EFBFBD><EFBFBD>.
void PushReadEvent(int fd, const char* data, int read_ret, bool need_copy = false)
{
SDL_Event e;
e.type = SDL_USEREVENT;
e.user.code = 1005;
SocketEventData* p = new SocketEventData;
p->fd = fd;
p->ret = read_ret;
if (read_ret <= 0)
{
p->errcode = WSAGetLastError();
p->data = nullptr;
if (!need_copy) delete[] data;
}
else
{
p->errcode = 0;
if (need_copy)
{
char* s = new char[read_ret + 32];
memcpy(s, data, read_ret);
p->data = s;
}
else
{
p->data = data;
}
}
SDL_PushEvent(&e);
}
void PushWriteEvent(int fd)
{
SDL_Event e;
e.type = SDL_USEREVENT;
e.user.code = 1006;
SocketEventData* p = new SocketEventData;
p->data = nullptr;
p->fd = fd;
p->ret = 0;
p->errcode = 0;
SDL_PushEvent(&e);
}
class SocketSelector
{
public:
mutex mLock;
condition_variable cond;
volatile int status;
set<int> wait_read;
set<int> wait_write;
int control_usec; // 1ms
int control_usec_low_bound; // 0.5ms
int control_usec_high_bound; // 1.5ms
SocketSelector()
{
control_usec = 1000;
control_usec_low_bound = 500;
control_usec_high_bound = 1500;
status = 1; // 0 Require stop, 1 Running
}
int run_once()
{
fd_set* pRead;
fd_set* pWrite;
fd_set fs_read;
fd_set fs_write;
timeval tm;
tm.tv_sec = 0;
tm.tv_usec = control_usec;
{
lock_guard<mutex> ulk(mLock);
if (!wait_read.empty())
{
FD_ZERO(&fs_read);
for (const int& fd : wait_read)
{
FD_SET(fd, &fs_read);
}
pRead = &fs_read;
}
if (!wait_write.empty())
{
FD_ZERO(&fs_write);
for (const int& fd : wait_write)
{
FD_SET(fd, &fs_write);
}
pWrite = &fs_write;
}
}
int ret = select(0, pRead, pWrite, nullptr, &tm); // <20><>û<EFBFBD>оͿյ<CDBF>һ<EFBFBD><D2BB>tick.
if (ret <= 0)
{
return ret;
}
{
lock_guard<mutex> ulk(mLock);
for (const int& fd : wait_read)
{
if (FD_ISSET(fd, &fs_read))
{
char* buffer = new char[1024];
int ret = recv(fd, buffer, 1024, 0);
PushReadEvent(fd, buffer, ret);
wait_read.erase(fd);
}
}
for (const int& fd : wait_write)
{
if (FD_ISSET(fd, &fs_write))
{
PushWriteEvent(fd);
wait_write.erase(fd);
}
}
}
return 1;
}
void run()
{
while (status)
{
run_once();
}
}
void add_wait_read(int fd)
{
lock_guard<mutex> lk(mLock);
wait_read.insert(fd);
}
void add_wait_write(int fd)
{
lock_guard<mutex> lk(mLock);
wait_write.insert(fd);
}
};
#define setfn(func, name) lua_pushcfunction(L, func);lua_setfield(L, -2, name)
struct ClientSocket
{
int fd;
int family;
int socketType;
static int close(lua_State* L)
{
check(L, { LUA_TUSERDATA }, { "clientsocket" });
ClientSocket* p = (ClientSocket*)lua_touserdata(L, 1);
if (p->fd != -1)
{
closesocket(p->fd);
p->fd = -1;
}
return 0;
}
static int connect_ipv4(lua_State* L)
{
check(L, { LUA_TUSERDATA, LUA_TSTRING, LUA_TNUMBER }, { "clientsocket" });
ClientSocket* p = (ClientSocket*)lua_touserdata(L, 1);
const char* addr = lua_tostring(L, 2);
int port = lua_tointeger(L, 3);
string IPStr;
if (DNSResolve(addr, IPStr) < 0)
{
return LuaWSError(L, "DNS resolve failed ("s + addr + ")");
}
struct sockaddr_in saddr;
memset(&saddr, 0, sizeof(saddr));
if (inet_pton(AF_INET, IPStr.c_str(), &(saddr.sin_addr.s_addr)) != 1)
{
return 0;
}
saddr.sin_port = htons(port);
saddr.sin_family = AF_INET;
int ret = connect(p->fd, (const sockaddr*)& saddr, sizeof(saddr));
if (ret < 0)
{
return LuaWSError(L, "Connect failed. ("s + IPStr + ")");
}
return 0;
}
static int send(lua_State* L)
{
check(L, { LUA_TUSERDATA, LUA_TSTRING }, { "clientsocket" });
ClientSocket* p = (ClientSocket*)lua_touserdata(L, 1);
size_t dataLen = 0;
const char* data = lua_tolstring(L, 2, &dataLen);
size_t done = 0;
while (done < dataLen)
{
int ret = ::send(p->fd, data + done, dataLen - done, 0);
if (ret < 0)
{
return LuaWSError(L, "Send failed");
}
done += ret;
}
return 0;
}
static int recv(lua_State* L)
{
check(L, { LUA_TUSERDATA }, { "clientsocket" });
ClientSocket* p = (ClientSocket*)lua_touserdata(L, 1);
int size;
if (lua_gettop(L) >= 2)
{
printf("%s\n", lua_typename(L, lua_type(L, 2)));
size = lua_tointeger(L, 2);
}
else
{
size = 1024;
}
if (size > 0)
{
unique_ptr<char> data(new char[size]);
int ret = ::recv(p->fd, data.get(), size, 0);
if (ret < 0)
{
return LuaWSError(L, "Recv failed");
}
else if (ret == 0)
{
lua_pushnil(L);
}
else
{
lua_pushlstring(L, data.get(), ret);
}
return 1;
}
else
{
return luaL_error(L, "Bad argument #%d. integer above 0 expected, got %s", 2, lua_tostring(L, 2));
}
}
static int create(lua_State* L)
{
int socketType = SOCK_STREAM;
if (lua_gettop(L) >= 1)
{
if (lua_type(L, 1) == LUA_TSTRING)
{
if (strcmp(lua_tostring(L, 1), "tcp") == 0)
{
socketType = SOCK_STREAM;
}
else if (strcmp(lua_tostring(L, 1), "udp") == 0)
{
socketType = SOCK_DGRAM;
}
else
{
return luaL_error(L, "Invalid socket type: %s", lua_tostring(L, 1));
}
}
lua_settop(L, 0);
}
int fd = socket(AF_INET, socketType, 0);
if (fd == -1)
{
return LuaWSError(L);
}
ClientSocket* p = (ClientSocket*)lua_newuserdata(L, sizeof(ClientSocket));
p->fd = fd;
p->family = AF_INET;
p->socketType = socketType;
if (lua_getfield(L, LUA_REGISTRYINDEX, "__clientsocket_mt") != LUA_TTABLE)
{
lua_pop(L, 1);
lua_newtable(L);
// GC
setfn(close, "__gc");
// Fields
lua_newtable(L);
lua_pushstring(L, "clientsocket"); lua_setfield(L, -2, "type");
setfn(close, "close");
setfn(connect_ipv4, "connect");
setfn(send, "send");
setfn(recv, "recv");
// Set __index of metatable.
lua_setfield(L, -2, "__index");
lua_pushvalue(L, -1);
lua_setfield(L, LUA_REGISTRYINDEX, "__clientsocket_mt");
}
lua_setmetatable(L, -2);
return 1;
}
};
template<typename T>
class ClassWrapper
{
public:
bool __destroyed;
const char* type;
struct Fuck
{
using TMemFn = int (T::*)(lua_State*);
TMemFn f;
};
static int destroy(lua_State* L)
{
T* t = static_cast<T*>(lua_touserdata(L, 1));
if (!t->__destroyed)
{
t->__destroyed = true;
t->~T();
}
return 0;
}
static int create(lua_State* L)
{
T* t = new (lua_newuserdata(L, sizeof(T))) T(L);
t->__destroyed = false;
char mtname[64] = { 0 };
sprintf(mtname, "__clsw_%u_mt", typeid(T).hash_code);
if (lua_getfield(L, LUA_REGISTRYINDEX, mtname) != LUA_TTABLE)
{
lua_pop(L, 1);
lua_newtable(L);
lua_pushstring(L, t->type);
lua_setfield(L, -2, "type");
lua_pushcfunction(L, destroy);
lua_setfield(L, -2, "__gc");
lua_newtable(L);
int __chk_before = lua_gettop(L);
t->prepare(L); // T::prepare(lua_State*) must keep stack balance!
int __chk_after = lua_gettop(L);
if (__chk_before != __chk_after)
{
return luaL_error(L, "Stack unbalance detected while initializing: %s", type);
}
lua_setfield(L, -2, "__index");
lua_pushvalue(L, -1);
lua_setfield(L, LUA_REGISTRYINDEX, mtname);
}
lua_setmetatable(L, -1);
return 1;
}
static int LuaCallGate(lua_State* L)
{
if (lua_type(L, 1) != LUA_TUSERDATA)
{
return luaL_error(L, "Bad argument #1. userdata expected. Got %s.", lua_typename(L, lua_type(L, 1)));
}
if (!lua_getmetatable(L, 1))
{
return luaL_error(L, "Bad argument #1. cannot get metatable.");
}
char mtname[64] = { 0 };
sprintf(mtname, "__clsw_%u_mt", typeid(T).hash_code);
lua_getfield(L, LUA_REGISTRYINDEX, mtname);
if (!lua_rawequal(L, -1, -2))
{
lua_pop(L, 2);
return luaL_error(L, "Bad argument #1. metatable mismatch.");
}
lua_pop(L, 2);
T* p = (T*)(lua_touserdata(L, 1));
Fuck* f = (Fuck*)(lua_touserdata(L, lua_upvalueindex(1)));
return std::invoke(f->f, *p, L);
}
void AddFunction(lua_State* L, typename Fuck::TMemFn callable)
{
Fuck* f = new Fuck;
f->f = callable;
lua_pushlightuserdata(L, f);
lua_pushcclosure(L, LuaCallGate, 1);
}
};
class MyClass : public ClassWrapper<MyClass>
{
public:
string welcome;
MyClass(lua_State* L)
{
printf("MyClass ctor.\n");
type = "MyClass";
welcome = lua_tostring(L, 1);
}
int showit(lua_State* L)
{
printf("In MyClass::showit, %s", welcome.c_str());
return 0;
}
~MyClass()
{
printf("MyClass dtor.\n");
}
void prepare(lua_State* L)
{
AddFunction(L, &MyClass::showit);
}
};
int InitNetwork(lua_State* L)
{
WORD wd;
WSAData wdt;
wd = MAKEWORD(2, 2);
int ret = WSAStartup(wd, &wdt);
if (ret < 0)
{
return LuaWSError(L);
}
lua_pushcfunction(L, ClientSocket::create);
lua_setglobal(L, "ClientSocket");
return 0;
}