LuaYard/LuaNetwork.cpp
2019-08-18 02:01:00 +08:00

663 lines
12 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "LuaNetwork.h"
// Using Win10 by default
#define _WIN32_WINNT 0x0A00
#include <winsock2.h>
#include <ws2tcpip.h>
#ifdef _MSC_VER
#pragma comment(lib,"ws2_32.lib")
#endif
#include <vector>
#include <string>
#include <set>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <functional>
#include "LuaCommon.h"
using namespace std;
int LuaWSError(lua_State* L, const string& prefix="")
{
char msgbuf[1024] = { 0 };
int err = WSAGetLastError();
FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, err, MAKELANGID(LANG_ENGLISH, SUBLANG_DEFAULT), msgbuf, 1024, NULL);
if (!prefix.empty())
{
return luaL_error(L, "%s: %s", prefix.c_str(), msgbuf);
}
else
{
return luaL_error(L, msgbuf);
}
}
// Bought from GSock.
int DNSResolve(const std::string& HostName, std::vector<std::string>& _out_IPStrVec)
{
std::vector<std::string> vec;
/// Use getaddrinfo instead
struct addrinfo hints;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_protocol = IPPROTO_TCP;
struct addrinfo* result = nullptr;
int ret = getaddrinfo(HostName.c_str(), NULL, &hints, &result);
if (ret != 0)
{
return -1;
}
int cnt = 0;
for (struct addrinfo* ptr = result; ptr != nullptr; ptr = ptr->ai_next)
{
cnt++;
switch (ptr->ai_family)
{
case AF_INET:
{
sockaddr_in* paddr = (struct sockaddr_in*) (ptr->ai_addr);
char ip_buff[64] = { 0 };
const char* ptr = inet_ntop(AF_INET, &(paddr->sin_addr), ip_buff, 64);
if (ptr != NULL)
{
vec.push_back(ptr);
}
break;
}
case AF_INET6:
{
sockaddr_in6* paddr = (struct sockaddr_in6*) (ptr->ai_addr);
char ip_buff[128] = { 0 };
const char* ptr = inet_ntop(AF_INET6, &(paddr->sin6_addr), ip_buff, 128);
if (ptr != NULL)
{
vec.push_back(ptr);
}
break;
}
}// End of switch
}
freeaddrinfo(result);
_out_IPStrVec = std::move(vec);
// if(cnt!=(int)_out_IPStrVec.size()),
// then (cnt-(int)_out_IPStrVec.size()) errors happend while calling inet_ntop().
return cnt;
}
int DNSResolve(const std::string& HostName, std::string& _out_IPStr)
{
std::vector<std::string> vec;
int ret = DNSResolve(HostName, vec);
if (ret < 0)
{
return -1;
}
if (vec.empty())
{
return -2;
}
_out_IPStr = vec[0];
return 0;
}
class ThreadPool
{
public:
int maxSize;
using TaskRunner = void (*)(void*);
class TaskData
{
public:
TaskRunner runner;
void* data;
mutex mLock;
condition_variable cond;
volatile int status; // -1 exit 0 new task 1 started 2 finished
TaskData()
{
runner = nullptr;
data = nullptr;
status = 2;
}
};
vector<TaskData> tasks;
vector<thread> workers;
static void task_runner(TaskData& td)
{
unique_lock<mutex> ulk(td.mLock);
while (1)
{
td.cond.wait(ulk);
if (td.status == -1)
{
break;
}
td.status = 1;
if (td.runner)
{
td.runner(td.data);
}
td.status = 2;
}
}
ThreadPool(int maxSize) : maxSize(maxSize), tasks(maxSize)
{
for (int i = 0; i < maxSize; i++)
{
workers.emplace_back(task_runner, tasks[i]);
}
}
~ThreadPool()
{
for (int i = 0; i < maxSize; i++)
{
TaskData& td = tasks[i];
{
unique_lock<mutex> ulk(td.mLock);
td.status = -1;
td.cond.notify_all();
}
workers[i].join();
}
}
int addTask(TaskRunner runner, void* data)
{
for (int i = 0; i < maxSize; i++)
{
TaskData& td = tasks[i];
if (td.status == 2)
{
unique_lock<mutex> ulk(td.mLock);
td.runner = runner;
td.data = data;
td.status = 0;
td.cond.notify_all();
return i;
}
}
return -1;
}
};
// need_copy: false <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>data<74>й<EFBFBD>ϽȨ. true <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>dataû<61>й<EFBFBD>ϽȨ, <20><>Ҫ<EFBFBD><D2AA><EFBFBD><EFBFBD>.
void PushReadEvent(int fd, const char* data, int read_ret, bool need_copy = false)
{
SDL_Event e;
e.type = SDL_USEREVENT;
e.user.code = 1005;
SocketEventData* p = new SocketEventData;
p->fd = fd;
p->ret = read_ret;
if (read_ret <= 0)
{
p->errcode = WSAGetLastError();
p->data = nullptr;
if (!need_copy) delete[] data;
}
else
{
p->errcode = 0;
if (need_copy)
{
char* s = new char[read_ret + 32];
memcpy(s, data, read_ret);
p->data = s;
}
else
{
p->data = data;
}
}
SDL_PushEvent(&e);
}
void PushWriteEvent(int fd)
{
SDL_Event e;
e.type = SDL_USEREVENT;
e.user.code = 1006;
SocketEventData* p = new SocketEventData;
p->data = nullptr;
p->fd = fd;
p->ret = 0;
p->errcode = 0;
SDL_PushEvent(&e);
}
class SocketSelector
{
public:
mutex mLock;
condition_variable cond;
volatile int status;
set<int> wait_read;
set<int> wait_write;
int control_usec; // 1ms
int control_usec_low_bound; // 0.5ms
int control_usec_high_bound; // 1.5ms
SocketSelector()
{
control_usec = 1000;
control_usec_low_bound = 500;
control_usec_high_bound = 1500;
status = 1; // 0 Require stop, 1 Running
}
int run_once()
{
fd_set* pRead;
fd_set* pWrite;
fd_set fs_read;
fd_set fs_write;
timeval tm;
tm.tv_sec = 0;
tm.tv_usec = control_usec;
{
lock_guard<mutex> ulk(mLock);
if (!wait_read.empty())
{
FD_ZERO(&fs_read);
for (const int& fd : wait_read)
{
FD_SET(fd, &fs_read);
}
pRead = &fs_read;
}
if (!wait_write.empty())
{
FD_ZERO(&fs_write);
for (const int& fd : wait_write)
{
FD_SET(fd, &fs_write);
}
pWrite = &fs_write;
}
}
int ret = select(0, pRead, pWrite, nullptr, &tm); // <20><>û<EFBFBD>оͿյ<CDBF>һ<EFBFBD><D2BB>tick.
if (ret <= 0)
{
return ret;
}
{
lock_guard<mutex> ulk(mLock);
for (const int& fd : wait_read)
{
if (FD_ISSET(fd, &fs_read))
{
char* buffer = new char[1024];
int ret = recv(fd, buffer, 1024, 0);
PushReadEvent(fd, buffer, ret);
wait_read.erase(fd);
}
}
for (const int& fd : wait_write)
{
if (FD_ISSET(fd, &fs_write))
{
PushWriteEvent(fd);
wait_write.erase(fd);
}
}
}
return 1;
}
void run()
{
while (status)
{
run_once();
}
}
void add_wait_read(int fd)
{
lock_guard<mutex> lk(mLock);
wait_read.insert(fd);
}
void add_wait_write(int fd)
{
lock_guard<mutex> lk(mLock);
wait_write.insert(fd);
}
};
#define setfn(func, name) lua_pushcfunction(L, func);lua_setfield(L, -2, name)
struct ClientSocket
{
int fd;
int family;
int socketType;
static int close(lua_State* L)
{
check(L, { LUA_TUSERDATA }, { "clientsocket" });
ClientSocket* p = (ClientSocket*)lua_touserdata(L, 1);
if (p->fd != -1)
{
closesocket(p->fd);
p->fd = -1;
}
return 0;
}
static int connect_ipv4(lua_State* L)
{
check(L, { LUA_TUSERDATA, LUA_TSTRING, LUA_TNUMBER }, { "clientsocket" });
ClientSocket* p = (ClientSocket*)lua_touserdata(L, 1);
const char* addr = lua_tostring(L, 2);
int port = lua_tointeger(L, 3);
string IPStr;
if (DNSResolve(addr, IPStr) < 0)
{
return LuaWSError(L, "DNS resolve failed ("s + addr + ")");
}
struct sockaddr_in saddr;
memset(&saddr, 0, sizeof(saddr));
if (inet_pton(AF_INET, IPStr.c_str(), &(saddr.sin_addr.s_addr)) != 1)
{
return 0;
}
saddr.sin_port = htons(port);
saddr.sin_family = AF_INET;
int ret = connect(p->fd, (const sockaddr*)& saddr, sizeof(saddr));
if (ret < 0)
{
return LuaWSError(L, "Connect failed. ("s + IPStr + ")");
}
return 0;
}
static int send(lua_State* L)
{
check(L, { LUA_TUSERDATA, LUA_TSTRING }, { "clientsocket" });
ClientSocket* p = (ClientSocket*)lua_touserdata(L, 1);
size_t dataLen = 0;
const char* data = lua_tolstring(L, 2, &dataLen);
size_t done = 0;
while (done < dataLen)
{
int ret = ::send(p->fd, data + done, dataLen - done, 0);
if (ret < 0)
{
return LuaWSError(L, "Send failed");
}
done += ret;
}
return 0;
}
static int recv(lua_State* L)
{
check(L, { LUA_TUSERDATA }, { "clientsocket" });
ClientSocket* p = (ClientSocket*)lua_touserdata(L, 1);
int size;
if (lua_gettop(L) >= 2)
{
printf("%s\n", lua_typename(L, lua_type(L, 2)));
size = lua_tointeger(L, 2);
}
else
{
size = 1024;
}
if (size > 0)
{
unique_ptr<char> data(new char[size]);
int ret = ::recv(p->fd, data.get(), size, 0);
if (ret < 0)
{
return LuaWSError(L, "Recv failed");
}
else if (ret == 0)
{
lua_pushnil(L);
}
else
{
lua_pushlstring(L, data.get(), ret);
}
return 1;
}
else
{
return luaL_error(L, "Bad argument #%d. integer above 0 expected, got %s", 2, lua_tostring(L, 2));
}
}
static int create(lua_State* L)
{
int socketType = SOCK_STREAM;
if (lua_gettop(L) >= 1)
{
if (lua_type(L, 1) == LUA_TSTRING)
{
if (strcmp(lua_tostring(L, 1), "tcp") == 0)
{
socketType = SOCK_STREAM;
}
else if (strcmp(lua_tostring(L, 1), "udp") == 0)
{
socketType = SOCK_DGRAM;
}
else
{
return luaL_error(L, "Invalid socket type: %s", lua_tostring(L, 1));
}
}
lua_settop(L, 0);
}
int fd = socket(AF_INET, socketType, 0);
if (fd == -1)
{
return LuaWSError(L);
}
ClientSocket* p = (ClientSocket*)lua_newuserdata(L, sizeof(ClientSocket));
p->fd = fd;
p->family = AF_INET;
p->socketType = socketType;
if (lua_getfield(L, LUA_REGISTRYINDEX, "__clientsocket_mt") != LUA_TTABLE)
{
lua_pop(L, 1);
lua_newtable(L);
// GC
setfn(close, "__gc");
// Fields
lua_newtable(L);
lua_pushstring(L, "clientsocket"); lua_setfield(L, -2, "type");
setfn(close, "close");
setfn(connect_ipv4, "connect");
setfn(send, "send");
setfn(recv, "recv");
// Set __index of metatable.
lua_setfield(L, -2, "__index");
lua_pushvalue(L, -1);
lua_setfield(L, LUA_REGISTRYINDEX, "__clientsocket_mt");
}
lua_setmetatable(L, -2);
return 1;
}
};
template<typename T>
class ClassWrapper
{
public:
bool __destroyed;
const char* type;
struct Fuck
{
using TMemFn = int (T::*)(lua_State*);
TMemFn f;
};
static int destroy(lua_State* L)
{
T* t = static_cast<T*>(lua_touserdata(L, 1));
if (!t->__destroyed)
{
t->__destroyed = true;
t->~T();
}
return 0;
}
static int create(lua_State* L)
{
T* t = new (lua_newuserdata(L, sizeof(T))) T(L);
t->__destroyed = false;
char mtname[64] = { 0 };
sprintf(mtname, "__clsw_%u_mt", typeid(T).hash_code);
if (lua_getfield(L, LUA_REGISTRYINDEX, mtname) != LUA_TTABLE)
{
lua_pop(L, 1);
lua_newtable(L);
lua_pushstring(L, t->type);
lua_setfield(L, -2, "type");
lua_pushcfunction(L, destroy);
lua_setfield(L, -2, "__gc");
lua_newtable(L);
int __chk_before = lua_gettop(L);
t->prepare(L); // T::prepare(lua_State*) must keep stack balance!
int __chk_after = lua_gettop(L);
if (__chk_before != __chk_after)
{
return luaL_error(L, "Stack unbalance detected while initializing: %s", type);
}
lua_setfield(L, -2, "__index");
lua_pushvalue(L, -1);
lua_setfield(L, LUA_REGISTRYINDEX, mtname);
}
lua_setmetatable(L, -1);
return 1;
}
static int LuaCallGate(lua_State* L)
{
if (lua_type(L, 1) != LUA_TUSERDATA)
{
return luaL_error(L, "Bad argument #1. userdata expected. Got %s.", lua_typename(L, lua_type(L, 1)));
}
if (!lua_getmetatable(L, 1))
{
return luaL_error(L, "Bad argument #1. cannot get metatable.");
}
char mtname[64] = { 0 };
sprintf(mtname, "__clsw_%u_mt", typeid(T).hash_code);
lua_getfield(L, LUA_REGISTRYINDEX, mtname);
if (!lua_rawequal(L, -1, -2))
{
lua_pop(L, 2);
return luaL_error(L, "Bad argument #1. metatable mismatch.");
}
lua_pop(L, 2);
T* p = (T*)(lua_touserdata(L, 1));
Fuck* f = (Fuck*)(lua_touserdata(L, lua_upvalueindex(1)));
return std::invoke(f->f, *p, L);
}
void AddFunction(lua_State* L, typename Fuck::TMemFn callable)
{
Fuck* f = new Fuck;
f->f = callable;
lua_pushlightuserdata(L, f);
lua_pushcclosure(L, LuaCallGate, 1);
}
};
class MyClass : public ClassWrapper<MyClass>
{
public:
string welcome;
MyClass(lua_State* L)
{
printf("MyClass ctor.\n");
type = "MyClass";
welcome = lua_tostring(L, 1);
}
int showit(lua_State* L)
{
printf("In MyClass::showit, %s", welcome.c_str());
return 0;
}
~MyClass()
{
printf("MyClass dtor.\n");
}
void prepare(lua_State* L)
{
AddFunction(L, &MyClass::showit);
}
};
int InitNetwork(lua_State* L)
{
WORD wd;
WSAData wdt;
wd = MAKEWORD(2, 2);
int ret = WSAStartup(wd, &wdt);
if (ret < 0)
{
return LuaWSError(L);
}
lua_pushcfunction(L, ClientSocket::create);
lua_setglobal(L, "ClientSocket");
return 0;
}