diff --git a/lib/socket/openssl_socket.cpp b/lib/socket/openssl_socket.cpp new file mode 100644 index 0000000..6127a00 --- /dev/null +++ b/lib/socket/openssl_socket.cpp @@ -0,0 +1,100 @@ +#include +#include +#include +#include "./openssl_socket.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "../log/log.hpp" + +namespace anthracite::socket { + +SSL_CTX* openssl_socket::_context = nullptr; + +openssl_socket::openssl_socket(int port, int max_queue) + : anthracite_socket(port, max_queue) +{ + const SSL_METHOD *method = TLS_server_method(); + + if (_context == nullptr) { + _context = SSL_CTX_new(method); + } + + if (!_context) { + log::err << "Unable to initialize SSL" << std::endl; + throw std::exception(); + } + + if (SSL_CTX_use_certificate_file(_context, "cert.pem", SSL_FILETYPE_PEM) <= 0) { + log::err << "Unable to open cert.pem" << std::endl; + throw std::exception(); + } + + if (SSL_CTX_use_PrivateKey_file(_context, "key.pem", SSL_FILETYPE_PEM) <= 0 ) { + log::err << "Unable to open key.pem" << std::endl; + throw std::exception(); + } +} + +openssl_socket::~openssl_socket() = default; + +void openssl_socket::wait_for_conn() +{ + client_ip = ""; + client_socket = accept(server_socket, reinterpret_cast(&client_addr), &client_addr_len); + std::array ip_str { 0 }; + inet_ntop(AF_INET, &client_addr.sin_addr, ip_str.data(), INET_ADDRSTRLEN); + client_ip = std::string(ip_str.data()); + + _ssl = SSL_new(_context); + SSL_set_fd(_ssl, client_socket); + if (SSL_accept(_ssl) <= 0) { + log::warn << "Unable to open SSL connection with client" << std::endl; + client_ip = ""; + close(client_socket); + client_socket = -1; + } +} + +void openssl_socket::close_conn() +{ + SSL_shutdown(_ssl); + SSL_free(_ssl); + close(client_socket); + client_socket = -1; +} + +void openssl_socket::send_message(std::string& msg) +{ + if (client_socket == -1) { + return; + } + SSL_write(_ssl, &msg[0], msg.length()); +} + +std::string openssl_socket::recv_message(int buffer_size) +{ + if (client_socket == -1) { + return ""; + } + + setsockopt(client_socket, SOL_SOCKET, SO_RCVTIMEO, &timeout_tv, sizeof timeout_tv); + std::vector response(buffer_size + 1); + ssize_t result = SSL_read(_ssl, response.data(), buffer_size+1); + + if (result < 1) { + return ""; + } + + response[buffer_size] = '\0'; + return { response.data() }; +} + +}; diff --git a/lib/socket/openssl_socket.hpp b/lib/socket/openssl_socket.hpp new file mode 100644 index 0000000..2d79c9c --- /dev/null +++ b/lib/socket/openssl_socket.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include "./socket.hpp" +#include +#include + +namespace anthracite::socket { +class openssl_socket : public anthracite_socket { + private: + static SSL_CTX* _context; + SSL* _ssl; + + public: + openssl_socket(int port, int max_queue = MAX_QUEUE_LENGTH); + ~openssl_socket(); + + void wait_for_conn() override; + void close_conn() override; + void send_message(std::string& msg) override; + std::string recv_message(int buffer_size) override; +}; +};