start hacky tls self-roll
Some checks failed
Docker Build & Publish / build (push) Failing after 1s

This commit is contained in:
Nicholas Orlowsky 2025-02-05 13:31:20 -05:00
parent 10ca7f9f51
commit 7b63f846d7
Signed by: nickorlow
GPG key ID: 838827D8C4611687
4 changed files with 355 additions and 25 deletions

View file

@ -1,3 +1,5 @@
#pragma once
#include <arpa/inet.h>
#include <malloc.h>
#include <netinet/in.h>
@ -10,8 +12,8 @@ namespace anthracite::socket {
class anthracite_socket {
protected:
static const int MAX_QUEUE_LENGTH = 100;
private:
int server_socket;
int client_socket {};
std::string client_ip;
@ -22,11 +24,11 @@ private:
public:
anthracite_socket(int port, int max_queue = MAX_QUEUE_LENGTH);
void wait_for_conn();
const std::string& get_client_ip();
void close_conn();
void send_message(std::string& msg);
std::string recv_message(int buffer_size);
virtual void wait_for_conn();
virtual const std::string& get_client_ip();
virtual void close_conn();
virtual void send_message(std::string& msg);
virtual std::string recv_message(int buffer_size);
};
};

106
lib/socket/tls_socket.cpp Normal file
View file

@ -0,0 +1,106 @@
#include "./tls_socket.hpp"
#include <arpa/inet.h>
#include <array>
#include <malloc.h>
#include <netinet/in.h>
#include <string>
#include <sys/socket.h>
#include <sys/time.h>
#include <unistd.h>
#include <vector>
#include "../log/log.hpp"
namespace anthracite::socket {
tls_socket::tls_socket(int port, int max_queue) : anthracite_socket(port, max_queue), _handshakeDone(false) {
}
void tls_socket::wait_for_conn()
{
client_ip = "";
client_socket = accept(server_socket, reinterpret_cast<struct sockaddr*>(&client_addr), &client_addr_len);
std::array<char, INET_ADDRSTRLEN> ip_str { 0 };
inet_ntop(AF_INET, &client_addr.sin_addr, ip_str.data(), INET_ADDRSTRLEN);
client_ip = std::string(ip_str.data());
}
void tls_socket::close_conn()
{
close(client_socket);
client_socket = -1;
}
void tls_socket::send_message(std::string& msg)
{
if (client_socket == -1) {
return;
}
send(client_socket, &msg[0], msg.length(), 0);
}
void tls_socket::perform_handshake() {
struct tls_msg_hdr hdr{};
ssize_t result = recv(client_socket, &hdr, sizeof(hdr), 0);
if (result < 1) {
return;
}
log::info << "MsgType " << unsigned(hdr.msg_type);
log::info << " MsgLen " << hdr.length << std::endl;
char hhdr[4];
result = recv(client_socket, &hhdr, sizeof(hhdr), 0);
if (result < 1) {
return;
}
uint16_t msg_size = ClientHello::deserialize_uint16(hhdr + 2);
log::debug << "TLS ClientHello Size: " << msg_size << std::endl;
char* client_hello_data = (char*) malloc(msg_size);
result = recv(client_socket, client_hello_data, msg_size, 0);
std::cout << result << " Bytes rxd" << std::endl;
ClientHello hello_msg(client_hello_data, result);
char *ptr;
ServerHello hello_retmsg(hello_msg.session_id);
int size = hello_retmsg.get_buf(&ptr);
log::debug << "Sending message of length " << size << std::endl;
send(client_socket, ptr , size, 0);
for(;;){}
_handshakeDone = true;
}
std::string tls_socket::recv_message(int buffer_size)
{
if (client_socket == -1) {
return "";
}
setsockopt(client_socket, SOL_SOCKET, SO_RCVTIMEO, &timeout_tv, sizeof timeout_tv);
if (!_handshakeDone) {
perform_handshake();
return "";
}
std::vector<char> response(buffer_size + 1);
ssize_t result = recv(client_socket, response.data(), buffer_size + 1, 0);
if (result < 1) {
return "";
}
response[buffer_size] = '\0';
return { response.data() };
}
};

235
lib/socket/tls_socket.hpp Normal file
View file

