#include "LuaNetwork.h" // Using Win10 by default #define _WIN32_WINNT 0x0A00 #include #include #ifdef _MSC_VER #pragma comment(lib,"ws2_32.lib") #endif #include #include #include #include #include #include #include #include #include "LuaCommon.h" using namespace std; int LuaWSError(lua_State* L, const string& prefix="") { char msgbuf[1024] = { 0 }; int err = WSAGetLastError(); FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, err, MAKELANGID(LANG_ENGLISH, SUBLANG_DEFAULT), msgbuf, 1024, NULL); if (!prefix.empty()) { return luaL_error(L, "%s: %s", prefix.c_str(), msgbuf); } else { return luaL_error(L, msgbuf); } } // Bought from GSock. int DNSResolve(const std::string& HostName, std::vector& _out_IPStrVec) { std::vector vec; /// Use getaddrinfo instead struct addrinfo hints; memset(&hints, 0, sizeof(hints)); hints.ai_family = AF_UNSPEC; hints.ai_socktype = SOCK_STREAM; hints.ai_protocol = IPPROTO_TCP; struct addrinfo* result = nullptr; int ret = getaddrinfo(HostName.c_str(), NULL, &hints, &result); if (ret != 0) { return -1; } int cnt = 0; for (struct addrinfo* ptr = result; ptr != nullptr; ptr = ptr->ai_next) { cnt++; switch (ptr->ai_family) { case AF_INET: { sockaddr_in* paddr = (struct sockaddr_in*) (ptr->ai_addr); char ip_buff[64] = { 0 }; const char* ptr = inet_ntop(AF_INET, &(paddr->sin_addr), ip_buff, 64); if (ptr != NULL) { vec.push_back(ptr); } break; } case AF_INET6: { sockaddr_in6* paddr = (struct sockaddr_in6*) (ptr->ai_addr); char ip_buff[128] = { 0 }; const char* ptr = inet_ntop(AF_INET6, &(paddr->sin6_addr), ip_buff, 128); if (ptr != NULL) { vec.push_back(ptr); } break; } }// End of switch } freeaddrinfo(result); _out_IPStrVec = std::move(vec); // if(cnt!=(int)_out_IPStrVec.size()), // then (cnt-(int)_out_IPStrVec.size()) errors happend while calling inet_ntop(). return cnt; } int DNSResolve(const std::string& HostName, std::string& _out_IPStr) { std::vector vec; int ret = DNSResolve(HostName, vec); if (ret < 0) { return -1; } if (vec.empty()) { return -2; } _out_IPStr = vec[0]; return 0; } class ThreadPool { public: int maxSize; using TaskRunner = void (*)(void*); class TaskData { public: TaskRunner runner; void* data; mutex mLock; condition_variable cond; volatile int status; // -1 exit 0 new task 1 started 2 finished TaskData() { runner = nullptr; data = nullptr; status = 2; } }; vector tasks; vector workers; static void task_runner(TaskData& td) { unique_lock ulk(td.mLock); while (1) { td.cond.wait(ulk); if (td.status == -1) { break; } td.status = 1; if (td.runner) { td.runner(td.data); } td.status = 2; } } ThreadPool(int maxSize) : maxSize(maxSize), tasks(maxSize) { for (int i = 0; i < maxSize; i++) { workers.emplace_back(task_runner, tasks[i]); } } ~ThreadPool() { for (int i = 0; i < maxSize; i++) { TaskData& td = tasks[i]; { unique_lock ulk(td.mLock); td.status = -1; td.cond.notify_all(); } workers[i].join(); } } int addTask(TaskRunner runner, void* data) { for (int i = 0; i < maxSize; i++) { TaskData& td = tasks[i]; if (td.status == 2) { unique_lock ulk(td.mLock); td.runner = runner; td.data = data; td.status = 0; td.cond.notify_all(); return i; } } return -1; } }; // need_copy: false 本方法对data有管辖权. true 本方法对data没有管辖权, 需要拷贝. void PushReadEvent(int fd, const char* data, int read_ret, bool need_copy = false) { SDL_Event e; e.type = SDL_USEREVENT; e.user.code = 1005; SocketEventData* p = new SocketEventData; p->fd = fd; p->ret = read_ret; if (read_ret <= 0) { p->errcode = WSAGetLastError(); p->data = nullptr; if (!need_copy) delete[] data; } else { p->errcode = 0; if (need_copy) { char* s = new char[read_ret + 32]; memcpy(s, data, read_ret); p->data = s; } else { p->data = data; } } SDL_PushEvent(&e); } void PushWriteEvent(int fd) { SDL_Event e; e.type = SDL_USEREVENT; e.user.code = 1006; SocketEventData* p = new SocketEventData; p->data = nullptr; p->fd = fd; p->ret = 0; p->errcode = 0; SDL_PushEvent(&e); } class SocketSelector { public: mutex mLock; condition_variable cond; volatile int status; set wait_read; set wait_write; int control_usec; // 1ms int control_usec_low_bound; // 0.5ms int control_usec_high_bound; // 1.5ms SocketSelector() { control_usec = 1000; control_usec_low_bound = 500; control_usec_high_bound = 1500; status = 1; // 0 Require stop, 1 Running } int run_once() { fd_set* pRead; fd_set* pWrite; fd_set fs_read; fd_set fs_write; timeval tm; tm.tv_sec = 0; tm.tv_usec = control_usec; { lock_guard ulk(mLock); if (!wait_read.empty()) { FD_ZERO(&fs_read); for (const int& fd : wait_read) { FD_SET(fd, &fs_read); } pRead = &fs_read; } if (!wait_write.empty()) { FD_ZERO(&fs_write); for (const int& fd : wait_write) { FD_SET(fd, &fs_write); } pWrite = &fs_write; } } int ret = select(0, pRead, pWrite, nullptr, &tm); // 都没有就空等一个tick. if (ret <= 0) { return ret; } { lock_guard ulk(mLock); for (const int& fd : wait_read) { if (FD_ISSET(fd, &fs_read)) { char* buffer = new char[1024]; int ret = recv(fd, buffer, 1024, 0); PushReadEvent(fd, buffer, ret); wait_read.erase(fd); } } for (const int& fd : wait_write) { if (FD_ISSET(fd, &fs_write)) { PushWriteEvent(fd); wait_write.erase(fd); } } } return 1; } void run() { while (status) { run_once(); } } void add_wait_read(int fd) { lock_guard lk(mLock); wait_read.insert(fd); } void add_wait_write(int fd) { lock_guard lk(mLock); wait_write.insert(fd); } }; #define setfn(func, name) lua_pushcfunction(L, func);lua_setfield(L, -2, name) struct ClientSocket { int fd; int family; int socketType; static int close(lua_State* L) { check(L, { LUA_TUSERDATA }, { "clientsocket" }); ClientSocket* p = (ClientSocket*)lua_touserdata(L, 1); if (p->fd != -1) { closesocket(p->fd); p->fd = -1; } return 0; } static int connect_ipv4(lua_State* L) { check(L, { LUA_TUSERDATA, LUA_TSTRING, LUA_TNUMBER }, { "clientsocket" }); ClientSocket* p = (ClientSocket*)lua_touserdata(L, 1); const char* addr = lua_tostring(L, 2); int port = lua_tointeger(L, 3); string IPStr; if (DNSResolve(addr, IPStr) < 0) { return LuaWSError(L, "DNS resolve failed ("s + addr + ")"); } struct sockaddr_in saddr; memset(&saddr, 0, sizeof(saddr)); if (inet_pton(AF_INET, IPStr.c_str(), &(saddr.sin_addr.s_addr)) != 1) { return 0; } saddr.sin_port = htons(port); saddr.sin_family = AF_INET; int ret = connect(p->fd, (const sockaddr*)& saddr, sizeof(saddr)); if (ret < 0) { return LuaWSError(L, "Connect failed. ("s + IPStr + ")"); } return 0; } static int send(lua_State* L) { check(L, { LUA_TUSERDATA, LUA_TSTRING }, { "clientsocket" }); ClientSocket* p = (ClientSocket*)lua_touserdata(L, 1); size_t dataLen = 0; const char* data = lua_tolstring(L, 2, &dataLen); size_t done = 0; while (done < dataLen) { int ret = ::send(p->fd, data + done, dataLen - done, 0); if (ret < 0) { return LuaWSError(L, "Send failed"); } done += ret; } return 0; } static int recv(lua_State* L) { check(L, { LUA_TUSERDATA }, { "clientsocket" }); ClientSocket* p = (ClientSocket*)lua_touserdata(L, 1); int size; if (lua_gettop(L) >= 2) { printf("%s\n", lua_typename(L, lua_type(L, 2))); size = lua_tointeger(L, 2); } else { size = 1024; } if (size > 0) { unique_ptr data(new char[size]); int ret = ::recv(p->fd, data.get(), size, 0); if (ret < 0) { return LuaWSError(L, "Recv failed"); } else if (ret == 0) { lua_pushnil(L); } else { lua_pushlstring(L, data.get(), ret); } return 1; } else { return luaL_error(L, "Bad argument #%d. integer above 0 expected, got %s", 2, lua_tostring(L, 2)); } } static int create(lua_State* L) { int socketType = SOCK_STREAM; if (lua_gettop(L) >= 1) { if (lua_type(L, 1) == LUA_TSTRING) { if (strcmp(lua_tostring(L, 1), "tcp") == 0) { socketType = SOCK_STREAM; } else if (strcmp(lua_tostring(L, 1), "udp") == 0) { socketType = SOCK_DGRAM; } else { return luaL_error(L, "Invalid socket type: %s", lua_tostring(L, 1)); } } lua_settop(L, 0); } int fd = socket(AF_INET, socketType, 0); if (fd == -1) { return LuaWSError(L); } ClientSocket* p = (ClientSocket*)lua_newuserdata(L, sizeof(ClientSocket)); p->fd = fd; p->family = AF_INET; p->socketType = socketType; if (lua_getfield(L, LUA_REGISTRYINDEX, "__clientsocket_mt") != LUA_TTABLE) { lua_pop(L, 1); lua_newtable(L); // GC setfn(close, "__gc"); // Fields lua_newtable(L); lua_pushstring(L, "clientsocket"); lua_setfield(L, -2, "type"); setfn(close, "close"); setfn(connect_ipv4, "connect"); setfn(send, "send"); setfn(recv, "recv"); // Set __index of metatable. lua_setfield(L, -2, "__index"); lua_pushvalue(L, -1); lua_setfield(L, LUA_REGISTRYINDEX, "__clientsocket_mt"); } lua_setmetatable(L, -2); return 1; } }; template class ClassWrapper { public: bool __destroyed; const char* type; struct Fuck { using TMemFn = int (T::*)(lua_State*); TMemFn f; }; static int destroy(lua_State* L) { T* t = static_cast(lua_touserdata(L, 1)); if (!t->__destroyed) { t->__destroyed = true; t->~T(); } return 0; } static int create(lua_State* L) { T* t = new (lua_newuserdata(L, sizeof(T))) T(L); t->__destroyed = false; char mtname[64] = { 0 }; sprintf(mtname, "__clsw_%u_mt", typeid(T).hash_code); if (lua_getfield(L, LUA_REGISTRYINDEX, mtname) != LUA_TTABLE) { lua_pop(L, 1); lua_newtable(L); lua_pushstring(L, t->type); lua_setfield(L, -2, "type"); lua_pushcfunction(L, destroy); lua_setfield(L, -2, "__gc"); lua_newtable(L); int __chk_before = lua_gettop(L); t->prepare(L); // T::prepare(lua_State*) must keep stack balance! int __chk_after = lua_gettop(L); if (__chk_before != __chk_after) { return luaL_error(L, "Stack unbalance detected while initializing: %s", type); } lua_setfield(L, -2, "__index"); lua_pushvalue(L, -1); lua_setfield(L, LUA_REGISTRYINDEX, mtname); } lua_setmetatable(L, -1); return 1; } static int LuaCallGate(lua_State* L) { if (lua_type(L, 1) != LUA_TUSERDATA) { return luaL_error(L, "Bad argument #1. userdata expected. Got %s.", lua_typename(L, lua_type(L, 1))); } if (!lua_getmetatable(L, 1)) { return luaL_error(L, "Bad argument #1. cannot get metatable."); } char mtname[64] = { 0 }; sprintf(mtname, "__clsw_%u_mt", typeid(T).hash_code); lua_getfield(L, LUA_REGISTRYINDEX, mtname); if (!lua_rawequal(L, -1, -2)) { lua_pop(L, 2); return luaL_error(L, "Bad argument #1. metatable mismatch."); } lua_pop(L, 2); T* p = (T*)(lua_touserdata(L, 1)); Fuck* f = (Fuck*)(lua_touserdata(L, lua_upvalueindex(1))); return std::invoke(f->f, *p, L); } void AddFunction(lua_State* L, typename Fuck::TMemFn callable) { Fuck* f = new Fuck; f->f = callable; lua_pushlightuserdata(L, f); lua_pushcclosure(L, LuaCallGate, 1); } }; class MyClass : public ClassWrapper { public: string welcome; MyClass(lua_State* L) { printf("MyClass ctor.\n"); type = "MyClass"; welcome = lua_tostring(L, 1); } int showit(lua_State* L) { printf("In MyClass::showit, %s", welcome.c_str()); return 0; } ~MyClass() { printf("MyClass dtor.\n"); } void prepare(lua_State* L) { AddFunction(L, &MyClass::showit); } }; int InitNetwork(lua_State* L) { WORD wd; WSAData wdt; wd = MAKEWORD(2, 2); int ret = WSAStartup(wd, &wdt); if (ret < 0) { return LuaWSError(L); } lua_pushcfunction(L, ClientSocket::create); lua_setglobal(L, "ClientSocket"); return 0; }