summaryrefslogtreecommitdiff
path: root/lib/server/SocketStreamTLS.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/server/SocketStreamTLS.cpp')
-rw-r--r--lib/server/SocketStreamTLS.cpp492
1 files changed, 492 insertions, 0 deletions
diff --git a/lib/server/SocketStreamTLS.cpp b/lib/server/SocketStreamTLS.cpp
new file mode 100644
index 00000000..19fdadd4
--- /dev/null
+++ b/lib/server/SocketStreamTLS.cpp
@@ -0,0 +1,492 @@
+// --------------------------------------------------------------------------
+//
+// File
+// Name: SocketStreamTLS.cpp
+// Purpose: Socket stream encrpyted and authenticated by TLS
+// Created: 2003/08/06
+//
+// --------------------------------------------------------------------------
+
+#include "Box.h"
+
+#define TLS_CLASS_IMPLEMENTATION_CPP
+#include <openssl/ssl.h>
+#include <openssl/bio.h>
+#include <errno.h>
+#include <fcntl.h>
+
+#ifndef WIN32
+#include <poll.h>
+#endif
+
+#include "SocketStreamTLS.h"
+#include "SSLLib.h"
+#include "ServerException.h"
+#include "TLSContext.h"
+#include "BoxTime.h"
+
+#include "MemLeakFindOn.h"
+
+// Allow 5 minutes to handshake (in milliseconds)
+#define TLS_HANDSHAKE_TIMEOUT (5*60*1000)
+
+// --------------------------------------------------------------------------
+//
+// Function
+// Name: SocketStreamTLS::SocketStreamTLS()
+// Purpose: Constructor
+// Created: 2003/08/06
+//
+// --------------------------------------------------------------------------
+SocketStreamTLS::SocketStreamTLS()
+ : mpSSL(0), mpBIO(0)
+{
+ ResetCounters();
+}
+
+// --------------------------------------------------------------------------
+//
+// Function
+// Name: SocketStreamTLS::SocketStreamTLS(int)
+// Purpose: Constructor, taking previously connected socket
+// Created: 2003/08/06
+//
+// --------------------------------------------------------------------------
+SocketStreamTLS::SocketStreamTLS(int socket)
+ : SocketStream(socket),
+ mpSSL(0), mpBIO(0)
+{
+}
+
+// --------------------------------------------------------------------------
+//
+// Function
+// Name: SocketStreamTLS::~SocketStreamTLS()
+// Purpose: Destructor
+// Created: 2003/08/06
+//
+// --------------------------------------------------------------------------
+SocketStreamTLS::~SocketStreamTLS()
+{
+ if(mpSSL)
+ {
+ // Attempt to close to avoid problems
+ Close();
+
+ // And if that didn't work...
+ if(mpSSL)
+ {
+ ::SSL_free(mpSSL);
+ mpSSL = 0;
+ mpBIO = 0; // implicity freed by the SSL_free call
+ }
+ }
+
+ // If we only got to creating that BIO.
+ if(mpBIO)
+ {
+ ::BIO_free(mpBIO);
+ mpBIO = 0;
+ }
+}
+
+
+// --------------------------------------------------------------------------
+//
+// Function
+// Name: SocketStreamTLS::Open(const TLSContext &, int, const char *, int)
+// Purpose: Open connection, and perform TLS handshake
+// Created: 2003/08/06
+//
+// --------------------------------------------------------------------------
+void SocketStreamTLS::Open(const TLSContext &rContext, Socket::Type Type,
+ const std::string& rName, int Port)
+{
+ SocketStream::Open(Type, rName, Port);
+ Handshake(rContext);
+ ResetCounters();
+}
+
+
+// --------------------------------------------------------------------------
+//
+// Function
+// Name: SocketStreamTLS::Handshake(const TLSContext &, bool)
+// Purpose: Perform TLS handshake
+// Created: 2003/08/06
+//
+// --------------------------------------------------------------------------
+void SocketStreamTLS::Handshake(const TLSContext &rContext, bool IsServer)
+{
+ if(mpBIO || mpSSL) {THROW_EXCEPTION(ServerException, TLSAlreadyHandshaked)}
+
+ // Create a BIO for this socket
+ mpBIO = ::BIO_new(::BIO_s_socket());
+ if(mpBIO == 0)
+ {
+ SSLLib::LogError("creating socket bio");
+ THROW_EXCEPTION(ServerException, TLSAllocationFailed)
+ }
+
+ tOSSocketHandle socket = GetSocketHandle();
+ BIO_set_fd(mpBIO, socket, BIO_NOCLOSE);
+
+ // Then the SSL object
+ mpSSL = ::SSL_new(rContext.GetRawContext());
+ if(mpSSL == 0)
+ {
+ SSLLib::LogError("creating SSL object");
+ THROW_EXCEPTION(ServerException, TLSAllocationFailed)
+ }
+
+ // Make the socket non-blocking so timeouts on Read work
+
+#ifdef WIN32
+ u_long nonblocking = 1;
+ ioctlsocket(socket, FIONBIO, &nonblocking);
+#else // !WIN32
+ // This is more portable than using ioctl with FIONBIO
+ int statusFlags = 0;
+ if(::fcntl(socket, F_GETFL, &statusFlags) < 0
+ || ::fcntl(socket, F_SETFL, statusFlags | O_NONBLOCK) == -1)
+ {
+ THROW_EXCEPTION(ServerException, SocketSetNonBlockingFailed)
+ }
+#endif
+
+ // FIXME: This is less portable than the above. However, it MAY be needed
+ // for cygwin, which has/had bugs with fcntl
+ //
+ // int nonblocking = true;
+ // if(::ioctl(socket, FIONBIO, &nonblocking) == -1)
+ // {
+ // THROW_EXCEPTION(ServerException, SocketSetNonBlockingFailed)
+ // }
+
+ // Set the two to know about each other
+ ::SSL_set_bio(mpSSL, mpBIO, mpBIO);
+
+ bool waitingForHandshake = true;
+ while(waitingForHandshake)
+ {
+ // Attempt to do the handshake
+ int r = 0;
+ if(IsServer)
+ {
+ r = ::SSL_accept(mpSSL);
+ }
+ else
+ {
+ r = ::SSL_connect(mpSSL);
+ }
+
+ // check return code
+ int se;
+ switch((se = ::SSL_get_error(mpSSL, r)))
+ {
+ case SSL_ERROR_NONE:
+ // No error, handshake succeeded
+ waitingForHandshake = false;
+ break;
+
+ case SSL_ERROR_WANT_READ:
+ case SSL_ERROR_WANT_WRITE:
+ // wait for the requried data
+ if(WaitWhenRetryRequired(se, TLS_HANDSHAKE_TIMEOUT) == false)
+ {
+ // timed out
+ THROW_EXCEPTION(ConnectionException, Conn_TLSHandshakeTimedOut)
+ }
+ break;
+
+ default: // (and SSL_ERROR_ZERO_RETURN)
+ // Error occured
+ if(IsServer)
+ {
+ SSLLib::LogError("accepting connection");
+ THROW_EXCEPTION(ConnectionException, Conn_TLSHandshakeFailed)
+ }
+ else
+ {
+ SSLLib::LogError("connecting");
+ THROW_EXCEPTION(ConnectionException, Conn_TLSHandshakeFailed)
+ }
+ }
+ }
+
+ // And that's it
+}
+
+// --------------------------------------------------------------------------
+//
+// Function
+// Name: WaitWhenRetryRequired(int, int)
+// Purpose: Waits until the condition required by the TLS layer is met.
+// Returns true if the condition is met, false if timed out.
+// Created: 2003/08/15
+//
+// --------------------------------------------------------------------------
+bool SocketStreamTLS::WaitWhenRetryRequired(int SSLErrorCode, int Timeout)
+{
+ struct pollfd p;
+ p.fd = GetSocketHandle();
+ switch(SSLErrorCode)
+ {
+ case SSL_ERROR_WANT_READ:
+ p.events = POLLIN;
+ break;
+
+ case SSL_ERROR_WANT_WRITE:
+ p.events = POLLOUT;
+ break;
+
+ default:
+ // Not good!
+ THROW_EXCEPTION(ServerException, Internal)
+ break;
+ }
+ p.revents = 0;
+
+ int64_t start, end;
+ start = BoxTimeToMilliSeconds(GetCurrentBoxTime());
+ end = start + Timeout;
+ int result;
+
+ do
+ {
+ int64_t now = BoxTimeToMilliSeconds(GetCurrentBoxTime());
+ int poll_timeout = (int)(end - now);
+ if (poll_timeout < 0) poll_timeout = 0;
+ if (Timeout == IOStream::TimeOutInfinite)
+ {
+ poll_timeout = INFTIM;
+ }
+ result = ::poll(&p, 1, poll_timeout);
+ }
+ while(result == -1 && errno == EINTR);
+
+ switch(result)
+ {
+ case -1:
+ // error - Bad!
+ THROW_EXCEPTION(ServerException, SocketPollError)
+ break;
+
+ case 0:
+ // Condition not met, timed out
+ return false;
+ break;
+
+ default:
+ // good to go!
+ return true;
+ break;
+ }
+
+ return true;
+}
+
+// --------------------------------------------------------------------------
+//
+// Function
+// Name: SocketStreamTLS::Read(void *, int, int Timeout)
+// Purpose: See base class
+// Created: 2003/08/06
+//
+// --------------------------------------------------------------------------
+int SocketStreamTLS::Read(void *pBuffer, int NBytes, int Timeout)
+{
+ if(!mpSSL) {THROW_EXCEPTION(ServerException, TLSNoSSLObject)}
+
+ // Make sure zero byte reads work as expected
+ if(NBytes == 0)
+ {
+ return 0;
+ }
+
+ while(true)
+ {
+ int r = ::SSL_read(mpSSL, pBuffer, NBytes);
+
+ int se;
+ switch((se = ::SSL_get_error(mpSSL, r)))
+ {
+ case SSL_ERROR_NONE:
+ // No error, return number of bytes read
+ mBytesRead += r;
+ return r;
+ break;
+
+ case SSL_ERROR_ZERO_RETURN:
+ // Connection closed
+ MarkAsReadClosed();
+ return 0;
+ break;
+
+ case SSL_ERROR_WANT_READ:
+ case SSL_ERROR_WANT_WRITE:
+ // wait for the required data
+ // Will only get once around this loop, so don't need to calculate timeout values
+ if(WaitWhenRetryRequired(se, Timeout) == false)
+ {
+ // timed out
+ return 0;
+ }
+ break;
+
+ default:
+ SSLLib::LogError("reading");
+ THROW_EXCEPTION(ConnectionException, Conn_TLSReadFailed)
+ break;
+ }
+ }
+}
+
+// --------------------------------------------------------------------------
+//
+// Function
+// Name: SocketStreamTLS::Write(const void *, int)
+// Purpose: See base class
+// Created: 2003/08/06
+//
+// --------------------------------------------------------------------------
+void SocketStreamTLS::Write(const void *pBuffer, int NBytes)
+{
+ if(!mpSSL) {THROW_EXCEPTION(ServerException, TLSNoSSLObject)}
+
+ // Make sure zero byte writes work as expected
+ if(NBytes == 0)
+ {
+ return;
+ }
+
+ // from man SSL_write
+ //
+ // SSL_write() will only return with success, when the
+ // complete contents of buf of length num has been written.
+ //
+ // So no worries about partial writes and moving the buffer around
+
+ while(true)
+ {
+ // try the write
+ int r = ::SSL_write(mpSSL, pBuffer, NBytes);
+
+ int se;
+ switch((se = ::SSL_get_error(mpSSL, r)))
+ {
+ case SSL_ERROR_NONE:
+ // No error, data sent, return success
+ mBytesWritten += r;
+ return;
+ break;
+
+ case SSL_ERROR_ZERO_RETURN:
+ // Connection closed
+ MarkAsWriteClosed();
+ THROW_EXCEPTION(ConnectionException, Conn_TLSClosedWhenWriting)
+ break;
+
+ case SSL_ERROR_WANT_READ:
+ case SSL_ERROR_WANT_WRITE:
+ // wait for the requried data
+ {
+ #ifndef BOX_RELEASE_BUILD
+ bool conditionmet =
+ #endif
+ WaitWhenRetryRequired(se, IOStream::TimeOutInfinite);
+ ASSERT(conditionmet);
+ }
+ break;
+
+ default:
+ SSLLib::LogError("writing");
+ THROW_EXCEPTION(ConnectionException, Conn_TLSWriteFailed)
+ break;
+ }
+ }
+}
+
+// --------------------------------------------------------------------------
+//
+// Function
+// Name: SocketStreamTLS::Close()
+// Purpose: See base class
+// Created: 2003/08/06
+//
+// --------------------------------------------------------------------------
+void SocketStreamTLS::Close()
+{
+ if(!mpSSL) {THROW_EXCEPTION(ServerException, TLSNoSSLObject)}
+
+ // Base class to close
+ SocketStream::Close();
+
+ // Free resources
+ ::SSL_free(mpSSL);
+ mpSSL = 0;
+ mpBIO = 0; // implicitly freed by SSL_free
+}
+
+// --------------------------------------------------------------------------
+//
+// Function
+// Name: SocketStreamTLS::Shutdown()
+// Purpose: See base class
+// Created: 2003/08/06
+//
+// --------------------------------------------------------------------------
+void SocketStreamTLS::Shutdown(bool Read, bool Write)
+{
+ if(!mpSSL) {THROW_EXCEPTION(ServerException, TLSNoSSLObject)}
+
+ if(::SSL_shutdown(mpSSL) < 0)
+ {
+ SSLLib::LogError("shutting down");
+ THROW_EXCEPTION(ConnectionException, Conn_TLSShutdownFailed)
+ }
+
+ // Don't ask the base class to shutdown -- BIO does this, apparently.
+}
+
+// --------------------------------------------------------------------------
+//
+// Function
+// Name: SocketStreamTLS::GetPeerCommonName()
+// Purpose: Returns the common name of the other end of the connection
+// Created: 2003/08/06
+//
+// --------------------------------------------------------------------------
+std::string SocketStreamTLS::GetPeerCommonName()
+{
+ if(!mpSSL) {THROW_EXCEPTION(ServerException, TLSNoSSLObject)}
+
+ // Get certificate
+ X509 *cert = ::SSL_get_peer_certificate(mpSSL);
+ if(cert == 0)
+ {
+ ::X509_free(cert);
+ THROW_EXCEPTION(ConnectionException, Conn_TLSNoPeerCertificate)
+ }
+
+ // Subject details
+ X509_NAME *subject = ::X509_get_subject_name(cert);
+ if(subject == 0)
+ {
+ ::X509_free(cert);
+ THROW_EXCEPTION(ConnectionException, Conn_TLSPeerCertificateInvalid)
+ }
+
+ // Common name
+ char commonName[256];
+ if(::X509_NAME_get_text_by_NID(subject, NID_commonName, commonName, sizeof(commonName)) <= 0)
+ {
+ ::X509_free(cert);
+ THROW_EXCEPTION(ConnectionException, Conn_TLSPeerCertificateInvalid)
+ }
+ // Terminate just in case
+ commonName[sizeof(commonName)-1] = '\0';
+
+ // Done.
+ return std::string(commonName);
+}