diff options
Diffstat (limited to 'lib/server')
28 files changed, 1115 insertions, 876 deletions
diff --git a/lib/server/ConnectionException.txt b/lib/server/ConnectionException.txt index c3429116..7dcaadeb 100644 --- a/lib/server/ConnectionException.txt +++ b/lib/server/ConnectionException.txt @@ -25,3 +25,4 @@ Protocol_HandshakeFailed 48 Protocol_StreamWhenObjExpected 49 Protocol_ObjWhenStreamExpected 50 Protocol_TimeOutWhenSendingStream 52 Probably a network issue between client and server. +Protocol_StreamsNotConsumed 53 The server command handler did not consume all streams that were sent. diff --git a/lib/server/Daemon.cpp b/lib/server/Daemon.cpp index 7419f973..d3c8441f 100644 --- a/lib/server/Daemon.cpp +++ b/lib/server/Daemon.cpp @@ -9,23 +9,27 @@ #include "Box.h" -#ifdef HAVE_UNISTD_H - #include <unistd.h> -#endif - #include <errno.h> #include <stdio.h> #include <signal.h> #include <string.h> #include <stdarg.h> +#ifdef HAVE_PROCESS_H +# include <process.h> +#endif + +#ifdef HAVE_UNISTD_H +# include <unistd.h> +#endif + #ifdef HAVE_BSD_UNISTD_H #include <bsd/unistd.h> #endif #ifdef WIN32 + #include <Strsafe.h> #include <ws2tcpip.h> - #include <process.h> #endif #include "depot.h" @@ -36,12 +40,13 @@ # include "BoxVersion.h" #endif +#include "autogen_ConnectionException.h" +#include "autogen_ServerException.h" #include "Configuration.h" #include "Daemon.h" #include "FileModificationTime.h" #include "Guards.h" #include "Logging.h" -#include "ServerException.h" #include "UnixUser.h" #include "Utils.h" @@ -106,11 +111,11 @@ Daemon::~Daemon() // -------------------------------------------------------------------------- std::string Daemon::GetOptionString() { - return "c:" + return std::string("c:" #ifndef WIN32 "DF" #endif - "hkKo:O:PqQt:TUvVW:"; + "hkKo:O:") + Logging::OptionParser::GetOptionString(); } void Daemon::Usage() @@ -133,16 +138,7 @@ void Daemon::Usage() " -K Stop writing log messages to console while daemon is running\n" " -o <file> Log to a file, defaults to maximum verbosity\n" " -O <level> Set file log verbosity to error/warning/notice/info/trace/everything\n" - " -P Show process ID (PID) in console output\n" - " -q Run more quietly, reduce verbosity level by one, can repeat\n" - " -Q Run at minimum verbosity, log nothing to console and system\n" - " -t <tag> Tag console output with specified marker\n" - " -T Timestamp console output\n" - " -U Timestamp console output with microseconds\n" - " -v Run more verbosely, increase verbosity level by one, can repeat\n" - " -V Run at maximum verbosity, log everything to console and system\n" - " -W <level> Set verbosity to error/warning/notice/info/trace/everything\n" - ; + << Logging::OptionParser::GetUsageString(); } // -------------------------------------------------------------------------- @@ -218,94 +214,9 @@ int Daemon::ProcessOption(signed int option) } break; - case 'P': - { - Console::SetShowPID(true); - } - break; - - case 'q': - { - if(mLogLevel == Log::NOTHING) - { - BOX_FATAL("Too many '-q': " - "Cannot reduce logging " - "level any more"); - return 2; - } - mLogLevel--; - } - break; - - case 'Q': - { - mLogLevel = Log::NOTHING; - } - break; - - case 't': - { - Logging::SetProgramName(optarg); - Console::SetShowTag(true); - } - break; - - case 'T': - { - Console::SetShowTime(true); - } - break; - - case 'U': - { - Console::SetShowTime(true); - Console::SetShowTimeMicros(true); - } - break; - - case 'v': - { - if(mLogLevel == Log::EVERYTHING) - { - BOX_FATAL("Too many '-v': " - "Cannot increase logging " - "level any more"); - return 2; - } - mLogLevel++; - } - break; - - case 'V': - { - mLogLevel = Log::EVERYTHING; - } - break; - - case 'W': - { - mLogLevel = Logging::GetNamedLevel(optarg); - if (mLogLevel == Log::INVALID) - { - BOX_FATAL("Invalid logging level: " << optarg); - return 2; - } - } - break; - - case '?': - { - BOX_FATAL("Unknown option on command line: " - << "'" << (char)optopt << "'"); - return 2; - } - break; - default: { - BOX_FATAL("Unknown error in getopt: returned " - << "'" << option << "'"); - return 1; + return mLogLevel.ProcessOption(option); } } @@ -351,12 +262,6 @@ int Daemon::Main(const std::string& rDefaultConfigFile, int argc, int Daemon::ProcessOptions(int argc, const char *argv[]) { - #ifdef BOX_RELEASE_BUILD - mLogLevel = Log::NOTICE; - #else - mLogLevel = Log::INFO; - #endif - if (argc == 2 && strcmp(argv[1], "/?") == 0) { Usage(); @@ -368,7 +273,7 @@ int Daemon::ProcessOptions(int argc, const char *argv[]) // reset getopt, just in case anybody used it before. // unfortunately glibc and BSD differ on this point! // http://www.ussg.iu.edu/hypermail/linux/kernel/0305.3/0262.html - #if HAVE_DECL_OPTRESET == 1 || defined WIN32 + #if HAVE_DECL_OPTRESET == 1 || defined BOX_BSD_GETOPT optind = 1; optreset = 1; #elif defined __GLIBC__ @@ -406,13 +311,14 @@ int Daemon::ProcessOptions(int argc, const char *argv[]) return 2; } - Logging::FilterConsole((Log::Level)mLogLevel); - Logging::FilterSyslog ((Log::Level)mLogLevel); + Logging::FilterConsole(mLogLevel.GetCurrentLevel()); + Logging::FilterSyslog (mLogLevel.GetCurrentLevel()); if (mLogFileLevel != Log::INVALID) { mapLogFileLogger.reset( - new FileLogger(mLogFile, mLogFileLevel)); + new FileLogger(mLogFile, mLogFileLevel, + !mLogLevel.mTruncateLogFile)); } return 0; @@ -473,17 +379,17 @@ bool Daemon::Configure(const std::string& rConfigFileName) BOX_ERROR("Failed to load or verify configuration file"); return false; } - + if(!Configure(*apConfig)) { BOX_ERROR("Failed to verify configuration file"); - return false; + return false; } - + // Store configuration mConfigFileName = rConfigFileName; mLoadedConfigModifiedTime = GetConfigFileModifiedTime(); - + return true; } @@ -513,14 +419,14 @@ bool Daemon::Configure(const Configuration& rConfig) BOX_ERROR("Configuration errors: " << errors); return false; } - + // Store configuration mapConfiguration = apConf; - + // Let the derived class have a go at setting up stuff // in the initial process SetupInInitialProcess(); - + return true; } @@ -664,7 +570,7 @@ int Daemon::Main(const std::string &rConfigFileName) // Write PID to file char pid[32]; - int pidsize = sprintf(pid, "%d", (int)getpid()); + int pidsize = snprintf(pid, sizeof(pid), "%d", (int)getpid()); if(::write(pidFile, pid, pidsize) != pidsize) { @@ -676,9 +582,8 @@ int Daemon::Main(const std::string &rConfigFileName) // Set up memory leak reporting #ifdef BOX_MEMORY_LEAK_TESTING { - char filename[256]; - sprintf(filename, "%s.memleaks", DaemonName()); - memleakfinder_setup_exit_report(filename, DaemonName()); + memleakfinder_setup_exit_report(std::string(DaemonName()) + + ".memleaks", DaemonName()); } #endif // BOX_MEMORY_LEAK_TESTING @@ -986,7 +891,9 @@ const Configuration &Daemon::GetConfiguration() const if(mapConfiguration.get() == 0) { // Shouldn't get anywhere near this if a configuration file can't be loaded - THROW_EXCEPTION(ServerException, Internal) + THROW_EXCEPTION_MESSAGE(ServerException, Internal, + "The daemon has not been configured; no config file " + "has been loaded."); } return *mapConfiguration; diff --git a/lib/server/Daemon.h b/lib/server/Daemon.h index 2718c288..b5384918 100644 --- a/lib/server/Daemon.h +++ b/lib/server/Daemon.h @@ -85,7 +85,16 @@ protected: bool IsSingleProcess() { return mSingleProcess; } virtual std::string GetOptionString(); virtual int ProcessOption(signed int option); - + void ResetLogFile() + { + if(mapLogFileLogger.get()) + { + mapLogFileLogger.reset( + new FileLogger(mLogFile, mLogFileLevel, + !mLogLevel.mTruncateLogFile)); + } + } + private: static void SignalHandler(int sigraised); box_time_t GetConfigFileModifiedTime() const; @@ -99,7 +108,7 @@ private: bool mRunInForeground; bool mKeepConsoleOpenAfterFork; bool mHaveConfigFile; - int mLogLevel; // need an int to do math with + Logging::OptionParser mLogLevel; std::string mLogFile; Log::Level mLogFileLevel; std::auto_ptr<FileLogger> mapLogFileLogger; diff --git a/lib/server/Message.h b/lib/server/Message.h index 0d073d49..9f2245ec 100644 --- a/lib/server/Message.h +++ b/lib/server/Message.h @@ -37,10 +37,11 @@ public: // reading and writing with Protocol objects virtual void SetPropertiesFromStreamData(Protocol &rProtocol); - virtual void WritePropertiesToStreamData(Protocol &rProtocol) const; + virtual void WritePropertiesToStreamData(Protocol &rProtocol) const; virtual void LogSysLog(const char *Action) const { } virtual void LogFile(const char *Action, FILE *file) const { } + virtual std::string ToString() const = 0; }; /* diff --git a/lib/server/Protocol.cpp b/lib/server/Protocol.cpp index 382f1c37..0adf9543 100644 --- a/lib/server/Protocol.cpp +++ b/lib/server/Protocol.cpp @@ -17,10 +17,11 @@ #include <new> +#include "autogen_ConnectionException.h" +#include "autogen_ServerException.h" #include "Protocol.h" #include "ProtocolWire.h" -#include "IOStream.h" -#include "ServerException.h" +#include "SocketStream.h" #include "PartialReadStream.h" #include "ProtocolUncertainStream.h" #include "Logging.h" @@ -44,8 +45,8 @@ // Created: 2003/08/19 // // -------------------------------------------------------------------------- -Protocol::Protocol(IOStream &rStream) -: mrStream(rStream), +Protocol::Protocol(std::auto_ptr<SocketStream> apConn) +: mapConn(apConn), mHandshakeDone(false), mMaxObjectSize(PROTOCOL_DEFAULT_MAXOBJSIZE), mTimeout(PROTOCOL_DEFAULT_TIMEOUT), @@ -103,8 +104,8 @@ void Protocol::Handshake() ::strncpy(hsSend.mIdent, GetProtocolIdentString(), sizeof(hsSend.mIdent)); // Send it - mrStream.Write(&hsSend, sizeof(hsSend)); - mrStream.WriteAllBuffered(); + mapConn->Write(&hsSend, sizeof(hsSend), GetTimeout()); + mapConn->WriteAllBuffered(); // Receive a handshake from the peer PW_Handshake hsReceive; @@ -114,10 +115,10 @@ void Protocol::Handshake() while(bytesToRead > 0) { // Get some data from the stream - int bytesRead = mrStream.Read(readInto, bytesToRead, mTimeout); + int bytesRead = mapConn->Read(readInto, bytesToRead, GetTimeout()); if(bytesRead == 0) { - THROW_EXCEPTION(ConnectionException, Conn_Protocol_Timeout) + THROW_EXCEPTION(ConnectionException, Protocol_Timeout) } readInto += bytesRead; bytesToRead -= bytesRead; @@ -127,7 +128,7 @@ void Protocol::Handshake() // Are they the same? if(::memcmp(&hsSend, &hsReceive, sizeof(hsSend)) != 0) { - THROW_EXCEPTION(ConnectionException, Conn_Protocol_HandshakeFailed) + THROW_EXCEPTION(ConnectionException, Protocol_HandshakeFailed) } // Mark as done @@ -158,9 +159,10 @@ void Protocol::CheckAndReadHdr(void *hdr) } // Get some data into this header - if(!mrStream.ReadFullBuffer(hdr, sizeof(PW_ObjectHeader), 0 /* not interested in bytes read if this fails */, mTimeout)) + if(!mapConn->ReadFullBuffer(hdr, sizeof(PW_ObjectHeader), + 0 /* not interested in bytes read if this fails */, mTimeout)) { - THROW_EXCEPTION(ConnectionException, Conn_Protocol_Timeout) + THROW_EXCEPTION(ConnectionException, Protocol_Timeout) } } @@ -168,8 +170,9 @@ void Protocol::CheckAndReadHdr(void *hdr) // -------------------------------------------------------------------------- // // Function -// Name: Protocol::Recieve() -// Purpose: Recieves an object from the stream, creating it from the factory object type +// Name: Protocol::ReceiveInternal() +// Purpose: Receives an object from the stream, creating it +// from the factory object type // Created: 2003/08/19 // // -------------------------------------------------------------------------- @@ -182,14 +185,14 @@ std::auto_ptr<Message> Protocol::ReceiveInternal() // Hope it's not a stream if(ntohl(objHeader.mObjType) == SPECIAL_STREAM_OBJECT_TYPE) { - THROW_EXCEPTION(ConnectionException, Conn_Protocol_StreamWhenObjExpected) + THROW_EXCEPTION(ConnectionException, Protocol_StreamWhenObjExpected) } // Check the object size - u_int32_t objSize = ntohl(objHeader.mObjSize); + uint32_t objSize = ntohl(objHeader.mObjSize); if(objSize < sizeof(objHeader) || objSize > mMaxObjectSize) { - THROW_EXCEPTION(ConnectionException, Conn_Protocol_ObjTooBig) + THROW_EXCEPTION(ConnectionException, Protocol_ObjTooBig) } // Create a blank object @@ -199,9 +202,10 @@ std::auto_ptr<Message> Protocol::ReceiveInternal() EnsureBufferAllocated(objSize); // Read data - if(!mrStream.ReadFullBuffer(mpBuffer, objSize - sizeof(objHeader), 0 /* not interested in bytes read if this fails */, mTimeout)) + if(!mapConn->ReadFullBuffer(mpBuffer, objSize - sizeof(objHeader), + 0 /* not interested in bytes read if this fails */, mTimeout)) { - THROW_EXCEPTION(ConnectionException, Conn_Protocol_Timeout) + THROW_EXCEPTION(ConnectionException, Protocol_Timeout) } // Setup ready to read out data from the buffer @@ -231,7 +235,7 @@ std::auto_ptr<Message> Protocol::ReceiveInternal() // Exception if not all the data was consumed if(dataLeftOver) { - THROW_EXCEPTION(ConnectionException, Conn_Protocol_BadCommandRecieved) + THROW_EXCEPTION(ConnectionException, Protocol_BadCommandRecieved) } return obj; @@ -240,7 +244,7 @@ std::auto_ptr<Message> Protocol::ReceiveInternal() // -------------------------------------------------------------------------- // // Function -// Name: Protocol::Send() +// Name: Protocol::SendInternal() // Purpose: Send an object to the other side of the connection. // Created: 2003/08/19 // @@ -292,8 +296,8 @@ void Protocol::SendInternal(const Message &rObject) pobjHeader->mObjType = htonl(rObject.GetType()); // Write data - mrStream.Write(mpBuffer, writtenSize); - mrStream.WriteAllBuffered(); + mapConn->Write(mpBuffer, writtenSize, GetTimeout()); + mapConn->WriteAllBuffered(); } // -------------------------------------------------------------------------- @@ -346,7 +350,7 @@ void Protocol::EnsureBufferAllocated(int Size) #define READ_CHECK_BYTES_AVAILABLE(bytesRequired) \ if((mReadOffset + (int)(bytesRequired)) > mValidDataSize) \ { \ - THROW_EXCEPTION(ConnectionException, Conn_Protocol_BadCommandRecieved) \ + THROW_EXCEPTION(ConnectionException, Protocol_BadCommandRecieved) \ } // -------------------------------------------------------------------------- @@ -619,7 +623,7 @@ void Protocol::Write(const std::string &rValue) // -------------------------------------------------------------------------- // // Function -// Name: Protocol::ReceieveStream() +// Name: Protocol::ReceiveStream() // Purpose: Receive a stream from the remote side // Created: 2003/08/26 // @@ -633,11 +637,11 @@ std::auto_ptr<IOStream> Protocol::ReceiveStream() // Hope it's not an object if(ntohl(objHeader.mObjType) != SPECIAL_STREAM_OBJECT_TYPE) { - THROW_EXCEPTION(ConnectionException, Conn_Protocol_ObjWhenStreamExpected) + THROW_EXCEPTION(ConnectionException, Protocol_ObjWhenStreamExpected) } // Get the stream size - u_int32_t streamSize = ntohl(objHeader.mObjSize); + uint32_t streamSize = ntohl(objHeader.mObjSize); // Inform sub class InformStreamReceiving(streamSize); @@ -647,13 +651,13 @@ std::auto_ptr<IOStream> Protocol::ReceiveStream() { BOX_TRACE("Receiving stream, size uncertain"); return std::auto_ptr<IOStream>( - new ProtocolUncertainStream(mrStream)); + new ProtocolUncertainStream(*mapConn)); } else { BOX_TRACE("Receiving stream, size " << streamSize << " bytes"); return std::auto_ptr<IOStream>( - new PartialReadStream(mrStream, streamSize)); + new PartialReadStream(*mapConn, streamSize)); } } @@ -709,7 +713,7 @@ void Protocol::SendStream(IOStream &rStream) objHeader.mObjType = htonl(SPECIAL_STREAM_OBJECT_TYPE); // Write header - mrStream.Write(&objHeader, sizeof(objHeader)); + mapConn->Write(&objHeader, sizeof(objHeader), GetTimeout()); // Could be sent in one of two ways if(uncertainSize) { @@ -744,7 +748,7 @@ void Protocol::SendStream(IOStream &rStream) // Send final byte to finish the stream BOX_TRACE("Sending end of stream byte"); uint8_t endOfStream = ProtocolStreamHeader_EndOfStream; - mrStream.Write(&endOfStream, 1); + mapConn->Write(&endOfStream, 1, GetTimeout()); BOX_TRACE("Sent end of stream byte"); } catch(...) @@ -759,13 +763,14 @@ void Protocol::SendStream(IOStream &rStream) else { // Fixed size stream, send it all in one go - if(!rStream.CopyStreamTo(mrStream, mTimeout, 4096 /* slightly larger buffer */)) + if(!rStream.CopyStreamTo(*mapConn, GetTimeout(), + 4096 /* slightly larger buffer */)) { - THROW_EXCEPTION(ConnectionException, Conn_Protocol_TimeOutWhenSendingStream) + THROW_EXCEPTION(ConnectionException, Protocol_TimeOutWhenSendingStream) } } // Make sure everything is written - mrStream.WriteAllBuffered(); + mapConn->WriteAllBuffered(); } @@ -816,7 +821,7 @@ int Protocol::SendStreamSendBlock(uint8_t *Block, int BytesInBlock) Block[-1] = header; // Write everything out - mrStream.Write(Block - 1, writeSize + 1); + mapConn->Write(Block - 1, writeSize + 1, GetTimeout()); BOX_TRACE("Sent " << (writeSize+1) << " bytes to stream"); // move the remainer to the beginning of the block for the next time round @@ -831,12 +836,12 @@ int Protocol::SendStreamSendBlock(uint8_t *Block, int BytesInBlock) // -------------------------------------------------------------------------- // // Function -// Name: Protocol::InformStreamReceiving(u_int32_t) +// Name: Protocol::InformStreamReceiving(uint32_t) // Purpose: Informs sub classes about streams being received // Created: 2003/10/27 // // -------------------------------------------------------------------------- -void Protocol::InformStreamReceiving(u_int32_t Size) +void Protocol::InformStreamReceiving(uint32_t Size) { if(GetLogToSysLog()) { @@ -863,12 +868,12 @@ void Protocol::InformStreamReceiving(u_int32_t Size) // -------------------------------------------------------------------------- // // Function -// Name: Protocol::InformStreamSending(u_int32_t) +// Name: Protocol::InformStreamSending(uint32_t) // Purpose: Informs sub classes about streams being sent // Created: 2003/10/27 // // -------------------------------------------------------------------------- -void Protocol::InformStreamSending(u_int32_t Size) +void Protocol::InformStreamSending(uint32_t Size) { if(GetLogToSysLog()) { @@ -1177,6 +1182,12 @@ const uint16_t Protocol::sProtocolStreamHeaderLengths[256] = 0 // 255 = special (reserved) }; +int64_t Protocol::GetBytesRead() const +{ + return mapConn->GetBytesRead(); +} - - +int64_t Protocol::GetBytesWritten() const +{ + return mapConn->GetBytesWritten(); +} diff --git a/lib/server/Protocol.h b/lib/server/Protocol.h index 42cb0ff8..fbe6461c 100644 --- a/lib/server/Protocol.h +++ b/lib/server/Protocol.h @@ -19,6 +19,7 @@ #include "Message.h" class IOStream; +class SocketStream; // default timeout is 15 minutes #define PROTOCOL_DEFAULT_TIMEOUT (15*60*1000) @@ -36,7 +37,7 @@ class IOStream; class Protocol { public: - Protocol(IOStream &rStream); + Protocol(std::auto_ptr<SocketStream> apConn); virtual ~Protocol(); private: @@ -66,14 +67,14 @@ public: // Purpose: Sets the timeout for sending and reciving // Created: 2003/08/19 // - // -------------------------------------------------------------------------- + // -------------------------------------------------------------------------- void SetTimeout(int NewTimeout) {mTimeout = NewTimeout;} // -------------------------------------------------------------------------- // // Function - // Name: Protocol::GetTimeout() + // Name: Protocol::GetTimeout() // Purpose: Get current timeout for sending and receiving // Created: 2003/09/06 // @@ -175,6 +176,8 @@ public: FILE *GetLogToFile() { return mLogToFile; } void SetLogToSysLog(bool Log = false) {mLogToSysLog = Log;} void SetLogToFile(FILE *File = 0) {mLogToFile = File;} + int64_t GetBytesRead() const; + int64_t GetBytesWritten() const; protected: virtual std::auto_ptr<Message> MakeMessage(int ObjType) = 0; @@ -183,14 +186,14 @@ protected: void CheckAndReadHdr(void *hdr); // don't use type here to avoid dependency // Will be used for logging - virtual void InformStreamReceiving(u_int32_t Size); - virtual void InformStreamSending(u_int32_t Size); + virtual void InformStreamReceiving(uint32_t Size); + virtual void InformStreamSending(uint32_t Size); private: void EnsureBufferAllocated(int Size); int SendStreamSendBlock(uint8_t *Block, int BytesInBlock); - IOStream &mrStream; + std::auto_ptr<SocketStream> mapConn; bool mHandshakeDone; unsigned int mMaxObjectSize; int mTimeout; @@ -208,4 +211,3 @@ class ProtocolContext }; #endif // PROTOCOL__H - diff --git a/lib/server/ProtocolUncertainStream.cpp b/lib/server/ProtocolUncertainStream.cpp index 84a213a8..aeb15816 100644 --- a/lib/server/ProtocolUncertainStream.cpp +++ b/lib/server/ProtocolUncertainStream.cpp @@ -8,8 +8,9 @@ // -------------------------------------------------------------------------- #include "Box.h" +#include "autogen_ConnectionException.h" +#include "autogen_ServerException.h" #include "ProtocolUncertainStream.h" -#include "ServerException.h" #include "Protocol.h" #include "MemLeakFindOn.h" @@ -172,7 +173,7 @@ IOStream::pos_type ProtocolUncertainStream::BytesLeftToRead() // Created: 2003/12/05 // // -------------------------------------------------------------------------- -void ProtocolUncertainStream::Write(const void *pBuffer, int NBytes) +void ProtocolUncertainStream::Write(const void *pBuffer, int NBytes, int Timeout) { THROW_EXCEPTION(ServerException, CantWriteToProtocolUncertainStream) } diff --git a/lib/server/ProtocolUncertainStream.h b/lib/server/ProtocolUncertainStream.h index 4954cf88..2e97ba6a 100644 --- a/lib/server/ProtocolUncertainStream.h +++ b/lib/server/ProtocolUncertainStream.h @@ -33,7 +33,8 @@ private: public: virtual int Read(void *pBuffer, int NBytes, int Timeout = IOStream::TimeOutInfinite); virtual pos_type BytesLeftToRead(); - virtual void Write(const void *pBuffer, int NBytes); + virtual void Write(const void *pBuffer, int NBytes, + int Timeout = IOStream::TimeOutInfinite); virtual bool StreamDataLeft(); virtual bool StreamClosed(); diff --git a/lib/server/ProtocolWire.h b/lib/server/ProtocolWire.h index ff62b66e..6dee445b 100644 --- a/lib/server/ProtocolWire.h +++ b/lib/server/ProtocolWire.h @@ -26,8 +26,8 @@ typedef struct typedef struct { - u_int32_t mObjSize; - u_int32_t mObjType; + uint32_t mObjSize; + uint32_t mObjType; } PW_ObjectHeader; #define SPECIAL_STREAM_OBJECT_TYPE 0xffffffff diff --git a/lib/server/SSLLib.cpp b/lib/server/SSLLib.cpp index 004d2d98..1bcadb0d 100644 --- a/lib/server/SSLLib.cpp +++ b/lib/server/SSLLib.cpp @@ -18,9 +18,10 @@ #include <wincrypt.h> #endif +#include "autogen_ConnectionException.h" +#include "autogen_ServerException.h" #include "CryptoUtils.h" #include "SSLLib.h" -#include "ServerException.h" #include "MemLeakFindOn.h" @@ -79,7 +80,7 @@ void SSLLib::Initialise() BOX_LOG_WIN_ERROR("Failed to release crypto context"); } } -#elif HAVE_RANDOM_DEVICE +#elif defined HAVE_RANDOM_DEVICE if(::RAND_load_file(RANDOM_DEVICE, 1024) != 1024) { THROW_EXCEPTION(ServerException, SSLRandomInitFailed) diff --git a/lib/server/ServerControl.cpp b/lib/server/ServerControl.cpp index b9650cee..f1a718df 100644 --- a/lib/server/ServerControl.cpp +++ b/lib/server/ServerControl.cpp @@ -15,13 +15,14 @@ #include <signal.h> #endif +#include "BoxTime.h" +#include "IOStreamGetLine.h" #include "ServerControl.h" #include "Test.h" #ifdef WIN32 #include "WinNamedPipeStream.h" -#include "IOStreamGetLine.h" #include "BoxPortsAndFiles.h" static std::string sPipeName; @@ -197,18 +198,18 @@ bool KillServer(int pid, bool WaitForProcess) } #endif - for (int i = 0; i < 30; i++) + printf("Waiting for server to die (pid %d): ", pid); + + for (int i = 0; i < 300; i++) { - if (i == 0) + if (i % 10 == 0) { - printf("Waiting for server to die (pid %d): ", pid); + printf("."); + fflush(stdout); } - printf("."); - fflush(stdout); - if (!ServerIsAlive(pid)) break; - ::sleep(1); + ShortSleep(MilliSecondsToBoxTime(100), false); if (!ServerIsAlive(pid)) break; } @@ -226,3 +227,64 @@ bool KillServer(int pid, bool WaitForProcess) return !ServerIsAlive(pid); } +bool KillServer(std::string pid_file, bool WaitForProcess) +{ + FileStream fs(pid_file); + IOStreamGetLine getline(fs); + std::string line = getline.GetLine(); + int pid = atoi(line.c_str()); + bool status = KillServer(pid, WaitForProcess); + TEST_EQUAL_LINE(true, status, std::string("kill(") + pid_file + ")"); + +#ifdef WIN32 + if(WaitForProcess) + { + int unlink_result = unlink(pid_file.c_str()); + TEST_EQUAL_LINE(0, unlink_result, std::string("unlink ") + pid_file); + if(unlink_result != 0) + { + return false; + } + } +#endif + + return status; +} + +int StartDaemon(int current_pid, const std::string& cmd_line, const char* pid_file) +{ + TEST_THAT_OR(current_pid == 0, return 0); + + int new_pid = LaunchServer(cmd_line, pid_file); + TEST_THAT_OR(new_pid != -1 && new_pid != 0, return 0); + + ::sleep(1); + TEST_THAT_OR(ServerIsAlive(new_pid), return 0); + return new_pid; +} + +bool StopDaemon(int current_pid, const std::string& pid_file, + const std::string& memleaks_file, bool wait_for_process) +{ + TEST_THAT_OR(current_pid != 0, return false); + TEST_THAT_OR(ServerIsAlive(current_pid), return false); + TEST_THAT_OR(KillServer(current_pid, wait_for_process), return false); + ::sleep(1); + + TEST_THAT_OR(!ServerIsAlive(current_pid), return false); + + #ifdef WIN32 + int unlink_result = unlink(pid_file.c_str()); + TEST_EQUAL_LINE(0, unlink_result, std::string("unlink ") + pid_file); + if(unlink_result != 0) + { + return false; + } + #else + TestRemoteProcessMemLeaks(memleaks_file.c_str()); + #endif + + return true; +} + + diff --git a/lib/server/ServerControl.h b/lib/server/ServerControl.h index b2e51864..be2464c1 100644 --- a/lib/server/ServerControl.h +++ b/lib/server/ServerControl.h @@ -5,6 +5,10 @@ bool HUPServer(int pid); bool KillServer(int pid, bool WaitForProcess = false); +bool KillServer(std::string pid_file, bool WaitForProcess = false); +int StartDaemon(int current_pid, const std::string& cmd_line, const char* pid_file); +bool StopDaemon(int current_pid, const std::string& pid_file, + const std::string& memleaks_file, bool wait_for_process); #ifdef WIN32 #include "WinNamedPipeStream.h" diff --git a/lib/server/ServerException.h b/lib/server/ServerException.h deleted file mode 100644 index 8851b90a..00000000 --- a/lib/server/ServerException.h +++ /dev/null @@ -1,46 +0,0 @@ -// -------------------------------------------------------------------------- -// -// File -// Name: ServerException.h -// Purpose: Exception -// Created: 2003/07/08 -// -// -------------------------------------------------------------------------- - -#ifndef SERVEREXCEPTION__H -#define SERVEREXCEPTION__H - -// Compatibility header -#include "autogen_ServerException.h" -#include "autogen_ConnectionException.h" - -// Rename old connection exception names to new names without Conn_ prefix -// This is all because ConnectionException used to be derived from ServerException -// with some funky magic with subtypes. Perhaps a little unreliable, and the -// usefulness of it never really was used. -#define Conn_SocketWriteError SocketWriteError -#define Conn_SocketReadError SocketReadError -#define Conn_SocketNameLookupError SocketNameLookupError -#define Conn_SocketShutdownError SocketShutdownError -#define Conn_SocketConnectError SocketConnectError -#define Conn_TLSHandshakeFailed TLSHandshakeFailed -#define Conn_TLSShutdownFailed TLSShutdownFailed -#define Conn_TLSWriteFailed TLSWriteFailed -#define Conn_TLSReadFailed TLSReadFailed -#define Conn_TLSNoPeerCertificate TLSNoPeerCertificate -#define Conn_TLSPeerCertificateInvalid TLSPeerCertificateInvalid -#define Conn_TLSClosedWhenWriting TLSClosedWhenWriting -#define Conn_TLSHandshakeTimedOut TLSHandshakeTimedOut -#define Conn_Protocol_Timeout Protocol_Timeout -#define Conn_Protocol_ObjTooBig Protocol_ObjTooBig -#define Conn_Protocol_BadCommandRecieved Protocol_BadCommandRecieved -#define Conn_Protocol_UnknownCommandRecieved Protocol_UnknownCommandRecieved -#define Conn_Protocol_TriedToExecuteReplyCommand Protocol_TriedToExecuteReplyCommand -#define Conn_Protocol_UnexpectedReply Protocol_UnexpectedReply -#define Conn_Protocol_HandshakeFailed Protocol_HandshakeFailed -#define Conn_Protocol_StreamWhenObjExpected Protocol_StreamWhenObjExpected -#define Conn_Protocol_ObjWhenStreamExpected Protocol_ObjWhenStreamExpected -#define Conn_Protocol_TimeOutWhenSendingStream Protocol_TimeOutWhenSendingStream - -#endif // SERVEREXCEPTION__H - diff --git a/lib/server/ServerStream.h b/lib/server/ServerStream.h index a9b56169..3f6eed7e 100644 --- a/lib/server/ServerStream.h +++ b/lib/server/ServerStream.h @@ -17,6 +17,7 @@ #include <sys/wait.h> #endif +#include "autogen_ServerException.h" #include "Daemon.h" #include "SocketListen.h" #include "Utils.h" @@ -286,7 +287,7 @@ public: #endif // The derived class does some server magic with the connection - HandleConnection(*connection); + HandleConnection(connection); // Since rChildExit == true, the forked process will call _exit() on return from this fn return; @@ -305,7 +306,7 @@ public: #endif // !WIN32 // Just handle in this process SetProcessTitle("handling"); - HandleConnection(*connection); + HandleConnection(connection); SetProcessTitle("idle"); #ifndef WIN32 } @@ -344,10 +345,20 @@ public: p = ::waitpid(0 /* any child in process group */, &status, WNOHANG); - if(p == -1 && errno != ECHILD && errno != EINTR) + if(p == -1) { - THROW_EXCEPTION(ServerException, - ServerWaitOnChildError) + if (errno == ECHILD || errno == EINTR) + { + // Nothing actually happened, so there's no reason + // to wait again. + break; + } + else + { + THROW_SYS_ERROR("Failed to wait for daemon child " + "process", ServerException, + ServerWaitOnChildError); + } } else if(p == 0) { @@ -377,12 +388,12 @@ public: } #endif - virtual void HandleConnection(StreamType &rStream) + virtual void HandleConnection(std::auto_ptr<StreamType> apStream) { - Connection(rStream); + Connection(apStream); } - virtual void Connection(StreamType &rStream) = 0; + virtual void Connection(std::auto_ptr<StreamType> apStream) = 0; protected: // For checking code in derived classes -- use if you have an algorithm which diff --git a/lib/server/ServerTLS.h b/lib/server/ServerTLS.h index a74a671e..f748f4b2 100644 --- a/lib/server/ServerTLS.h +++ b/lib/server/ServerTLS.h @@ -52,18 +52,19 @@ public: std::string certFile(serverconf.GetKeyValue("CertificateFile")); std::string keyFile(serverconf.GetKeyValue("PrivateKeyFile")); std::string caFile(serverconf.GetKeyValue("TrustedCAsFile")); - mContext.Initialise(true /* as server */, certFile.c_str(), keyFile.c_str(), caFile.c_str()); + mContext.Initialise(true /* as server */, certFile.c_str(), + keyFile.c_str(), caFile.c_str()); // Then do normal stream server stuff ServerStream<SocketStreamTLS, Port, ListenBacklog, ForkToHandleRequests>::Run2(rChildExit); } - virtual void HandleConnection(SocketStreamTLS &rStream) + virtual void HandleConnection(std::auto_ptr<SocketStreamTLS> apStream) { - rStream.Handshake(mContext, true /* is server */); + apStream->Handshake(mContext, true /* is server */); // this-> in next line required to build under some gcc versions - this->Connection(rStream); + this->Connection(apStream); } private: diff --git a/lib/server/Socket.cpp b/lib/server/Socket.cpp index f2a4996b..c9c1773d 100644 --- a/lib/server/Socket.cpp +++ b/lib/server/Socket.cpp @@ -24,8 +24,9 @@ #include <string.h> #include <stdio.h> +#include "autogen_ConnectionException.h" +#include "autogen_ServerException.h" #include "Socket.h" -#include "ServerException.h" #include "MemLeakFindOn.h" @@ -69,12 +70,12 @@ void Socket::NameLookupToSockAddr(SocketAllAddr &addr, int &sockDomain, } else { - THROW_EXCEPTION(ConnectionException, Conn_SocketNameLookupError); + THROW_EXCEPTION(ConnectionException, SocketNameLookupError); } } else { - THROW_EXCEPTION(ConnectionException, Conn_SocketNameLookupError); + THROW_EXCEPTION(ConnectionException, SocketNameLookupError); } } break; diff --git a/lib/server/SocketListen.h b/lib/server/SocketListen.h index 39c60ba6..39fe7e24 100644 --- a/lib/server/SocketListen.h +++ b/lib/server/SocketListen.h @@ -29,8 +29,9 @@ #include <memory> #include <string> +#include "autogen_ConnectionException.h" +#include "autogen_ServerException.h" #include "Socket.h" -#include "ServerException.h" #include "MemLeakFindOn.h" @@ -73,7 +74,8 @@ private: // Created: 2003/07/31 // // -------------------------------------------------------------------------- -template<typename SocketType, int ListenBacklog = 128, typename SocketLockingType = _NoSocketLocking, int MaxMultiListenSockets = 16> +template<typename SocketType, int ListenBacklog = 128, + typename SocketLockingType = _NoSocketLocking, int MaxMultiListenSockets = 16> class SocketListen { public: @@ -112,10 +114,9 @@ public: if(::close(mSocketHandle) == -1) #endif { - BOX_LOG_SOCKET_ERROR(mType, mName, mPort, - "Failed to close network socket"); - THROW_EXCEPTION(ServerException, - SocketCloseError) + THROW_EXCEPTION_MESSAGE(ServerException, SocketCloseError, + BOX_SOCKET_ERROR_MESSAGE(mType, mName, mPort, + "Failed to close network socket")); } } mSocketHandle = -1; @@ -152,9 +153,9 @@ public: 0 /* let OS choose protocol */); if(mSocketHandle == -1) { - BOX_LOG_SOCKET_ERROR(Type, Name, Port, - "Failed to create a network socket"); - THROW_EXCEPTION(ServerException, SocketOpenError) + THROW_EXCEPTION_MESSAGE(ServerException, SocketOpenError, + BOX_SOCKET_ERROR_MESSAGE(Type, Name, Port, + "Failed to create a network socket")); } // Set an option to allow reuse (useful for -HUP situations!) @@ -167,28 +168,28 @@ public: &option, sizeof(option)) == -1) #endif { - BOX_LOG_SOCKET_ERROR(Type, Name, Port, - "Failed to set socket options"); - THROW_EXCEPTION(ServerException, SocketOpenError) + THROW_EXCEPTION_MESSAGE(ServerException, SocketOpenError, + BOX_SOCKET_ERROR_MESSAGE(Type, Name, Port, + "Failed to set socket options")); } // Bind it to the right port, and start listening if(::bind(mSocketHandle, &addr.sa_generic, addrLen) == -1 || ::listen(mSocketHandle, ListenBacklog) == -1) { - int err_number = errno; - - BOX_LOG_SOCKET_ERROR(Type, Name, Port, - "Failed to bind socket"); - - // Dispose of the socket - ::close(mSocketHandle); - mSocketHandle = -1; - - THROW_SYS_FILE_ERRNO("Failed to bind or listen " - "on socket", Name, err_number, - ServerException, SocketBindError); - } + try + { + THROW_EXCEPTION_MESSAGE(ServerException, SocketOpenError, + BOX_SOCKET_ERROR_MESSAGE(Type, Name, Port, + "Failed to bind socket to name/port")); + } + catch(ServerException &e) // finally + { + // Dispose of the socket + Close(); + throw; + } + } } // ------------------------------------------------------------------ @@ -248,10 +249,10 @@ public: } else { - BOX_LOG_SOCKET_ERROR(mType, mName, mPort, - "Failed to poll connection"); - THROW_EXCEPTION(ServerException, - SocketPollError) + THROW_EXCEPTION_MESSAGE(ServerException, + SocketPollError, + BOX_SOCKET_ERROR_MESSAGE(mType, mName, + mPort, "Failed to poll connection")); } break; case 0: // timed out @@ -268,9 +269,9 @@ public: // Got socket (or error), unlock (implicit in destruction) if(sock == -1) { - BOX_LOG_SOCKET_ERROR(mType, mName, mPort, - "Failed to accept connection"); - THROW_EXCEPTION(ServerException, SocketAcceptError) + THROW_EXCEPTION_MESSAGE(ServerException, SocketAcceptError, + BOX_SOCKET_ERROR_MESSAGE(mType, mName, + mPort, "Failed to accept connection")); } // Log it diff --git a/lib/server/SocketStream.cpp b/lib/server/SocketStream.cpp index 6ef4b8d1..edb5e5b8 100644 --- a/lib/server/SocketStream.cpp +++ b/lib/server/SocketStream.cpp @@ -25,10 +25,24 @@ #include <ucred.h> #endif +#ifdef HAVE_BSD_UNISTD_H + #include <bsd/unistd.h> +#endif + +#ifdef HAVE_SYS_PARAM_H + #include <sys/param.h> +#endif + +#ifdef HAVE_SYS_UCRED_H + #include <sys/ucred.h> +#endif + +#include "autogen_ConnectionException.h" +#include "autogen_ServerException.h" #include "SocketStream.h" -#include "ServerException.h" #include "CommonException.h" #include "Socket.h" +#include "Utils.h" #include "MemLeakFindOn.h" @@ -162,25 +176,31 @@ void SocketStream::Open(Socket::Type Type, const std::string& rName, int Port) 0 /* let OS choose protocol */); if(mSocketHandle == INVALID_SOCKET_VALUE) { - BOX_LOG_SOCKET_ERROR(Type, rName, Port, - "Failed to create a network socket"); - THROW_EXCEPTION(ServerException, SocketOpenError) + THROW_EXCEPTION_MESSAGE(ServerException, SocketOpenError, + BOX_SOCKET_ERROR_MESSAGE(Type, rName, Port, + "Failed to create a network socket")); } // Connect it if(::connect(mSocketHandle, &addr.sa_generic, addrLen) == -1) { // Dispose of the socket - BOX_LOG_SOCKET_ERROR(Type, rName, Port, - "Failed to connect to socket"); + try + { + THROW_EXCEPTION_MESSAGE(ServerException, SocketOpenError, + BOX_SOCKET_ERROR_MESSAGE(Type, rName, Port, + "Failed to connect to socket")); + } + catch(ServerException &e) + { #ifdef WIN32 - ::closesocket(mSocketHandle); + ::closesocket(mSocketHandle); #else // !WIN32 - ::close(mSocketHandle); + ::close(mSocketHandle); #endif // WIN32 - - mSocketHandle = INVALID_SOCKET_VALUE; - THROW_EXCEPTION(ConnectionException, Conn_SocketConnectError) + mSocketHandle = INVALID_SOCKET_VALUE; + throw; + } } ResetCounters(); @@ -199,7 +219,9 @@ void SocketStream::Open(Socket::Type Type, const std::string& rName, int Port) // -------------------------------------------------------------------------- int SocketStream::Read(void *pBuffer, int NBytes, int Timeout) { - if(mSocketHandle == INVALID_SOCKET_VALUE) + CheckForMissingTimeout(Timeout); + + if(mSocketHandle == INVALID_SOCKET_VALUE) { THROW_EXCEPTION(ServerException, BadSocketHandle) } @@ -210,7 +232,7 @@ int SocketStream::Read(void *pBuffer, int NBytes, int Timeout) p.fd = mSocketHandle; p.events = POLLIN; p.revents = 0; - switch(::poll(&p, 1, (Timeout == IOStream::TimeOutInfinite)?INFTIM:Timeout)) + switch(::poll(&p, 1, PollTimeout(Timeout, 0))) { case -1: // error @@ -256,7 +278,7 @@ int SocketStream::Read(void *pBuffer, int NBytes, int Timeout) // Other error BOX_LOG_SYS_ERROR("Failed to read from socket"); THROW_EXCEPTION(ConnectionException, - Conn_SocketReadError); + SocketReadError); } } @@ -270,6 +292,41 @@ int SocketStream::Read(void *pBuffer, int NBytes, int Timeout) return r; } +bool SocketStream::Poll(short Events, int Timeout) +{ + // Wait for data to send. + struct pollfd p; + p.fd = GetSocketHandle(); + p.events = Events; + p.revents = 0; + + box_time_t start = GetCurrentBoxTime(); + int result; + + do + { + result = ::poll(&p, 1, PollTimeout(Timeout, start)); + } + while(result == -1 && errno == EINTR); + + switch(result) + { + case -1: + // error - Bad! + THROW_SYS_ERROR("Failed to poll socket", ServerException, + SocketPollError); + break; + + case 0: + // Condition not met, timed out + return false; + + default: + // good to go! + return true; + } +} + // -------------------------------------------------------------------------- // // Function @@ -278,20 +335,21 @@ int SocketStream::Read(void *pBuffer, int NBytes, int Timeout) // Created: 2003/07/31 // // -------------------------------------------------------------------------- -void SocketStream::Write(const void *pBuffer, int NBytes) +void SocketStream::Write(const void *pBuffer, int NBytes, int Timeout) { - if(mSocketHandle == INVALID_SOCKET_VALUE) + if(mSocketHandle == INVALID_SOCKET_VALUE) { THROW_EXCEPTION(ServerException, BadSocketHandle) } - + // Buffer in byte sized type. ASSERT(sizeof(char) == 1); const char *buffer = (char *)pBuffer; - + // Bytes left to send int bytesLeft = NBytes; - + box_time_t start = GetCurrentBoxTime(); + while(bytesLeft > 0) { // Try to send. @@ -304,41 +362,30 @@ void SocketStream::Write(const void *pBuffer, int NBytes) { // Error. mWriteClosed = true; // assume can't write again - BOX_LOG_SYS_ERROR("Failed to write to socket"); - THROW_EXCEPTION(ConnectionException, - Conn_SocketWriteError); + THROW_SYS_ERROR("Failed to write to socket", + ConnectionException, SocketWriteError); } - + // Knock off bytes sent bytesLeft -= sent; // Move buffer pointer buffer += sent; mBytesWritten += sent; - + // Need to wait until it can send again? if(bytesLeft > 0) { - BOX_TRACE("Waiting to send data on socket " << + BOX_TRACE("Waiting to send data on socket " << mSocketHandle << " (" << bytesLeft << " of " << NBytes << " bytes left)"); - - // Wait for data to send. - struct pollfd p; - p.fd = mSocketHandle; - p.events = POLLOUT; - p.revents = 0; - - if(::poll(&p, 1, 16000 /* 16 seconds */) == -1) + + if(!Poll(POLLOUT, PollTimeout(Timeout, start))) { - // Don't exception if it's just a signal - if(errno != EINTR) - { - BOX_LOG_SYS_ERROR("Failed to poll " - "socket"); - THROW_EXCEPTION(ServerException, - SocketPollError) - } + THROW_EXCEPTION_MESSAGE(ConnectionException, + Protocol_Timeout, "Timed out waiting " + "to send " << bytesLeft << " of " << + NBytes << " bytes"); } } } @@ -354,7 +401,7 @@ void SocketStream::Write(const void *pBuffer, int NBytes) // -------------------------------------------------------------------------- void SocketStream::Close() { - if(mSocketHandle == INVALID_SOCKET_VALUE) + if(mSocketHandle == INVALID_SOCKET_VALUE) { THROW_EXCEPTION(ServerException, BadSocketHandle) } @@ -385,19 +432,19 @@ void SocketStream::Shutdown(bool Read, bool Write) { THROW_EXCEPTION(ServerException, BadSocketHandle) } - + // Do anything? if(!Read && !Write) return; - + int how = SHUT_RDWR; if(Read && !Write) how = SHUT_RD; if(!Read && Write) how = SHUT_WR; - + // Shut it down! if(::shutdown(mSocketHandle, how) == -1) { BOX_LOG_SYS_ERROR("Failed to shutdown socket"); - THROW_EXCEPTION(ConnectionException, Conn_SocketShutdownError) + THROW_EXCEPTION(ConnectionException, SocketShutdownError) } } @@ -478,8 +525,13 @@ bool SocketStream::GetPeerCredentials(uid_t &rUidOut, gid_t &rGidOut) if(::getsockopt(mSocketHandle, SOL_SOCKET, SO_PEERCRED, &cred, &credLen) == 0) { +#ifdef HAVE_STRUCT_UCRED_UID rUidOut = cred.uid; rGidOut = cred.gid; +#else // HAVE_STRUCT_UCRED_CR_UID + rUidOut = cred.cr_uid; + rGidOut = cred.cr_gid; +#endif return true; } @@ -509,3 +561,11 @@ bool SocketStream::GetPeerCredentials(uid_t &rUidOut, gid_t &rGidOut) return false; } +void SocketStream::CheckForMissingTimeout(int Timeout) +{ + if (Timeout == IOStream::TimeOutInfinite) + { + BOX_WARNING("Network operation started with no timeout!"); + DumpStackBacktrace(); + } +} diff --git a/lib/server/SocketStream.h b/lib/server/SocketStream.h index 2fb5e391..fd57af8f 100644 --- a/lib/server/SocketStream.h +++ b/lib/server/SocketStream.h @@ -10,6 +10,13 @@ #ifndef SOCKETSTREAM__H #define SOCKETSTREAM__H +#include <climits> + +#ifdef HAVE_SYS_POLL_H +# include <sys/poll.h> +#endif + +#include "BoxTime.h" #include "IOStream.h" #include "Socket.h" @@ -41,7 +48,16 @@ public: void Attach(int socket); virtual int Read(void *pBuffer, int NBytes, int Timeout = IOStream::TimeOutInfinite); - virtual void Write(const void *pBuffer, int NBytes); + virtual void Write(const void *pBuffer, int NBytes, + int Timeout = IOStream::TimeOutInfinite); + + // Why not inherited from IOStream? Never mind, we want to enforce + // supplying a timeout for network operations anyway. + virtual void Write(const std::string& rBuffer, int Timeout) + { + IOStream::Write(rBuffer, Timeout); + } + virtual void Close(); virtual bool StreamDataLeft(); virtual bool StreamClosed(); @@ -53,6 +69,42 @@ public: protected: void MarkAsReadClosed() {mReadClosed = true;} void MarkAsWriteClosed() {mWriteClosed = true;} + void CheckForMissingTimeout(int Timeout); + + // Converts a timeout in milliseconds (or IOStream::TimeOutInfinite) + // into one that can be passed to poll() (also in milliseconds), also + // compensating for time elapsed since the wait should have started, + // if known. + int PollTimeout(int timeout, box_time_t start_time) + { + if (timeout == IOStream::TimeOutInfinite) + { + return INFTIM; + } + + if (start_time == 0) + { + return timeout; // no adjustment possible + } + + box_time_t end_time = start_time + MilliSecondsToBoxTime(timeout); + box_time_t now = GetCurrentBoxTime(); + box_time_t remaining = end_time - now; + + if (remaining < 0) + { + return 0; // no delay + } + else if (BoxTimeToMilliSeconds(remaining) > INT_MAX) + { + return INT_MAX; + } + else + { + return (int) BoxTimeToMilliSeconds(remaining); + } + } + bool Poll(short Events, int Timeout); private: tOSSocketHandle mSocketHandle; diff --git a/lib/server/SocketStreamTLS.cpp b/lib/server/SocketStreamTLS.cpp index 576b53a2..e6299bfa 100644 --- a/lib/server/SocketStreamTLS.cpp +++ b/lib/server/SocketStreamTLS.cpp @@ -19,9 +19,10 @@ #include <poll.h> #endif +#include "autogen_ConnectionException.h" +#include "autogen_ServerException.h" #include "BoxTime.h" #include "CryptoUtils.h" -#include "ServerException.h" #include "SocketStreamTLS.h" #include "SSLLib.h" #include "TLSContext.h" @@ -131,7 +132,7 @@ void SocketStreamTLS::Handshake(const TLSContext &rContext, bool IsServer) tOSSocketHandle socket = GetSocketHandle(); BIO_set_fd(mpBIO, socket, BIO_NOCLOSE); - + // Then the SSL object mpSSL = ::SSL_new(rContext.GetRawContext()); if(mpSSL == 0) @@ -154,7 +155,7 @@ void SocketStreamTLS::Handshake(const TLSContext &rContext, bool IsServer) THROW_EXCEPTION(ServerException, SocketSetNonBlockingFailed) } #endif - + // FIXME: This is less portable than the above. However, it MAY be needed // for cygwin, which has/had bugs with fcntl // @@ -196,7 +197,7 @@ void SocketStreamTLS::Handshake(const TLSContext &rContext, bool IsServer) if(WaitWhenRetryRequired(se, TLS_HANDSHAKE_TIMEOUT) == false) { // timed out - THROW_EXCEPTION(ConnectionException, Conn_TLSHandshakeTimedOut) + THROW_EXCEPTION(ConnectionException, TLSHandshakeTimedOut) } break; @@ -205,12 +206,12 @@ void SocketStreamTLS::Handshake(const TLSContext &rContext, bool IsServer) if(IsServer) { CryptoUtils::LogError("accepting connection"); - THROW_EXCEPTION(ConnectionException, Conn_TLSHandshakeFailed) + THROW_EXCEPTION(ConnectionException, TLSHandshakeFailed) } else { CryptoUtils::LogError("connecting"); - THROW_EXCEPTION(ConnectionException, Conn_TLSHandshakeFailed) + THROW_EXCEPTION(ConnectionException, TLSHandshakeFailed) } } } @@ -222,23 +223,25 @@ void SocketStreamTLS::Handshake(const TLSContext &rContext, bool IsServer) // // Function // Name: WaitWhenRetryRequired(int, int) -// Purpose: Waits until the condition required by the TLS layer is met. -// Returns true if the condition is met, false if timed out. +// Purpose: Waits until the condition required by the TLS layer +// is met. Returns true if the condition is met, false +// if timed out. // Created: 2003/08/15 // // -------------------------------------------------------------------------- bool SocketStreamTLS::WaitWhenRetryRequired(int SSLErrorCode, int Timeout) { - struct pollfd p; - p.fd = GetSocketHandle(); + CheckForMissingTimeout(Timeout); + + short events; switch(SSLErrorCode) { case SSL_ERROR_WANT_READ: - p.events = POLLIN; + events = POLLIN; break; case SSL_ERROR_WANT_WRITE: - p.events = POLLOUT; + events = POLLOUT; break; default: @@ -246,45 +249,8 @@ bool SocketStreamTLS::WaitWhenRetryRequired(int SSLErrorCode, int Timeout) THROW_EXCEPTION(ServerException, Internal) break; } - p.revents = 0; - - int64_t start, end; - start = BoxTimeToMilliSeconds(GetCurrentBoxTime()); - end = start + Timeout; - int result; - - do - { - int64_t now = BoxTimeToMilliSeconds(GetCurrentBoxTime()); - int poll_timeout = (int)(end - now); - if (poll_timeout < 0) poll_timeout = 0; - if (Timeout == IOStream::TimeOutInfinite) - { - poll_timeout = INFTIM; - } - result = ::poll(&p, 1, poll_timeout); - } - while(result == -1 && errno == EINTR); - - switch(result) - { - case -1: - // error - Bad! - THROW_EXCEPTION(ServerException, SocketPollError) - break; - - case 0: - // Condition not met, timed out - return false; - break; - - default: - // good to go! - return true; - break; - } - return true; + return Poll(events, Timeout); } // -------------------------------------------------------------------------- @@ -297,6 +263,7 @@ bool SocketStreamTLS::WaitWhenRetryRequired(int SSLErrorCode, int Timeout) // -------------------------------------------------------------------------- int SocketStreamTLS::Read(void *pBuffer, int NBytes, int Timeout) { + CheckForMissingTimeout(Timeout); if(!mpSSL) {THROW_EXCEPTION(ServerException, TLSNoSSLObject)} // Make sure zero byte reads work as expected @@ -304,7 +271,7 @@ int SocketStreamTLS::Read(void *pBuffer, int NBytes, int Timeout) { return 0; } - + while(true) { int r = ::SSL_read(mpSSL, pBuffer, NBytes); @@ -337,7 +304,7 @@ int SocketStreamTLS::Read(void *pBuffer, int NBytes, int Timeout) default: CryptoUtils::LogError("reading"); - THROW_EXCEPTION(ConnectionException, Conn_TLSReadFailed) + THROW_EXCEPTION(ConnectionException, TLSReadFailed) break; } } @@ -351,23 +318,23 @@ int SocketStreamTLS::Read(void *pBuffer, int NBytes, int Timeout) // Created: 2003/08/06 // // -------------------------------------------------------------------------- -void SocketStreamTLS::Write(const void *pBuffer, int NBytes) +void SocketStreamTLS::Write(const void *pBuffer, int NBytes, int Timeout) { if(!mpSSL) {THROW_EXCEPTION(ServerException, TLSNoSSLObject)} - + // Make sure zero byte writes work as expected if(NBytes == 0) { return; } - + // from man SSL_write // // SSL_write() will only return with success, when the // complete contents of buf of length num has been written. // // So no worries about partial writes and moving the buffer around - + while(true) { // try the write @@ -385,24 +352,24 @@ void SocketStreamTLS::Write(const void *pBuffer, int NBytes) case SSL_ERROR_ZERO_RETURN: // Connection closed MarkAsWriteClosed(); - THROW_EXCEPTION(ConnectionException, Conn_TLSClosedWhenWriting) + THROW_EXCEPTION(ConnectionException, TLSClosedWhenWriting) break; case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_WRITE: - // wait for the requried data + // wait for the required data { #ifndef BOX_RELEASE_BUILD - bool conditionmet = + bool conditionmet = #endif - WaitWhenRetryRequired(se, IOStream::TimeOutInfinite); + WaitWhenRetryRequired(se, Timeout); ASSERT(conditionmet); } break; default: CryptoUtils::LogError("writing"); - THROW_EXCEPTION(ConnectionException, Conn_TLSWriteFailed) + THROW_EXCEPTION(ConnectionException, TLSWriteFailed) break; } } @@ -444,7 +411,7 @@ void SocketStreamTLS::Shutdown(bool Read, bool Write) if(::SSL_shutdown(mpSSL) < 0) { CryptoUtils::LogError("shutting down"); - THROW_EXCEPTION(ConnectionException, Conn_TLSShutdownFailed) + THROW_EXCEPTION(ConnectionException, TLSShutdownFailed) } // Don't ask the base class to shutdown -- BIO does this, apparently. @@ -467,15 +434,15 @@ std::string SocketStreamTLS::GetPeerCommonName() if(cert == 0) { ::X509_free(cert); - THROW_EXCEPTION(ConnectionException, Conn_TLSNoPeerCertificate) + THROW_EXCEPTION(ConnectionException, TLSNoPeerCertificate) } - // Subject details - X509_NAME *subject = ::X509_get_subject_name(cert); + // Subject details + X509_NAME *subject = ::X509_get_subject_name(cert); if(subject == 0) { ::X509_free(cert); - THROW_EXCEPTION(ConnectionException, Conn_TLSPeerCertificateInvalid) + THROW_EXCEPTION(ConnectionException, TLSPeerCertificateInvalid) } // Common name @@ -483,7 +450,7 @@ std::string SocketStreamTLS::GetPeerCommonName() if(::X509_NAME_get_text_by_NID(subject, NID_commonName, commonName, sizeof(commonName)) <= 0) { ::X509_free(cert); - THROW_EXCEPTION(ConnectionException, Conn_TLSPeerCertificateInvalid) + THROW_EXCEPTION(ConnectionException, TLSPeerCertificateInvalid) } // Terminate just in case commonName[sizeof(commonName)-1] = '\0'; diff --git a/lib/server/SocketStreamTLS.h b/lib/server/SocketStreamTLS.h index bb40ed10..3fda98c1 100644 --- a/lib/server/SocketStreamTLS.h +++ b/lib/server/SocketStreamTLS.h @@ -43,7 +43,8 @@ public: void Handshake(const TLSContext &rContext, bool IsServer = false); virtual int Read(void *pBuffer, int NBytes, int Timeout = IOStream::TimeOutInfinite); - virtual void Write(const void *pBuffer, int NBytes); + virtual void Write(const void *pBuffer, int NBytes, + int Timeout = IOStream::TimeOutInfinite); virtual void Close(); virtual void Shutdown(bool Read = true, bool Write = true); diff --git a/lib/server/TLSContext.cpp b/lib/server/TLSContext.cpp index 341043e9..1a6d4a53 100644 --- a/lib/server/TLSContext.cpp +++ b/lib/server/TLSContext.cpp @@ -12,8 +12,9 @@ #define TLS_CLASS_IMPLEMENTATION_CPP #include <openssl/ssl.h> +#include "autogen_ConnectionException.h" +#include "autogen_ServerException.h" #include "CryptoUtils.h" -#include "ServerException.h" #include "SSLLib.h" #include "TLSContext.h" @@ -22,6 +23,17 @@ #define MAX_VERIFICATION_DEPTH 2 #define CIPHER_LIST "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH" +// Macros to allow compatibility with OpenSSL 1.0 and 1.1 APIs. See +// https://github.com/charybdis-ircd/charybdis/blob/release/3.5/libratbox/src/openssl_ratbox.h +// for the gory details. +#if defined(LIBRESSL_VERSION_NUMBER) || (OPENSSL_VERSION_NUMBER >= 0x10100000L) // OpenSSL >= 1.1 +# define BOX_TLS_SERVER_METHOD TLS_server_method +# define BOX_TLS_CLIENT_METHOD TLS_client_method +#else // OpenSSL < 1.1 +# define BOX_TLS_SERVER_METHOD TLSv1_server_method +# define BOX_TLS_CLIENT_METHOD TLSv1_client_method +#endif + // -------------------------------------------------------------------------- // // Function @@ -66,7 +78,7 @@ void TLSContext::Initialise(bool AsServer, const char *CertificatesFile, const c ::SSL_CTX_free(mpContext); } - mpContext = ::SSL_CTX_new(AsServer?TLSv1_server_method():TLSv1_client_method()); + mpContext = ::SSL_CTX_new(AsServer ? BOX_TLS_SERVER_METHOD() : BOX_TLS_CLIENT_METHOD()); if(mpContext == NULL) { THROW_EXCEPTION(ServerException, TLSAllocationFailed) diff --git a/lib/server/TcpNice.cpp b/lib/server/TcpNice.cpp index 20619e49..79e91eeb 100644 --- a/lib/server/TcpNice.cpp +++ b/lib/server/TcpNice.cpp @@ -146,7 +146,7 @@ void NiceSocketStream::Write(const void *pBuffer, int NBytes) int socket = mapSocket->GetSocketHandle(); int rtt = 50; // WAG -# if HAVE_DECL_SOL_TCP && HAVE_DECL_TCP_INFO && HAVE_STRUCT_TCP_INFO_TCPI_RTT +# if HAVE_DECL_SOL_TCP && defined HAVE_STRUCT_TCP_INFO_TCPI_RTT struct tcp_info info; socklen_t optlen = sizeof(info); if(getsockopt(socket, SOL_TCP, TCP_INFO, &info, &optlen) == -1) @@ -154,7 +154,7 @@ void NiceSocketStream::Write(const void *pBuffer, int NBytes) BOX_LOG_SYS_WARNING("getsockopt(" << socket << ", SOL_TCP, " "TCP_INFO) failed"); } - else if(optlen != sizeof(info)) + else if(optlen < sizeof(info)) { BOX_WARNING("getsockopt(" << socket << ", SOL_TCP, " "TCP_INFO) return structure size " << optlen << ", " @@ -164,7 +164,7 @@ void NiceSocketStream::Write(const void *pBuffer, int NBytes) { rtt = info.tcpi_rtt; } -# endif +# endif // HAVE_DECL_SOL_TCP && defined HAVE_STRUCT_TCP_INFO_TCPI_RTT int newWindow = mTcpNice.GetNextWindowSize(mBytesWrittenThisPeriod, elapsed, rtt); diff --git a/lib/server/TcpNice.h b/lib/server/TcpNice.h index e2027749..4381df42 100644 --- a/lib/server/TcpNice.h +++ b/lib/server/TcpNice.h @@ -83,7 +83,7 @@ private: // // -------------------------------------------------------------------------- -class NiceSocketStream : public IOStream +class NiceSocketStream : public SocketStream { private: std::auto_ptr<SocketStream> mapSocket; @@ -166,6 +166,10 @@ public: } virtual void SetEnabled(bool enabled); + off_t GetBytesRead() const { return mapSocket->GetBytesRead(); } + off_t GetBytesWritten() const { return mapSocket->GetBytesWritten(); } + void ResetCounters() { mapSocket->ResetCounters(); } + private: NiceSocketStream(const NiceSocketStream &rToCopy) { /* do not call */ } diff --git a/lib/server/WinNamedPipeListener.h b/lib/server/WinNamedPipeListener.h index 26e76e3d..956a7b5a 100644 --- a/lib/server/WinNamedPipeListener.h +++ b/lib/server/WinNamedPipeListener.h @@ -11,10 +11,10 @@ #ifndef WINNAMEDPIPELISTENER__H #define WINNAMEDPIPELISTENER__H -#include <OverlappedIO.h> -#include <WinNamedPipeStream.h> - -#include "ServerException.h" +#include "autogen_ConnectionException.h" +#include "autogen_ServerException.h" +#include "OverlappedIO.h" +#include "WinNamedPipeStream.h" #include "MemLeakFindOn.h" @@ -53,8 +53,8 @@ private: socket.c_str(), // pipe name PIPE_ACCESS_DUPLEX | // read/write access FILE_FLAG_OVERLAPPED, // enabled overlapped I/O - PIPE_TYPE_BYTE | // message type pipe - PIPE_READMODE_BYTE | // message-read mode + PIPE_TYPE_BYTE | + PIPE_READMODE_BYTE | PIPE_WAIT, // blocking mode ListenBacklog + 1, // max. instances 4096, // output buffer size @@ -64,9 +64,9 @@ private: if (handle == INVALID_HANDLE_VALUE) { - BOX_LOG_WIN_ERROR("Failed to create named pipe " << - socket); - THROW_EXCEPTION(ServerException, SocketOpenError) + THROW_WIN_FILE_ERRNO("Failed to create named pipe", + socket, GetLastError(), ServerException, + SocketOpenError); } return handle; diff --git a/lib/server/WinNamedPipeStream.cpp b/lib/server/WinNamedPipeStream.cpp index 1179516e..448a3c9d 100644 --- a/lib/server/WinNamedPipeStream.cpp +++ b/lib/server/WinNamedPipeStream.cpp @@ -19,10 +19,12 @@ #include <errno.h> #include <windows.h> -#include "WinNamedPipeStream.h" -#include "ServerException.h" +#include "autogen_ConnectionException.h" +#include "autogen_ServerException.h" +#include "BoxTime.h" #include "CommonException.h" #include "Socket.h" +#include "WinNamedPipeStream.h" #include "MemLeakFindOn.h" @@ -37,13 +39,14 @@ std::string WinNamedPipeStream::sPipeNamePrefix = "\\\\.\\pipe\\"; // // -------------------------------------------------------------------------- WinNamedPipeStream::WinNamedPipeStream() - : mSocketHandle(INVALID_HANDLE_VALUE), - mReadableEvent(INVALID_HANDLE_VALUE), - mBytesInBuffer(0), - mReadClosed(false), - mWriteClosed(false), - mIsServer(false), - mIsConnected(false) +: mSocketHandle(INVALID_HANDLE_VALUE), + mReadableEvent(INVALID_HANDLE_VALUE), + mBytesInBuffer(0), + mReadClosed(false), + mWriteClosed(false), + mIsServer(false), + mIsConnected(false), + mNeedAnotherRead(false) { } // -------------------------------------------------------------------------- @@ -55,14 +58,21 @@ WinNamedPipeStream::WinNamedPipeStream() // // -------------------------------------------------------------------------- WinNamedPipeStream::WinNamedPipeStream(HANDLE hNamedPipe) - : mSocketHandle(hNamedPipe), - mReadableEvent(INVALID_HANDLE_VALUE), - mBytesInBuffer(0), - mReadClosed(false), - mWriteClosed(false), - mIsServer(true), - mIsConnected(true) +: mSocketHandle(hNamedPipe), + mReadableEvent(INVALID_HANDLE_VALUE), + mBytesInBuffer(0), + mReadClosed(false), + mWriteClosed(false), + mIsServer(true), + mIsConnected(true), + mNeedAnotherRead(false) { + StartFirstRead(); +} + +// Start the first overlapped read +void WinNamedPipeStream::StartFirstRead() +{ // create the Readable event mReadableEvent = CreateEvent(NULL, TRUE, FALSE, NULL); @@ -74,23 +84,50 @@ WinNamedPipeStream::WinNamedPipeStream(HANDLE hNamedPipe) THROW_EXCEPTION(CommonException, Internal) } - // initialise the OVERLAPPED structure + StartOverlappedRead(); +} + +void WinNamedPipeStream::StartOverlappedRead() +{ + // We should only do this when the buffer is empty. We don't want + // to start an overlapped read anywhere else than the start of the + // buffer, because it could complete at any time and we don't want + // to mess about with interrupting the read already in progress. + ASSERT(mBytesInBuffer == 0); + + // Initialise the OVERLAPPED structure memset(&mReadOverlap, 0, sizeof(mReadOverlap)); mReadOverlap.hEvent = mReadableEvent; - // start the first overlapped read if (!ReadFile(mSocketHandle, mReadBuffer, sizeof(mReadBuffer), NULL, &mReadOverlap)) { DWORD err = GetLastError(); - - if (err != ERROR_IO_PENDING) + if (err == ERROR_IO_PENDING) + { + // Don't reset yet, there might be data + // in the buffer waiting to be read, + // will check below. + // ResetEvent(mReadableEvent); + } + else if (err == ERROR_HANDLE_EOF) + { + BOX_INFO("Control client disconnected"); + mReadClosed = true; + } + else if (err == ERROR_BROKEN_PIPE || + err == ERROR_PIPE_NOT_CONNECTED) + { + BOX_NOTICE("Control client disconnected"); + mReadClosed = true; + mIsConnected = false; + } + else { - BOX_ERROR("Failed to start overlapped read: " << - GetErrorMessage(err)); Close(); - THROW_EXCEPTION(ConnectionException, - Conn_SocketReadError) + THROW_WIN_ERROR_NUMBER("Failed to start overlapped " + "read", err, ConnectionException, + SocketReadError) } } } @@ -105,6 +142,12 @@ WinNamedPipeStream::WinNamedPipeStream(HANDLE hNamedPipe) // -------------------------------------------------------------------------- WinNamedPipeStream::~WinNamedPipeStream() { + for(std::list<WriteInProgress*>::iterator i = mWritesInProgress.begin(); + i != mWritesInProgress.end(); i++) + { + delete *i; + } + if (mSocketHandle != INVALID_HANDLE_VALUE) { try @@ -157,36 +200,7 @@ void WinNamedPipeStream::Accept() mIsServer = true; // must flush and disconnect before closing mIsConnected = true; - // create the Readable event - mReadableEvent = CreateEvent(NULL, TRUE, FALSE, NULL); - - if (mReadableEvent == INVALID_HANDLE_VALUE) - { - BOX_ERROR("Failed to create the Readable event: " << - GetErrorMessage(GetLastError())); - Close(); - THROW_EXCEPTION(CommonException, Internal) - } - - // initialise the OVERLAPPED structure - memset(&mReadOverlap, 0, sizeof(mReadOverlap)); - mReadOverlap.hEvent = mReadableEvent; - - // start the first overlapped read - if (!ReadFile(mSocketHandle, mReadBuffer, sizeof(mReadBuffer), - NULL, &mReadOverlap)) - { - DWORD err = GetLastError(); - - if (err != ERROR_IO_PENDING) - { - BOX_ERROR("Failed to start overlapped read: " << - GetErrorMessage(err)); - Close(); - THROW_EXCEPTION(ConnectionException, - Conn_SocketReadError) - } - } + StartFirstRead(); } */ @@ -214,7 +228,7 @@ void WinNamedPipeStream::Connect(const std::string& rName) 0, // no sharing NULL, // default security attributes OPEN_EXISTING, - 0, // default attributes + 0, // FILE_FLAG_OVERLAPPED, // dwFlagsAndAttributes NULL); // no template file if (mSocketHandle == INVALID_HANDLE_VALUE) @@ -237,6 +251,86 @@ void WinNamedPipeStream::Connect(const std::string& rName) mWriteClosed = false; mIsServer = false; // just close the socket mIsConnected = true; + + StartFirstRead(); +} + +// Returns true if the operation is complete (and you will need to start +// another one), or false otherwise (you can wait again). +bool WinNamedPipeStream::WaitForOverlappedOperation(OVERLAPPED& Overlapped, + int Timeout, int64_t* pBytesTransferred) +{ + if (Timeout == IOStream::TimeOutInfinite) + { + Timeout = INFINITE; + } + + // overlapped I/O completed successfully? (wait if needed) + DWORD waitResult = WaitForSingleObject(Overlapped.hEvent, Timeout); + DWORD NumBytesTransferred = -1; + + if (waitResult == WAIT_FAILED) + { + THROW_WIN_ERROR_NUMBER("Failed to wait for overlapped I/O", + GetLastError(), ServerException, Internal); + } + + if (waitResult == WAIT_ABANDONED) + { + THROW_EXCEPTION_MESSAGE(ServerException, Internal, + "Wait for overlapped I/O abandoned by system"); + } + + if (waitResult == WAIT_TIMEOUT) + { + // wait timed out, nothing to read + *pBytesTransferred = 0; + return false; + } + + if (waitResult != WAIT_OBJECT_0) + { + THROW_EXCEPTION_MESSAGE(ServerException, BadSocketHandle, + "Failed to wait for overlapped I/O: unknown " + "result code: " << waitResult); + } + + // Overlapped operation completed successfully. Return the number + // of bytes transferred. + if (GetOverlappedResult(mSocketHandle, &Overlapped, + &NumBytesTransferred, TRUE)) + { + *pBytesTransferred = NumBytesTransferred; + return true; + } + + // We are here because GetOverlappedResult() informed us that the + // overlapped operation encountered an error, so what was it? + DWORD err = GetLastError(); + + if (err == ERROR_HANDLE_EOF) + { + Close(); + *pBytesTransferred = 0; + return true; + } + + // ERROR_NO_DATA is a strange name for + // "The pipe is being closed". No exception wanted. + + if (err == ERROR_NO_DATA || + err == ERROR_PIPE_NOT_CONNECTED || + err == ERROR_BROKEN_PIPE) + { + BOX_INFO(BOX_WIN_ERRNO_MESSAGE(err, + "Named pipe peer disconnected")); + Close(); + *pBytesTransferred = 0; + return true; + } + + THROW_WIN_ERROR_NUMBER("Failed to wait for overlapped I/O " + "to complete", err, ConnectionException, SocketReadError); } // -------------------------------------------------------------------------- @@ -249,192 +343,61 @@ void WinNamedPipeStream::Connect(const std::string& rName) // -------------------------------------------------------------------------- int WinNamedPipeStream::Read(void *pBuffer, int NBytes, int Timeout) { - // TODO no support for timeouts yet - if (!mIsServer && Timeout != IOStream::TimeOutInfinite) - { - THROW_EXCEPTION(CommonException, AssertFailed) - } - if (mSocketHandle == INVALID_HANDLE_VALUE || !mIsConnected) { - THROW_EXCEPTION(ServerException, BadSocketHandle) + THROW_EXCEPTION_MESSAGE(ServerException, BadSocketHandle, + "Tried to read from closed pipe"); } if (mReadClosed) { - THROW_EXCEPTION(ConnectionException, SocketShutdownError) + THROW_EXCEPTION_MESSAGE(ConnectionException, + SocketShutdownError, "Tried to read from closing pipe"); } // ensure safe to cast NBytes to unsigned if (NBytes < 0) { - THROW_EXCEPTION(CommonException, AssertFailed) + THROW_EXCEPTION(CommonException, AssertFailed); } - DWORD NumBytesRead; + int64_t NumBytesRead; - if (mIsServer) + // Satisfy from buffer if possible, to avoid blocking on read. + if (mBytesInBuffer == 0) { - // satisfy from buffer if possible, to avoid - // blocking on read. - bool needAnotherRead = false; - if (mBytesInBuffer == 0) - { - // overlapped I/O completed successfully? - // (wait if needed) - DWORD waitResult = WaitForSingleObject( - mReadOverlap.hEvent, Timeout); - - if (waitResult == WAIT_ABANDONED) - { - BOX_ERROR("Wait for command socket read " - "abandoned by system"); - THROW_EXCEPTION(ServerException, - BadSocketHandle); - } - else if (waitResult == WAIT_TIMEOUT) - { - // wait timed out, nothing to read - NumBytesRead = 0; - } - else if (waitResult != WAIT_OBJECT_0) - { - BOX_ERROR("Failed to wait for command " - "socket read: unknown result " << - waitResult); - } - // object is ready to read from - else if (GetOverlappedResult(mSocketHandle, - &mReadOverlap, &NumBytesRead, TRUE)) - { - needAnotherRead = true; - } - else - { - DWORD err = GetLastError(); - - if (err == ERROR_HANDLE_EOF) - { - mReadClosed = true; - } - else - { - if (err == ERROR_BROKEN_PIPE) - { - BOX_NOTICE("Control client " - "disconnected"); - } - else - { - BOX_ERROR("Failed to wait for " - "ReadFile to complete: " - << GetErrorMessage(err)); - } - - Close(); - THROW_EXCEPTION(ConnectionException, - Conn_SocketReadError) - } - } - } - else - { - NumBytesRead = 0; - } - - size_t BytesToCopy = NumBytesRead + mBytesInBuffer; - size_t BytesRemaining = 0; - - if (BytesToCopy > (size_t)NBytes) - { - BytesRemaining = BytesToCopy - NBytes; - BytesToCopy = NBytes; - } - - memcpy(pBuffer, mReadBuffer, BytesToCopy); - memmove(mReadBuffer, mReadBuffer + BytesToCopy, BytesRemaining); - - mBytesInBuffer = BytesRemaining; - NumBytesRead = BytesToCopy; - - if (needAnotherRead) + if (mNeedAnotherRead) { - // reinitialise the OVERLAPPED structure - memset(&mReadOverlap, 0, sizeof(mReadOverlap)); - mReadOverlap.hEvent = mReadableEvent; + // Start the next overlapped read + StartOverlappedRead(); } - // start the next overlapped read - if (needAnotherRead && !ReadFile(mSocketHandle, - mReadBuffer + mBytesInBuffer, - sizeof(mReadBuffer) - mBytesInBuffer, - NULL, &mReadOverlap)) - { - DWORD err = GetLastError(); - if (err == ERROR_IO_PENDING) - { - // Don't reset yet, there might be data - // in the buffer waiting to be read, - // will check below. - // ResetEvent(mReadableEvent); - } - else if (err == ERROR_HANDLE_EOF) - { - mReadClosed = true; - } - else if (err == ERROR_BROKEN_PIPE) - { - BOX_ERROR("Control client disconnected"); - mReadClosed = true; - } - else - { - BOX_ERROR("Failed to start overlapped read: " - << GetErrorMessage(err)); - Close(); - THROW_EXCEPTION(ConnectionException, - Conn_SocketReadError) - } - } + mNeedAnotherRead = WaitForOverlappedOperation(mReadOverlap, + Timeout, &NumBytesRead); } else { - if (!ReadFile( - mSocketHandle, // pipe handle - pBuffer, // buffer to receive reply - NBytes, // size of buffer - &NumBytesRead, // number of bytes read - NULL)) // not overlapped - { - DWORD err = GetLastError(); - - Close(); + // Just return the existing data from the buffer + // this time around. The caller should call again, + // and then the buffer will be empty. + NumBytesRead = 0; + } - // ERROR_NO_DATA is a strange name for - // "The pipe is being closed". No exception wanted. - - if (err == ERROR_NO_DATA || - err == ERROR_PIPE_NOT_CONNECTED) - { - NumBytesRead = 0; - } - else - { - BOX_ERROR("Failed to read from control socket: " - << GetErrorMessage(err)); - THROW_EXCEPTION(ConnectionException, - Conn_SocketReadError) - } - } - - // Closed for reading at EOF? - if (NumBytesRead == 0) - { - mReadClosed = true; - } + int BytesToCopy = NumBytesRead + mBytesInBuffer; + + if (NBytes < BytesToCopy) + { + BytesToCopy = NBytes; } - - return NumBytesRead; + + memcpy(pBuffer, mReadBuffer, BytesToCopy); + + size_t BytesRemaining = mBytesInBuffer + NumBytesRead - BytesToCopy; + ASSERT(BytesToCopy + BytesRemaining <= sizeof(mReadBuffer)); + memmove(mReadBuffer, mReadBuffer + BytesToCopy, BytesRemaining); + mBytesInBuffer = BytesRemaining; + + return BytesToCopy; } // -------------------------------------------------------------------------- @@ -445,8 +408,15 @@ int WinNamedPipeStream::Read(void *pBuffer, int NBytes, int Timeout) // Created: 2003/07/31 // // -------------------------------------------------------------------------- -void WinNamedPipeStream::Write(const void *pBuffer, int NBytes) +void WinNamedPipeStream::Write(const void *pBuffer, int NBytes, int Timeout) { + // Calculate the deadline at the beginning. Not valid if Timeout is + // IOStream::TimeOutInfinite! + ASSERT(Timeout != IOStream::TimeOutInfinite); + + box_time_t deadline = GetCurrentBoxTime() + + MilliSecondsToBoxTime(Timeout); + if (mSocketHandle == INVALID_HANDLE_VALUE || !mIsConnected) { THROW_EXCEPTION(ServerException, BadSocketHandle) @@ -454,41 +424,62 @@ void WinNamedPipeStream::Write(const void *pBuffer, int NBytes) // Buffer in byte sized type. ASSERT(sizeof(char) == 1); - const char *pByteBuffer = (char *)pBuffer; - - int NumBytesWrittenTotal = 0; - - while (NumBytesWrittenTotal < NBytes) + WriteInProgress* new_write = new WriteInProgress( + std::string((char *)pBuffer, NBytes)); + + // Start the WriteFile operation, and add to queue if pending. + BOOL Success = WriteFile( + mSocketHandle, // pipe handle + new_write->mBuffer.c_str(), // message + NBytes, // message length + NULL, // bytes written this time + &(new_write->mOverlap)); + + if (Success == TRUE) { - DWORD NumBytesWrittenThisTime = 0; - - bool Success = WriteFile( - mSocketHandle, // pipe handle - pByteBuffer + NumBytesWrittenTotal, // message - NBytes - NumBytesWrittenTotal, // message length - &NumBytesWrittenThisTime, // bytes written this time - NULL); // not overlapped + // Unfortunately this does happen. We should still call + // GetOverlappedResult() to get the number of bytes written, + // so we can treat it just the same. + // BOX_NOTICE("Write claimed success while overlapped?"); + mWritesInProgress.push_back(new_write); + } + else + { + DWORD err = GetLastError(); - if (!Success) + if (err == ERROR_IO_PENDING) { - // ERROR_NO_DATA is a strange name for - // "The pipe is being closed". - - DWORD err = GetLastError(); - - if (err != ERROR_NO_DATA) - { - BOX_ERROR("Failed to write to control " - "socket: " << GetErrorMessage(err)); - } - + BOX_TRACE("WriteFile is pending, adding to queue"); + mWritesInProgress.push_back(new_write); + } + else + { + // Not in progress any more, pop it Close(); - - THROW_EXCEPTION(ConnectionException, - Conn_SocketWriteError) + THROW_WIN_ERROR_NUMBER("Failed to start overlapped " + "write", err, ConnectionException, + SocketWriteError); } + } + + // Wait for previous WriteFile operations to complete, one at a time, + // until the deadline expires or the pipe becomes disconnected. + for(box_time_t remaining = deadline - GetCurrentBoxTime(); + remaining > 0 && !mWritesInProgress.empty() && mIsConnected; + remaining = deadline - GetCurrentBoxTime()) + { + int new_timeout = BoxTimeToMilliSeconds(remaining); + WriteInProgress* oldest_write = + *(mWritesInProgress.begin()); - NumBytesWrittenTotal += NumBytesWrittenThisTime; + int64_t bytes_written = 0; + if(WaitForOverlappedOperation(oldest_write->mOverlap, + new_timeout, &bytes_written)) + { + // This one is complete, pop it and start a new one + delete oldest_write; + mWritesInProgress.pop_front(); + } } } @@ -513,59 +504,47 @@ void WinNamedPipeStream::Close() THROW_EXCEPTION(ServerException, BadSocketHandle) } - if (mIsServer) + if (!CancelIo(mSocketHandle)) { - if (!CancelIo(mSocketHandle)) - { - BOX_ERROR("Failed to cancel outstanding I/O: " << - GetErrorMessage(GetLastError())); - } + BOX_LOG_WIN_ERROR("Failed to cancel outstanding I/O"); + } - if (mReadableEvent == INVALID_HANDLE_VALUE) - { - BOX_ERROR("Failed to destroy Readable event: " - "invalid handle"); - } - else if (!CloseHandle(mReadableEvent)) - { - BOX_ERROR("Failed to destroy Readable event: " << - GetErrorMessage(GetLastError())); - } + if (mReadableEvent == INVALID_HANDLE_VALUE) + { + BOX_ERROR("Failed to destroy Readable event: " + "invalid handle"); + } + else if (!CloseHandle(mReadableEvent)) + { + BOX_LOG_WIN_ERROR("Failed to destroy Readable event"); + } - mReadableEvent = INVALID_HANDLE_VALUE; + mReadableEvent = INVALID_HANDLE_VALUE; - if (!FlushFileBuffers(mSocketHandle)) - { - BOX_ERROR("Failed to FlushFileBuffers: " << - GetErrorMessage(GetLastError())); - } - - if (!DisconnectNamedPipe(mSocketHandle)) + if (mIsConnected && !FlushFileBuffers(mSocketHandle)) + { + BOX_LOG_WIN_ERROR("Failed to FlushFileBuffers"); + } + + if (mIsServer && mIsConnected && !DisconnectNamedPipe(mSocketHandle)) + { + DWORD err = GetLastError(); + if (err != ERROR_PIPE_NOT_CONNECTED) { - DWORD err = GetLastError(); - if (err != ERROR_PIPE_NOT_CONNECTED) - { - BOX_ERROR("Failed to DisconnectNamedPipe: " << - GetErrorMessage(err)); - } + BOX_LOG_WIN_ERROR("Failed to DisconnectNamedPipe"); } - - mIsServer = false; } - bool result = CloseHandle(mSocketHandle); + if (!CloseHandle(mSocketHandle)) + { + THROW_WIN_ERROR_NUMBER("Failed to CloseHandle", + GetLastError(), ServerException, SocketCloseError); + } mSocketHandle = INVALID_HANDLE_VALUE; mIsConnected = false; mReadClosed = true; mWriteClosed = true; - - if (!result) - { - BOX_ERROR("Failed to CloseHandle: " << - GetErrorMessage(GetLastError())); - THROW_EXCEPTION(ServerException, SocketCloseError) - } } // -------------------------------------------------------------------------- diff --git a/lib/server/WinNamedPipeStream.h b/lib/server/WinNamedPipeStream.h index 386ff7e3..5473c690 100644 --- a/lib/server/WinNamedPipeStream.h +++ b/lib/server/WinNamedPipeStream.h @@ -10,6 +10,8 @@ #if ! defined WINNAMEDPIPESTREAM__H && defined WIN32 #define WINNAMEDPIPESTREAM__H +#include <list> + #include "IOStream.h" // -------------------------------------------------------------------------- @@ -36,15 +38,27 @@ public: // both sides virtual int Read(void *pBuffer, int NBytes, int Timeout = IOStream::TimeOutInfinite); - virtual void Write(const void *pBuffer, int NBytes); + virtual void Write(const void *pBuffer, int NBytes, + int Timeout = IOStream::TimeOutInfinite); virtual void WriteAllBuffered(); virtual void Close(); virtual bool StreamDataLeft(); virtual bool StreamClosed(); + // Why not inherited from IOStream? Never mind, we want to enforce + // supplying a timeout for network operations anyway. + virtual void Write(const std::string& rBuffer, int Timeout) + { + IOStream::Write(rBuffer, Timeout); + } + protected: void MarkAsReadClosed() {mReadClosed = true;} void MarkAsWriteClosed() {mWriteClosed = true;} + bool WaitForOverlappedOperation(OVERLAPPED& Overlapped, + int Timeout, int64_t* pBytesTransferred); + void StartFirstRead(); + void StartOverlappedRead(); private: WinNamedPipeStream(const WinNamedPipeStream &rToCopy) @@ -59,6 +73,37 @@ private: bool mWriteClosed; bool mIsServer; bool mIsConnected; + bool mNeedAnotherRead; + + class WriteInProgress { + private: + friend class WinNamedPipeStream; + std::string mBuffer; + OVERLAPPED mOverlap; + WriteInProgress(const WriteInProgress& other); // do not call + public: + WriteInProgress(const std::string& dataToWrite) + : mBuffer(dataToWrite) + { + // create the Writable event + HANDLE writable_event = CreateEvent(NULL, TRUE, FALSE, + NULL); + if (writable_event == INVALID_HANDLE_VALUE) + { + BOX_LOG_WIN_ERROR("Failed to create the " + "Writable event"); + THROW_EXCEPTION(CommonException, Internal) + } + + memset(&mOverlap, 0, sizeof(mOverlap)); + mOverlap.hEvent = writable_event; + } + ~WriteInProgress() + { + CloseHandle(mOverlap.hEvent); + } + }; + std::list<WriteInProgress*> mWritesInProgress; public: static std::string sPipeNamePrefix; diff --git a/lib/server/makeprotocol.pl.in b/lib/server/makeprotocol.pl.in index a074b435..d6c0e216 100755 --- a/lib/server/makeprotocol.pl.in +++ b/lib/server/makeprotocol.pl.in @@ -78,7 +78,7 @@ sub add_type my ($protocol_name, $cpp_name, $header_file) = split /\s+/,$_[0]; $translate_type_info{$protocol_name} = [0, $cpp_name]; - push @extra_header_files, $header_file; + push @extra_header_files, $header_file if $header_file; } # check attributes @@ -158,7 +158,10 @@ print CPP <<__E; #include <sstream> #include "$filename_base.h" -#include "IOStream.h" +#include "CollectInBufferStream.h" +#include "MemBlockStream.h" +#include "SelfFlushingStream.h" +#include "SocketStream.h" __E print H <<__E; @@ -174,12 +177,10 @@ print H <<__E; #include <syslog.h> #endif +#include "autogen_ConnectionException.h" #include "Protocol.h" #include "Message.h" -#include "ServerException.h" - -class IOStream; - +#include "SocketStream.h" __E @@ -210,19 +211,26 @@ __E my $request_base_class = "${protocol_name}ProtocolRequest"; my $reply_base_class = "${protocol_name}ProtocolReply"; # the abstract protocol interface -my $protocol_base_class = $protocol_name."ProtocolBase"; +my $custom_protocol_subclass = $protocol_name."Protocol"; +my $client_server_base_class = $protocol_name."ProtocolClientServer"; my $replyable_base_class = $protocol_name."ProtocolReplyable"; +my $callable_base_class = $protocol_name."ProtocolCallable"; +my $send_receive_class = $protocol_name."ProtocolSendReceive"; print H <<__E; -class $protocol_base_class; +class $custom_protocol_subclass; +class $client_server_base_class; +class $callable_base_class; class $replyable_base_class; -class $reply_base_class; class $message_base_class : public Message { public: virtual std::auto_ptr<$message_base_class> DoCommand($replyable_base_class &rProtocol, $context_class &rContext) const; + virtual std::auto_ptr<$message_base_class> DoCommand($replyable_base_class &rProtocol, + $context_class &rContext, IOStream& rDataStream) const; + virtual bool HasStreamWithCommand() const = 0; }; class $reply_base_class @@ -233,17 +241,44 @@ class $request_base_class { }; +class $send_receive_class { +public: + virtual void Send(const $message_base_class &rObject) = 0; + virtual std::auto_ptr<$message_base_class> Receive() = 0; +}; + +class $custom_protocol_subclass : public Protocol +{ +public: + $custom_protocol_subclass(std::auto_ptr<SocketStream> apConn) + : Protocol(apConn) + { } + virtual ~$custom_protocol_subclass() { } + virtual std::auto_ptr<Message> MakeMessage(int ObjType); + virtual const char *GetProtocolIdentString(); + +private: + $custom_protocol_subclass(const $custom_protocol_subclass &rToCopy); +}; + __E print CPP <<__E; std::auto_ptr<$message_base_class> $message_base_class\::DoCommand($replyable_base_class &rProtocol, $context_class &rContext) const { - THROW_EXCEPTION(ConnectionException, Conn_Protocol_TriedToExecuteReplyCommand) + THROW_EXCEPTION(ConnectionException, Protocol_TriedToExecuteReplyCommand) +} + +std::auto_ptr<$message_base_class> $message_base_class\::DoCommand($replyable_base_class &rProtocol, + $context_class &rContext, IOStream& rDataStream) const +{ + THROW_EXCEPTION(ConnectionException, Protocol_TriedToExecuteReplyCommand) } __E -my %cmd_class; +my %cmd_classes; +my $error_message = undef; # output the classes foreach my $cmd (@cmd_list) @@ -262,7 +297,7 @@ foreach my $cmd (@cmd_list) my $cmd_base_class = join(", ", map {"public $_"} @cmd_base_classes); my $cmd_class = $protocol_name."Protocol".$cmd; - $cmd_class{$cmd} = $cmd_class; + $cmd_classes{$cmd} = $cmd_class; print H <<__E; class $cmd_class : $cmd_base_class @@ -294,18 +329,50 @@ __E if(obj_is_type($cmd,'IsError')) { - print H "\tbool IsError(int &rTypeOut, int &rSubTypeOut) const;\n"; - print H "\tstd::string GetMessage() const;\n"; + $error_message = $cmd; + my ($mem_type,$mem_subtype) = split /,/,obj_get_type_params($cmd,'IsError'); + my $error_type = $cmd_constants{"ErrorType"}; + print H <<__E; + $cmd_class(int SubType) : m$mem_type($error_type), m$mem_subtype(SubType) { } + bool IsError(int &rTypeOut, int &rSubTypeOut) const; + std::string GetMessage() const { return GetMessage(m$mem_subtype); }; + static std::string GetMessage(int subtype); +__E } - if(obj_is_type($cmd, 'Command')) + my $has_stream = obj_is_type($cmd, 'StreamWithCommand'); + + if(obj_is_type($cmd, 'Command') && $has_stream) + { + print H <<__E; + std::auto_ptr<$message_base_class> DoCommand($replyable_base_class &rProtocol, + $context_class &rContext, IOStream& rDataStream) const; // IMPLEMENT THIS\n + std::auto_ptr<$message_base_class> DoCommand($replyable_base_class &rProtocol, + $context_class &rContext) const + { + THROW_EXCEPTION_MESSAGE(CommonException, Internal, + "This command requires a stream parameter"); + } +__E + } + elsif(obj_is_type($cmd, 'Command') && !$has_stream) { print H <<__E; std::auto_ptr<$message_base_class> DoCommand($replyable_base_class &rProtocol, $context_class &rContext) const; // IMPLEMENT THIS\n + std::auto_ptr<$message_base_class> DoCommand($replyable_base_class &rProtocol, + $context_class &rContext, IOStream& rDataStream) const + { + THROW_EXCEPTION_MESSAGE(CommonException, NotSupported, + "This command requires no stream parameter"); + } __E } + print H <<__E; + bool HasStreamWithCommand() const { return $has_stream; } +__E + # want to be able to read from streams? print H "\tvoid SetPropertiesFromStreamData(Protocol &rProtocol);\n"; @@ -442,9 +509,9 @@ bool $cmd_class\::IsError(int &rTypeOut, int &rSubTypeOut) const rSubTypeOut = m$mem_subtype; return true; } -std::string $cmd_class\::GetMessage() const +std::string $cmd_class\::GetMessage(int subtype) { - switch(m$mem_subtype) + switch(subtype) { __E foreach my $const (@{$cmd_constants{$cmd}}) @@ -459,7 +526,7 @@ __E print CPP <<__E; default: std::ostringstream out; - out << "Unknown subtype " << m$mem_subtype; + out << "Unknown subtype " << subtype; return out.str(); } } @@ -505,47 +572,44 @@ my $error_class = $protocol_name."ProtocolError"; # the abstract protocol interface print H <<__E; -class $protocol_base_class + +class $client_server_base_class { public: - $protocol_base_class(); - virtual ~$protocol_base_class(); - virtual const char *GetIdentString(); + $client_server_base_class(); + virtual ~$client_server_base_class(); + virtual std::auto_ptr<IOStream> ReceiveStream() = 0; bool GetLastError(int &rTypeOut, int &rSubTypeOut); + int GetLastErrorType() { return mLastErrorSubType; } protected: - void CheckReply(const std::string& requestCommand, - const $message_base_class &rReply, int expectedType); void SetLastError(int Type, int SubType) { mLastErrorType = Type; mLastErrorSubType = SubType; } + std::string mPreviousCommand; + std::string mPreviousReply; private: - $protocol_base_class(const $protocol_base_class &rToCopy); /* do not call */ + $client_server_base_class(const $client_server_base_class &rToCopy); /* do not call */ int mLastErrorType; int mLastErrorSubType; }; -class $replyable_base_class : public virtual $protocol_base_class +class $replyable_base_class : public virtual $client_server_base_class { public: - $replyable_base_class(); + $replyable_base_class() { } virtual ~$replyable_base_class(); - /* - virtual std::auto_ptr<$message_base_class> Receive() = 0; - virtual void Send(const ${message_base_class} &rObject) = 0; - */ - - virtual std::auto_ptr<IOStream> ReceiveStream() = 0; virtual int GetTimeout() = 0; void SendStreamAfterCommand(std::auto_ptr<IOStream> apStream); - + protected: std::list<IOStream*> mStreamsToSend; void DeleteStreamsToSend(); + virtual std::auto_ptr<$message_base_class> HandleException(BoxException& e) const; private: $replyable_base_class(const $replyable_base_class &rToCopy); /* do not call */ @@ -554,24 +618,47 @@ private: __E print CPP <<__E; -$protocol_base_class\::$protocol_base_class() +$client_server_base_class\::$client_server_base_class() : mLastErrorType(Protocol::NoError), mLastErrorSubType(Protocol::NoError) { } -$protocol_base_class\::~$protocol_base_class() +$client_server_base_class\::~$client_server_base_class() { } -const char *$protocol_base_class\::GetIdentString() +const char *$custom_protocol_subclass\::GetProtocolIdentString() { return "$ident_string"; } -$replyable_base_class\::$replyable_base_class() -{ } +std::auto_ptr<Message> $custom_protocol_subclass\::MakeMessage(int ObjType) +{ + switch(ObjType) + { +__E + +# do objects within this +for my $cmd (@cmd_list) +{ + print CPP <<__E; + case $cmd_id{$cmd}: + return std::auto_ptr<Message>(new $cmd_classes{$cmd}()); + break; +__E +} + +print CPP <<__E; + default: + THROW_EXCEPTION(ConnectionException, Protocol_UnknownCommandRecieved) + } +} $replyable_base_class\::~$replyable_base_class() -{ } +{ + // If there were any streams left over, there's no longer any way to + // access them, and we're responsible for them, so we'd better delete them. + DeleteStreamsToSend(); +} void $replyable_base_class\::SendStreamAfterCommand(std::auto_ptr<IOStream> apStream) { @@ -589,12 +676,14 @@ void $replyable_base_class\::DeleteStreamsToSend() mStreamsToSend.clear(); } -void $protocol_base_class\::CheckReply(const std::string& requestCommand, - const $message_base_class &rReply, int expectedType) +void $callable_base_class\::CheckReply(const std::string& requestCommandName, + const $message_base_class &rCommand, const $message_base_class &rReply, + int expectedType) { if(rReply.GetType() == expectedType) { // Correct response, do nothing + SetLastError(Protocol::NoError, Protocol::NoError); } else { @@ -605,8 +694,8 @@ void $protocol_base_class\::CheckReply(const std::string& requestCommand, { SetLastError(type, subType); THROW_EXCEPTION_MESSAGE(ConnectionException, - Conn_Protocol_UnexpectedReply, - requestCommand << " command failed: " + Protocol_UnexpectedReply, + requestCommandName << " command failed: " "received error " << (($error_class&)rReply).GetMessage()); } @@ -614,12 +703,18 @@ void $protocol_base_class\::CheckReply(const std::string& requestCommand, { SetLastError(Protocol::UnknownError, Protocol::UnknownError); THROW_EXCEPTION_MESSAGE(ConnectionException, - Conn_Protocol_UnexpectedReply, - requestCommand << " command failed: " + Protocol_UnexpectedReply, + requestCommandName << " command failed: " "received unexpected response type " << rReply.GetType()); } } + + // As a client, if we get an unexpected reply later, we'll want to know + // the last command that we executed, and the reply, to help debug the + // server. + mPreviousCommand = rCommand.ToString(); + mPreviousReply = rReply.ToString(); } // -------------------------------------------------------------------------- @@ -630,7 +725,7 @@ void $protocol_base_class\::CheckReply(const std::string& requestCommand, // Created: 2003/08/19 // // -------------------------------------------------------------------------- -bool $protocol_base_class\::GetLastError(int &rTypeOut, int &rSubTypeOut) +bool $client_server_base_class\::GetLastError(int &rTypeOut, int &rSubTypeOut) { if(mLastErrorType == Protocol::NoError) { @@ -653,13 +748,19 @@ __E # the callable protocol interface (implemented by Client and Local classes) # with Query methods that don't take a context parameter -my $callable_base_class = $protocol_name."ProtocolCallable"; print H <<__E; -class $callable_base_class : public virtual $protocol_base_class +class $callable_base_class : public virtual $client_server_base_class, + public $send_receive_class { public: - virtual std::auto_ptr<IOStream> ReceiveStream() = 0; virtual int GetTimeout() = 0; + +protected: + void CheckReply(const std::string& requestCommandName, + const $message_base_class &rCommand, + const $message_base_class &rReply, int expectedType); + +public: __E # add plain object taking query functions @@ -671,8 +772,8 @@ for my $cmd (@cmd_list) my $has_stream = obj_is_type($cmd,'StreamWithCommand'); my $argextra = $has_stream?', std::auto_ptr<IOStream> apStream':''; my $queryextra = $has_stream?', apStream':''; - my $request_class = $cmd_class{$cmd}; - my $reply_class = $cmd_class{obj_get_type_params($cmd,'Command')}; + my $request_class = $cmd_classes{$cmd}; + my $reply_class = $cmd_classes{obj_get_type_params($cmd,'Command')}; print H "\tvirtual std::auto_ptr<$reply_class> Query(const $request_class &rQuery$argextra) = 0;\n"; my @a; @@ -720,13 +821,15 @@ foreach my $type ('Client', 'Server', 'Local') { push @base_classes, $replyable_base_class; } + if (not $writing_server) { push @base_classes, $callable_base_class; } + if (not $writing_local) { - push @base_classes, "Protocol"; + push @base_classes, $custom_protocol_subclass; } my $base_classes_str = join(", ", map {"public $_"} @base_classes); @@ -735,6 +838,7 @@ foreach my $type ('Client', 'Server', 'Local') class $server_or_client_class : $base_classes_str { public: + virtual ~$server_or_client_class(); __E if($writing_local) @@ -743,18 +847,12 @@ __E $server_or_client_class($context_class &rContext); __E } - else - { - print H <<__E; - $server_or_client_class(IOStream &rStream); + + print H <<__E; + $server_or_client_class(std::auto_ptr<SocketStream> apConn); std::auto_ptr<$message_base_class> Receive(); void Send(const $message_base_class &rObject); __E - } - - print H <<__E; - virtual ~$server_or_client_class(); -__E if($writing_server) { @@ -775,37 +873,29 @@ __E my $has_stream = obj_is_type($cmd,'StreamWithCommand'); my $argextra = $has_stream?', std::auto_ptr<IOStream> apStream':''; my $queryextra = $has_stream?', apStream':''; - my $request_class = $cmd_class{$cmd}; - my $reply_class = $cmd_class{obj_get_type_params($cmd,'Command')}; + my $request_class = $cmd_classes{$cmd}; + my $reply_class = $cmd_classes{obj_get_type_params($cmd,'Command')}; print H "\tstd::auto_ptr<$reply_class> Query(const $request_class &rQuery$argextra);\n"; } } } - + if($writing_local) { print H <<__E; private: $context_class &mrContext; -__E - } - - print H <<__E; - -protected: - virtual std::auto_ptr<Message> MakeMessage(int ObjType); - -__E - - if($writing_local) - { - print H <<__E; - virtual void InformStreamReceiving(u_int32_t Size) { } - virtual void InformStreamSending(u_int32_t Size) { } - + std::auto_ptr<$message_base_class> mapLastReply; public: virtual std::auto_ptr<IOStream> ReceiveStream() { + if(mStreamsToSend.empty()) + { + THROW_EXCEPTION_MESSAGE(CommonException, Internal, + "Tried to ReceiveStream when none was sent or " + "made available"); + } + std::auto_ptr<IOStream> apStream(mStreamsToSend.front()); mStreamsToSend.pop_front(); return apStream; @@ -815,29 +905,33 @@ __E else { print H <<__E; - virtual void InformStreamReceiving(u_int32_t Size) - { - this->Protocol::InformStreamReceiving(Size); - } - virtual void InformStreamSending(u_int32_t Size) - { - this->Protocol::InformStreamSending(Size); - } + virtual std::auto_ptr<IOStream> ReceiveStream(); +__E -public: - virtual std::auto_ptr<IOStream> ReceiveStream() + print CPP <<__E; +std::auto_ptr<IOStream> $server_or_client_class\::ReceiveStream() +{ + try { - return this->Protocol::ReceiveStream(); - } -__E + return $custom_protocol_subclass\::ReceiveStream(); } - - print H <<__E; - virtual const char *GetProtocolIdentString() + catch(ConnectionException &e) { - return GetIdentString(); + if(e.GetSubType() == ConnectionException::Protocol_ObjWhenStreamExpected) + { + THROW_EXCEPTION_MESSAGE(ConnectionException, + Protocol_ObjWhenStreamExpected, + "Last exchange was " << mPreviousCommand << + " => " << mPreviousReply); + } + else + { + throw; + } } +} __E + } if($writing_local) { @@ -853,23 +947,13 @@ __E print H <<__E; virtual int GetTimeout() { - return this->Protocol::GetTimeout(); + return $custom_protocol_subclass\::GetTimeout(); } __E } - + print H <<__E; - /* - virtual void Handshake() - { - this->Protocol::Handshake(); - } - virtual bool GetLastError(int &rTypeOut, int &rSubTypeOut) - { - return this->Protocol::GetLastError(rTypeOut, rSubTypeOut); - } - */ - + private: $server_or_client_class(const $server_or_client_class &rToCopy); /* no copies */ }; @@ -890,8 +974,8 @@ __E else { print CPP <<__E; -$server_or_client_class\::$server_or_client_class(IOStream &rStream) -: Protocol(rStream) +$server_or_client_class\::$server_or_client_class(std::auto_ptr<SocketStream> apConn) +: $custom_protocol_subclass(apConn) { } __E } @@ -903,49 +987,58 @@ $server_or_client_class\::~$server_or_client_class() __E # write receive and send functions - print CPP <<__E; -std::auto_ptr<Message> $server_or_client_class\::MakeMessage(int ObjType) -{ - switch(ObjType) - { -__E - - # do objects within this - for my $cmd (@cmd_list) + if($writing_local) { print CPP <<__E; - case $cmd_id{$cmd}: - return std::auto_ptr<Message>(new $cmd_class{$cmd}()); - break; -__E - } - - print CPP <<__E; - default: - THROW_EXCEPTION(ConnectionException, Conn_Protocol_UnknownCommandRecieved) - } +std::auto_ptr<$message_base_class> $server_or_client_class\::Receive() +{ + return mapLastReply; +} +void $server_or_client_class\::Send(const $message_base_class &rObject) +{ + mapLastReply = rObject.DoCommand(*this, mrContext); } __E - - if(not $writing_local) + } + else { print CPP <<__E; std::auto_ptr<$message_base_class> $server_or_client_class\::Receive() { - std::auto_ptr<$message_base_class> preply(($message_base_class *) - Protocol::ReceiveInternal().release()); + std::auto_ptr<$message_base_class> apReply; + + try + { + apReply = std::auto_ptr<$message_base_class>( + static_cast<$message_base_class *> + ($custom_protocol_subclass\::ReceiveInternal().release())); + } + catch(ConnectionException &e) + { + if(e.GetSubType() == ConnectionException::Protocol_StreamWhenObjExpected) + { + THROW_EXCEPTION_MESSAGE(ConnectionException, + Protocol_StreamWhenObjExpected, + "Last exchange was " << mPreviousCommand << + " => " << mPreviousReply); + } + else + { + throw; + } + } if(GetLogToSysLog()) { - preply->LogSysLog("Receive"); + apReply->LogSysLog("Receive"); } if(GetLogToFile() != 0) { - preply->LogFile("Receive", GetLogToFile()); + apReply->LogFile("Receive", GetLogToFile()); } - return preply; + return apReply; } void $server_or_client_class\::Send(const $message_base_class &rObject) @@ -981,10 +1074,40 @@ void $server_or_client_class\::DoServer($context_class &rContext) { // Get an object from the conversation std::auto_ptr<$message_base_class> pobj = Receive(); + std::auto_ptr<$message_base_class> preply; // Run the command - std::auto_ptr<$message_base_class> preply = pobj->DoCommand(*this, rContext); - + try + { + try + { + if(pobj->HasStreamWithCommand()) + { + std::auto_ptr<IOStream> apDataStream = ReceiveStream(); + SelfFlushingStream autoflush(*apDataStream); + preply = pobj->DoCommand(*this, rContext, *apDataStream); + } + else + { + preply = pobj->DoCommand(*this, rContext); + } + } + catch(BoxException &e) + { + // First try a the built-in exception handler + preply = HandleException(e); + } + } + catch (...) + { + // Fallback in case the exception isn't a BoxException + // or the exception handler fails as well. This path + // throws the exception upwards, killing the process + // that handles the current client. + Send($cmd_classes{$error_message}(-1)); + throw; + } + // Send the reply Send(*preply); @@ -995,7 +1118,16 @@ void $server_or_client_class\::DoServer($context_class &rContext) { SendStream(**i); } - + + // As a server, if we get an unexpected message later, we'll + // want to know the last command that we received, and the + // reply, to help debug our response to it. + mPreviousCommand = pobj->ToString(); + std::ostringstream reply; + reply << preply->ToString() << " and " << + mStreamsToSend.size() << " streams"; + mPreviousReply = reply.str(); + // Delete these streams DeleteStreamsToSend(); @@ -1004,7 +1136,7 @@ void $server_or_client_class\::DoServer($context_class &rContext) { inProgress = false; } - } + } } __E @@ -1017,67 +1149,86 @@ __E { if(obj_is_type($cmd,'Command')) { - my $request_class = $cmd_class{$cmd}; + my $request_class = $cmd_classes{$cmd}; my $reply_msg = obj_get_type_params($cmd,'Command'); - my $reply_class = $cmd_class{$reply_msg}; + my $reply_class = $cmd_classes{$reply_msg}; my $reply_id = $cmd_id{$reply_msg}; my $has_stream = obj_is_type($cmd,'StreamWithCommand'); - my $argextra = $has_stream?', std::auto_ptr<IOStream> apStream':''; + my $argextra = $has_stream?', std::auto_ptr<IOStream> apDataStream':''; my $send_stream_extra = ''; - my $send_stream_method = $writing_client ? "SendStream" - : "SendStreamAfterCommand"; - + + print CPP <<__E; +std::auto_ptr<$reply_class> $server_or_client_class\::Query(const $request_class &rQuery$argextra) +{ +__E + if($writing_client) { if($has_stream) { $send_stream_extra = <<__E; // Send stream after the command - SendStream(*apStream); + try + { + SendStream(*apDataStream); + } + catch (BoxException &e) + { + BOX_WARNING("Failed to send stream after command: " << + rQuery.ToString() << ": " << e.what()); + throw; + } __E } print CPP <<__E; -std::auto_ptr<$reply_class> $server_or_client_class\::Query(const $request_class &rQuery$argextra) -{ // Send query Send(rQuery); - $send_stream_extra +$send_stream_extra // Wait for the reply - std::auto_ptr<$message_base_class> preply = Receive(); - - CheckReply("$cmd", *preply, $reply_id); - - // Correct response, if no exception thrown by CheckReply - return std::auto_ptr<$reply_class>(($reply_class *)preply.release()); -} + std::auto_ptr<$message_base_class> apReply = Receive(); __E } elsif($writing_local) { + print CPP <<__E; + std::auto_ptr<$message_base_class> apReply; + try + { +__E if($has_stream) { - $send_stream_extra = <<__E; - // Send stream after the command - SendStreamAfterCommand(apStream); + print CPP <<__E; + apReply = rQuery.DoCommand(*this, mrContext, *apDataStream); +__E + } + else + { + print CPP <<__E; + apReply = rQuery.DoCommand(*this, mrContext); __E } print CPP <<__E; -std::auto_ptr<$reply_class> $server_or_client_class\::Query(const $request_class &rQuery$argextra) -{ - // Send query - $send_stream_extra - std::auto_ptr<$message_base_class> preply = rQuery.DoCommand(*this, mrContext); - - CheckReply("$cmd", *preply, $reply_id); + } + catch(BoxException &e) + { + // First try a the built-in exception handler + apReply = HandleException(e); + } +__E + } + + # Common to both client and local + print CPP <<__E; + CheckReply("$cmd", rQuery, *apReply, $reply_id); // Correct response, if no exception thrown by CheckReply - return std::auto_ptr<$reply_class>(($reply_class *)preply.release()); + return std::auto_ptr<$reply_class>( + static_cast<$reply_class *>(apReply.release())); } __E - } } } } @@ -1110,7 +1261,7 @@ sub obj_get_type_params { return $1 if $_ =~ m/\A$ty\((.+?)\)\Z/; } - die "Can't find attribute $ty\n" + die "Can't find attribute $ty on command $c\n" } # returns (is basic type, typename) |