diff --git a/lib/anthracite.cpp b/lib/anthracite.cpp index 51a646e..20c2373 100644 --- a/lib/anthracite.cpp +++ b/lib/anthracite.cpp @@ -1,6 +1,7 @@ #include "./anthracite.hpp" #include "./log/log.hpp" #include "./socket/socket.hpp" +#include "./socket/tls_socket.hpp" #include "backends/file_backend.hpp" #include #include @@ -24,7 +25,7 @@ using std::chrono::duration_cast; using std::chrono::duration; using std::chrono::milliseconds; -void handle_client(socket::anthracite_socket s, backends::backend& b, backends::file_backend& fb, std::mutex& thread_wait_mutex, std::condition_variable& thread_wait_condvar, int& active_threads) +void handle_client(socket::tls_socket s, backends::backend& b, backends::file_backend& fb, std::mutex& thread_wait_mutex, std::condition_variable& thread_wait_condvar, int& active_threads) { while (true) { std::string raw_request = s.recv_message(http::HEADER_BYTES); @@ -38,20 +39,7 @@ void handle_client(socket::anthracite_socket s, backends::backend& b, backends:: break; } - http::request req(raw_request, s.get_client_ip()); - std::unique_ptr resp = req.is_supported_version() ? b.handle_request(req) : fb.handle_error(http::status_codes::HTTP_VERSION_NOT_SUPPORTED); - std::string header = resp->header_to_string(); - s.send_message(header); - s.send_message(resp->content()); - - auto end = high_resolution_clock::now(); - auto ms_int = duration_cast(end-start); - log_request_and_response(req, resp , ms_int.count()); - - resp.reset(); - if (req.close_connection()) { - break; - } + continue; } s.close_conn(); { @@ -63,17 +51,16 @@ void handle_client(socket::anthracite_socket s, backends::backend& b, backends:: int anthracite_main(int argc, char** argv, backends::backend& be) { - log::logger.initialize(log::LOG_LEVEL_INFO); + log::logger.initialize(log::LOG_LEVEL_DEBUG); auto args = std::span(argv, size_t(argc)); int port_number = default_port; if (argc > 1) { port_number = atoi(args[1]); } - log::verbose << "Initializing Anthracite" << std::endl; - socket::anthracite_socket s(port_number); + + socket::tls_socket s(port_number); backends::file_backend fb(argc > 2 ? args[2] : "./www"); - log::verbose << "Initialization Complete" << std::endl; log::info << "Listening for HTTP connections on port " << port_number << std::endl; int active_threads = 0; diff --git a/lib/socket/socket.hpp b/lib/socket/socket.hpp index 7a57de7..dc924cc 100644 --- a/lib/socket/socket.hpp +++ b/lib/socket/socket.hpp @@ -1,3 +1,5 @@ +#pragma once + #include #include #include @@ -10,8 +12,8 @@ namespace anthracite::socket { class anthracite_socket { +protected: static const int MAX_QUEUE_LENGTH = 100; -private: int server_socket; int client_socket {}; std::string client_ip; @@ -22,11 +24,11 @@ private: public: anthracite_socket(int port, int max_queue = MAX_QUEUE_LENGTH); - void wait_for_conn(); - const std::string& get_client_ip(); - void close_conn(); - void send_message(std::string& msg); - std::string recv_message(int buffer_size); + virtual void wait_for_conn(); + virtual const std::string& get_client_ip(); + virtual void close_conn(); + virtual void send_message(std::string& msg); + virtual std::string recv_message(int buffer_size); }; }; diff --git a/lib/socket/tls_socket.cpp b/lib/socket/tls_socket.cpp new file mode 100644 index 0000000..9ae13e2 --- /dev/null +++ b/lib/socket/tls_socket.cpp @@ -0,0 +1,106 @@ +#include "./tls_socket.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "../log/log.hpp" + + +namespace anthracite::socket { + +tls_socket::tls_socket(int port, int max_queue) : anthracite_socket(port, max_queue), _handshakeDone(false) { +} + +void tls_socket::wait_for_conn() +{ + client_ip = ""; + client_socket = accept(server_socket, reinterpret_cast(&client_addr), &client_addr_len); + std::array ip_str { 0 }; + inet_ntop(AF_INET, &client_addr.sin_addr, ip_str.data(), INET_ADDRSTRLEN); + client_ip = std::string(ip_str.data()); +} + +void tls_socket::close_conn() +{ + close(client_socket); + client_socket = -1; +} + +void tls_socket::send_message(std::string& msg) +{ + if (client_socket == -1) { + return; + } + send(client_socket, &msg[0], msg.length(), 0); +} + +void tls_socket::perform_handshake() { + struct tls_msg_hdr hdr{}; + ssize_t result = recv(client_socket, &hdr, sizeof(hdr), 0); + + if (result < 1) { + return; + } + + log::info << "MsgType " << unsigned(hdr.msg_type); + log::info << " MsgLen " << hdr.length << std::endl; + + char hhdr[4]; + result = recv(client_socket, &hhdr, sizeof(hhdr), 0); + + if (result < 1) { + return; + } + + uint16_t msg_size = ClientHello::deserialize_uint16(hhdr + 2); + + log::debug << "TLS ClientHello Size: " << msg_size << std::endl; + + + char* client_hello_data = (char*) malloc(msg_size); + + result = recv(client_socket, client_hello_data, msg_size, 0); + + std::cout << result << " Bytes rxd" << std::endl; + + ClientHello hello_msg(client_hello_data, result); + + char *ptr; + ServerHello hello_retmsg(hello_msg.session_id); + int size = hello_retmsg.get_buf(&ptr); + log::debug << "Sending message of length " << size << std::endl; + send(client_socket, ptr , size, 0); + for(;;){} + _handshakeDone = true; +} + +std::string tls_socket::recv_message(int buffer_size) +{ + if (client_socket == -1) { + return ""; + } + + setsockopt(client_socket, SOL_SOCKET, SO_RCVTIMEO, &timeout_tv, sizeof timeout_tv); + + if (!_handshakeDone) { + perform_handshake(); + return ""; + } + + std::vector response(buffer_size + 1); + ssize_t result = recv(client_socket, response.data(), buffer_size + 1, 0); + + if (result < 1) { + return ""; + } + + response[buffer_size] = '\0'; + return { response.data() }; +} + +}; diff --git a/lib/socket/tls_socket.hpp b/lib/socket/tls_socket.hpp new file mode 100644 index 0000000..3eb09c7 --- /dev/null +++ b/lib/socket/tls_socket.hpp @@ -0,0 +1,235 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include "./socket.hpp" +#include +#include +#include +#include "../log/log.hpp" + +namespace anthracite::socket { + +constexpr uint32_t TLS_MSGHDR_RXSIZE = 4; + + +struct __attribute__((packed)) tls_version { + uint8_t major; + uint8_t minor; +}; + +struct __attribute__((packed)) tls_msg_hdr { + uint8_t msg_type; + tls_version version; + uint16_t length; +}; + +struct tls_extension { + uint16_t extension_type; + std::vector data; +}; + +class ServerHello { + public: + struct tls_version server_version; + std::array random_bytes; + std::array _session_id; + uint16_t cipher; + uint8_t compression; + + ServerHello(std::array session_id) { + srand(time(nullptr)); + server_version.major = 3; + server_version.minor = 3; + + for (int i = 0; i < 32; i++) { + random_bytes[i] = rand() % 256; + } + + _session_id = session_id; + + // TLS_RSA_WITH_NULL_MD5 + cipher = 1; + + // None + compression = 0; + } + + int get_buf(char** bufptr) { + constexpr int msgsize = 2 + 32 + 1 + 32 + 2 + 1 + 7; + constexpr int mmsgsize = msgsize + 4; + constexpr int bufsize = mmsgsize + 5; + + *bufptr = (char*) malloc(bufsize); + + char* buf = *bufptr; + + buf[0] = 0x16; + buf[1] = 3; + buf[2] = 1; + buf[3] = (mmsgsize >> 8) & 0xFF; + buf[4] = (mmsgsize) & 0xFF; + + + buf[5] = 0x02; + buf[6] = (msgsize >> 16) & 0xFF; + buf[7] = (msgsize >> 8) & 0xFF; + buf[8] = (msgsize) & 0xFF; + + buf[9] = server_version.major; + buf[10] = server_version.minor; + + for(int i = 0; i < 32; i++) { + buf[i+11] = random_bytes[i]; + } + + buf[43] = 32; + + for(int i = 0; i < 32; i++) { + buf[i+44] = _session_id[i]; + } + + // Cipher + buf[76] = 00; + buf[77] = 0x33; + + // Compression + buf[78] = 00; + + // Extensions Length + buf[79] = 00; + buf[80] = 01; + + // Renegotiation + buf[81] = 0xFF; + buf[82] = 01; + + // Disabled + buf[83] = 00; + buf[84] = 01; + buf[85] = 00; + + + return bufsize; + } + + + + +}; + + +class ClientHello { + public: + struct tls_version client_version; + std::array random_bytes; + std::array session_id; + std::vector cipher_suites; + std::vector compression_methods; + std::vector extensions; + + static uint32_t deserialize_uint32(char *buffer) + { + uint32_t value = 0; + + value |= buffer[0] << 24; + value |= buffer[1] << 16; + value |= buffer[2] << 8; + value |= buffer[3]; + return value; + } + + static uint32_t deserialize_uint24(char *buffer) + { + uint32_t value = 0; + + value |= buffer[0] << 16; + value |= buffer[1] << 8; + value |= buffer[2]; + return value; + } + + static uint16_t deserialize_uint16(char *buffer) + { + uint32_t value = 0; + + value |= buffer[0] << 8; + value |= buffer[1]; + + return value; + } + + + // TODO: Note that the security of this funciton is terrible and absolutely + // can cause nasty things to happen with a malformed message + ClientHello(char* buffer, ssize_t size) { + int bufptr = 0; + // Get version data + client_version.major = (uint8_t) buffer[bufptr++]; + client_version.minor = (uint8_t) buffer[bufptr++]; + log::debug << "TLS Version : maj " << unsigned(client_version.major) << " min " << unsigned(client_version.minor) << std::endl; + + log::debug << "TLS Random Data: "; + for(int i = 0; i < 32; i++) { + random_bytes[i] = buffer[bufptr++]; + log::debug << std::hex << unsigned(random_bytes[i]) << std::hex << " "; + } + log::debug << std::endl; + + // Get session id + int session_id_length = (uint8_t) buffer[bufptr++]; + log::debug << "TLS SesId Data : "; + for(int i = 0; i < session_id_length; i++) { + session_id[i] = buffer[bufptr++]; + log::debug << std::hex << unsigned(session_id[i]) << " "; + } + log::debug << std::dec; + log::debug << std::endl; + + // Get cipher suites + uint16_t cipher_suites_length = deserialize_uint16(&buffer[bufptr]); + bufptr += 2; + + log::debug << cipher_suites_length << " cipher suites supported" << std::endl; + + for(uint16_t i = 0; i < cipher_suites_length; i++) { + cipher_suites.push_back(deserialize_uint16(&buffer[bufptr])); + bufptr += 2; + } + + // Get compression methods + uint16_t compression_methods_length = buffer[bufptr++]; + + log::debug << compression_methods_length << " compression methods supported" << std::endl; + + for(uint16_t i = 0; i < compression_methods_length; i++) { + cipher_suites.push_back(buffer[bufptr]); + bufptr ++; + } + } +}; + +class tls_socket : anthracite_socket { + + + +private: + bool _handshakeDone; + + void perform_handshake(); + +public: + tls_socket(int port, int max_queue = MAX_QUEUE_LENGTH); + void wait_for_conn() override; + void close_conn() override; + void send_message(std::string& msg) override; + std::string recv_message(int buffer_size) override; +}; + +};