From 040709bdb4c4b14410e1e87bb73bf9cf14d23057 Mon Sep 17 00:00:00 2001 From: Daniel Karbach Date: Sun, 13 Oct 2024 17:47:58 +0200 Subject: [PATCH] basic twitch chat connection --- .gitignore | 1 + src/app/Application.h | 27 ++- src/sys/Promise.h | 56 ++++++ src/twitch/IRCMessage.cpp | 174 +++++++++++++++++ src/twitch/IRCMessage.h | 76 ++++++++ src/twitch/LoginToken.cpp | 59 ++++++ src/twitch/LoginToken.h | 88 +++++++++ src/ws/Connection.cpp | 199 +++++++++++++++++++- src/ws/Context.h | 140 ++++++++++++-- src/ws/HttpsConnection.h | 109 +++++++++++ src/ws/{Connection.h => PusherConnection.h} | 21 +-- src/ws/TwitchConnection.h | 160 ++++++++++++++++ src/ws/io.h | 5 +- 13 files changed, 1071 insertions(+), 44 deletions(-) create mode 100644 src/sys/Promise.h create mode 100644 src/twitch/IRCMessage.cpp create mode 100644 src/twitch/IRCMessage.h create mode 100644 src/twitch/LoginToken.cpp create mode 100644 src/twitch/LoginToken.h create mode 100644 src/ws/HttpsConnection.h rename src/ws/{Connection.h => PusherConnection.h} (83%) create mode 100644 src/ws/TwitchConnection.h diff --git a/.gitignore b/.gitignore index d5615fc..ae1bbc3 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ main out.flv test.mkv test.mp4 +twitch-token.json diff --git a/src/app/Application.h b/src/app/Application.h index 2fa7ccd..84014cd 100644 --- a/src/app/Application.h +++ b/src/app/Application.h @@ -10,7 +10,8 @@ #include "Stream.h" #include "../ffmpeg/Network.h" #include "../uv/Loop.h" -#include "../ws/Connection.h" +#include "../ws/PusherConnection.h" +#include "../ws/TwitchConnection.h" #include "../ws/Context.h" namespace app { @@ -22,7 +23,8 @@ public: : net() , loop() , ws_ctx(loop) - , ws_conn(ws_ctx.GetContext()) + , pusher_conn(ws_ctx) + , twitch_conn(ws_ctx) , stream(url, width, height, fps) , mixer(stream.GetAudioPlane(), stream.GetAudioChannels(), stream.GetAudioFrameSize()) , renderer(stream.GetVideoPlane(), stream.GetVideoLineSize(), width, height) @@ -36,7 +38,8 @@ public: public: void Start() { - ws_conn.Subscribe("ChatBotLog", &WsHandler, this); + pusher_conn.Subscribe("ChatBotLog", &PusherHandler, this); + twitch_conn.Join("#horstiebot", &TwitchHandler, this); stream.Start(); //Media &media = state.AddMedia("test.mp4"); @@ -91,12 +94,12 @@ public: } private: - static void WsHandler(void *user, const Json::Value &json) { + static void PusherHandler(void *user, const Json::Value &json) { Application *app = static_cast(user); - app->HandleWebSocket(json); + app->HandlePusher(json); } - void HandleWebSocket(const Json::Value &json) { + void HandlePusher(const Json::Value &json) { const std::string data_string = json["data"].asString(); Json::Value data; Json::Reader json_reader; @@ -117,11 +120,21 @@ private: msg.Update(renderer.GetContext()); } + static void TwitchHandler(void *user, const twitch::IRCMessage &msg) { + Application *app = static_cast(user); + app->HandleTwitch(msg); + } + + void HandleTwitch(const twitch::IRCMessage &msg) { + std::cout << "got message: " << msg.GetText() << std::endl; + } + private: ffmpeg::Network net; uv::Loop loop; ws::Context ws_ctx; - ws::Connection ws_conn; + ws::PusherConnection pusher_conn; + ws::TwitchConnection twitch_conn; Stream stream; Mixer mixer; Renderer renderer; diff --git a/src/sys/Promise.h b/src/sys/Promise.h new file mode 100644 index 0000000..3a2a6ad --- /dev/null +++ b/src/sys/Promise.h @@ -0,0 +1,56 @@ +#ifndef TEST_SYS_PROMISE_H_ +#define TEST_SYS_PROMISE_H_ + +#include +#include +#include + +namespace sys { + +template +class Promise { + +public: + typedef std::function Callback; + +public: + Promise &Then(Callback callback) { + success.push_back(callback); + return *this; + } + + Promise &Catch(Callback callback) { + error.push_back(callback); + return *this; + } + +public: + void Resolve(Args... args) { + for (Callback &callback : success) { + try { + callback(args...); + } catch (...) { + std::cerr << "exception in promise resolution" << std::endl; + } + } + } + + void Reject(Args... args) { + for (Callback &callback : error) { + try { + callback(args...); + } catch (...) { + std::cerr << "exception in promise rejection" << std::endl; + } + } + } + +private: + std::vector success; + std::vector error; + +}; + +} + +#endif diff --git a/src/twitch/IRCMessage.cpp b/src/twitch/IRCMessage.cpp new file mode 100644 index 0000000..90ea381 --- /dev/null +++ b/src/twitch/IRCMessage.cpp @@ -0,0 +1,174 @@ +#include "IRCMessage.h" +#include + +namespace twitch { + +void IRCMessage::Decode(std::string::const_iterator begin, std::string::const_iterator input_end) { + command.clear(); + params.clear(); + nick.clear(); + user.clear(); + host.clear(); + server.clear(); + tags.clear(); + + auto i = begin; + auto end = input_end; + + // strip end of newline, if present + if (i != end && *(end - 1) == '\n') --end; + if (i != end && *(end - 1) == '\r') --end; + if (i == end) return; + + if (*i == '@') { + ++i; + i = ParseTags(i, end); + } + if (i == end) return; + + if (*i == ':') { + ++i; + i = ParseAuthority(i, end); + } + if (i == end) return; + + i = ParseCommand(i, end); + while (i != end) { + i = ParseParam(i, end); + } +} + +std::string::const_iterator IRCMessage::ParseTags(std::string::const_iterator begin, std::string::const_iterator end) { + auto tags_end = std::find(begin, end, ' '); + auto tag_start = begin; + while (tag_start != tags_end) { + auto tag_end = std::find(tag_start, tags_end, ';'); + auto separator = std::find(tag_start, tag_end, '='); + std::string name(tag_start, separator); + // skip separator + if (separator != tag_end) { + ++separator; + } + tags[name] = std::string(separator, tag_end); + // skip semicolon + if (tag_end != tags_end) { + ++tag_end; + } + tag_start = tag_end; + } + // skip space + if (tags_end != end) ++tags_end; + return tags_end; +} + +std::string::const_iterator IRCMessage::ParseAuthority(std::string::const_iterator begin, std::string::const_iterator end) { + auto authority_end = std::find(begin, end, ' '); + auto exclamation = std::find(begin, authority_end, '!'); + auto at_symbol = std::find(begin, authority_end, '@'); + bool has_user = exclamation != authority_end; + bool has_host = at_symbol != authority_end; + if (has_user && has_host) { + nick.assign(begin, exclamation); + // skip exclamation mark + if (exclamation != authority_end) { + ++exclamation; + } + user.assign(exclamation, at_symbol); + // skip at symbol + if (at_symbol != authority_end) { + ++at_symbol; + } + host.assign(at_symbol, authority_end); + } else if (has_user) { + nick.assign(begin, exclamation); + // skip exclamation mark + if (exclamation != authority_end) { + ++exclamation; + } + user.assign(exclamation, authority_end); + } else if (has_host) { + nick.assign(begin, at_symbol); + // skip at symbol + if (at_symbol != authority_end) { + ++at_symbol; + } + host.assign(at_symbol, authority_end); + } else { + server.assign(begin, authority_end); + } + // skip space + if (authority_end != end) ++authority_end; + return authority_end; +} + +std::string::const_iterator IRCMessage::ParseCommand(std::string::const_iterator begin, std::string::const_iterator end) { + auto space = std::find(begin, end, ' '); + command.assign(begin, space); + // skip space + if (space != end) ++space; + return space; +} + +std::string::const_iterator IRCMessage::ParseParam(std::string::const_iterator begin, std::string::const_iterator end) { + if (begin == end) return end; + if (*begin == ':') { + params.emplace_back(begin + 1, end); + return end; + } + auto space = std::find(begin, end, ' '); + params.emplace_back(begin, space); + // skip space + if (space != end) ++space; + return space; +} + +void IRCMessage::Encode(std::string &out) const { + if (tags.size() > 0) { + out.push_back('@'); + bool first = true; + for (const auto &tag : tags) { + if (first) { + first = false; + } else { + out.push_back(';'); + } + // TODO: this may need some kind of encoding? + out.append(tag.first); + out.push_back('='); + out.append(tag.second); + } + out.push_back(' '); + } + + if (server.size() > 0) { + out.push_back(':'); + out.append(server); + out.push_back(' '); + } else if (nick.size() > 0) { + out.push_back(':'); + out.append(nick); + if (user.size() > 0) { + out.push_back('!'); + out.append(user); + } + if (host.size() > 0) { + out.push_back('@'); + out.append(host); + } + out.push_back(' '); + } + + out.append(command); + + if (params.size() == 0) return; + + for (int i = 0; i < params.size(); ++i) { + out.push_back(' '); + if (i == params.size() - 1) { + out.push_back(':'); + } + out.append(params[i]); + } +} + +} diff --git a/src/twitch/IRCMessage.h b/src/twitch/IRCMessage.h new file mode 100644 index 0000000..db25ae5 --- /dev/null +++ b/src/twitch/IRCMessage.h @@ -0,0 +1,76 @@ +#ifndef TEST_TWITCH_IRCMESSAGE_H_ +#define TEST_TWITCH_IRCMESSAGE_H_ + +#include +#include +#include +#include + +namespace twitch { + +class IRCMessage { + +public: + IRCMessage() { + } + +public: + void Decode(const std::string &in) { + Decode(in.begin(), in.end()); + } + void Decode(std::string::const_iterator begin, std::string::const_iterator end); + void Encode(std::string &out) const; + + std::string GetText() const { + return params.empty() ? "" : params.back(); + } + + bool IsLoginSuccess() const { + return command == "001"; + } + + bool IsPing() const { + return command == "PING"; + } + + bool IsPong() const { + return command == "PONG"; + } + + bool IsPrivMsg() const { + return command == "PRIVMSG"; + } + + IRCMessage MakePong() const { + IRCMessage pong; + pong.command = "PONG"; + pong.params = params; + return pong; + } + +private: + std::string::const_iterator ParseTags(std::string::const_iterator begin, std::string::const_iterator end); + std::string::const_iterator ParseAuthority(std::string::const_iterator begin, std::string::const_iterator end); + std::string::const_iterator ParseCommand(std::string::const_iterator begin, std::string::const_iterator end); + std::string::const_iterator ParseParam(std::string::const_iterator begin, std::string::const_iterator end); + +public: + std::string command; + std::vector params; + std::string nick; + std::string user; + std::string host; + std::string server; + std::map tags; + +}; + +inline std::ostream &operator <<(std::ostream &out, const IRCMessage &msg) { + std::string msg_str; + msg.Encode(msg_str); + return out << msg_str; +} + +} + +#endif diff --git a/src/twitch/LoginToken.cpp b/src/twitch/LoginToken.cpp new file mode 100644 index 0000000..6feb6fd --- /dev/null +++ b/src/twitch/LoginToken.cpp @@ -0,0 +1,59 @@ +#include "LoginToken.h" + +#include +#include + +#include "../ws/Context.h" +#include "../ws/HttpsConnection.h" + +namespace twitch { + +LoginToken::PromiseType &LoginToken::Refresh(ws::Context &ws) { + if (is_refreshing) return promise; + is_refreshing = true; + + ws::HttpsConnection &req = ws.HttpsRequest("POST", "id.twitch.tv", "/oauth2/token"); + req.SetHeader("Content-Type", "application/x-www-form-urlencoded"); + req.AddFormUrlenc("client_id", client_id); + req.AddFormUrlenc("client_secret", client_secret); + req.AddFormUrlenc("grant_type", "refresh_token"); + req.AddFormUrlenc("refresh_token", refresh_token); + req.SetContentLength(); + req.GetPromise() + .Then([this](ws::HttpsConnection &rsp) -> void { + HandleRefreshComplete(rsp); + }) + .Catch([this](ws::HttpsConnection &rsp) -> void { + HandleRefreshError(rsp); + }); + return promise; +} + +void LoginToken::HandleRefreshComplete(ws::HttpsConnection &rsp) { + is_refreshing = false; + std::cout << "completed https request with status " << rsp.GetStatus() << std::endl; + std::cout << "body: " << rsp.GetBody() << std::endl; + if (rsp.GetStatus() != 200) return; + // access_token + // refresh_token + // expires_in (seconds) + Json::Value json; + json_reader.parse(rsp.GetBody(), json); + access_token = json["access_token"].asString(); + refresh_token = json["refresh_token"].asString(); + int expires_in = json["expires_in"].asInt(); + time_t now; + std::time(&now); + expires = now + expires_in; + Save(); + promise.Resolve(*this); +} + +void LoginToken::HandleRefreshError(ws::HttpsConnection &rsp) { + is_refreshing = false; + std::cout << "errored https request with status " << rsp.GetStatus() << std::endl; + std::cout << "body: " << rsp.GetBody() << std::endl; + promise.Reject(*this); +} + +} diff --git a/src/twitch/LoginToken.h b/src/twitch/LoginToken.h new file mode 100644 index 0000000..6b76617 --- /dev/null +++ b/src/twitch/LoginToken.h @@ -0,0 +1,88 @@ +#ifndef TEST_TWITCH_LOGINTOKEN_H_ +#define TEST_TWITCH_LOGINTOKEN_H_ + +#include +#include +#include + +#include + +#include "../sys/Promise.h" + +namespace ws { + class Context; + class HttpsConnection; +} + +namespace twitch { + +class LoginToken { + +public: + typedef sys::Promise PromiseType; + +public: + LoginToken(): expires(0), is_refreshing(false) { + } + +public: + void Load() { + std::ifstream in("twitch-token.json"); + Json::Value json; + in >> json; + + client_id = json["client_id"].asString(); + client_secret = json["client_secret"].asString(); + access_token = json["access_token"].asString(); + refresh_token = json["refresh_token"].asString(); + expires = json["expires"].asInt64(); + } + + void Save() { + Json::Value json; + json["client_id"] = client_id; + json["client_secret"] = client_secret; + json["access_token"] = access_token; + json["refresh_token"] = refresh_token; + json["expires"] = expires; + + std::ofstream out; + out.open("twitch-token.json"); + out << json << std::endl; + } + + bool HasExpired() { + time_t now; + std::time(&now); + return expires < now; + } + + const std::string &GetAccessToken() const { + return access_token; + } + + PromiseType &Refresh(ws::Context &ws); + +private: + void HandleRefreshComplete(ws::HttpsConnection &rsp); + + void HandleRefreshError(ws::HttpsConnection &rsp); + +private: + std::string client_id; + std::string client_secret; + std::string access_token; + std::string refresh_token; + std::time_t expires; + + Json::Reader json_reader; + Json::FastWriter json_writer; + + PromiseType promise; + bool is_refreshing; + +}; + +} + +#endif diff --git a/src/ws/Connection.cpp b/src/ws/Connection.cpp index a111318..854245d 100644 --- a/src/ws/Connection.cpp +++ b/src/ws/Connection.cpp @@ -1,5 +1,9 @@ -#include "Connection.h" +#include "Context.h" +#include "HttpsConnection.h" +#include "PusherConnection.h" +#include "TwitchConnection.h" +#include #include #include #include @@ -8,9 +12,101 @@ namespace ws { -Connection::Connection(lws_context *ctx) +HttpsConnection::HttpsConnection(Context &ctx, const char *method, const char *host, const char *path) +: info{0}, wsi(nullptr), read_buffer{0}, status(0) { + info.context = ctx.GetContext(); + info.opaque_user_data = this; + info.address = host; + info.port = 443; + info.ssl_connection = 1; + info.path = path; + info.host = host; + info.origin = "test"; + info.method = method; + info.protocol = "https"; + info.ietf_version_or_minus_one = -1; + info.userdata = &ctx; + info.pwsi = &wsi; + wsi = lws_client_connect_via_info(&info); + if (!wsi) { + throw std::runtime_error("failed to connect client"); + } + out_buffer.insert(0, LWS_PRE, '\0'); +} + +int HttpsConnection::ProtoCallback(lws_callback_reasons reason, void *in, size_t len) { + switch (reason) { + case LWS_CALLBACK_CLIENT_CONNECTION_ERROR: + promise.Reject(*this); + break; + case LWS_CALLBACK_ESTABLISHED_CLIENT_HTTP: + status = lws_http_client_http_response(wsi); + break; + case LWS_CALLBACK_RECEIVE_CLIENT_HTTP_READ: + in_buffer.append(static_cast(in), len); + break; + case LWS_CALLBACK_RECEIVE_CLIENT_HTTP: { + char *p = &read_buffer[LWS_PRE]; + int l = sizeof(read_buffer) - LWS_PRE; + if (lws_http_client_read(wsi, &p, &l) < 0) { + return -1; + } + } + break; + case LWS_CALLBACK_COMPLETED_CLIENT_HTTP: + promise.Resolve(*this); + break; + case LWS_CALLBACK_CLIENT_APPEND_HANDSHAKE_HEADER: + if (!lws_http_is_redirected_to_get(wsi)) { + unsigned char **p = reinterpret_cast(in); + for (const auto &header : headers) { + const unsigned char *name = reinterpret_cast(header.first.c_str()); + const unsigned char *value = reinterpret_cast(header.second.c_str()); + if (lws_add_http_header_by_name(wsi, name, value, header.second.length(), p, (*p) + len) != 0) { + return -1; + } + } + if (out_buffer.length() > LWS_PRE) { + lws_client_http_body_pending(wsi, 1); + lws_callback_on_writable(wsi); + } else { + lws_client_http_body_pending(wsi, 0); + } + } + break; + case LWS_CALLBACK_CLIENT_HTTP_WRITEABLE: + if (!lws_http_is_redirected_to_get(wsi) && out_buffer.length() > LWS_PRE) { + int len = std::min(int(out_buffer.length() - LWS_PRE), BUFSIZ); + lws_write_protocol proto = out_buffer.length() - LWS_PRE > BUFSIZ ? LWS_WRITE_HTTP : LWS_WRITE_HTTP_FINAL; + int res = lws_write(wsi, reinterpret_cast(&out_buffer[LWS_PRE]), len, proto); + if (res > 0) { + out_buffer.erase(LWS_PRE, res); + } + if (out_buffer.length() > LWS_PRE) { + lws_callback_on_writable(wsi); + } else { + lws_client_http_body_pending(wsi, 0); + } + } + break; + case LWS_CALLBACK_WSI_CREATE: + case LWS_CALLBACK_OPENSSL_PERFORM_SERVER_CERT_VERIFICATION: + case LWS_CALLBACK_HTTP_DROP_PROTOCOL: + case LWS_CALLBACK_CLOSED_CLIENT_HTTP: + break; + default: + std::cout << "unhandled https connection proto callback, reason: " << reason << ", in: " << in << ", len: " << len << std::endl; + if (in && len) { + std::cout << " DATA: \"" << std::string(static_cast(in), len) << '"' << std::endl; + } + break; + } + return 0; +} + +PusherConnection::PusherConnection(Context &ctx) : info{0}, wsi(nullptr), connected(false) { - info.context = ctx; + info.context = ctx.GetContext(); info.opaque_user_data = this; // wss://alttp.localhorst.tv/app/nkmbiabdrtqnd8t19txs?protocol=7&client=js&version=8.3.0&flash=false info.address = "alttp.localhorst.tv"; @@ -21,7 +117,7 @@ Connection::Connection(lws_context *ctx) info.origin = "test"; info.protocol = "pusher"; info.ietf_version_or_minus_one = -1; - info.userdata = this; + info.userdata = &ctx; info.pwsi = &wsi; wsi = lws_client_connect_via_info(&info); if (!wsi) { @@ -31,7 +127,7 @@ Connection::Connection(lws_context *ctx) out_buffer.insert(0, LWS_PRE, '\0'); } -int Connection::ProtoCallback(lws_callback_reasons reason, void *in, size_t len) { +int PusherConnection::ProtoCallback(lws_callback_reasons reason, void *in, size_t len) { switch (reason) { case LWS_CALLBACK_CLIENT_ESTABLISHED: connected = true; @@ -53,12 +149,12 @@ int Connection::ProtoCallback(lws_callback_reasons reason, void *in, size_t len) break; case LWS_CALLBACK_CLIENT_WRITEABLE: if (out_buffer.length() > LWS_PRE) { - int res = lws_write(wsi, reinterpret_cast(&out_buffer[0]) + LWS_PRE, out_buffer.length() - LWS_PRE, LWS_WRITE_TEXT); + int res = lws_write(wsi, reinterpret_cast(&out_buffer[LWS_PRE]), out_buffer.length() - LWS_PRE, LWS_WRITE_TEXT); if (res > 0) { out_buffer.erase(LWS_PRE, res); } - break; } + break; case LWS_CALLBACK_TIMER: Ping(); lws_set_timer_usecs(wsi, 30000000); @@ -77,7 +173,94 @@ int Connection::ProtoCallback(lws_callback_reasons reason, void *in, size_t len) case LWS_CALLBACK_WSI_CREATE: break; default: - std::cout << "unhandled connection proto callback, reason: " << reason << ", in: " << in << ", len: " << len << std::endl; + std::cout << "unhandled pusher connection proto callback, reason: " << reason << ", in: " << in << ", len: " << len << std::endl; + if (in && len) { + std::cout << " DATA: \"" << std::string(static_cast(in), len) << '"' << std::endl; + } + break; + } + return 0; +} + +TwitchConnection::TwitchConnection(Context &ctx) +: ctx(ctx), info{0}, wsi(nullptr), connected(false), authenticated(false) { + info.context = ctx.GetContext(); + info.opaque_user_data = this; + // wss://irc-ws.chat.twitch.tv:443 + info.address = "irc-ws.chat.twitch.tv"; + info.port = 443; + info.ssl_connection = 1; + info.path = "/"; + info.host = "irc-ws.chat.twitch.tv"; + info.origin = "test"; + info.protocol = "twitch"; + info.ietf_version_or_minus_one = -1; + info.userdata = &ctx; + info.pwsi = &wsi; + wsi = lws_client_connect_via_info(&info); + if (!wsi) { + throw std::runtime_error("failed to connect client"); + } + lws_set_timer_usecs(wsi, 30000000); + out_buffer.insert(0, LWS_PRE, '\0'); + token.Load(); +} + +int TwitchConnection::ProtoCallback(lws_callback_reasons reason, void *in, size_t len) { + switch (reason) { + case LWS_CALLBACK_CLIENT_ESTABLISHED: + connected = true; + OnConnect(); + if (out_buffer.length() > LWS_PRE) { + lws_callback_on_writable(wsi); + } + break; + case LWS_CALLBACK_CLIENT_CLOSED: + connected = false; + break; + case LWS_CALLBACK_CLIENT_RECEIVE: + if (lws_is_first_fragment(wsi)) { + in_buffer.clear(); + } + in_buffer.append(static_cast(in), len); + if (lws_is_final_fragment(wsi)) { + HandleMessage(in_buffer); + } + // reset ping timer + lws_set_timer_usecs(wsi, 30000000); + break; + case LWS_CALLBACK_CLIENT_WRITEABLE: + if (out_buffer.length() > LWS_PRE) { + size_t pos = out_buffer.find('\n', LWS_PRE); + size_t len = pos == std::string::npos ? out_buffer.length() : pos + 1; + int res = lws_write(wsi, reinterpret_cast(&out_buffer[LWS_PRE]), len - LWS_PRE, LWS_WRITE_TEXT); + if (res > 0) { + out_buffer.erase(LWS_PRE, res); + } + if (out_buffer.length() > LWS_PRE) { + lws_callback_on_writable(wsi); + } + break; + } + case LWS_CALLBACK_TIMER: + Ping(); + lws_set_timer_usecs(wsi, 60000000); + break; + case LWS_CALLBACK_CLIENT_RECEIVE_PONG: + case LWS_CALLBACK_CLIENT_HTTP_BIND_PROTOCOL: + case LWS_CALLBACK_CLIENT_HTTP_DROP_PROTOCOL: + case LWS_CALLBACK_WS_CLIENT_BIND_PROTOCOL: + case LWS_CALLBACK_WS_CLIENT_DROP_PROTOCOL: + case LWS_CALLBACK_OPENSSL_PERFORM_SERVER_CERT_VERIFICATION: + case LWS_CALLBACK_CLIENT_APPEND_HANDSHAKE_HEADER: + case LWS_CALLBACK_ESTABLISHED_CLIENT_HTTP: + case LWS_CALLBACK_CLOSED_CLIENT_HTTP: + case LWS_CALLBACK_SERVER_NEW_CLIENT_INSTANTIATED: + case LWS_CALLBACK_CLIENT_FILTER_PRE_ESTABLISH: + case LWS_CALLBACK_WSI_CREATE: + break; + default: + std::cout << "unhandled twitch connection proto callback, reason: " << reason << ", in: " << in << ", len: " << len << std::endl; if (in && len) { std::cout << " DATA: \"" << std::string(static_cast(in), len) << '"' << std::endl; } diff --git a/src/ws/Context.h b/src/ws/Context.h index 7d69981..c9abe2c 100644 --- a/src/ws/Context.h +++ b/src/ws/Context.h @@ -3,13 +3,18 @@ #include #include +#include #include #include #include +#include +#include -#include "Connection.h" +#include "HttpsConnection.h" #include "io.h" +#include "PusherConnection.h" +#include "TwitchConnection.h" #include "../uv/Loop.h" namespace ws { @@ -17,14 +22,27 @@ namespace ws { class Context { public: - explicit Context(uv::Loop &loop): info{0}, ctx(nullptr), proto{0}, protos{0}, loops{0} { + explicit Context(uv::Loop &loop) + : info{0}, ctx(nullptr), https_proto{0}, pusher_proto{0}, twitch_proto{0}, protos{0}, loops{0} { //lws_set_log_level(LLL_USER|LLL_ERR|LLL_WARN|LLL_NOTICE|LLL_INFO|LLL_DEBUG, nullptr); - proto.name = "pusher"; - proto.callback = &proto_callback; - proto.user = this; - proto.rx_buffer_size = BUFSIZ; - proto.tx_packet_size = BUFSIZ; - protos[0] = &proto; + https_proto.name = "https"; + https_proto.callback = &https_callback; + https_proto.user = this; + https_proto.rx_buffer_size = BUFSIZ; + https_proto.tx_packet_size = BUFSIZ; + protos[0] = &https_proto; + pusher_proto.name = "pusher"; + pusher_proto.callback = &pusher_callback; + pusher_proto.user = this; + pusher_proto.rx_buffer_size = BUFSIZ; + pusher_proto.tx_packet_size = BUFSIZ; + protos[1] = &pusher_proto; + twitch_proto.name = "twitch"; + twitch_proto.callback = &twitch_callback; + twitch_proto.user = this; + twitch_proto.rx_buffer_size = BUFSIZ; + twitch_proto.tx_packet_size = BUFSIZ; + protos[2] = &twitch_proto; info.options = LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT | LWS_SERVER_OPTION_LIBUV; info.port = CONTEXT_PORT_NO_LISTEN; info.pprotocols = protos; @@ -47,29 +65,113 @@ public: return ctx; } + HttpsConnection &HttpsRequest(const char *method, const char *host, const char *path) { + std::unique_ptr con = std::make_unique(*this, method, host, path); + https_connections.emplace_back(std::move(con)); + return *https_connections.back(); + } + void Shutdown() { lws_context_deprecate(ctx, nullptr); } private: - static int proto_callback(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) { - void *user_data = lws_wsi_user(wsi); + static int https_callback(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) { + void *user_data = lws_get_opaque_user_data(wsi); + Context *c = static_cast(user); if (user_data) { - Connection *conn = static_cast(user_data); + HttpsConnection *conn = static_cast(user_data); + if (reason == LWS_CALLBACK_WSI_DESTROY) { + c->RemoveHttpConnection(conn); + return 0; + } return conn->ProtoCallback(reason, in, len); } - Context *c = static_cast(user); - return c->ProtoCallback(reason, in, len); + if (c) { + return c->HttpsCallback(reason, in, len); + } + return 0; + } + + int HttpsCallback(lws_callback_reasons reason, void *in, size_t len) { + switch (reason) { + case LWS_CALLBACK_CLIENT_HTTP_BIND_PROTOCOL: + case LWS_CALLBACK_PROTOCOL_INIT: + case LWS_CALLBACK_PROTOCOL_DESTROY: + case LWS_CALLBACK_OPENSSL_LOAD_EXTRA_CLIENT_VERIFY_CERTS: + break; + default: + std::cout << "unhandled generic https proto callback, reason: " << reason << ", in: " << in << ", len: " << len << std::endl; + if (in && len) { + std::cout << " DATA: \"" << std::string(static_cast(in), len) << '"' << std::endl; + } + break; + } + return 0; + } + + void RemoveHttpConnection(HttpsConnection *conn) { + for (auto i = https_connections.begin(); i != https_connections.end();) { + if (i->get() == conn) { + i = https_connections.erase(i); + } else { + ++i; + } + } + } + + static int pusher_callback(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) { + void *user_data = lws_get_opaque_user_data(wsi); + if (user_data) { + PusherConnection *conn = static_cast(user_data); + return conn->ProtoCallback(reason, in, len); + } + if (user) { + Context *c = static_cast(user); + return c->PusherCallback(reason, in, len); + } + return 0; + } + + int PusherCallback(lws_callback_reasons reason, void *in, size_t len) { + switch (reason) { + case LWS_CALLBACK_CLIENT_HTTP_BIND_PROTOCOL: + case LWS_CALLBACK_PROTOCOL_INIT: + case LWS_CALLBACK_PROTOCOL_DESTROY: + case LWS_CALLBACK_OPENSSL_LOAD_EXTRA_CLIENT_VERIFY_CERTS: + break; + default: + std::cout << "unhandled generic pusher proto callback, reason: " << reason << ", in: " << in << ", len: " << len << std::endl; + if (in && len) { + std::cout << " DATA: \"" << std::string(static_cast(in), len) << '"' << std::endl; + } + break; + } + return 0; } - int ProtoCallback(lws_callback_reasons reason, void *in, size_t len) { + static int twitch_callback(lws *wsi, lws_callback_reasons reason, void *user, void *in, size_t len) { + void *user_data = lws_get_opaque_user_data(wsi); + if (user_data) { + TwitchConnection *conn = static_cast(user_data); + return conn->ProtoCallback(reason, in, len); + } + if (user) { + Context *c = static_cast(user); + return c->TwitchCallback(reason, in, len); + } + return 0; + } + + int TwitchCallback(lws_callback_reasons reason, void *in, size_t len) { switch (reason) { + case LWS_CALLBACK_CLIENT_HTTP_BIND_PROTOCOL: case LWS_CALLBACK_PROTOCOL_INIT: case LWS_CALLBACK_PROTOCOL_DESTROY: case LWS_CALLBACK_OPENSSL_LOAD_EXTRA_CLIENT_VERIFY_CERTS: break; default: - std::cout << "unhandled generic proto callback, reason: " << reason << ", in: " << in << ", len: " << len << std::endl; + std::cout << "unhandled generic twitch proto callback, reason: " << reason << ", in: " << in << ", len: " << len << std::endl; if (in && len) { std::cout << " DATA: \"" << std::string(static_cast(in), len) << '"' << std::endl; } @@ -80,11 +182,15 @@ private: private: lws_context_creation_info info; - lws_protocols proto; - const lws_protocols *protos[2]; + lws_protocols https_proto; + lws_protocols pusher_proto; + lws_protocols twitch_proto; + const lws_protocols *protos[4]; void *loops[2]; lws_context *ctx; + std::vector> https_connections; + }; } diff --git a/src/ws/HttpsConnection.h b/src/ws/HttpsConnection.h new file mode 100644 index 0000000..bf07a4e --- /dev/null +++ b/src/ws/HttpsConnection.h @@ -0,0 +1,109 @@ +#ifndef TEST_WS_HTTPSCONNECTION_H_ +#define TEST_WS_HTTPSCONNECTION_H_ + +#include +#include +#include +#include + +#include "../sys/Promise.h" + +namespace ws { + +class Context; + +class HttpsConnection { + +public: + typedef sys::Promise PromiseType; + +public: + HttpsConnection(Context &ctx, const char *method, const char *host, const char *path); + ~HttpsConnection() { + } + + HttpsConnection(const HttpsConnection &) = delete; + HttpsConnection &operator =(const HttpsConnection &) = delete; + +private: + struct Callback { + void *user; + void (*callback)(void *, HttpsConnection &); + void Call(HttpsConnection &val) const { + (*callback)(user, val); + } + }; + +public: + void SetHeader(const std::string &name, const std::string &value) { + headers[name + ":"] = value; + } + + void SetContentLength() { + headers["Content-Length:"] = std::to_string(out_buffer.size() - LWS_PRE); + } + + void AddBody(const std::string &body) { + out_buffer.append(body); + } + + void AddFormUrlenc(const std::string &name, const std::string &value) { + out_buffer.reserve(out_buffer.size() + name.size() + value.size() + 3); + if (out_buffer.size() > LWS_PRE) { + out_buffer.push_back('&'); + } + AddFormUrlencPart(name); + out_buffer.push_back('='); + AddFormUrlencPart(value); + } + + void AddFormUrlencPart(const std::string &s) { + for (const char c : s) { + if (c == ' ') { + out_buffer.push_back('+'); + } else if (c < 32 || c > 127 || c == ':' || c == '/' || c == '?' || c == '#' || c == '[' || c == ']' || c == '@' || c == '!' || c == '$' || c == '&' || c == '\'' || c == '(' || c == ')' || c == '*' || c == '+' || c == ',' || c == ';' || c == '=' || c == '%') { + out_buffer.push_back('%'); + out_buffer.push_back(HexDigit(c / 16)); + out_buffer.push_back(HexDigit(c % 16)); + } else { + out_buffer.push_back(c); + } + } + } + + static char HexDigit(int i) { + return (i < 10) ? '0' + i : 'A' + (i - 10); + } + +public: + PromiseType &GetPromise() { + return promise; + } + + int GetStatus() const { + return status; + } + + const std::string &GetBody() const { + return in_buffer; + } + + int ProtoCallback(lws_callback_reasons reason, void *in, size_t len); + +private: + lws_client_connect_info info; + lws *wsi; + + std::string out_buffer; + std::map headers; + + char read_buffer[BUFSIZ]; + int status; + std::string in_buffer; + + PromiseType promise; +}; + +} + +#endif diff --git a/src/ws/Connection.h b/src/ws/PusherConnection.h similarity index 83% rename from src/ws/Connection.h rename to src/ws/PusherConnection.h index 76ec576..dfbc1f4 100644 --- a/src/ws/Connection.h +++ b/src/ws/PusherConnection.h @@ -1,28 +1,27 @@ -#ifndef TEST_WS_CONNECTION_H_ -#define TEST_WS_CONNECTION_H_ +#ifndef TEST_WS_PUSHERCONNECTION_H_ +#define TEST_WS_PUSHERCONNECTION_H_ -#include "json/reader.h" -#include "json/value.h" -#include "json/writer.h" #include #include #include +#include #include #include -#include namespace ws { -class Connection { +class Context; + +class PusherConnection { public: - explicit Connection(lws_context *ctx); - ~Connection() { + explicit PusherConnection(Context &ctx); + ~PusherConnection() { } - Connection(const Connection &) = delete; - Connection &operator =(const Connection &) = delete; + PusherConnection(const PusherConnection &) = delete; + PusherConnection &operator =(const PusherConnection &) = delete; private: struct Callback { diff --git a/src/ws/TwitchConnection.h b/src/ws/TwitchConnection.h new file mode 100644 index 0000000..7e67e34 --- /dev/null +++ b/src/ws/TwitchConnection.h @@ -0,0 +1,160 @@ +#ifndef TEST_WS_TWITCHCONNECTION_H_ +#define TEST_WS_TWITCHCONNECTION_H_ + +#include +#include +#include +#include +#include +#include + +#include + +#include "../twitch/IRCMessage.h" +#include "../twitch/LoginToken.h" + +namespace ws { + +class Context; + +class TwitchConnection { + +public: + explicit TwitchConnection(Context &ctx); + ~TwitchConnection() { + } + + TwitchConnection(const TwitchConnection &) = delete; + TwitchConnection &operator =(const TwitchConnection &) = delete; + +private: + struct Callback { + void *user; + void (*callback)(void *, const twitch::IRCMessage &); + void Call(const twitch::IRCMessage &val) const { + (*callback)(user, val); + } + }; + +public: + void OnConnect() { + SendMessage("CAP REQ :twitch.tv/tags twitch.tv/commands"); + if (token.HasExpired()) { + token.Refresh(ctx) + .Then([this](twitch::LoginToken &) -> void { + Login(); + }) + .Catch([this](twitch::LoginToken &) -> void { + std::cerr << "unable to refresh login token" << std::endl; + }); + } else { + Login(); + } + } + + void Ping() { + SendMessage("PING localhorst.tv"); + } + + void Login() { + SendMessage("Pass oauth:" + token.GetAccessToken()); + SendMessage("NICK HorstieBot"); + } + + void Join(const std::string &chan, void (*callback)(void *, const twitch::IRCMessage &), void *user = nullptr) { + callbacks[chan].push_back({ user, callback }); + if (authenticated && callbacks[chan].size() == 1) { + SendMessage("JOIN " + chan); + } + } + + void SendMessage(const twitch::IRCMessage &msg) { + msg.Encode(out_buffer); + out_buffer.append("\r\n"); + lws_callback_on_writable(wsi); + } + + void SendMessage(const std::string &msg) { + out_buffer.append(msg); + out_buffer.append("\r\n"); + lws_callback_on_writable(wsi); + } + + void SendMessage(const char *msg) { + out_buffer.append(msg); + out_buffer.append("\r\n"); + lws_callback_on_writable(wsi); + } + +public: + int ProtoCallback(lws_callback_reasons reason, void *in, size_t len); + + void HandleMessage(const std::string &msg) { + auto begin = msg.begin(); + auto end = msg.end(); + while (begin != end) { + auto part_end = std::find(begin, end, '\n'); + // skip newline character + if (part_end != end) { + ++part_end; + } + in_msg.Decode(begin, part_end); + HandleMessage(in_msg); + begin = part_end; + } + } + + void HandleMessage(const twitch::IRCMessage &msg) { + if (msg.IsPing()) { + SendMessage(msg.MakePong()); + return; + } + if (msg.IsPong()) { + return; + } + if (msg.IsLoginSuccess()) { + JoinChannels(); + return; + } + if (msg.IsPrivMsg()) { + HandlePrivMsg(msg); + return; + } + } + + void JoinChannels() { + for (const auto &entry : callbacks) { + SendMessage("JOIN " + entry.first); + } + } + + void HandlePrivMsg(const twitch::IRCMessage &msg) { + if (msg.params.empty()) return; + auto it = callbacks.find(msg.params[0]); + if (it != callbacks.end()) { + for (const Callback &callback : it->second) { + callback.Call(msg); + } + } + } + +private: + ws::Context &ctx; + lws_client_connect_info info; + lws *wsi; + bool connected; + bool authenticated; + + std::string in_buffer; + std::string out_buffer; + + std::map> callbacks; + + twitch::LoginToken token; + twitch::IRCMessage in_msg; + +}; + +} + +#endif diff --git a/src/ws/io.h b/src/ws/io.h index b60d852..d67bab9 100644 --- a/src/ws/io.h +++ b/src/ws/io.h @@ -94,9 +94,12 @@ inline std::ostream &operator <<(std::ostream &out, lws_callback_reasons r) { case LWS_CALLBACK_CLOSED_CLIENT_HTTP: out << "http client connection closed"; break; - case LWS_CALLBACK_RECEIVE_CLIENT_HTTP: + case LWS_CALLBACK_RECEIVE_CLIENT_HTTP_READ: out << "http client read"; break; + case LWS_CALLBACK_RECEIVE_CLIENT_HTTP: + out << "http client"; + break; case LWS_CALLBACK_COMPLETED_CLIENT_HTTP: out << "http client completed"; break; -- 2.39.2