@ -0,0 +1,235 @@
#pragma once
#include <arpa/inet.h>
#include <cstdlib>
#include <malloc.h>
#include <netinet/in.h>
#include <string>
#include <sys/socket.h>
#include <sys/time.h>
#include <unistd.h>
#include "./socket.hpp"
#include <vector>
#include <array>
#include <iostream>
#include "../log/log.hpp"
namespace anthracite::socket {
constexpr uint32_t TLS_MSGHDR_RXSIZE = 4;
struct __attribute__((packed)) tls_version {
uint8_t major;
uint8_t minor;
};
struct __attribute__((packed)) tls_msg_hdr {
uint8_t msg_type;
tls_version version;
uint16_t length;
};
struct tls_extension {
uint16_t extension_type;
std::vector<char> data;
};
class ServerHello {
public:
struct tls_version server_version;
std::array<uint8_t, 32> random_bytes;
std::array<uint8_t, 32> _session_id;
uint16_t cipher;
uint8_t compression;
ServerHello(std::array<uint8_t, 32> session_id) {
srand(time(nullptr));
server_version.major = 3;
server_version.minor = 3;
for (int i = 0; i < 32; i++) {
random_bytes[i] = rand() % 256;
}
_session_id = session_id;
// TLS_RSA_WITH_NULL_MD5
cipher = 1;
// None
compression = 0;
}
int get_buf(char** bufptr) {
constexpr int msgsize = 2 + 32 + 1 + 32 + 2 + 1 + 7;
constexpr int mmsgsize = msgsize + 4;
constexpr int bufsize = mmsgsize + 5;
*bufptr = (char*) malloc(bufsize);
char* buf = *bufptr;
buf[0] = 0x16;
buf[1] = 3;
buf[2] = 1;
buf[3] = (mmsgsize >> 8) & 0xFF;
buf[4] = (mmsgsize) & 0xFF;
buf[5] = 0x02;
buf[6] = (msgsize >> 16) & 0xFF;
buf[7] = (msgsize >> 8) & 0xFF;
buf[8] = (msgsize) & 0xFF;
buf[9] = server_version.major;
buf[10] = server_version.minor;
for(int i = 0; i < 32; i++) {
buf[i+11] = random_bytes[i];
}
buf[43] = 32;
for(int i = 0; i < 32; i++) {
buf[i+44] = _session_id[i];
}
// Cipher
buf[76] = 00;
buf[77] = 0x33;
// Compression
buf[78] = 00;
// Extensions Length
buf[79] = 00;
buf[80] = 01;
// Renegotiation
buf[81] = 0xFF;
buf[82] = 01;
// Disabled
buf[83] = 00;
buf[84] = 01;
buf[85] = 00;
return bufsize;
}
};
class ClientHello {
public:
struct tls_version client_version;
std::array<uint8_t, 32> random_bytes;
std::array<uint8_t, 32> session_id;
std::vector<uint16_t> cipher_suites;
std::vector<uint8_t> compression_methods;
std::vector<tls_extension> extensions;
static uint32_t deserialize_uint32(char *buffer)
{
uint32_t value = 0;
value |= buffer[0] << 24;
value |= buffer[1] << 16;
value |= buffer[2] << 8;
value |= buffer[3];
return value;
}
static uint32_t deserialize_uint24(char *buffer)
{
uint32_t value = 0;
value |= buffer[0] << 16;
value |= buffer[1] << 8;
value |= buffer[2];
return value;
}
static uint16_t deserialize_uint16(char *buffer)
{
uint32_t value = 0;
value |= buffer[0] << 8;
value |= buffer[1];
return value;
}
// TODO: Note that the security of this funciton is terrible and absolutely
// can cause nasty things to happen with a malformed message
ClientHello(char* buffer, ssize_t size) {
int bufptr = 0;
// Get version data
client_version.major = (uint8_t) buffer[bufptr++];
client_version.minor = (uint8_t) buffer[bufptr++];
log::debug << "TLS Version : maj " << unsigned(client_version.major) << " min " << unsigned(client_version.minor) << std::endl;
log::debug << "TLS Random Data: ";
for(int i = 0; i < 32; i++) {
random_bytes[i] = buffer[bufptr++];
log::debug << std::hex << unsigned(random_bytes[i]) << std::hex << " ";
}
log::debug << std::endl;
// Get session id
int session_id_length = (uint8_t) buffer[bufptr++];
log::debug << "TLS SesId Data : ";
for(int i = 0; i < session_id_length; i++) {
session_id[i] = buffer[bufptr++];
log::debug << std::hex << unsigned(session_id[i]) << " ";
}
log::debug << std::dec;
log::debug << std::endl;
// Get cipher suites
uint16_t cipher_suites_length = deserialize_uint16(&buffer[bufptr]);
bufptr += 2;
log::debug << cipher_suites_length << " cipher suites supported" << std::endl;
for(uint16_t i = 0; i < cipher_suites_length; i++) {
cipher_suites.push_back(deserialize_uint16(&buffer[bufptr]));
bufptr += 2;
}
// Get compression methods
uint16_t compression_methods_length = buffer[bufptr++];
log::debug << compression_methods_length << " compression methods supported" << std::endl;
for(uint16_t i = 0; i < compression_methods_length; i++) {
cipher_suites.push_back(buffer[bufptr]);
bufptr ++;
}
}
};
class tls_socket : anthracite_socket {
private:
bool _handshakeDone;
void perform_handshake();
public:
tls_socket(int port, int max_queue = MAX_QUEUE_LENGTH);
void wait_for_conn() override;
void close_conn() override;
void send_message(std::string& msg) override;
std::string recv_message(int buffer_size) override;
};
};