LibWS/websocket.cpp

238 lines
4.7 KiB
C++

#include "websocket.h"
#include "gsock_helper.h" // GSock
#include <cstring>
#include <memory>
#include <vector>
#include <queue>
#include "base64.hpp"
#include "sha1.h"
#include <WinSock2.h> // htonl
using namespace std;
string GetResponseKey(const string& key)
{
string server_key = key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
SHA1 sha;
sha.update(server_key);
uint32_t arr[5];
sha.final(arr);
// 必须转换成网络字节序
for (int i = 0; i < 5; i++) arr[i] = htonl(arr[i]);
return base64_encode_std((const unsigned char*)arr, sizeof(arr));
}
int Handshake(sock& s)
{
sock_helper sp(s);
vector<string> lines;
while (true)
{
string str;
int ret = sp.recvline(str);
if (ret <= 0) return -1;
if (str.empty()) break;
lines.push_back(str);
// cout << str << endl;
}
string target("Sec-WebSocket-Key");
for (auto& str : lines)
{
if (str.find(target) != string::npos)
{
string key = str.substr(str.find(target) + 19, 24);
string response_key = GetResponseKey(key);
string response_header = string("HTTP/1.1 101 Switching Protocols\r\n") +
"Connection: upgrade\r\n" +
"Sec-WebSocket-Accept: " + response_key + "\r\n" +
"Upgrade: websocket\r\n\r\n";
if (sp.sendall(response_header) <= 0)
{
// Network error.
return -1;
}
else
{
// Handshake finished successfully.
return 0;
}
}
}
// Not websocket protocol
return -2;
}
int ReadFrame(sock& s, WSFrame& f)
{
sock_helper sp(s);
unsigned char c;
if (s.recv(&c, 1) <= 0) return -1;
f.fin = c & 0x80; // 1000 0000
f.rsv1 = c & 0x40; // 0100 0000
f.rsv2 = c & 0x20; // 0010 0000
f.rsv3 = c & 0x10; // 0001 0000
f.opcode = c & 0xF; // 0000 1111
if (s.recv(&c, 1) <= 0) return -1;
f.ismask = c & 0x80; // 1000 0000
int payload_head = c & 0x7F; // 0111 1111
if (payload_head < 126)
{
f.len = payload_head;
}
else if (payload_head == 126)
{
uint16_t x;
if (sp.recvall(&x, sizeof(x)) <= 0) return -1;
f.len = ntohs(x);
}
else if (payload_head == 127)
{
uint64_t x;
if (sp.recvall(&x, sizeof(x)) <= 0) return -1;
f.len = ntohll(x);
}
if (f.ismask)
{
if (sp.recvall(f.mask, sizeof(f.mask)) <= 0) return -1;
}
unique_ptr<char[]> xp(new char[f.len]);
if (xp.get() == nullptr) return -2;
memset(xp.get(), 0, f.len);
if (sp.recvall(xp.get(), f.len) <= 0) return -1;
if (f.ismask)
{
// 处理掩码问题
for (unsigned long long i = 0; i < f.len; i++)
{
xp[i] ^= f.mask[i % 4];
}
}
f.data = string(xp.get(), f.len);
return 0;
}
int SendFrame(sock& s, const WSFrame& f)
{
sock_helper sp(s);
unsigned char c = 0;
if (f.fin) c = 0x80; // 1000 0000
if (f.rsv1) c |= 0x40; // 0100 0000
if (f.rsv2) c |= 0x20; // 0010 0000
if (f.rsv3) c |= 0x10; // 0001 0000
c |= (f.opcode & 0xF); // 0000 1111
if (s.send(&c, 1) <= 0) return -1;
if (f.ismask) c = 0x80; // 1000 0000
else c = 0;
if (f.len < 126)
{
c |= (f.len & 0x7F);
if (s.send(&c, 1) <= 0) return -1;
}
else if (f.len < (2 << 16))
{
c |= 126;
if (s.send(&c, 1) <= 0) return -1;
uint16_t len = f.len;
if (sp.sendall(&len, sizeof(len)) <= 0) return -1;
}
else
{
c |= 127;
if (s.send(&c, 1) <= 0) return -1;
uint64_t len = f.len;
if (sp.sendall(&len, sizeof(len)) <= 0) return -1;
}
if (f.ismask)
{
if (sp.sendall(&(f.mask), sizeof(f.mask)) <= 0) return -1;
vector<char> vec(f.data.begin(), f.data.end());
for (uint64_t i = 0; i < f.len; i++)
{
vec[i] ^= f.mask[i % 4];
}
if (sp.sendall(vec.data(), vec.size()) <= 0) return -1;
}
else
{
if (sp.sendall(f.data) <= 0) return -1;
}
return 0;
}
int SendPong(sock& s, const WSFrame& ping)
{
cout << "Sending pong to " << (&s) << endl;
WSFrame f;
f.fin = true;
f.rsv1 = false;
f.rsv2 = false;
f.rsv3 = false;
f.ismask = false;
f.len = ping.len;
f.data = ping.data;
f.opcode = 0xA;
return SendFrame(s, f);
}
int ReadMsg(sock& s, string& out_data)
{
string data;
WSFrame f;
int pack_cnt = 0;
while (true)
{
int ret = ReadFrame(s, f);
if (ret < 0) return -1;
if (f.opcode == 0x9) // ping包则回复一个pong包
{
cout << "Received ping from " << (&s) << endl;
if (SendPong(s, f) < 0) return -2;
continue;
}
else if (f.opcode == 0xA) // pong包则什么都不做
{
cout << "Received pong from " << (&s) << endl;
continue;
}
else if (f.opcode == 0x8) // 连接主动关闭
{
cout << "Websocket is closing " << (&s) << endl;
return 0;
}
pack_cnt++;
data.append(f.data);
if (f.fin) break;
}
out_data = data;
return pack_cnt;
}
int SendMsg(sock& s, const string& data, bool is_text)
{
WSFrame f;
f.fin = true;
f.rsv1 = false;
f.rsv2 = false;
f.rsv3 = false;
f.ismask = false;
f.len = data.size();
f.data = data;
f.opcode = is_text ? 0x1 : 0x2;
if (SendFrame(s, f) < 0)
{
return -1;
}
else
{
return 1;
}
}