summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Wilson <chris+github@qwirx.com>2014-03-02 08:59:39 +0000
committerChris Wilson <chris+github@qwirx.com>2014-03-02 08:59:39 +0000
commitc30a30907d26fb25f449fb3f03274418f935b0a7 (patch)
treeef9b22698b5ca5e035bfe234722a8854e4b357c5
parentb88db70703097ae8e5894e6dd5af2c5b672799a9 (diff)
Always flush any incoming stream on server side.
Otherwise the protocol might be broken and can't be used any more, even if we made an effort to return an Error reply instead of throwing an exception. This used to not be a problem because an Error reply would terminate the connection anyway, but it no longer does. So if the client also didn't terminate, but tried to handle the exception and keep using the connection, then it might find that its next command fails because the protocol is broken.
-rw-r--r--lib/backupstore/BackupCommands.cpp38
-rw-r--r--lib/common/SelfFlushingStream.h6
-rwxr-xr-xlib/server/makeprotocol.pl.in85
-rw-r--r--test/backupstore/testbackupstore.cpp13
4 files changed, 99 insertions, 43 deletions
diff --git a/lib/backupstore/BackupCommands.cpp b/lib/backupstore/BackupCommands.cpp
index a0788f32..2d927358 100644
--- a/lib/backupstore/BackupCommands.cpp
+++ b/lib/backupstore/BackupCommands.cpp
@@ -227,7 +227,9 @@ std::auto_ptr<BackupProtocolMessage> BackupProtocolListDirectory::DoCommand(Back
// Created: 2003/09/02
//
// --------------------------------------------------------------------------
-std::auto_ptr<BackupProtocolMessage> BackupProtocolStoreFile::DoCommand(BackupProtocolReplyable &rProtocol, BackupStoreContext &rContext) const
+std::auto_ptr<BackupProtocolMessage> BackupProtocolStoreFile::DoCommand(
+ BackupProtocolReplyable &rProtocol, BackupStoreContext &rContext,
+ IOStream& rDataStream) const
{
CHECK_PHASE(Phase_Commands)
CHECK_WRITEABLE_SESSION
@@ -249,14 +251,11 @@ std::auto_ptr<BackupProtocolMessage> BackupProtocolStoreFile::DoCommand(BackupPr
}
}
- // A stream follows, which contains the file
- std::auto_ptr<IOStream> filestream(rProtocol.ReceiveStream());
-
// Ask the context to store it
int64_t id = 0;
try
{
- id = rContext.AddFile(*filestream, mDirectoryObjectID,
+ id = rContext.AddFile(rDataStream, mDirectoryObjectID,
mModificationTime, mAttributesHash, mDiffFromFileID,
mFilename,
true /* mark files with same name as old versions */);
@@ -469,11 +468,12 @@ std::auto_ptr<BackupProtocolMessage> BackupProtocolGetFile::DoCommand(BackupProt
//
// --------------------------------------------------------------------------
std::auto_ptr<BackupProtocolMessage> BackupProtocolCreateDirectory::DoCommand(
- BackupProtocolReplyable &rProtocol, BackupStoreContext &rContext) const
+ BackupProtocolReplyable &rProtocol, BackupStoreContext &rContext,
+ IOStream& rDataStream) const
{
return BackupProtocolCreateDirectory2(mContainingDirectoryID,
mAttributesModTime, 0 /* ModificationTime */,
- mDirectoryName).DoCommand(rProtocol, rContext);
+ mDirectoryName).DoCommand(rProtocol, rContext, rDataStream);
}
@@ -488,17 +488,16 @@ std::auto_ptr<BackupProtocolMessage> BackupProtocolCreateDirectory::DoCommand(
//
// --------------------------------------------------------------------------
std::auto_ptr<BackupProtocolMessage> BackupProtocolCreateDirectory2::DoCommand(
- BackupProtocolReplyable &rProtocol, BackupStoreContext &rContext) const
+ BackupProtocolReplyable &rProtocol, BackupStoreContext &rContext,
+ IOStream& rDataStream) const
{
CHECK_PHASE(Phase_Commands)
CHECK_WRITEABLE_SESSION
- // Get the stream containing the attributes
- std::auto_ptr<IOStream> attrstream(rProtocol.ReceiveStream());
// Collect the attributes -- do this now so no matter what the outcome,
// the data has been absorbed.
StreamableMemBlock attr;
- attr.Set(*attrstream, rProtocol.GetTimeout());
+ attr.Set(rDataStream, rProtocol.GetTimeout());
// Check to see if the hard limit has been exceeded
if(rContext.HardLimitExceeded())
@@ -547,17 +546,17 @@ std::auto_ptr<BackupProtocolMessage> BackupProtocolCreateDirectory2::DoCommand(
// Created: 2003/09/06
//
// --------------------------------------------------------------------------
-std::auto_ptr<BackupProtocolMessage> BackupProtocolChangeDirAttributes::DoCommand(BackupProtocolReplyable &rProtocol, BackupStoreContext &rContext) const
+std::auto_ptr<BackupProtocolMessage> BackupProtocolChangeDirAttributes::DoCommand(
+ BackupProtocolReplyable &rProtocol, BackupStoreContext &rContext,
+ IOStream& rDataStream) const
{
CHECK_PHASE(Phase_Commands)
CHECK_WRITEABLE_SESSION
- // Get the stream containing the attributes
- std::auto_ptr<IOStream> attrstream(rProtocol.ReceiveStream());
// Collect the attributes -- do this now so no matter what the outcome,
// the data has been absorbed.
StreamableMemBlock attr;
- attr.Set(*attrstream, rProtocol.GetTimeout());
+ attr.Set(rDataStream, rProtocol.GetTimeout());
// Get the context to do it's magic
rContext.ChangeDirAttributes(mObjectID, attr, mAttributesModTime);
@@ -575,17 +574,18 @@ std::auto_ptr<BackupProtocolMessage> BackupProtocolChangeDirAttributes::DoComman
// Created: 2003/09/06
//
// --------------------------------------------------------------------------
-std::auto_ptr<BackupProtocolMessage> BackupProtocolSetReplacementFileAttributes::DoCommand(BackupProtocolReplyable &rProtocol, BackupStoreContext &rContext) const
+std::auto_ptr<BackupProtocolMessage>
+BackupProtocolSetReplacementFileAttributes::DoCommand(
+ BackupProtocolReplyable &rProtocol, BackupStoreContext &rContext,
+ IOStream& rDataStream) const
{
CHECK_PHASE(Phase_Commands)
CHECK_WRITEABLE_SESSION
- // Get the stream containing the attributes
- std::auto_ptr<IOStream> attrstream(rProtocol.ReceiveStream());
// Collect the attributes -- do this now so no matter what the outcome,
// the data has been absorbed.
StreamableMemBlock attr;
- attr.Set(*attrstream, rProtocol.GetTimeout());
+ attr.Set(rDataStream, rProtocol.GetTimeout());
// Get the context to do it's magic
int64_t objectID = 0;
diff --git a/lib/common/SelfFlushingStream.h b/lib/common/SelfFlushingStream.h
index 36e9a4d3..6865ab96 100644
--- a/lib/common/SelfFlushingStream.h
+++ b/lib/common/SelfFlushingStream.h
@@ -33,6 +33,12 @@ public:
~SelfFlushingStream()
{
+ if(StreamDataLeft())
+ {
+ BOX_WARNING("Not all data was read from stream, "
+ "discarding the rest");
+ }
+
Flush();
}
diff --git a/lib/server/makeprotocol.pl.in b/lib/server/makeprotocol.pl.in
index 78ef57a1..1c8f6081 100755
--- a/lib/server/makeprotocol.pl.in
+++ b/lib/server/makeprotocol.pl.in
@@ -159,8 +159,9 @@ print CPP <<__E;
#include "$filename_base.h"
#include "CollectInBufferStream.h"
-#include "SocketStream.h"
#include "MemBlockStream.h"
+#include "SelfFlushingStream.h"
+#include "SocketStream.h"
__E
print H <<__E;
@@ -227,6 +228,9 @@ class $message_base_class : public Message
public:
virtual std::auto_ptr<$message_base_class> DoCommand($replyable_base_class &rProtocol,
$context_class &rContext) const;
+ virtual std::auto_ptr<$message_base_class> DoCommand($replyable_base_class &rProtocol,
+ $context_class &rContext, IOStream& rDataStream) const;
+ virtual bool HasStreamWithCommand() const = 0;
};
class $reply_base_class
@@ -259,6 +263,12 @@ std::auto_ptr<$message_base_class> $message_base_class\::DoCommand($replyable_ba
{
THROW_EXCEPTION(ConnectionException, Conn_Protocol_TriedToExecuteReplyCommand)
}
+
+std::auto_ptr<$message_base_class> $message_base_class\::DoCommand($replyable_base_class &rProtocol,
+ $context_class &rContext, IOStream& rDataStream) const
+{
+ THROW_EXCEPTION(ConnectionException, Conn_Protocol_TriedToExecuteReplyCommand)
+}
__E
my %cmd_classes;
@@ -317,14 +327,39 @@ __E
print H "\tstd::string GetMessage() const;\n";
}
- if(obj_is_type($cmd, 'Command'))
+ my $has_stream = obj_is_type($cmd, 'StreamWithCommand');
+
+ if(obj_is_type($cmd, 'Command') && $has_stream)
+ {
+ print H <<__E;
+ std::auto_ptr<$message_base_class> DoCommand($replyable_base_class &rProtocol,
+ $context_class &rContext, IOStream& rDataStream) const; // IMPLEMENT THIS\n
+ std::auto_ptr<$message_base_class> DoCommand($replyable_base_class &rProtocol,
+ $context_class &rContext) const
+ {
+ THROW_EXCEPTION_MESSAGE(CommonException, Internal,
+ "This command requires a stream parameter");
+ }
+__E
+ }
+ elsif(obj_is_type($cmd, 'Command') && !$has_stream)
{
print H <<__E;
std::auto_ptr<$message_base_class> DoCommand($replyable_base_class &rProtocol,
$context_class &rContext) const; // IMPLEMENT THIS\n
+ std::auto_ptr<$message_base_class> DoCommand($replyable_base_class &rProtocol,
+ $context_class &rContext, IOStream& rDataStream) const
+ {
+ THROW_EXCEPTION_MESSAGE(CommonException, NotSupported,
+ "This command requires no stream parameter");
+ }
__E
}
+ print H <<__E;
+ bool HasStreamWithCommand() const { return $has_stream; }
+__E
+
# want to be able to read from streams?
print H "\tvoid SetPropertiesFromStreamData(Protocol &rProtocol);\n";
@@ -1002,10 +1037,20 @@ void $server_or_client_class\::DoServer($context_class &rContext)
{
// Get an object from the conversation
std::auto_ptr<$message_base_class> pobj = Receive();
+ std::auto_ptr<$message_base_class> preply;
// Run the command
- std::auto_ptr<$message_base_class> preply = pobj->DoCommand(*this, rContext);
-
+ if(pobj->HasStreamWithCommand())
+ {
+ std::auto_ptr<IOStream> apDataStream = ReceiveStream();
+ SelfFlushingStream autoflush(*apDataStream);
+ preply = pobj->DoCommand(*this, rContext, *apDataStream);
+ }
+ else
+ {
+ preply = pobj->DoCommand(*this, rContext);
+ }
+
// Send the reply
Send(*preply);
@@ -1052,7 +1097,7 @@ __E
my $reply_class = $cmd_classes{$reply_msg};
my $reply_id = $cmd_id{$reply_msg};
my $has_stream = obj_is_type($cmd,'StreamWithCommand');
- my $argextra = $has_stream?', std::auto_ptr<IOStream> apStream':'';
+ my $argextra = $has_stream?', std::auto_ptr<IOStream> apDataStream':'';
my $send_stream_extra = '';
my $send_stream_method = $writing_client ? "SendStream"
: "SendStreamAfterCommand";
@@ -1068,7 +1113,7 @@ __E
{
$send_stream_extra = <<__E;
// Send stream after the command
- SendStream(*apStream);
+ SendStream(*apDataStream);
__E
}
@@ -1078,37 +1123,29 @@ __E
$send_stream_extra
// Wait for the reply
- std::auto_ptr<$message_base_class> preply = Receive();
+ std::auto_ptr<$message_base_class> apReply = Receive();
__E
}
elsif($writing_local)
{
if($has_stream)
{
- $send_stream_extra = <<__E;
- // Send stream after the command
- SendStreamAfterCommand(apStream);
+ print CPP <<__E;
+ std::auto_ptr<$message_base_class> apReply = rQuery.DoCommand(*this,
+ mrContext, *apDataStream);
__E
}
-
- print CPP <<__E;
- // Push streams to send, if any, into queue for retrieval by DoCommand.
- $send_stream_extra
-
- // Execute the command and get the reply message
- std::auto_ptr<$message_base_class> preply = rQuery.DoCommand(*this, mrContext);
-
- if(!mStreamsToSend.empty())
- {
- THROW_EXCEPTION_MESSAGE(ConnectionException,
- Protocol_StreamsNotConsumed, rQuery.ToString());
- }
+ else
+ {
+ print CPP <<__E;
+ std::auto_ptr<$message_base_class> apReply = rQuery.DoCommand(*this, mrContext);
__E
+ }
}
# Common to both client and local
print CPP <<__E;
- CheckReply("$cmd", rQuery, *preply, $reply_id);
+ CheckReply("$cmd", rQuery, *apReply, $reply_id);
// Correct response, if no exception thrown by CheckReply
return std::auto_ptr<$reply_class>(
diff --git a/test/backupstore/testbackupstore.cpp b/test/backupstore/testbackupstore.cpp
index 300a666b..8b185f1e 100644
--- a/test/backupstore/testbackupstore.cpp
+++ b/test/backupstore/testbackupstore.cpp
@@ -45,6 +45,7 @@
#include "StoreTestUtils.h"
#include "TLSContext.h"
#include "Test.h"
+#include "ZeroStream.h"
#include "MemLeakFindOn.h"
@@ -1077,6 +1078,18 @@ bool test_server_housekeeping()
TEST_THAT(check_num_files(0, 0, 0, 1));
TEST_THAT(check_num_blocks(protocol, 0, 0, 0, root_dir_blocks, root_dir_blocks));
+ // Used to not consume the stream
+ std::auto_ptr<IOStream> upload(new ZeroStream(1000));
+ TEST_COMMAND_RETURNS_ERROR(protocol.QueryStoreFile(
+ BACKUPSTORE_ROOT_DIRECTORY_ID,
+ 0,
+ 0, /* use for attr hash too */
+ 99999, /* diff from ID */
+ uploads[0].name,
+ upload),
+ Err_DiffFromFileDoesNotExist);
+ // TODO FIXME replace all other TEST_CHECK_THROWS with TEST_COMMAND_RETURNS_ERROR
+
// TODO FIXME These tests should not be here, but in
// test_server_commands. But make sure you use a network protocol,
// not a local one, when you move them.