添加套接字创建等网络能力

This commit is contained in:
Kirigaya Kazuto 2019-10-13 17:44:59 +08:00
parent bc189b4951
commit b16cc65c5e
3 changed files with 362 additions and 517 deletions

View File

@ -12,6 +12,7 @@
#include "LuaEnum.h"
#include "LuaCommon.h"
#include "LuaHelper.h"
#include "LuaNetwork.h"
#include "EventDef.h"
#include "Window.h"
#include "Renderer.h"
@ -418,6 +419,10 @@ int InitLibs(lua_State* L)
InitKeys(L);
lua_setfield(L, -2, "keys");
// Network
InitLuaNetwork(L);
lua_setfield(L, -2, "socket");
// Stack balance
lua_pop(L, 2);
return 0;

View File

@ -1,540 +1,391 @@
#include "LuaNetwork.h"
// Using Win10 by default
#define _WIN32_WINNT 0x0A00
#include <winsock2.h>
#include <WinSock2.h>
#include <ws2tcpip.h>
#ifdef _MSC_VER
#pragma comment(lib,"ws2_32.lib")
#endif
#include <Windows.h>
#include <vector>
#include <string>
#include <set>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <functional>
#include "LuaCommon.h"
#include <iostream>
using namespace std;
int LuaWSError(lua_State* L, const string& prefix="")
#pragma comment(lib, "ws2_32.lib")
void push_message(lua_State* L, int errcode)
{
char msgbuf[1024] = { 0 };
int err = WSAGetLastError();
FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, err, MAKELANGID(LANG_ENGLISH, SUBLANG_DEFAULT), msgbuf, 1024, NULL);
if (!prefix.empty())
FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, errcode, MAKELANGID(LANG_ENGLISH, SUBLANG_DEFAULT), msgbuf, 1024, NULL);
lua_pushstring(L, msgbuf);
}
void push_wserror(lua_State* L)
{
int errcode = WSAGetLastError();
lua_pushinteger(L, errcode);
push_message(L, errcode);
}
int socket_create(lua_State* L)
{
int mode = lua_tointeger(L, 1);
int option = lua_tointeger(L, 2);
int socket_type;
switch (mode)
{
return luaL_error(L, "%s: %s", prefix.c_str(), msgbuf);
case 1:
socket_type = SOCK_STREAM;
break;
case 2:
socket_type = SOCK_DGRAM;
break;
default:
socket_type = SOCK_STREAM;
}
int sfd = socket(AF_INET, socket_type, 0);
if (sfd < 0)
{
lua_pushboolean(L, false);
push_wserror(L);
return 3;
}
if (option & 0x1)
{
unsigned long op = 1;
int ret = ioctlsocket(sfd, FIONBIO, &op);
if (ret < 0)
{
closesocket(sfd);
lua_pushboolean(L, false);
push_wserror(L);
return 3;
}
}
lua_pushboolean(L, true);
lua_pushinteger(L, sfd);
return 2;
}
int socket_recv(lua_State* L)
{
int sfd = lua_tointeger(L, 1);
int nsz = lua_tointeger(L, 2);
static char buffer[1024 * 1024]; // 1M
int ret = recv(sfd, buffer, max(min(nsz, 1024 * 1024), 0), 0);
if (ret < 0)
{
lua_pushboolean(L, false);
push_wserror(L);
return 3;
}
else
{
return luaL_error(L, msgbuf);
lua_pushboolean(L, true);
lua_pushinteger(L, ret);
if (ret > 0)
{
lua_pushlstring(L, buffer, ret);
}
else
{
lua_pushstring(L, "");
}
return 3;
}
}
// Bought from GSock.
int DNSResolve(const std::string& HostName, std::vector<std::string>& _out_IPStrVec)
int socket_send(lua_State* L)
{
std::vector<std::string> vec;
int sfd = lua_tointeger(L, 1);
size_t len;
const char* data = lua_tolstring(L, 2, &len);
int ret = send(sfd, data, len, 0);
if (ret < 0)
{
lua_pushboolean(L, false);
push_wserror(L);
return 3;
}
else
{
lua_pushboolean(L, true);
lua_pushinteger(L, ret);
lua_pushboolean(L, ret == len);
return 3;
}
}
/// Use getaddrinfo instead
int socket_select(lua_State* L)
{
timeval tval;
tval.tv_sec = lua_tointeger(L, 4);
tval.tv_usec = lua_tointeger(L, 5);
vector<int> readvec;
lua_pushnil(L);
while (lua_next(L, 1))
{
int fd = lua_tointeger(L, -1);
readvec.push_back(fd);
lua_pop(L, 1);
}
lua_pop(L, 1);
vector<int> writevec;
lua_pushnil(L);
while (lua_next(L, 2))
{
int fd = lua_tointeger(L, -1);
writevec.push_back(fd);
lua_pop(L, 1);
}
lua_pop(L, 1);
vector<int> errorvec;
lua_pushnil(L);
while (lua_next(L, 3))
{
int fd = lua_tointeger(L, -1);
errorvec.push_back(fd);
lua_pop(L, 1);
}
lua_pop(L, 1);
FD_SET readset, writeset, errorset, * preadset, * pwriteset, * perrorset;
FD_ZERO(&readset);
FD_ZERO(&writeset);
FD_ZERO(&errorset);
if (!readvec.empty())
{
for (const int& i : readvec)
{
FD_SET(i, &readset);
}
preadset = &readset;
}
else
{
preadset = NULL;
}
if (!writevec.empty())
{
for (const int& i : writevec)
{
FD_SET(i, &writeset);
}
pwriteset = &writeset;
}
else
{
pwriteset = NULL;
}
if (!errorvec.empty())
{
for (const int& i : errorvec)
{
FD_SET(i, &errorset);
}
perrorset = &errorset;
}
else
{
perrorset = NULL;
}
// cout << preadset << " " << pwriteset << " " << perrorset << " " << tval.tv_sec << " " << tval.tv_usec << endl;
int ret = select(0, preadset, pwriteset, perrorset, &tval);
// cout << "select() return: " << ret << endl;
// Sleep(1000);
if (ret < 0)
{
lua_pushboolean(L, false);
push_wserror(L);
return 3;
}
else
{
lua_pushboolean(L, true);
lua_pushinteger(L, ret);
if (ret > 0)
{
lua_newtable(L);
if (preadset)
{
int cnt = 1;
for (const int& fd : readvec)
{
if (FD_ISSET(fd, preadset))
{
lua_pushinteger(L, fd);
lua_seti(L, -2, cnt++);
}
}
}
lua_newtable(L);
if (pwriteset)
{
int cnt = 1;
for (const int& fd : writevec)
{
if (FD_ISSET(fd, pwriteset))
{
lua_pushinteger(L, fd);
lua_seti(L, -2, cnt++);
}
}
}
lua_newtable(L);
if (perrorset)
{
int cnt = 1;
for (const int& fd : errorvec)
{
if (FD_ISSET(fd, perrorset))
{
lua_pushinteger(L, fd);
lua_seti(L, -2, cnt++);
}
}
}
return 5;
}
else
{
return 2;
}
}
}
int socket_close(lua_State* L)
{
int sfd = lua_tointeger(L, 1);
closesocket(sfd);
return 0;
}
int socket_connect(lua_State* L)
{
int sfd = lua_tointeger(L, 1);
const char* ip = lua_tostring(L, 2);
int port = lua_tointeger(L, 3);
sockaddr_in saddr;
memset(&saddr, 0, sizeof(saddr));
saddr.sin_family = AF_INET;
saddr.sin_addr.s_addr = inet_addr(ip);
saddr.sin_port = htons(port);
int ret = connect(sfd, (const sockaddr*)&saddr, sizeof(saddr));
if (ret < 0)
{
lua_pushboolean(L, false);
push_wserror(L);
return 3;
}
else
{
lua_pushboolean(L, true);
return 1;
}
}
int socket_connect_finished(lua_State* L)
{
int sfd = lua_tointeger(L, 1);
int result;
int result_len = sizeof(result);
int ret = getsockopt(sfd, SOL_SOCKET, SO_ERROR, (char*)&result, &result_len);
cout << "getsockopt " << ret << " " << result << endl;
if (ret < 0)
{
lua_pushboolean(L, false);
push_wserror(L);
return 3;
}
else
{
lua_pushboolean(L, true);
return 1;
}
}
int network_dnsresolve(lua_State* L)
{
const char* ip = lua_tostring(L, 1);
struct addrinfo hints;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_family = AF_INET; // Specified to IPv4
hints.ai_socktype = SOCK_STREAM;
hints.ai_protocol = IPPROTO_TCP;
struct addrinfo* result = nullptr;
int ret = getaddrinfo(HostName.c_str(), NULL, &hints, &result);
if (ret != 0)
struct addrinfo* result = NULL;
if (getaddrinfo(ip, NULL, &hints, &result) != 0)
{
return -1;
lua_pushboolean(L, false);
push_wserror(L);
return 3;
}
struct addrinfo* p = result;
vector<string> vec;
while (p)
{
if (p->ai_family == AF_INET)
{
struct sockaddr_in* psaddr = (sockaddr_in*)(p->ai_addr);
vec.push_back(inet_ntoa(psaddr->sin_addr));
}
int cnt = 0;
for (struct addrinfo* ptr = result; ptr != nullptr; ptr = ptr->ai_next)
{
cnt++;
switch (ptr->ai_family)
{
case AF_INET:
{
sockaddr_in* paddr = (struct sockaddr_in*) (ptr->ai_addr);
char ip_buff[64] = { 0 };
const char* ptr = inet_ntop(AF_INET, &(paddr->sin_addr), ip_buff, 64);
if (ptr != NULL)
{
vec.push_back(ptr);
}
break;
}
case AF_INET6:
{
sockaddr_in6* paddr = (struct sockaddr_in6*) (ptr->ai_addr);
char ip_buff[128] = { 0 };
const char* ptr = inet_ntop(AF_INET6, &(paddr->sin6_addr), ip_buff, 128);
if (ptr != NULL)
{
vec.push_back(ptr);
}
break;
}
}// End of switch
p = p->ai_next;
}
freeaddrinfo(result);
_out_IPStrVec = std::move(vec);
// if(cnt!=(int)_out_IPStrVec.size()),
// then (cnt-(int)_out_IPStrVec.size()) errors happend while calling inet_ntop().
return cnt;
lua_pushboolean(L, true);
lua_newtable(L);
for (size_t i = 0; i < vec.size(); i++)
{
lua_pushstring(L, vec[i].c_str());
lua_seti(L, -2, (lua_Integer)i + 1);
}
return 2;
}
int DNSResolve(const std::string& HostName, std::string& _out_IPStr)
// Act like require(...), returns a table.
int InitLuaNetwork(lua_State* L)
{
std::vector<std::string> vec;
int ret = DNSResolve(HostName, vec);
if (ret < 0)
{
return -1;
WORD wVersionRequested = MAKEWORD(2, 2);
WSADATA wsaData;
int err = WSAStartup(wVersionRequested, &wsaData);
if (err != 0) {
throw exception("WSAStartup failed with error: %d\n", err);
}
if (vec.empty())
luaL_Reg funcs[] =
{
return -2;
}
_out_IPStr = vec[0];
return 0;
}
class ThreadPool
{
public:
int maxSize;
using TaskRunner = void (*)(void*);
class TaskData
{
public:
TaskRunner runner;
void* data;
mutex mLock;
condition_variable cond;
volatile int status; // -1 exit 0 new task 1 started 2 finished
TaskData()
{
runner = nullptr;
data = nullptr;
status = 2;
}
{"socket", socket_create},
{"send", socket_send},
{"recv", socket_recv},
{"connect", socket_connect},
{"connect_finished", socket_connect_finished},
{"select", socket_select},
{"close", socket_close},
{"resolve", network_dnsresolve},
{NULL, NULL}
};
vector<TaskData> tasks;
vector<thread> workers;
static void task_runner(TaskData& td)
{
unique_lock<mutex> ulk(td.mLock);
while (1)
{
td.cond.wait(ulk);
if (td.status == -1)
{
break;
}
td.status = 1;
if (td.runner)
{
td.runner(td.data);
}
td.status = 2;
}
}
ThreadPool(int maxSize) : maxSize(maxSize), tasks(maxSize)
{
for (int i = 0; i < maxSize; i++)
{
workers.emplace_back(task_runner, ref(tasks[i]));
}
}
~ThreadPool()
{
for (int i = 0; i < maxSize; i++)
{
TaskData& td = tasks[i];
{
unique_lock<mutex> ulk(td.mLock);
td.status = -1;
td.cond.notify_all();
}
workers[i].join();
}
}
int addTask(TaskRunner runner, void* data)
{
for (int i = 0; i < maxSize; i++)
{
TaskData& td = tasks[i];
if (td.status == 2)
{
unique_lock<mutex> ulk(td.mLock);
td.runner = runner;
td.data = data;
td.status = 0;
td.cond.notify_all();
return i;
}
}
return -1;
}
};
// need_copy: false 本方法对data有管辖权. true 本方法对data没有管辖权, 需要拷贝.
void PushReadEvent(int fd, const char* data, int read_ret, bool need_copy = false)
{
SDL_Event e;
e.type = SDL_USEREVENT;
e.user.code = 1005;
SocketEventData* p = new SocketEventData;
p->fd = fd;
p->ret = read_ret;
if (read_ret <= 0)
{
p->errcode = WSAGetLastError();
p->data = nullptr;
if (!need_copy) delete[] data;
}
else
{
p->errcode = 0;
if (need_copy)
{
char* s = new char[read_ret + 32];
memcpy(s, data, read_ret);
p->data = s;
}
else
{
p->data = data;
}
}
SDL_PushEvent(&e);
}
void PushWriteEvent(int fd)
{
SDL_Event e;
e.type = SDL_USEREVENT;
e.user.code = 1006;
SocketEventData* p = new SocketEventData;
p->data = nullptr;
p->fd = fd;
p->ret = 0;
p->errcode = 0;
SDL_PushEvent(&e);
}
class SocketSelector
{
public:
mutex mLock;
condition_variable cond;
volatile int status;
set<int> wait_read;
set<int> wait_write;
int control_usec; // 1ms
int control_usec_low_bound; // 0.5ms
int control_usec_high_bound; // 1.5ms
SocketSelector()
{
control_usec = 1000;
control_usec_low_bound = 500;
control_usec_high_bound = 1500;
status = 1; // 0 Require stop, 1 Running
}
int run_once()
{
fd_set* pRead;
fd_set* pWrite;
fd_set fs_read;
fd_set fs_write;
timeval tm;
tm.tv_sec = 0;
tm.tv_usec = control_usec;
{
lock_guard<mutex> ulk(mLock);
if (!wait_read.empty())
{
FD_ZERO(&fs_read);
for (const int& fd : wait_read)
{
FD_SET(fd, &fs_read);
}
pRead = &fs_read;
}
if (!wait_write.empty())
{
FD_ZERO(&fs_write);
for (const int& fd : wait_write)
{
FD_SET(fd, &fs_write);
}
pWrite = &fs_write;
}
}
int ret = select(0, pRead, pWrite, nullptr, &tm); // 都没有就空等一个tick.
if (ret <= 0)
{
return ret;
}
{
lock_guard<mutex> ulk(mLock);
for (const int& fd : wait_read)
{
if (FD_ISSET(fd, &fs_read))
{
char* buffer = new char[1024];
int ret = recv(fd, buffer, 1024, 0);
PushReadEvent(fd, buffer, ret);
wait_read.erase(fd);
}
}
for (const int& fd : wait_write)
{
if (FD_ISSET(fd, &fs_write))
{
PushWriteEvent(fd);
wait_write.erase(fd);
}
}
}
luaL_newlib(L, funcs);
lua_pushstring(L, "LuaNetwork v0.1");
lua_setfield(L, -2, "version");
return 1;
}
void run()
{
while (status)
{
run_once();
}
}
void add_wait_read(int fd)
{
lock_guard<mutex> lk(mLock);
wait_read.insert(fd);
}
void add_wait_write(int fd)
{
lock_guard<mutex> lk(mLock);
wait_write.insert(fd);
}
};
#define setfn(func, name) lua_pushcfunction(L, func);lua_setfield(L, -2, name)
struct ClientSocket
{
int fd;
int family;
int socketType;
static int close(lua_State* L)
{
check(L, { LUA_TUSERDATA }, { "clientsocket" });
ClientSocket* p = (ClientSocket*)lua_touserdata(L, 1);
if (p->fd != -1)
{
closesocket(p->fd);
p->fd = -1;
}
return 0;
}
static int connect_ipv4(lua_State* L)
{
check(L, { LUA_TUSERDATA, LUA_TSTRING, LUA_TNUMBER }, { "clientsocket" });
ClientSocket* p = (ClientSocket*)lua_touserdata(L, 1);
const char* addr = lua_tostring(L, 2);
int port = lua_tointeger(L, 3);
string IPStr;
if (DNSResolve(addr, IPStr) < 0)
{
return LuaWSError(L, "DNS resolve failed ("s + addr + ")");
}
struct sockaddr_in saddr;
memset(&saddr, 0, sizeof(saddr));
if (inet_pton(AF_INET, IPStr.c_str(), &(saddr.sin_addr.s_addr)) != 1)
{
return 0;
}
saddr.sin_port = htons(port);
saddr.sin_family = AF_INET;
int ret = connect(p->fd, (const sockaddr*)& saddr, sizeof(saddr));
if (ret < 0)
{
return LuaWSError(L, "Connect failed. ("s + IPStr + ")");
}
return 0;
}
static int send(lua_State* L)
{
check(L, { LUA_TUSERDATA, LUA_TSTRING }, { "clientsocket" });
ClientSocket* p = (ClientSocket*)lua_touserdata(L, 1);
size_t dataLen = 0;
const char* data = lua_tolstring(L, 2, &dataLen);
size_t done = 0;
while (done < dataLen)
{
int ret = ::send(p->fd, data + done, dataLen - done, 0);
if (ret < 0)
{
return LuaWSError(L, "Send failed");
}
done += ret;
}
return 0;
}
static int recv(lua_State* L)
{
check(L, { LUA_TUSERDATA }, { "clientsocket" });
ClientSocket* p = (ClientSocket*)lua_touserdata(L, 1);
int size;
if (lua_gettop(L) >= 2)
{
printf("%s\n", lua_typename(L, lua_type(L, 2)));
size = lua_tointeger(L, 2);
}
else
{
size = 1024;
}
if (size > 0)
{
unique_ptr<char> data(new char[size]);
int ret = ::recv(p->fd, data.get(), size, 0);
if (ret < 0)
{
return LuaWSError(L, "Recv failed");
}
else if (ret == 0)
{
lua_pushnil(L);
}
else
{
lua_pushlstring(L, data.get(), ret);
}
return 1;
}
else
{
return luaL_error(L, "Bad argument #%d. integer above 0 expected, got %s", 2, lua_tostring(L, 2));
}
}
static int create(lua_State* L)
{
int socketType = SOCK_STREAM;
if (lua_gettop(L) >= 1)
{
if (lua_type(L, 1) == LUA_TSTRING)
{
if (strcmp(lua_tostring(L, 1), "tcp") == 0)
{
socketType = SOCK_STREAM;
}
else if (strcmp(lua_tostring(L, 1), "udp") == 0)
{
socketType = SOCK_DGRAM;
}
else
{
return luaL_error(L, "Invalid socket type: %s", lua_tostring(L, 1));
}
}
lua_settop(L, 0);
}
int fd = socket(AF_INET, socketType, 0);
if (fd == -1)
{
return LuaWSError(L);
}
ClientSocket* p = (ClientSocket*)lua_newuserdata(L, sizeof(ClientSocket));
p->fd = fd;
p->family = AF_INET;
p->socketType = socketType;
if (lua_getfield(L, LUA_REGISTRYINDEX, "__clientsocket_mt") != LUA_TTABLE)
{
lua_pop(L, 1);
lua_newtable(L);
// GC
setfn(close, "__gc");
// Fields
lua_newtable(L);
lua_pushstring(L, "clientsocket"); lua_setfield(L, -2, "type");
setfn(close, "close");
setfn(connect_ipv4, "connect");
setfn(send, "send");
setfn(recv, "recv");
// Set __index of metatable.
lua_setfield(L, -2, "__index");
lua_pushvalue(L, -1);
lua_setfield(L, LUA_REGISTRYINDEX, "__clientsocket_mt");
}
lua_setmetatable(L, -2);
return 1;
}
};
int InitNetwork(lua_State* L)
{
WORD wd;
WSAData wdt;
wd = MAKEWORD(2, 2);
int ret = WSAStartup(wd, &wdt);
if (ret < 0)
{
return LuaWSError(L);
}
lua_pushcfunction(L, ClientSocket::create);
lua_setglobal(L, "ClientSocket");
return 0;
}
void HandleNetworkEvent(lua_State* L, const SDL_Event& e)
{
}

View File

@ -1,15 +1,4 @@
#pragma once
#include "LuaVM.h"
#include "SDL2/include/SDL.h"
int InitNetwork(lua_State* L);
struct SocketEventData
{
int fd;
const char* data;
int ret;
int errcode;
};
void HandleNetworkEvent(lua_State* L, const SDL_Event& e);
int InitLuaNetwork(lua_State* L);