From 2b6ac135fa7071290289741c9e35747bb9f1012f Mon Sep 17 00:00:00 2001 From: Chris Wilson Date: Sat, 1 Mar 2014 10:42:48 +0000 Subject: Make Protocol take control of the socket object passed in. We pass a std::auto_ptr to every Protocol subclass when we construct it, and it takes control of this object. This reduces the risk of: * accidentally reusing the same SocketStream for multiple Protocols (it happened to me in testbackupstore); * holding onto a reference to the SocketStream; * allowing a locally-scoped SocketStream to go out of scope and be released while still being referenced by a live Protocol. --- lib/server/Protocol.cpp | 46 +++++++++++++++++++++++++------------------ lib/server/Protocol.h | 7 +++++-- lib/server/ServerStream.h | 10 +++++----- lib/server/ServerTLS.h | 6 +++--- lib/server/TcpNice.h | 6 +++++- lib/server/makeprotocol.pl.in | 13 ++++++------ 6 files changed, 52 insertions(+), 36 deletions(-) (limited to 'lib/server') diff --git a/lib/server/Protocol.cpp b/lib/server/Protocol.cpp index 382f1c37..3c661154 100644 --- a/lib/server/Protocol.cpp +++ b/lib/server/Protocol.cpp @@ -19,7 +19,7 @@ #include "Protocol.h" #include "ProtocolWire.h" -#include "IOStream.h" +#include "SocketStream.h" #include "ServerException.h" #include "PartialReadStream.h" #include "ProtocolUncertainStream.h" @@ -44,8 +44,8 @@ // Created: 2003/08/19 // // -------------------------------------------------------------------------- -Protocol::Protocol(IOStream &rStream) -: mrStream(rStream), +Protocol::Protocol(std::auto_ptr apConn) +: mapConn(apConn), mHandshakeDone(false), mMaxObjectSize(PROTOCOL_DEFAULT_MAXOBJSIZE), mTimeout(PROTOCOL_DEFAULT_TIMEOUT), @@ -103,8 +103,8 @@ void Protocol::Handshake() ::strncpy(hsSend.mIdent, GetProtocolIdentString(), sizeof(hsSend.mIdent)); // Send it - mrStream.Write(&hsSend, sizeof(hsSend)); - mrStream.WriteAllBuffered(); + mapConn->Write(&hsSend, sizeof(hsSend)); + mapConn->WriteAllBuffered(); // Receive a handshake from the peer PW_Handshake hsReceive; @@ -114,7 +114,7 @@ void Protocol::Handshake() while(bytesToRead > 0) { // Get some data from the stream - int bytesRead = mrStream.Read(readInto, bytesToRead, mTimeout); + int bytesRead = mapConn->Read(readInto, bytesToRead, mTimeout); if(bytesRead == 0) { THROW_EXCEPTION(ConnectionException, Conn_Protocol_Timeout) @@ -158,7 +158,8 @@ void Protocol::CheckAndReadHdr(void *hdr) } // Get some data into this header - if(!mrStream.ReadFullBuffer(hdr, sizeof(PW_ObjectHeader), 0 /* not interested in bytes read if this fails */, mTimeout)) + if(!mapConn->ReadFullBuffer(hdr, sizeof(PW_ObjectHeader), + 0 /* not interested in bytes read if this fails */, mTimeout)) { THROW_EXCEPTION(ConnectionException, Conn_Protocol_Timeout) } @@ -199,7 +200,8 @@ std::auto_ptr Protocol::ReceiveInternal() EnsureBufferAllocated(objSize); // Read data - if(!mrStream.ReadFullBuffer(mpBuffer, objSize - sizeof(objHeader), 0 /* not interested in bytes read if this fails */, mTimeout)) + if(!mapConn->ReadFullBuffer(mpBuffer, objSize - sizeof(objHeader), + 0 /* not interested in bytes read if this fails */, mTimeout)) { THROW_EXCEPTION(ConnectionException, Conn_Protocol_Timeout) } @@ -292,8 +294,8 @@ void Protocol::SendInternal(const Message &rObject) pobjHeader->mObjType = htonl(rObject.GetType()); // Write data - mrStream.Write(mpBuffer, writtenSize); - mrStream.WriteAllBuffered(); + mapConn->Write(mpBuffer, writtenSize); + mapConn->WriteAllBuffered(); } // -------------------------------------------------------------------------- @@ -647,13 +649,13 @@ std::auto_ptr Protocol::ReceiveStream() { BOX_TRACE("Receiving stream, size uncertain"); return std::auto_ptr( - new ProtocolUncertainStream(mrStream)); + new ProtocolUncertainStream(*mapConn)); } else { BOX_TRACE("Receiving stream, size " << streamSize << " bytes"); return std::auto_ptr( - new PartialReadStream(mrStream, streamSize)); + new PartialReadStream(*mapConn, streamSize)); } } @@ -709,7 +711,7 @@ void Protocol::SendStream(IOStream &rStream) objHeader.mObjType = htonl(SPECIAL_STREAM_OBJECT_TYPE); // Write header - mrStream.Write(&objHeader, sizeof(objHeader)); + mapConn->Write(&objHeader, sizeof(objHeader)); // Could be sent in one of two ways if(uncertainSize) { @@ -744,7 +746,7 @@ void Protocol::SendStream(IOStream &rStream) // Send final byte to finish the stream BOX_TRACE("Sending end of stream byte"); uint8_t endOfStream = ProtocolStreamHeader_EndOfStream; - mrStream.Write(&endOfStream, 1); + mapConn->Write(&endOfStream, 1); BOX_TRACE("Sent end of stream byte"); } catch(...) @@ -759,13 +761,13 @@ void Protocol::SendStream(IOStream &rStream) else { // Fixed size stream, send it all in one go - if(!rStream.CopyStreamTo(mrStream, mTimeout, 4096 /* slightly larger buffer */)) + if(!rStream.CopyStreamTo(*mapConn, mTimeout, 4096 /* slightly larger buffer */)) { THROW_EXCEPTION(ConnectionException, Conn_Protocol_TimeOutWhenSendingStream) } } // Make sure everything is written - mrStream.WriteAllBuffered(); + mapConn->WriteAllBuffered(); } @@ -816,7 +818,7 @@ int Protocol::SendStreamSendBlock(uint8_t *Block, int BytesInBlock) Block[-1] = header; // Write everything out - mrStream.Write(Block - 1, writeSize + 1); + mapConn->Write(Block - 1, writeSize + 1); BOX_TRACE("Sent " << (writeSize+1) << " bytes to stream"); // move the remainer to the beginning of the block for the next time round @@ -1177,6 +1179,12 @@ const uint16_t Protocol::sProtocolStreamHeaderLengths[256] = 0 // 255 = special (reserved) }; +int64_t Protocol::GetBytesRead() const +{ + return mapConn->GetBytesRead(); +} - - +int64_t Protocol::GetBytesWritten() const +{ + return mapConn->GetBytesWritten(); +} diff --git a/lib/server/Protocol.h b/lib/server/Protocol.h index 42cb0ff8..0995393d 100644 --- a/lib/server/Protocol.h +++ b/lib/server/Protocol.h @@ -19,6 +19,7 @@ #include "Message.h" class IOStream; +class SocketStream; // default timeout is 15 minutes #define PROTOCOL_DEFAULT_TIMEOUT (15*60*1000) @@ -36,7 +37,7 @@ class IOStream; class Protocol { public: - Protocol(IOStream &rStream); + Protocol(std::auto_ptr apConn); virtual ~Protocol(); private: @@ -175,6 +176,8 @@ public: FILE *GetLogToFile() { return mLogToFile; } void SetLogToSysLog(bool Log = false) {mLogToSysLog = Log;} void SetLogToFile(FILE *File = 0) {mLogToFile = File;} + int64_t GetBytesRead() const; + int64_t GetBytesWritten() const; protected: virtual std::auto_ptr MakeMessage(int ObjType) = 0; @@ -190,7 +193,7 @@ private: void EnsureBufferAllocated(int Size); int SendStreamSendBlock(uint8_t *Block, int BytesInBlock); - IOStream &mrStream; + std::auto_ptr mapConn; bool mHandshakeDone; unsigned int mMaxObjectSize; int mTimeout; diff --git a/lib/server/ServerStream.h b/lib/server/ServerStream.h index a9b56169..8bb52b5b 100644 --- a/lib/server/ServerStream.h +++ b/lib/server/ServerStream.h @@ -286,7 +286,7 @@ public: #endif // The derived class does some server magic with the connection - HandleConnection(*connection); + HandleConnection(connection); // Since rChildExit == true, the forked process will call _exit() on return from this fn return; @@ -305,7 +305,7 @@ public: #endif // !WIN32 // Just handle in this process SetProcessTitle("handling"); - HandleConnection(*connection); + HandleConnection(connection); SetProcessTitle("idle"); #ifndef WIN32 } @@ -377,12 +377,12 @@ public: } #endif - virtual void HandleConnection(StreamType &rStream) + virtual void HandleConnection(std::auto_ptr apStream) { - Connection(rStream); + Connection(apStream); } - virtual void Connection(StreamType &rStream) = 0; + virtual void Connection(std::auto_ptr apStream) = 0; protected: // For checking code in derived classes -- use if you have an algorithm which diff --git a/lib/server/ServerTLS.h b/lib/server/ServerTLS.h index a74a671e..20e55964 100644 --- a/lib/server/ServerTLS.h +++ b/lib/server/ServerTLS.h @@ -59,11 +59,11 @@ public: ForkToHandleRequests>::Run2(rChildExit); } - virtual void HandleConnection(SocketStreamTLS &rStream) + virtual void HandleConnection(std::auto_ptr apStream) { - rStream.Handshake(mContext, true /* is server */); + apStream->Handshake(mContext, true /* is server */); // this-> in next line required to build under some gcc versions - this->Connection(rStream); + this->Connection(apStream); } private: diff --git a/lib/server/TcpNice.h b/lib/server/TcpNice.h index e2027749..4381df42 100644 --- a/lib/server/TcpNice.h +++ b/lib/server/TcpNice.h @@ -83,7 +83,7 @@ private: // // -------------------------------------------------------------------------- -class NiceSocketStream : public IOStream +class NiceSocketStream : public SocketStream { private: std::auto_ptr mapSocket; @@ -166,6 +166,10 @@ public: } virtual void SetEnabled(bool enabled); + off_t GetBytesRead() const { return mapSocket->GetBytesRead(); } + off_t GetBytesWritten() const { return mapSocket->GetBytesWritten(); } + void ResetCounters() { mapSocket->ResetCounters(); } + private: NiceSocketStream(const NiceSocketStream &rToCopy) { /* do not call */ } diff --git a/lib/server/makeprotocol.pl.in b/lib/server/makeprotocol.pl.in index 35814d1d..00dc58d4 100755 --- a/lib/server/makeprotocol.pl.in +++ b/lib/server/makeprotocol.pl.in @@ -158,7 +158,9 @@ print CPP <<__E; #include #include "$filename_base.h" -#include "IOStream.h" +#include "CollectInBufferStream.h" +#include "SocketStream.h" +#include "MemBlockStream.h" __E print H <<__E; @@ -179,8 +181,7 @@ print H <<__E; #include "ServerException.h" class IOStream; - - +class SocketStream; __E # extra headers @@ -741,7 +742,7 @@ __E else { print H <<__E; - $server_or_client_class(IOStream &rStream); + $server_or_client_class(std::auto_ptr apConn); std::auto_ptr<$message_base_class> Receive(); void Send(const $message_base_class &rObject); __E @@ -885,8 +886,8 @@ __E else { print CPP <<__E; -$server_or_client_class\::$server_or_client_class(IOStream &rStream) -: Protocol(rStream) +$server_or_client_class\::$server_or_client_class(std::auto_ptr apConn) +: Protocol(apConn) { } __E } -- cgit v1.2.3