start hacky tls self-roll
Some checks failed
Docker Build & Publish / build (push) Failing after 1s

This commit is contained in:
Nicholas Orlowsky 2025-02-05 13:31:20 -05:00
parent 10ca7f9f51
commit 7b63f846d7
Signed by: nickorlow
GPG key ID: 838827D8C4611687
4 changed files with 355 additions and 25 deletions

View file

@ -1,6 +1,7 @@
#include "./anthracite.hpp"
#include "./log/log.hpp"
#include "./socket/socket.hpp"
#include "./socket/tls_socket.hpp"
#include "backends/file_backend.hpp"
#include <chrono>
#include <condition_variable>
@ -24,7 +25,7 @@ using std::chrono::duration_cast;
using std::chrono::duration;
using std::chrono::milliseconds;
void handle_client(socket::anthracite_socket s, backends::backend& b, backends::file_backend& fb, std::mutex& thread_wait_mutex, std::condition_variable& thread_wait_condvar, int& active_threads)
void handle_client(socket::tls_socket s, backends::backend& b, backends::file_backend& fb, std::mutex& thread_wait_mutex, std::condition_variable& thread_wait_condvar, int& active_threads)
{
while (true) {
std::string raw_request = s.recv_message(http::HEADER_BYTES);
@ -38,20 +39,7 @@ void handle_client(socket::anthracite_socket s, backends::backend& b, backends::
break;
}
http::request req(raw_request, s.get_client_ip());
std::unique_ptr<http::response> resp = req.is_supported_version() ? b.handle_request(req) : fb.handle_error(http::status_codes::HTTP_VERSION_NOT_SUPPORTED);
std::string header = resp->header_to_string();
s.send_message(header);
s.send_message(resp->content());
auto end = high_resolution_clock::now();
auto ms_int = duration_cast<std::chrono::microseconds>(end-start);
log_request_and_response(req, resp , ms_int.count());
resp.reset();
if (req.close_connection()) {
break;
}
continue;
}
s.close_conn();
{
@ -63,17 +51,16 @@ void handle_client(socket::anthracite_socket s, backends::backend& b, backends::
int anthracite_main(int argc, char** argv, backends::backend& be)
{
log::logger.initialize(log::LOG_LEVEL_INFO);
log::logger.initialize(log::LOG_LEVEL_DEBUG);
auto args = std::span(argv, size_t(argc));
int port_number = default_port;
if (argc > 1) {
port_number = atoi(args[1]);
}
log::verbose << "Initializing Anthracite" << std::endl;
socket::anthracite_socket s(port_number);
socket::tls_socket s(port_number);
backends::file_backend fb(argc > 2 ? args[2] : "./www");
log::verbose << "Initialization Complete" << std::endl;
log::info << "Listening for HTTP connections on port " << port_number << std::endl;
int active_threads = 0;

View file

@ -1,3 +1,5 @@
#pragma once
#include <arpa/inet.h>
#include <malloc.h>
#include <netinet/in.h>
@ -10,8 +12,8 @@ namespace anthracite::socket {
class anthracite_socket {
protected:
static const int MAX_QUEUE_LENGTH = 100;
private:
int server_socket;
int client_socket {};
std::string client_ip;
@ -22,11 +24,11 @@ private:
public:
anthracite_socket(int port, int max_queue = MAX_QUEUE_LENGTH);
void wait_for_conn();
const std::string& get_client_ip();
void close_conn();
void send_message(std::string& msg);
std::string recv_message(int buffer_size);
virtual void wait_for_conn();
virtual const std::string& get_client_ip();
virtual void close_conn();
virtual void send_message(std::string& msg);
virtual std::string recv_message(int buffer_size);
};
};

106
lib/socket/tls_socket.cpp Normal file
View file

@ -0,0 +1,106 @@
#include "./tls_socket.hpp"
#include <arpa/inet.h>
#include <array>
#include <malloc.h>
#include <netinet/in.h>
#include <string>
#include <sys/socket.h>
#include <sys/time.h>
#include <unistd.h>
#include <vector>
#include "../log/log.hpp"
namespace anthracite::socket {
tls_socket::tls_socket(int port, int max_queue) : anthracite_socket(port, max_queue), _handshakeDone(false) {
}
void tls_socket::wait_for_conn()
{
client_ip = "";
client_socket = accept(server_socket, reinterpret_cast<struct sockaddr*>(&client_addr), &client_addr_len);
std::array<char, INET_ADDRSTRLEN> ip_str { 0 };
inet_ntop(AF_INET, &client_addr.sin_addr, ip_str.data(), INET_ADDRSTRLEN);
client_ip = std::string(ip_str.data());
}
void tls_socket::close_conn()
{
close(client_socket);
client_socket = -1;
}
void tls_socket::send_message(std::string& msg)
{
if (client_socket == -1) {
return;
}
send(client_socket, &msg[0], msg.length(), 0);
}
void tls_socket::perform_handshake() {
struct tls_msg_hdr hdr{};
ssize_t result = recv(client_socket, &hdr, sizeof(hdr), 0);
if (result < 1) {
return;
}
log::info << "MsgType " << unsigned(hdr.msg_type);
log::info << " MsgLen " << hdr.length << std::endl;
char hhdr[4];
result = recv(client_socket, &hhdr, sizeof(hhdr), 0);
if (result < 1) {
return;
}
uint16_t msg_size = ClientHello::deserialize_uint16(hhdr + 2);
log::debug << "TLS ClientHello Size: " << msg_size << std::endl;
char* client_hello_data = (char*) malloc(msg_size);
result = recv(client_socket, client_hello_data, msg_size, 0);
std::cout << result << " Bytes rxd" << std::endl;
ClientHello hello_msg(client_hello_data, result);
char *ptr;
ServerHello hello_retmsg(hello_msg.session_id);
int size = hello_retmsg.get_buf(&ptr);
log::debug << "Sending message of length " << size << std::endl;
send(client_socket, ptr , size, 0);
for(;;){}
_handshakeDone = true;
}
std::string tls_socket::recv_message(int buffer_size)
{
if (client_socket == -1) {
return "";
}
setsockopt(client_socket, SOL_SOCKET, SO_RCVTIMEO, &timeout_tv, sizeof timeout_tv);
if (!_handshakeDone) {
perform_handshake();
return "";
}
std::vector<char> response(buffer_size + 1);
ssize_t result = recv(client_socket, response.data(), buffer_size + 1, 0);
if (result < 1) {
return "";
}
response[buffer_size] = '\0';
return { response.data() };
}
};

235
lib/socket/tls_socket.hpp Normal file
View file

@ -0,0 +1,235 @@
#pragma once
#include <arpa/inet.h>
#include <cstdlib>
#include <malloc.h>
#include <netinet/in.h>
#include <string>
#include <sys/socket.h>
#include <sys/time.h>
#include <unistd.h>
#include "./socket.hpp"
#include <vector>
#include <array>
#include <iostream>
#include "../log/log.hpp"
namespace anthracite::socket {
constexpr uint32_t TLS_MSGHDR_RXSIZE = 4;
struct __attribute__((packed)) tls_version {
uint8_t major;
uint8_t minor;
};
struct __attribute__((packed)) tls_msg_hdr {
uint8_t msg_type;
tls_version version;
uint16_t length;
};
struct tls_extension {
uint16_t extension_type;
std::vector<char> data;
};
class ServerHello {
public:
struct tls_version server_version;
std::array<uint8_t, 32> random_bytes;
std::array<uint8_t, 32> _session_id;
uint16_t cipher;
uint8_t compression;
ServerHello(std::array<uint8_t, 32> session_id) {
srand(time(nullptr));
server_version.major = 3;
server_version.minor = 3;
for (int i = 0; i < 32; i++) {
random_bytes[i] = rand() % 256;
}
_session_id = session_id;
// TLS_RSA_WITH_NULL_MD5
cipher = 1;
// None
compression = 0;
}
int get_buf(char** bufptr) {
constexpr int msgsize = 2 + 32 + 1 + 32 + 2 + 1 + 7;
constexpr int mmsgsize = msgsize + 4;
constexpr int bufsize = mmsgsize + 5;
*bufptr = (char*) malloc(bufsize);
char* buf = *bufptr;
buf[0] = 0x16;
buf[1] = 3;
buf[2] = 1;
buf[3] = (mmsgsize >> 8) & 0xFF;
buf[4] = (mmsgsize) & 0xFF;
buf[5] = 0x02;
buf[6] = (msgsize >> 16) & 0xFF;
buf[7] = (msgsize >> 8) & 0xFF;
buf[8] = (msgsize) & 0xFF;
buf[9] = server_version.major;
buf[10] = server_version.minor;
for(int i = 0; i < 32; i++) {
buf[i+11] = random_bytes[i];
}
buf[43] = 32;
for(int i = 0; i < 32; i++) {
buf[i+44] = _session_id[i];
}
// Cipher
buf[76] = 00;
buf[77] = 0x33;
// Compression
buf[78] = 00;
// Extensions Length
buf[79] = 00;
buf[80] = 01;
// Renegotiation
buf[81] = 0xFF;
buf[82] = 01;
// Disabled
buf[83] = 00;
buf[84] = 01;
buf[85] = 00;
return bufsize;
}
};
class ClientHello {
public:
struct tls_version client_version;
std::array<uint8_t, 32> random_bytes;
std::array<uint8_t, 32> session_id;
std::vector<uint16_t> cipher_suites;
std::vector<uint8_t> compression_methods;
std::vector<tls_extension> extensions;
static uint32_t deserialize_uint32(char *buffer)
{
uint32_t value = 0;
value |= buffer[0] << 24;
value |= buffer[1] << 16;
value |= buffer[2] << 8;
value |= buffer[3];
return value;
}
static uint32_t deserialize_uint24(char *buffer)
{
uint32_t value = 0;
value |= buffer[0] << 16;
value |= buffer[1] << 8;
value |= buffer[2];
return value;
}
static uint16_t deserialize_uint16(char *buffer)
{
uint32_t value = 0;
value |= buffer[0] << 8;
value |= buffer[1];
return value;
}
// TODO: Note that the security of this funciton is terrible and absolutely
// can cause nasty things to happen with a malformed message
ClientHello(char* buffer, ssize_t size) {
int bufptr = 0;
// Get version data
client_version.major = (uint8_t) buffer[bufptr++];
client_version.minor = (uint8_t) buffer[bufptr++];
log::debug << "TLS Version : maj " << unsigned(client_version.major) << " min " << unsigned(client_version.minor) << std::endl;
log::debug << "TLS Random Data: ";
for(int i = 0; i < 32; i++) {
random_bytes[i] = buffer[bufptr++];
log::debug << std::hex << unsigned(random_bytes[i]) << std::hex << " ";
}
log::debug << std::endl;
// Get session id
int session_id_length = (uint8_t) buffer[bufptr++];
log::debug << "TLS SesId Data : ";
for(int i = 0; i < session_id_length; i++) {
session_id[i] = buffer[bufptr++];
log::debug << std::hex << unsigned(session_id[i]) << " ";
}
log::debug << std::dec;
log::debug << std::endl;
// Get cipher suites
uint16_t cipher_suites_length = deserialize_uint16(&buffer[bufptr]);
bufptr += 2;
log::debug << cipher_suites_length << " cipher suites supported" << std::endl;
for(uint16_t i = 0; i < cipher_suites_length; i++) {
cipher_suites.push_back(deserialize_uint16(&buffer[bufptr]));
bufptr += 2;
}
// Get compression methods
uint16_t compression_methods_length = buffer[bufptr++];
log::debug << compression_methods_length << " compression methods supported" << std::endl;
for(uint16_t i = 0; i < compression_methods_length; i++) {
cipher_suites.push_back(buffer[bufptr]);
bufptr ++;
}
}
};
class tls_socket : anthracite_socket {
private:
bool _handshakeDone;
void perform_handshake();
public:
tls_socket(int port, int max_queue = MAX_QUEUE_LENGTH);
void wait_for_conn() override;
void close_conn() override;
void send_message(std::string& msg) override;
std::string recv_message(int buffer_size) override;
};
};