summaryrefslogtreecommitdiff
path: root/lib/server
diff options
context:
space:
mode:
authorChris Wilson <chris+github@qwirx.com>2014-03-01 10:42:48 +0000
committerChris Wilson <chris+github@qwirx.com>2014-03-01 10:42:48 +0000
commit2b6ac135fa7071290289741c9e35747bb9f1012f (patch)
tree037507c43ac097fb2cd5850bbd57f25a553e20e8 /lib/server
parent2ef0a9aa8cd3cd4dcfa0cd9d2014051832c52e8a (diff)
Make Protocol take control of the socket object passed in.
We pass a std::auto_ptr<SocketStream> to every Protocol subclass when we construct it, and it takes control of this object. This reduces the risk of: * accidentally reusing the same SocketStream for multiple Protocols (it happened to me in testbackupstore); * holding onto a reference to the SocketStream; * allowing a locally-scoped SocketStream to go out of scope and be released while still being referenced by a live Protocol.
Diffstat (limited to 'lib/server')
-rw-r--r--lib/server/Protocol.cpp46
-rw-r--r--lib/server/Protocol.h7
-rw-r--r--lib/server/ServerStream.h10
-rw-r--r--lib/server/ServerTLS.h6
-rw-r--r--lib/server/TcpNice.h6
-rwxr-xr-xlib/server/makeprotocol.pl.in13
6 files changed, 52 insertions, 36 deletions
diff --git a/lib/server/Protocol.cpp b/lib/server/Protocol.cpp
index 382f1c37..3c661154 100644
--- a/lib/server/Protocol.cpp
+++ b/lib/server/Protocol.cpp
@@ -19,7 +19,7 @@
#include "Protocol.h"
#include "ProtocolWire.h"
-#include "IOStream.h"
+#include "SocketStream.h"
#include "ServerException.h"
#include "PartialReadStream.h"
#include "ProtocolUncertainStream.h"
@@ -44,8 +44,8 @@
// Created: 2003/08/19
//
// --------------------------------------------------------------------------
-Protocol::Protocol(IOStream &rStream)
-: mrStream(rStream),
+Protocol::Protocol(std::auto_ptr<SocketStream> apConn)
+: mapConn(apConn),
mHandshakeDone(false),
mMaxObjectSize(PROTOCOL_DEFAULT_MAXOBJSIZE),
mTimeout(PROTOCOL_DEFAULT_TIMEOUT),
@@ -103,8 +103,8 @@ void Protocol::Handshake()
::strncpy(hsSend.mIdent, GetProtocolIdentString(), sizeof(hsSend.mIdent));
// Send it
- mrStream.Write(&hsSend, sizeof(hsSend));
- mrStream.WriteAllBuffered();
+ mapConn->Write(&hsSend, sizeof(hsSend));
+ mapConn->WriteAllBuffered();
// Receive a handshake from the peer
PW_Handshake hsReceive;
@@ -114,7 +114,7 @@ void Protocol::Handshake()
while(bytesToRead > 0)
{
// Get some data from the stream
- int bytesRead = mrStream.Read(readInto, bytesToRead, mTimeout);
+ int bytesRead = mapConn->Read(readInto, bytesToRead, mTimeout);
if(bytesRead == 0)
{
THROW_EXCEPTION(ConnectionException, Conn_Protocol_Timeout)
@@ -158,7 +158,8 @@ void Protocol::CheckAndReadHdr(void *hdr)
}
// Get some data into this header
- if(!mrStream.ReadFullBuffer(hdr, sizeof(PW_ObjectHeader), 0 /* not interested in bytes read if this fails */, mTimeout))
+ if(!mapConn->ReadFullBuffer(hdr, sizeof(PW_ObjectHeader),
+ 0 /* not interested in bytes read if this fails */, mTimeout))
{
THROW_EXCEPTION(ConnectionException, Conn_Protocol_Timeout)
}
@@ -199,7 +200,8 @@ std::auto_ptr<Message> Protocol::ReceiveInternal()
EnsureBufferAllocated(objSize);
// Read data
- if(!mrStream.ReadFullBuffer(mpBuffer, objSize - sizeof(objHeader), 0 /* not interested in bytes read if this fails */, mTimeout))
+ if(!mapConn->ReadFullBuffer(mpBuffer, objSize - sizeof(objHeader),
+ 0 /* not interested in bytes read if this fails */, mTimeout))
{
THROW_EXCEPTION(ConnectionException, Conn_Protocol_Timeout)
}
@@ -292,8 +294,8 @@ void Protocol::SendInternal(const Message &rObject)
pobjHeader->mObjType = htonl(rObject.GetType());
// Write data
- mrStream.Write(mpBuffer, writtenSize);
- mrStream.WriteAllBuffered();
+ mapConn->Write(mpBuffer, writtenSize);
+ mapConn->WriteAllBuffered();
}
// --------------------------------------------------------------------------
@@ -647,13 +649,13 @@ std::auto_ptr<IOStream> Protocol::ReceiveStream()
{
BOX_TRACE("Receiving stream, size uncertain");
return std::auto_ptr<IOStream>(
- new ProtocolUncertainStream(mrStream));
+ new ProtocolUncertainStream(*mapConn));
}
else
{
BOX_TRACE("Receiving stream, size " << streamSize << " bytes");
return std::auto_ptr<IOStream>(
- new PartialReadStream(mrStream, streamSize));
+ new PartialReadStream(*mapConn, streamSize));
}
}
@@ -709,7 +711,7 @@ void Protocol::SendStream(IOStream &rStream)
objHeader.mObjType = htonl(SPECIAL_STREAM_OBJECT_TYPE);
// Write header
- mrStream.Write(&objHeader, sizeof(objHeader));
+ mapConn->Write(&objHeader, sizeof(objHeader));
// Could be sent in one of two ways
if(uncertainSize)
{
@@ -744,7 +746,7 @@ void Protocol::SendStream(IOStream &rStream)
// Send final byte to finish the stream
BOX_TRACE("Sending end of stream byte");
uint8_t endOfStream = ProtocolStreamHeader_EndOfStream;
- mrStream.Write(&endOfStream, 1);
+ mapConn->Write(&endOfStream, 1);
BOX_TRACE("Sent end of stream byte");
}
catch(...)
@@ -759,13 +761,13 @@ void Protocol::SendStream(IOStream &rStream)
else
{
// Fixed size stream, send it all in one go
- if(!rStream.CopyStreamTo(mrStream, mTimeout, 4096 /* slightly larger buffer */))
+ if(!rStream.CopyStreamTo(*mapConn, mTimeout, 4096 /* slightly larger buffer */))
{
THROW_EXCEPTION(ConnectionException, Conn_Protocol_TimeOutWhenSendingStream)
}
}
// Make sure everything is written
- mrStream.WriteAllBuffered();
+ mapConn->WriteAllBuffered();
}
@@ -816,7 +818,7 @@ int Protocol::SendStreamSendBlock(uint8_t *Block, int BytesInBlock)
Block[-1] = header;
// Write everything out
- mrStream.Write(Block - 1, writeSize + 1);
+ mapConn->Write(Block - 1, writeSize + 1);
BOX_TRACE("Sent " << (writeSize+1) << " bytes to stream");
// move the remainer to the beginning of the block for the next time round
@@ -1177,6 +1179,12 @@ const uint16_t Protocol::sProtocolStreamHeaderLengths[256] =
0 // 255 = special (reserved)
};
+int64_t Protocol::GetBytesRead() const
+{
+ return mapConn->GetBytesRead();
+}
-
-
+int64_t Protocol::GetBytesWritten() const
+{
+ return mapConn->GetBytesWritten();
+}
diff --git a/lib/server/Protocol.h b/lib/server/Protocol.h
index 42cb0ff8..0995393d 100644
--- a/lib/server/Protocol.h
+++ b/lib/server/Protocol.h
@@ -19,6 +19,7 @@
#include "Message.h"
class IOStream;
+class SocketStream;
// default timeout is 15 minutes
#define PROTOCOL_DEFAULT_TIMEOUT (15*60*1000)
@@ -36,7 +37,7 @@ class IOStream;
class Protocol
{
public:
- Protocol(IOStream &rStream);
+ Protocol(std::auto_ptr<SocketStream> apConn);
virtual ~Protocol();
private:
@@ -175,6 +176,8 @@ public:
FILE *GetLogToFile() { return mLogToFile; }
void SetLogToSysLog(bool Log = false) {mLogToSysLog = Log;}
void SetLogToFile(FILE *File = 0) {mLogToFile = File;}
+ int64_t GetBytesRead() const;
+ int64_t GetBytesWritten() const;
protected:
virtual std::auto_ptr<Message> MakeMessage(int ObjType) = 0;
@@ -190,7 +193,7 @@ private:
void EnsureBufferAllocated(int Size);
int SendStreamSendBlock(uint8_t *Block, int BytesInBlock);
- IOStream &mrStream;
+ std::auto_ptr<SocketStream> mapConn;
bool mHandshakeDone;
unsigned int mMaxObjectSize;
int mTimeout;
diff --git a/lib/server/ServerStream.h b/lib/server/ServerStream.h
index a9b56169..8bb52b5b 100644
--- a/lib/server/ServerStream.h
+++ b/lib/server/ServerStream.h
@@ -286,7 +286,7 @@ public:
#endif
// The derived class does some server magic with the connection
- HandleConnection(*connection);
+ HandleConnection(connection);
// Since rChildExit == true, the forked process will call _exit() on return from this fn
return;
@@ -305,7 +305,7 @@ public:
#endif // !WIN32
// Just handle in this process
SetProcessTitle("handling");
- HandleConnection(*connection);
+ HandleConnection(connection);
SetProcessTitle("idle");
#ifndef WIN32
}
@@ -377,12 +377,12 @@ public:
}
#endif
- virtual void HandleConnection(StreamType &rStream)
+ virtual void HandleConnection(std::auto_ptr<StreamType> apStream)
{
- Connection(rStream);
+ Connection(apStream);
}
- virtual void Connection(StreamType &rStream) = 0;
+ virtual void Connection(std::auto_ptr<StreamType> apStream) = 0;
protected:
// For checking code in derived classes -- use if you have an algorithm which
diff --git a/lib/server/ServerTLS.h b/lib/server/ServerTLS.h
index a74a671e..20e55964 100644
--- a/lib/server/ServerTLS.h
+++ b/lib/server/ServerTLS.h
@@ -59,11 +59,11 @@ public:
ForkToHandleRequests>::Run2(rChildExit);
}
- virtual void HandleConnection(SocketStreamTLS &rStream)
+ virtual void HandleConnection(std::auto_ptr<SocketStreamTLS> apStream)
{
- rStream.Handshake(mContext, true /* is server */);
+ apStream->Handshake(mContext, true /* is server */);
// this-> in next line required to build under some gcc versions
- this->Connection(rStream);
+ this->Connection(apStream);
}
private:
diff --git a/lib/server/TcpNice.h b/lib/server/TcpNice.h
index e2027749..4381df42 100644
--- a/lib/server/TcpNice.h
+++ b/lib/server/TcpNice.h
@@ -83,7 +83,7 @@ private:
//
// --------------------------------------------------------------------------
-class NiceSocketStream : public IOStream
+class NiceSocketStream : public SocketStream
{
private:
std::auto_ptr<SocketStream> mapSocket;
@@ -166,6 +166,10 @@ public:
}
virtual void SetEnabled(bool enabled);
+ off_t GetBytesRead() const { return mapSocket->GetBytesRead(); }
+ off_t GetBytesWritten() const { return mapSocket->GetBytesWritten(); }
+ void ResetCounters() { mapSocket->ResetCounters(); }
+
private:
NiceSocketStream(const NiceSocketStream &rToCopy)
{ /* do not call */ }
diff --git a/lib/server/makeprotocol.pl.in b/lib/server/makeprotocol.pl.in
index 35814d1d..00dc58d4 100755
--- a/lib/server/makeprotocol.pl.in
+++ b/lib/server/makeprotocol.pl.in
@@ -158,7 +158,9 @@ print CPP <<__E;
#include <sstream>
#include "$filename_base.h"
-#include "IOStream.h"
+#include "CollectInBufferStream.h"
+#include "SocketStream.h"
+#include "MemBlockStream.h"
__E
print H <<__E;
@@ -179,8 +181,7 @@ print H <<__E;
#include "ServerException.h"
class IOStream;
-
-
+class SocketStream;
__E
# extra headers
@@ -741,7 +742,7 @@ __E
else
{
print H <<__E;
- $server_or_client_class(IOStream &rStream);
+ $server_or_client_class(std::auto_ptr<SocketStream> apConn);
std::auto_ptr<$message_base_class> Receive();
void Send(const $message_base_class &rObject);
__E
@@ -885,8 +886,8 @@ __E
else
{
print CPP <<__E;
-$server_or_client_class\::$server_or_client_class(IOStream &rStream)
-: Protocol(rStream)
+$server_or_client_class\::$server_or_client_class(std::auto_ptr<SocketStream> apConn)
+: Protocol(apConn)
{ }
__E
}