summaryrefslogtreecommitdiff
path: root/lib/server
diff options
context:
space:
mode:
Diffstat (limited to 'lib/server')
-rw-r--r--lib/server/ConnectionException.txt1
-rw-r--r--lib/server/Daemon.cpp159
-rw-r--r--lib/server/Daemon.h13
-rw-r--r--lib/server/Message.h3
-rw-r--r--lib/server/Protocol.cpp91
-rw-r--r--lib/server/Protocol.h16
-rw-r--r--lib/server/ProtocolUncertainStream.cpp5
-rw-r--r--lib/server/ProtocolUncertainStream.h3
-rw-r--r--lib/server/ProtocolWire.h4
-rw-r--r--lib/server/SSLLib.cpp5
-rw-r--r--lib/server/ServerControl.cpp78
-rw-r--r--lib/server/ServerControl.h4
-rw-r--r--lib/server/ServerException.h46
-rw-r--r--lib/server/ServerStream.h27
-rw-r--r--lib/server/ServerTLS.h9
-rw-r--r--lib/server/Socket.cpp7
-rw-r--r--lib/server/SocketListen.h65
-rw-r--r--lib/server/SocketStream.cpp152
-rw-r--r--lib/server/SocketStream.h54
-rw-r--r--lib/server/SocketStreamTLS.cpp101
-rw-r--r--lib/server/SocketStreamTLS.h3
-rw-r--r--lib/server/TLSContext.cpp16
-rw-r--r--lib/server/TcpNice.cpp6
-rw-r--r--lib/server/TcpNice.h6
-rw-r--r--lib/server/WinNamedPipeListener.h18
-rw-r--r--lib/server/WinNamedPipeStream.cpp555
-rw-r--r--lib/server/WinNamedPipeStream.h47
-rwxr-xr-xlib/server/makeprotocol.pl.in497
28 files changed, 1115 insertions, 876 deletions
diff --git a/lib/server/ConnectionException.txt b/lib/server/ConnectionException.txt
index c3429116..7dcaadeb 100644
--- a/lib/server/ConnectionException.txt
+++ b/lib/server/ConnectionException.txt
@@ -25,3 +25,4 @@ Protocol_HandshakeFailed 48
Protocol_StreamWhenObjExpected 49
Protocol_ObjWhenStreamExpected 50
Protocol_TimeOutWhenSendingStream 52 Probably a network issue between client and server.
+Protocol_StreamsNotConsumed 53 The server command handler did not consume all streams that were sent.
diff --git a/lib/server/Daemon.cpp b/lib/server/Daemon.cpp
index 7419f973..d3c8441f 100644
--- a/lib/server/Daemon.cpp
+++ b/lib/server/Daemon.cpp
@@ -9,23 +9,27 @@
#include "Box.h"
-#ifdef HAVE_UNISTD_H
- #include <unistd.h>
-#endif
-
#include <errno.h>
#include <stdio.h>
#include <signal.h>
#include <string.h>
#include <stdarg.h>
+#ifdef HAVE_PROCESS_H
+# include <process.h>
+#endif
+
+#ifdef HAVE_UNISTD_H
+# include <unistd.h>
+#endif
+
#ifdef HAVE_BSD_UNISTD_H
#include <bsd/unistd.h>
#endif
#ifdef WIN32
+ #include <Strsafe.h>
#include <ws2tcpip.h>
- #include <process.h>
#endif
#include "depot.h"
@@ -36,12 +40,13 @@
# include "BoxVersion.h"
#endif
+#include "autogen_ConnectionException.h"
+#include "autogen_ServerException.h"
#include "Configuration.h"
#include "Daemon.h"
#include "FileModificationTime.h"
#include "Guards.h"
#include "Logging.h"
-#include "ServerException.h"
#include "UnixUser.h"
#include "Utils.h"
@@ -106,11 +111,11 @@ Daemon::~Daemon()
// --------------------------------------------------------------------------
std::string Daemon::GetOptionString()
{
- return "c:"
+ return std::string("c:"
#ifndef WIN32
"DF"
#endif
- "hkKo:O:PqQt:TUvVW:";
+ "hkKo:O:") + Logging::OptionParser::GetOptionString();
}
void Daemon::Usage()
@@ -133,16 +138,7 @@ void Daemon::Usage()
" -K Stop writing log messages to console while daemon is running\n"
" -o <file> Log to a file, defaults to maximum verbosity\n"
" -O <level> Set file log verbosity to error/warning/notice/info/trace/everything\n"
- " -P Show process ID (PID) in console output\n"
- " -q Run more quietly, reduce verbosity level by one, can repeat\n"
- " -Q Run at minimum verbosity, log nothing to console and system\n"
- " -t <tag> Tag console output with specified marker\n"
- " -T Timestamp console output\n"
- " -U Timestamp console output with microseconds\n"
- " -v Run more verbosely, increase verbosity level by one, can repeat\n"
- " -V Run at maximum verbosity, log everything to console and system\n"
- " -W <level> Set verbosity to error/warning/notice/info/trace/everything\n"
- ;
+ << Logging::OptionParser::GetUsageString();
}
// --------------------------------------------------------------------------
@@ -218,94 +214,9 @@ int Daemon::ProcessOption(signed int option)
}
break;
- case 'P':
- {
- Console::SetShowPID(true);
- }
- break;
-
- case 'q':
- {
- if(mLogLevel == Log::NOTHING)
- {
- BOX_FATAL("Too many '-q': "
- "Cannot reduce logging "
- "level any more");
- return 2;
- }
- mLogLevel--;
- }
- break;
-
- case 'Q':
- {
- mLogLevel = Log::NOTHING;
- }
- break;
-
- case 't':
- {
- Logging::SetProgramName(optarg);
- Console::SetShowTag(true);
- }
- break;
-
- case 'T':
- {
- Console::SetShowTime(true);
- }
- break;
-
- case 'U':
- {
- Console::SetShowTime(true);
- Console::SetShowTimeMicros(true);
- }
- break;
-
- case 'v':
- {
- if(mLogLevel == Log::EVERYTHING)
- {
- BOX_FATAL("Too many '-v': "
- "Cannot increase logging "
- "level any more");
- return 2;
- }
- mLogLevel++;
- }
- break;
-
- case 'V':
- {
- mLogLevel = Log::EVERYTHING;
- }
- break;
-
- case 'W':
- {
- mLogLevel = Logging::GetNamedLevel(optarg);
- if (mLogLevel == Log::INVALID)
- {
- BOX_FATAL("Invalid logging level: " << optarg);
- return 2;
- }
- }
- break;
-
- case '?':
- {
- BOX_FATAL("Unknown option on command line: "
- << "'" << (char)optopt << "'");
- return 2;
- }
- break;
-
default:
{
- BOX_FATAL("Unknown error in getopt: returned "
- << "'" << option << "'");
- return 1;
+ return mLogLevel.ProcessOption(option);
}
}
@@ -351,12 +262,6 @@ int Daemon::Main(const std::string& rDefaultConfigFile, int argc,
int Daemon::ProcessOptions(int argc, const char *argv[])
{
- #ifdef BOX_RELEASE_BUILD
- mLogLevel = Log::NOTICE;
- #else
- mLogLevel = Log::INFO;
- #endif
-
if (argc == 2 && strcmp(argv[1], "/?") == 0)
{
Usage();
@@ -368,7 +273,7 @@ int Daemon::ProcessOptions(int argc, const char *argv[])
// reset getopt, just in case anybody used it before.
// unfortunately glibc and BSD differ on this point!
// http://www.ussg.iu.edu/hypermail/linux/kernel/0305.3/0262.html
- #if HAVE_DECL_OPTRESET == 1 || defined WIN32
+ #if HAVE_DECL_OPTRESET == 1 || defined BOX_BSD_GETOPT
optind = 1;
optreset = 1;
#elif defined __GLIBC__
@@ -406,13 +311,14 @@ int Daemon::ProcessOptions(int argc, const char *argv[])
return 2;
}
- Logging::FilterConsole((Log::Level)mLogLevel);
- Logging::FilterSyslog ((Log::Level)mLogLevel);
+ Logging::FilterConsole(mLogLevel.GetCurrentLevel());
+ Logging::FilterSyslog (mLogLevel.GetCurrentLevel());
if (mLogFileLevel != Log::INVALID)
{
mapLogFileLogger.reset(
- new FileLogger(mLogFile, mLogFileLevel));
+ new FileLogger(mLogFile, mLogFileLevel,
+ !mLogLevel.mTruncateLogFile));
}
return 0;
@@ -473,17 +379,17 @@ bool Daemon::Configure(const std::string& rConfigFileName)
BOX_ERROR("Failed to load or verify configuration file");
return false;
}
-
+
if(!Configure(*apConfig))
{
BOX_ERROR("Failed to verify configuration file");
- return false;
+ return false;
}
-
+
// Store configuration
mConfigFileName = rConfigFileName;
mLoadedConfigModifiedTime = GetConfigFileModifiedTime();
-
+
return true;
}
@@ -513,14 +419,14 @@ bool Daemon::Configure(const Configuration& rConfig)
BOX_ERROR("Configuration errors: " << errors);
return false;
}
-
+
// Store configuration
mapConfiguration = apConf;
-
+
// Let the derived class have a go at setting up stuff
// in the initial process
SetupInInitialProcess();
-
+
return true;
}
@@ -664,7 +570,7 @@ int Daemon::Main(const std::string &rConfigFileName)
// Write PID to file
char pid[32];
- int pidsize = sprintf(pid, "%d", (int)getpid());
+ int pidsize = snprintf(pid, sizeof(pid), "%d", (int)getpid());
if(::write(pidFile, pid, pidsize) != pidsize)
{
@@ -676,9 +582,8 @@ int Daemon::Main(const std::string &rConfigFileName)
// Set up memory leak reporting
#ifdef BOX_MEMORY_LEAK_TESTING
{
- char filename[256];
- sprintf(filename, "%s.memleaks", DaemonName());
- memleakfinder_setup_exit_report(filename, DaemonName());
+ memleakfinder_setup_exit_report(std::string(DaemonName()) +
+ ".memleaks", DaemonName());
}
#endif // BOX_MEMORY_LEAK_TESTING
@@ -986,7 +891,9 @@ const Configuration &Daemon::GetConfiguration() const
if(mapConfiguration.get() == 0)
{
// Shouldn't get anywhere near this if a configuration file can't be loaded
- THROW_EXCEPTION(ServerException, Internal)
+ THROW_EXCEPTION_MESSAGE(ServerException, Internal,
+ "The daemon has not been configured; no config file "
+ "has been loaded.");
}
return *mapConfiguration;
diff --git a/lib/server/Daemon.h b/lib/server/Daemon.h
index 2718c288..b5384918 100644
--- a/lib/server/Daemon.h
+++ b/lib/server/Daemon.h
@@ -85,7 +85,16 @@ protected:
bool IsSingleProcess() { return mSingleProcess; }
virtual std::string GetOptionString();
virtual int ProcessOption(signed int option);
-
+ void ResetLogFile()
+ {
+ if(mapLogFileLogger.get())
+ {
+ mapLogFileLogger.reset(
+ new FileLogger(mLogFile, mLogFileLevel,
+ !mLogLevel.mTruncateLogFile));
+ }
+ }
+
private:
static void SignalHandler(int sigraised);
box_time_t GetConfigFileModifiedTime() const;
@@ -99,7 +108,7 @@ private:
bool mRunInForeground;
bool mKeepConsoleOpenAfterFork;
bool mHaveConfigFile;
- int mLogLevel; // need an int to do math with
+ Logging::OptionParser mLogLevel;
std::string mLogFile;
Log::Level mLogFileLevel;
std::auto_ptr<FileLogger> mapLogFileLogger;
diff --git a/lib/server/Message.h b/lib/server/Message.h
index 0d073d49..9f2245ec 100644
--- a/lib/server/Message.h
+++ b/lib/server/Message.h
@@ -37,10 +37,11 @@ public:
// reading and writing with Protocol objects
virtual void SetPropertiesFromStreamData(Protocol &rProtocol);
- virtual void WritePropertiesToStreamData(Protocol &rProtocol) const;
+ virtual void WritePropertiesToStreamData(Protocol &rProtocol) const;
virtual void LogSysLog(const char *Action) const { }
virtual void LogFile(const char *Action, FILE *file) const { }
+ virtual std::string ToString() const = 0;
};
/*
diff --git a/lib/server/Protocol.cpp b/lib/server/Protocol.cpp
index 382f1c37..0adf9543 100644
--- a/lib/server/Protocol.cpp
+++ b/lib/server/Protocol.cpp
@@ -17,10 +17,11 @@
#include <new>
+#include "autogen_ConnectionException.h"
+#include "autogen_ServerException.h"
#include "Protocol.h"
#include "ProtocolWire.h"
-#include "IOStream.h"
-#include "ServerException.h"
+#include "SocketStream.h"
#include "PartialReadStream.h"
#include "ProtocolUncertainStream.h"
#include "Logging.h"
@@ -44,8 +45,8 @@
// Created: 2003/08/19
//
// --------------------------------------------------------------------------
-Protocol::Protocol(IOStream &rStream)
-: mrStream(rStream),
+Protocol::Protocol(std::auto_ptr<SocketStream> apConn)
+: mapConn(apConn),
mHandshakeDone(false),
mMaxObjectSize(PROTOCOL_DEFAULT_MAXOBJSIZE),
mTimeout(PROTOCOL_DEFAULT_TIMEOUT),
@@ -103,8 +104,8 @@ void Protocol::Handshake()
::strncpy(hsSend.mIdent, GetProtocolIdentString(), sizeof(hsSend.mIdent));
// Send it
- mrStream.Write(&hsSend, sizeof(hsSend));
- mrStream.WriteAllBuffered();
+ mapConn->Write(&hsSend, sizeof(hsSend), GetTimeout());
+ mapConn->WriteAllBuffered();
// Receive a handshake from the peer
PW_Handshake hsReceive;
@@ -114,10 +115,10 @@ void Protocol::Handshake()
while(bytesToRead > 0)
{
// Get some data from the stream
- int bytesRead = mrStream.Read(readInto, bytesToRead, mTimeout);
+ int bytesRead = mapConn->Read(readInto, bytesToRead, GetTimeout());
if(bytesRead == 0)
{
- THROW_EXCEPTION(ConnectionException, Conn_Protocol_Timeout)
+ THROW_EXCEPTION(ConnectionException, Protocol_Timeout)
}
readInto += bytesRead;
bytesToRead -= bytesRead;
@@ -127,7 +128,7 @@ void Protocol::Handshake()
// Are they the same?
if(::memcmp(&hsSend, &hsReceive, sizeof(hsSend)) != 0)
{
- THROW_EXCEPTION(ConnectionException, Conn_Protocol_HandshakeFailed)
+ THROW_EXCEPTION(ConnectionException, Protocol_HandshakeFailed)
}
// Mark as done
@@ -158,9 +159,10 @@ void Protocol::CheckAndReadHdr(void *hdr)
}
// Get some data into this header
- if(!mrStream.ReadFullBuffer(hdr, sizeof(PW_ObjectHeader), 0 /* not interested in bytes read if this fails */, mTimeout))
+ if(!mapConn->ReadFullBuffer(hdr, sizeof(PW_ObjectHeader),
+ 0 /* not interested in bytes read if this fails */, mTimeout))
{
- THROW_EXCEPTION(ConnectionException, Conn_Protocol_Timeout)
+ THROW_EXCEPTION(ConnectionException, Protocol_Timeout)
}
}
@@ -168,8 +170,9 @@ void Protocol::CheckAndReadHdr(void *hdr)
// --------------------------------------------------------------------------
//
// Function
-// Name: Protocol::Recieve()
-// Purpose: Recieves an object from the stream, creating it from the factory object type
+// Name: Protocol::ReceiveInternal()
+// Purpose: Receives an object from the stream, creating it
+// from the factory object type
// Created: 2003/08/19
//
// --------------------------------------------------------------------------
@@ -182,14 +185,14 @@ std::auto_ptr<Message> Protocol::ReceiveInternal()
// Hope it's not a stream
if(ntohl(objHeader.mObjType) == SPECIAL_STREAM_OBJECT_TYPE)
{
- THROW_EXCEPTION(ConnectionException, Conn_Protocol_StreamWhenObjExpected)
+ THROW_EXCEPTION(ConnectionException, Protocol_StreamWhenObjExpected)
}
// Check the object size
- u_int32_t objSize = ntohl(objHeader.mObjSize);
+ uint32_t objSize = ntohl(objHeader.mObjSize);
if(objSize < sizeof(objHeader) || objSize > mMaxObjectSize)
{
- THROW_EXCEPTION(ConnectionException, Conn_Protocol_ObjTooBig)
+ THROW_EXCEPTION(ConnectionException, Protocol_ObjTooBig)
}
// Create a blank object
@@ -199,9 +202,10 @@ std::auto_ptr<Message> Protocol::ReceiveInternal()
EnsureBufferAllocated(objSize);
// Read data
- if(!mrStream.ReadFullBuffer(mpBuffer, objSize - sizeof(objHeader), 0 /* not interested in bytes read if this fails */, mTimeout))
+ if(!mapConn->ReadFullBuffer(mpBuffer, objSize - sizeof(objHeader),
+ 0 /* not interested in bytes read if this fails */, mTimeout))
{
- THROW_EXCEPTION(ConnectionException, Conn_Protocol_Timeout)
+ THROW_EXCEPTION(ConnectionException, Protocol_Timeout)
}
// Setup ready to read out data from the buffer
@@ -231,7 +235,7 @@ std::auto_ptr<Message> Protocol::ReceiveInternal()
// Exception if not all the data was consumed
if(dataLeftOver)
{
- THROW_EXCEPTION(ConnectionException, Conn_Protocol_BadCommandRecieved)
+ THROW_EXCEPTION(ConnectionException, Protocol_BadCommandRecieved)
}
return obj;
@@ -240,7 +244,7 @@ std::auto_ptr<Message> Protocol::ReceiveInternal()
// --------------------------------------------------------------------------
//
// Function
-// Name: Protocol::Send()
+// Name: Protocol::SendInternal()
// Purpose: Send an object to the other side of the connection.
// Created: 2003/08/19
//
@@ -292,8 +296,8 @@ void Protocol::SendInternal(const Message &rObject)
pobjHeader->mObjType = htonl(rObject.GetType());
// Write data
- mrStream.Write(mpBuffer, writtenSize);
- mrStream.WriteAllBuffered();
+ mapConn->Write(mpBuffer, writtenSize, GetTimeout());
+ mapConn->WriteAllBuffered();
}
// --------------------------------------------------------------------------
@@ -346,7 +350,7 @@ void Protocol::EnsureBufferAllocated(int Size)
#define READ_CHECK_BYTES_AVAILABLE(bytesRequired) \
if((mReadOffset + (int)(bytesRequired)) > mValidDataSize) \
{ \
- THROW_EXCEPTION(ConnectionException, Conn_Protocol_BadCommandRecieved) \
+ THROW_EXCEPTION(ConnectionException, Protocol_BadCommandRecieved) \
}
// --------------------------------------------------------------------------
@@ -619,7 +623,7 @@ void Protocol::Write(const std::string &rValue)
// --------------------------------------------------------------------------
//
// Function
-// Name: Protocol::ReceieveStream()
+// Name: Protocol::ReceiveStream()
// Purpose: Receive a stream from the remote side
// Created: 2003/08/26
//
@@ -633,11 +637,11 @@ std::auto_ptr<IOStream> Protocol::ReceiveStream()
// Hope it's not an object
if(ntohl(objHeader.mObjType) != SPECIAL_STREAM_OBJECT_TYPE)
{
- THROW_EXCEPTION(ConnectionException, Conn_Protocol_ObjWhenStreamExpected)
+ THROW_EXCEPTION(ConnectionException, Protocol_ObjWhenStreamExpected)
}
// Get the stream size
- u_int32_t streamSize = ntohl(objHeader.mObjSize);
+ uint32_t streamSize = ntohl(objHeader.mObjSize);
// Inform sub class
InformStreamReceiving(streamSize);
@@ -647,13 +651,13 @@ std::auto_ptr<IOStream> Protocol::ReceiveStream()
{
BOX_TRACE("Receiving stream, size uncertain");
return std::auto_ptr<IOStream>(
- new ProtocolUncertainStream(mrStream));
+ new ProtocolUncertainStream(*mapConn));
}
else
{
BOX_TRACE("Receiving stream, size " << streamSize << " bytes");
return std::auto_ptr<IOStream>(
- new PartialReadStream(mrStream, streamSize));
+ new PartialReadStream(*mapConn, streamSize));
}
}
@@ -709,7 +713,7 @@ void Protocol::SendStream(IOStream &rStream)
objHeader.mObjType = htonl(SPECIAL_STREAM_OBJECT_TYPE);
// Write header
- mrStream.Write(&objHeader, sizeof(objHeader));
+ mapConn->Write(&objHeader, sizeof(objHeader), GetTimeout());
// Could be sent in one of two ways
if(uncertainSize)
{
@@ -744,7 +748,7 @@ void Protocol::SendStream(IOStream &rStream)
// Send final byte to finish the stream
BOX_TRACE("Sending end of stream byte");
uint8_t endOfStream = ProtocolStreamHeader_EndOfStream;
- mrStream.Write(&endOfStream, 1);
+ mapConn->Write(&endOfStream, 1, GetTimeout());
BOX_TRACE("Sent end of stream byte");
}
catch(...)
@@ -759,13 +763,14 @@ void Protocol::SendStream(IOStream &rStream)
else
{
// Fixed size stream, send it all in one go
- if(!rStream.CopyStreamTo(mrStream, mTimeout, 4096 /* slightly larger buffer */))
+ if(!rStream.CopyStreamTo(*mapConn, GetTimeout(),
+ 4096 /* slightly larger buffer */))
{
- THROW_EXCEPTION(ConnectionException, Conn_Protocol_TimeOutWhenSendingStream)
+ THROW_EXCEPTION(ConnectionException, Protocol_TimeOutWhenSendingStream)
}
}
// Make sure everything is written
- mrStream.WriteAllBuffered();
+ mapConn->WriteAllBuffered();
}
@@ -816,7 +821,7 @@ int Protocol::SendStreamSendBlock(uint8_t *Block, int BytesInBlock)
Block[-1] = header;
// Write everything out
- mrStream.Write(Block - 1, writeSize + 1);
+ mapConn->Write(Block - 1, writeSize + 1, GetTimeout());
BOX_TRACE("Sent " << (writeSize+1) << " bytes to stream");
// move the remainer to the beginning of the block for the next time round
@@ -831,12 +836,12 @@ int Protocol::SendStreamSendBlock(uint8_t *Block, int BytesInBlock)
// --------------------------------------------------------------------------
//
// Function
-// Name: Protocol::InformStreamReceiving(u_int32_t)
+// Name: Protocol::InformStreamReceiving(uint32_t)
// Purpose: Informs sub classes about streams being received
// Created: 2003/10/27
//
// --------------------------------------------------------------------------
-void Protocol::InformStreamReceiving(u_int32_t Size)
+void Protocol::InformStreamReceiving(uint32_t Size)
{
if(GetLogToSysLog())
{
@@ -863,12 +868,12 @@ void Protocol::InformStreamReceiving(u_int32_t Size)
// --------------------------------------------------------------------------
//
// Function
-// Name: Protocol::InformStreamSending(u_int32_t)
+// Name: Protocol::InformStreamSending(uint32_t)
// Purpose: Informs sub classes about streams being sent
// Created: 2003/10/27
//
// --------------------------------------------------------------------------
-void Protocol::InformStreamSending(u_int32_t Size)
+void Protocol::InformStreamSending(uint32_t Size)
{
if(GetLogToSysLog())
{
@@ -1177,6 +1182,12 @@ const uint16_t Protocol::sProtocolStreamHeaderLengths[256] =
0 // 255 = special (reserved)
};
+int64_t Protocol::GetBytesRead() const
+{
+ return mapConn->GetBytesRead();
+}
-
-
+int64_t Protocol::GetBytesWritten() const
+{
+ return mapConn->GetBytesWritten();
+}
diff --git a/lib/server/Protocol.h b/lib/server/Protocol.h
index 42cb0ff8..fbe6461c 100644
--- a/lib/server/Protocol.h
+++ b/lib/server/Protocol.h
@@ -19,6 +19,7 @@
#include "Message.h"
class IOStream;
+class SocketStream;
// default timeout is 15 minutes
#define PROTOCOL_DEFAULT_TIMEOUT (15*60*1000)
@@ -36,7 +37,7 @@ class IOStream;
class Protocol
{
public:
- Protocol(IOStream &rStream);
+ Protocol(std::auto_ptr<SocketStream> apConn);
virtual ~Protocol();
private:
@@ -66,14 +67,14 @@ public:
// Purpose: Sets the timeout for sending and reciving
// Created: 2003/08/19
//
- // --------------------------------------------------------------------------
+ // --------------------------------------------------------------------------
void SetTimeout(int NewTimeout) {mTimeout = NewTimeout;}
// --------------------------------------------------------------------------
//
// Function
- // Name: Protocol::GetTimeout()
+ // Name: Protocol::GetTimeout()
// Purpose: Get current timeout for sending and receiving
// Created: 2003/09/06
//
@@ -175,6 +176,8 @@ public:
FILE *GetLogToFile() { return mLogToFile; }
void SetLogToSysLog(bool Log = false) {mLogToSysLog = Log;}
void SetLogToFile(FILE *File = 0) {mLogToFile = File;}
+ int64_t GetBytesRead() const;
+ int64_t GetBytesWritten() const;
protected:
virtual std::auto_ptr<Message> MakeMessage(int ObjType) = 0;
@@ -183,14 +186,14 @@ protected:
void CheckAndReadHdr(void *hdr); // don't use type here to avoid dependency
// Will be used for logging
- virtual void InformStreamReceiving(u_int32_t Size);
- virtual void InformStreamSending(u_int32_t Size);
+ virtual void InformStreamReceiving(uint32_t Size);
+ virtual void InformStreamSending(uint32_t Size);
private:
void EnsureBufferAllocated(int Size);
int SendStreamSendBlock(uint8_t *Block, int BytesInBlock);
- IOStream &mrStream;
+ std::auto_ptr<SocketStream> mapConn;
bool mHandshakeDone;
unsigned int mMaxObjectSize;
int mTimeout;
@@ -208,4 +211,3 @@ class ProtocolContext
};
#endif // PROTOCOL__H
-
diff --git a/lib/server/ProtocolUncertainStream.cpp b/lib/server/ProtocolUncertainStream.cpp
index 84a213a8..aeb15816 100644
--- a/lib/server/ProtocolUncertainStream.cpp
+++ b/lib/server/ProtocolUncertainStream.cpp
@@ -8,8 +8,9 @@
// --------------------------------------------------------------------------
#include "Box.h"
+#include "autogen_ConnectionException.h"
+#include "autogen_ServerException.h"
#include "ProtocolUncertainStream.h"
-#include "ServerException.h"
#include "Protocol.h"
#include "MemLeakFindOn.h"
@@ -172,7 +173,7 @@ IOStream::pos_type ProtocolUncertainStream::BytesLeftToRead()
// Created: 2003/12/05
//
// --------------------------------------------------------------------------
-void ProtocolUncertainStream::Write(const void *pBuffer, int NBytes)
+void ProtocolUncertainStream::Write(const void *pBuffer, int NBytes, int Timeout)
{
THROW_EXCEPTION(ServerException, CantWriteToProtocolUncertainStream)
}
diff --git a/lib/server/ProtocolUncertainStream.h b/lib/server/ProtocolUncertainStream.h
index 4954cf88..2e97ba6a 100644
--- a/lib/server/ProtocolUncertainStream.h
+++ b/lib/server/ProtocolUncertainStream.h
@@ -33,7 +33,8 @@ private:
public:
virtual int Read(void *pBuffer, int NBytes, int Timeout = IOStream::TimeOutInfinite);
virtual pos_type BytesLeftToRead();
- virtual void Write(const void *pBuffer, int NBytes);
+ virtual void Write(const void *pBuffer, int NBytes,
+ int Timeout = IOStream::TimeOutInfinite);
virtual bool StreamDataLeft();
virtual bool StreamClosed();
diff --git a/lib/server/ProtocolWire.h b/lib/server/ProtocolWire.h
index ff62b66e..6dee445b 100644
--- a/lib/server/ProtocolWire.h
+++ b/lib/server/ProtocolWire.h
@@ -26,8 +26,8 @@ typedef struct
typedef struct
{
- u_int32_t mObjSize;
- u_int32_t mObjType;
+ uint32_t mObjSize;
+ uint32_t mObjType;
} PW_ObjectHeader;
#define SPECIAL_STREAM_OBJECT_TYPE 0xffffffff
diff --git a/lib/server/SSLLib.cpp b/lib/server/SSLLib.cpp
index 004d2d98..1bcadb0d 100644
--- a/lib/server/SSLLib.cpp
+++ b/lib/server/SSLLib.cpp
@@ -18,9 +18,10 @@
#include <wincrypt.h>
#endif
+#include "autogen_ConnectionException.h"
+#include "autogen_ServerException.h"
#include "CryptoUtils.h"
#include "SSLLib.h"
-#include "ServerException.h"
#include "MemLeakFindOn.h"
@@ -79,7 +80,7 @@ void SSLLib::Initialise()
BOX_LOG_WIN_ERROR("Failed to release crypto context");
}
}
-#elif HAVE_RANDOM_DEVICE
+#elif defined HAVE_RANDOM_DEVICE
if(::RAND_load_file(RANDOM_DEVICE, 1024) != 1024)
{
THROW_EXCEPTION(ServerException, SSLRandomInitFailed)
diff --git a/lib/server/ServerControl.cpp b/lib/server/ServerControl.cpp
index b9650cee..f1a718df 100644
--- a/lib/server/ServerControl.cpp
+++ b/lib/server/ServerControl.cpp
@@ -15,13 +15,14 @@
#include <signal.h>
#endif
+#include "BoxTime.h"
+#include "IOStreamGetLine.h"
#include "ServerControl.h"
#include "Test.h"
#ifdef WIN32
#include "WinNamedPipeStream.h"
-#include "IOStreamGetLine.h"
#include "BoxPortsAndFiles.h"
static std::string sPipeName;
@@ -197,18 +198,18 @@ bool KillServer(int pid, bool WaitForProcess)
}
#endif
- for (int i = 0; i < 30; i++)
+ printf("Waiting for server to die (pid %d): ", pid);
+
+ for (int i = 0; i < 300; i++)
{
- if (i == 0)
+ if (i % 10 == 0)
{
- printf("Waiting for server to die (pid %d): ", pid);
+ printf(".");
+ fflush(stdout);
}
- printf(".");
- fflush(stdout);
-
if (!ServerIsAlive(pid)) break;
- ::sleep(1);
+ ShortSleep(MilliSecondsToBoxTime(100), false);
if (!ServerIsAlive(pid)) break;
}
@@ -226,3 +227,64 @@ bool KillServer(int pid, bool WaitForProcess)
return !ServerIsAlive(pid);
}
+bool KillServer(std::string pid_file, bool WaitForProcess)
+{
+ FileStream fs(pid_file);
+ IOStreamGetLine getline(fs);
+ std::string line = getline.GetLine();
+ int pid = atoi(line.c_str());
+ bool status = KillServer(pid, WaitForProcess);
+ TEST_EQUAL_LINE(true, status, std::string("kill(") + pid_file + ")");
+
+#ifdef WIN32
+ if(WaitForProcess)
+ {
+ int unlink_result = unlink(pid_file.c_str());
+ TEST_EQUAL_LINE(0, unlink_result, std::string("unlink ") + pid_file);
+ if(unlink_result != 0)
+ {
+ return false;
+ }
+ }
+#endif
+
+ return status;
+}
+
+int StartDaemon(int current_pid, const std::string& cmd_line, const char* pid_file)
+{
+ TEST_THAT_OR(current_pid == 0, return 0);
+
+ int new_pid = LaunchServer(cmd_line, pid_file);
+ TEST_THAT_OR(new_pid != -1 && new_pid != 0, return 0);
+
+ ::sleep(1);
+ TEST_THAT_OR(ServerIsAlive(new_pid), return 0);
+ return new_pid;
+}
+
+bool StopDaemon(int current_pid, const std::string& pid_file,
+ const std::string& memleaks_file, bool wait_for_process)
+{
+ TEST_THAT_OR(current_pid != 0, return false);
+ TEST_THAT_OR(ServerIsAlive(current_pid), return false);
+ TEST_THAT_OR(KillServer(current_pid, wait_for_process), return false);
+ ::sleep(1);
+
+ TEST_THAT_OR(!ServerIsAlive(current_pid), return false);
+
+ #ifdef WIN32
+ int unlink_result = unlink(pid_file.c_str());
+ TEST_EQUAL_LINE(0, unlink_result, std::string("unlink ") + pid_file);
+ if(unlink_result != 0)
+ {
+ return false;
+ }
+ #else
+ TestRemoteProcessMemLeaks(memleaks_file.c_str());
+ #endif
+
+ return true;
+}
+
+
diff --git a/lib/server/ServerControl.h b/lib/server/ServerControl.h
index b2e51864..be2464c1 100644
--- a/lib/server/ServerControl.h
+++ b/lib/server/ServerControl.h
@@ -5,6 +5,10 @@
bool HUPServer(int pid);
bool KillServer(int pid, bool WaitForProcess = false);
+bool KillServer(std::string pid_file, bool WaitForProcess = false);
+int StartDaemon(int current_pid, const std::string& cmd_line, const char* pid_file);
+bool StopDaemon(int current_pid, const std::string& pid_file,
+ const std::string& memleaks_file, bool wait_for_process);
#ifdef WIN32
#include "WinNamedPipeStream.h"
diff --git a/lib/server/ServerException.h b/lib/server/ServerException.h
deleted file mode 100644
index 8851b90a..00000000
--- a/lib/server/ServerException.h
+++ /dev/null
@@ -1,46 +0,0 @@
-// --------------------------------------------------------------------------
-//
-// File
-// Name: ServerException.h
-// Purpose: Exception
-// Created: 2003/07/08
-//
-// --------------------------------------------------------------------------
-
-#ifndef SERVEREXCEPTION__H
-#define SERVEREXCEPTION__H
-
-// Compatibility header
-#include "autogen_ServerException.h"
-#include "autogen_ConnectionException.h"
-
-// Rename old connection exception names to new names without Conn_ prefix
-// This is all because ConnectionException used to be derived from ServerException
-// with some funky magic with subtypes. Perhaps a little unreliable, and the
-// usefulness of it never really was used.
-#define Conn_SocketWriteError SocketWriteError
-#define Conn_SocketReadError SocketReadError
-#define Conn_SocketNameLookupError SocketNameLookupError
-#define Conn_SocketShutdownError SocketShutdownError
-#define Conn_SocketConnectError SocketConnectError
-#define Conn_TLSHandshakeFailed TLSHandshakeFailed
-#define Conn_TLSShutdownFailed TLSShutdownFailed
-#define Conn_TLSWriteFailed TLSWriteFailed
-#define Conn_TLSReadFailed TLSReadFailed
-#define Conn_TLSNoPeerCertificate TLSNoPeerCertificate
-#define Conn_TLSPeerCertificateInvalid TLSPeerCertificateInvalid
-#define Conn_TLSClosedWhenWriting TLSClosedWhenWriting
-#define Conn_TLSHandshakeTimedOut TLSHandshakeTimedOut
-#define Conn_Protocol_Timeout Protocol_Timeout
-#define Conn_Protocol_ObjTooBig Protocol_ObjTooBig
-#define Conn_Protocol_BadCommandRecieved Protocol_BadCommandRecieved
-#define Conn_Protocol_UnknownCommandRecieved Protocol_UnknownCommandRecieved
-#define Conn_Protocol_TriedToExecuteReplyCommand Protocol_TriedToExecuteReplyCommand
-#define Conn_Protocol_UnexpectedReply Protocol_UnexpectedReply
-#define Conn_Protocol_HandshakeFailed Protocol_HandshakeFailed
-#define Conn_Protocol_StreamWhenObjExpected Protocol_StreamWhenObjExpected
-#define Conn_Protocol_ObjWhenStreamExpected Protocol_ObjWhenStreamExpected
-#define Conn_Protocol_TimeOutWhenSendingStream Protocol_TimeOutWhenSendingStream
-
-#endif // SERVEREXCEPTION__H
-
diff --git a/lib/server/ServerStream.h b/lib/server/ServerStream.h
index a9b56169..3f6eed7e 100644
--- a/lib/server/ServerStream.h
+++ b/lib/server/ServerStream.h
@@ -17,6 +17,7 @@
#include <sys/wait.h>
#endif
+#include "autogen_ServerException.h"
#include "Daemon.h"
#include "SocketListen.h"
#include "Utils.h"
@@ -286,7 +287,7 @@ public:
#endif
// The derived class does some server magic with the connection
- HandleConnection(*connection);
+ HandleConnection(connection);
// Since rChildExit == true, the forked process will call _exit() on return from this fn
return;
@@ -305,7 +306,7 @@ public:
#endif // !WIN32
// Just handle in this process
SetProcessTitle("handling");
- HandleConnection(*connection);
+ HandleConnection(connection);
SetProcessTitle("idle");
#ifndef WIN32
}
@@ -344,10 +345,20 @@ public:
p = ::waitpid(0 /* any child in process group */,
&status, WNOHANG);
- if(p == -1 && errno != ECHILD && errno != EINTR)
+ if(p == -1)
{
- THROW_EXCEPTION(ServerException,
- ServerWaitOnChildError)
+ if (errno == ECHILD || errno == EINTR)
+ {
+ // Nothing actually happened, so there's no reason
+ // to wait again.
+ break;
+ }
+ else
+ {
+ THROW_SYS_ERROR("Failed to wait for daemon child "
+ "process", ServerException,
+ ServerWaitOnChildError);
+ }
}
else if(p == 0)
{
@@ -377,12 +388,12 @@ public:
}
#endif
- virtual void HandleConnection(StreamType &rStream)
+ virtual void HandleConnection(std::auto_ptr<StreamType> apStream)
{
- Connection(rStream);
+ Connection(apStream);
}
- virtual void Connection(StreamType &rStream) = 0;
+ virtual void Connection(std::auto_ptr<StreamType> apStream) = 0;
protected:
// For checking code in derived classes -- use if you have an algorithm which
diff --git a/lib/server/ServerTLS.h b/lib/server/ServerTLS.h
index a74a671e..f748f4b2 100644
--- a/lib/server/ServerTLS.h
+++ b/lib/server/ServerTLS.h
@@ -52,18 +52,19 @@ public:
std::string certFile(serverconf.GetKeyValue("CertificateFile"));
std::string keyFile(serverconf.GetKeyValue("PrivateKeyFile"));
std::string caFile(serverconf.GetKeyValue("TrustedCAsFile"));
- mContext.Initialise(true /* as server */, certFile.c_str(), keyFile.c_str(), caFile.c_str());
+ mContext.Initialise(true /* as server */, certFile.c_str(),
+ keyFile.c_str(), caFile.c_str());
// Then do normal stream server stuff
ServerStream<SocketStreamTLS, Port, ListenBacklog,
ForkToHandleRequests>::Run2(rChildExit);
}
- virtual void HandleConnection(SocketStreamTLS &rStream)
+ virtual void HandleConnection(std::auto_ptr<SocketStreamTLS> apStream)
{
- rStream.Handshake(mContext, true /* is server */);
+ apStream->Handshake(mContext, true /* is server */);
// this-> in next line required to build under some gcc versions
- this->Connection(rStream);
+ this->Connection(apStream);
}
private:
diff --git a/lib/server/Socket.cpp b/lib/server/Socket.cpp
index f2a4996b..c9c1773d 100644
--- a/lib/server/Socket.cpp
+++ b/lib/server/Socket.cpp
@@ -24,8 +24,9 @@
#include <string.h>
#include <stdio.h>
+#include "autogen_ConnectionException.h"
+#include "autogen_ServerException.h"
#include "Socket.h"
-#include "ServerException.h"
#include "MemLeakFindOn.h"
@@ -69,12 +70,12 @@ void Socket::NameLookupToSockAddr(SocketAllAddr &addr, int &sockDomain,
}
else
{
- THROW_EXCEPTION(ConnectionException, Conn_SocketNameLookupError);
+ THROW_EXCEPTION(ConnectionException, SocketNameLookupError);
}
}
else
{
- THROW_EXCEPTION(ConnectionException, Conn_SocketNameLookupError);
+ THROW_EXCEPTION(ConnectionException, SocketNameLookupError);
}
}
break;
diff --git a/lib/server/SocketListen.h b/lib/server/SocketListen.h
index 39c60ba6..39fe7e24 100644
--- a/lib/server/SocketListen.h
+++ b/lib/server/SocketListen.h
@@ -29,8 +29,9 @@
#include <memory>
#include <string>
+#include "autogen_ConnectionException.h"
+#include "autogen_ServerException.h"
#include "Socket.h"
-#include "ServerException.h"
#include "MemLeakFindOn.h"
@@ -73,7 +74,8 @@ private:
// Created: 2003/07/31
//
// --------------------------------------------------------------------------
-template<typename SocketType, int ListenBacklog = 128, typename SocketLockingType = _NoSocketLocking, int MaxMultiListenSockets = 16>
+template<typename SocketType, int ListenBacklog = 128,
+ typename SocketLockingType = _NoSocketLocking, int MaxMultiListenSockets = 16>
class SocketListen
{
public:
@@ -112,10 +114,9 @@ public:
if(::close(mSocketHandle) == -1)
#endif
{
- BOX_LOG_SOCKET_ERROR(mType, mName, mPort,
- "Failed to close network socket");
- THROW_EXCEPTION(ServerException,
- SocketCloseError)
+ THROW_EXCEPTION_MESSAGE(ServerException, SocketCloseError,
+ BOX_SOCKET_ERROR_MESSAGE(mType, mName, mPort,
+ "Failed to close network socket"));
}
}
mSocketHandle = -1;
@@ -152,9 +153,9 @@ public:
0 /* let OS choose protocol */);
if(mSocketHandle == -1)
{
- BOX_LOG_SOCKET_ERROR(Type, Name, Port,
- "Failed to create a network socket");
- THROW_EXCEPTION(ServerException, SocketOpenError)
+ THROW_EXCEPTION_MESSAGE(ServerException, SocketOpenError,
+ BOX_SOCKET_ERROR_MESSAGE(Type, Name, Port,
+ "Failed to create a network socket"));
}
// Set an option to allow reuse (useful for -HUP situations!)
@@ -167,28 +168,28 @@ public:
&option, sizeof(option)) == -1)
#endif
{
- BOX_LOG_SOCKET_ERROR(Type, Name, Port,
- "Failed to set socket options");
- THROW_EXCEPTION(ServerException, SocketOpenError)
+ THROW_EXCEPTION_MESSAGE(ServerException, SocketOpenError,
+ BOX_SOCKET_ERROR_MESSAGE(Type, Name, Port,
+ "Failed to set socket options"));
}
// Bind it to the right port, and start listening
if(::bind(mSocketHandle, &addr.sa_generic, addrLen) == -1
|| ::listen(mSocketHandle, ListenBacklog) == -1)
{
- int err_number = errno;
-
- BOX_LOG_SOCKET_ERROR(Type, Name, Port,
- "Failed to bind socket");
-
- // Dispose of the socket
- ::close(mSocketHandle);
- mSocketHandle = -1;
-
- THROW_SYS_FILE_ERRNO("Failed to bind or listen "
- "on socket", Name, err_number,
- ServerException, SocketBindError);
- }
+ try
+ {
+ THROW_EXCEPTION_MESSAGE(ServerException, SocketOpenError,
+ BOX_SOCKET_ERROR_MESSAGE(Type, Name, Port,
+ "Failed to bind socket to name/port"));
+ }
+ catch(ServerException &e) // finally
+ {
+ // Dispose of the socket
+ Close();
+ throw;
+ }
+ }
}
// ------------------------------------------------------------------
@@ -248,10 +249,10 @@ public:
}
else
{
- BOX_LOG_SOCKET_ERROR(mType, mName, mPort,
- "Failed to poll connection");
- THROW_EXCEPTION(ServerException,
- SocketPollError)
+ THROW_EXCEPTION_MESSAGE(ServerException,
+ SocketPollError,
+ BOX_SOCKET_ERROR_MESSAGE(mType, mName,
+ mPort, "Failed to poll connection"));
}
break;
case 0: // timed out
@@ -268,9 +269,9 @@ public:
// Got socket (or error), unlock (implicit in destruction)
if(sock == -1)
{
- BOX_LOG_SOCKET_ERROR(mType, mName, mPort,
- "Failed to accept connection");
- THROW_EXCEPTION(ServerException, SocketAcceptError)
+ THROW_EXCEPTION_MESSAGE(ServerException, SocketAcceptError,
+ BOX_SOCKET_ERROR_MESSAGE(mType, mName,
+ mPort, "Failed to accept connection"));
}
// Log it
diff --git a/lib/server/SocketStream.cpp b/lib/server/SocketStream.cpp
index 6ef4b8d1..edb5e5b8 100644
--- a/lib/server/SocketStream.cpp
+++ b/lib/server/SocketStream.cpp
@@ -25,10 +25,24 @@
#include <ucred.h>
#endif
+#ifdef HAVE_BSD_UNISTD_H
+ #include <bsd/unistd.h>
+#endif
+
+#ifdef HAVE_SYS_PARAM_H
+ #include <sys/param.h>
+#endif
+
+#ifdef HAVE_SYS_UCRED_H
+ #include <sys/ucred.h>
+#endif
+
+#include "autogen_ConnectionException.h"
+#include "autogen_ServerException.h"
#include "SocketStream.h"
-#include "ServerException.h"
#include "CommonException.h"
#include "Socket.h"
+#include "Utils.h"
#include "MemLeakFindOn.h"
@@ -162,25 +176,31 @@ void SocketStream::Open(Socket::Type Type, const std::string& rName, int Port)
0 /* let OS choose protocol */);
if(mSocketHandle == INVALID_SOCKET_VALUE)
{
- BOX_LOG_SOCKET_ERROR(Type, rName, Port,
- "Failed to create a network socket");
- THROW_EXCEPTION(ServerException, SocketOpenError)
+ THROW_EXCEPTION_MESSAGE(ServerException, SocketOpenError,
+ BOX_SOCKET_ERROR_MESSAGE(Type, rName, Port,
+ "Failed to create a network socket"));
}
// Connect it
if(::connect(mSocketHandle, &addr.sa_generic, addrLen) == -1)
{
// Dispose of the socket
- BOX_LOG_SOCKET_ERROR(Type, rName, Port,
- "Failed to connect to socket");
+ try
+ {
+ THROW_EXCEPTION_MESSAGE(ServerException, SocketOpenError,
+ BOX_SOCKET_ERROR_MESSAGE(Type, rName, Port,
+ "Failed to connect to socket"));
+ }
+ catch(ServerException &e)
+ {
#ifdef WIN32
- ::closesocket(mSocketHandle);
+ ::closesocket(mSocketHandle);
#else // !WIN32
- ::close(mSocketHandle);
+ ::close(mSocketHandle);
#endif // WIN32
-
- mSocketHandle = INVALID_SOCKET_VALUE;
- THROW_EXCEPTION(ConnectionException, Conn_SocketConnectError)
+ mSocketHandle = INVALID_SOCKET_VALUE;
+ throw;
+ }
}
ResetCounters();
@@ -199,7 +219,9 @@ void SocketStream::Open(Socket::Type Type, const std::string& rName, int Port)
// --------------------------------------------------------------------------
int SocketStream::Read(void *pBuffer, int NBytes, int Timeout)
{
- if(mSocketHandle == INVALID_SOCKET_VALUE)
+ CheckForMissingTimeout(Timeout);
+
+ if(mSocketHandle == INVALID_SOCKET_VALUE)
{
THROW_EXCEPTION(ServerException, BadSocketHandle)
}
@@ -210,7 +232,7 @@ int SocketStream::Read(void *pBuffer, int NBytes, int Timeout)
p.fd = mSocketHandle;
p.events = POLLIN;
p.revents = 0;
- switch(::poll(&p, 1, (Timeout == IOStream::TimeOutInfinite)?INFTIM:Timeout))
+ switch(::poll(&p, 1, PollTimeout(Timeout, 0)))
{
case -1:
// error
@@ -256,7 +278,7 @@ int SocketStream::Read(void *pBuffer, int NBytes, int Timeout)
// Other error
BOX_LOG_SYS_ERROR("Failed to read from socket");
THROW_EXCEPTION(ConnectionException,
- Conn_SocketReadError);
+ SocketReadError);
}
}
@@ -270,6 +292,41 @@ int SocketStream::Read(void *pBuffer, int NBytes, int Timeout)
return r;
}
+bool SocketStream::Poll(short Events, int Timeout)
+{
+ // Wait for data to send.
+ struct pollfd p;
+ p.fd = GetSocketHandle();
+ p.events = Events;
+ p.revents = 0;
+
+ box_time_t start = GetCurrentBoxTime();
+ int result;
+
+ do
+ {
+ result = ::poll(&p, 1, PollTimeout(Timeout, start));
+ }
+ while(result == -1 && errno == EINTR);
+
+ switch(result)
+ {
+ case -1:
+ // error - Bad!
+ THROW_SYS_ERROR("Failed to poll socket", ServerException,
+ SocketPollError);
+ break;
+
+ case 0:
+ // Condition not met, timed out
+ return false;
+
+ default:
+ // good to go!
+ return true;
+ }
+}
+
// --------------------------------------------------------------------------
//
// Function
@@ -278,20 +335,21 @@ int SocketStream::Read(void *pBuffer, int NBytes, int Timeout)
// Created: 2003/07/31
//
// --------------------------------------------------------------------------
-void SocketStream::Write(const void *pBuffer, int NBytes)
+void SocketStream::Write(const void *pBuffer, int NBytes, int Timeout)
{
- if(mSocketHandle == INVALID_SOCKET_VALUE)
+ if(mSocketHandle == INVALID_SOCKET_VALUE)
{
THROW_EXCEPTION(ServerException, BadSocketHandle)
}
-
+
// Buffer in byte sized type.
ASSERT(sizeof(char) == 1);
const char *buffer = (char *)pBuffer;
-
+
// Bytes left to send
int bytesLeft = NBytes;
-
+ box_time_t start = GetCurrentBoxTime();
+
while(bytesLeft > 0)
{
// Try to send.
@@ -304,41 +362,30 @@ void SocketStream::Write(const void *pBuffer, int NBytes)
{
// Error.
mWriteClosed = true; // assume can't write again
- BOX_LOG_SYS_ERROR("Failed to write to socket");
- THROW_EXCEPTION(ConnectionException,
- Conn_SocketWriteError);
+ THROW_SYS_ERROR("Failed to write to socket",
+ ConnectionException, SocketWriteError);
}
-
+
// Knock off bytes sent
bytesLeft -= sent;
// Move buffer pointer
buffer += sent;
mBytesWritten += sent;
-
+
// Need to wait until it can send again?
if(bytesLeft > 0)
{
- BOX_TRACE("Waiting to send data on socket " <<
+ BOX_TRACE("Waiting to send data on socket " <<
mSocketHandle << " (" << bytesLeft <<
" of " << NBytes << " bytes left)");
-
- // Wait for data to send.
- struct pollfd p;
- p.fd = mSocketHandle;
- p.events = POLLOUT;
- p.revents = 0;
-
- if(::poll(&p, 1, 16000 /* 16 seconds */) == -1)
+
+ if(!Poll(POLLOUT, PollTimeout(Timeout, start)))
{
- // Don't exception if it's just a signal
- if(errno != EINTR)
- {
- BOX_LOG_SYS_ERROR("Failed to poll "
- "socket");
- THROW_EXCEPTION(ServerException,
- SocketPollError)
- }
+ THROW_EXCEPTION_MESSAGE(ConnectionException,
+ Protocol_Timeout, "Timed out waiting "
+ "to send " << bytesLeft << " of " <<
+ NBytes << " bytes");
}
}
}
@@ -354,7 +401,7 @@ void SocketStream::Write(const void *pBuffer, int NBytes)
// --------------------------------------------------------------------------
void SocketStream::Close()
{
- if(mSocketHandle == INVALID_SOCKET_VALUE)
+ if(mSocketHandle == INVALID_SOCKET_VALUE)
{
THROW_EXCEPTION(ServerException, BadSocketHandle)
}
@@ -385,19 +432,19 @@ void SocketStream::Shutdown(bool Read, bool Write)
{
THROW_EXCEPTION(ServerException, BadSocketHandle)
}
-
+
// Do anything?
if(!Read && !Write) return;
-
+
int how = SHUT_RDWR;
if(Read && !Write) how = SHUT_RD;
if(!Read && Write) how = SHUT_WR;
-
+
// Shut it down!
if(::shutdown(mSocketHandle, how) == -1)
{
BOX_LOG_SYS_ERROR("Failed to shutdown socket");
- THROW_EXCEPTION(ConnectionException, Conn_SocketShutdownError)
+ THROW_EXCEPTION(ConnectionException, SocketShutdownError)
}
}
@@ -478,8 +525,13 @@ bool SocketStream::GetPeerCredentials(uid_t &rUidOut, gid_t &rGidOut)
if(::getsockopt(mSocketHandle, SOL_SOCKET, SO_PEERCRED, &cred,
&credLen) == 0)
{
+#ifdef HAVE_STRUCT_UCRED_UID
rUidOut = cred.uid;
rGidOut = cred.gid;
+#else // HAVE_STRUCT_UCRED_CR_UID
+ rUidOut = cred.cr_uid;
+ rGidOut = cred.cr_gid;
+#endif
return true;
}
@@ -509,3 +561,11 @@ bool SocketStream::GetPeerCredentials(uid_t &rUidOut, gid_t &rGidOut)
return false;
}
+void SocketStream::CheckForMissingTimeout(int Timeout)
+{
+ if (Timeout == IOStream::TimeOutInfinite)
+ {
+ BOX_WARNING("Network operation started with no timeout!");
+ DumpStackBacktrace();
+ }
+}
diff --git a/lib/server/SocketStream.h b/lib/server/SocketStream.h
index 2fb5e391..fd57af8f 100644
--- a/lib/server/SocketStream.h
+++ b/lib/server/SocketStream.h
@@ -10,6 +10,13 @@
#ifndef SOCKETSTREAM__H
#define SOCKETSTREAM__H
+#include <climits>
+
+#ifdef HAVE_SYS_POLL_H
+# include <sys/poll.h>
+#endif
+
+#include "BoxTime.h"
#include "IOStream.h"
#include "Socket.h"
@@ -41,7 +48,16 @@ public:
void Attach(int socket);
virtual int Read(void *pBuffer, int NBytes, int Timeout = IOStream::TimeOutInfinite);
- virtual void Write(const void *pBuffer, int NBytes);
+ virtual void Write(const void *pBuffer, int NBytes,
+ int Timeout = IOStream::TimeOutInfinite);
+
+ // Why not inherited from IOStream? Never mind, we want to enforce
+ // supplying a timeout for network operations anyway.
+ virtual void Write(const std::string& rBuffer, int Timeout)
+ {
+ IOStream::Write(rBuffer, Timeout);
+ }
+
virtual void Close();
virtual bool StreamDataLeft();
virtual bool StreamClosed();
@@ -53,6 +69,42 @@ public:
protected:
void MarkAsReadClosed() {mReadClosed = true;}
void MarkAsWriteClosed() {mWriteClosed = true;}
+ void CheckForMissingTimeout(int Timeout);
+
+ // Converts a timeout in milliseconds (or IOStream::TimeOutInfinite)
+ // into one that can be passed to poll() (also in milliseconds), also
+ // compensating for time elapsed since the wait should have started,
+ // if known.
+ int PollTimeout(int timeout, box_time_t start_time)
+ {
+ if (timeout == IOStream::TimeOutInfinite)
+ {
+ return INFTIM;
+ }
+
+ if (start_time == 0)
+ {
+ return timeout; // no adjustment possible
+ }
+
+ box_time_t end_time = start_time + MilliSecondsToBoxTime(timeout);
+ box_time_t now = GetCurrentBoxTime();
+ box_time_t remaining = end_time - now;
+
+ if (remaining < 0)
+ {
+ return 0; // no delay
+ }
+ else if (BoxTimeToMilliSeconds(remaining) > INT_MAX)
+ {
+ return INT_MAX;
+ }
+ else
+ {
+ return (int) BoxTimeToMilliSeconds(remaining);
+ }
+ }
+ bool Poll(short Events, int Timeout);
private:
tOSSocketHandle mSocketHandle;
diff --git a/lib/server/SocketStreamTLS.cpp b/lib/server/SocketStreamTLS.cpp
index 576b53a2..e6299bfa 100644
--- a/lib/server/SocketStreamTLS.cpp
+++ b/lib/server/SocketStreamTLS.cpp
@@ -19,9 +19,10 @@
#include <poll.h>
#endif
+#include "autogen_ConnectionException.h"
+#include "autogen_ServerException.h"
#include "BoxTime.h"
#include "CryptoUtils.h"
-#include "ServerException.h"
#include "SocketStreamTLS.h"
#include "SSLLib.h"
#include "TLSContext.h"
@@ -131,7 +132,7 @@ void SocketStreamTLS::Handshake(const TLSContext &rContext, bool IsServer)
tOSSocketHandle socket = GetSocketHandle();
BIO_set_fd(mpBIO, socket, BIO_NOCLOSE);
-
+
// Then the SSL object
mpSSL = ::SSL_new(rContext.GetRawContext());
if(mpSSL == 0)
@@ -154,7 +155,7 @@ void SocketStreamTLS::Handshake(const TLSContext &rContext, bool IsServer)
THROW_EXCEPTION(ServerException, SocketSetNonBlockingFailed)
}
#endif
-
+
// FIXME: This is less portable than the above. However, it MAY be needed
// for cygwin, which has/had bugs with fcntl
//
@@ -196,7 +197,7 @@ void SocketStreamTLS::Handshake(const TLSContext &rContext, bool IsServer)
if(WaitWhenRetryRequired(se, TLS_HANDSHAKE_TIMEOUT) == false)
{
// timed out
- THROW_EXCEPTION(ConnectionException, Conn_TLSHandshakeTimedOut)
+ THROW_EXCEPTION(ConnectionException, TLSHandshakeTimedOut)
}
break;
@@ -205,12 +206,12 @@ void SocketStreamTLS::Handshake(const TLSContext &rContext, bool IsServer)
if(IsServer)
{
CryptoUtils::LogError("accepting connection");
- THROW_EXCEPTION(ConnectionException, Conn_TLSHandshakeFailed)
+ THROW_EXCEPTION(ConnectionException, TLSHandshakeFailed)
}
else
{
CryptoUtils::LogError("connecting");
- THROW_EXCEPTION(ConnectionException, Conn_TLSHandshakeFailed)
+ THROW_EXCEPTION(ConnectionException, TLSHandshakeFailed)
}
}
}
@@ -222,23 +223,25 @@ void SocketStreamTLS::Handshake(const TLSContext &rContext, bool IsServer)
//
// Function
// Name: WaitWhenRetryRequired(int, int)
-// Purpose: Waits until the condition required by the TLS layer is met.
-// Returns true if the condition is met, false if timed out.
+// Purpose: Waits until the condition required by the TLS layer
+// is met. Returns true if the condition is met, false
+// if timed out.
// Created: 2003/08/15
//
// --------------------------------------------------------------------------
bool SocketStreamTLS::WaitWhenRetryRequired(int SSLErrorCode, int Timeout)
{
- struct pollfd p;
- p.fd = GetSocketHandle();
+ CheckForMissingTimeout(Timeout);
+
+ short events;
switch(SSLErrorCode)
{
case SSL_ERROR_WANT_READ:
- p.events = POLLIN;
+ events = POLLIN;
break;
case SSL_ERROR_WANT_WRITE:
- p.events = POLLOUT;
+ events = POLLOUT;
break;
default:
@@ -246,45 +249,8 @@ bool SocketStreamTLS::WaitWhenRetryRequired(int SSLErrorCode, int Timeout)
THROW_EXCEPTION(ServerException, Internal)
break;
}
- p.revents = 0;
-
- int64_t start, end;
- start = BoxTimeToMilliSeconds(GetCurrentBoxTime());
- end = start + Timeout;
- int result;
-
- do
- {
- int64_t now = BoxTimeToMilliSeconds(GetCurrentBoxTime());
- int poll_timeout = (int)(end - now);
- if (poll_timeout < 0) poll_timeout = 0;
- if (Timeout == IOStream::TimeOutInfinite)
- {
- poll_timeout = INFTIM;
- }
- result = ::poll(&p, 1, poll_timeout);
- }
- while(result == -1 && errno == EINTR);
-
- switch(result)
- {
- case -1:
- // error - Bad!
- THROW_EXCEPTION(ServerException, SocketPollError)
- break;
-
- case 0:
- // Condition not met, timed out
- return false;
- break;
-
- default:
- // good to go!
- return true;
- break;
- }
- return true;
+ return Poll(events, Timeout);
}
// --------------------------------------------------------------------------
@@ -297,6 +263,7 @@ bool SocketStreamTLS::WaitWhenRetryRequired(int SSLErrorCode, int Timeout)
// --------------------------------------------------------------------------
int SocketStreamTLS::Read(void *pBuffer, int NBytes, int Timeout)
{
+ CheckForMissingTimeout(Timeout);
if(!mpSSL) {THROW_EXCEPTION(ServerException, TLSNoSSLObject)}
// Make sure zero byte reads work as expected
@@ -304,7 +271,7 @@ int SocketStreamTLS::Read(void *pBuffer, int NBytes, int Timeout)
{
return 0;
}
-
+
while(true)
{
int r = ::SSL_read(mpSSL, pBuffer, NBytes);
@@ -337,7 +304,7 @@ int SocketStreamTLS::Read(void *pBuffer, int NBytes, int Timeout)
default:
CryptoUtils::LogError("reading");
- THROW_EXCEPTION(ConnectionException, Conn_TLSReadFailed)
+ THROW_EXCEPTION(ConnectionException, TLSReadFailed)
break;
}
}
@@ -351,23 +318,23 @@ int SocketStreamTLS::Read(void *pBuffer, int NBytes, int Timeout)
// Created: 2003/08/06
//
// --------------------------------------------------------------------------
-void SocketStreamTLS::Write(const void *pBuffer, int NBytes)
+void SocketStreamTLS::Write(const void *pBuffer, int NBytes, int Timeout)
{
if(!mpSSL) {THROW_EXCEPTION(ServerException, TLSNoSSLObject)}
-
+
// Make sure zero byte writes work as expected
if(NBytes == 0)
{
return;
}
-
+
// from man SSL_write
//
// SSL_write() will only return with success, when the
// complete contents of buf of length num has been written.
//
// So no worries about partial writes and moving the buffer around
-
+
while(true)
{
// try the write
@@ -385,24 +352,24 @@ void SocketStreamTLS::Write(const void *pBuffer, int NBytes)
case SSL_ERROR_ZERO_RETURN:
// Connection closed
MarkAsWriteClosed();
- THROW_EXCEPTION(ConnectionException, Conn_TLSClosedWhenWriting)
+ THROW_EXCEPTION(ConnectionException, TLSClosedWhenWriting)
break;
case SSL_ERROR_WANT_READ:
case SSL_ERROR_WANT_WRITE:
- // wait for the requried data
+ // wait for the required data
{
#ifndef BOX_RELEASE_BUILD
- bool conditionmet =
+ bool conditionmet =
#endif
- WaitWhenRetryRequired(se, IOStream::TimeOutInfinite);
+ WaitWhenRetryRequired(se, Timeout);
ASSERT(conditionmet);
}
break;
default:
CryptoUtils::LogError("writing");
- THROW_EXCEPTION(ConnectionException, Conn_TLSWriteFailed)
+ THROW_EXCEPTION(ConnectionException, TLSWriteFailed)
break;
}
}
@@ -444,7 +411,7 @@ void SocketStreamTLS::Shutdown(bool Read, bool Write)
if(::SSL_shutdown(mpSSL) < 0)
{
CryptoUtils::LogError("shutting down");
- THROW_EXCEPTION(ConnectionException, Conn_TLSShutdownFailed)
+ THROW_EXCEPTION(ConnectionException, TLSShutdownFailed)
}
// Don't ask the base class to shutdown -- BIO does this, apparently.
@@ -467,15 +434,15 @@ std::string SocketStreamTLS::GetPeerCommonName()
if(cert == 0)
{
::X509_free(cert);
- THROW_EXCEPTION(ConnectionException, Conn_TLSNoPeerCertificate)
+ THROW_EXCEPTION(ConnectionException, TLSNoPeerCertificate)
}
- // Subject details
- X509_NAME *subject = ::X509_get_subject_name(cert);
+ // Subject details
+ X509_NAME *subject = ::X509_get_subject_name(cert);
if(subject == 0)
{
::X509_free(cert);
- THROW_EXCEPTION(ConnectionException, Conn_TLSPeerCertificateInvalid)
+ THROW_EXCEPTION(ConnectionException, TLSPeerCertificateInvalid)
}
// Common name
@@ -483,7 +450,7 @@ std::string SocketStreamTLS::GetPeerCommonName()
if(::X509_NAME_get_text_by_NID(subject, NID_commonName, commonName, sizeof(commonName)) <= 0)
{
::X509_free(cert);
- THROW_EXCEPTION(ConnectionException, Conn_TLSPeerCertificateInvalid)
+ THROW_EXCEPTION(ConnectionException, TLSPeerCertificateInvalid)
}
// Terminate just in case
commonName[sizeof(commonName)-1] = '\0';
diff --git a/lib/server/SocketStreamTLS.h b/lib/server/SocketStreamTLS.h
index bb40ed10..3fda98c1 100644
--- a/lib/server/SocketStreamTLS.h
+++ b/lib/server/SocketStreamTLS.h
@@ -43,7 +43,8 @@ public:
void Handshake(const TLSContext &rContext, bool IsServer = false);
virtual int Read(void *pBuffer, int NBytes, int Timeout = IOStream::TimeOutInfinite);
- virtual void Write(const void *pBuffer, int NBytes);
+ virtual void Write(const void *pBuffer, int NBytes,
+ int Timeout = IOStream::TimeOutInfinite);
virtual void Close();
virtual void Shutdown(bool Read = true, bool Write = true);
diff --git a/lib/server/TLSContext.cpp b/lib/server/TLSContext.cpp
index 341043e9..1a6d4a53 100644
--- a/lib/server/TLSContext.cpp
+++ b/lib/server/TLSContext.cpp
@@ -12,8 +12,9 @@
#define TLS_CLASS_IMPLEMENTATION_CPP
#include <openssl/ssl.h>
+#include "autogen_ConnectionException.h"
+#include "autogen_ServerException.h"
#include "CryptoUtils.h"
-#include "ServerException.h"
#include "SSLLib.h"
#include "TLSContext.h"
@@ -22,6 +23,17 @@
#define MAX_VERIFICATION_DEPTH 2
#define CIPHER_LIST "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"
+// Macros to allow compatibility with OpenSSL 1.0 and 1.1 APIs. See
+// https://github.com/charybdis-ircd/charybdis/blob/release/3.5/libratbox/src/openssl_ratbox.h
+// for the gory details.
+#if defined(LIBRESSL_VERSION_NUMBER) || (OPENSSL_VERSION_NUMBER >= 0x10100000L) // OpenSSL >= 1.1
+# define BOX_TLS_SERVER_METHOD TLS_server_method
+# define BOX_TLS_CLIENT_METHOD TLS_client_method
+#else // OpenSSL < 1.1
+# define BOX_TLS_SERVER_METHOD TLSv1_server_method
+# define BOX_TLS_CLIENT_METHOD TLSv1_client_method
+#endif
+
// --------------------------------------------------------------------------
//
// Function
@@ -66,7 +78,7 @@ void TLSContext::Initialise(bool AsServer, const char *CertificatesFile, const c
::SSL_CTX_free(mpContext);
}
- mpContext = ::SSL_CTX_new(AsServer?TLSv1_server_method():TLSv1_client_method());
+ mpContext = ::SSL_CTX_new(AsServer ? BOX_TLS_SERVER_METHOD() : BOX_TLS_CLIENT_METHOD());
if(mpContext == NULL)
{
THROW_EXCEPTION(ServerException, TLSAllocationFailed)
diff --git a/lib/server/TcpNice.cpp b/lib/server/TcpNice.cpp
index 20619e49..79e91eeb 100644
--- a/lib/server/TcpNice.cpp
+++ b/lib/server/TcpNice.cpp
@@ -146,7 +146,7 @@ void NiceSocketStream::Write(const void *pBuffer, int NBytes)
int socket = mapSocket->GetSocketHandle();
int rtt = 50; // WAG
-# if HAVE_DECL_SOL_TCP && HAVE_DECL_TCP_INFO && HAVE_STRUCT_TCP_INFO_TCPI_RTT
+# if HAVE_DECL_SOL_TCP && defined HAVE_STRUCT_TCP_INFO_TCPI_RTT
struct tcp_info info;
socklen_t optlen = sizeof(info);
if(getsockopt(socket, SOL_TCP, TCP_INFO, &info, &optlen) == -1)
@@ -154,7 +154,7 @@ void NiceSocketStream::Write(const void *pBuffer, int NBytes)
BOX_LOG_SYS_WARNING("getsockopt(" << socket << ", SOL_TCP, "
"TCP_INFO) failed");
}
- else if(optlen != sizeof(info))
+ else if(optlen < sizeof(info))
{
BOX_WARNING("getsockopt(" << socket << ", SOL_TCP, "
"TCP_INFO) return structure size " << optlen << ", "
@@ -164,7 +164,7 @@ void NiceSocketStream::Write(const void *pBuffer, int NBytes)
{
rtt = info.tcpi_rtt;
}
-# endif
+# endif // HAVE_DECL_SOL_TCP && defined HAVE_STRUCT_TCP_INFO_TCPI_RTT
int newWindow = mTcpNice.GetNextWindowSize(mBytesWrittenThisPeriod,
elapsed, rtt);
diff --git a/lib/server/TcpNice.h b/lib/server/TcpNice.h
index e2027749..4381df42 100644
--- a/lib/server/TcpNice.h
+++ b/lib/server/TcpNice.h
@@ -83,7 +83,7 @@ private:
//
// --------------------------------------------------------------------------
-class NiceSocketStream : public IOStream
+class NiceSocketStream : public SocketStream
{
private:
std::auto_ptr<SocketStream> mapSocket;
@@ -166,6 +166,10 @@ public:
}
virtual void SetEnabled(bool enabled);
+ off_t GetBytesRead() const { return mapSocket->GetBytesRead(); }
+ off_t GetBytesWritten() const { return mapSocket->GetBytesWritten(); }
+ void ResetCounters() { mapSocket->ResetCounters(); }
+
private:
NiceSocketStream(const NiceSocketStream &rToCopy)
{ /* do not call */ }
diff --git a/lib/server/WinNamedPipeListener.h b/lib/server/WinNamedPipeListener.h
index 26e76e3d..956a7b5a 100644
--- a/lib/server/WinNamedPipeListener.h
+++ b/lib/server/WinNamedPipeListener.h
@@ -11,10 +11,10 @@
#ifndef WINNAMEDPIPELISTENER__H
#define WINNAMEDPIPELISTENER__H
-#include <OverlappedIO.h>
-#include <WinNamedPipeStream.h>
-
-#include "ServerException.h"
+#include "autogen_ConnectionException.h"
+#include "autogen_ServerException.h"
+#include "OverlappedIO.h"
+#include "WinNamedPipeStream.h"
#include "MemLeakFindOn.h"
@@ -53,8 +53,8 @@ private:
socket.c_str(), // pipe name
PIPE_ACCESS_DUPLEX | // read/write access
FILE_FLAG_OVERLAPPED, // enabled overlapped I/O
- PIPE_TYPE_BYTE | // message type pipe
- PIPE_READMODE_BYTE | // message-read mode
+ PIPE_TYPE_BYTE |
+ PIPE_READMODE_BYTE |
PIPE_WAIT, // blocking mode
ListenBacklog + 1, // max. instances
4096, // output buffer size
@@ -64,9 +64,9 @@ private:
if (handle == INVALID_HANDLE_VALUE)
{
- BOX_LOG_WIN_ERROR("Failed to create named pipe " <<
- socket);
- THROW_EXCEPTION(ServerException, SocketOpenError)
+ THROW_WIN_FILE_ERRNO("Failed to create named pipe",
+ socket, GetLastError(), ServerException,
+ SocketOpenError);
}
return handle;
diff --git a/lib/server/WinNamedPipeStream.cpp b/lib/server/WinNamedPipeStream.cpp
index 1179516e..448a3c9d 100644
--- a/lib/server/WinNamedPipeStream.cpp
+++ b/lib/server/WinNamedPipeStream.cpp
@@ -19,10 +19,12 @@
#include <errno.h>
#include <windows.h>
-#include "WinNamedPipeStream.h"
-#include "ServerException.h"
+#include "autogen_ConnectionException.h"
+#include "autogen_ServerException.h"
+#include "BoxTime.h"
#include "CommonException.h"
#include "Socket.h"
+#include "WinNamedPipeStream.h"
#include "MemLeakFindOn.h"
@@ -37,13 +39,14 @@ std::string WinNamedPipeStream::sPipeNamePrefix = "\\\\.\\pipe\\";
//
// --------------------------------------------------------------------------
WinNamedPipeStream::WinNamedPipeStream()
- : mSocketHandle(INVALID_HANDLE_VALUE),
- mReadableEvent(INVALID_HANDLE_VALUE),
- mBytesInBuffer(0),
- mReadClosed(false),
- mWriteClosed(false),
- mIsServer(false),
- mIsConnected(false)
+: mSocketHandle(INVALID_HANDLE_VALUE),
+ mReadableEvent(INVALID_HANDLE_VALUE),
+ mBytesInBuffer(0),
+ mReadClosed(false),
+ mWriteClosed(false),
+ mIsServer(false),
+ mIsConnected(false),
+ mNeedAnotherRead(false)
{ }
// --------------------------------------------------------------------------
@@ -55,14 +58,21 @@ WinNamedPipeStream::WinNamedPipeStream()
//
// --------------------------------------------------------------------------
WinNamedPipeStream::WinNamedPipeStream(HANDLE hNamedPipe)
- : mSocketHandle(hNamedPipe),
- mReadableEvent(INVALID_HANDLE_VALUE),
- mBytesInBuffer(0),
- mReadClosed(false),
- mWriteClosed(false),
- mIsServer(true),
- mIsConnected(true)
+: mSocketHandle(hNamedPipe),
+ mReadableEvent(INVALID_HANDLE_VALUE),
+ mBytesInBuffer(0),
+ mReadClosed(false),
+ mWriteClosed(false),
+ mIsServer(true),
+ mIsConnected(true),
+ mNeedAnotherRead(false)
{
+ StartFirstRead();
+}
+
+// Start the first overlapped read
+void WinNamedPipeStream::StartFirstRead()
+{
// create the Readable event
mReadableEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
@@ -74,23 +84,50 @@ WinNamedPipeStream::WinNamedPipeStream(HANDLE hNamedPipe)
THROW_EXCEPTION(CommonException, Internal)
}
- // initialise the OVERLAPPED structure
+ StartOverlappedRead();
+}
+
+void WinNamedPipeStream::StartOverlappedRead()
+{
+ // We should only do this when the buffer is empty. We don't want
+ // to start an overlapped read anywhere else than the start of the
+ // buffer, because it could complete at any time and we don't want
+ // to mess about with interrupting the read already in progress.
+ ASSERT(mBytesInBuffer == 0);
+
+ // Initialise the OVERLAPPED structure
memset(&mReadOverlap, 0, sizeof(mReadOverlap));
mReadOverlap.hEvent = mReadableEvent;
- // start the first overlapped read
if (!ReadFile(mSocketHandle, mReadBuffer, sizeof(mReadBuffer),
NULL, &mReadOverlap))
{
DWORD err = GetLastError();
-
- if (err != ERROR_IO_PENDING)
+ if (err == ERROR_IO_PENDING)
+ {
+ // Don't reset yet, there might be data
+ // in the buffer waiting to be read,
+ // will check below.
+ // ResetEvent(mReadableEvent);
+ }
+ else if (err == ERROR_HANDLE_EOF)
+ {
+ BOX_INFO("Control client disconnected");
+ mReadClosed = true;
+ }
+ else if (err == ERROR_BROKEN_PIPE ||
+ err == ERROR_PIPE_NOT_CONNECTED)
+ {
+ BOX_NOTICE("Control client disconnected");
+ mReadClosed = true;
+ mIsConnected = false;
+ }
+ else
{
- BOX_ERROR("Failed to start overlapped read: " <<
- GetErrorMessage(err));
Close();
- THROW_EXCEPTION(ConnectionException,
- Conn_SocketReadError)
+ THROW_WIN_ERROR_NUMBER("Failed to start overlapped "
+ "read", err, ConnectionException,
+ SocketReadError)
}
}
}
@@ -105,6 +142,12 @@ WinNamedPipeStream::WinNamedPipeStream(HANDLE hNamedPipe)
// --------------------------------------------------------------------------
WinNamedPipeStream::~WinNamedPipeStream()
{
+ for(std::list<WriteInProgress*>::iterator i = mWritesInProgress.begin();
+ i != mWritesInProgress.end(); i++)
+ {
+ delete *i;
+ }
+
if (mSocketHandle != INVALID_HANDLE_VALUE)
{
try
@@ -157,36 +200,7 @@ void WinNamedPipeStream::Accept()
mIsServer = true; // must flush and disconnect before closing
mIsConnected = true;
- // create the Readable event
- mReadableEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
-
- if (mReadableEvent == INVALID_HANDLE_VALUE)
- {
- BOX_ERROR("Failed to create the Readable event: " <<
- GetErrorMessage(GetLastError()));
- Close();
- THROW_EXCEPTION(CommonException, Internal)
- }
-
- // initialise the OVERLAPPED structure
- memset(&mReadOverlap, 0, sizeof(mReadOverlap));
- mReadOverlap.hEvent = mReadableEvent;
-
- // start the first overlapped read
- if (!ReadFile(mSocketHandle, mReadBuffer, sizeof(mReadBuffer),
- NULL, &mReadOverlap))
- {
- DWORD err = GetLastError();
-
- if (err != ERROR_IO_PENDING)
- {
- BOX_ERROR("Failed to start overlapped read: " <<
- GetErrorMessage(err));
- Close();
- THROW_EXCEPTION(ConnectionException,
- Conn_SocketReadError)
- }
- }
+ StartFirstRead();
}
*/
@@ -214,7 +228,7 @@ void WinNamedPipeStream::Connect(const std::string& rName)
0, // no sharing
NULL, // default security attributes
OPEN_EXISTING,
- 0, // default attributes
+ 0, // FILE_FLAG_OVERLAPPED, // dwFlagsAndAttributes
NULL); // no template file
if (mSocketHandle == INVALID_HANDLE_VALUE)
@@ -237,6 +251,86 @@ void WinNamedPipeStream::Connect(const std::string& rName)
mWriteClosed = false;
mIsServer = false; // just close the socket
mIsConnected = true;
+
+ StartFirstRead();
+}
+
+// Returns true if the operation is complete (and you will need to start
+// another one), or false otherwise (you can wait again).
+bool WinNamedPipeStream::WaitForOverlappedOperation(OVERLAPPED& Overlapped,
+ int Timeout, int64_t* pBytesTransferred)
+{
+ if (Timeout == IOStream::TimeOutInfinite)
+ {
+ Timeout = INFINITE;
+ }
+
+ // overlapped I/O completed successfully? (wait if needed)
+ DWORD waitResult = WaitForSingleObject(Overlapped.hEvent, Timeout);
+ DWORD NumBytesTransferred = -1;
+
+ if (waitResult == WAIT_FAILED)
+ {
+ THROW_WIN_ERROR_NUMBER("Failed to wait for overlapped I/O",
+ GetLastError(), ServerException, Internal);
+ }
+
+ if (waitResult == WAIT_ABANDONED)
+ {
+ THROW_EXCEPTION_MESSAGE(ServerException, Internal,
+ "Wait for overlapped I/O abandoned by system");
+ }
+
+ if (waitResult == WAIT_TIMEOUT)
+ {
+ // wait timed out, nothing to read
+ *pBytesTransferred = 0;
+ return false;
+ }
+
+ if (waitResult != WAIT_OBJECT_0)
+ {
+ THROW_EXCEPTION_MESSAGE(ServerException, BadSocketHandle,
+ "Failed to wait for overlapped I/O: unknown "
+ "result code: " << waitResult);
+ }
+
+ // Overlapped operation completed successfully. Return the number
+ // of bytes transferred.
+ if (GetOverlappedResult(mSocketHandle, &Overlapped,
+ &NumBytesTransferred, TRUE))
+ {
+ *pBytesTransferred = NumBytesTransferred;
+ return true;
+ }
+
+ // We are here because GetOverlappedResult() informed us that the
+ // overlapped operation encountered an error, so what was it?
+ DWORD err = GetLastError();
+
+ if (err == ERROR_HANDLE_EOF)
+ {
+ Close();
+ *pBytesTransferred = 0;
+ return true;
+ }
+
+ // ERROR_NO_DATA is a strange name for
+ // "The pipe is being closed". No exception wanted.
+
+ if (err == ERROR_NO_DATA ||
+ err == ERROR_PIPE_NOT_CONNECTED ||
+ err == ERROR_BROKEN_PIPE)
+ {
+ BOX_INFO(BOX_WIN_ERRNO_MESSAGE(err,
+ "Named pipe peer disconnected"));
+ Close();
+ *pBytesTransferred = 0;
+ return true;
+ }
+
+ THROW_WIN_ERROR_NUMBER("Failed to wait for overlapped I/O "
+ "to complete", err, ConnectionException, SocketReadError);
}
// --------------------------------------------------------------------------
@@ -249,192 +343,61 @@ void WinNamedPipeStream::Connect(const std::string& rName)
// --------------------------------------------------------------------------
int WinNamedPipeStream::Read(void *pBuffer, int NBytes, int Timeout)
{
- // TODO no support for timeouts yet
- if (!mIsServer && Timeout != IOStream::TimeOutInfinite)
- {
- THROW_EXCEPTION(CommonException, AssertFailed)
- }
-
if (mSocketHandle == INVALID_HANDLE_VALUE || !mIsConnected)
{
- THROW_EXCEPTION(ServerException, BadSocketHandle)
+ THROW_EXCEPTION_MESSAGE(ServerException, BadSocketHandle,
+ "Tried to read from closed pipe");
}
if (mReadClosed)
{
- THROW_EXCEPTION(ConnectionException, SocketShutdownError)
+ THROW_EXCEPTION_MESSAGE(ConnectionException,
+ SocketShutdownError, "Tried to read from closing pipe");
}
// ensure safe to cast NBytes to unsigned
if (NBytes < 0)
{
- THROW_EXCEPTION(CommonException, AssertFailed)
+ THROW_EXCEPTION(CommonException, AssertFailed);
}
- DWORD NumBytesRead;
+ int64_t NumBytesRead;
- if (mIsServer)
+ // Satisfy from buffer if possible, to avoid blocking on read.
+ if (mBytesInBuffer == 0)
{
- // satisfy from buffer if possible, to avoid
- // blocking on read.
- bool needAnotherRead = false;
- if (mBytesInBuffer == 0)
- {
- // overlapped I/O completed successfully?
- // (wait if needed)
- DWORD waitResult = WaitForSingleObject(
- mReadOverlap.hEvent, Timeout);
-
- if (waitResult == WAIT_ABANDONED)
- {
- BOX_ERROR("Wait for command socket read "
- "abandoned by system");
- THROW_EXCEPTION(ServerException,
- BadSocketHandle);
- }
- else if (waitResult == WAIT_TIMEOUT)
- {
- // wait timed out, nothing to read
- NumBytesRead = 0;
- }
- else if (waitResult != WAIT_OBJECT_0)
- {
- BOX_ERROR("Failed to wait for command "
- "socket read: unknown result " <<
- waitResult);
- }
- // object is ready to read from
- else if (GetOverlappedResult(mSocketHandle,
- &mReadOverlap, &NumBytesRead, TRUE))
- {
- needAnotherRead = true;
- }
- else
- {
- DWORD err = GetLastError();
-
- if (err == ERROR_HANDLE_EOF)
- {
- mReadClosed = true;
- }
- else
- {
- if (err == ERROR_BROKEN_PIPE)
- {
- BOX_NOTICE("Control client "
- "disconnected");
- }
- else
- {
- BOX_ERROR("Failed to wait for "
- "ReadFile to complete: "
- << GetErrorMessage(err));
- }
-
- Close();
- THROW_EXCEPTION(ConnectionException,
- Conn_SocketReadError)
- }
- }
- }
- else
- {
- NumBytesRead = 0;
- }
-
- size_t BytesToCopy = NumBytesRead + mBytesInBuffer;
- size_t BytesRemaining = 0;
-
- if (BytesToCopy > (size_t)NBytes)
- {
- BytesRemaining = BytesToCopy - NBytes;
- BytesToCopy = NBytes;
- }
-
- memcpy(pBuffer, mReadBuffer, BytesToCopy);
- memmove(mReadBuffer, mReadBuffer + BytesToCopy, BytesRemaining);
-
- mBytesInBuffer = BytesRemaining;
- NumBytesRead = BytesToCopy;
-
- if (needAnotherRead)
+ if (mNeedAnotherRead)
{
- // reinitialise the OVERLAPPED structure
- memset(&mReadOverlap, 0, sizeof(mReadOverlap));
- mReadOverlap.hEvent = mReadableEvent;
+ // Start the next overlapped read
+ StartOverlappedRead();
}
- // start the next overlapped read
- if (needAnotherRead && !ReadFile(mSocketHandle,
- mReadBuffer + mBytesInBuffer,
- sizeof(mReadBuffer) - mBytesInBuffer,
- NULL, &mReadOverlap))
- {
- DWORD err = GetLastError();
- if (err == ERROR_IO_PENDING)
- {
- // Don't reset yet, there might be data
- // in the buffer waiting to be read,
- // will check below.
- // ResetEvent(mReadableEvent);
- }
- else if (err == ERROR_HANDLE_EOF)
- {
- mReadClosed = true;
- }
- else if (err == ERROR_BROKEN_PIPE)
- {
- BOX_ERROR("Control client disconnected");
- mReadClosed = true;
- }
- else
- {
- BOX_ERROR("Failed to start overlapped read: "
- << GetErrorMessage(err));
- Close();
- THROW_EXCEPTION(ConnectionException,
- Conn_SocketReadError)
- }
- }
+ mNeedAnotherRead = WaitForOverlappedOperation(mReadOverlap,
+ Timeout, &NumBytesRead);
}
else
{
- if (!ReadFile(
- mSocketHandle, // pipe handle
- pBuffer, // buffer to receive reply
- NBytes, // size of buffer
- &NumBytesRead, // number of bytes read
- NULL)) // not overlapped
- {
- DWORD err = GetLastError();
-
- Close();
+ // Just return the existing data from the buffer
+ // this time around. The caller should call again,
+ // and then the buffer will be empty.
+ NumBytesRead = 0;
+ }
- // ERROR_NO_DATA is a strange name for
- // "The pipe is being closed". No exception wanted.
-
- if (err == ERROR_NO_DATA ||
- err == ERROR_PIPE_NOT_CONNECTED)
- {
- NumBytesRead = 0;
- }
- else
- {
- BOX_ERROR("Failed to read from control socket: "
- << GetErrorMessage(err));
- THROW_EXCEPTION(ConnectionException,
- Conn_SocketReadError)
- }
- }
-
- // Closed for reading at EOF?
- if (NumBytesRead == 0)
- {
- mReadClosed = true;
- }
+ int BytesToCopy = NumBytesRead + mBytesInBuffer;
+
+ if (NBytes < BytesToCopy)
+ {
+ BytesToCopy = NBytes;
}
-
- return NumBytesRead;
+
+ memcpy(pBuffer, mReadBuffer, BytesToCopy);
+
+ size_t BytesRemaining = mBytesInBuffer + NumBytesRead - BytesToCopy;
+ ASSERT(BytesToCopy + BytesRemaining <= sizeof(mReadBuffer));
+ memmove(mReadBuffer, mReadBuffer + BytesToCopy, BytesRemaining);
+ mBytesInBuffer = BytesRemaining;
+
+ return BytesToCopy;
}
// --------------------------------------------------------------------------
@@ -445,8 +408,15 @@ int WinNamedPipeStream::Read(void *pBuffer, int NBytes, int Timeout)
// Created: 2003/07/31
//
// --------------------------------------------------------------------------
-void WinNamedPipeStream::Write(const void *pBuffer, int NBytes)
+void WinNamedPipeStream::Write(const void *pBuffer, int NBytes, int Timeout)
{
+ // Calculate the deadline at the beginning. Not valid if Timeout is
+ // IOStream::TimeOutInfinite!
+ ASSERT(Timeout != IOStream::TimeOutInfinite);
+
+ box_time_t deadline = GetCurrentBoxTime() +
+ MilliSecondsToBoxTime(Timeout);
+
if (mSocketHandle == INVALID_HANDLE_VALUE || !mIsConnected)
{
THROW_EXCEPTION(ServerException, BadSocketHandle)
@@ -454,41 +424,62 @@ void WinNamedPipeStream::Write(const void *pBuffer, int NBytes)
// Buffer in byte sized type.
ASSERT(sizeof(char) == 1);
- const char *pByteBuffer = (char *)pBuffer;
-
- int NumBytesWrittenTotal = 0;
-
- while (NumBytesWrittenTotal < NBytes)
+ WriteInProgress* new_write = new WriteInProgress(
+ std::string((char *)pBuffer, NBytes));
+
+ // Start the WriteFile operation, and add to queue if pending.
+ BOOL Success = WriteFile(
+ mSocketHandle, // pipe handle
+ new_write->mBuffer.c_str(), // message
+ NBytes, // message length
+ NULL, // bytes written this time
+ &(new_write->mOverlap));
+
+ if (Success == TRUE)
{
- DWORD NumBytesWrittenThisTime = 0;
-
- bool Success = WriteFile(
- mSocketHandle, // pipe handle
- pByteBuffer + NumBytesWrittenTotal, // message
- NBytes - NumBytesWrittenTotal, // message length
- &NumBytesWrittenThisTime, // bytes written this time
- NULL); // not overlapped
+ // Unfortunately this does happen. We should still call
+ // GetOverlappedResult() to get the number of bytes written,
+ // so we can treat it just the same.
+ // BOX_NOTICE("Write claimed success while overlapped?");
+ mWritesInProgress.push_back(new_write);
+ }
+ else
+ {
+ DWORD err = GetLastError();
- if (!Success)
+ if (err == ERROR_IO_PENDING)
{
- // ERROR_NO_DATA is a strange name for
- // "The pipe is being closed".
-
- DWORD err = GetLastError();
-
- if (err != ERROR_NO_DATA)
- {
- BOX_ERROR("Failed to write to control "
- "socket: " << GetErrorMessage(err));
- }
-
+ BOX_TRACE("WriteFile is pending, adding to queue");
+ mWritesInProgress.push_back(new_write);
+ }
+ else
+ {
+ // Not in progress any more, pop it
Close();
-
- THROW_EXCEPTION(ConnectionException,
- Conn_SocketWriteError)
+ THROW_WIN_ERROR_NUMBER("Failed to start overlapped "
+ "write", err, ConnectionException,
+ SocketWriteError);
}
+ }
+
+ // Wait for previous WriteFile operations to complete, one at a time,
+ // until the deadline expires or the pipe becomes disconnected.
+ for(box_time_t remaining = deadline - GetCurrentBoxTime();
+ remaining > 0 && !mWritesInProgress.empty() && mIsConnected;
+ remaining = deadline - GetCurrentBoxTime())
+ {
+ int new_timeout = BoxTimeToMilliSeconds(remaining);
+ WriteInProgress* oldest_write =
+ *(mWritesInProgress.begin());
- NumBytesWrittenTotal += NumBytesWrittenThisTime;
+ int64_t bytes_written = 0;
+ if(WaitForOverlappedOperation(oldest_write->mOverlap,
+ new_timeout, &bytes_written))
+ {
+ // This one is complete, pop it and start a new one
+ delete oldest_write;
+ mWritesInProgress.pop_front();
+ }
}
}
@@ -513,59 +504,47 @@ void WinNamedPipeStream::Close()
THROW_EXCEPTION(ServerException, BadSocketHandle)
}
- if (mIsServer)
+ if (!CancelIo(mSocketHandle))
{
- if (!CancelIo(mSocketHandle))
- {
- BOX_ERROR("Failed to cancel outstanding I/O: " <<
- GetErrorMessage(GetLastError()));
- }
+ BOX_LOG_WIN_ERROR("Failed to cancel outstanding I/O");
+ }
- if (mReadableEvent == INVALID_HANDLE_VALUE)
- {
- BOX_ERROR("Failed to destroy Readable event: "
- "invalid handle");
- }
- else if (!CloseHandle(mReadableEvent))
- {
- BOX_ERROR("Failed to destroy Readable event: " <<
- GetErrorMessage(GetLastError()));
- }
+ if (mReadableEvent == INVALID_HANDLE_VALUE)
+ {
+ BOX_ERROR("Failed to destroy Readable event: "
+ "invalid handle");
+ }
+ else if (!CloseHandle(mReadableEvent))
+ {
+ BOX_LOG_WIN_ERROR("Failed to destroy Readable event");
+ }
- mReadableEvent = INVALID_HANDLE_VALUE;
+ mReadableEvent = INVALID_HANDLE_VALUE;
- if (!FlushFileBuffers(mSocketHandle))
- {
- BOX_ERROR("Failed to FlushFileBuffers: " <<
- GetErrorMessage(GetLastError()));
- }
-
- if (!DisconnectNamedPipe(mSocketHandle))
+ if (mIsConnected && !FlushFileBuffers(mSocketHandle))
+ {
+ BOX_LOG_WIN_ERROR("Failed to FlushFileBuffers");
+ }
+
+ if (mIsServer && mIsConnected && !DisconnectNamedPipe(mSocketHandle))
+ {
+ DWORD err = GetLastError();
+ if (err != ERROR_PIPE_NOT_CONNECTED)
{
- DWORD err = GetLastError();
- if (err != ERROR_PIPE_NOT_CONNECTED)
- {
- BOX_ERROR("Failed to DisconnectNamedPipe: " <<
- GetErrorMessage(err));
- }
+ BOX_LOG_WIN_ERROR("Failed to DisconnectNamedPipe");
}
-
- mIsServer = false;
}
- bool result = CloseHandle(mSocketHandle);
+ if (!CloseHandle(mSocketHandle))
+ {
+ THROW_WIN_ERROR_NUMBER("Failed to CloseHandle",
+ GetLastError(), ServerException, SocketCloseError);
+ }
mSocketHandle = INVALID_HANDLE_VALUE;
mIsConnected = false;
mReadClosed = true;
mWriteClosed = true;
-
- if (!result)
- {
- BOX_ERROR("Failed to CloseHandle: " <<
- GetErrorMessage(GetLastError()));
- THROW_EXCEPTION(ServerException, SocketCloseError)
- }
}
// --------------------------------------------------------------------------
diff --git a/lib/server/WinNamedPipeStream.h b/lib/server/WinNamedPipeStream.h
index 386ff7e3..5473c690 100644
--- a/lib/server/WinNamedPipeStream.h
+++ b/lib/server/WinNamedPipeStream.h
@@ -10,6 +10,8 @@
#if ! defined WINNAMEDPIPESTREAM__H && defined WIN32
#define WINNAMEDPIPESTREAM__H
+#include <list>
+
#include "IOStream.h"
// --------------------------------------------------------------------------
@@ -36,15 +38,27 @@ public:
// both sides
virtual int Read(void *pBuffer, int NBytes,
int Timeout = IOStream::TimeOutInfinite);
- virtual void Write(const void *pBuffer, int NBytes);
+ virtual void Write(const void *pBuffer, int NBytes,
+ int Timeout = IOStream::TimeOutInfinite);
virtual void WriteAllBuffered();
virtual void Close();
virtual bool StreamDataLeft();
virtual bool StreamClosed();
+ // Why not inherited from IOStream? Never mind, we want to enforce
+ // supplying a timeout for network operations anyway.
+ virtual void Write(const std::string& rBuffer, int Timeout)
+ {
+ IOStream::Write(rBuffer, Timeout);
+ }
+
protected:
void MarkAsReadClosed() {mReadClosed = true;}
void MarkAsWriteClosed() {mWriteClosed = true;}
+ bool WaitForOverlappedOperation(OVERLAPPED& Overlapped,
+ int Timeout, int64_t* pBytesTransferred);
+ void StartFirstRead();
+ void StartOverlappedRead();
private:
WinNamedPipeStream(const WinNamedPipeStream &rToCopy)
@@ -59,6 +73,37 @@ private:
bool mWriteClosed;
bool mIsServer;
bool mIsConnected;
+ bool mNeedAnotherRead;
+
+ class WriteInProgress {
+ private:
+ friend class WinNamedPipeStream;
+ std::string mBuffer;
+ OVERLAPPED mOverlap;
+ WriteInProgress(const WriteInProgress& other); // do not call
+ public:
+ WriteInProgress(const std::string& dataToWrite)
+ : mBuffer(dataToWrite)
+ {
+ // create the Writable event
+ HANDLE writable_event = CreateEvent(NULL, TRUE, FALSE,
+ NULL);
+ if (writable_event == INVALID_HANDLE_VALUE)
+ {
+ BOX_LOG_WIN_ERROR("Failed to create the "
+ "Writable event");
+ THROW_EXCEPTION(CommonException, Internal)
+ }
+
+ memset(&mOverlap, 0, sizeof(mOverlap));
+ mOverlap.hEvent = writable_event;
+ }
+ ~WriteInProgress()
+ {
+ CloseHandle(mOverlap.hEvent);
+ }
+ };
+ std::list<WriteInProgress*> mWritesInProgress;
public:
static std::string sPipeNamePrefix;
diff --git a/lib/server/makeprotocol.pl.in b/lib/server/makeprotocol.pl.in
index a074b435..d6c0e216 100755
--- a/lib/server/makeprotocol.pl.in
+++ b/lib/server/makeprotocol.pl.in
@@ -78,7 +78,7 @@ sub add_type
my ($protocol_name, $cpp_name, $header_file) = split /\s+/,$_[0];
$translate_type_info{$protocol_name} = [0, $cpp_name];
- push @extra_header_files, $header_file;
+ push @extra_header_files, $header_file if $header_file;
}
# check attributes
@@ -158,7 +158,10 @@ print CPP <<__E;
#include <sstream>
#include "$filename_base.h"
-#include "IOStream.h"
+#include "CollectInBufferStream.h"
+#include "MemBlockStream.h"
+#include "SelfFlushingStream.h"
+#include "SocketStream.h"
__E
print H <<__E;
@@ -174,12 +177,10 @@ print H <<__E;
#include <syslog.h>
#endif
+#include "autogen_ConnectionException.h"
#include "Protocol.h"
#include "Message.h"
-#include "ServerException.h"
-
-class IOStream;
-
+#include "SocketStream.h"
__E
@@ -210,19 +211,26 @@ __E
my $request_base_class = "${protocol_name}ProtocolRequest";
my $reply_base_class = "${protocol_name}ProtocolReply";
# the abstract protocol interface
-my $protocol_base_class = $protocol_name."ProtocolBase";
+my $custom_protocol_subclass = $protocol_name."Protocol";
+my $client_server_base_class = $protocol_name."ProtocolClientServer";
my $replyable_base_class = $protocol_name."ProtocolReplyable";
+my $callable_base_class = $protocol_name."ProtocolCallable";
+my $send_receive_class = $protocol_name."ProtocolSendReceive";
print H <<__E;
-class $protocol_base_class;
+class $custom_protocol_subclass;
+class $client_server_base_class;
+class $callable_base_class;
class $replyable_base_class;
-class $reply_base_class;
class $message_base_class : public Message
{
public:
virtual std::auto_ptr<$message_base_class> DoCommand($replyable_base_class &rProtocol,
$context_class &rContext) const;
+ virtual std::auto_ptr<$message_base_class> DoCommand($replyable_base_class &rProtocol,
+ $context_class &rContext, IOStream& rDataStream) const;
+ virtual bool HasStreamWithCommand() const = 0;
};
class $reply_base_class
@@ -233,17 +241,44 @@ class $request_base_class
{
};
+class $send_receive_class {
+public:
+ virtual void Send(const $message_base_class &rObject) = 0;
+ virtual std::auto_ptr<$message_base_class> Receive() = 0;
+};
+
+class $custom_protocol_subclass : public Protocol
+{
+public:
+ $custom_protocol_subclass(std::auto_ptr<SocketStream> apConn)
+ : Protocol(apConn)
+ { }
+ virtual ~$custom_protocol_subclass() { }
+ virtual std::auto_ptr<Message> MakeMessage(int ObjType);
+ virtual const char *GetProtocolIdentString();
+
+private:
+ $custom_protocol_subclass(const $custom_protocol_subclass &rToCopy);
+};
+
__E
print CPP <<__E;
std::auto_ptr<$message_base_class> $message_base_class\::DoCommand($replyable_base_class &rProtocol,
$context_class &rContext) const
{
- THROW_EXCEPTION(ConnectionException, Conn_Protocol_TriedToExecuteReplyCommand)
+ THROW_EXCEPTION(ConnectionException, Protocol_TriedToExecuteReplyCommand)
+}
+
+std::auto_ptr<$message_base_class> $message_base_class\::DoCommand($replyable_base_class &rProtocol,
+ $context_class &rContext, IOStream& rDataStream) const
+{
+ THROW_EXCEPTION(ConnectionException, Protocol_TriedToExecuteReplyCommand)
}
__E
-my %cmd_class;
+my %cmd_classes;
+my $error_message = undef;
# output the classes
foreach my $cmd (@cmd_list)
@@ -262,7 +297,7 @@ foreach my $cmd (@cmd_list)
my $cmd_base_class = join(", ", map {"public $_"} @cmd_base_classes);
my $cmd_class = $protocol_name."Protocol".$cmd;
- $cmd_class{$cmd} = $cmd_class;
+ $cmd_classes{$cmd} = $cmd_class;
print H <<__E;
class $cmd_class : $cmd_base_class
@@ -294,18 +329,50 @@ __E
if(obj_is_type($cmd,'IsError'))
{
- print H "\tbool IsError(int &rTypeOut, int &rSubTypeOut) const;\n";
- print H "\tstd::string GetMessage() const;\n";
+ $error_message = $cmd;
+ my ($mem_type,$mem_subtype) = split /,/,obj_get_type_params($cmd,'IsError');
+ my $error_type = $cmd_constants{"ErrorType"};
+ print H <<__E;
+ $cmd_class(int SubType) : m$mem_type($error_type), m$mem_subtype(SubType) { }
+ bool IsError(int &rTypeOut, int &rSubTypeOut) const;
+ std::string GetMessage() const { return GetMessage(m$mem_subtype); };
+ static std::string GetMessage(int subtype);
+__E
}
- if(obj_is_type($cmd, 'Command'))
+ my $has_stream = obj_is_type($cmd, 'StreamWithCommand');
+
+ if(obj_is_type($cmd, 'Command') && $has_stream)
+ {
+ print H <<__E;
+ std::auto_ptr<$message_base_class> DoCommand($replyable_base_class &rProtocol,
+ $context_class &rContext, IOStream& rDataStream) const; // IMPLEMENT THIS\n
+ std::auto_ptr<$message_base_class> DoCommand($replyable_base_class &rProtocol,
+ $context_class &rContext) const
+ {
+ THROW_EXCEPTION_MESSAGE(CommonException, Internal,
+ "This command requires a stream parameter");
+ }
+__E
+ }
+ elsif(obj_is_type($cmd, 'Command') && !$has_stream)
{
print H <<__E;
std::auto_ptr<$message_base_class> DoCommand($replyable_base_class &rProtocol,
$context_class &rContext) const; // IMPLEMENT THIS\n
+ std::auto_ptr<$message_base_class> DoCommand($replyable_base_class &rProtocol,
+ $context_class &rContext, IOStream& rDataStream) const
+ {
+ THROW_EXCEPTION_MESSAGE(CommonException, NotSupported,
+ "This command requires no stream parameter");
+ }
__E
}
+ print H <<__E;
+ bool HasStreamWithCommand() const { return $has_stream; }
+__E
+
# want to be able to read from streams?
print H "\tvoid SetPropertiesFromStreamData(Protocol &rProtocol);\n";
@@ -442,9 +509,9 @@ bool $cmd_class\::IsError(int &rTypeOut, int &rSubTypeOut) const
rSubTypeOut = m$mem_subtype;
return true;
}
-std::string $cmd_class\::GetMessage() const
+std::string $cmd_class\::GetMessage(int subtype)
{
- switch(m$mem_subtype)
+ switch(subtype)
{
__E
foreach my $const (@{$cmd_constants{$cmd}})
@@ -459,7 +526,7 @@ __E
print CPP <<__E;
default:
std::ostringstream out;
- out << "Unknown subtype " << m$mem_subtype;
+ out << "Unknown subtype " << subtype;
return out.str();
}
}
@@ -505,47 +572,44 @@ my $error_class = $protocol_name."ProtocolError";
# the abstract protocol interface
print H <<__E;
-class $protocol_base_class
+
+class $client_server_base_class
{
public:
- $protocol_base_class();
- virtual ~$protocol_base_class();
- virtual const char *GetIdentString();
+ $client_server_base_class();
+ virtual ~$client_server_base_class();
+ virtual std::auto_ptr<IOStream> ReceiveStream() = 0;
bool GetLastError(int &rTypeOut, int &rSubTypeOut);
+ int GetLastErrorType() { return mLastErrorSubType; }
protected:
- void CheckReply(const std::string& requestCommand,
- const $message_base_class &rReply, int expectedType);
void SetLastError(int Type, int SubType)
{
mLastErrorType = Type;
mLastErrorSubType = SubType;
}
+ std::string mPreviousCommand;
+ std::string mPreviousReply;
private:
- $protocol_base_class(const $protocol_base_class &rToCopy); /* do not call */
+ $client_server_base_class(const $client_server_base_class &rToCopy); /* do not call */
int mLastErrorType;
int mLastErrorSubType;
};
-class $replyable_base_class : public virtual $protocol_base_class
+class $replyable_base_class : public virtual $client_server_base_class
{
public:
- $replyable_base_class();
+ $replyable_base_class() { }
virtual ~$replyable_base_class();
- /*
- virtual std::auto_ptr<$message_base_class> Receive() = 0;
- virtual void Send(const ${message_base_class} &rObject) = 0;
- */
-
- virtual std::auto_ptr<IOStream> ReceiveStream() = 0;
virtual int GetTimeout() = 0;
void SendStreamAfterCommand(std::auto_ptr<IOStream> apStream);
-
+
protected:
std::list<IOStream*> mStreamsToSend;
void DeleteStreamsToSend();
+ virtual std::auto_ptr<$message_base_class> HandleException(BoxException& e) const;
private:
$replyable_base_class(const $replyable_base_class &rToCopy); /* do not call */
@@ -554,24 +618,47 @@ private:
__E
print CPP <<__E;
-$protocol_base_class\::$protocol_base_class()
+$client_server_base_class\::$client_server_base_class()
: mLastErrorType(Protocol::NoError),
mLastErrorSubType(Protocol::NoError)
{ }
-$protocol_base_class\::~$protocol_base_class()
+$client_server_base_class\::~$client_server_base_class()
{ }
-const char *$protocol_base_class\::GetIdentString()
+const char *$custom_protocol_subclass\::GetProtocolIdentString()
{
return "$ident_string";
}
-$replyable_base_class\::$replyable_base_class()
-{ }
+std::auto_ptr<Message> $custom_protocol_subclass\::MakeMessage(int ObjType)
+{
+ switch(ObjType)
+ {
+__E
+
+# do objects within this
+for my $cmd (@cmd_list)
+{
+ print CPP <<__E;
+ case $cmd_id{$cmd}:
+ return std::auto_ptr<Message>(new $cmd_classes{$cmd}());
+ break;
+__E
+}
+
+print CPP <<__E;
+ default:
+ THROW_EXCEPTION(ConnectionException, Protocol_UnknownCommandRecieved)
+ }
+}
$replyable_base_class\::~$replyable_base_class()
-{ }
+{
+ // If there were any streams left over, there's no longer any way to
+ // access them, and we're responsible for them, so we'd better delete them.
+ DeleteStreamsToSend();
+}
void $replyable_base_class\::SendStreamAfterCommand(std::auto_ptr<IOStream> apStream)
{
@@ -589,12 +676,14 @@ void $replyable_base_class\::DeleteStreamsToSend()
mStreamsToSend.clear();
}
-void $protocol_base_class\::CheckReply(const std::string& requestCommand,
- const $message_base_class &rReply, int expectedType)
+void $callable_base_class\::CheckReply(const std::string& requestCommandName,
+ const $message_base_class &rCommand, const $message_base_class &rReply,
+ int expectedType)
{
if(rReply.GetType() == expectedType)
{
// Correct response, do nothing
+ SetLastError(Protocol::NoError, Protocol::NoError);
}
else
{
@@ -605,8 +694,8 @@ void $protocol_base_class\::CheckReply(const std::string& requestCommand,
{
SetLastError(type, subType);
THROW_EXCEPTION_MESSAGE(ConnectionException,
- Conn_Protocol_UnexpectedReply,
- requestCommand << " command failed: "
+ Protocol_UnexpectedReply,
+ requestCommandName << " command failed: "
"received error " <<
(($error_class&)rReply).GetMessage());
}
@@ -614,12 +703,18 @@ void $protocol_base_class\::CheckReply(const std::string& requestCommand,
{
SetLastError(Protocol::UnknownError, Protocol::UnknownError);
THROW_EXCEPTION_MESSAGE(ConnectionException,
- Conn_Protocol_UnexpectedReply,
- requestCommand << " command failed: "
+ Protocol_UnexpectedReply,
+ requestCommandName << " command failed: "
"received unexpected response type " <<
rReply.GetType());
}
}
+
+ // As a client, if we get an unexpected reply later, we'll want to know
+ // the last command that we executed, and the reply, to help debug the
+ // server.
+ mPreviousCommand = rCommand.ToString();
+ mPreviousReply = rReply.ToString();
}
// --------------------------------------------------------------------------
@@ -630,7 +725,7 @@ void $protocol_base_class\::CheckReply(const std::string& requestCommand,
// Created: 2003/08/19
//
// --------------------------------------------------------------------------
-bool $protocol_base_class\::GetLastError(int &rTypeOut, int &rSubTypeOut)
+bool $client_server_base_class\::GetLastError(int &rTypeOut, int &rSubTypeOut)
{
if(mLastErrorType == Protocol::NoError)
{
@@ -653,13 +748,19 @@ __E
# the callable protocol interface (implemented by Client and Local classes)
# with Query methods that don't take a context parameter
-my $callable_base_class = $protocol_name."ProtocolCallable";
print H <<__E;
-class $callable_base_class : public virtual $protocol_base_class
+class $callable_base_class : public virtual $client_server_base_class,
+ public $send_receive_class
{
public:
- virtual std::auto_ptr<IOStream> ReceiveStream() = 0;
virtual int GetTimeout() = 0;
+
+protected:
+ void CheckReply(const std::string& requestCommandName,
+ const $message_base_class &rCommand,
+ const $message_base_class &rReply, int expectedType);
+
+public:
__E
# add plain object taking query functions
@@ -671,8 +772,8 @@ for my $cmd (@cmd_list)
my $has_stream = obj_is_type($cmd,'StreamWithCommand');
my $argextra = $has_stream?', std::auto_ptr<IOStream> apStream':'';
my $queryextra = $has_stream?', apStream':'';
- my $request_class = $cmd_class{$cmd};
- my $reply_class = $cmd_class{obj_get_type_params($cmd,'Command')};
+ my $request_class = $cmd_classes{$cmd};
+ my $reply_class = $cmd_classes{obj_get_type_params($cmd,'Command')};
print H "\tvirtual std::auto_ptr<$reply_class> Query(const $request_class &rQuery$argextra) = 0;\n";
my @a;
@@ -720,13 +821,15 @@ foreach my $type ('Client', 'Server', 'Local')
{
push @base_classes, $replyable_base_class;
}
+
if (not $writing_server)
{
push @base_classes, $callable_base_class;
}
+
if (not $writing_local)
{
- push @base_classes, "Protocol";
+ push @base_classes, $custom_protocol_subclass;
}
my $base_classes_str = join(", ", map {"public $_"} @base_classes);
@@ -735,6 +838,7 @@ foreach my $type ('Client', 'Server', 'Local')
class $server_or_client_class : $base_classes_str
{
public:
+ virtual ~$server_or_client_class();
__E
if($writing_local)
@@ -743,18 +847,12 @@ __E
$server_or_client_class($context_class &rContext);
__E
}
- else
- {
- print H <<__E;
- $server_or_client_class(IOStream &rStream);
+
+ print H <<__E;
+ $server_or_client_class(std::auto_ptr<SocketStream> apConn);
std::auto_ptr<$message_base_class> Receive();
void Send(const $message_base_class &rObject);
__E
- }
-
- print H <<__E;
- virtual ~$server_or_client_class();
-__E
if($writing_server)
{
@@ -775,37 +873,29 @@ __E
my $has_stream = obj_is_type($cmd,'StreamWithCommand');
my $argextra = $has_stream?', std::auto_ptr<IOStream> apStream':'';
my $queryextra = $has_stream?', apStream':'';
- my $request_class = $cmd_class{$cmd};
- my $reply_class = $cmd_class{obj_get_type_params($cmd,'Command')};
+ my $request_class = $cmd_classes{$cmd};
+ my $reply_class = $cmd_classes{obj_get_type_params($cmd,'Command')};
print H "\tstd::auto_ptr<$reply_class> Query(const $request_class &rQuery$argextra);\n";
}
}
}
-
+
if($writing_local)
{
print H <<__E;
private:
$context_class &mrContext;
-__E
- }
-
- print H <<__E;
-
-protected:
- virtual std::auto_ptr<Message> MakeMessage(int ObjType);
-
-__E
-
- if($writing_local)
- {
- print H <<__E;
- virtual void InformStreamReceiving(u_int32_t Size) { }
- virtual void InformStreamSending(u_int32_t Size) { }
-
+ std::auto_ptr<$message_base_class> mapLastReply;
public:
virtual std::auto_ptr<IOStream> ReceiveStream()
{
+ if(mStreamsToSend.empty())
+ {
+ THROW_EXCEPTION_MESSAGE(CommonException, Internal,
+ "Tried to ReceiveStream when none was sent or "
+ "made available");
+ }
+
std::auto_ptr<IOStream> apStream(mStreamsToSend.front());
mStreamsToSend.pop_front();
return apStream;
@@ -815,29 +905,33 @@ __E
else
{
print H <<__E;
- virtual void InformStreamReceiving(u_int32_t Size)
- {
- this->Protocol::InformStreamReceiving(Size);
- }
- virtual void InformStreamSending(u_int32_t Size)
- {
- this->Protocol::InformStreamSending(Size);
- }
+ virtual std::auto_ptr<IOStream> ReceiveStream();
+__E
-public:
- virtual std::auto_ptr<IOStream> ReceiveStream()
+ print CPP <<__E;
+std::auto_ptr<IOStream> $server_or_client_class\::ReceiveStream()
+{
+ try
{
- return this->Protocol::ReceiveStream();
- }
-__E
+ return $custom_protocol_subclass\::ReceiveStream();
}
-
- print H <<__E;
- virtual const char *GetProtocolIdentString()
+ catch(ConnectionException &e)
{
- return GetIdentString();
+ if(e.GetSubType() == ConnectionException::Protocol_ObjWhenStreamExpected)
+ {
+ THROW_EXCEPTION_MESSAGE(ConnectionException,
+ Protocol_ObjWhenStreamExpected,
+ "Last exchange was " << mPreviousCommand <<
+ " => " << mPreviousReply);
+ }
+ else
+ {
+ throw;
+ }
}
+}
__E
+ }
if($writing_local)
{
@@ -853,23 +947,13 @@ __E
print H <<__E;
virtual int GetTimeout()
{
- return this->Protocol::GetTimeout();
+ return $custom_protocol_subclass\::GetTimeout();
}
__E
}
-
+
print H <<__E;
- /*
- virtual void Handshake()
- {
- this->Protocol::Handshake();
- }
- virtual bool GetLastError(int &rTypeOut, int &rSubTypeOut)
- {
- return this->Protocol::GetLastError(rTypeOut, rSubTypeOut);
- }
- */
-
+
private:
$server_or_client_class(const $server_or_client_class &rToCopy); /* no copies */
};
@@ -890,8 +974,8 @@ __E
else
{
print CPP <<__E;
-$server_or_client_class\::$server_or_client_class(IOStream &rStream)
-: Protocol(rStream)
+$server_or_client_class\::$server_or_client_class(std::auto_ptr<SocketStream> apConn)
+: $custom_protocol_subclass(apConn)
{ }
__E
}
@@ -903,49 +987,58 @@ $server_or_client_class\::~$server_or_client_class()
__E
# write receive and send functions
- print CPP <<__E;
-std::auto_ptr<Message> $server_or_client_class\::MakeMessage(int ObjType)
-{
- switch(ObjType)
- {
-__E
-
- # do objects within this
- for my $cmd (@cmd_list)
+ if($writing_local)
{
print CPP <<__E;
- case $cmd_id{$cmd}:
- return std::auto_ptr<Message>(new $cmd_class{$cmd}());
- break;
-__E
- }
-
- print CPP <<__E;
- default:
- THROW_EXCEPTION(ConnectionException, Conn_Protocol_UnknownCommandRecieved)
- }
+std::auto_ptr<$message_base_class> $server_or_client_class\::Receive()
+{
+ return mapLastReply;
+}
+void $server_or_client_class\::Send(const $message_base_class &rObject)
+{
+ mapLastReply = rObject.DoCommand(*this, mrContext);
}
__E
-
- if(not $writing_local)
+ }
+ else
{
print CPP <<__E;
std::auto_ptr<$message_base_class> $server_or_client_class\::Receive()
{
- std::auto_ptr<$message_base_class> preply(($message_base_class *)
- Protocol::ReceiveInternal().release());
+ std::auto_ptr<$message_base_class> apReply;
+
+ try
+ {
+ apReply = std::auto_ptr<$message_base_class>(
+ static_cast<$message_base_class *>
+ ($custom_protocol_subclass\::ReceiveInternal().release()));
+ }
+ catch(ConnectionException &e)
+ {
+ if(e.GetSubType() == ConnectionException::Protocol_StreamWhenObjExpected)
+ {
+ THROW_EXCEPTION_MESSAGE(ConnectionException,
+ Protocol_StreamWhenObjExpected,
+ "Last exchange was " << mPreviousCommand <<
+ " => " << mPreviousReply);
+ }
+ else
+ {
+ throw;
+ }
+ }
if(GetLogToSysLog())
{
- preply->LogSysLog("Receive");
+ apReply->LogSysLog("Receive");
}
if(GetLogToFile() != 0)
{
- preply->LogFile("Receive", GetLogToFile());
+ apReply->LogFile("Receive", GetLogToFile());
}
- return preply;
+ return apReply;
}
void $server_or_client_class\::Send(const $message_base_class &rObject)
@@ -981,10 +1074,40 @@ void $server_or_client_class\::DoServer($context_class &rContext)
{
// Get an object from the conversation
std::auto_ptr<$message_base_class> pobj = Receive();
+ std::auto_ptr<$message_base_class> preply;
// Run the command
- std::auto_ptr<$message_base_class> preply = pobj->DoCommand(*this, rContext);
-
+ try
+ {
+ try
+ {
+ if(pobj->HasStreamWithCommand())
+ {
+ std::auto_ptr<IOStream> apDataStream = ReceiveStream();
+ SelfFlushingStream autoflush(*apDataStream);
+ preply = pobj->DoCommand(*this, rContext, *apDataStream);
+ }
+ else
+ {
+ preply = pobj->DoCommand(*this, rContext);
+ }
+ }
+ catch(BoxException &e)
+ {
+ // First try a the built-in exception handler
+ preply = HandleException(e);
+ }
+ }
+ catch (...)
+ {
+ // Fallback in case the exception isn't a BoxException
+ // or the exception handler fails as well. This path
+ // throws the exception upwards, killing the process
+ // that handles the current client.
+ Send($cmd_classes{$error_message}(-1));
+ throw;
+ }
+
// Send the reply
Send(*preply);
@@ -995,7 +1118,16 @@ void $server_or_client_class\::DoServer($context_class &rContext)
{
SendStream(**i);
}
-
+
+ // As a server, if we get an unexpected message later, we'll
+ // want to know the last command that we received, and the
+ // reply, to help debug our response to it.
+ mPreviousCommand = pobj->ToString();
+ std::ostringstream reply;
+ reply << preply->ToString() << " and " <<
+ mStreamsToSend.size() << " streams";
+ mPreviousReply = reply.str();
+
// Delete these streams
DeleteStreamsToSend();
@@ -1004,7 +1136,7 @@ void $server_or_client_class\::DoServer($context_class &rContext)
{
inProgress = false;
}
- }
+ }
}
__E
@@ -1017,67 +1149,86 @@ __E
{
if(obj_is_type($cmd,'Command'))
{
- my $request_class = $cmd_class{$cmd};
+ my $request_class = $cmd_classes{$cmd};
my $reply_msg = obj_get_type_params($cmd,'Command');
- my $reply_class = $cmd_class{$reply_msg};
+ my $reply_class = $cmd_classes{$reply_msg};
my $reply_id = $cmd_id{$reply_msg};
my $has_stream = obj_is_type($cmd,'StreamWithCommand');
- my $argextra = $has_stream?', std::auto_ptr<IOStream> apStream':'';
+ my $argextra = $has_stream?', std::auto_ptr<IOStream> apDataStream':'';
my $send_stream_extra = '';
- my $send_stream_method = $writing_client ? "SendStream"
- : "SendStreamAfterCommand";
-
+
+ print CPP <<__E;
+std::auto_ptr<$reply_class> $server_or_client_class\::Query(const $request_class &rQuery$argextra)
+{
+__E
+
if($writing_client)
{
if($has_stream)
{
$send_stream_extra = <<__E;
// Send stream after the command
- SendStream(*apStream);
+ try
+ {
+ SendStream(*apDataStream);
+ }
+ catch (BoxException &e)
+ {
+ BOX_WARNING("Failed to send stream after command: " <<
+ rQuery.ToString() << ": " << e.what());
+ throw;
+ }
__E
}
print CPP <<__E;
-std::auto_ptr<$reply_class> $server_or_client_class\::Query(const $request_class &rQuery$argextra)
-{
// Send query
Send(rQuery);
- $send_stream_extra
+$send_stream_extra
// Wait for the reply
- std::auto_ptr<$message_base_class> preply = Receive();
-
- CheckReply("$cmd", *preply, $reply_id);
-
- // Correct response, if no exception thrown by CheckReply
- return std::auto_ptr<$reply_class>(($reply_class *)preply.release());
-}
+ std::auto_ptr<$message_base_class> apReply = Receive();
__E
}
elsif($writing_local)
{
+ print CPP <<__E;
+ std::auto_ptr<$message_base_class> apReply;
+ try
+ {
+__E
if($has_stream)
{
- $send_stream_extra = <<__E;
- // Send stream after the command
- SendStreamAfterCommand(apStream);
+ print CPP <<__E;
+ apReply = rQuery.DoCommand(*this, mrContext, *apDataStream);
+__E
+ }
+ else
+ {
+ print CPP <<__E;
+ apReply = rQuery.DoCommand(*this, mrContext);
__E
}
print CPP <<__E;
-std::auto_ptr<$reply_class> $server_or_client_class\::Query(const $request_class &rQuery$argextra)
-{
- // Send query
- $send_stream_extra
- std::auto_ptr<$message_base_class> preply = rQuery.DoCommand(*this, mrContext);
-
- CheckReply("$cmd", *preply, $reply_id);
+ }
+ catch(BoxException &e)
+ {
+ // First try a the built-in exception handler
+ apReply = HandleException(e);
+ }
+__E
+ }
+
+ # Common to both client and local
+ print CPP <<__E;
+ CheckReply("$cmd", rQuery, *apReply, $reply_id);
// Correct response, if no exception thrown by CheckReply
- return std::auto_ptr<$reply_class>(($reply_class *)preply.release());
+ return std::auto_ptr<$reply_class>(
+ static_cast<$reply_class *>(apReply.release()));
}
__E
- }
}
}
}
@@ -1110,7 +1261,7 @@ sub obj_get_type_params
{
return $1 if $_ =~ m/\A$ty\((.+?)\)\Z/;
}
- die "Can't find attribute $ty\n"
+ die "Can't find attribute $ty on command $c\n"
}
# returns (is basic type, typename)