summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bin/bbackupd/BackupClientContext.cpp50
-rw-r--r--bin/bbackupd/BackupClientContext.h5
-rw-r--r--bin/bbackupquery/bbackupquery.cpp7
-rw-r--r--bin/bbstored/BackupStoreDaemon.cpp25
-rw-r--r--bin/bbstored/BackupStoreDaemon.h6
-rw-r--r--lib/backupstore/StoreTestUtils.cpp16
-rw-r--r--lib/backupstore/StoreTestUtils.h6
-rw-r--r--lib/httpserver/HTTPServer.cpp6
-rw-r--r--lib/httpserver/HTTPServer.h2
-rw-r--r--lib/server/Protocol.cpp46
-rw-r--r--lib/server/Protocol.h7
-rw-r--r--lib/server/ServerStream.h10
-rw-r--r--lib/server/ServerTLS.h6
-rw-r--r--lib/server/TcpNice.h6
-rwxr-xr-xlib/server/makeprotocol.pl.in13
-rw-r--r--test/backupstore/testbackupstore.cpp75
-rw-r--r--test/backupstorepatch/testbackupstorepatch.cpp16
-rw-r--r--test/basicserver/testbasicserver.cpp26
-rw-r--r--test/bbackupd/testbbackupd.cpp51
19 files changed, 167 insertions, 212 deletions
diff --git a/bin/bbackupd/BackupClientContext.cpp b/bin/bbackupd/BackupClientContext.cpp
index 26df04be..45c48a67 100644
--- a/bin/bbackupd/BackupClientContext.cpp
+++ b/bin/bbackupd/BackupClientContext.cpp
@@ -73,7 +73,8 @@ BackupClientContext::BackupClientContext
mKeepAliveTimer(0, "KeepAliveTime"),
mbIsManaged(false),
mrProgressNotifier(rProgressNotifier),
- mTcpNiceMode(TcpNiceMode)
+ mTcpNiceMode(TcpNiceMode),
+ mpNice(NULL)
{
}
@@ -113,13 +114,12 @@ BackupProtocolClient &BackupClientContext::GetConnection()
{
return *mapConnection;
}
-
- // there shouldn't be a connection open
- ASSERT(mapSocket.get() == 0);
+
// Defensive. Must close connection before releasing any old socket.
mapConnection.reset();
- mapSocket.reset(new SocketStreamTLS);
-
+
+ std::auto_ptr<SocketStream> apSocket(new SocketStreamTLS);
+
try
{
// Defensive.
@@ -130,21 +130,22 @@ BackupProtocolClient &BackupClientContext::GetConnection()
mHostname << "'...");
// Connect!
- ((SocketStreamTLS *)(mapSocket.get()))->Open(mrTLSContext,
+ ((SocketStreamTLS *)(apSocket.get()))->Open(mrTLSContext,
Socket::TypeINET, mHostname, mPort);
if(mTcpNiceMode)
{
- // Pass control of mapSocket to NiceSocketStream,
+ // Pass control of apSocket to NiceSocketStream,
// which will take care of destroying it for us.
- mapNice.reset(new NiceSocketStream(mapSocket));
- mapConnection.reset(new BackupProtocolClient(*mapNice));
+ // But we need to hang onto a pointer to the nice
+ // socket, so we can enable and disable nice mode.
+ // This is scary, it could be deallocated under us.
+ mpNice = new NiceSocketStream(apSocket);
+ apSocket.reset(mpNice);
}
- else
- {
- mapConnection.reset(new BackupProtocolClient(*mapSocket));
- }
-
+
+ mapConnection.reset(new BackupProtocolClient(apSocket));
+
// Set logging option
mapConnection->SetLogToSysLog(mExtendedLogging);
@@ -165,10 +166,10 @@ BackupProtocolClient &BackupClientContext::GetConnection()
mapConnection->SetLogToFile(mpExtendedLogFileHandle);
}
}
-
+
// Handshake
mapConnection->Handshake();
-
+
// Check the version of the server
{
std::auto_ptr<BackupProtocolVersion> serverVersion(
@@ -192,8 +193,6 @@ BackupProtocolClient &BackupClientContext::GetConnection()
try
{
mapConnection->QueryFinished();
- mapNice.reset();
- mapSocket.reset();
}
catch(...)
{
@@ -222,8 +221,6 @@ BackupProtocolClient &BackupClientContext::GetConnection()
{
// Clean up.
mapConnection.reset();
- mapNice.reset();
- mapSocket.reset();
throw;
}
@@ -269,17 +266,6 @@ void BackupClientContext::CloseAnyOpenConnection()
mapConnection.reset();
}
- try
- {
- // Be nice about closing the socket
- mapNice.reset();
- mapSocket.reset();
- }
- catch(...)
- {
- // Ignore errors
- }
-
// Delete any pending list
if(mpDeleteList != 0)
{
diff --git a/bin/bbackupd/BackupClientContext.h b/bin/bbackupd/BackupClientContext.h
index 1fcc6ede..7e081e2d 100644
--- a/bin/bbackupd/BackupClientContext.h
+++ b/bin/bbackupd/BackupClientContext.h
@@ -214,7 +214,7 @@ public:
{
if(mTcpNiceMode)
{
- mapNice->SetEnabled(enabled);
+ mpNice->SetEnabled(enabled);
}
}
@@ -226,8 +226,6 @@ private:
std::string mHostname;
int mPort;
uint32_t mAccountNumber;
- std::auto_ptr<SocketStream> mapSocket;
- std::auto_ptr<NiceSocketStream> mapNice;
std::auto_ptr<BackupProtocolClient> mapConnection;
bool mExtendedLogging;
bool mExtendedLogToFile;
@@ -246,6 +244,7 @@ private:
int mMaximumDiffingTime;
ProgressNotifier &mrProgressNotifier;
bool mTcpNiceMode;
+ NiceSocketStream *mpNice;
};
#endif // BACKUPCLIENTCONTEXT__H
diff --git a/bin/bbackupquery/bbackupquery.cpp b/bin/bbackupquery/bbackupquery.cpp
index dbac3f27..1ff3b28a 100644
--- a/bin/bbackupquery/bbackupquery.cpp
+++ b/bin/bbackupquery/bbackupquery.cpp
@@ -429,15 +429,16 @@ int main(int argc, const char *argv[])
// 2. Connect to server
if(!quiet) BOX_INFO("Connecting to store...");
- SocketStreamTLS socket;
- socket.Open(tlsContext, Socket::TypeINET,
+ SocketStreamTLS *socket = new SocketStreamTLS;
+ std::auto_ptr<SocketStream> apSocket(socket);
+ socket->Open(tlsContext, Socket::TypeINET,
conf.GetKeyValue("StoreHostname").c_str(),
conf.GetKeyValueInt("StorePort"));
// 3. Make a protocol, and handshake
if(!quiet) BOX_INFO("Handshake with store...");
std::auto_ptr<BackupProtocolClient>
- apConnection(new BackupProtocolClient(socket));
+ apConnection(new BackupProtocolClient(apSocket));
BackupProtocolClient& connection(*(apConnection.get()));
connection.Handshake();
diff --git a/bin/bbstored/BackupStoreDaemon.cpp b/bin/bbstored/BackupStoreDaemon.cpp
index 2649e0a2..8fddf125 100644
--- a/bin/bbstored/BackupStoreDaemon.cpp
+++ b/bin/bbstored/BackupStoreDaemon.cpp
@@ -272,11 +272,11 @@ void BackupStoreDaemon::Run()
// Created: 2003/08/20
//
// --------------------------------------------------------------------------
-void BackupStoreDaemon::Connection(SocketStreamTLS &rStream)
+void BackupStoreDaemon::Connection(std::auto_ptr<SocketStreamTLS> apStream)
{
try
{
- Connection2(rStream);
+ Connection2(apStream);
}
catch(BoxException &e)
{
@@ -304,10 +304,10 @@ void BackupStoreDaemon::Connection(SocketStreamTLS &rStream)
// Created: 2006/11/12
//
// --------------------------------------------------------------------------
-void BackupStoreDaemon::Connection2(SocketStreamTLS &rStream)
+void BackupStoreDaemon::Connection2(std::auto_ptr<SocketStreamTLS> apStream)
{
// Get the common name from the certificate
- std::string clientCommonName(rStream.GetPeerCommonName());
+ std::string clientCommonName(apStream->GetPeerCommonName());
// Log the name
BOX_INFO("Client certificate CN: " << clientCommonName);
@@ -346,7 +346,8 @@ void BackupStoreDaemon::Connection2(SocketStreamTLS &rStream)
}
// Handle a connection with the backup protocol
- BackupProtocolServer server(rStream);
+ std::auto_ptr<SocketStream> apPlainStream(apStream);
+ BackupProtocolServer server(apPlainStream);
server.SetLogToSysLog(mExtendedLogging);
server.SetTimeout(BACKUP_STORE_TIMEOUT);
try
@@ -355,22 +356,22 @@ void BackupStoreDaemon::Connection2(SocketStreamTLS &rStream)
}
catch(...)
{
- LogConnectionStats(id, context.GetAccountName(), rStream);
+ LogConnectionStats(id, context.GetAccountName(), server);
throw;
}
- LogConnectionStats(id, context.GetAccountName(), rStream);
+ LogConnectionStats(id, context.GetAccountName(), server);
context.CleanUp();
}
void BackupStoreDaemon::LogConnectionStats(uint32_t accountId,
- const std::string& accountName, const SocketStreamTLS &s)
+ const std::string& accountName, const BackupProtocolServer &server)
{
// Log the amount of data transferred
BOX_NOTICE("Connection statistics for " <<
BOX_FORMAT_ACCOUNT(accountId) << " "
"(name=" << accountName << "):"
- " IN=" << s.GetBytesRead() <<
- " OUT=" << s.GetBytesWritten() <<
- " NET_IN=" << (s.GetBytesRead() - s.GetBytesWritten()) <<
- " TOTAL=" << (s.GetBytesRead() + s.GetBytesWritten()));
+ " IN=" << server.GetBytesRead() <<
+ " OUT=" << server.GetBytesWritten() <<
+ " NET_IN=" << (server.GetBytesRead() - server.GetBytesWritten()) <<
+ " TOTAL=" << (server.GetBytesRead() + server.GetBytesWritten()));
}
diff --git a/bin/bbstored/BackupStoreDaemon.h b/bin/bbstored/BackupStoreDaemon.h
index ce538477..a2dab5e5 100644
--- a/bin/bbstored/BackupStoreDaemon.h
+++ b/bin/bbstored/BackupStoreDaemon.h
@@ -52,8 +52,8 @@ protected:
virtual void Run();
- virtual void Connection(SocketStreamTLS &rStream);
- void Connection2(SocketStreamTLS &rStream);
+ virtual void Connection(std::auto_ptr<SocketStreamTLS> apStream);
+ void Connection2(std::auto_ptr<SocketStreamTLS> apStream);
virtual const char *DaemonName() const;
virtual std::string DaemonBanner() const;
@@ -64,7 +64,7 @@ protected:
void HousekeepingProcess();
void LogConnectionStats(uint32_t accountId,
- const std::string& accountName, const SocketStreamTLS &s);
+ const std::string& accountName, const BackupProtocolServer &server);
public:
// HousekeepingInterface implementation
diff --git a/lib/backupstore/StoreTestUtils.cpp b/lib/backupstore/StoreTestUtils.cpp
index 99ebb9df..e5b05e9c 100644
--- a/lib/backupstore/StoreTestUtils.cpp
+++ b/lib/backupstore/StoreTestUtils.cpp
@@ -132,24 +132,22 @@ void init_context(TLSContext& rContext)
"testfiles/clientTrustedCAs.pem");
}
-std::auto_ptr<SocketStreamTLS> open_conn(const char *hostname,
+std::auto_ptr<SocketStream> open_conn(const char *hostname,
TLSContext& rContext)
{
init_context(rContext);
std::auto_ptr<SocketStreamTLS> conn(new SocketStreamTLS);
conn->Open(rContext, Socket::TypeINET, hostname,
BOX_PORT_BBSTORED_TEST);
- return conn;
+ return static_cast<std::auto_ptr<SocketStream> >(conn);
}
-std::auto_ptr<BackupProtocolClient> test_server_login(const char *hostname,
- TLSContext& rContext, std::auto_ptr<SocketStreamTLS>& rapConn)
+std::auto_ptr<BackupProtocolCallable> test_server_login(const char *hostname,
+ TLSContext& rContext, int flags)
{
- rapConn = open_conn(hostname, rContext);
-
// Make a protocol
- std::auto_ptr<BackupProtocolClient> protocol(new
- BackupProtocolClient(*rapConn));
+ std::auto_ptr<BackupProtocolCallable> protocol(new
+ BackupProtocolClient(open_conn(hostname, rContext)));
// Check the version
std::auto_ptr<BackupProtocolVersion> serverVersion(
@@ -158,7 +156,7 @@ std::auto_ptr<BackupProtocolClient> test_server_login(const char *hostname,
// Login
std::auto_ptr<BackupProtocolLoginConfirmed> loginConf(
- protocol->QueryLogin(0x01234567, 0));
+ protocol->QueryLogin(0x01234567, flags));
return protocol;
}
diff --git a/lib/backupstore/StoreTestUtils.h b/lib/backupstore/StoreTestUtils.h
index d15404fa..3365a7f4 100644
--- a/lib/backupstore/StoreTestUtils.h
+++ b/lib/backupstore/StoreTestUtils.h
@@ -37,12 +37,12 @@ void set_refcount(int64_t ObjectID, uint32_t RefCount = 1);
void init_context(TLSContext& rContext);
//! Opens a connection to the server (bbstored).
-std::auto_ptr<SocketStreamTLS> open_conn(const char *hostname,
+std::auto_ptr<SocketStream> open_conn(const char *hostname,
TLSContext& rContext);
//! Opens a connection to the server (bbstored) and logs in.
-std::auto_ptr<BackupProtocolClient> test_server_login(const char *hostname,
- TLSContext& rContext, std::auto_ptr<SocketStreamTLS>& rapConn);
+std::auto_ptr<BackupProtocolCallable> test_server_login(const char *hostname,
+ TLSContext& rContext, int flags = 0);
//! Checks the number of files of each type in the store against expectations.
bool check_num_files(int files, int old, int deleted, int dirs);
diff --git a/lib/httpserver/HTTPServer.cpp b/lib/httpserver/HTTPServer.cpp
index be1db687..d36be473 100644
--- a/lib/httpserver/HTTPServer.cpp
+++ b/lib/httpserver/HTTPServer.cpp
@@ -132,10 +132,10 @@ void HTTPServer::Run()
// Created: 26/3/04
//
// --------------------------------------------------------------------------
-void HTTPServer::Connection(SocketStream &rStream)
+void HTTPServer::Connection(std::auto_ptr<SocketStream> apConn)
{
// Create a get line object to use
- IOStreamGetLine getLine(rStream);
+ IOStreamGetLine getLine(*apConn);
// Notify dervived claases
HTTPConnectionOpening();
@@ -152,7 +152,7 @@ void HTTPServer::Connection(SocketStream &rStream)
}
// Generate a response
- HTTPResponse response(&rStream);
+ HTTPResponse response(apConn.get());
try
{
diff --git a/lib/httpserver/HTTPServer.h b/lib/httpserver/HTTPServer.h
index d9f74949..91f4e96c 100644
--- a/lib/httpserver/HTTPServer.h
+++ b/lib/httpserver/HTTPServer.h
@@ -62,7 +62,7 @@ private:
const char *DaemonName() const;
const ConfigurationVerify *GetConfigVerify() const;
void Run();
- void Connection(SocketStream &rStream);
+ void Connection(std::auto_ptr<SocketStream> apStream);
};
// Root level
diff --git a/lib/server/Protocol.cpp b/lib/server/Protocol.cpp
index 382f1c37..3c661154 100644
--- a/lib/server/Protocol.cpp
+++ b/lib/server/Protocol.cpp
@@ -19,7 +19,7 @@
#include "Protocol.h"
#include "ProtocolWire.h"
-#include "IOStream.h"
+#include "SocketStream.h"
#include "ServerException.h"
#include "PartialReadStream.h"
#include "ProtocolUncertainStream.h"
@@ -44,8 +44,8 @@
// Created: 2003/08/19
//
// --------------------------------------------------------------------------
-Protocol::Protocol(IOStream &rStream)
-: mrStream(rStream),
+Protocol::Protocol(std::auto_ptr<SocketStream> apConn)
+: mapConn(apConn),
mHandshakeDone(false),
mMaxObjectSize(PROTOCOL_DEFAULT_MAXOBJSIZE),
mTimeout(PROTOCOL_DEFAULT_TIMEOUT),
@@ -103,8 +103,8 @@ void Protocol::Handshake()
::strncpy(hsSend.mIdent, GetProtocolIdentString(), sizeof(hsSend.mIdent));
// Send it
- mrStream.Write(&hsSend, sizeof(hsSend));
- mrStream.WriteAllBuffered();
+ mapConn->Write(&hsSend, sizeof(hsSend));
+ mapConn->WriteAllBuffered();
// Receive a handshake from the peer
PW_Handshake hsReceive;
@@ -114,7 +114,7 @@ void Protocol::Handshake()
while(bytesToRead > 0)
{
// Get some data from the stream
- int bytesRead = mrStream.Read(readInto, bytesToRead, mTimeout);
+ int bytesRead = mapConn->Read(readInto, bytesToRead, mTimeout);
if(bytesRead == 0)
{
THROW_EXCEPTION(ConnectionException, Conn_Protocol_Timeout)
@@ -158,7 +158,8 @@ void Protocol::CheckAndReadHdr(void *hdr)
}
// Get some data into this header
- if(!mrStream.ReadFullBuffer(hdr, sizeof(PW_ObjectHeader), 0 /* not interested in bytes read if this fails */, mTimeout))
+ if(!mapConn->ReadFullBuffer(hdr, sizeof(PW_ObjectHeader),
+ 0 /* not interested in bytes read if this fails */, mTimeout))
{
THROW_EXCEPTION(ConnectionException, Conn_Protocol_Timeout)
}
@@ -199,7 +200,8 @@ std::auto_ptr<Message> Protocol::ReceiveInternal()
EnsureBufferAllocated(objSize);
// Read data
- if(!mrStream.ReadFullBuffer(mpBuffer, objSize - sizeof(objHeader), 0 /* not interested in bytes read if this fails */, mTimeout))
+ if(!mapConn->ReadFullBuffer(mpBuffer, objSize - sizeof(objHeader),
+ 0 /* not interested in bytes read if this fails */, mTimeout))
{
THROW_EXCEPTION(ConnectionException, Conn_Protocol_Timeout)
}
@@ -292,8 +294,8 @@ void Protocol::SendInternal(const Message &rObject)
pobjHeader->mObjType = htonl(rObject.GetType());
// Write data
- mrStream.Write(mpBuffer, writtenSize);
- mrStream.WriteAllBuffered();
+ mapConn->Write(mpBuffer, writtenSize);
+ mapConn->WriteAllBuffered();
}
// --------------------------------------------------------------------------
@@ -647,13 +649,13 @@ std::auto_ptr<IOStream> Protocol::ReceiveStream()
{
BOX_TRACE("Receiving stream, size uncertain");
return std::auto_ptr<IOStream>(
- new ProtocolUncertainStream(mrStream));
+ new ProtocolUncertainStream(*mapConn));
}
else
{
BOX_TRACE("Receiving stream, size " << streamSize << " bytes");
return std::auto_ptr<IOStream>(
- new PartialReadStream(mrStream, streamSize));
+ new PartialReadStream(*mapConn, streamSize));
}
}
@@ -709,7 +711,7 @@ void Protocol::SendStream(IOStream &rStream)
objHeader.mObjType = htonl(SPECIAL_STREAM_OBJECT_TYPE);
// Write header
- mrStream.Write(&objHeader, sizeof(objHeader));
+ mapConn->Write(&objHeader, sizeof(objHeader));
// Could be sent in one of two ways
if(uncertainSize)
{
@@ -744,7 +746,7 @@ void Protocol::SendStream(IOStream &rStream)
// Send final byte to finish the stream
BOX_TRACE("Sending end of stream byte");
uint8_t endOfStream = ProtocolStreamHeader_EndOfStream;
- mrStream.Write(&endOfStream, 1);
+ mapConn->Write(&endOfStream, 1);
BOX_TRACE("Sent end of stream byte");
}
catch(...)
@@ -759,13 +761,13 @@ void Protocol::SendStream(IOStream &rStream)
else
{
// Fixed size stream, send it all in one go
- if(!rStream.CopyStreamTo(mrStream, mTimeout, 4096 /* slightly larger buffer */))
+ if(!rStream.CopyStreamTo(*mapConn, mTimeout, 4096 /* slightly larger buffer */))
{
THROW_EXCEPTION(ConnectionException, Conn_Protocol_TimeOutWhenSendingStream)
}
}
// Make sure everything is written
- mrStream.WriteAllBuffered();
+ mapConn->WriteAllBuffered();
}
@@ -816,7 +818,7 @@ int Protocol::SendStreamSendBlock(uint8_t *Block, int BytesInBlock)
Block[-1] = header;
// Write everything out
- mrStream.Write(Block - 1, writeSize + 1);
+ mapConn->Write(Block - 1, writeSize + 1);
BOX_TRACE("Sent " << (writeSize+1) << " bytes to stream");
// move the remainer to the beginning of the block for the next time round
@@ -1177,6 +1179,12 @@ const uint16_t Protocol::sProtocolStreamHeaderLengths[256] =
0 // 255 = special (reserved)
};
+int64_t Protocol::GetBytesRead() const
+{
+ return mapConn->GetBytesRead();
+}
-
-
+int64_t Protocol::GetBytesWritten() const
+{
+ return mapConn->GetBytesWritten();
+}
diff --git a/lib/server/Protocol.h b/lib/server/Protocol.h
index 42cb0ff8..0995393d 100644
--- a/lib/server/Protocol.h
+++ b/lib/server/Protocol.h
@@ -19,6 +19,7 @@
#include "Message.h"
class IOStream;
+class SocketStream;
// default timeout is 15 minutes
#define PROTOCOL_DEFAULT_TIMEOUT (15*60*1000)
@@ -36,7 +37,7 @@ class IOStream;
class Protocol
{
public:
- Protocol(IOStream &rStream);
+ Protocol(std::auto_ptr<SocketStream> apConn);
virtual ~Protocol();
private:
@@ -175,6 +176,8 @@ public:
FILE *GetLogToFile() { return mLogToFile; }
void SetLogToSysLog(bool Log = false) {mLogToSysLog = Log;}
void SetLogToFile(FILE *File = 0) {mLogToFile = File;}
+ int64_t GetBytesRead() const;
+ int64_t GetBytesWritten() const;
protected:
virtual std::auto_ptr<Message> MakeMessage(int ObjType) = 0;
@@ -190,7 +193,7 @@ private:
void EnsureBufferAllocated(int Size);
int SendStreamSendBlock(uint8_t *Block, int BytesInBlock);
- IOStream &mrStream;
+ std::auto_ptr<SocketStream> mapConn;
bool mHandshakeDone;
unsigned int mMaxObjectSize;
int mTimeout;
diff --git a/lib/server/ServerStream.h b/lib/server/ServerStream.h
index a9b56169..8bb52b5b 100644
--- a/lib/server/ServerStream.h
+++ b/lib/server/ServerStream.h
@@ -286,7 +286,7 @@ public:
#endif
// The derived class does some server magic with the connection
- HandleConnection(*connection);
+ HandleConnection(connection);
// Since rChildExit == true, the forked process will call _exit() on return from this fn
return;
@@ -305,7 +305,7 @@ public:
#endif // !WIN32
// Just handle in this process
SetProcessTitle("handling");
- HandleConnection(*connection);
+ HandleConnection(connection);
SetProcessTitle("idle");
#ifndef WIN32
}
@@ -377,12 +377,12 @@ public:
}
#endif
- virtual void HandleConnection(StreamType &rStream)
+ virtual void HandleConnection(std::auto_ptr<StreamType> apStream)
{
- Connection(rStream);
+ Connection(apStream);
}
- virtual void Connection(StreamType &rStream) = 0;
+ virtual void Connection(std::auto_ptr<StreamType> apStream) = 0;
protected:
// For checking code in derived classes -- use if you have an algorithm which
diff --git a/lib/server/ServerTLS.h b/lib/server/ServerTLS.h
index a74a671e..20e55964 100644
--- a/lib/server/ServerTLS.h
+++ b/lib/server/ServerTLS.h
@@ -59,11 +59,11 @@ public:
ForkToHandleRequests>::Run2(rChildExit);
}
- virtual void HandleConnection(SocketStreamTLS &rStream)
+ virtual void HandleConnection(std::auto_ptr<SocketStreamTLS> apStream)
{
- rStream.Handshake(mContext, true /* is server */);
+ apStream->Handshake(mContext, true /* is server */);
// this-> in next line required to build under some gcc versions
- this->Connection(rStream);
+ this->Connection(apStream);
}
private:
diff --git a/lib/server/TcpNice.h b/lib/server/TcpNice.h
index e2027749..4381df42 100644
--- a/lib/server/TcpNice.h
+++ b/lib/server/TcpNice.h
@@ -83,7 +83,7 @@ private:
//
// --------------------------------------------------------------------------
-class NiceSocketStream : public IOStream
+class NiceSocketStream : public SocketStream
{
private:
std::auto_ptr<SocketStream> mapSocket;
@@ -166,6 +166,10 @@ public:
}
virtual void SetEnabled(bool enabled);
+ off_t GetBytesRead() const { return mapSocket->GetBytesRead(); }
+ off_t GetBytesWritten() const { return mapSocket->GetBytesWritten(); }
+ void ResetCounters() { mapSocket->ResetCounters(); }
+
private:
NiceSocketStream(const NiceSocketStream &rToCopy)
{ /* do not call */ }
diff --git a/lib/server/makeprotocol.pl.in b/lib/server/makeprotocol.pl.in
index 35814d1d..00dc58d4 100755
--- a/lib/server/makeprotocol.pl.in
+++ b/lib/server/makeprotocol.pl.in
@@ -158,7 +158,9 @@ print CPP <<__E;
#include <sstream>
#include "$filename_base.h"
-#include "IOStream.h"
+#include "CollectInBufferStream.h"
+#include "SocketStream.h"
+#include "MemBlockStream.h"
__E
print H <<__E;
@@ -179,8 +181,7 @@ print H <<__E;
#include "ServerException.h"
class IOStream;
-
-
+class SocketStream;
__E
# extra headers
@@ -741,7 +742,7 @@ __E
else
{
print H <<__E;
- $server_or_client_class(IOStream &rStream);
+ $server_or_client_class(std::auto_ptr<SocketStream> apConn);
std::auto_ptr<$message_base_class> Receive();
void Send(const $message_base_class &rObject);
__E
@@ -885,8 +886,8 @@ __E
else
{
print CPP <<__E;
-$server_or_client_class\::$server_or_client_class(IOStream &rStream)
-: Protocol(rStream)
+$server_or_client_class\::$server_or_client_class(std::auto_ptr<SocketStream> apConn)
+: Protocol(apConn)
{ }
__E
}
diff --git a/test/backupstore/testbackupstore.cpp b/test/backupstore/testbackupstore.cpp
index 901cf31a..24144d65 100644
--- a/test/backupstore/testbackupstore.cpp
+++ b/test/backupstore/testbackupstore.cpp
@@ -661,18 +661,9 @@ void recursive_count_objects(const char *hostname, int64_t id, recursive_count_o
"testfiles/clientTrustedCAs.pem");
// Get a connection
- // TODO FIXME replace with test_server_login
- SocketStreamTLS connReadOnly;
- connReadOnly.Open(context, Socket::TypeINET, hostname,
- BOX_PORT_BBSTORED_TEST);
- BackupProtocolClient protocolReadOnly(connReadOnly);
+ BackupProtocolLocal2 protocolReadOnly(0x01234567, "test",
+ "backup/01234567/", 0, false);
- {
- std::auto_ptr<BackupProtocolVersion> serverVersion(protocolReadOnly.QueryVersion(BACKUP_STORE_SERVER_VERSION));
- TEST_THAT(serverVersion->GetVersion() == BACKUP_STORE_SERVER_VERSION);
- std::auto_ptr<BackupProtocolLoginConfirmed> loginConf(protocolReadOnly.QueryLogin(0x01234567, BackupProtocolLogin::Flags_ReadOnly));
- }
-
// Count objects
recursive_count_objects_r(protocolReadOnly, id, results);
@@ -1171,23 +1162,16 @@ bool test_multiple_uploads()
SETUP();
TEST_THAT_THROWONFAIL(StartServer());
- std::auto_ptr<SocketStreamTLS> conn;
std::auto_ptr<BackupProtocolCallable> apProtocol(
- test_server_login("localhost", context, conn).release());
+ test_server_login("localhost", context).release());
#ifndef WIN32
// Open a new connection which is read only
- std::auto_ptr<SocketStreamTLS> conn2(new SocketStreamTLS);
- // TODO FIXME replace with test_server_login
- conn2->Open(context, Socket::TypeINET, "localhost",
- BOX_PORT_BBSTORED_TEST);
- BackupProtocolClient protocolReadOnly(*conn2);
-
- {
- std::auto_ptr<BackupProtocolVersion> serverVersion(protocolReadOnly.QueryVersion(BACKUP_STORE_SERVER_VERSION));
- TEST_THAT(serverVersion->GetVersion() == BACKUP_STORE_SERVER_VERSION);
- std::auto_ptr<BackupProtocolLoginConfirmed> loginConf(protocolReadOnly.QueryLogin(0x01234567, BackupProtocolLogin::Flags_ReadOnly));
- }
+ // TODO FIXME replace protocolReadOnly with apProtocolReadOnly.
+ std::auto_ptr<BackupProtocolCallable> apProtocolReadOnly =
+ test_server_login("localhost", context,
+ BackupProtocolLogin::Flags_ReadOnly);
+ BackupProtocolCallable& protocolReadOnly(*apProtocolReadOnly);
#else // WIN32
BackupProtocolCallable& protocolReadOnly(*apProtocol);
#endif
@@ -1266,7 +1250,7 @@ bool test_multiple_uploads()
apProtocol->QueryFinished();
TEST_THAT(run_housekeeping_and_check_account());
- apProtocol = test_server_login("localhost", context, conn);
+ apProtocol = test_server_login("localhost", context);
TEST_THAT(check_num_files(expected_num_current_files,
expected_num_old_files, 0, 1));
@@ -1371,7 +1355,7 @@ bool test_multiple_uploads()
" -c testfiles/bbstored.conf check 01234567 fix") == 0);
TestRemoteProcessMemLeaks("bbstoreaccounts.memleaks");
- apProtocol = test_server_login("localhost", context, conn);
+ apProtocol = test_server_login("localhost", context);
TEST_THAT(check_num_files(UPLOAD_NUM - 4, 3, 2, 1));
@@ -1672,7 +1656,7 @@ bool test_multiple_uploads()
apProtocol->QueryFinished();
TEST_THAT(run_housekeeping_and_check_account());
- apProtocol = test_server_login("localhost", context, conn);
+ apProtocol = test_server_login("localhost", context);
// Query names -- test that invalid stuff returns not found OK
{
@@ -2053,12 +2037,9 @@ bool test_login_without_account()
// BLOCK
{
// Open a connection to the server
- SocketStreamTLS conn;
- conn.Open(context, Socket::TypeINET, "localhost",
- BOX_PORT_BBSTORED_TEST);
-
- // Make a protocol
- BackupProtocolClient protocol(conn);
+ std::auto_ptr<BackupProtocolCallable> apProtocol(new
+ BackupProtocolClient(open_conn("localhost", context)));
+ BackupProtocolCallable& protocol(*apProtocol);
// Check the version
std::auto_ptr<BackupProtocolVersion> serverVersion(protocol.QueryVersion(BACKUP_STORE_SERVER_VERSION));
@@ -2140,12 +2121,9 @@ bool test_login_with_disabled_account()
// BLOCK
{
// Open a connection to the server
- SocketStreamTLS conn;
- conn.Open(context, Socket::TypeINET, "localhost",
- BOX_PORT_BBSTORED_TEST);
-
- // Make a protocol
- BackupProtocolClient protocol(conn);
+ std::auto_ptr<BackupProtocolCallable> apProtocol(new
+ BackupProtocolClient(open_conn("localhost", context)));
+ BackupProtocolCallable& protocol(*apProtocol);
// Check the version
std::auto_ptr<BackupProtocolVersion> serverVersion(protocol.QueryVersion(BACKUP_STORE_SERVER_VERSION));
@@ -2189,8 +2167,7 @@ bool test_login_with_no_refcount_db()
// the refcount db.
TEST_EQUAL(0, ::unlink("testfiles/0_0/backup/01234567/refcount.rdb.rfw"));
- std::auto_ptr<SocketStreamTLS> conn;
- TEST_CHECK_THROWS(test_server_login("localhost", context, conn),
+ TEST_CHECK_THROWS(test_server_login("localhost", context),
ConnectionException, Conn_TLSReadFailed);
TEST_THAT(ServerIsAlive(bbstored_pid));
@@ -2328,20 +2305,10 @@ bool test_account_limits_respected()
// Try to upload a file and create a directory, and check an error is generated
{
// Open a connection to the server
- SocketStreamTLS conn;
- conn.Open(context, Socket::TypeINET, "localhost",
- BOX_PORT_BBSTORED_TEST);
-
- // Make a protocol
- BackupProtocolClient protocol(conn);
-
- // Check the version
- std::auto_ptr<BackupProtocolVersion> serverVersion(protocol.QueryVersion(BACKUP_STORE_SERVER_VERSION));
- TEST_THAT(serverVersion->GetVersion() == BACKUP_STORE_SERVER_VERSION);
+ std::auto_ptr<BackupProtocolCallable> apProtocol =
+ test_server_login("localhost", context);
+ BackupProtocolCallable& protocol(*apProtocol);
- // Login
- std::auto_ptr<BackupProtocolLoginConfirmed> loginConf(protocol.QueryLogin(0x01234567, 0));
-
int64_t modtime = 0;
write_test_file(3);
diff --git a/test/backupstorepatch/testbackupstorepatch.cpp b/test/backupstorepatch/testbackupstorepatch.cpp
index 42594652..96b93308 100644
--- a/test/backupstorepatch/testbackupstorepatch.cpp
+++ b/test/backupstorepatch/testbackupstorepatch.cpp
@@ -345,12 +345,13 @@ int test(int argc, const char *argv[])
{
// Open a connection to the server
- SocketStreamTLS conn;
- conn.Open(context, Socket::TypeINET, "localhost",
+ SocketStreamTLS *pConn = new SocketStreamTLS;
+ std::auto_ptr<SocketStream> apConn(pConn);
+ pConn->Open(context, Socket::TypeINET, "localhost",
BOX_PORT_BBSTORED_TEST);
// Make a protocol
- BackupProtocolClient protocol(conn);
+ BackupProtocolClient protocol(apConn);
// Login
{
@@ -454,7 +455,6 @@ int test(int argc, const char *argv[])
// Finish the connection
protocol.QueryFinished();
- conn.Close();
}
// Fill in initial dependency information
@@ -528,10 +528,11 @@ int test(int argc, const char *argv[])
}
// Open a connection to the server (need to do this each time, otherwise housekeeping won't delete files)
- SocketStreamTLS conn;
- conn.Open(context, Socket::TypeINET, "localhost",
+ SocketStreamTLS *pConn = new SocketStreamTLS;
+ std::auto_ptr<SocketStream> apConn(pConn);
+ pConn->Open(context, Socket::TypeINET, "localhost",
BOX_PORT_BBSTORED_TEST);
- BackupProtocolClient protocol(conn);
+ BackupProtocolClient protocol(apConn);
{
std::auto_ptr<BackupProtocolVersion> serverVersion(protocol.QueryVersion(BACKUP_STORE_SERVER_VERSION));
TEST_THAT(serverVersion->GetVersion() == BACKUP_STORE_SERVER_VERSION);
@@ -583,7 +584,6 @@ int test(int argc, const char *argv[])
// Close the connection
protocol.QueryFinished();
- conn.Close();
// Mark one of the elements as deleted
if(test_file_remove_order[deleteIndex] == -1)
diff --git a/test/basicserver/testbasicserver.cpp b/test/basicserver/testbasicserver.cpp
index 5a13cb45..fe06e8c2 100644
--- a/test/basicserver/testbasicserver.cpp
+++ b/test/basicserver/testbasicserver.cpp
@@ -170,7 +170,7 @@ public:
testserver() {}
~testserver() {}
- void Connection(SocketStream &rStream);
+ void Connection(std::auto_ptr<SocketStream> apStream);
virtual const char *DaemonName() const
{
@@ -210,9 +210,9 @@ const ConfigurationVerify *testserver::GetConfigVerify() const
return &verify;
}
-void testserver::Connection(SocketStream &rStream)
+void testserver::Connection(std::auto_ptr<SocketStream> apStream)
{
- testservers_connection(rStream);
+ testservers_connection(*apStream);
}
class testProtocolServer : public testserver
@@ -221,7 +221,7 @@ public:
testProtocolServer() {}
~testProtocolServer() {}
- void Connection(SocketStream &rStream);
+ void Connection(std::auto_ptr<SocketStream> apStream);
virtual const char *DaemonName() const
{
@@ -229,9 +229,9 @@ public:
}
};
-void testProtocolServer::Connection(SocketStream &rStream)
+void testProtocolServer::Connection(std::auto_ptr<SocketStream> apStream)
{
- TestProtocolServer server(rStream);
+ TestProtocolServer server(apStream);
TestContext context;
server.DoServer(context);
}
@@ -243,7 +243,7 @@ public:
testTLSserver() {}
~testTLSserver() {}
- void Connection(SocketStreamTLS &rStream);
+ void Connection(std::auto_ptr<SocketStreamTLS> apStream);
virtual const char *DaemonName() const
{
@@ -283,9 +283,9 @@ const ConfigurationVerify *testTLSserver::GetConfigVerify() const
return &verify;
}
-void testTLSserver::Connection(SocketStreamTLS &rStream)
+void testTLSserver::Connection(std::auto_ptr<SocketStreamTLS> apStream)
{
- testservers_connection(rStream);
+ testservers_connection(*apStream);
}
@@ -691,15 +691,15 @@ int test(int argc, const char *argv[])
TEST_THAT(ServerIsAlive(pid));
// Open a connection to it
- SocketStream conn;
+ std::auto_ptr<SocketStream> apConn(new SocketStream);
#ifdef WIN32
- conn.Open(Socket::TypeINET, "localhost", 2003);
+ apConn->Open(Socket::TypeINET, "localhost", 2003);
#else
- conn.Open(Socket::TypeUNIX, "testfiles/srv4.sock");
+ apConn->Open(Socket::TypeUNIX, "testfiles/srv4.sock");
#endif
// Create a protocol
- TestProtocolClient protocol(conn);
+ TestProtocolClient protocol(apConn);
// Simple query
{
diff --git a/test/bbackupd/testbbackupd.cpp b/test/bbackupd/testbbackupd.cpp
index 4fdd12e9..28651233 100644
--- a/test/bbackupd/testbbackupd.cpp
+++ b/test/bbackupd/testbbackupd.cpp
@@ -468,10 +468,11 @@ void do_interrupted_restore(const TLSContext &context, int64_t restoredirid)
// child process
{
// connect and log in
- SocketStreamTLS conn;
- conn.Open(context, Socket::TypeINET, "localhost",
- 22011);
- BackupProtocolClient protocol(conn);
+ SocketStreamTLS* pConn = new SocketStreamTLS;
+ std::auto_ptr<SocketStream> apConn(pConn);
+ pConn->Open(context, Socket::TypeINET, "localhost", 22011);
+ BackupProtocolClient protocol(apConn);
+
protocol.QueryVersion(BACKUP_STORE_SERVER_VERSION);
std::auto_ptr<BackupProtocolLoginConfirmed>
loginConf(protocol.QueryLogin(0x01234567,
@@ -567,17 +568,17 @@ int64_t SearchDir(BackupStoreDirectory& rDir,
return id;
}
-SocketStreamTLS sSocket;
-
std::auto_ptr<BackupProtocolClient> Connect(TLSContext& rContext)
{
- sSocket.Open(rContext, Socket::TypeINET,
- "localhost", 22011);
- std::auto_ptr<BackupProtocolClient> connection;
- connection.reset(new BackupProtocolClient(sSocket));
- connection->Handshake();
+ SocketStreamTLS* pConn = new SocketStreamTLS;
+ std::auto_ptr<SocketStream> apConn(pConn);
+ pConn->Open(rContext, Socket::TypeINET, "localhost", 22011);
+
+ std::auto_ptr<BackupProtocolClient> client;
+ client.reset(new BackupProtocolClient(apConn));
+ client->Handshake();
std::auto_ptr<BackupProtocolVersion>
- serverVersion(connection->QueryVersion(
+ serverVersion(client->QueryVersion(
BACKUP_STORE_SERVER_VERSION));
if(serverVersion->GetVersion() !=
BACKUP_STORE_SERVER_VERSION)
@@ -585,15 +586,15 @@ std::auto_ptr<BackupProtocolClient> Connect(TLSContext& rContext)
THROW_EXCEPTION(BackupStoreException,
WrongServerVersion);
}
- return connection;
+ return client;
}
std::auto_ptr<BackupProtocolClient> ConnectAndLogin(TLSContext& rContext,
int flags)
{
- std::auto_ptr<BackupProtocolClient> connection(Connect(rContext));
- connection->QueryLogin(0x01234567, flags);
- return connection;
+ std::auto_ptr<BackupProtocolClient> client(Connect(rContext));
+ client->QueryLogin(0x01234567, flags);
+ return client;
}
std::auto_ptr<BackupStoreDirectory> ReadDirectory
@@ -885,7 +886,6 @@ int test_bbackupd()
}
client->QueryFinished();
- sSocket.Close();
}
// unpack the files for the initial test
@@ -1274,7 +1274,6 @@ int test_bbackupd()
TEST_THAT(check_num_blocks(*client, 10, expected_blocks_old,
0, 18, 28 + expected_blocks_old));
client->QueryFinished();
- sSocket.Close();
}
std::string cmd = BBACKUPD " " + bbackupd_args +
@@ -1381,7 +1380,6 @@ int test_bbackupd()
TEST_THAT(check_num_files(4, 0, 0, 8));
TEST_THAT(check_num_blocks(*client, 8, 0, 0, 16, 24));
client->QueryFinished();
- sSocket.Close();
}
if (failures) return 1;
@@ -1515,7 +1513,6 @@ int test_bbackupd()
// Log out.
client->QueryFinished();
- sSocket.Close();
}
BOX_TRACE("done.");
@@ -1564,7 +1561,6 @@ int test_bbackupd()
// Log out.
client->QueryFinished();
- sSocket.Close();
}
if (failures) return 1;
@@ -1602,7 +1598,6 @@ int test_bbackupd()
// i.e. 2 new files, 1 new directory
client->QueryFinished();
- sSocket.Close();
}
if (failures) return 1;
@@ -1978,7 +1973,6 @@ int test_bbackupd()
int64_t testDirId = SearchDir(*dir, "Test2");
TEST_THAT(testDirId == 0);
client->QueryFinished();
- sSocket.Close();
}
// create the location directory and unpack some files into it
@@ -2018,7 +2012,6 @@ int test_bbackupd()
TEST_THAT(testDirId != 0);
client->QueryFinished();
- sSocket.Close();
}
printf("\n==== Testing that redundant locations are deleted on time\n");
@@ -2060,7 +2053,6 @@ int test_bbackupd()
TEST_THAT(testDirId != 0);
client->QueryFinished();
- sSocket.Close();
}
wait_for_sync_end();
@@ -2078,7 +2070,6 @@ int test_bbackupd()
TEST_THAT(test_entry_deleted(*root_dir, "Test2"));
client->QueryFinished();
- sSocket.Close();
}
}
@@ -2325,7 +2316,6 @@ int test_bbackupd()
TEST_THAT(SearchDir(*dir, filename.c_str()) != 0);
// Log out
client->QueryFinished();
- sSocket.Close();
}
// Check that bbackupquery shows the dir in console encoding
@@ -3239,7 +3229,6 @@ int test_bbackupd()
TEST_THAT(!SearchDir(*dir, "xx_not_this_dir_22"));
TEST_THAT(!SearchDir(*dir, "somefile.excludethis"));
client->QueryFinished();
- sSocket.Close();
}
TEST_THAT(ServerIsAlive(bbackupd_pid));
@@ -3449,7 +3438,6 @@ int test_bbackupd()
// Log out
client->QueryFinished();
- sSocket.Close();
}
// Compare the restored files
@@ -3644,7 +3632,6 @@ int test_bbackupd()
// Log out
protocol->QueryFinished();
- sSocket.Close();
}
catch(...)
{
@@ -3734,7 +3721,7 @@ int test_bbackupd()
== Restore_Complete);
client->QueryFinished();
- sSocket.Close();
+ client.reset();
// Then check it has restored the correct stuff
TEST_COMPARE(Compare_Same);
@@ -3764,7 +3751,7 @@ int test_bbackupd()
== Restore_Complete);
client->QueryFinished();
- sSocket.Close();
+ client.reset();
// Do a compare with the now undeleted files
compareReturnValue = ::system(BBACKUPQUERY " "