diff options
| author | yum <yum.food.vr@gmail.com> | 2023-07-03 18:44:43 -0700 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2023-07-03 19:36:13 -0700 |
| commit | 76ae7c28ea6224b2c919122d5dc71bcc00a0ecaa (patch) | |
| tree | 9723fd02715d747cfc439d1f66d36821a56069e9 /BrowserSource/Proxy | |
| parent | 7888ccc96d001512dd3bdc01f299856e86c876f5 (diff) | |
Begin work on proxy server
Create a simple server with 3 endpoints:
* /create_session: Create a session and return its identifier.
* /set_transcript: Update a session's transcript.
* /get_transcript: Fetch a session's transcript.
Right now the session ID provides authentication *and* authorization.
There is no public/private ID so you have to trust whoever you share
your ID with.
IDs are long and generated by the server, so it should be somewhat
secure against low-effort hacking.
Other updates:
* Drop whisper_requirements.txt - no longer needed.
* Vendor curl to make it easier to interact with the server.
TODO:
* Fuzz test the server.
Diffstat (limited to 'BrowserSource/Proxy')
| -rw-r--r-- | BrowserSource/Proxy/.gitignore | 2 | ||||
| -rw-r--r-- | BrowserSource/Proxy/HTTPMapper.cpp | 69 | ||||
| -rw-r--r-- | BrowserSource/Proxy/HTTPMapper.h | 35 | ||||
| -rw-r--r-- | BrowserSource/Proxy/HTTPParser.cpp | 229 | ||||
| -rw-r--r-- | BrowserSource/Proxy/HTTPParser.h | 52 | ||||
| -rw-r--r-- | BrowserSource/Proxy/Logging.h | 19 | ||||
| -rw-r--r-- | BrowserSource/Proxy/Makefile | 37 | ||||
| -rw-r--r-- | BrowserSource/Proxy/README.md | 12 | ||||
| -rw-r--r-- | BrowserSource/Proxy/ScopeGuard.h | 32 | ||||
| -rw-r--r-- | BrowserSource/Proxy/Utils.h | 31 | ||||
| -rw-r--r-- | BrowserSource/Proxy/WebCommon.h | 8 | ||||
| -rw-r--r-- | BrowserSource/Proxy/WebServer.cpp | 226 | ||||
| -rw-r--r-- | BrowserSource/Proxy/WebServer.h | 48 | ||||
| -rw-r--r-- | BrowserSource/Proxy/build-foss.sh | 10 | ||||
| m--------- | BrowserSource/Proxy/fmt | 0 | ||||
| -rw-r--r-- | BrowserSource/Proxy/integration_test.sh | 46 | ||||
| -rw-r--r-- | BrowserSource/Proxy/server.cpp | 204 |
17 files changed, 1060 insertions, 0 deletions
diff --git a/BrowserSource/Proxy/.gitignore b/BrowserSource/Proxy/.gitignore new file mode 100644 index 0000000..6c46439 --- /dev/null +++ b/BrowserSource/Proxy/.gitignore @@ -0,0 +1,2 @@ +*.o +server diff --git a/BrowserSource/Proxy/HTTPMapper.cpp b/BrowserSource/Proxy/HTTPMapper.cpp new file mode 100644 index 0000000..af4f6d0 --- /dev/null +++ b/BrowserSource/Proxy/HTTPMapper.cpp @@ -0,0 +1,69 @@ +#include "HTTPMapper.h" + +#include <sstream> +#include <map> + +namespace { + // Source: RFC 2616 section 6.1.1 + const std::map<int, std::string> kStatusCodeToString{ + {100, "Continue" }, + {101, "Switching Protocols"}, + {200, "OK"}, + {201, "Created"}, + {202, "Accepted"}, + {203, "Non-Authoritative Information"}, + {204, "No Content"}, + {205, "Reset Content"}, + {206, "Partial Content"}, + {300, "Multiple Choices"}, + {301, "Moved Permanently"}, + {302, "Found"}, + {303, "See Other"}, + {304, "Not Modified"}, + {305, "Use Proxy"}, + {307, "Temporary Redirect"}, + {400, "Bad Request"}, + {401, "Unauthorized"}, + {402, "Payment Required"}, + {403, "Forbidden"}, + {404, "Not Found"}, + {405, "Method Not Allowed"}, + {406, "Not Acceptable"}, + }; +} + +namespace WebServer { + std::string HTTPMapper::Map(const int status_code, + const std::string& payload, const ContentType type) { + switch (type) { + case HTML: + return HTTPMapperHTML().Map(status_code, payload); + case JSON: + return HTTPMapperJSON().Map(status_code, payload); + } + } + + std::string HTTPMapperHTML::Map(const int status_code, + const std::string& payload) { + std::ostringstream oss; + // This might throw and crash the app, but that's ok, just don't use an unsupported code. + oss << "HTTP/1.1 " << status_code << " " << kStatusCodeToString.at(status_code) << "\r\n"; + oss << "Content-Type: text/html\r\n"; + oss << "Content-Length: " << std::to_string(payload.size()) << "\r\n"; + oss << "\r\n"; + oss << payload; + return oss.str(); + } + + std::string HTTPMapperJSON::Map(const int status_code, + const std::string& payload) { + std::ostringstream oss; + // This might throw and crash the app, but that's ok, just don't use an unsupported code. + oss << "HTTP/1.1 " << status_code << " " << kStatusCodeToString.at(status_code) << "\r\n"; + oss << "Content-Type: application/json\r\n"; + oss << "Content-Length: " << std::to_string(payload.size()) << "\r\n"; + oss << "\r\n"; + oss << payload; + return oss.str(); + } +}
\ No newline at end of file diff --git a/BrowserSource/Proxy/HTTPMapper.h b/BrowserSource/Proxy/HTTPMapper.h new file mode 100644 index 0000000..e349f6e --- /dev/null +++ b/BrowserSource/Proxy/HTTPMapper.h @@ -0,0 +1,35 @@ +#pragma once + +#include "WebCommon.h" + +#include <string> + +namespace WebServer { + + class HTTPMapper { + public: + HTTPMapper() {} + virtual ~HTTPMapper() {} + + std::string Map(int status_code, + const std::string& payload, ContentType type); + }; + + class HTTPMapperHTML : public HTTPMapper { + public: + HTTPMapperHTML() {} + virtual ~HTTPMapperHTML() {} + + std::string Map(int status_code, + const std::string& payload); + }; + + class HTTPMapperJSON : public HTTPMapper { + public: + HTTPMapperJSON() {} + virtual ~HTTPMapperJSON() {} + + std::string Map(int status_code, + const std::string& payload); + }; +} diff --git a/BrowserSource/Proxy/HTTPParser.cpp b/BrowserSource/Proxy/HTTPParser.cpp new file mode 100644 index 0000000..4f6c850 --- /dev/null +++ b/BrowserSource/Proxy/HTTPParser.cpp @@ -0,0 +1,229 @@ +#include "HTTPParser.h" +#include "Logging.h" +#include "ScopeGuard.h" + +#include <sstream> +#include <string.h> +#include <string_view> + +using ::Logging::Log; + +namespace WebServer { + HTTPParser::HTTPParser() {} + + namespace { + constexpr const char kLineDelim[] = "\r\n"; + constexpr const char kHeadersDelim[] = "\r\n\r\n"; + constexpr const char kRfcLWS[] = " \t\r\n"; + }; + + bool HTTPParser::Parse(const std::string& raw_http, std::string& err) { + std::ostringstream err_oss; + ScopeGuard err_oss_flush([&]() { err += err_oss.str(); }); + + ParserState state = PARSER_STATE_START_LINE; + size_t pos = 0; + while (pos < raw_http.length()) { + size_t end; + switch (state) { + case PARSER_STATE_START_LINE: + end = raw_http.find(kLineDelim, pos); + break; + case PARSER_STATE_HEADERS: + end = raw_http.find(kHeadersDelim, pos); + break; + case PARSER_STATE_PAYLOAD: + end = raw_http.length(); + break; + } + ScopeGuard advance_pos([&]() { pos = end + 1; }); + if (end == std::string::npos) { + err_oss << "Failed to parse HTTP in state " << state << ": No delimiter!" << std::endl; + return false; + } + std::string_view segment(raw_http.data() + pos, end - pos); + if (!ParseSegment(segment, state, err)) { + return false; + } + } + return true; + } + + const std::string& HTTPParser::GetMethod() const { + return method_; + } + + const std::string& HTTPParser::GetPath() const { + return path_; + } + + bool HTTPParser::GetHeader(const std::string& header, std::string& value) const { + auto iter = headers_.find(header); + if (iter == headers_.end()) { + return false; + } + value = iter->second; + return true; + } + + const std::map<std::string, std::string>& HTTPParser::GetHeaders() const { + return headers_; + } + + const std::string& HTTPParser::GetPayload() const { + return payload_; + } + + bool HTTPParser::ParseSegment( + const std::string_view segment, + ParserState& state, + std::string& err) { + std::ostringstream err_oss; + ScopeGuard err_oss_flush([&]() { err += err_oss.str(); }); + switch (state) { + case PARSER_STATE_START_LINE: + return ParseStartLine(segment, state, err); + case PARSER_STATE_HEADERS: + return ParseHeaders(segment, state, err); + case PARSER_STATE_PAYLOAD: + return ParsePayload(segment, state, err); + } + } + + enum StartLineParserState { + START_LINE_PARSER_STATE_METHOD, + START_LINE_PARSER_STATE_PATH, + START_LINE_PARSER_STATE_VERSION, + START_LINE_PARSER_STATE_END, + }; + // Source: RFC 2616 section 5.1.1. + bool HTTPParser::ParseStartLine( + const std::string_view segment, + ParserState& state, + std::string& err) { + std::ostringstream err_oss; + ScopeGuard err_oss_flush([&]() { err += err_oss.str(); }); + + // Request-Line = Method SP Request-URI SP HTTP-Version CRLF + // SP == space. + // Thus we expect to see exactly three space-delimited chunks. + StartLineParserState cur_state = START_LINE_PARSER_STATE_METHOD; + size_t pos = 0; + while (pos < segment.length()) { + size_t end = segment.find(' ', pos); + if (end == std::string::npos) { + end = segment.length(); + } + ScopeGuard advance_pos([&]() { pos = end + 1; }); + + std::string_view cur_segment(segment.data() + pos, end - pos); + switch (cur_state) { + case START_LINE_PARSER_STATE_METHOD: + method_ = cur_segment; + cur_state = START_LINE_PARSER_STATE_PATH; + continue; + case START_LINE_PARSER_STATE_PATH: + path_ = cur_segment; + cur_state = START_LINE_PARSER_STATE_VERSION; + continue; + case START_LINE_PARSER_STATE_VERSION: + // TODO(yum) check this + cur_state = START_LINE_PARSER_STATE_END; + continue; + case START_LINE_PARSER_STATE_END: + err_oss << "Invalid start line: has too many parts: " << segment << std::endl; + return false; + } + } + if (cur_state != START_LINE_PARSER_STATE_END) { + err_oss << "Invalid start line: missing parts: " << segment << std::endl; + return false; + } + + state = PARSER_STATE_HEADERS; + return true; + } + + // Source: RFC 2616 section 4.2. + bool HTTPParser::ParseHeaders( + const std::string_view segment, + ParserState& state, + std::string& err) { + std::ostringstream err_oss; + ScopeGuard err_oss_flush([&]() { err += err_oss.str(); }); + + // From the RFC: + // message-header = field-name ":" [ field-value ] + // field-name = token + // field-value = *(field-content | LWS) + // field-content = <the OCTETs making up the field - value + // and consisting of either * TEXT or combinations + // of token, separators, and quoted-string> + // Takewaways: + // * field-name is guaranteed to not be preceded by whitespace + // * field-name is guaranteed to be followed by ":" + // * field-value may be preceded by LWS + // * multi-line field-values are guaranteed to start with either ' ' + // or '\t' + size_t pos = 0; + std::string key, value; + while (pos < segment.length()) { + // Divide into lines. + size_t end = segment.find(kLineDelim, pos); + if (end == std::string::npos) { + end = segment.length(); + } + ScopeGuard advance_pos([&]() { pos = end + 1; }); + + std::string_view line = segment.substr(pos, end - pos); + if (line.empty()) { + continue; + } + + // Lengthen the current line to cover multi-line header. + while (end + 1 < segment.length() && + (segment[end + 1] == ' ' || segment[end + 1] == '\t')) { + end = segment.find("\r\n", end + 1); + } + + size_t sep = line.find(':'); + if (sep == std::string::npos) { + err_oss << "Invalid header: No ':' delimiter: " << segment << std::endl; + return false; + } + + std::string_view key = line.substr(0, sep); + size_t key_start = key.find_first_not_of(kRfcLWS); + size_t key_end = key.find_last_not_of(kRfcLWS); + key = key.substr(key_start, (key_end - key_start) + 1); + // Value may contain interspersed LWS (linear whitespace). + // Could scrub it out, but not necessary for our purposes. + std::string_view value = line.substr(sep + 1); + size_t value_start = value.find_first_not_of(kRfcLWS); + size_t value_end = value.find_last_not_of(kRfcLWS); + value = value.substr(value_start, (value_end - value_start) + 1); + + headers_[std::string(key)] = value; + } + + state = PARSER_STATE_PAYLOAD; + return true; + } + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunused-parameter" + bool HTTPParser::ParsePayload( + const std::string_view segment, + ParserState& state, + std::string& err) { + + const char kScuffedHeadersDelim[] = "\n\r\n"; + if (!segment.starts_with(kScuffedHeadersDelim)) { + return true; + } + + payload_ = segment.substr(strlen(kScuffedHeadersDelim)); + return true; + } +#pragma clang diagnostic pop +} diff --git a/BrowserSource/Proxy/HTTPParser.h b/BrowserSource/Proxy/HTTPParser.h new file mode 100644 index 0000000..e97f896 --- /dev/null +++ b/BrowserSource/Proxy/HTTPParser.h @@ -0,0 +1,52 @@ +#pragma once + +#include <string> +#include <string_view> +#include <map> + +namespace WebServer { + + // A simple HTTP/1.1 message parser based on RFC 2616. + class HTTPParser + { + public: + HTTPParser(); + + bool Parse(const std::string& raw_http, std::string& err); + + const std::string& GetMethod() const; + const std::string& GetPath() const; + bool GetHeader(const std::string& header, std::string& value) const; + const std::map<std::string, std::string>& GetHeaders() const; + const std::string& GetPayload() const; + + private: + enum ParserState { + PARSER_STATE_START_LINE, + PARSER_STATE_HEADERS, + PARSER_STATE_PAYLOAD, + }; + + bool ParseSegment( + const std::string_view segment, + ParserState& state, + std::string& err); + bool ParseStartLine( + const std::string_view segment, + ParserState& state, + std::string& err); + bool ParseHeaders( + const std::string_view segment, + ParserState& state, + std::string& err); + bool ParsePayload( + const std::string_view segment, + ParserState& state, + std::string& err); + + std::string method_; + std::string path_; + std::map<std::string, std::string> headers_; + std::string payload_; + }; +} diff --git a/BrowserSource/Proxy/Logging.h b/BrowserSource/Proxy/Logging.h new file mode 100644 index 0000000..767821f --- /dev/null +++ b/BrowserSource/Proxy/Logging.h @@ -0,0 +1,19 @@ +#pragma once + +#pragma once + +#include <fmt/core.h> +#include <iostream> +#include <string> +#include <string_view> + +namespace Logging { + // Usage: Log("{}\n", "Hello, world!"); + template<typename... Args> + void Log(std::string_view format, Args&&... args) { + const std::string raw = fmt::vformat(format, fmt::make_format_args(args...)); + + std::cout << raw; + } +} + diff --git a/BrowserSource/Proxy/Makefile b/BrowserSource/Proxy/Makefile new file mode 100644 index 0000000..81eb814 --- /dev/null +++ b/BrowserSource/Proxy/Makefile @@ -0,0 +1,37 @@ +CC := clang++ + +MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST)))) + +DEFINES := +CFLAGS := -Wall -Wextra -std=c++20 $(DEFINES) -I$(MAKEFILE_DIR)/fmt/include +LDFLAGS := -L$(MAKEFILE_DIR)/fmt/build -lfmt -static + +SRCS := $(wildcard *.cpp) + +HDRS := $(wildcard *.h) + +OBJS := $(SRCS:.cpp=.o) + +EXE := server + +.PHONY: all +all: $(EXE) + +$(EXE): $(OBJS) + $(CC) -o $@ $^ $(LDFLAGS) + +# Hack: any header change causes a full recompilation of everything. +%.o: %.cpp $(HDRS) + $(CC) $(CFLAGS) -c -o $@ $< + +.PHONY: clean +clean: + @rm -f $(OBJS) $(EXE) + +.PHONY: debug +debug: + @echo "CC: $(CC)" + @echo "MAKEFILE_DIR: $(MAKEFILE_DIR)" + @echo "STT_TOP: $(STT_TOP)" + @echo "OBJS: $(OBJS)" + diff --git a/BrowserSource/Proxy/README.md b/BrowserSource/Proxy/README.md new file mode 100644 index 0000000..6c986b2 --- /dev/null +++ b/BrowserSource/Proxy/README.md @@ -0,0 +1,12 @@ +This is a Linux server. It receives transcripts from TaSTT and serves them to +peers using a session identifier. It serves the use case where a mute player +wishes to show their transcript on a friend's stream. + +Dependencies: +* clang-15 +* cmake +* fmtlib/fmt + +To build: +./build-foss.sh +make diff --git a/BrowserSource/Proxy/ScopeGuard.h b/BrowserSource/Proxy/ScopeGuard.h new file mode 100644 index 0000000..61bb64d --- /dev/null +++ b/BrowserSource/Proxy/ScopeGuard.h @@ -0,0 +1,32 @@ +#pragma once + +#include <functional> +#include <utility> + +class ScopeGuard { +public: + ScopeGuard(std::function<void()>&& cb) : cb_(std::move(cb)), active_(true) {} + ~ScopeGuard() { + Invoke(); + } + + ScopeGuard() = delete; + ScopeGuard(ScopeGuard&) = delete; + ScopeGuard(const ScopeGuard&) = delete; + ScopeGuard(ScopeGuard&&) = delete; + ScopeGuard& operator=(ScopeGuard&) = delete; + ScopeGuard& operator=(const ScopeGuard&) = delete; + + void Cancel() { active_ = false; } + + void Invoke() { + if (active_) { + cb_(); + active_ = false; + } + } + +private: + const std::function<void()> cb_; + bool active_; +}; diff --git a/BrowserSource/Proxy/Utils.h b/BrowserSource/Proxy/Utils.h new file mode 100644 index 0000000..af5bc65 --- /dev/null +++ b/BrowserSource/Proxy/Utils.h @@ -0,0 +1,31 @@ +#pragma once + +#include <random> +#include <string> +#include <array> + +std::string RandomString(std::size_t length) { + static const std::array<char, 62> characters{ + { + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', + 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', + 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', + 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', + 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', + 'y', 'z' + } + }; + + std::random_device rd; + std::default_random_engine generator(rd()); + std::uniform_int_distribution<> distribution(0, characters.size() - 1); + + std::string random_string; + + for (std::size_t i = 0; i < length; ++i) { + random_string += characters[distribution(generator)]; + } + + return random_string; +} diff --git a/BrowserSource/Proxy/WebCommon.h b/BrowserSource/Proxy/WebCommon.h new file mode 100644 index 0000000..506d13c --- /dev/null +++ b/BrowserSource/Proxy/WebCommon.h @@ -0,0 +1,8 @@ +#pragma once + +namespace WebServer { + enum ContentType { + HTML, + JSON, + }; +}; diff --git a/BrowserSource/Proxy/WebServer.cpp b/BrowserSource/Proxy/WebServer.cpp new file mode 100644 index 0000000..fab23d9 --- /dev/null +++ b/BrowserSource/Proxy/WebServer.cpp @@ -0,0 +1,226 @@ +#include "HTTPMapper.h" +#include "HTTPParser.h" +#include "Logging.h" +#include "ScopeGuard.h" +#include "WebServer.h" + +#include <arpa/inet.h> +#include <errno.h> +#include <fcntl.h> +#include <netinet/in.h> +#include <stdint.h> +#include <string.h> +#include <sys/socket.h> +#include <sys/types.h> + +using ::Logging::Log; + +namespace WebServer { + WebServer::WebServer(uint16_t port) + : port_(port) + { + default_handler_ = + [](int& status_code, std::string& payload, + ContentType& type) -> void { + status_code = 404; + payload = "404: No route to URI"; + type = HTML; + }; + } + + bool WebServer::RegisterPathHandler(const std::string& method, + const std::string& path, handler_t&& handler) { + dispatch_key_t key = GetDispatchKey(method, path); + if (dispatch_map_.contains(key)) { + Log("Failed to register path handler at {} {}: " + "Handler already exists!\n", method, path); + return false; + } + + dispatch_map_[key] = std::move(handler); + return true; + } + + void WebServer::RegisterDefaultHandler(handler_t&& handler) { + default_handler_ = std::move(handler); + } + + bool WebServer::Run(volatile bool* run) { + int sock = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (sock == -1) { + Log("Failed to create socket: {}\n", strerror(errno)); + return false; + } + ScopeGuard sock_cleanup([sock]() { close(sock); }); + + sockaddr_in saddr; + saddr.sin_family = AF_INET; + saddr.sin_addr.s_addr = INADDR_ANY; + saddr.sin_port = htons(port_); + if (bind(sock, (sockaddr*)&saddr, sizeof(saddr)) == -1) { + Log("Failed to bind to port {}: {}\n", port_, strerror(errno)); + return false; + } + + // enable non-blocking mode + int flags = fcntl(sock, F_GETFL, 0); + if (flags == -1) { + Log("Failed to get socket flags for port {}: {}\n", port_, strerror(errno)); + return false; + } + if (fcntl(sock, F_SETFL, flags | O_NONBLOCK) == -1) { + Log("Failed to enable non-blocking mode on socket {}: {}\n", port_, strerror(errno)); + return false; + } + + if (listen(sock, SOMAXCONN) == -1) { + Log("Failed to listen on port {}: {}\n", port_, strerror(errno)); + return false; + } + + struct sockaddr_in bound_addr; + socklen_t len = sizeof(bound_addr); + if (getsockname(sock, (struct sockaddr *)&bound_addr, &len) == -1) { + Log("Failed to get socket name: {}\n", strerror(errno)); + return false; + } + char ipstr[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, &bound_addr.sin_addr, ipstr, sizeof(ipstr)); + + Log("Server running on IP {} port {}\n", ipstr, port_); + + sockaddr_in peer_addr; + int accept_cnt = 0; + while (*run) { + socklen_t peer_addr_sz = sizeof(peer_addr); + int csock = accept(sock, (sockaddr*)&peer_addr, &peer_addr_sz); + if (csock == -1) { + if (errno == EWOULDBLOCK) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + continue; + } + Log("Accept failed: {}\n", strerror(errno)); + return false; + } + + // enable non-blocking mode + int flags = fcntl(csock, F_GETFL, 0); + if (flags == -1) { + Log("Failed to get socket flags for client on port {}: {}\n", port_, strerror(errno)); + return false; + } + if (fcntl(csock, F_SETFL, flags | O_NONBLOCK) == -1) { + Log("Failed to enable non-blocking mode on client socket {}: {}\n", port_, strerror(errno)); + return false; + } + + // Periodically cull dead connections to prevent runaway memory usage. + ++accept_cnt; + if (accept_cnt % 10 == 0) { + std::vector<std::future<void>> alive_conn; + for (size_t i = 0; i < connections_.size(); i++) { + if (connections_[i].valid()) { + continue; + } + alive_conn.push_back(std::move(connections_[i])); + } + //Log("Culled {} dead connections\n", connections_.size() - alive_conn.size()); + connections_ = std::move(alive_conn); + accept_cnt = 0; // Prevent overflow + } + + const auto& dispatch_map = dispatch_map_; + const auto& default_handler = default_handler_; + connections_.push_back(std::async(std::launch::async, + [csock, peer_addr, run, dispatch_map, default_handler]() -> void { + ScopeGuard csock_cleanup([csock]() { close(csock); }); + char peer_ip_str[INET_ADDRSTRLEN]{}; + inet_ntop(AF_INET, &peer_addr.sin_addr, peer_ip_str, sizeof(peer_ip_str)); + + std::string buf(4096 * 16, 0); + int cur_bytes_read = 0; + int sum_bytes_read = 0; + + // Drain socket until we see a valid HTTP message. + while (*run) { + cur_bytes_read = recv(csock, buf.data() + sum_bytes_read, + buf.size() - (1 + sum_bytes_read), /*flags=*/0); + if (cur_bytes_read == -1) { + if (errno == EWOULDBLOCK || errno == EAGAIN) { + // Client may try to keep the connection open, + // so see if there's a complete request in the + // buffer. If so, terminate the recv loop. + HTTPParser p; + std::string err; + if (p.Parse(buf, err)) { + // In general we should verify that we got a + // full message, but since we only need to + // support GET, this is unnecessary. + cur_bytes_read = 0; + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + continue; + } + break; + } + sum_bytes_read += cur_bytes_read; + if (cur_bytes_read == 0) { + break; + } + } + if (cur_bytes_read == -1) { + Log("Failed to read client socket: {}\n", strerror(errno)); + return; + } + // Edge case: Server was stopped in the middle of serving a request. + if (!*run) { + Log("Server stop requested, bail out!\n"); + return; + } + buf.resize(sum_bytes_read); + + // Parse HTTP. Expect this to succeed, since we only exit the loop once the + // request parses. + // TODO(yum) this repeats work! The loop already parsed the request. + HTTPParser p; + std::string err; + if (!p.Parse(buf, err)) { + Log("Failed to parse client request: {}\n", err); + Log("Offending request:\n{}\n", buf); + return; + } + + // Find the dispatch handler for the requested method and path. + dispatch_key_t dispatch_key = GetDispatchKey(p.GetMethod(), p.GetPath()); + auto iter = dispatch_map.find(dispatch_key); + handler_t handler; + if (iter == dispatch_map.end()) { + handler = default_handler; + } else { + handler = iter->second; + } + + // Generate a response. + int status_code; + std::string payload = p.GetPayload(); + ContentType type; + handler(status_code, payload, type); + std::string response = HTTPMapper().Map(status_code, payload, type); + + // Send the response. + if (send(csock, response.data(), response.size(), /*flags=*/0) == -1) { + Log("Failed to send response to client: {}\n", strerror(errno)); + return; + } + + // Implicitly close the connection by exiting scope. We + // completely ignore keep-alive requests for now. Browsers + // should handle this well, there are many reasons why + // keep-alive requests may be ignored, such as transient + // network failures. + })); + } + return true; + } +} diff --git a/BrowserSource/Proxy/WebServer.h b/BrowserSource/Proxy/WebServer.h new file mode 100644 index 0000000..7815e89 --- /dev/null +++ b/BrowserSource/Proxy/WebServer.h @@ -0,0 +1,48 @@ +#pragma once + +#include <stdint.h> + +#include <functional> +#include <future> +#include <map> +#include <mutex> +#include <string> +#include <vector> + +#include "WebCommon.h" + +namespace WebServer { + class WebServer { + public: + WebServer(std::uint16_t port); + + typedef std::function<void( + int& status_code, + std::string& payload, + ContentType& type)> handler_t; + + bool RegisterPathHandler(const std::string& method, + const std::string& path, handler_t&& handler); + void RegisterDefaultHandler(handler_t&& handler); + + bool Run(volatile bool* run); + + private: + // Dispatch requests by mapping from (method, path) to handler. + // Dispatch key is (method, path) in that order. + typedef std::tuple<std::string, std::string> dispatch_key_t; + static inline dispatch_key_t GetDispatchKey(const std::string& method, const std::string& path) + { + return dispatch_key_t(method, path); + } + + typedef std::map<dispatch_key_t, handler_t> dispatch_map_t; + dispatch_map_t dispatch_map_; + handler_t default_handler_; + + const uint16_t port_; + + std::vector<std::future<void>> connections_; + }; +} + diff --git a/BrowserSource/Proxy/build-foss.sh b/BrowserSource/Proxy/build-foss.sh new file mode 100644 index 0000000..38cf5af --- /dev/null +++ b/BrowserSource/Proxy/build-foss.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +pushd fmt >/dev/null + +mkdir build +cd build +cmake .. +cmake --build . + +popd >/dev/null diff --git a/BrowserSource/Proxy/fmt b/BrowserSource/Proxy/fmt new file mode 160000 +Subproject de4705f84d16324c9496d0833e09656bc5d0e5f diff --git a/BrowserSource/Proxy/integration_test.sh b/BrowserSource/Proxy/integration_test.sh new file mode 100644 index 0000000..5aabc9e --- /dev/null +++ b/BrowserSource/Proxy/integration_test.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash + +SERVER_IP="$1" +if [ -z "$SERVER_IP" ]; then + echo "Usage: $0 server_ip" + exit 1 +fi +echo "Testing server at $SERVER_IP" + +SESSION_ID_0=$(curl -s -X GET $SERVER_IP:8080/api/v0/create_session | awk -F':' '{print $2}' | tr -d '}') +echo "Got session id $SESSION_ID_0" +if [ -z "$SESSION_ID_0" ]; then + echo "Expected session ID to be non-empty" + exit 1 +fi + +echo "Initial transcript should be empty" +T0=$(curl -s -X GET -d "$SESSION_ID_0" $SERVER_IP:8080/api/v0/get_transcript) +if [ "$T0" != "" ]; then + echo "Expected initial transcript to be empty, but got $T0" + exit 1 +fi +echo "Pass!" + +echo "Should be able to update transcript once" +T1_EXP="foo bar" +curl -s -X POST -d "$SESSION_ID_0 $T1_EXP" $SERVER_IP:8080/api/v0/set_transcript + +T1=$(curl -s -X GET -d "$SESSION_ID_0" $SERVER_IP:8080/api/v0/get_transcript) +if [ "$T1" != "$T1_EXP" ]; then + echo "Expected transcript to be $T1_EXP, but got $T1" + exit 1 +fi +echo "Pass!" + +echo "Subsequent update should overwrite" +T2_EXP="baz qux" +curl -s -X POST -d "$SESSION_ID_0 $T2_EXP" $SERVER_IP:8080/api/v0/set_transcript + +T2=$(curl -s -X GET -d "$SESSION_ID_0" $SERVER_IP:8080/api/v0/get_transcript) +if [ "$T2" != "$T2_EXP" ]; then + echo "Expected transcript to be $T2_EXP, but got $T2" + exit 1 +fi +echo "Pass!" + diff --git a/BrowserSource/Proxy/server.cpp b/BrowserSource/Proxy/server.cpp new file mode 100644 index 0000000..9e34fdf --- /dev/null +++ b/BrowserSource/Proxy/server.cpp @@ -0,0 +1,204 @@ +#include "Logging.h" +#include "Utils.h" +#include "WebServer.h" + +#include <chrono> +#include <iostream> +#include <mutex> +#include <set> +#include <sstream> +#include <unordered_map> + +using ::Logging::Log; + +class Sessions { + public: + struct SessionInfo { + std::chrono::time_point<std::chrono::steady_clock> creation_time; + std::string transcript; + }; + + // Create a new session and return its identifier. + std::string CreateSession() + { + // Each char in RandomString has 5.9 bits of entropy, so 16 chars gives + // us 94 bits of entropy - unlikely to ever collide or be guessed. + std::string id = RandomString(16); + + SessionInfo info; + info.creation_time = std::chrono::steady_clock::now(); + + std::scoped_lock l(mu_); + sessions_[id] = info; + + return id; + } + + // Look up a session by ID. The session info is copied into `info`. + bool GetSession(const std::string& id, SessionInfo& info) + { + std::scoped_lock l(mu_); + auto session_iter = sessions_.find(id); + if (session_iter == sessions_.end()) { + return false; + } + info = session_iter->second; + return true; + } + + void SetSession(const std::string& id, SessionInfo&& info) + { + std::scoped_lock l(mu_); + sessions_[id] = std::move(info); + } + + void PruneSessions(std::chrono::duration<double> max_age) + { + auto now = std::chrono::steady_clock::now(); + std::set<std::string> pruned_sessions; + { + std::scoped_lock l(mu_); + auto session_iter = sessions_.begin(); + while (session_iter != sessions_.end()) { + std::chrono::duration<double> age = now - session_iter->second.creation_time; + if (age > max_age) { + pruned_sessions.insert(session_iter->first); + sessions_.erase(session_iter++); + } else { + ++session_iter; + } + } + } + if (!pruned_sessions.empty()) { + std::ostringstream sessions_oss; + for (auto& session : pruned_sessions) { + sessions_oss << ' ' << session; + } + Log("Pruned sessions{}\n", sessions_oss.str()); + } + } + + private: + + std::mutex mu_; + std::unordered_map<std::string, SessionInfo> sessions_; +}; + +int main () { + WebServer::WebServer ws(8080); + Sessions s; + + // TODO rm + { + Sessions::SessionInfo info; + info.creation_time = std::chrono::steady_clock::now(); + s.SetSession("test_session", std::move(info)); + } + + ws.RegisterDefaultHandler( + [&](int& status_code, std::string& payload, + WebServer::ContentType& type) -> void { + + std::string resp = "Hello, world!\n"; + + status_code = 200; + payload = resp; + type = WebServer::HTML; + }); + + ws.RegisterPathHandler("GET", "/api/v0/create_session", + [&](int& status_code, std::string& payload, + WebServer::ContentType& type) -> void { + + std::string id = s.CreateSession(); + // Each char in RandomString has 5.9 bits of entropy, so 16 chars gives + // us 94 bits of entropy - unlikely to ever collide or be guessed. + std::string resp = "{session_id:" + id + "}"; + + Log("Created session {}\n", id); + + status_code = 200; + payload = resp; + type = WebServer::JSON; + }); + + ws.RegisterPathHandler("POST", "/api/v0/set_transcript", + [&](int& status_code, std::string& payload, + WebServer::ContentType& type) -> void { + + // Payload must look like "$session_id $transcript" + size_t space_pos = payload.find(' '); + std::string session_id; + std::string transcript; + if (space_pos == std::string::npos) { + session_id = payload; + } else { + session_id = payload.substr(0, space_pos); + transcript = payload.substr(space_pos + 1); + } + + Log("Updating session {}\n", session_id); + + Sessions::SessionInfo info; + if (!s.GetSession(session_id, info)) { + status_code = 404; + payload = "Failed to find session " + session_id; + type = WebServer::HTML; + return; + } + + info.transcript = transcript; + s.SetSession(session_id, std::move(info)); + + Log("Updated transcript of session {}: {}\n", session_id, transcript); + + status_code = 200; + payload.clear(); + type = WebServer::HTML; + }); + + ws.RegisterPathHandler("GET", "/api/v0/get_transcript", + [&](int& status_code, std::string& payload, + WebServer::ContentType& type) -> void { + + // Payload must look like "$session_id $transcript" + std::string session_id = payload; + + Sessions::SessionInfo info; + if (!s.GetSession(session_id, info)) { + status_code = 404; + payload = "Failed to find session " + session_id; + type = WebServer::HTML; + return; + } + + status_code = 200; + payload = info.transcript; + type = WebServer::HTML; + }); + + bool run = true; + auto server_thd = std::async(std::launch::async, [&]() -> void { + ws.Run(&run); + }); + auto prune_thd = std::async(std::launch::async, [&]() -> void { + while (run) { + std::this_thread::sleep_for(std::chrono::seconds(1)); + s.PruneSessions(std::chrono::days(1)); + } + }); + + Log("Started webserver. Press enter to exit.\n"); + std::string line; + while (std::getline(std::cin, line)) { + break; + } + run = false; + + // Wait for server to exit. + Log("Joining server thread...\n"); + server_thd.get(); + Log("Done!\n"); + + return 0; +} |
