diff options
Diffstat (limited to 'lib/server')
35 files changed, 8093 insertions, 0 deletions
diff --git a/lib/server/ConnectionException.txt b/lib/server/ConnectionException.txt new file mode 100644 index 00000000..c3429116 --- /dev/null +++ b/lib/server/ConnectionException.txt @@ -0,0 +1,27 @@ +EXCEPTION Connection 7 + +# for historic reasons not all numbers are used + +SocketWriteError 6 Probably a network issue between client and server. +SocketReadError 7 Probably a network issue between client and server. +SocketNameLookupError 9 Check hostname specified. +SocketShutdownError 12 +SocketConnectError 15 Probably a network issue between client and server, bad hostname, or server not running. +TLSHandshakeFailed 30 +TLSShutdownFailed 32 +TLSWriteFailed 33 Probably a network issue between client and server. +TLSReadFailed 34 Probably a network issue between client and server, or a problem with the server. +TLSNoPeerCertificate 36 +TLSPeerCertificateInvalid 37 Check certification process +TLSClosedWhenWriting 38 +TLSHandshakeTimedOut 39 +Protocol_Timeout 41 Probably a network issue between client and server. +Protocol_ObjTooBig 42 +Protocol_BadCommandRecieved 44 +Protocol_UnknownCommandRecieved 45 +Protocol_TriedToExecuteReplyCommand 46 +Protocol_UnexpectedReply 47 Server probably reported an error. +Protocol_HandshakeFailed 48 +Protocol_StreamWhenObjExpected 49 +Protocol_ObjWhenStreamExpected 50 +Protocol_TimeOutWhenSendingStream 52 Probably a network issue between client and server. diff --git a/lib/server/Daemon.cpp b/lib/server/Daemon.cpp new file mode 100644 index 00000000..8b4f1d0c --- /dev/null +++ b/lib/server/Daemon.cpp @@ -0,0 +1,1024 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: Daemon.cpp +// Purpose: Basic daemon functionality +// Created: 2003/07/29 +// +// -------------------------------------------------------------------------- + +#include "Box.h" + +#ifdef HAVE_UNISTD_H + #include <unistd.h> +#endif + +#include <errno.h> +#include <stdio.h> +#include <signal.h> +#include <string.h> +#include <stdarg.h> + +#ifdef HAVE_BSD_UNISTD_H + #include <bsd/unistd.h> +#endif + +#ifdef WIN32 + #include <ws2tcpip.h> +#endif + +#include <iostream> + +#include "Daemon.h" +#include "Configuration.h" +#include "ServerException.h" +#include "Guards.h" +#include "UnixUser.h" +#include "FileModificationTime.h" +#include "Logging.h" +#include "Utils.h" + +#include "MemLeakFindOn.h" + +Daemon *Daemon::spDaemon = 0; + + +// -------------------------------------------------------------------------- +// +// Function +// Name: Daemon::Daemon() +// Purpose: Constructor +// Created: 2003/07/29 +// +// -------------------------------------------------------------------------- +Daemon::Daemon() + : mReloadConfigWanted(false), + mTerminateWanted(false), + #ifdef WIN32 + mSingleProcess(true), + mRunInForeground(true), + mKeepConsoleOpenAfterFork(true), + #else + mSingleProcess(false), + mRunInForeground(false), + mKeepConsoleOpenAfterFork(false), + #endif + mHaveConfigFile(false), + mAppName(DaemonName()) +{ + // In debug builds, switch on assert failure logging to syslog + ASSERT_FAILS_TO_SYSLOG_ON + // And trace goes to syslog too + TRACE_TO_SYSLOG(true) +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Daemon::~Daemon() +// Purpose: Destructor +// Created: 2003/07/29 +// +// -------------------------------------------------------------------------- +Daemon::~Daemon() +{ +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Daemon::GetOptionString() +// Purpose: Returns the valid Getopt command-line options. +// This should be overridden by subclasses to add +// their own options, which should override +// ProcessOption, handle their own, and delegate to +// ProcessOption for the standard options. +// Created: 2007/09/18 +// +// -------------------------------------------------------------------------- +std::string Daemon::GetOptionString() +{ + return "c:" + #ifndef WIN32 + "DFK" + #endif + "hkPqQt:TUvVW:"; +} + +void Daemon::Usage() +{ + std::cout << + DaemonBanner() << "\n" + "\n" + "Usage: " << mAppName << " [options] [config file]\n" + "\n" + "Options:\n" + " -c <file> Use the specified configuration file. If -c is omitted, the last\n" + " argument is the configuration file, or else the default \n" + " [" << GetConfigFileName() << "]\n" +#ifndef WIN32 + " -D Debugging mode, do not fork, one process only, one client only\n" + " -F Do not fork into background, but fork to serve multiple clients\n" +#endif + " -k Keep console open after fork, keep writing log messages to it\n" +#ifndef WIN32 + " -K Stop writing log messages to console while daemon is running\n" + " -P Show process ID (PID) in console output\n" +#endif + " -q Run more quietly, reduce verbosity level by one, can repeat\n" + " -Q Run at minimum verbosity, log nothing\n" + " -v Run more verbosely, increase verbosity level by one, can repeat\n" + " -V Run at maximum verbosity, log everything\n" + " -W <level> Set verbosity to error/warning/notice/info/trace/everything\n" + " -t <tag> Tag console output with specified marker\n" + " -T Timestamp console output\n" + " -U Timestamp console output with microseconds\n"; +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Daemon::ProcessOption(int option) +// Purpose: Processes the supplied option (equivalent to the +// return code from getopt()). Return zero if the +// option was handled successfully, or nonzero to +// abort the program with that return value. +// Created: 2007/09/18 +// +// -------------------------------------------------------------------------- +int Daemon::ProcessOption(signed int option) +{ + switch(option) + { + case 'c': + { + mConfigFileName = optarg; + mHaveConfigFile = true; + } + break; + +#ifndef WIN32 + case 'D': + { + mSingleProcess = true; + } + break; + + case 'F': + { + mRunInForeground = true; + } + break; +#endif // !WIN32 + + case 'k': + { + mKeepConsoleOpenAfterFork = true; + } + break; + + case 'K': + { + mKeepConsoleOpenAfterFork = false; + } + break; + + case 'h': + { + Usage(); + return 2; + } + break; + + case 'P': + { + Console::SetShowPID(true); + } + break; + + case 'q': + { + if(mLogLevel == Log::NOTHING) + { + BOX_FATAL("Too many '-q': " + "Cannot reduce logging " + "level any more"); + return 2; + } + mLogLevel--; + } + break; + + case 'Q': + { + mLogLevel = Log::NOTHING; + } + break; + + + case 'v': + { + if(mLogLevel == Log::EVERYTHING) + { + BOX_FATAL("Too many '-v': " + "Cannot increase logging " + "level any more"); + return 2; + } + mLogLevel++; + } + break; + + case 'V': + { + mLogLevel = Log::EVERYTHING; + } + break; + + case 'W': + { + mLogLevel = Logging::GetNamedLevel(optarg); + if (mLogLevel == Log::INVALID) + { + BOX_FATAL("Invalid logging level"); + return 2; + } + } + break; + + case 't': + { + Logging::SetProgramName(optarg); + Console::SetShowTag(true); + } + break; + + case 'T': + { + Console::SetShowTime(true); + } + break; + + case 'U': + { + Console::SetShowTime(true); + Console::SetShowTimeMicros(true); + } + break; + + case '?': + { + BOX_FATAL("Unknown option on command line: " + << "'" << (char)optopt << "'"); + return 2; + } + break; + + default: + { + BOX_FATAL("Unknown error in getopt: returned " + << "'" << option << "'"); + return 1; + } + } + + return 0; +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Daemon::Main(const char *, int, const char *[]) +// Purpose: Parses command-line options, and then calls +// Main(std::string& configFile, bool singleProcess) +// to start the daemon. +// Created: 2003/07/29 +// +// -------------------------------------------------------------------------- +int Daemon::Main(const char *DefaultConfigFile, int argc, const char *argv[]) +{ + // Find filename of config file + mConfigFileName = DefaultConfigFile; + mAppName = argv[0]; + + #ifdef BOX_RELEASE_BUILD + mLogLevel = Log::NOTICE; // need an int to do math with + #else + mLogLevel = Log::INFO; // need an int to do math with + #endif + + if (argc == 2 && strcmp(argv[1], "/?") == 0) + { + Usage(); + return 2; + } + + signed int c; + + // reset getopt, just in case anybody used it before. + // unfortunately glibc and BSD differ on this point! + // http://www.ussg.iu.edu/hypermail/linux/kernel/0305.3/0262.html + #if HAVE_DECL_OPTRESET == 1 || defined WIN32 + optind = 1; + optreset = 1; + #elif defined __GLIBC__ + optind = 0; + #else // Solaris, any others? + optind = 1; + #endif + + while((c = getopt(argc, (char * const *)argv, + GetOptionString().c_str())) != -1) + { + int returnCode = ProcessOption(c); + + if (returnCode != 0) + { + return returnCode; + } + } + + if (argc > optind && !mHaveConfigFile) + { + mConfigFileName = argv[optind]; optind++; + mHaveConfigFile = true; + } + + if (argc > optind && ::strcmp(argv[optind], "SINGLEPROCESS") == 0) + { + mSingleProcess = true; optind++; + } + + if (argc > optind) + { + BOX_FATAL("Unknown parameter on command line: " + << "'" << std::string(argv[optind]) << "'"); + return 2; + } + + Logging::FilterConsole((Log::Level)mLogLevel); + Logging::FilterSyslog ((Log::Level)mLogLevel); + + return Main(mConfigFileName); +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Daemon::Configure(const std::string& rConfigFileName) +// Purpose: Loads daemon configuration. Useful when you have +// a local Daemon object and don't intend to fork() +// or call Main(). +// Created: 2008/04/19 +// +// -------------------------------------------------------------------------- + +bool Daemon::Configure(const std::string& rConfigFileName) +{ + // Load the configuration file. + std::string errors; + std::auto_ptr<Configuration> apConfig; + + try + { + if (!FileExists(rConfigFileName.c_str())) + { + BOX_FATAL("The main configuration file for " << + DaemonName() << " was not found: " << + rConfigFileName); + if (!mHaveConfigFile) + { + BOX_WARNING("The default configuration " + "directory has changed from /etc/box " + "to /etc/boxbackup"); + } + return false; + } + + apConfig = Configuration::LoadAndVerify(rConfigFileName, + GetConfigVerify(), errors); + } + catch(BoxException &e) + { + if(e.GetType() == CommonException::ExceptionType && + e.GetSubType() == CommonException::OSFileOpenError) + { + BOX_ERROR("Failed to open configuration file: " << + rConfigFileName); + return false; + } + + throw; + } + + // Got errors? + if(apConfig.get() == 0) + { + BOX_ERROR("Failed to load or verify configuration file"); + return false; + } + + if(!Configure(*apConfig)) + { + BOX_ERROR("Failed to verify configuration file"); + return false; + } + + // Store configuration + mConfigFileName = rConfigFileName; + mLoadedConfigModifiedTime = GetConfigFileModifiedTime(); + + return true; +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Daemon::Configure(const Configuration& rConfig) +// Purpose: Loads daemon configuration. Useful when you have +// a local Daemon object and don't intend to fork() +// or call Main(). +// Created: 2008/08/12 +// +// -------------------------------------------------------------------------- + +bool Daemon::Configure(const Configuration& rConfig) +{ + std::string errors; + + // Verify() may modify the configuration, e.g. adding default values + // for required keys, so need to make a copy here + std::auto_ptr<Configuration> apConf(new Configuration(rConfig)); + apConf->Verify(*GetConfigVerify(), errors); + + // Got errors? + if(!errors.empty()) + { + BOX_ERROR("Configuration errors: " << errors); + return false; + } + + // Store configuration + mapConfiguration = apConf; + + // Let the derived class have a go at setting up stuff + // in the initial process + SetupInInitialProcess(); + + return true; +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Daemon::Main(const std::string& rConfigFileName) +// Purpose: Starts the daemon off -- equivalent of C main() function +// Created: 2003/07/29 +// +// -------------------------------------------------------------------------- +int Daemon::Main(const std::string &rConfigFileName) +{ + // Banner (optional) + { + BOX_SYSLOG(Log::NOTICE, DaemonBanner()); + } + + std::string pidFileName; + + bool asDaemon = !mSingleProcess && !mRunInForeground; + + try + { + if (!Configure(rConfigFileName)) + { + BOX_FATAL("Failed to start: failed to load " + "configuration file: " << rConfigFileName); + return 1; + } + + // Server configuration + const Configuration &serverConfig( + mapConfiguration->GetSubConfiguration("Server")); + + if(serverConfig.KeyExists("LogFacility")) + { + std::string facility = + serverConfig.GetKeyValue("LogFacility"); + Logging::SetFacility(Syslog::GetNamedFacility(facility)); + } + + // Open PID file for writing + pidFileName = serverConfig.GetKeyValue("PidFile"); + FileHandleGuard<(O_WRONLY | O_CREAT | O_TRUNC), (S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH)> pidFile(pidFileName.c_str()); + +#ifndef WIN32 + // Handle changing to a different user + if(serverConfig.KeyExists("User")) + { + // Config file specifies an user -- look up + UnixUser daemonUser(serverConfig.GetKeyValue("User").c_str()); + + // Change the owner on the PID file, so it can be deleted properly on termination + if(::fchown(pidFile, daemonUser.GetUID(), daemonUser.GetGID()) != 0) + { + THROW_EXCEPTION(ServerException, CouldNotChangePIDFileOwner) + } + + // Change the process ID + daemonUser.ChangeProcessUser(); + } + + if(asDaemon) + { + // Let's go... Daemonise... + switch(::fork()) + { + case -1: + // error + THROW_EXCEPTION(ServerException, DaemoniseFailed) + break; + + default: + // parent + // _exit(0); + return 0; + break; + + case 0: + // child + break; + } + + // In child + + // Set new session + if(::setsid() == -1) + { + BOX_LOG_SYS_ERROR("Failed to setsid()"); + THROW_EXCEPTION(ServerException, DaemoniseFailed) + } + + // Fork again... + switch(::fork()) + { + case -1: + // error + BOX_LOG_SYS_ERROR("Failed to fork() a child"); + THROW_EXCEPTION(ServerException, DaemoniseFailed) + break; + + default: + // parent + _exit(0); + return 0; + break; + + case 0: + // child + break; + } + } +#endif // !WIN32 + + // Must set spDaemon before installing signal handler, + // otherwise the handler will crash if invoked too soon. + if(spDaemon != NULL) + { + THROW_EXCEPTION(ServerException, AlreadyDaemonConstructed) + } + spDaemon = this; + +#ifndef WIN32 + // Set signal handler + // Don't do this in the parent, since it might be anything + // (e.g. test/bbackupd) + + struct sigaction sa; + sa.sa_handler = SignalHandler; + sa.sa_flags = 0; + sigemptyset(&sa.sa_mask); // macro + if(::sigaction(SIGHUP, &sa, NULL) != 0 || + ::sigaction(SIGTERM, &sa, NULL) != 0) + { + BOX_LOG_SYS_ERROR("Failed to set signal handlers"); + THROW_EXCEPTION(ServerException, DaemoniseFailed) + } +#endif // !WIN32 + + // Write PID to file + char pid[32]; + + int pidsize = sprintf(pid, "%d", (int)getpid()); + + if(::write(pidFile, pid, pidsize) != pidsize) + { + BOX_LOG_SYS_FATAL("Failed to write PID file: " << + pidFileName); + THROW_EXCEPTION(ServerException, DaemoniseFailed) + } + + // Set up memory leak reporting + #ifdef BOX_MEMORY_LEAK_TESTING + { + char filename[256]; + sprintf(filename, "%s.memleaks", DaemonName()); + memleakfinder_setup_exit_report(filename, DaemonName()); + } + #endif // BOX_MEMORY_LEAK_TESTING + + if(asDaemon && !mKeepConsoleOpenAfterFork) + { +#ifndef WIN32 + // Close standard streams + ::close(0); + ::close(1); + ::close(2); + + // Open and redirect them into /dev/null + int devnull = ::open(PLATFORM_DEV_NULL, O_RDWR, 0); + if(devnull == -1) + { + BOX_LOG_SYS_ERROR("Failed to open /dev/null"); + THROW_EXCEPTION(CommonException, OSFileError); + } + // Then duplicate them to all three handles + if(devnull != 0) dup2(devnull, 0); + if(devnull != 1) dup2(devnull, 1); + if(devnull != 2) dup2(devnull, 2); + // Close the original handle if it was opened above the std* range + if(devnull > 2) + { + ::close(devnull); + } + + // And definitely don't try and send anything to those file descriptors + // -- this has in the past sent text to something which isn't expecting it. + TRACE_TO_STDOUT(false); +#endif // ! WIN32 + Logging::ToConsole(false); + } + + // Log the start message + BOX_NOTICE("Starting daemon, version: " << BOX_VERSION); + BOX_NOTICE("Using configuration file: " << mConfigFileName); + } + catch(BoxException &e) + { + BOX_FATAL("Failed to start: exception " << e.what() + << " (" << e.GetType() + << "/" << e.GetSubType() << ")"); + return 1; + } + catch(std::exception &e) + { + BOX_FATAL("Failed to start: exception " << e.what()); + return 1; + } + catch(...) + { + BOX_FATAL("Failed to start: unknown error"); + return 1; + } + +#ifdef WIN32 + // Under win32 we must initialise the Winsock library + // before using sockets + + WSADATA info; + + if (WSAStartup(0x0101, &info) == SOCKET_ERROR) + { + // will not run without sockets + BOX_FATAL("Failed to initialise Windows Sockets"); + THROW_EXCEPTION(CommonException, Internal) + } +#endif + + int retcode = 0; + + // Main Daemon running + try + { + while(!mTerminateWanted) + { + Run(); + + if(mReloadConfigWanted && !mTerminateWanted) + { + // Need to reload that config file... + BOX_NOTICE("Reloading configuration file: " + << mConfigFileName); + std::string errors; + std::auto_ptr<Configuration> pconfig( + Configuration::LoadAndVerify( + mConfigFileName.c_str(), + GetConfigVerify(), errors)); + + // Got errors? + if(pconfig.get() == 0 || !errors.empty()) + { + // Tell user about errors + BOX_FATAL("Error in configuration " + << "file: " << mConfigFileName + << ": " << errors); + // And give up + retcode = 1; + break; + } + + // Store configuration + mapConfiguration = pconfig; + mLoadedConfigModifiedTime = + GetConfigFileModifiedTime(); + + // Stop being marked for loading config again + mReloadConfigWanted = false; + } + } + + // Delete the PID file + ::unlink(pidFileName.c_str()); + + // Log + BOX_NOTICE("Terminating daemon"); + } + catch(BoxException &e) + { + BOX_FATAL("Terminating due to exception " << e.what() + << " (" << e.GetType() + << "/" << e.GetSubType() << ")"); + retcode = 1; + } + catch(std::exception &e) + { + BOX_FATAL("Terminating due to exception " << e.what()); + retcode = 1; + } + catch(...) + { + BOX_FATAL("Terminating due to unknown exception"); + retcode = 1; + } + +#ifdef WIN32 + WSACleanup(); +#else + // Should clean up here, but it breaks memory leak tests. + /* + if(asDaemon) + { + // we are running in the child by now, and should not return + mapConfiguration.reset(); + exit(0); + } + */ +#endif + + ASSERT(spDaemon == this); + spDaemon = NULL; + + return retcode; +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Daemon::EnterChild() +// Purpose: Sets up for a child task of the main server. Call +// just after fork(). +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- +void Daemon::EnterChild() +{ +#ifndef WIN32 + // Unset signal handlers + struct sigaction sa; + sa.sa_handler = SIG_DFL; + sa.sa_flags = 0; + sigemptyset(&sa.sa_mask); // macro + ::sigaction(SIGHUP, &sa, NULL); + ::sigaction(SIGTERM, &sa, NULL); +#endif +} + + +// -------------------------------------------------------------------------- +// +// Function +// Name: Daemon::SignalHandler(int) +// Purpose: Signal handler +// Created: 2003/07/29 +// +// -------------------------------------------------------------------------- +void Daemon::SignalHandler(int sigraised) +{ +#ifndef WIN32 + if(spDaemon != 0) + { + switch(sigraised) + { + case SIGHUP: + spDaemon->mReloadConfigWanted = true; + break; + + case SIGTERM: + spDaemon->mTerminateWanted = true; + break; + + default: + break; + } + } +#endif +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Daemon::DaemonName() +// Purpose: Returns name of the daemon +// Created: 2003/07/29 +// +// -------------------------------------------------------------------------- +const char *Daemon::DaemonName() const +{ + return "generic-daemon"; +} + + +// -------------------------------------------------------------------------- +// +// Function +// Name: Daemon::DaemonBanner() +// Purpose: Returns the text banner for this daemon's startup +// Created: 1/1/04 +// +// -------------------------------------------------------------------------- +std::string Daemon::DaemonBanner() const +{ + return "Generic daemon using the Box Application Framework"; +} + + +// -------------------------------------------------------------------------- +// +// Function +// Name: Daemon::Run() +// Purpose: Main run function after basic Daemon initialisation +// Created: 2003/07/29 +// +// -------------------------------------------------------------------------- +void Daemon::Run() +{ + while(!StopRun()) + { + ::sleep(10); + } +} + + +// -------------------------------------------------------------------------- +// +// Function +// Name: Daemon::GetConfigVerify() +// Purpose: Returns the configuration file verification structure for this daemon +// Created: 2003/07/29 +// +// -------------------------------------------------------------------------- +const ConfigurationVerify *Daemon::GetConfigVerify() const +{ + static ConfigurationVerifyKey verifyserverkeys[] = + { + DAEMON_VERIFY_SERVER_KEYS + }; + + static ConfigurationVerify verifyserver[] = + { + { + "Server", + 0, + verifyserverkeys, + ConfigTest_Exists | ConfigTest_LastEntry, + 0 + } + }; + + static ConfigurationVerify verify = + { + "root", + verifyserver, + 0, + ConfigTest_Exists | ConfigTest_LastEntry, + 0 + }; + + return &verify; +} + + +// -------------------------------------------------------------------------- +// +// Function +// Name: Daemon::GetConfiguration() +// Purpose: Returns the daemon configuration object +// Created: 2003/07/29 +// +// -------------------------------------------------------------------------- +const Configuration &Daemon::GetConfiguration() const +{ + if(mapConfiguration.get() == 0) + { + // Shouldn't get anywhere near this if a configuration file can't be loaded + THROW_EXCEPTION(ServerException, Internal) + } + + return *mapConfiguration; +} + + +// -------------------------------------------------------------------------- +// +// Function +// Name: Daemon::SetupInInitialProcess() +// Purpose: A chance for the daemon to do something initial +// setting up in the process which initiates +// everything, and after the configuration file has +// been read and verified. +// Created: 2003/08/20 +// +// -------------------------------------------------------------------------- +void Daemon::SetupInInitialProcess() +{ + // Base class doesn't do anything. +} + + +void Daemon::SetProcessTitle(const char *format, ...) +{ + // On OpenBSD, setproctitle() sets the process title to imagename: <text> (imagename) + // -- make sure other platforms include the image name somewhere so ps listings give + // useful information. + +#ifdef HAVE_SETPROCTITLE + // optional arguments + va_list args; + va_start(args, format); + + // Make the string + char title[256]; + ::vsnprintf(title, sizeof(title), format, args); + + // Set process title + ::setproctitle("%s", title); + +#endif // HAVE_SETPROCTITLE +} + + +// -------------------------------------------------------------------------- +// +// Function +// Name: Daemon::GetConfigFileModifiedTime() +// Purpose: Returns the timestamp when the configuration file +// was last modified +// +// Created: 2006/01/29 +// +// -------------------------------------------------------------------------- + +box_time_t Daemon::GetConfigFileModifiedTime() const +{ + EMU_STRUCT_STAT st; + + if(EMU_STAT(GetConfigFileName().c_str(), &st) != 0) + { + if (errno == ENOENT) + { + return 0; + } + BOX_LOG_SYS_ERROR("Failed to stat configuration file: " << + GetConfigFileName()); + THROW_EXCEPTION(CommonException, OSFileError) + } + + return FileModificationTime(st); +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Daemon::GetLoadedConfigModifiedTime() +// Purpose: Returns the timestamp when the configuration file +// had been last modified, at the time when it was +// loaded +// +// Created: 2006/01/29 +// +// -------------------------------------------------------------------------- + +box_time_t Daemon::GetLoadedConfigModifiedTime() const +{ + return mLoadedConfigModifiedTime; +} + diff --git a/lib/server/Daemon.h b/lib/server/Daemon.h new file mode 100644 index 00000000..a3212a00 --- /dev/null +++ b/lib/server/Daemon.h @@ -0,0 +1,112 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: Daemon.h +// Purpose: Basic daemon functionality +// Created: 2003/07/29 +// +// -------------------------------------------------------------------------- + +/* NOTE: will log to local6: include a line like + local6.info /var/log/box + in /etc/syslog.conf +*/ + + +#ifndef DAEMON__H +#define DAEMON__H + +#include <string> + +#include "BoxTime.h" +#include "Configuration.h" + +class ConfigurationVerify; + +// -------------------------------------------------------------------------- +// +// Class +// Name: Daemon +// Purpose: Basic daemon functionality +// Created: 2003/07/29 +// +// -------------------------------------------------------------------------- +class Daemon +{ +public: + Daemon(); + virtual ~Daemon(); +private: + Daemon(const Daemon &rToCopy); +public: + + virtual int Main(const char *DefaultConfigFile, int argc, + const char *argv[]); + + /* override this Main() if you want custom option processing: */ + virtual int Main(const std::string &rConfigFile); + + virtual void Run(); + const Configuration &GetConfiguration() const; + const std::string &GetConfigFileName() const {return mConfigFileName;} + + virtual const char *DaemonName() const; + virtual std::string DaemonBanner() const; + virtual const ConfigurationVerify *GetConfigVerify() const; + virtual void Usage(); + + virtual bool Configure(const std::string& rConfigFileName); + virtual bool Configure(const Configuration& rConfig); + + bool StopRun() {return mReloadConfigWanted | mTerminateWanted;} + bool IsReloadConfigWanted() {return mReloadConfigWanted;} + bool IsTerminateWanted() {return mTerminateWanted;} + + // To allow derived classes to get these signals in other ways + void SetReloadConfigWanted() {mReloadConfigWanted = true;} + void SetTerminateWanted() {mTerminateWanted = true;} + + virtual void EnterChild(); + + static void SetProcessTitle(const char *format, ...); + void SetRunInForeground(bool foreground) + { + mRunInForeground = foreground; + } + void SetSingleProcess(bool value) + { + mSingleProcess = value; + } + +protected: + virtual void SetupInInitialProcess(); + box_time_t GetLoadedConfigModifiedTime() const; + bool IsSingleProcess() { return mSingleProcess; } + virtual std::string GetOptionString(); + virtual int ProcessOption(signed int option); + +private: + static void SignalHandler(int sigraised); + box_time_t GetConfigFileModifiedTime() const; + + std::string mConfigFileName; + std::auto_ptr<Configuration> mapConfiguration; + box_time_t mLoadedConfigModifiedTime; + bool mReloadConfigWanted; + bool mTerminateWanted; + bool mSingleProcess; + bool mRunInForeground; + bool mKeepConsoleOpenAfterFork; + bool mHaveConfigFile; + int mLogLevel; // need an int to do math with + static Daemon *spDaemon; + std::string mAppName; +}; + +#define DAEMON_VERIFY_SERVER_KEYS \ + ConfigurationVerifyKey("PidFile", ConfigTest_Exists), \ + ConfigurationVerifyKey("LogFacility", 0), \ + ConfigurationVerifyKey("User", ConfigTest_LastEntry) + +#endif // DAEMON__H + diff --git a/lib/server/LocalProcessStream.cpp b/lib/server/LocalProcessStream.cpp new file mode 100644 index 00000000..c331a135 --- /dev/null +++ b/lib/server/LocalProcessStream.cpp @@ -0,0 +1,180 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: LocalProcessStream.cpp +// Purpose: Opens a process, and presents stdin/stdout as a stream. +// Created: 12/3/04 +// +// -------------------------------------------------------------------------- + +#include "Box.h" + +#ifdef HAVE_SYS_SOCKET_H + #include <sys/socket.h> +#endif + +#ifdef HAVE_UNISTD_H + #include <unistd.h> +#endif + +#include "LocalProcessStream.h" +#include "autogen_ServerException.h" +#include "Utils.h" + +#ifdef WIN32 + #include "FileStream.h" +#else + #include "SocketStream.h" +#endif + +#include "MemLeakFindOn.h" + +#define MAX_ARGUMENTS 64 + +// -------------------------------------------------------------------------- +// +// Function +// Name: LocalProcessStream(const char *, pid_t &) +// Purpose: Run a new process, and return a stream giving access +// to its stdin and stdout (stdout and stderr on +// Win32). Returns the PID of the new process -- this +// must be waited on at some point to avoid zombies +// (except on Win32). +// Created: 12/3/04 +// +// -------------------------------------------------------------------------- +std::auto_ptr<IOStream> LocalProcessStream(const std::string& rCommandLine, + pid_t &rPidOut) +{ +#ifndef WIN32 + + // Split up command + std::vector<std::string> command; + SplitString(rCommandLine, ' ', command); + + // Build arguments + char *args[MAX_ARGUMENTS + 4]; + { + int a = 0; + std::vector<std::string>::const_iterator i(command.begin()); + while(a < MAX_ARGUMENTS && i != command.end()) + { + args[a++] = (char*)(*(i++)).c_str(); + } + args[a] = NULL; + } + + // Create a socket pair to communicate over. + int sv[2] = {-1,-1}; + if(::socketpair(AF_UNIX, SOCK_STREAM, PF_UNSPEC, sv) != 0) + { + THROW_EXCEPTION(ServerException, SocketPairFailed) + } + + std::auto_ptr<IOStream> stream(new SocketStream(sv[0])); + + // Fork + pid_t pid = 0; + switch(pid = vfork()) + { + case -1: // error + ::close(sv[0]); + ::close(sv[1]); + THROW_EXCEPTION(ServerException, ServerForkError) + break; + + case 0: // child + // Close end of the socket not being used + ::close(sv[0]); + // Duplicate the file handles to stdin and stdout + if(sv[1] != 0) ::dup2(sv[1], 0); + if(sv[1] != 1) ::dup2(sv[1], 1); + // Close the now redundant socket + if(sv[1] != 0 && sv[1] != 1) + { + ::close(sv[1]); + } + // Execute command! + ::execv(args[0], args); + ::_exit(127); // report error + break; + + default: + // just continue... + break; + } + + // Close the file descriptor not being used + ::close(sv[1]); + + // Return the stream object and PID + rPidOut = pid; + return stream; + +#else // WIN32 + + SECURITY_ATTRIBUTES secAttr; + secAttr.nLength = sizeof(SECURITY_ATTRIBUTES); + secAttr.bInheritHandle = TRUE; + secAttr.lpSecurityDescriptor = NULL; + + HANDLE writeInChild, readFromChild; + if(!CreatePipe(&readFromChild, &writeInChild, &secAttr, 0)) + { + BOX_ERROR("Failed to CreatePipe for child process: " << + GetErrorMessage(GetLastError())); + THROW_EXCEPTION(ServerException, SocketPairFailed) + } + SetHandleInformation(readFromChild, HANDLE_FLAG_INHERIT, 0); + + PROCESS_INFORMATION procInfo; + STARTUPINFO startupInfo; + + ZeroMemory(&procInfo, sizeof(procInfo)); + ZeroMemory(&startupInfo, sizeof(startupInfo)); + startupInfo.cb = sizeof(startupInfo); + startupInfo.hStdError = writeInChild; + startupInfo.hStdOutput = writeInChild; + startupInfo.hStdInput = INVALID_HANDLE_VALUE; + startupInfo.dwFlags |= STARTF_USESTDHANDLES; + + CHAR* commandLineCopy = (CHAR*)malloc(rCommandLine.size() + 1); + strcpy(commandLineCopy, rCommandLine.c_str()); + + BOOL result = CreateProcess(NULL, + commandLineCopy, // command line + NULL, // process security attributes + NULL, // primary thread security attributes + TRUE, // handles are inherited + 0, // creation flags + NULL, // use parent's environment + NULL, // use parent's current directory + &startupInfo, // STARTUPINFO pointer + &procInfo); // receives PROCESS_INFORMATION + + free(commandLineCopy); + + if(!result) + { + BOX_ERROR("Failed to CreateProcess: '" << rCommandLine << + "': " << GetErrorMessage(GetLastError())); + CloseHandle(writeInChild); + CloseHandle(readFromChild); + THROW_EXCEPTION(ServerException, ServerForkError) + } + + CloseHandle(procInfo.hProcess); + CloseHandle(procInfo.hThread); + CloseHandle(writeInChild); + + rPidOut = (int)(procInfo.dwProcessId); + + std::auto_ptr<IOStream> stream(new FileStream(readFromChild)); + return stream; + +#endif // ! WIN32 +} + + + + diff --git a/lib/server/LocalProcessStream.h b/lib/server/LocalProcessStream.h new file mode 100644 index 00000000..51e51f8a --- /dev/null +++ b/lib/server/LocalProcessStream.h @@ -0,0 +1,20 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: LocalProcessStream.h +// Purpose: Opens a process, and presents stdin/stdout as a stream. +// Created: 12/3/04 +// +// -------------------------------------------------------------------------- + +#ifndef LOCALPROCESSSTREAM__H +#define LOCALPROCESSSTREAM__H + +#include <memory> +#include "IOStream.h" + +std::auto_ptr<IOStream> LocalProcessStream(const std::string& rCommandLine, + pid_t &rPidOut); + +#endif // LOCALPROCESSSTREAM__H + diff --git a/lib/server/Makefile.extra b/lib/server/Makefile.extra new file mode 100644 index 00000000..7fc6baf9 --- /dev/null +++ b/lib/server/Makefile.extra @@ -0,0 +1,11 @@ + +MAKEEXCEPTION = ../../lib/common/makeexception.pl + +# AUTOGEN SEEDING +autogen_ServerException.h autogen_ServerException.cpp: $(MAKEEXCEPTION) ServerException.txt + $(_PERL) $(MAKEEXCEPTION) ServerException.txt + +# AUTOGEN SEEDING +autogen_ConnectionException.h autogen_ConnectionException.cpp: $(MAKEEXCEPTION) ConnectionException.txt + $(_PERL) $(MAKEEXCEPTION) ConnectionException.txt + diff --git a/lib/server/OverlappedIO.h b/lib/server/OverlappedIO.h new file mode 100644 index 00000000..12495053 --- /dev/null +++ b/lib/server/OverlappedIO.h @@ -0,0 +1,42 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: OverlappedIO.h +// Purpose: Windows overlapped IO handle guard +// Created: 2008/09/30 +// +// -------------------------------------------------------------------------- + +#ifndef OVERLAPPEDIO__H +#define OVERLAPPEDIO__H + +class OverlappedIO +{ +public: + OVERLAPPED mOverlapped; + + OverlappedIO() + { + ZeroMemory(&mOverlapped, sizeof(mOverlapped)); + mOverlapped.hEvent = CreateEvent(NULL, TRUE, FALSE, + NULL); + if (mOverlapped.hEvent == INVALID_HANDLE_VALUE) + { + BOX_LOG_WIN_ERROR("Failed to create event for " + "overlapped I/O"); + THROW_EXCEPTION(ServerException, BadSocketHandle); + } + } + + ~OverlappedIO() + { + if (CloseHandle(mOverlapped.hEvent) != TRUE) + { + BOX_LOG_WIN_ERROR("Failed to delete event for " + "overlapped I/O"); + THROW_EXCEPTION(ServerException, BadSocketHandle); + } + } +}; + +#endif // !OVERLAPPEDIO__H diff --git a/lib/server/Protocol.cpp b/lib/server/Protocol.cpp new file mode 100644 index 00000000..5dc5d0b1 --- /dev/null +++ b/lib/server/Protocol.cpp @@ -0,0 +1,1160 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: Protocol.cpp +// Purpose: Generic protocol support +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- + +#include "Box.h" + +#include <sys/types.h> + +#include <stdlib.h> +#include <string.h> + +#include <new> + +#include "Protocol.h" +#include "ProtocolWire.h" +#include "IOStream.h" +#include "ServerException.h" +#include "PartialReadStream.h" +#include "ProtocolUncertainStream.h" +#include "Logging.h" + +#include "MemLeakFindOn.h" + +#ifdef BOX_RELEASE_BUILD + #define PROTOCOL_ALLOCATE_SEND_BLOCK_CHUNK 1024 +#else +// #define PROTOCOL_ALLOCATE_SEND_BLOCK_CHUNK 1024 + #define PROTOCOL_ALLOCATE_SEND_BLOCK_CHUNK 4 +#endif + +#define UNCERTAIN_STREAM_SIZE_BLOCK (64*1024) + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::Protocol(IOStream &rStream) +// Purpose: Constructor +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +Protocol::Protocol(IOStream &rStream) + : mrStream(rStream), + mHandshakeDone(false), + mMaxObjectSize(PROTOCOL_DEFAULT_MAXOBJSIZE), + mTimeout(PROTOCOL_DEFAULT_TIMEOUT), + mpBuffer(0), + mBufferSize(0), + mReadOffset(-1), + mWriteOffset(-1), + mValidDataSize(-1), + mLastErrorType(NoError), + mLastErrorSubType(NoError) +{ + BOX_TRACE("Send block allocation size is " << + PROTOCOL_ALLOCATE_SEND_BLOCK_CHUNK); +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::~Protocol() +// Purpose: Destructor +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +Protocol::~Protocol() +{ + // Free buffer? + if(mpBuffer != 0) + { + free(mpBuffer); + mpBuffer = 0; + } +} + + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::GetLastError(int &, int &) +// Purpose: Returns true if there was an error, and type and subtype if there was. +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +bool Protocol::GetLastError(int &rTypeOut, int &rSubTypeOut) +{ + if(mLastErrorType == NoError) + { + // no error. + return false; + } + + // Return type and subtype in args + rTypeOut = mLastErrorType; + rSubTypeOut = mLastErrorSubType; + + // and unset them + mLastErrorType = NoError; + mLastErrorSubType = NoError; + + return true; +} + + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::Handshake() +// Purpose: Handshake with peer (exchange ident strings) +// Created: 2003/08/20 +// +// -------------------------------------------------------------------------- +void Protocol::Handshake() +{ + // Already done? + if(mHandshakeDone) + { + THROW_EXCEPTION(CommonException, Internal) + } + + // Make handshake block + PW_Handshake hsSend; + ::memset(&hsSend, 0, sizeof(hsSend)); + // Copy in ident string + ::strncpy(hsSend.mIdent, GetIdentString(), sizeof(hsSend.mIdent)); + + // Send it + mrStream.Write(&hsSend, sizeof(hsSend)); + mrStream.WriteAllBuffered(); + + // Receive a handshake from the peer + PW_Handshake hsReceive; + ::memset(&hsReceive, 0, sizeof(hsReceive)); + char *readInto = (char*)&hsReceive; + int bytesToRead = sizeof(hsReceive); + while(bytesToRead > 0) + { + // Get some data from the stream + int bytesRead = mrStream.Read(readInto, bytesToRead, mTimeout); + if(bytesRead == 0) + { + THROW_EXCEPTION(ConnectionException, Conn_Protocol_Timeout) + } + readInto += bytesRead; + bytesToRead -= bytesRead; + } + ASSERT(bytesToRead == 0); + + // Are they the same? + if(::memcmp(&hsSend, &hsReceive, sizeof(hsSend)) != 0) + { + THROW_EXCEPTION(ConnectionException, Conn_Protocol_HandshakeFailed) + } + + // Mark as done + mHandshakeDone = true; +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::CheckAndReadHdr(void *) +// Purpose: Check read for recieve call and get object header from stream. +// Don't use type here to avoid dependency in .h file. +// Created: 2003/08/26 +// +// -------------------------------------------------------------------------- +void Protocol::CheckAndReadHdr(void *hdr) +{ + // Check usage + if(mValidDataSize != -1 || mWriteOffset != -1 || mReadOffset != -1) + { + THROW_EXCEPTION(ServerException, Protocol_BadUsage) + } + + // Handshake done? + if(!mHandshakeDone) + { + Handshake(); + } + + // Get some data into this header + if(!mrStream.ReadFullBuffer(hdr, sizeof(PW_ObjectHeader), 0 /* not interested in bytes read if this fails */, mTimeout)) + { + THROW_EXCEPTION(ConnectionException, Conn_Protocol_Timeout) + } +} + + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::Recieve() +// Purpose: Recieves an object from the stream, creating it from the factory object type +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +std::auto_ptr<ProtocolObject> Protocol::Receive() +{ + // Get object header + PW_ObjectHeader objHeader; + CheckAndReadHdr(&objHeader); + + // Hope it's not a stream + if(ntohl(objHeader.mObjType) == SPECIAL_STREAM_OBJECT_TYPE) + { + THROW_EXCEPTION(ConnectionException, Conn_Protocol_StreamWhenObjExpected) + } + + // Check the object size + u_int32_t objSize = ntohl(objHeader.mObjSize); + if(objSize < sizeof(objHeader) || objSize > mMaxObjectSize) + { + THROW_EXCEPTION(ConnectionException, Conn_Protocol_ObjTooBig) + } + + // Create a blank object + std::auto_ptr<ProtocolObject> obj(MakeProtocolObject(ntohl(objHeader.mObjType))); + + // Make sure memory is allocated to read it into + EnsureBufferAllocated(objSize); + + // Read data + if(!mrStream.ReadFullBuffer(mpBuffer, objSize - sizeof(objHeader), 0 /* not interested in bytes read if this fails */, mTimeout)) + { + THROW_EXCEPTION(ConnectionException, Conn_Protocol_Timeout) + } + + // Setup ready to read out data from the buffer + mValidDataSize = objSize - sizeof(objHeader); + mReadOffset = 0; + + // Get the object to read its properties from the data recieved + try + { + obj->SetPropertiesFromStreamData(*this); + } + catch(...) + { + // Make sure state is reset! + mValidDataSize = -1; + mReadOffset = -1; + throw; + } + + // Any data left over? + bool dataLeftOver = (mValidDataSize != mReadOffset); + + // Unset read state, so future read calls don't fail + mValidDataSize = -1; + mReadOffset = -1; + + // Exception if not all the data was consumed + if(dataLeftOver) + { + THROW_EXCEPTION(ConnectionException, Conn_Protocol_BadCommandRecieved) + } + + return obj; +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::Send() +// Purpose: Send an object to the other side of the connection. +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +void Protocol::Send(const ProtocolObject &rObject) +{ + // Check usage + if(mValidDataSize != -1 || mWriteOffset != -1 || mReadOffset != -1) + { + THROW_EXCEPTION(ServerException, Protocol_BadUsage) + } + + // Handshake done? + if(!mHandshakeDone) + { + Handshake(); + } + + // Make sure there's a little bit of space allocated + EnsureBufferAllocated(((sizeof(PW_ObjectHeader) + PROTOCOL_ALLOCATE_SEND_BLOCK_CHUNK - 1) / PROTOCOL_ALLOCATE_SEND_BLOCK_CHUNK) * PROTOCOL_ALLOCATE_SEND_BLOCK_CHUNK); + ASSERT(mBufferSize >= (int)sizeof(PW_ObjectHeader)); + + // Setup for write operation + mValidDataSize = 0; // Not used, but must not be -1 + mWriteOffset = sizeof(PW_ObjectHeader); + + try + { + rObject.WritePropertiesToStreamData(*this); + } + catch(...) + { + // Make sure state is reset! + mValidDataSize = -1; + mWriteOffset = -1; + throw; + } + + // How big? + int writtenSize = mWriteOffset; + + // Reset write state + mValidDataSize = -1; + mWriteOffset = -1; + + // Make header in the existing block + PW_ObjectHeader *pobjHeader = (PW_ObjectHeader*)(mpBuffer); + pobjHeader->mObjSize = htonl(writtenSize); + pobjHeader->mObjType = htonl(rObject.GetType()); + + // Write data + mrStream.Write(mpBuffer, writtenSize); + mrStream.WriteAllBuffered(); +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::EnsureBufferAllocated(int) +// Purpose: Private. Ensures the buffer is at least the size requested. +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +void Protocol::EnsureBufferAllocated(int Size) +{ + if(mpBuffer != 0 && mBufferSize >= Size) + { + // Nothing to do! + return; + } + + // Need to allocate, or reallocate, the block + if(mpBuffer != 0) + { + // Reallocate + void *b = realloc(mpBuffer, Size); + if(b == 0) + { + throw std::bad_alloc(); + } + mpBuffer = (char*)b; + mBufferSize = Size; + } + else + { + // Just allocate + mpBuffer = (char*)malloc(Size); + if(mpBuffer == 0) + { + throw std::bad_alloc(); + } + mBufferSize = Size; + } +} + + +#define READ_START_CHECK \ + if(mValidDataSize == -1 || mWriteOffset != -1 || mReadOffset == -1) \ + { \ + THROW_EXCEPTION(ServerException, Protocol_BadUsage) \ + } + +#define READ_CHECK_BYTES_AVAILABLE(bytesRequired) \ + if((mReadOffset + (int)(bytesRequired)) > mValidDataSize) \ + { \ + THROW_EXCEPTION(ConnectionException, Conn_Protocol_BadCommandRecieved) \ + } + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::Read(void *, int) +// Purpose: Read raw data from the stream (buffered) +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +void Protocol::Read(void *Buffer, int Size) +{ + READ_START_CHECK + READ_CHECK_BYTES_AVAILABLE(Size) + + // Copy data out + ::memmove(Buffer, mpBuffer + mReadOffset, Size); + mReadOffset += Size; +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::Read(std::string &, int) +// Purpose: Read raw data from the stream (buffered), into a std::string +// Created: 2003/08/26 +// +// -------------------------------------------------------------------------- +void Protocol::Read(std::string &rOut, int Size) +{ + READ_START_CHECK + READ_CHECK_BYTES_AVAILABLE(Size) + + rOut.assign(mpBuffer + mReadOffset, Size); + mReadOffset += Size; +} + + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::Read(int64_t &) +// Purpose: Read a value from the stream (buffered) +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +void Protocol::Read(int64_t &rOut) +{ + READ_START_CHECK + READ_CHECK_BYTES_AVAILABLE(sizeof(int64_t)) + +#ifdef HAVE_ALIGNED_ONLY_INT64 + int64_t nvalue; + memcpy(&nvalue, mpBuffer + mReadOffset, sizeof(int64_t)); +#else + int64_t nvalue = *((int64_t*)(mpBuffer + mReadOffset)); +#endif + rOut = box_ntoh64(nvalue); + + mReadOffset += sizeof(int64_t); +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::Read(int32_t &) +// Purpose: Read a value from the stream (buffered) +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +void Protocol::Read(int32_t &rOut) +{ + READ_START_CHECK + READ_CHECK_BYTES_AVAILABLE(sizeof(int32_t)) + +#ifdef HAVE_ALIGNED_ONLY_INT32 + int32_t nvalue; + memcpy(&nvalue, mpBuffer + mReadOffset, sizeof(int32_t)); +#else + int32_t nvalue = *((int32_t*)(mpBuffer + mReadOffset)); +#endif + rOut = ntohl(nvalue); + mReadOffset += sizeof(int32_t); +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::Read(int16_t &) +// Purpose: Read a value from the stream (buffered) +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +void Protocol::Read(int16_t &rOut) +{ + READ_START_CHECK + READ_CHECK_BYTES_AVAILABLE(sizeof(int16_t)) + + rOut = ntohs(*((int16_t*)(mpBuffer + mReadOffset))); + mReadOffset += sizeof(int16_t); +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::Read(int8_t &) +// Purpose: Read a value from the stream (buffered) +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +void Protocol::Read(int8_t &rOut) +{ + READ_START_CHECK + READ_CHECK_BYTES_AVAILABLE(sizeof(int8_t)) + + rOut = *((int8_t*)(mpBuffer + mReadOffset)); + mReadOffset += sizeof(int8_t); +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::Read(std::string &) +// Purpose: Read a value from the stream (buffered) +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +void Protocol::Read(std::string &rOut) +{ + // READ_START_CHECK implied + int32_t size; + Read(size); + + READ_CHECK_BYTES_AVAILABLE(size) + + // initialise string + rOut.assign(mpBuffer + mReadOffset, size); + mReadOffset += size; +} + + + + +#define WRITE_START_CHECK \ + if(mValidDataSize == -1 || mWriteOffset == -1 || mReadOffset != -1) \ + { \ + THROW_EXCEPTION(ServerException, Protocol_BadUsage) \ + } + +#define WRITE_ENSURE_BYTES_AVAILABLE(bytesToWrite) \ + if(mWriteOffset + (int)(bytesToWrite) > mBufferSize) \ + { \ + EnsureBufferAllocated((((mWriteOffset + (int)(bytesToWrite)) + PROTOCOL_ALLOCATE_SEND_BLOCK_CHUNK - 1) / PROTOCOL_ALLOCATE_SEND_BLOCK_CHUNK) * PROTOCOL_ALLOCATE_SEND_BLOCK_CHUNK); \ + ASSERT(mWriteOffset + (int)(bytesToWrite) <= mBufferSize); \ + } + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::Write(const void *, int) +// Purpose: Writes the contents of a buffer to the stream +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +void Protocol::Write(const void *Buffer, int Size) +{ + WRITE_START_CHECK + WRITE_ENSURE_BYTES_AVAILABLE(Size) + + ::memmove(mpBuffer + mWriteOffset, Buffer, Size); + mWriteOffset += Size; +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::Write(int64_t) +// Purpose: Writes a value to the stream +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +void Protocol::Write(int64_t Value) +{ + WRITE_START_CHECK + WRITE_ENSURE_BYTES_AVAILABLE(sizeof(int64_t)) + + int64_t nvalue = box_hton64(Value); +#ifdef HAVE_ALIGNED_ONLY_INT64 + memcpy(mpBuffer + mWriteOffset, &nvalue, sizeof(int64_t)); +#else + *((int64_t*)(mpBuffer + mWriteOffset)) = nvalue; +#endif + mWriteOffset += sizeof(int64_t); +} + + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::Write(int32_t) +// Purpose: Writes a value to the stream +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +void Protocol::Write(int32_t Value) +{ + WRITE_START_CHECK + WRITE_ENSURE_BYTES_AVAILABLE(sizeof(int32_t)) + + int32_t nvalue = htonl(Value); +#ifdef HAVE_ALIGNED_ONLY_INT32 + memcpy(mpBuffer + mWriteOffset, &nvalue, sizeof(int32_t)); +#else + *((int32_t*)(mpBuffer + mWriteOffset)) = nvalue; +#endif + mWriteOffset += sizeof(int32_t); +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::Write(int16_t) +// Purpose: Writes a value to the stream +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +void Protocol::Write(int16_t Value) +{ + WRITE_START_CHECK + WRITE_ENSURE_BYTES_AVAILABLE(sizeof(int16_t)) + + *((int16_t*)(mpBuffer + mWriteOffset)) = htons(Value); + mWriteOffset += sizeof(int16_t); +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::Write(int8_t) +// Purpose: Writes a value to the stream +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +void Protocol::Write(int8_t Value) +{ + WRITE_START_CHECK + WRITE_ENSURE_BYTES_AVAILABLE(sizeof(int8_t)) + + *((int8_t*)(mpBuffer + mWriteOffset)) = Value; + mWriteOffset += sizeof(int8_t); +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::Write(const std::string &) +// Purpose: Writes a value to the stream +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +void Protocol::Write(const std::string &rValue) +{ + // WRITE_START_CHECK implied + Write((int32_t)(rValue.size())); + + WRITE_ENSURE_BYTES_AVAILABLE(rValue.size()) + Write(rValue.c_str(), rValue.size()); +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::ReceieveStream() +// Purpose: Receive a stream from the remote side +// Created: 2003/08/26 +// +// -------------------------------------------------------------------------- +std::auto_ptr<IOStream> Protocol::ReceiveStream() +{ + // Get object header + PW_ObjectHeader objHeader; + CheckAndReadHdr(&objHeader); + + // Hope it's not an object + if(ntohl(objHeader.mObjType) != SPECIAL_STREAM_OBJECT_TYPE) + { + THROW_EXCEPTION(ConnectionException, Conn_Protocol_ObjWhenStreamExpected) + } + + // Get the stream size + u_int32_t streamSize = ntohl(objHeader.mObjSize); + + // Inform sub class + InformStreamReceiving(streamSize); + + // Return a stream object + if(streamSize == ProtocolStream_SizeUncertain) + { + BOX_TRACE("Receiving stream, size uncertain"); + return std::auto_ptr<IOStream>( + new ProtocolUncertainStream(mrStream)); + } + else + { + BOX_TRACE("Receiving stream, size " << streamSize << " bytes"); + return std::auto_ptr<IOStream>( + new PartialReadStream(mrStream, streamSize)); + } +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::SendStream(IOStream &) +// Purpose: Send a stream to the remote side +// Created: 2003/08/26 +// +// -------------------------------------------------------------------------- +void Protocol::SendStream(IOStream &rStream) +{ + // Check usage + if(mValidDataSize != -1 || mWriteOffset != -1 || mReadOffset != -1) + { + THROW_EXCEPTION(ServerException, Protocol_BadUsage) + } + + // Handshake done? + if(!mHandshakeDone) + { + Handshake(); + } + + // How should this be streamed? + bool uncertainSize = false; + IOStream::pos_type streamSize = rStream.BytesLeftToRead(); + if(streamSize == IOStream::SizeOfStreamUnknown + || streamSize > 0x7fffffff) + { + // Can't send this using the fixed size header + uncertainSize = true; + } + + // Inform sub class + InformStreamSending(streamSize); + + // Make header + PW_ObjectHeader objHeader; + objHeader.mObjSize = htonl(uncertainSize?(ProtocolStream_SizeUncertain):streamSize); + objHeader.mObjType = htonl(SPECIAL_STREAM_OBJECT_TYPE); + + // Write header + mrStream.Write(&objHeader, sizeof(objHeader)); + // Could be sent in one of two ways + if(uncertainSize) + { + // Don't know how big this is going to be -- so send it in chunks + + // Allocate memory + uint8_t *blockA = (uint8_t *)malloc(UNCERTAIN_STREAM_SIZE_BLOCK + sizeof(int)); + if(blockA == 0) + { + throw std::bad_alloc(); + } + uint8_t *block = blockA + sizeof(int); // so that everything is word aligned for reading, but can put the one byte header before it + + try + { + int bytesInBlock = 0; + while(rStream.StreamDataLeft()) + { + // Read some of it + bytesInBlock += rStream.Read(block + bytesInBlock, UNCERTAIN_STREAM_SIZE_BLOCK - bytesInBlock); + + // Send as much as we can out + bytesInBlock -= SendStreamSendBlock(block, bytesInBlock); + } + + // Everything recieved from stream, but need to send whatevers left in the block + while(bytesInBlock > 0) + { + bytesInBlock -= SendStreamSendBlock(block, bytesInBlock); + } + + // Send final byte to finish the stream + BOX_TRACE("Sending end of stream byte"); + uint8_t endOfStream = ProtocolStreamHeader_EndOfStream; + mrStream.Write(&endOfStream, 1); + BOX_TRACE("Sent end of stream byte"); + } + catch(...) + { + free(blockA); + throw; + } + + // Clean up + free(blockA); + } + else + { + // Fixed size stream, send it all in one go + if(!rStream.CopyStreamTo(mrStream, mTimeout, 4096 /* slightly larger buffer */)) + { + THROW_EXCEPTION(ConnectionException, Conn_Protocol_TimeOutWhenSendingStream) + } + } + // Make sure everything is written + mrStream.WriteAllBuffered(); + +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::SendStreamSendBlock(uint8_t *, int) +// Purpose: Sends as much of the block as can be sent, moves the remainer down to the beginning, +// and returns the number of bytes sent. WARNING: Will write to Block[-1] +// Created: 5/12/03 +// +// -------------------------------------------------------------------------- +int Protocol::SendStreamSendBlock(uint8_t *Block, int BytesInBlock) +{ + // Quick sanity check + if(BytesInBlock == 0) + { + BOX_TRACE("Zero size block, not sending anything"); + return 0; + } + + // Work out the header byte + uint8_t header = 0; + int writeSize = 0; + if(BytesInBlock >= (64*1024)) + { + header = ProtocolStreamHeader_SizeIs64k; + writeSize = (64*1024); + } + else + { + // Scan the table to find the most that can be written + for(int s = ProtocolStreamHeader_MaxEncodedSizeValue; s > 0; --s) + { + if(sProtocolStreamHeaderLengths[s] <= BytesInBlock) + { + header = s; + writeSize = sProtocolStreamHeaderLengths[s]; + break; + } + } + } + ASSERT(header > 0); + BOX_TRACE("Sending header byte " << (int)header << " plus " << + writeSize << " bytes to stream"); + + // Store the header + Block[-1] = header; + + // Write everything out + mrStream.Write(Block - 1, writeSize + 1); + + BOX_TRACE("Sent " << (writeSize+1) << " bytes to stream"); + // move the remainer to the beginning of the block for the next time round + if(writeSize != BytesInBlock) + { + ::memmove(Block, Block + writeSize, BytesInBlock - writeSize); + } + + return writeSize; +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::InformStreamReceiving(u_int32_t) +// Purpose: Informs sub classes about streams being received +// Created: 2003/10/27 +// +// -------------------------------------------------------------------------- +void Protocol::InformStreamReceiving(u_int32_t Size) +{ + // Do nothing +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Protocol::InformStreamSending(u_int32_t) +// Purpose: Informs sub classes about streams being sent +// Created: 2003/10/27 +// +// -------------------------------------------------------------------------- +void Protocol::InformStreamSending(u_int32_t Size) +{ + // Do nothing +} + + +/* +perl code to generate the table below + +#!/usr/bin/perl +use strict; +open OUT,">protolengths.txt"; +my $len = 0; +for(0 .. 255) +{ + print OUT "\t$len,\t// $_\n"; + my $inc = 1; + $inc = 8 if $_ >= 64; + $inc = 16 if $_ >= 96; + $inc = 32 if $_ >= 112; + $inc = 64 if $_ >= 128; + $inc = 128 if $_ >= 135; + $inc = 256 if $_ >= 147; + $inc = 512 if $_ >= 159; + $inc = 1024 if $_ >= 231; + $len += $inc; +} +close OUT; + +*/ +const uint16_t Protocol::sProtocolStreamHeaderLengths[256] = +{ + 0, // 0 + 1, // 1 + 2, // 2 + 3, // 3 + 4, // 4 + 5, // 5 + 6, // 6 + 7, // 7 + 8, // 8 + 9, // 9 + 10, // 10 + 11, // 11 + 12, // 12 + 13, // 13 + 14, // 14 + 15, // 15 + 16, // 16 + 17, // 17 + 18, // 18 + 19, // 19 + 20, // 20 + 21, // 21 + 22, // 22 + 23, // 23 + 24, // 24 + 25, // 25 + 26, // 26 + 27, // 27 + 28, // 28 + 29, // 29 + 30, // 30 + 31, // 31 + 32, // 32 + 33, // 33 + 34, // 34 + 35, // 35 + 36, // 36 + 37, // 37 + 38, // 38 + 39, // 39 + 40, // 40 + 41, // 41 + 42, // 42 + 43, // 43 + 44, // 44 + 45, // 45 + 46, // 46 + 47, // 47 + 48, // 48 + 49, // 49 + 50, // 50 + 51, // 51 + 52, // 52 + 53, // 53 + 54, // 54 + 55, // 55 + 56, // 56 + 57, // 57 + 58, // 58 + 59, // 59 + 60, // 60 + 61, // 61 + 62, // 62 + 63, // 63 + 64, // 64 + 72, // 65 + 80, // 66 + 88, // 67 + 96, // 68 + 104, // 69 + 112, // 70 + 120, // 71 + 128, // 72 + 136, // 73 + 144, // 74 + 152, // 75 + 160, // 76 + 168, // 77 + 176, // 78 + 184, // 79 + 192, // 80 + 200, // 81 + 208, // 82 + 216, // 83 + 224, // 84 + 232, // 85 + 240, // 86 + 248, // 87 + 256, // 88 + 264, // 89 + 272, // 90 + 280, // 91 + 288, // 92 + 296, // 93 + 304, // 94 + 312, // 95 + 320, // 96 + 336, // 97 + 352, // 98 + 368, // 99 + 384, // 100 + 400, // 101 + 416, // 102 + 432, // 103 + 448, // 104 + 464, // 105 + 480, // 106 + 496, // 107 + 512, // 108 + 528, // 109 + 544, // 110 + 560, // 111 + 576, // 112 + 608, // 113 + 640, // 114 + 672, // 115 + 704, // 116 + 736, // 117 + 768, // 118 + 800, // 119 + 832, // 120 + 864, // 121 + 896, // 122 + 928, // 123 + 960, // 124 + 992, // 125 + 1024, // 126 + 1056, // 127 + 1088, // 128 + 1152, // 129 + 1216, // 130 + 1280, // 131 + 1344, // 132 + 1408, // 133 + 1472, // 134 + 1536, // 135 + 1664, // 136 + 1792, // 137 + 1920, // 138 + 2048, // 139 + 2176, // 140 + 2304, // 141 + 2432, // 142 + 2560, // 143 + 2688, // 144 + 2816, // 145 + 2944, // 146 + 3072, // 147 + 3328, // 148 + 3584, // 149 + 3840, // 150 + 4096, // 151 + 4352, // 152 + 4608, // 153 + 4864, // 154 + 5120, // 155 + 5376, // 156 + 5632, // 157 + 5888, // 158 + 6144, // 159 + 6656, // 160 + 7168, // 161 + 7680, // 162 + 8192, // 163 + 8704, // 164 + 9216, // 165 + 9728, // 166 + 10240, // 167 + 10752, // 168 + 11264, // 169 + 11776, // 170 + 12288, // 171 + 12800, // 172 + 13312, // 173 + 13824, // 174 + 14336, // 175 + 14848, // 176 + 15360, // 177 + 15872, // 178 + 16384, // 179 + 16896, // 180 + 17408, // 181 + 17920, // 182 + 18432, // 183 + 18944, // 184 + 19456, // 185 + 19968, // 186 + 20480, // 187 + 20992, // 188 + 21504, // 189 + 22016, // 190 + 22528, // 191 + 23040, // 192 + 23552, // 193 + 24064, // 194 + 24576, // 195 + 25088, // 196 + 25600, // 197 + 26112, // 198 + 26624, // 199 + 27136, // 200 + 27648, // 201 + 28160, // 202 + 28672, // 203 + 29184, // 204 + 29696, // 205 + 30208, // 206 + 30720, // 207 + 31232, // 208 + 31744, // 209 + 32256, // 210 + 32768, // 211 + 33280, // 212 + 33792, // 213 + 34304, // 214 + 34816, // 215 + 35328, // 216 + 35840, // 217 + 36352, // 218 + 36864, // 219 + 37376, // 220 + 37888, // 221 + 38400, // 222 + 38912, // 223 + 39424, // 224 + 39936, // 225 + 40448, // 226 + 40960, // 227 + 41472, // 228 + 41984, // 229 + 42496, // 230 + 43008, // 231 + 44032, // 232 + 45056, // 233 + 46080, // 234 + 47104, // 235 + 48128, // 236 + 49152, // 237 + 50176, // 238 + 51200, // 239 + 52224, // 240 + 53248, // 241 + 54272, // 242 + 55296, // 243 + 56320, // 244 + 57344, // 245 + 58368, // 246 + 59392, // 247 + 60416, // 248 + 61440, // 249 + 62464, // 250 + 63488, // 251 + 64512, // 252 + 0, // 253 = 65536 / 64k + 0, // 254 = special (reserved) + 0 // 255 = special (reserved) +}; + + + + diff --git a/lib/server/Protocol.h b/lib/server/Protocol.h new file mode 100644 index 00000000..e037e33c --- /dev/null +++ b/lib/server/Protocol.h @@ -0,0 +1,201 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: Protocol.h +// Purpose: Generic protocol support +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- + +#ifndef PROTOCOL__H +#define PROTOCOL__H + +#include <sys/types.h> + +class IOStream; +#include "ProtocolObject.h" +#include <memory> +#include <vector> +#include <string> + +// default timeout is 15 minutes +#define PROTOCOL_DEFAULT_TIMEOUT (15*60*1000) +// 16 default maximum object size -- should be enough +#define PROTOCOL_DEFAULT_MAXOBJSIZE (16*1024) + +// -------------------------------------------------------------------------- +// +// Class +// Name: Protocol +// Purpose: Generic command / response protocol support +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +class Protocol +{ +public: + Protocol(IOStream &rStream); + virtual ~Protocol(); + +private: + Protocol(const Protocol &rToCopy); + +public: + void Handshake(); + std::auto_ptr<ProtocolObject> Receive(); + void Send(const ProtocolObject &rObject); + + std::auto_ptr<IOStream> ReceiveStream(); + void SendStream(IOStream &rStream); + + enum + { + NoError = -1, + UnknownError = 0 + }; + + bool GetLastError(int &rTypeOut, int &rSubTypeOut); + + // -------------------------------------------------------------------------- + // + // Function + // Name: Protocol::SetTimeout(int) + // Purpose: Sets the timeout for sending and reciving + // Created: 2003/08/19 + // + // -------------------------------------------------------------------------- + void SetTimeout(int NewTimeout) {mTimeout = NewTimeout;} + + + // -------------------------------------------------------------------------- + // + // Function + // Name: Protocol::GetTimeout() + // Purpose: Get current timeout for sending and receiving + // Created: 2003/09/06 + // + // -------------------------------------------------------------------------- + int GetTimeout() {return mTimeout;} + + // -------------------------------------------------------------------------- + // + // Function + // Name: Protocol::SetMaxObjectSize(int) + // Purpose: Sets the maximum size of an object which will be accepted + // Created: 2003/08/19 + // + // -------------------------------------------------------------------------- + void SetMaxObjectSize(unsigned int NewMaxObjSize) {mMaxObjectSize = NewMaxObjSize;} + + // For ProtocolObject derived classes + void Read(void *Buffer, int Size); + void Read(std::string &rOut, int Size); + void Read(int64_t &rOut); + void Read(int32_t &rOut); + void Read(int16_t &rOut); + void Read(int8_t &rOut); + void Read(bool &rOut) {int8_t read; Read(read); rOut = (read == true);} + void Read(std::string &rOut); + template<typename type> + void Read(type &rOut) + { + rOut.ReadFromProtocol(*this); + } + // -------------------------------------------------------------------------- + // + // Function + // Name: Protocol::ReadVector(std::vector<> &) + // Purpose: Reads a vector/list of items from the stream + // Created: 2003/08/19 + // + // -------------------------------------------------------------------------- + template<typename type> + void ReadVector(std::vector<type> &rOut) + { + rOut.clear(); + int16_t num = 0; + Read(num); + for(int16_t n = 0; n < num; ++n) + { + type v; + Read(v); + rOut.push_back(v); + } + } + + void Write(const void *Buffer, int Size); + void Write(int64_t Value); + void Write(int32_t Value); + void Write(int16_t Value); + void Write(int8_t Value); + void Write(bool Value) {int8_t write = Value; Write(write);} + void Write(const std::string &rValue); + template<typename type> + void Write(const type &rValue) + { + rValue.WriteToProtocol(*this); + } + template<typename type> + // -------------------------------------------------------------------------- + // + // Function + // Name: Protocol::WriteVector(const std::vector<> &) + // Purpose: Writes a vector/list of items from the stream + // Created: 2003/08/19 + // + // -------------------------------------------------------------------------- + void WriteVector(const std::vector<type> &rValue) + { + int16_t num = rValue.size(); + Write(num); + for(int16_t n = 0; n < num; ++n) + { + Write(rValue[n]); + } + } + +public: + static const uint16_t sProtocolStreamHeaderLengths[256]; + enum + { + ProtocolStreamHeader_EndOfStream = 0, + ProtocolStreamHeader_MaxEncodedSizeValue = 252, + ProtocolStreamHeader_SizeIs64k = 253, + ProtocolStreamHeader_Reserved1 = 254, + ProtocolStreamHeader_Reserved2 = 255 + }; + enum + { + ProtocolStream_SizeUncertain = 0xffffffff + }; + +protected: + virtual std::auto_ptr<ProtocolObject> MakeProtocolObject(int ObjType) = 0; + virtual const char *GetIdentString() = 0; + void SetError(int Type, int SubType) {mLastErrorType = Type; mLastErrorSubType = SubType;} + void CheckAndReadHdr(void *hdr); // don't use type here to avoid dependency + + // Will be used for logging + virtual void InformStreamReceiving(u_int32_t Size); + virtual void InformStreamSending(u_int32_t Size); + +private: + void EnsureBufferAllocated(int Size); + int SendStreamSendBlock(uint8_t *Block, int BytesInBlock); + +private: + IOStream &mrStream; + bool mHandshakeDone; + unsigned int mMaxObjectSize; + int mTimeout; + char *mpBuffer; + int mBufferSize; + int mReadOffset; + int mWriteOffset; + int mValidDataSize; + int mLastErrorType; + int mLastErrorSubType; +}; + +#endif // PROTOCOL__H + diff --git a/lib/server/ProtocolObject.cpp b/lib/server/ProtocolObject.cpp new file mode 100644 index 00000000..fb09f820 --- /dev/null +++ b/lib/server/ProtocolObject.cpp @@ -0,0 +1,125 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: ProtocolObject.h +// Purpose: Protocol object base class +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- + +#include "Box.h" +#include "ProtocolObject.h" +#include "CommonException.h" + +#include "MemLeakFindOn.h" + +// -------------------------------------------------------------------------- +// +// Function +// Name: ProtocolObject::ProtocolObject() +// Purpose: Default constructor +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +ProtocolObject::ProtocolObject() +{ +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: ProtocolObject::ProtocolObject() +// Purpose: Destructor +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +ProtocolObject::~ProtocolObject() +{ +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: ProtocolObject::ProtocolObject() +// Purpose: Copy constructor +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +ProtocolObject::ProtocolObject(const ProtocolObject &rToCopy) +{ +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: ProtocolObject::IsError(int &, int &) +// Purpose: Does this represent an error, and if so, what is the type and subtype? +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +bool ProtocolObject::IsError(int &rTypeOut, int &rSubTypeOut) const +{ + return false; +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: ProtocolObject::IsConversationEnd() +// Purpose: Does this command end the conversation? +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +bool ProtocolObject::IsConversationEnd() const +{ + return false; +} + + +// -------------------------------------------------------------------------- +// +// Function +// Name: ProtocolObject::GetType() +// Purpose: Return type of the object +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +int ProtocolObject::GetType() const +{ + // This isn't implemented in the base class! + THROW_EXCEPTION(CommonException, Internal) +} + + +// -------------------------------------------------------------------------- +// +// Function +// Name: ProtocolObject::SetPropertiesFromStreamData(Protocol &) +// Purpose: Set the properties of the object from the stream data ready in the Protocol object +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +void ProtocolObject::SetPropertiesFromStreamData(Protocol &rProtocol) +{ + // This isn't implemented in the base class! + THROW_EXCEPTION(CommonException, Internal) +} + + + +// -------------------------------------------------------------------------- +// +// Function +// Name: ProtocolObject::WritePropertiesToStreamData(Protocol &) +// Purpose: Write the properties of the object into the stream data in the Protocol object +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +void ProtocolObject::WritePropertiesToStreamData(Protocol &rProtocol) const +{ + // This isn't implemented in the base class! + THROW_EXCEPTION(CommonException, Internal) +} + + + diff --git a/lib/server/ProtocolObject.h b/lib/server/ProtocolObject.h new file mode 100644 index 00000000..0a127ab5 --- /dev/null +++ b/lib/server/ProtocolObject.h @@ -0,0 +1,41 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: ProtocolObject.h +// Purpose: Protocol object base class +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- + +#ifndef PROTOCOLOBJECT__H +#define PROTOCOLOBJECT__H + +class Protocol; + +// -------------------------------------------------------------------------- +// +// Class +// Name: ProtocolObject +// Purpose: Basic object representation of objects to pass through a Protocol session +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- +class ProtocolObject +{ +public: + ProtocolObject(); + virtual ~ProtocolObject(); + ProtocolObject(const ProtocolObject &rToCopy); + + // Info about this object + virtual int GetType() const; + virtual bool IsError(int &rTypeOut, int &rSubTypeOut) const; + virtual bool IsConversationEnd() const; + + // reading and writing with Protocol objects + virtual void SetPropertiesFromStreamData(Protocol &rProtocol); + virtual void WritePropertiesToStreamData(Protocol &rProtocol) const; +}; + +#endif // PROTOCOLOBJECT__H + diff --git a/lib/server/ProtocolUncertainStream.cpp b/lib/server/ProtocolUncertainStream.cpp new file mode 100644 index 00000000..84a213a8 --- /dev/null +++ b/lib/server/ProtocolUncertainStream.cpp @@ -0,0 +1,206 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: ProtocolUncertainStream.h +// Purpose: Read part of another stream +// Created: 2003/12/05 +// +// -------------------------------------------------------------------------- + +#include "Box.h" +#include "ProtocolUncertainStream.h" +#include "ServerException.h" +#include "Protocol.h" + +#include "MemLeakFindOn.h" + +// -------------------------------------------------------------------------- +// +// Function +// Name: ProtocolUncertainStream::ProtocolUncertainStream(IOStream &, int) +// Purpose: Constructor, taking another stream. +// Created: 2003/12/05 +// +// -------------------------------------------------------------------------- +ProtocolUncertainStream::ProtocolUncertainStream(IOStream &rSource) + : mrSource(rSource), + mBytesLeftInCurrentBlock(0), + mFinished(false) +{ +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: ProtocolUncertainStream::~ProtocolUncertainStream() +// Purpose: Destructor. Won't absorb any unread bytes. +// Created: 2003/12/05 +// +// -------------------------------------------------------------------------- +ProtocolUncertainStream::~ProtocolUncertainStream() +{ + if(!mFinished) + { + BOX_WARNING("ProtocolUncertainStream destroyed before " + "stream finished"); + } +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: ProtocolUncertainStream::Read(void *, int, int) +// Purpose: As interface. +// Created: 2003/12/05 +// +// -------------------------------------------------------------------------- +int ProtocolUncertainStream::Read(void *pBuffer, int NBytes, int Timeout) +{ + // Finished? + if(mFinished) + { + return 0; + } + + int read = 0; + while(read < NBytes) + { + // Anything we can get from the current block? + ASSERT(mBytesLeftInCurrentBlock >= 0); + if(mBytesLeftInCurrentBlock > 0) + { + // Yes, let's use some of these up + int toRead = (NBytes - read); + if(toRead > mBytesLeftInCurrentBlock) + { + // Adjust downwards to only read stuff out of the current block + toRead = mBytesLeftInCurrentBlock; + } + + BOX_TRACE("Reading " << toRead << " bytes from stream"); + + // Read it + int r = mrSource.Read(((uint8_t*)pBuffer) + read, toRead, Timeout); + // Give up now if it didn't return anything + if(r == 0) + { + BOX_TRACE("Read " << r << " bytes from " + "stream, returning"); + return read; + } + + // Adjust counts of bytes by the bytes recieved + read += r; + mBytesLeftInCurrentBlock -= r; + + // stop now if the stream returned less than we asked for -- avoid blocking + if(r != toRead) + { + BOX_TRACE("Read " << r << " bytes from " + "stream, returning"); + return read; + } + } + else + { + // Read the header byte to find out how much there is + // in the next block + uint8_t header; + if(mrSource.Read(&header, 1, Timeout) == 0) + { + // Didn't get the byte, return now + BOX_TRACE("Read 0 bytes of block header, " + "returning with " << read << " bytes " + "read this time"); + return read; + } + + // Interpret the byte... + if(header == Protocol::ProtocolStreamHeader_EndOfStream) + { + // All done. + mFinished = true; + BOX_TRACE("Stream finished, returning with " << + read << " bytes read this time"); + return read; + } + else if(header <= Protocol::ProtocolStreamHeader_MaxEncodedSizeValue) + { + // get size of the block from the Protocol's lovely list + mBytesLeftInCurrentBlock = Protocol::sProtocolStreamHeaderLengths[header]; + } + else if(header == Protocol::ProtocolStreamHeader_SizeIs64k) + { + // 64k + mBytesLeftInCurrentBlock = (64*1024); + } + else + { + // Bad. It used the reserved values. + THROW_EXCEPTION(ServerException, ProtocolUncertainStreamBadBlockHeader) + } + + BOX_TRACE("Read header byte " << (int)header << ", " + "next block has " << + mBytesLeftInCurrentBlock << " bytes"); + } + } + + // Return the number read + return read; +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: ProtocolUncertainStream::BytesLeftToRead() +// Purpose: As interface. +// Created: 2003/12/05 +// +// -------------------------------------------------------------------------- +IOStream::pos_type ProtocolUncertainStream::BytesLeftToRead() +{ + // Only know how much is left if everything is finished + return mFinished?(0):(IOStream::SizeOfStreamUnknown); +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: ProtocolUncertainStream::Write(const void *, int) +// Purpose: As interface. But will exception. +// Created: 2003/12/05 +// +// -------------------------------------------------------------------------- +void ProtocolUncertainStream::Write(const void *pBuffer, int NBytes) +{ + THROW_EXCEPTION(ServerException, CantWriteToProtocolUncertainStream) +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: ProtocolUncertainStream::StreamDataLeft() +// Purpose: As interface. +// Created: 2003/12/05 +// +// -------------------------------------------------------------------------- +bool ProtocolUncertainStream::StreamDataLeft() +{ + return !mFinished; +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: ProtocolUncertainStream::StreamClosed() +// Purpose: As interface. +// Created: 2003/12/05 +// +// -------------------------------------------------------------------------- +bool ProtocolUncertainStream::StreamClosed() +{ + // always closed + return true; +} + diff --git a/lib/server/ProtocolUncertainStream.h b/lib/server/ProtocolUncertainStream.h new file mode 100644 index 00000000..4954cf88 --- /dev/null +++ b/lib/server/ProtocolUncertainStream.h @@ -0,0 +1,47 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: ProtocolUncertainStream.h +// Purpose: Read part of another stream +// Created: 2003/12/05 +// +// -------------------------------------------------------------------------- + +#ifndef PROTOCOLUNCERTAINSTREAM__H +#define PROTOCOLUNCERTAINSTREAM__H + +#include "IOStream.h" + +// -------------------------------------------------------------------------- +// +// Class +// Name: ProtocolUncertainStream +// Purpose: Read part of another stream +// Created: 2003/12/05 +// +// -------------------------------------------------------------------------- +class ProtocolUncertainStream : public IOStream +{ +public: + ProtocolUncertainStream(IOStream &rSource); + ~ProtocolUncertainStream(); +private: + // no copying allowed + ProtocolUncertainStream(const IOStream &); + ProtocolUncertainStream(const ProtocolUncertainStream &); + +public: + virtual int Read(void *pBuffer, int NBytes, int Timeout = IOStream::TimeOutInfinite); + virtual pos_type BytesLeftToRead(); + virtual void Write(const void *pBuffer, int NBytes); + virtual bool StreamDataLeft(); + virtual bool StreamClosed(); + +private: + IOStream &mrSource; + int mBytesLeftInCurrentBlock; + bool mFinished; +}; + +#endif // PROTOCOLUNCERTAINSTREAM__H + diff --git a/lib/server/ProtocolWire.h b/lib/server/ProtocolWire.h new file mode 100644 index 00000000..ff62b66e --- /dev/null +++ b/lib/server/ProtocolWire.h @@ -0,0 +1,43 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: ProtocolWire.h +// Purpose: On the wire structures for Protocol +// Created: 2003/08/19 +// +// -------------------------------------------------------------------------- + +#ifndef PROTOCOLWIRE__H +#define PROTOCOLWIRE__H + +#include <sys/types.h> + +// set packing to one byte +#ifdef STRUCTURE_PACKING_FOR_WIRE_USE_HEADERS +#include "BeginStructPackForWire.h" +#else +BEGIN_STRUCTURE_PACKING_FOR_WIRE +#endif + +typedef struct +{ + char mIdent[32]; +} PW_Handshake; + +typedef struct +{ + u_int32_t mObjSize; + u_int32_t mObjType; +} PW_ObjectHeader; + +#define SPECIAL_STREAM_OBJECT_TYPE 0xffffffff + +// Use default packing +#ifdef STRUCTURE_PACKING_FOR_WIRE_USE_HEADERS +#include "EndStructPackForWire.h" +#else +END_STRUCTURE_PACKING_FOR_WIRE +#endif + +#endif // PROTOCOLWIRE__H + diff --git a/lib/server/SSLLib.cpp b/lib/server/SSLLib.cpp new file mode 100644 index 00000000..de7a941b --- /dev/null +++ b/lib/server/SSLLib.cpp @@ -0,0 +1,111 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: SSLLib.cpp +// Purpose: Utility functions for dealing with the OpenSSL library +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- + +#include "Box.h" + +#define TLS_CLASS_IMPLEMENTATION_CPP +#include <openssl/ssl.h> +#include <openssl/err.h> +#include <openssl/rand.h> + +#ifdef WIN32 + #include <wincrypt.h> +#endif + +#include "SSLLib.h" +#include "ServerException.h" + +#include "MemLeakFindOn.h" + +#ifndef BOX_RELEASE_BUILD + bool SSLLib__TraceErrors = false; +#endif + +// -------------------------------------------------------------------------- +// +// Function +// Name: SSLLib::Initialise() +// Purpose: Initialise SSL library +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- +void SSLLib::Initialise() +{ + if(!::SSL_library_init()) + { + LogError("initialising OpenSSL"); + THROW_EXCEPTION(ServerException, SSLLibraryInitialisationError) + } + + // More helpful error messages + ::SSL_load_error_strings(); + + // Extra seeding over and above what's already done by the library +#ifdef WIN32 + HCRYPTPROV provider; + if(!CryptAcquireContext(&provider, NULL, NULL, PROV_RSA_FULL, + CRYPT_VERIFYCONTEXT)) + { + BOX_LOG_WIN_ERROR("Failed to acquire crypto context"); + BOX_WARNING("No random device -- additional seeding of " + "random number generator not performed."); + } + else + { + // must free provider + BYTE buf[1024]; + + if(!CryptGenRandom(provider, sizeof(buf), buf)) + { + BOX_LOG_WIN_ERROR("Failed to get random data"); + BOX_WARNING("No random device -- additional seeding of " + "random number generator not performed."); + } + else + { + RAND_seed(buf, sizeof(buf)); + } + + if(!CryptReleaseContext(provider, 0)) + { + BOX_LOG_WIN_ERROR("Failed to release crypto context"); + } + } +#elif HAVE_RANDOM_DEVICE + if(::RAND_load_file(RANDOM_DEVICE, 1024) != 1024) + { + THROW_EXCEPTION(ServerException, SSLRandomInitFailed) + } +#else + BOX_WARNING("No random device -- additional seeding of " + "random number generator not performed."); +#endif +} + + +// -------------------------------------------------------------------------- +// +// Function +// Name: SSLLib::LogError(const char *) +// Purpose: Logs an error +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- +void SSLLib::LogError(const std::string& rErrorDuringAction) +{ + unsigned long errcode; + char errname[256]; // SSL docs say at least 120 bytes + while((errcode = ERR_get_error()) != 0) + { + ::ERR_error_string_n(errcode, errname, sizeof(errname)); + BOX_ERROR("SSL error while " << rErrorDuringAction << ": " << + errname); + } +} + diff --git a/lib/server/SSLLib.h b/lib/server/SSLLib.h new file mode 100644 index 00000000..ff4aab19 --- /dev/null +++ b/lib/server/SSLLib.h @@ -0,0 +1,36 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: SSLLib.h +// Purpose: Utility functions for dealing with the OpenSSL library +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- + +#ifndef SSLLIB__H +#define SSLLIB__H + +#ifndef BOX_RELEASE_BUILD + extern bool SSLLib__TraceErrors; + #define SET_DEBUG_SSLLIB_TRACE_ERRORS {SSLLib__TraceErrors = true;} +#else + #define SET_DEBUG_SSLLIB_TRACE_ERRORS +#endif + + +// -------------------------------------------------------------------------- +// +// Namespace +// Name: SSLLib +// Purpose: Utility functions for dealing with the OpenSSL library +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- +namespace SSLLib +{ + void Initialise(); + void LogError(const std::string& rErrorDuringAction); +}; + +#endif // SSLLIB__H + diff --git a/lib/server/ServerControl.cpp b/lib/server/ServerControl.cpp new file mode 100644 index 00000000..b9650cee --- /dev/null +++ b/lib/server/ServerControl.cpp @@ -0,0 +1,228 @@ +#include "Box.h" + +#include <errno.h> +#include <stdio.h> + +#ifdef HAVE_SYS_TYPES_H + #include <sys/types.h> +#endif + +#ifdef HAVE_SYS_WAIT_H + #include <sys/wait.h> +#endif + +#ifdef HAVE_SIGNAL_H + #include <signal.h> +#endif + +#include "ServerControl.h" +#include "Test.h" + +#ifdef WIN32 + +#include "WinNamedPipeStream.h" +#include "IOStreamGetLine.h" +#include "BoxPortsAndFiles.h" + +static std::string sPipeName; + +void SetNamedPipeName(const std::string& rPipeName) +{ + sPipeName = rPipeName; +} + +bool SendCommands(const std::string& rCmd) +{ + WinNamedPipeStream connection; + + try + { + connection.Connect(sPipeName); + } + catch(...) + { + BOX_ERROR("Failed to connect to daemon control socket"); + return false; + } + + // For receiving data + IOStreamGetLine getLine(connection); + + // Wait for the configuration summary + std::string configSummary; + if(!getLine.GetLine(configSummary)) + { + BOX_ERROR("Failed to receive configuration summary from daemon"); + return false; + } + + // Was the connection rejected by the server? + if(getLine.IsEOF()) + { + BOX_ERROR("Server rejected the connection"); + return false; + } + + // Decode it + int autoBackup, updateStoreInterval, minimumFileAge, maxUploadWait; + if(::sscanf(configSummary.c_str(), "bbackupd: %d %d %d %d", + &autoBackup, &updateStoreInterval, + &minimumFileAge, &maxUploadWait) != 4) + { + BOX_ERROR("Config summary didn't decode"); + return false; + } + + std::string cmds; + bool expectResponse; + + if (rCmd != "") + { + cmds = rCmd; + cmds += "\nquit\n"; + expectResponse = true; + } + else + { + cmds = "quit\n"; + expectResponse = false; + } + + connection.Write(cmds.c_str(), cmds.size()); + + // Read the response + std::string line; + bool statusOk = !expectResponse; + + while (expectResponse && !getLine.IsEOF() && getLine.GetLine(line)) + { + // Is this an OK or error line? + if (line == "ok") + { + statusOk = true; + } + else if (line == "error") + { + BOX_ERROR(rCmd); + break; + } + else + { + BOX_WARNING("Unexpected response to command '" << + rCmd << "': " << line) + } + } + + return statusOk; +} + +bool HUPServer(int pid) +{ + return SendCommands("reload"); +} + +bool KillServerInternal(int pid) +{ + HANDLE hProcess = OpenProcess(PROCESS_TERMINATE, false, pid); + if (hProcess == NULL) + { + BOX_ERROR("Failed to open process " << pid << ": " << + GetErrorMessage(GetLastError())); + return false; + } + + if (!TerminateProcess(hProcess, 1)) + { + BOX_ERROR("Failed to terminate process " << pid << ": " << + GetErrorMessage(GetLastError())); + CloseHandle(hProcess); + return false; + } + + CloseHandle(hProcess); + return true; +} + +#else // !WIN32 + +bool HUPServer(int pid) +{ + if(pid == 0) return false; + return ::kill(pid, SIGHUP) == 0; +} + +bool KillServerInternal(int pid) +{ + if(pid == 0 || pid == -1) return false; + bool killed = (::kill(pid, SIGTERM) == 0); + if (!killed) + { + BOX_LOG_SYS_ERROR("Failed to kill process " << pid); + } + TEST_THAT(killed); + return killed; +} + +#endif // WIN32 + +bool KillServer(int pid, bool WaitForProcess) +{ + if (!KillServerInternal(pid)) + { + return false; + } + + #ifdef HAVE_WAITPID + if (WaitForProcess) + { + int status, result; + + result = waitpid(pid, &status, 0); + if (result != pid) + { + BOX_LOG_SYS_ERROR("waitpid failed"); + } + TEST_THAT(result == pid); + + TEST_THAT(WIFEXITED(status)); + if (WIFEXITED(status)) + { + if (WEXITSTATUS(status) != 0) + { + BOX_WARNING("process exited with code " << + WEXITSTATUS(status)); + } + TEST_THAT(WEXITSTATUS(status) == 0); + } + } + #endif + + for (int i = 0; i < 30; i++) + { + if (i == 0) + { + printf("Waiting for server to die (pid %d): ", pid); + } + + printf("."); + fflush(stdout); + + if (!ServerIsAlive(pid)) break; + ::sleep(1); + if (!ServerIsAlive(pid)) break; + } + + if (!ServerIsAlive(pid)) + { + printf(" done.\n"); + } + else + { + printf(" failed!\n"); + } + + fflush(stdout); + + return !ServerIsAlive(pid); +} + diff --git a/lib/server/ServerControl.h b/lib/server/ServerControl.h new file mode 100644 index 00000000..b2e51864 --- /dev/null +++ b/lib/server/ServerControl.h @@ -0,0 +1,18 @@ +#ifndef SERVER_CONTROL_H +#define SERVER_CONTROL_H + +#include "Test.h" + +bool HUPServer(int pid); +bool KillServer(int pid, bool WaitForProcess = false); + +#ifdef WIN32 + #include "WinNamedPipeStream.h" + #include "IOStreamGetLine.h" + #include "BoxPortsAndFiles.h" + + void SetNamedPipeName(const std::string& rPipeName); + // bool SendCommands(const std::string& rCmd); +#endif // WIN32 + +#endif // SERVER_CONTROL_H diff --git a/lib/server/ServerException.h b/lib/server/ServerException.h new file mode 100644 index 00000000..8851b90a --- /dev/null +++ b/lib/server/ServerException.h @@ -0,0 +1,46 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: ServerException.h +// Purpose: Exception +// Created: 2003/07/08 +// +// -------------------------------------------------------------------------- + +#ifndef SERVEREXCEPTION__H +#define SERVEREXCEPTION__H + +// Compatibility header +#include "autogen_ServerException.h" +#include "autogen_ConnectionException.h" + +// Rename old connection exception names to new names without Conn_ prefix +// This is all because ConnectionException used to be derived from ServerException +// with some funky magic with subtypes. Perhaps a little unreliable, and the +// usefulness of it never really was used. +#define Conn_SocketWriteError SocketWriteError +#define Conn_SocketReadError SocketReadError +#define Conn_SocketNameLookupError SocketNameLookupError +#define Conn_SocketShutdownError SocketShutdownError +#define Conn_SocketConnectError SocketConnectError +#define Conn_TLSHandshakeFailed TLSHandshakeFailed +#define Conn_TLSShutdownFailed TLSShutdownFailed +#define Conn_TLSWriteFailed TLSWriteFailed +#define Conn_TLSReadFailed TLSReadFailed +#define Conn_TLSNoPeerCertificate TLSNoPeerCertificate +#define Conn_TLSPeerCertificateInvalid TLSPeerCertificateInvalid +#define Conn_TLSClosedWhenWriting TLSClosedWhenWriting +#define Conn_TLSHandshakeTimedOut TLSHandshakeTimedOut +#define Conn_Protocol_Timeout Protocol_Timeout +#define Conn_Protocol_ObjTooBig Protocol_ObjTooBig +#define Conn_Protocol_BadCommandRecieved Protocol_BadCommandRecieved +#define Conn_Protocol_UnknownCommandRecieved Protocol_UnknownCommandRecieved +#define Conn_Protocol_TriedToExecuteReplyCommand Protocol_TriedToExecuteReplyCommand +#define Conn_Protocol_UnexpectedReply Protocol_UnexpectedReply +#define Conn_Protocol_HandshakeFailed Protocol_HandshakeFailed +#define Conn_Protocol_StreamWhenObjExpected Protocol_StreamWhenObjExpected +#define Conn_Protocol_ObjWhenStreamExpected Protocol_ObjWhenStreamExpected +#define Conn_Protocol_TimeOutWhenSendingStream Protocol_TimeOutWhenSendingStream + +#endif // SERVEREXCEPTION__H + diff --git a/lib/server/ServerException.txt b/lib/server/ServerException.txt new file mode 100644 index 00000000..ed591b73 --- /dev/null +++ b/lib/server/ServerException.txt @@ -0,0 +1,39 @@ +EXCEPTION Server 3 + +# for historic reasons, some codes are not used + +Internal 0 +FailedToLoadConfiguration 1 +DaemoniseFailed 2 +AlreadyDaemonConstructed 3 +BadSocketHandle 4 +DupError 5 +SocketAlreadyOpen 8 +SocketOpenError 10 +SocketPollError 11 +SocketCloseError 13 +SocketNameUNIXPathTooLong 14 +SocketBindError 16 Check the ListenAddresses directive in your config file -- must refer to local IP addresses only +SocketAcceptError 17 +ServerStreamBadListenAddrs 18 +ServerForkError 19 +ServerWaitOnChildError 20 +TooManySocketsInMultiListen 21 There is a limit on how many addresses you can listen on simulatiously. +ServerStreamTooManyListenAddresses 22 +TLSContextNotInitialised 23 +TLSAllocationFailed 24 +TLSLoadCertificatesFailed 25 +TLSLoadPrivateKeyFailed 26 +TLSLoadTrustedCAsFailed 27 +TLSSetCiphersFailed 28 +SSLLibraryInitialisationError 29 +TLSNoSSLObject 31 +TLSAlreadyHandshaked 35 +SocketSetNonBlockingFailed 40 +Protocol_BadUsage 43 +Protocol_UnsuitableStreamTypeForSending 51 +CantWriteToProtocolUncertainStream 53 +ProtocolUncertainStreamBadBlockHeader 54 +SocketPairFailed 55 +CouldNotChangePIDFileOwner 56 +SSLRandomInitFailed 57 Read from /dev/*random device failed diff --git a/lib/server/ServerStream.h b/lib/server/ServerStream.h new file mode 100644 index 00000000..e49dbcbe --- /dev/null +++ b/lib/server/ServerStream.h @@ -0,0 +1,418 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: ServerStream.h +// Purpose: Stream based server daemons +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- + +#ifndef SERVERSTREAM__H +#define SERVERSTREAM__H + +#include <stdlib.h> +#include <errno.h> + +#ifndef WIN32 + #include <sys/wait.h> +#endif + +#include "Daemon.h" +#include "SocketListen.h" +#include "Utils.h" +#include "Configuration.h" +#include "WaitForEvent.h" + +#include "MemLeakFindOn.h" + +// -------------------------------------------------------------------------- +// +// Class +// Name: ServerStream +// Purpose: Stream based server daemon +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- +template<typename StreamType, int Port, int ListenBacklog = 128, bool ForkToHandleRequests = true> +class ServerStream : public Daemon +{ +public: + ServerStream() + { + } + ~ServerStream() + { + DeleteSockets(); + } +private: + ServerStream(const ServerStream &rToCopy) + { + } +public: + + virtual const char *DaemonName() const + { + return "generic-stream-server"; + } + + virtual void OnIdle() { } + + virtual void Run() + { + // Set process title as appropriate + SetProcessTitle(ForkToHandleRequests?"server":"idle"); + + // Handle exceptions and child task quitting gracefully. + bool childExit = false; + try + { + Run2(childExit); + } + catch(BoxException &e) + { + if(childExit) + { + BOX_ERROR("Error in child process, " + "terminating connection: exception " << + e.what() << "(" << e.GetType() << + "/" << e.GetSubType() << ")"); + _exit(1); + } + else throw; + } + catch(std::exception &e) + { + if(childExit) + { + BOX_ERROR("Error in child process, " + "terminating connection: exception " << + e.what()); + _exit(1); + } + else throw; + } + catch(...) + { + if(childExit) + { + BOX_ERROR("Error in child process, " + "terminating connection: " + "unknown exception"); + _exit(1); + } + else throw; + } + + // if it's a child fork, exit the process now + if(childExit) + { + // Child task, dump leaks to trace, which we make sure is on + #ifdef BOX_MEMORY_LEAK_TESTING + #ifndef BOX_RELEASE_BUILD + TRACE_TO_SYSLOG(true); + TRACE_TO_STDOUT(true); + #endif + memleakfinder_traceblocksinsection(); + #endif + + // If this is a child quitting, exit now to stop bad things happening + _exit(0); + } + } + +protected: + virtual void NotifyListenerIsReady() { } + +public: + virtual void Run2(bool &rChildExit) + { + try + { + // Wait object with a timeout of 1 second, which + // is a reasonable time to wait before cleaning up + // finished child processes, and allows the daemon + // to terminate reasonably quickly on request. + WaitForEvent connectionWait(1000); + + // BLOCK + { + // Get the address we need to bind to + // this-> in next line required to build under some gcc versions + const Configuration &config(this->GetConfiguration()); + const Configuration &server(config.GetSubConfiguration("Server")); + std::string addrs = server.GetKeyValue("ListenAddresses"); + + // split up the list of addresses + std::vector<std::string> addrlist; + SplitString(addrs, ',', addrlist); + + for(unsigned int a = 0; a < addrlist.size(); ++a) + { + // split the address up into components + std::vector<std::string> c; + SplitString(addrlist[a], ':', c); + + // listen! + SocketListen<StreamType, ListenBacklog> *psocket = new SocketListen<StreamType, ListenBacklog>; + try + { + if(c[0] == "inet") + { + // Check arguments + if(c.size() != 2 && c.size() != 3) + { + THROW_EXCEPTION(ServerException, ServerStreamBadListenAddrs) + } + + // Which port? + int port = Port; + + if(c.size() == 3) + { + // Convert to number + port = ::atol(c[2].c_str()); + if(port <= 0 || port > ((64*1024)-1)) + { + THROW_EXCEPTION(ServerException, ServerStreamBadListenAddrs) + } + } + + // Listen + psocket->Listen(Socket::TypeINET, c[1].c_str(), port); + } + else if(c[0] == "unix") + { + #ifdef WIN32 + BOX_WARNING("Ignoring request to listen on a Unix socket on Windows: " << addrlist[a]); + delete psocket; + psocket = NULL; + #else + // Check arguments size + if(c.size() != 2) + { + THROW_EXCEPTION(ServerException, ServerStreamBadListenAddrs) + } + + // unlink anything there + ::unlink(c[1].c_str()); + + psocket->Listen(Socket::TypeUNIX, c[1].c_str()); + #endif // WIN32 + } + else + { + delete psocket; + THROW_EXCEPTION(ServerException, ServerStreamBadListenAddrs) + } + + if (psocket != NULL) + { + // Add to list of sockets + mSockets.push_back(psocket); + } + } + catch(...) + { + delete psocket; + throw; + } + + if (psocket != NULL) + { + // Add to the list of things to wait on + connectionWait.Add(psocket); + } + } + } + + NotifyListenerIsReady(); + + while(!StopRun()) + { + // Wait for a connection, or timeout + SocketListen<StreamType, ListenBacklog> *psocket + = (SocketListen<StreamType, ListenBacklog> *)connectionWait.Wait(); + + if(psocket) + { + // Get the incoming connection + // (with zero wait time) + std::string logMessage; + std::auto_ptr<StreamType> connection(psocket->Accept(0, &logMessage)); + + // Was there one (there should be...) + if(connection.get()) + { + // Since this is a template parameter, the if() will be optimised out by the compiler + #ifndef WIN32 // no fork on Win32 + if(ForkToHandleRequests && !IsSingleProcess()) + { + pid_t pid = ::fork(); + switch(pid) + { + case -1: + // Error! + THROW_EXCEPTION(ServerException, ServerForkError) + break; + + case 0: + // Child process + rChildExit = true; + // Close listening sockets + DeleteSockets(); + + // Set up daemon + EnterChild(); + SetProcessTitle("transaction"); + + // Memory leak test the forked process + #ifdef BOX_MEMORY_LEAK_TESTING + memleakfinder_startsectionmonitor(); + #endif + + // The derived class does some server magic with the connection + HandleConnection(*connection); + // Since rChildExit == true, the forked process will call _exit() on return from this fn + return; + + default: + // parent daemon process + break; + } + + // Log it + BOX_NOTICE("Message from child process " << pid << ": " << logMessage); + } + else + { + #endif // !WIN32 + // Just handle in this process + SetProcessTitle("handling"); + HandleConnection(*connection); + SetProcessTitle("idle"); + #ifndef WIN32 + } + #endif // !WIN32 + } + } + + OnIdle(); + + #ifndef WIN32 + // Clean up child processes (if forking daemon) + if(ForkToHandleRequests && !IsSingleProcess()) + { + WaitForChildren(); + } + #endif // !WIN32 + } + } + catch(...) + { + DeleteSockets(); + throw; + } + + // Delete the sockets + DeleteSockets(); + } + + #ifndef WIN32 // no waitpid() on Windows + void WaitForChildren() + { + int p = 0; + do + { + int status = 0; + p = ::waitpid(0 /* any child in process group */, + &status, WNOHANG); + + if(p == -1 && errno != ECHILD && errno != EINTR) + { + THROW_EXCEPTION(ServerException, + ServerWaitOnChildError) + } + else if(p == 0) + { + // no children exited, will return from + // function + } + else if(WIFEXITED(status)) + { + BOX_INFO("child process " << p << " " + "terminated normally"); + } + else if(WIFSIGNALED(status)) + { + int sig = WTERMSIG(status); + BOX_ERROR("child process " << p << " " + "terminated abnormally with " + "signal " << sig); + } + else + { + BOX_WARNING("something unknown happened " + "to child process " << p << ": " + "status = " << status); + } + } + while(p > 0); + } + #endif + + virtual void HandleConnection(StreamType &rStream) + { + Connection(rStream); + } + + virtual void Connection(StreamType &rStream) = 0; + +protected: + // For checking code in derived classes -- use if you have an algorithm which + // depends on the forking model in case someone changes it later. + bool WillForkToHandleRequests() + { + #ifdef WIN32 + return false; + #else + return ForkToHandleRequests && !IsSingleProcess(); + #endif // WIN32 + } + +private: + // -------------------------------------------------------------------------- + // + // Function + // Name: ServerStream::DeleteSockets() + // Purpose: Delete sockets + // Created: 9/3/04 + // + // -------------------------------------------------------------------------- + void DeleteSockets() + { + for(unsigned int l = 0; l < mSockets.size(); ++l) + { + if(mSockets[l]) + { + mSockets[l]->Close(); + delete mSockets[l]; + } + mSockets[l] = 0; + } + mSockets.clear(); + } + +private: + std::vector<SocketListen<StreamType, ListenBacklog> *> mSockets; +}; + +#define SERVERSTREAM_VERIFY_SERVER_KEYS(DEFAULT_ADDRESSES) \ + ConfigurationVerifyKey("ListenAddresses", 0, DEFAULT_ADDRESSES), \ + DAEMON_VERIFY_SERVER_KEYS + +#include "MemLeakFindOff.h" + +#endif // SERVERSTREAM__H + + + diff --git a/lib/server/ServerTLS.h b/lib/server/ServerTLS.h new file mode 100644 index 00000000..a74a671e --- /dev/null +++ b/lib/server/ServerTLS.h @@ -0,0 +1,80 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: ServerTLS.h +// Purpose: Implementation of a server using TLS streams +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- + +#ifndef SERVERTLS__H +#define SERVERTLS__H + +#include "ServerStream.h" +#include "SocketStreamTLS.h" +#include "SSLLib.h" +#include "TLSContext.h" + +// -------------------------------------------------------------------------- +// +// Class +// Name: ServerTLS +// Purpose: Implementation of a server using TLS streams +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- +template<int Port, int ListenBacklog = 128, bool ForkToHandleRequests = true> +class ServerTLS : public ServerStream<SocketStreamTLS, Port, ListenBacklog, ForkToHandleRequests> +{ +public: + ServerTLS() + { + // Safe to call this here, as the Daemon class makes sure there is only one instance every of a Daemon. + SSLLib::Initialise(); + } + + ~ServerTLS() + { + } +private: + ServerTLS(const ServerTLS &) + { + } +public: + + virtual void Run2(bool &rChildExit) + { + // First, set up the SSL context. + // Get parameters from the configuration + // this-> in next line required to build under some gcc versions + const Configuration &conf(this->GetConfiguration()); + const Configuration &serverconf(conf.GetSubConfiguration("Server")); + std::string certFile(serverconf.GetKeyValue("CertificateFile")); + std::string keyFile(serverconf.GetKeyValue("PrivateKeyFile")); + std::string caFile(serverconf.GetKeyValue("TrustedCAsFile")); + mContext.Initialise(true /* as server */, certFile.c_str(), keyFile.c_str(), caFile.c_str()); + + // Then do normal stream server stuff + ServerStream<SocketStreamTLS, Port, ListenBacklog, + ForkToHandleRequests>::Run2(rChildExit); + } + + virtual void HandleConnection(SocketStreamTLS &rStream) + { + rStream.Handshake(mContext, true /* is server */); + // this-> in next line required to build under some gcc versions + this->Connection(rStream); + } + +private: + TLSContext mContext; +}; + +#define SERVERTLS_VERIFY_SERVER_KEYS(DEFAULT_ADDRESSES) \ + ConfigurationVerifyKey("CertificateFile", ConfigTest_Exists), \ + ConfigurationVerifyKey("PrivateKeyFile", ConfigTest_Exists), \ + ConfigurationVerifyKey("TrustedCAsFile", ConfigTest_Exists), \ + SERVERSTREAM_VERIFY_SERVER_KEYS(DEFAULT_ADDRESSES) + +#endif // SERVERTLS__H + diff --git a/lib/server/Socket.cpp b/lib/server/Socket.cpp new file mode 100644 index 00000000..4a83bdb0 --- /dev/null +++ b/lib/server/Socket.cpp @@ -0,0 +1,184 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: Socket.cpp +// Purpose: Socket related stuff +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- + +#include "Box.h" + +#ifdef HAVE_UNISTD_H + #include <unistd.h> +#endif + +#include <sys/types.h> +#ifndef WIN32 +#include <sys/socket.h> +#include <netdb.h> +#include <netinet/in.h> +#include <arpa/inet.h> +#endif + +#include <string.h> +#include <stdio.h> + +#include "Socket.h" +#include "ServerException.h" + +#include "MemLeakFindOn.h" + +// -------------------------------------------------------------------------- +// +// Function +// Name: Socket::NameLookupToSockAddr(SocketAllAddr &, int, +// char *, int) +// Purpose: Sets up a sockaddr structure given a name and type +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- +void Socket::NameLookupToSockAddr(SocketAllAddr &addr, int &sockDomain, + enum Type Type, const std::string& rName, int Port, + int &rSockAddrLenOut) +{ + int sockAddrLen = 0; + + switch(Type) + { + case TypeINET: + sockDomain = AF_INET; + { + // Lookup hostname + struct hostent *phost = ::gethostbyname(rName.c_str()); + if(phost != NULL) + { + if(phost->h_addr_list[0] != 0) + { + sockAddrLen = sizeof(addr.sa_inet); +#ifdef HAVE_STRUCT_SOCKADDR_IN_SIN_LEN + addr.sa_inet.sin_len = sizeof(addr.sa_inet); +#endif + addr.sa_inet.sin_family = PF_INET; + addr.sa_inet.sin_port = htons(Port); + addr.sa_inet.sin_addr = *((in_addr*)phost->h_addr_list[0]); + for(unsigned int l = 0; l < sizeof(addr.sa_inet.sin_zero); ++l) + { + addr.sa_inet.sin_zero[l] = 0; + } + } + else + { + THROW_EXCEPTION(ConnectionException, Conn_SocketNameLookupError); + } + } + else + { + THROW_EXCEPTION(ConnectionException, Conn_SocketNameLookupError); + } + } + break; + +#ifndef WIN32 + case TypeUNIX: + sockDomain = AF_UNIX; + { + // Check length of name is OK + unsigned int nameLen = rName.length(); + if(nameLen >= (sizeof(addr.sa_unix.sun_path) - 1)) + { + THROW_EXCEPTION(ServerException, SocketNameUNIXPathTooLong); + } + sockAddrLen = nameLen + (((char*)(&(addr.sa_unix.sun_path[0]))) - ((char*)(&addr.sa_unix))); +#ifdef HAVE_STRUCT_SOCKADDR_IN_SIN_LEN + addr.sa_unix.sun_len = sockAddrLen; +#endif + addr.sa_unix.sun_family = PF_UNIX; + ::strncpy(addr.sa_unix.sun_path, rName.c_str(), + sizeof(addr.sa_unix.sun_path) - 1); + addr.sa_unix.sun_path[sizeof(addr.sa_unix.sun_path)-1] = 0; + } + break; +#endif + + default: + THROW_EXCEPTION(CommonException, BadArguments) + break; + } + + // Return size of structure to caller + rSockAddrLenOut = sockAddrLen; +} + + + + +// -------------------------------------------------------------------------- +// +// Function +// Name: Socket::LogIncomingConnection(const struct sockaddr *, socklen_t) +// Purpose: Writes a message logging the connection to syslog +// Created: 2003/08/01 +// +// -------------------------------------------------------------------------- +void Socket::LogIncomingConnection(const struct sockaddr *addr, socklen_t addrlen) +{ + if(addr == NULL) {THROW_EXCEPTION(CommonException, BadArguments)} + + switch(addr->sa_family) + { + case AF_UNIX: + BOX_INFO("Incoming connection from local (UNIX socket)"); + break; + + case AF_INET: + { + sockaddr_in *a = (sockaddr_in*)addr; + BOX_INFO("Incoming connection from " << + inet_ntoa(a->sin_addr) << " port " << + ntohs(a->sin_port)); + } + break; + + default: + BOX_WARNING("Incoming connection of unknown type"); + break; + } +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: Socket::IncomingConnectionLogMessage(const struct sockaddr *, socklen_t) +// Purpose: Returns a string for use in log messages +// Created: 2003/08/01 +// +// -------------------------------------------------------------------------- +std::string Socket::IncomingConnectionLogMessage(const struct sockaddr *addr, socklen_t addrlen) +{ + if(addr == NULL) {THROW_EXCEPTION(CommonException, BadArguments)} + + switch(addr->sa_family) + { + case AF_UNIX: + return std::string("Incoming connection from local (UNIX socket)"); + break; + + case AF_INET: + { + char msg[256]; // more than enough + sockaddr_in *a = (sockaddr_in*)addr; + sprintf(msg, "Incoming connection from %s port %d", inet_ntoa(a->sin_addr), ntohs(a->sin_port)); + return std::string(msg); + } + break; + + default: + return std::string("Incoming connection of unknown type"); + break; + } + + // Dummy. + return std::string(); +} + diff --git a/lib/server/Socket.h b/lib/server/Socket.h new file mode 100644 index 00000000..5034dbd8 --- /dev/null +++ b/lib/server/Socket.h @@ -0,0 +1,56 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: Socket.h +// Purpose: Socket related stuff +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- + +#ifndef SOCKET__H +#define SOCKET__H + +#ifdef WIN32 +#include "emu.h" +typedef int socklen_t; +#else +#include <sys/socket.h> +#include <netinet/in.h> +#include <sys/un.h> +#endif + +#include <string> + +typedef union { + struct sockaddr sa_generic; + struct sockaddr_in sa_inet; +#ifndef WIN32 + struct sockaddr_un sa_unix; +#endif +} SocketAllAddr; + +// -------------------------------------------------------------------------- +// +// Namespace +// Name: Socket +// Purpose: Socket utilities +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- +namespace Socket +{ + enum Type + { + TypeINET = 1, + TypeUNIX = 2 + }; + + void NameLookupToSockAddr(SocketAllAddr &addr, int &sockDomain, + enum Type type, const std::string& rName, int Port, + int &rSockAddrLenOut); + void LogIncomingConnection(const struct sockaddr *addr, socklen_t addrlen); + std::string IncomingConnectionLogMessage(const struct sockaddr *addr, socklen_t addrlen); +}; + +#endif // SOCKET__H + diff --git a/lib/server/SocketListen.h b/lib/server/SocketListen.h new file mode 100644 index 00000000..586adf22 --- /dev/null +++ b/lib/server/SocketListen.h @@ -0,0 +1,312 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: SocketListen.h +// Purpose: Stream based sockets for servers +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- + +#ifndef SOCKETLISTEN__H +#define SOCKETLISTEN__H + +#include <errno.h> + +#ifdef HAVE_UNISTD_H + #include <unistd.h> +#endif + +#ifdef HAVE_KQUEUE + #include <sys/event.h> + #include <sys/time.h> +#endif + +#ifndef WIN32 + #include <poll.h> +#endif + +#include <new> +#include <memory> +#include <string> + +#include "Socket.h" +#include "ServerException.h" + +#include "MemLeakFindOn.h" + +// -------------------------------------------------------------------------- +// +// Class +// Name: _NoSocketLocking +// Purpose: Default locking class for SocketListen +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- +class _NoSocketLocking +{ +public: + _NoSocketLocking(int sock) + { + } + + ~_NoSocketLocking() + { + } + + bool HaveLock() + { + return true; + } + +private: + _NoSocketLocking(const _NoSocketLocking &rToCopy) + { + } +}; + + +// -------------------------------------------------------------------------- +// +// Class +// Name: SocketListen +// Purpose: +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- +template<typename SocketType, int ListenBacklog = 128, typename SocketLockingType = _NoSocketLocking, int MaxMultiListenSockets = 16> +class SocketListen +{ +public: + // Initialise + SocketListen() + : mSocketHandle(-1) + { + } + // Close socket nicely + ~SocketListen() + { + Close(); + } +private: + SocketListen(const SocketListen &rToCopy) + { + } +public: + + enum + { + MaxMultipleListenSockets = MaxMultiListenSockets + }; + + void Close() + { + if(mSocketHandle != -1) + { +#ifdef WIN32 + if(::closesocket(mSocketHandle) == -1) +#else + if(::close(mSocketHandle) == -1) +#endif + { + BOX_LOG_SYS_ERROR("Failed to close network " + "socket"); + THROW_EXCEPTION(ServerException, + SocketCloseError) + } + } + mSocketHandle = -1; + } + + // ------------------------------------------------------------------ + // + // Function + // Name: SocketListen::Listen(int, char*, int) + // Purpose: Initialises, starts the socket listening. + // Created: 2003/07/31 + // + // ------------------------------------------------------------------ + void Listen(Socket::Type Type, const char *Name, int Port = 0) + { + if(mSocketHandle != -1) + { + THROW_EXCEPTION(ServerException, SocketAlreadyOpen); + } + + // Setup parameters based on type, looking up names if required + int sockDomain = 0; + SocketAllAddr addr; + int addrLen = 0; + Socket::NameLookupToSockAddr(addr, sockDomain, Type, Name, + Port, addrLen); + + // Create the socket + mSocketHandle = ::socket(sockDomain, SOCK_STREAM, + 0 /* let OS choose protocol */); + if(mSocketHandle == -1) + { + BOX_LOG_SYS_ERROR("Failed to create a network socket"); + THROW_EXCEPTION(ServerException, SocketOpenError) + } + + // Set an option to allow reuse (useful for -HUP situations!) +#ifdef WIN32 + if(::setsockopt(mSocketHandle, SOL_SOCKET, SO_REUSEADDR, "", + 0) == -1) +#else + int option = true; + if(::setsockopt(mSocketHandle, SOL_SOCKET, SO_REUSEADDR, + &option, sizeof(option)) == -1) +#endif + { + BOX_LOG_SYS_ERROR("Failed to set socket options"); + THROW_EXCEPTION(ServerException, SocketOpenError) + } + + // Bind it to the right port, and start listening + if(::bind(mSocketHandle, &addr.sa_generic, addrLen) == -1 + || ::listen(mSocketHandle, ListenBacklog) == -1) + { + // Dispose of the socket + ::close(mSocketHandle); + mSocketHandle = -1; + THROW_EXCEPTION(ServerException, SocketBindError) + } + } + + // ------------------------------------------------------------------ + // + // Function + // Name: SocketListen::Accept(int) + // Purpose: Accepts a connection, returning a pointer to + // a class of the specified type. May return a + // null pointer if a signal happens, or there's + // a timeout. Timeout specified in + // milliseconds, defaults to infinite time. + // Created: 2003/07/31 + // + // ------------------------------------------------------------------ + std::auto_ptr<SocketType> Accept(int Timeout = INFTIM, + std::string *pLogMsg = 0) + { + if(mSocketHandle == -1) + { + THROW_EXCEPTION(ServerException, BadSocketHandle); + } + + // Do the accept, using the supplied locking type + int sock; + struct sockaddr addr; + socklen_t addrlen = sizeof(addr); + // BLOCK + { + SocketLockingType socklock(mSocketHandle); + + if(!socklock.HaveLock()) + { + // Didn't get the lock for some reason. + // Wait a while, then return nothing. + BOX_ERROR("Failed to get a lock on incoming " + "connection"); + ::sleep(1); + return std::auto_ptr<SocketType>(); + } + + // poll this socket + struct pollfd p; + p.fd = mSocketHandle; + p.events = POLLIN; + p.revents = 0; + switch(::poll(&p, 1, Timeout)) + { + case -1: + // signal? + if(errno == EINTR) + { + BOX_ERROR("Failed to accept " + "connection: interrupted by " + "signal"); + // return nothing + return std::auto_ptr<SocketType>(); + } + else + { + BOX_LOG_SYS_ERROR("Failed to poll " + "connection"); + THROW_EXCEPTION(ServerException, + SocketPollError) + } + break; + case 0: // timed out + return std::auto_ptr<SocketType>(); + break; + default: // got some thing... + // control flows on... + break; + } + + sock = ::accept(mSocketHandle, &addr, &addrlen); + } + + // Got socket (or error), unlock (implicit in destruction) + if(sock == -1) + { + BOX_LOG_SYS_ERROR("Failed to accept connection"); + THROW_EXCEPTION(ServerException, SocketAcceptError) + } + + // Log it + if(pLogMsg) + { + *pLogMsg = Socket::IncomingConnectionLogMessage(&addr, + addrlen); + } + else + { + // Do logging ourselves + Socket::LogIncomingConnection(&addr, addrlen); + } + + return std::auto_ptr<SocketType>(new SocketType(sock)); + } + + // Functions to allow adding to WaitForEvent class, for efficient waiting + // on multiple sockets. +#ifdef HAVE_KQUEUE + // ------------------------------------------------------------------ + // + // Function + // Name: SocketListen::FillInKEevent + // Purpose: Fills in a kevent structure for this socket + // Created: 9/3/04 + // + // ------------------------------------------------------------------ + void FillInKEvent(struct kevent &rEvent, int Flags = 0) const + { + EV_SET(&rEvent, mSocketHandle, EVFILT_READ, 0, 0, 0, + (void*)this); + } +#else + // ------------------------------------------------------------------ + // + // Function + // Name: SocketListen::FillInPoll + // Purpose: Fills in the data necessary for a poll + // operation + // Created: 9/3/04 + // + // ------------------------------------------------------------------ + void FillInPoll(int &fd, short &events, int Flags = 0) const + { + fd = mSocketHandle; + events = POLLIN; + } +#endif + +private: + int mSocketHandle; +}; + +#include "MemLeakFindOff.h" + +#endif // SOCKETLISTEN__H + diff --git a/lib/server/SocketStream.cpp b/lib/server/SocketStream.cpp new file mode 100644 index 00000000..95b4b4f4 --- /dev/null +++ b/lib/server/SocketStream.cpp @@ -0,0 +1,514 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: SocketStream.cpp +// Purpose: I/O stream interface for sockets +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- + +#include "Box.h" + +#ifdef HAVE_UNISTD_H + #include <unistd.h> +#endif + +#include <sys/types.h> +#include <errno.h> +#include <string.h> + +#ifndef WIN32 + #include <poll.h> +#endif + +#ifdef HAVE_UCRED_H + #include <ucred.h> +#endif + +#include "SocketStream.h" +#include "ServerException.h" +#include "CommonException.h" +#include "Socket.h" + +#include "MemLeakFindOn.h" + +// -------------------------------------------------------------------------- +// +// Function +// Name: SocketStream::SocketStream() +// Purpose: Constructor (create stream ready for Open() call) +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- +SocketStream::SocketStream() + : mSocketHandle(INVALID_SOCKET_VALUE), + mReadClosed(false), + mWriteClosed(false), + mBytesRead(0), + mBytesWritten(0) +{ +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: SocketStream::SocketStream(int) +// Purpose: Create stream from existing socket handle +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- +SocketStream::SocketStream(int socket) + : mSocketHandle(socket), + mReadClosed(false), + mWriteClosed(false), + mBytesRead(0), + mBytesWritten(0) +{ + if(socket < 0) + { + THROW_EXCEPTION(ServerException, BadSocketHandle); + } +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: SocketStream::SocketStream(const SocketStream &) +// Purpose: Copy constructor (dup()s socket) +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- +SocketStream::SocketStream(const SocketStream &rToCopy) + : mSocketHandle(::dup(rToCopy.mSocketHandle)), + mReadClosed(rToCopy.mReadClosed), + mWriteClosed(rToCopy.mWriteClosed), + mBytesRead(rToCopy.mBytesRead), + mBytesWritten(rToCopy.mBytesWritten) + +{ + if(rToCopy.mSocketHandle < 0) + { + THROW_EXCEPTION(ServerException, BadSocketHandle); + } + if(mSocketHandle == INVALID_SOCKET_VALUE) + { + THROW_EXCEPTION(ServerException, DupError); + } +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: SocketStream::~SocketStream() +// Purpose: Destructor, closes stream if open +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- +SocketStream::~SocketStream() +{ + if(mSocketHandle != INVALID_SOCKET_VALUE) + { + Close(); + } +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: SocketStream::Attach(int) +// Purpose: Attach a socket handle to this stream +// Created: 11/12/03 +// +// -------------------------------------------------------------------------- +void SocketStream::Attach(int socket) +{ + if(mSocketHandle != INVALID_SOCKET_VALUE) + { + THROW_EXCEPTION(ServerException, SocketAlreadyOpen) + } + + ResetCounters(); + + mSocketHandle = socket; + mReadClosed = false; + mWriteClosed = false; +} + + +// -------------------------------------------------------------------------- +// +// Function +// Name: SocketStream::Open(Socket::Type, char *, int) +// Purpose: Opens a connection to a listening socket (INET or UNIX) +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- +void SocketStream::Open(Socket::Type Type, const std::string& rName, int Port) +{ + if(mSocketHandle != INVALID_SOCKET_VALUE) + { + THROW_EXCEPTION(ServerException, SocketAlreadyOpen) + } + + // Setup parameters based on type, looking up names if required + int sockDomain = 0; + SocketAllAddr addr; + int addrLen = 0; + Socket::NameLookupToSockAddr(addr, sockDomain, Type, rName, Port, addrLen); + + // Create the socket + mSocketHandle = ::socket(sockDomain, SOCK_STREAM, + 0 /* let OS choose protocol */); + if(mSocketHandle == INVALID_SOCKET_VALUE) + { + BOX_LOG_SYS_ERROR("Failed to create a network socket"); + THROW_EXCEPTION(ServerException, SocketOpenError) + } + + // Connect it + if(::connect(mSocketHandle, &addr.sa_generic, addrLen) == -1) + { + // Dispose of the socket +#ifdef WIN32 + DWORD err = WSAGetLastError(); + ::closesocket(mSocketHandle); + BOX_LOG_WIN_ERROR_NUMBER("Failed to connect to socket " + "(type " << Type << ", name " << rName << + ", port " << Port << ")", err); +#else // !WIN32 + BOX_LOG_SYS_ERROR("Failed to connect to socket (type " << + Type << ", name " << rName << ", port " << Port << + ")"); + ::close(mSocketHandle); +#endif // WIN32 + + mSocketHandle = INVALID_SOCKET_VALUE; + THROW_EXCEPTION(ConnectionException, Conn_SocketConnectError) + } + + ResetCounters(); + + mReadClosed = false; + mWriteClosed = false; +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: SocketStream::Read(void *pBuffer, int NBytes) +// Purpose: Reads data from stream. Maybe returns less than asked for. +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- +int SocketStream::Read(void *pBuffer, int NBytes, int Timeout) +{ + if(mSocketHandle == INVALID_SOCKET_VALUE) + { + THROW_EXCEPTION(ServerException, BadSocketHandle) + } + + if(Timeout != IOStream::TimeOutInfinite) + { + struct pollfd p; + p.fd = mSocketHandle; + p.events = POLLIN; + p.revents = 0; + switch(::poll(&p, 1, (Timeout == IOStream::TimeOutInfinite)?INFTIM:Timeout)) + { + case -1: + // error + if(errno == EINTR) + { + // Signal. Just return 0 bytes + return 0; + } + else + { + // Bad! + BOX_LOG_SYS_ERROR("Failed to poll socket"); + THROW_EXCEPTION(ServerException, + SocketPollError) + } + break; + + case 0: + // no data + return 0; + break; + + default: + // good to go! + break; + } + } + +#ifdef WIN32 + int r = ::recv(mSocketHandle, (char*)pBuffer, NBytes, 0); +#else + int r = ::read(mSocketHandle, pBuffer, NBytes); +#endif + if(r == -1) + { + if(errno == EINTR) + { + // Nothing could be read + return 0; + } + else + { + // Other error + BOX_LOG_SYS_ERROR("Failed to read from socket"); + THROW_EXCEPTION(ConnectionException, + Conn_SocketReadError); + } + } + + // Closed for reading? + if(r == 0) + { + mReadClosed = true; + } + + mBytesRead += r; + return r; +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: SocketStream::Write(void *pBuffer, int NBytes) +// Purpose: Writes data, blocking until it's all done. +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- +void SocketStream::Write(const void *pBuffer, int NBytes) +{ + if(mSocketHandle == INVALID_SOCKET_VALUE) + { + THROW_EXCEPTION(ServerException, BadSocketHandle) + } + + // Buffer in byte sized type. + ASSERT(sizeof(char) == 1); + const char *buffer = (char *)pBuffer; + + // Bytes left to send + int bytesLeft = NBytes; + + while(bytesLeft > 0) + { + // Try to send. +#ifdef WIN32 + int sent = ::send(mSocketHandle, buffer, bytesLeft, 0); +#else + int sent = ::write(mSocketHandle, buffer, bytesLeft); +#endif + if(sent == -1) + { + // Error. + mWriteClosed = true; // assume can't write again + BOX_LOG_SYS_ERROR("Failed to write to socket"); + THROW_EXCEPTION(ConnectionException, + Conn_SocketWriteError); + } + + // Knock off bytes sent + bytesLeft -= sent; + // Move buffer pointer + buffer += sent; + + mBytesWritten += sent; + + // Need to wait until it can send again? + if(bytesLeft > 0) + { + BOX_TRACE("Waiting to send data on socket " << + mSocketHandle << " (" << bytesLeft << + " of " << NBytes << " bytes left)"); + + // Wait for data to send. + struct pollfd p; + p.fd = mSocketHandle; + p.events = POLLOUT; + p.revents = 0; + + if(::poll(&p, 1, 16000 /* 16 seconds */) == -1) + { + // Don't exception if it's just a signal + if(errno != EINTR) + { + BOX_LOG_SYS_ERROR("Failed to poll " + "socket"); + THROW_EXCEPTION(ServerException, + SocketPollError) + } + } + } + } +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: SocketStream::Close() +// Purpose: Closes connection to remote socket +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- +void SocketStream::Close() +{ + if(mSocketHandle == INVALID_SOCKET_VALUE) + { + THROW_EXCEPTION(ServerException, BadSocketHandle) + } +#ifdef WIN32 + if(::closesocket(mSocketHandle) == -1) +#else + if(::close(mSocketHandle) == -1) +#endif + { + BOX_LOG_SYS_ERROR("Failed to close socket"); + // don't throw an exception here, assume that the socket was + // already closed or closing. + } + mSocketHandle = INVALID_SOCKET_VALUE; +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: SocketStream::Shutdown(bool, bool) +// Purpose: Shuts down a socket for further reading and/or writing +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- +void SocketStream::Shutdown(bool Read, bool Write) +{ + if(mSocketHandle == INVALID_SOCKET_VALUE) + { + THROW_EXCEPTION(ServerException, BadSocketHandle) + } + + // Do anything? + if(!Read && !Write) return; + + int how = SHUT_RDWR; + if(Read && !Write) how = SHUT_RD; + if(!Read && Write) how = SHUT_WR; + + // Shut it down! + if(::shutdown(mSocketHandle, how) == -1) + { + BOX_LOG_SYS_ERROR("Failed to shutdown socket"); + THROW_EXCEPTION(ConnectionException, Conn_SocketShutdownError) + } +} + + +// -------------------------------------------------------------------------- +// +// Function +// Name: SocketStream::StreamDataLeft() +// Purpose: Still capable of reading data? +// Created: 2003/08/02 +// +// -------------------------------------------------------------------------- +bool SocketStream::StreamDataLeft() +{ + return !mReadClosed; +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: SocketStream::StreamClosed() +// Purpose: Connection been closed? +// Created: 2003/08/02 +// +// -------------------------------------------------------------------------- +bool SocketStream::StreamClosed() +{ + return mWriteClosed; +} + + +// -------------------------------------------------------------------------- +// +// Function +// Name: SocketStream::GetSocketHandle() +// Purpose: Returns socket handle for this stream (derived classes only). +// Will exception if there's no valid socket. +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- +tOSSocketHandle SocketStream::GetSocketHandle() +{ + if(mSocketHandle == INVALID_SOCKET_VALUE) + { + THROW_EXCEPTION(ServerException, BadSocketHandle) + } + return mSocketHandle; +} + + +// -------------------------------------------------------------------------- +// +// Function +// Name: SocketStream::GetPeerCredentials(uid_t &, gid_t &) +// Purpose: Returns true if the peer credientials are available. +// (will work on UNIX domain sockets only) +// Created: 19/2/04 +// +// -------------------------------------------------------------------------- +bool SocketStream::GetPeerCredentials(uid_t &rUidOut, gid_t &rGidOut) +{ +#ifdef HAVE_GETPEEREID + uid_t remoteEUID = 0xffff; + gid_t remoteEGID = 0xffff; + + if(::getpeereid(mSocketHandle, &remoteEUID, &remoteEGID) == 0) + { + rUidOut = remoteEUID; + rGidOut = remoteEGID; + return true; + } +#endif + +#if HAVE_DECL_SO_PEERCRED + struct ucred cred; + socklen_t credLen = sizeof(cred); + + if(::getsockopt(mSocketHandle, SOL_SOCKET, SO_PEERCRED, &cred, + &credLen) == 0) + { + rUidOut = cred.uid; + rGidOut = cred.gid; + return true; + } + + BOX_LOG_SYS_ERROR("Failed to get peer credentials on socket"); +#endif + +#if defined HAVE_UCRED_H && HAVE_GETPEERUCRED + ucred_t *pucred = NULL; + if(::getpeerucred(mSocketHandle, &pucred) == 0) + { + rUidOut = ucred_geteuid(pucred); + rGidOut = ucred_getegid(pucred); + ucred_free(pucred); + if (rUidOut == -1 || rGidOut == -1) + { + BOX_ERROR("Failed to get peer credentials on " + "socket: insufficient information"); + return false; + } + return true; + } + + BOX_LOG_SYS_ERROR("Failed to get peer credentials on socket"); +#endif + + // Not available + return false; +} + diff --git a/lib/server/SocketStream.h b/lib/server/SocketStream.h new file mode 100644 index 00000000..2b582f21 --- /dev/null +++ b/lib/server/SocketStream.h @@ -0,0 +1,75 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: SocketStream.h +// Purpose: I/O stream interface for sockets +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- + +#ifndef SOCKETSTREAM__H +#define SOCKETSTREAM__H + +#include "IOStream.h" +#include "Socket.h" + +#ifdef WIN32 + typedef SOCKET tOSSocketHandle; + #define INVALID_SOCKET_VALUE (tOSSocketHandle)(-1) +#else + typedef int tOSSocketHandle; + #define INVALID_SOCKET_VALUE -1 +#endif + +// -------------------------------------------------------------------------- +// +// Class +// Name: SocketStream +// Purpose: Stream interface for sockets +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- +class SocketStream : public IOStream +{ +public: + SocketStream(); + SocketStream(int socket); + SocketStream(const SocketStream &rToCopy); + ~SocketStream(); + + void Open(Socket::Type Type, const std::string& rName, int Port = 0); + void Attach(int socket); + + virtual int Read(void *pBuffer, int NBytes, int Timeout = IOStream::TimeOutInfinite); + virtual void Write(const void *pBuffer, int NBytes); + virtual void Close(); + virtual bool StreamDataLeft(); + virtual bool StreamClosed(); + + virtual void Shutdown(bool Read = true, bool Write = true); + + virtual bool GetPeerCredentials(uid_t &rUidOut, gid_t &rGidOut); + +protected: + tOSSocketHandle GetSocketHandle(); + void MarkAsReadClosed() {mReadClosed = true;} + void MarkAsWriteClosed() {mWriteClosed = true;} + +private: + tOSSocketHandle mSocketHandle; + bool mReadClosed; + bool mWriteClosed; + +protected: + off_t mBytesRead; + off_t mBytesWritten; + +public: + off_t GetBytesRead() const {return mBytesRead;} + off_t GetBytesWritten() const {return mBytesWritten;} + void ResetCounters() {mBytesRead = mBytesWritten = 0;} + bool IsOpened() { return mSocketHandle != INVALID_SOCKET_VALUE; } +}; + +#endif // SOCKETSTREAM__H + diff --git a/lib/server/SocketStreamTLS.cpp b/lib/server/SocketStreamTLS.cpp new file mode 100644 index 00000000..19fdadd4 --- /dev/null +++ b/lib/server/SocketStreamTLS.cpp @@ -0,0 +1,492 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: SocketStreamTLS.cpp +// Purpose: Socket stream encrpyted and authenticated by TLS +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- + +#include "Box.h" + +#define TLS_CLASS_IMPLEMENTATION_CPP +#include <openssl/ssl.h> +#include <openssl/bio.h> +#include <errno.h> +#include <fcntl.h> + +#ifndef WIN32 +#include <poll.h> +#endif + +#include "SocketStreamTLS.h" +#include "SSLLib.h" +#include "ServerException.h" +#include "TLSContext.h" +#include "BoxTime.h" + +#include "MemLeakFindOn.h" + +// Allow 5 minutes to handshake (in milliseconds) +#define TLS_HANDSHAKE_TIMEOUT (5*60*1000) + +// -------------------------------------------------------------------------- +// +// Function +// Name: SocketStreamTLS::SocketStreamTLS() +// Purpose: Constructor +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- +SocketStreamTLS::SocketStreamTLS() + : mpSSL(0), mpBIO(0) +{ + ResetCounters(); +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: SocketStreamTLS::SocketStreamTLS(int) +// Purpose: Constructor, taking previously connected socket +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- +SocketStreamTLS::SocketStreamTLS(int socket) + : SocketStream(socket), + mpSSL(0), mpBIO(0) +{ +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: SocketStreamTLS::~SocketStreamTLS() +// Purpose: Destructor +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- +SocketStreamTLS::~SocketStreamTLS() +{ + if(mpSSL) + { + // Attempt to close to avoid problems + Close(); + + // And if that didn't work... + if(mpSSL) + { + ::SSL_free(mpSSL); + mpSSL = 0; + mpBIO = 0; // implicity freed by the SSL_free call + } + } + + // If we only got to creating that BIO. + if(mpBIO) + { + ::BIO_free(mpBIO); + mpBIO = 0; + } +} + + +// -------------------------------------------------------------------------- +// +// Function +// Name: SocketStreamTLS::Open(const TLSContext &, int, const char *, int) +// Purpose: Open connection, and perform TLS handshake +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- +void SocketStreamTLS::Open(const TLSContext &rContext, Socket::Type Type, + const std::string& rName, int Port) +{ + SocketStream::Open(Type, rName, Port); + Handshake(rContext); + ResetCounters(); +} + + +// -------------------------------------------------------------------------- +// +// Function +// Name: SocketStreamTLS::Handshake(const TLSContext &, bool) +// Purpose: Perform TLS handshake +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- +void SocketStreamTLS::Handshake(const TLSContext &rContext, bool IsServer) +{ + if(mpBIO || mpSSL) {THROW_EXCEPTION(ServerException, TLSAlreadyHandshaked)} + + // Create a BIO for this socket + mpBIO = ::BIO_new(::BIO_s_socket()); + if(mpBIO == 0) + { + SSLLib::LogError("creating socket bio"); + THROW_EXCEPTION(ServerException, TLSAllocationFailed) + } + + tOSSocketHandle socket = GetSocketHandle(); + BIO_set_fd(mpBIO, socket, BIO_NOCLOSE); + + // Then the SSL object + mpSSL = ::SSL_new(rContext.GetRawContext()); + if(mpSSL == 0) + { + SSLLib::LogError("creating SSL object"); + THROW_EXCEPTION(ServerException, TLSAllocationFailed) + } + + // Make the socket non-blocking so timeouts on Read work + +#ifdef WIN32 + u_long nonblocking = 1; + ioctlsocket(socket, FIONBIO, &nonblocking); +#else // !WIN32 + // This is more portable than using ioctl with FIONBIO + int statusFlags = 0; + if(::fcntl(socket, F_GETFL, &statusFlags) < 0 + || ::fcntl(socket, F_SETFL, statusFlags | O_NONBLOCK) == -1) + { + THROW_EXCEPTION(ServerException, SocketSetNonBlockingFailed) + } +#endif + + // FIXME: This is less portable than the above. However, it MAY be needed + // for cygwin, which has/had bugs with fcntl + // + // int nonblocking = true; + // if(::ioctl(socket, FIONBIO, &nonblocking) == -1) + // { + // THROW_EXCEPTION(ServerException, SocketSetNonBlockingFailed) + // } + + // Set the two to know about each other + ::SSL_set_bio(mpSSL, mpBIO, mpBIO); + + bool waitingForHandshake = true; + while(waitingForHandshake) + { + // Attempt to do the handshake + int r = 0; + if(IsServer) + { + r = ::SSL_accept(mpSSL); + } + else + { + r = ::SSL_connect(mpSSL); + } + + // check return code + int se; + switch((se = ::SSL_get_error(mpSSL, r))) + { + case SSL_ERROR_NONE: + // No error, handshake succeeded + waitingForHandshake = false; + break; + + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + // wait for the requried data + if(WaitWhenRetryRequired(se, TLS_HANDSHAKE_TIMEOUT) == false) + { + // timed out + THROW_EXCEPTION(ConnectionException, Conn_TLSHandshakeTimedOut) + } + break; + + default: // (and SSL_ERROR_ZERO_RETURN) + // Error occured + if(IsServer) + { + SSLLib::LogError("accepting connection"); + THROW_EXCEPTION(ConnectionException, Conn_TLSHandshakeFailed) + } + else + { + SSLLib::LogError("connecting"); + THROW_EXCEPTION(ConnectionException, Conn_TLSHandshakeFailed) + } + } + } + + // And that's it +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: WaitWhenRetryRequired(int, int) +// Purpose: Waits until the condition required by the TLS layer is met. +// Returns true if the condition is met, false if timed out. +// Created: 2003/08/15 +// +// -------------------------------------------------------------------------- +bool SocketStreamTLS::WaitWhenRetryRequired(int SSLErrorCode, int Timeout) +{ + struct pollfd p; + p.fd = GetSocketHandle(); + switch(SSLErrorCode) + { + case SSL_ERROR_WANT_READ: + p.events = POLLIN; + break; + + case SSL_ERROR_WANT_WRITE: + p.events = POLLOUT; + break; + + default: + // Not good! + THROW_EXCEPTION(ServerException, Internal) + break; + } + p.revents = 0; + + int64_t start, end; + start = BoxTimeToMilliSeconds(GetCurrentBoxTime()); + end = start + Timeout; + int result; + + do + { + int64_t now = BoxTimeToMilliSeconds(GetCurrentBoxTime()); + int poll_timeout = (int)(end - now); + if (poll_timeout < 0) poll_timeout = 0; + if (Timeout == IOStream::TimeOutInfinite) + { + poll_timeout = INFTIM; + } + result = ::poll(&p, 1, poll_timeout); + } + while(result == -1 && errno == EINTR); + + switch(result) + { + case -1: + // error - Bad! + THROW_EXCEPTION(ServerException, SocketPollError) + break; + + case 0: + // Condition not met, timed out + return false; + break; + + default: + // good to go! + return true; + break; + } + + return true; +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: SocketStreamTLS::Read(void *, int, int Timeout) +// Purpose: See base class +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- +int SocketStreamTLS::Read(void *pBuffer, int NBytes, int Timeout) +{ + if(!mpSSL) {THROW_EXCEPTION(ServerException, TLSNoSSLObject)} + + // Make sure zero byte reads work as expected + if(NBytes == 0) + { + return 0; + } + + while(true) + { + int r = ::SSL_read(mpSSL, pBuffer, NBytes); + + int se; + switch((se = ::SSL_get_error(mpSSL, r))) + { + case SSL_ERROR_NONE: + // No error, return number of bytes read + mBytesRead += r; + return r; + break; + + case SSL_ERROR_ZERO_RETURN: + // Connection closed + MarkAsReadClosed(); + return 0; + break; + + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + // wait for the required data + // Will only get once around this loop, so don't need to calculate timeout values + if(WaitWhenRetryRequired(se, Timeout) == false) + { + // timed out + return 0; + } + break; + + default: + SSLLib::LogError("reading"); + THROW_EXCEPTION(ConnectionException, Conn_TLSReadFailed) + break; + } + } +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: SocketStreamTLS::Write(const void *, int) +// Purpose: See base class +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- +void SocketStreamTLS::Write(const void *pBuffer, int NBytes) +{ + if(!mpSSL) {THROW_EXCEPTION(ServerException, TLSNoSSLObject)} + + // Make sure zero byte writes work as expected + if(NBytes == 0) + { + return; + } + + // from man SSL_write + // + // SSL_write() will only return with success, when the + // complete contents of buf of length num has been written. + // + // So no worries about partial writes and moving the buffer around + + while(true) + { + // try the write + int r = ::SSL_write(mpSSL, pBuffer, NBytes); + + int se; + switch((se = ::SSL_get_error(mpSSL, r))) + { + case SSL_ERROR_NONE: + // No error, data sent, return success + mBytesWritten += r; + return; + break; + + case SSL_ERROR_ZERO_RETURN: + // Connection closed + MarkAsWriteClosed(); + THROW_EXCEPTION(ConnectionException, Conn_TLSClosedWhenWriting) + break; + + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + // wait for the requried data + { + #ifndef BOX_RELEASE_BUILD + bool conditionmet = + #endif + WaitWhenRetryRequired(se, IOStream::TimeOutInfinite); + ASSERT(conditionmet); + } + break; + + default: + SSLLib::LogError("writing"); + THROW_EXCEPTION(ConnectionException, Conn_TLSWriteFailed) + break; + } + } +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: SocketStreamTLS::Close() +// Purpose: See base class +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- +void SocketStreamTLS::Close() +{ + if(!mpSSL) {THROW_EXCEPTION(ServerException, TLSNoSSLObject)} + + // Base class to close + SocketStream::Close(); + + // Free resources + ::SSL_free(mpSSL); + mpSSL = 0; + mpBIO = 0; // implicitly freed by SSL_free +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: SocketStreamTLS::Shutdown() +// Purpose: See base class +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- +void SocketStreamTLS::Shutdown(bool Read, bool Write) +{ + if(!mpSSL) {THROW_EXCEPTION(ServerException, TLSNoSSLObject)} + + if(::SSL_shutdown(mpSSL) < 0) + { + SSLLib::LogError("shutting down"); + THROW_EXCEPTION(ConnectionException, Conn_TLSShutdownFailed) + } + + // Don't ask the base class to shutdown -- BIO does this, apparently. +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: SocketStreamTLS::GetPeerCommonName() +// Purpose: Returns the common name of the other end of the connection +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- +std::string SocketStreamTLS::GetPeerCommonName() +{ + if(!mpSSL) {THROW_EXCEPTION(ServerException, TLSNoSSLObject)} + + // Get certificate + X509 *cert = ::SSL_get_peer_certificate(mpSSL); + if(cert == 0) + { + ::X509_free(cert); + THROW_EXCEPTION(ConnectionException, Conn_TLSNoPeerCertificate) + } + + // Subject details + X509_NAME *subject = ::X509_get_subject_name(cert); + if(subject == 0) + { + ::X509_free(cert); + THROW_EXCEPTION(ConnectionException, Conn_TLSPeerCertificateInvalid) + } + + // Common name + char commonName[256]; + if(::X509_NAME_get_text_by_NID(subject, NID_commonName, commonName, sizeof(commonName)) <= 0) + { + ::X509_free(cert); + THROW_EXCEPTION(ConnectionException, Conn_TLSPeerCertificateInvalid) + } + // Terminate just in case + commonName[sizeof(commonName)-1] = '\0'; + + // Done. + return std::string(commonName); +} diff --git a/lib/server/SocketStreamTLS.h b/lib/server/SocketStreamTLS.h new file mode 100644 index 00000000..bb40ed10 --- /dev/null +++ b/lib/server/SocketStreamTLS.h @@ -0,0 +1,61 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: SocketStreamTLS.h +// Purpose: Socket stream encrpyted and authenticated by TLS +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- + +#ifndef SOCKETSTREAMTLS__H +#define SOCKETSTREAMTLS__H + +#include <string> + +#include "SocketStream.h" + +class TLSContext; +#ifndef TLS_CLASS_IMPLEMENTATION_CPP + class SSL; + class BIO; +#endif + +// -------------------------------------------------------------------------- +// +// Class +// Name: SocketStreamTLS +// Purpose: Socket stream encrpyted and authenticated by TLS +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- +class SocketStreamTLS : public SocketStream +{ +public: + SocketStreamTLS(); + SocketStreamTLS(int socket); + ~SocketStreamTLS(); +private: + SocketStreamTLS(const SocketStreamTLS &rToCopy); +public: + + void Open(const TLSContext &rContext, Socket::Type Type, + const std::string& rName, int Port = 0); + void Handshake(const TLSContext &rContext, bool IsServer = false); + + virtual int Read(void *pBuffer, int NBytes, int Timeout = IOStream::TimeOutInfinite); + virtual void Write(const void *pBuffer, int NBytes); + virtual void Close(); + virtual void Shutdown(bool Read = true, bool Write = true); + + std::string GetPeerCommonName(); + +private: + bool WaitWhenRetryRequired(int SSLErrorCode, int Timeout); + +private: + SSL *mpSSL; + BIO *mpBIO; +}; + +#endif // SOCKETSTREAMTLS__H + diff --git a/lib/server/TLSContext.cpp b/lib/server/TLSContext.cpp new file mode 100644 index 00000000..ebc7384a --- /dev/null +++ b/lib/server/TLSContext.cpp @@ -0,0 +1,131 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: TLSContext.h +// Purpose: TLS (SSL) context for connections +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- + +#include "Box.h" + +#define TLS_CLASS_IMPLEMENTATION_CPP +#include <openssl/ssl.h> + +#include "TLSContext.h" +#include "ServerException.h" +#include "SSLLib.h" +#include "TLSContext.h" + +#include "MemLeakFindOn.h" + +#define MAX_VERIFICATION_DEPTH 2 +#define CIPHER_LIST "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH" + +// -------------------------------------------------------------------------- +// +// Function +// Name: TLSContext::TLSContext() +// Purpose: Constructor +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- +TLSContext::TLSContext() + : mpContext(0) +{ +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: TLSContext::~TLSContext() +// Purpose: Destructor +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- +TLSContext::~TLSContext() +{ + if(mpContext != 0) + { + ::SSL_CTX_free(mpContext); + } +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: TLSContext::Initialise(bool, const char *, const char *, const char *) +// Purpose: Initialise the context, loading in the specified certificate and private key files +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- +void TLSContext::Initialise(bool AsServer, const char *CertificatesFile, const char *PrivateKeyFile, const char *TrustedCAsFile) +{ + if(mpContext != 0) + { + ::SSL_CTX_free(mpContext); + } + + mpContext = ::SSL_CTX_new(AsServer?TLSv1_server_method():TLSv1_client_method()); + if(mpContext == NULL) + { + THROW_EXCEPTION(ServerException, TLSAllocationFailed) + } + + // Setup our identity + if(::SSL_CTX_use_certificate_chain_file(mpContext, CertificatesFile) != 1) + { + std::string msg = "loading certificates from "; + msg += CertificatesFile; + SSLLib::LogError(msg); + THROW_EXCEPTION(ServerException, TLSLoadCertificatesFailed) + } + if(::SSL_CTX_use_PrivateKey_file(mpContext, PrivateKeyFile, SSL_FILETYPE_PEM) != 1) + { + std::string msg = "loading private key from "; + msg += PrivateKeyFile; + SSLLib::LogError(msg); + THROW_EXCEPTION(ServerException, TLSLoadPrivateKeyFailed) + } + + // Setup the identify of CAs we trust + if(::SSL_CTX_load_verify_locations(mpContext, TrustedCAsFile, NULL) != 1) + { + std::string msg = "loading CA cert from "; + msg += TrustedCAsFile; + SSLLib::LogError(msg); + THROW_EXCEPTION(ServerException, TLSLoadTrustedCAsFailed) + } + + // Setup options to require these certificates + ::SSL_CTX_set_verify(mpContext, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, NULL); + // and a sensible maximum depth + ::SSL_CTX_set_verify_depth(mpContext, MAX_VERIFICATION_DEPTH); + + // Setup allowed ciphers + if(::SSL_CTX_set_cipher_list(mpContext, CIPHER_LIST) != 1) + { + SSLLib::LogError("setting cipher list to " CIPHER_LIST); + THROW_EXCEPTION(ServerException, TLSSetCiphersFailed) + } +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: TLSContext::GetRawContext() +// Purpose: Get the raw context for OpenSSL API +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- +SSL_CTX *TLSContext::GetRawContext() const +{ + if(mpContext == 0) + { + THROW_EXCEPTION(ServerException, TLSContextNotInitialised) + } + return mpContext; +} + + + diff --git a/lib/server/TLSContext.h b/lib/server/TLSContext.h new file mode 100644 index 00000000..f52f5457 --- /dev/null +++ b/lib/server/TLSContext.h @@ -0,0 +1,41 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: TLSContext.h +// Purpose: TLS (SSL) context for connections +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- + +#ifndef TLSCONTEXT__H +#define TLSCONTEXT__H + +#ifndef TLS_CLASS_IMPLEMENTATION_CPP + class SSL_CTX; +#endif + +// -------------------------------------------------------------------------- +// +// Class +// Name: TLSContext +// Purpose: TLS (SSL) context for connections +// Created: 2003/08/06 +// +// -------------------------------------------------------------------------- +class TLSContext +{ +public: + TLSContext(); + ~TLSContext(); +private: + TLSContext(const TLSContext &); +public: + void Initialise(bool AsServer, const char *CertificatesFile, const char *PrivateKeyFile, const char *TrustedCAsFile); + SSL_CTX *GetRawContext() const; + +private: + SSL_CTX *mpContext; +}; + +#endif // TLSCONTEXT__H + diff --git a/lib/server/WinNamedPipeListener.h b/lib/server/WinNamedPipeListener.h new file mode 100644 index 00000000..26e76e3d --- /dev/null +++ b/lib/server/WinNamedPipeListener.h @@ -0,0 +1,232 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: WinNamedPipeListener.h +// Purpose: Windows named pipe socket connection listener +// for server +// Created: 2008/09/30 +// +// -------------------------------------------------------------------------- + +#ifndef WINNAMEDPIPELISTENER__H +#define WINNAMEDPIPELISTENER__H + +#include <OverlappedIO.h> +#include <WinNamedPipeStream.h> + +#include "ServerException.h" + +#include "MemLeakFindOn.h" + +// -------------------------------------------------------------------------- +// +// Class +// Name: WinNamedPipeListener +// Purpose: +// Created: 2008/09/30 +// +// -------------------------------------------------------------------------- +template<int ListenBacklog = 128> +class WinNamedPipeListener +{ +private: + std::auto_ptr<std::string> mapPipeName; + std::auto_ptr<OverlappedIO> mapOverlapConnect; + HANDLE mPipeHandle; + +public: + // Initialise + WinNamedPipeListener() + : mPipeHandle(INVALID_HANDLE_VALUE) + { } + +private: + WinNamedPipeListener(const WinNamedPipeListener &rToCopy) + { /* forbidden */ } + + HANDLE CreatePipeHandle(const std::string& rName) + { + std::string socket = WinNamedPipeStream::sPipeNamePrefix + + rName; + + HANDLE handle = CreateNamedPipeA( + socket.c_str(), // pipe name + PIPE_ACCESS_DUPLEX | // read/write access + FILE_FLAG_OVERLAPPED, // enabled overlapped I/O + PIPE_TYPE_BYTE | // message type pipe + PIPE_READMODE_BYTE | // message-read mode + PIPE_WAIT, // blocking mode + ListenBacklog + 1, // max. instances + 4096, // output buffer size + 4096, // input buffer size + NMPWAIT_USE_DEFAULT_WAIT, // client time-out + NULL); // default security attribute + + if (handle == INVALID_HANDLE_VALUE) + { + BOX_LOG_WIN_ERROR("Failed to create named pipe " << + socket); + THROW_EXCEPTION(ServerException, SocketOpenError) + } + + return handle; + } + +public: + ~WinNamedPipeListener() + { + Close(); + } + + void Close() + { + if (mPipeHandle != INVALID_HANDLE_VALUE) + { + if (mapOverlapConnect.get()) + { + // outstanding connect in progress + if (CancelIo(mPipeHandle) != TRUE) + { + BOX_LOG_WIN_ERROR("Failed to cancel " + "outstanding connect request " + "on named pipe"); + } + + mapOverlapConnect.reset(); + } + + if (CloseHandle(mPipeHandle) != TRUE) + { + BOX_LOG_WIN_ERROR("Failed to close named pipe " + "handle"); + } + + mPipeHandle = INVALID_HANDLE_VALUE; + } + } + + // ------------------------------------------------------------------ + // + // Function + // Name: WinNamedPipeListener::Listen(std::string name) + // Purpose: Initialises socket name + // Created: 2003/07/31 + // + // ------------------------------------------------------------------ + void Listen(const std::string& rName) + { + Close(); + mapPipeName.reset(new std::string(rName)); + mPipeHandle = CreatePipeHandle(rName); + } + + // ------------------------------------------------------------------ + // + // Function + // Name: WinNamedPipeListener::Accept(int) + // Purpose: Accepts a connection, returning a pointer to + // a class of the specified type. May return a + // null pointer if a signal happens, or there's + // a timeout. Timeout specified in + // milliseconds, defaults to infinite time. + // Created: 2003/07/31 + // + // ------------------------------------------------------------------ + std::auto_ptr<WinNamedPipeStream> Accept(int Timeout = INFTIM, + const char* pLogMsgOut = NULL) + { + if(!mapPipeName.get()) + { + THROW_EXCEPTION(ServerException, BadSocketHandle); + } + + BOOL connected = FALSE; + std::auto_ptr<WinNamedPipeStream> mapStream; + + if (!mapOverlapConnect.get()) + { + // start a new connect operation + mapOverlapConnect.reset(new OverlappedIO()); + connected = ConnectNamedPipe(mPipeHandle, + &mapOverlapConnect->mOverlapped); + + if (connected == FALSE) + { + if (GetLastError() == ERROR_PIPE_CONNECTED) + { + connected = TRUE; + } + else if (GetLastError() != ERROR_IO_PENDING) + { + BOX_LOG_WIN_ERROR("Failed to connect " + "named pipe"); + THROW_EXCEPTION(ServerException, + SocketAcceptError); + } + } + } + + if (connected == FALSE) + { + // wait for connection + DWORD result = WaitForSingleObject( + mapOverlapConnect->mOverlapped.hEvent, + (Timeout == INFTIM) ? INFINITE : Timeout); + + if (result == WAIT_OBJECT_0) + { + DWORD dummy; + + if (!GetOverlappedResult(mPipeHandle, + &mapOverlapConnect->mOverlapped, + &dummy, TRUE)) + { + BOX_LOG_WIN_ERROR("Failed to get " + "overlapped connect result"); + THROW_EXCEPTION(ServerException, + SocketAcceptError); + } + + connected = TRUE; + } + else if (result == WAIT_TIMEOUT) + { + return mapStream; // contains NULL + } + else if (result == WAIT_ABANDONED) + { + BOX_ERROR("Wait for named pipe connection " + "was abandoned by the system"); + THROW_EXCEPTION(ServerException, + SocketAcceptError); + } + else if (result == WAIT_FAILED) + { + BOX_LOG_WIN_ERROR("Failed to wait for named " + "pipe connection"); + THROW_EXCEPTION(ServerException, + SocketAcceptError); + } + else + { + BOX_ERROR("Failed to wait for named pipe " + "connection: unknown return code " << + result); + THROW_EXCEPTION(ServerException, + SocketAcceptError); + } + } + + ASSERT(connected == TRUE); + + mapStream.reset(new WinNamedPipeStream(mPipeHandle)); + mPipeHandle = CreatePipeHandle(*mapPipeName); + mapOverlapConnect.reset(); + + return mapStream; + } +}; + +#include "MemLeakFindOff.h" + +#endif // WINNAMEDPIPELISTENER__H diff --git a/lib/server/WinNamedPipeStream.cpp b/lib/server/WinNamedPipeStream.cpp new file mode 100644 index 00000000..1179516e --- /dev/null +++ b/lib/server/WinNamedPipeStream.cpp @@ -0,0 +1,620 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: WinNamedPipeStream.cpp +// Purpose: I/O stream interface for Win32 named pipes +// Created: 2005/12/07 +// +// -------------------------------------------------------------------------- + +#include "Box.h" + +#ifdef WIN32 + +#ifdef HAVE_UNISTD_H + #include <unistd.h> +#endif + +#include <sys/types.h> +#include <errno.h> +#include <windows.h> + +#include "WinNamedPipeStream.h" +#include "ServerException.h" +#include "CommonException.h" +#include "Socket.h" + +#include "MemLeakFindOn.h" + +std::string WinNamedPipeStream::sPipeNamePrefix = "\\\\.\\pipe\\"; + +// -------------------------------------------------------------------------- +// +// Function +// Name: WinNamedPipeStream::WinNamedPipeStream() +// Purpose: Constructor (create stream ready for Open() call) +// Created: 2005/12/07 +// +// -------------------------------------------------------------------------- +WinNamedPipeStream::WinNamedPipeStream() + : mSocketHandle(INVALID_HANDLE_VALUE), + mReadableEvent(INVALID_HANDLE_VALUE), + mBytesInBuffer(0), + mReadClosed(false), + mWriteClosed(false), + mIsServer(false), + mIsConnected(false) +{ } + +// -------------------------------------------------------------------------- +// +// Function +// Name: WinNamedPipeStream::WinNamedPipeStream(HANDLE) +// Purpose: Constructor (with already-connected pipe handle) +// Created: 2008/10/01 +// +// -------------------------------------------------------------------------- +WinNamedPipeStream::WinNamedPipeStream(HANDLE hNamedPipe) + : mSocketHandle(hNamedPipe), + mReadableEvent(INVALID_HANDLE_VALUE), + mBytesInBuffer(0), + mReadClosed(false), + mWriteClosed(false), + mIsServer(true), + mIsConnected(true) +{ + // create the Readable event + mReadableEvent = CreateEvent(NULL, TRUE, FALSE, NULL); + + if (mReadableEvent == INVALID_HANDLE_VALUE) + { + BOX_ERROR("Failed to create the Readable event: " << + GetErrorMessage(GetLastError())); + Close(); + THROW_EXCEPTION(CommonException, Internal) + } + + // initialise the OVERLAPPED structure + memset(&mReadOverlap, 0, sizeof(mReadOverlap)); + mReadOverlap.hEvent = mReadableEvent; + + // start the first overlapped read + if (!ReadFile(mSocketHandle, mReadBuffer, sizeof(mReadBuffer), + NULL, &mReadOverlap)) + { + DWORD err = GetLastError(); + + if (err != ERROR_IO_PENDING) + { + BOX_ERROR("Failed to start overlapped read: " << + GetErrorMessage(err)); + Close(); + THROW_EXCEPTION(ConnectionException, + Conn_SocketReadError) + } + } +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: WinNamedPipeStream::~WinNamedPipeStream() +// Purpose: Destructor, closes stream if open +// Created: 2005/12/07 +// +// -------------------------------------------------------------------------- +WinNamedPipeStream::~WinNamedPipeStream() +{ + if (mSocketHandle != INVALID_HANDLE_VALUE) + { + try + { + Close(); + } + catch (std::exception &e) + { + BOX_ERROR("Caught exception while destroying " + "named pipe, ignored: " << e.what()); + } + } +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: WinNamedPipeStream::Accept(const std::string& rName) +// Purpose: Creates a new named pipe with the given name, +// and wait for a connection on it +// Created: 2005/12/07 +// +// -------------------------------------------------------------------------- +/* +void WinNamedPipeStream::Accept() +{ + if (mSocketHandle == INVALID_HANDLE_VALUE) + { + THROW_EXCEPTION(ServerException, BadSocketHandle); + } + + if (mIsConnected) + { + THROW_EXCEPTION(ServerException, SocketAlreadyOpen); + } + + bool connected = ConnectNamedPipe(mSocketHandle, (LPOVERLAPPED) NULL); + + if (!connected) + { + BOX_ERROR("Failed to ConnectNamedPipe(" << socket << "): " << + GetErrorMessage(GetLastError())); + Close(); + THROW_EXCEPTION(ServerException, SocketOpenError) + } + + mBytesInBuffer = 0; + mReadClosed = false; + mWriteClosed = false; + mIsServer = true; // must flush and disconnect before closing + mIsConnected = true; + + // create the Readable event + mReadableEvent = CreateEvent(NULL, TRUE, FALSE, NULL); + + if (mReadableEvent == INVALID_HANDLE_VALUE) + { + BOX_ERROR("Failed to create the Readable event: " << + GetErrorMessage(GetLastError())); + Close(); + THROW_EXCEPTION(CommonException, Internal) + } + + // initialise the OVERLAPPED structure + memset(&mReadOverlap, 0, sizeof(mReadOverlap)); + mReadOverlap.hEvent = mReadableEvent; + + // start the first overlapped read + if (!ReadFile(mSocketHandle, mReadBuffer, sizeof(mReadBuffer), + NULL, &mReadOverlap)) + { + DWORD err = GetLastError(); + + if (err != ERROR_IO_PENDING) + { + BOX_ERROR("Failed to start overlapped read: " << + GetErrorMessage(err)); + Close(); + THROW_EXCEPTION(ConnectionException, + Conn_SocketReadError) + } + } +} +*/ + +// -------------------------------------------------------------------------- +// +// Function +// Name: WinNamedPipeStream::Connect(const std::string& rName) +// Purpose: Opens a connection to a listening named pipe +// Created: 2005/12/07 +// +// -------------------------------------------------------------------------- +void WinNamedPipeStream::Connect(const std::string& rName) +{ + if (mSocketHandle != INVALID_HANDLE_VALUE || mIsConnected) + { + THROW_EXCEPTION(ServerException, SocketAlreadyOpen) + } + + std::string socket = sPipeNamePrefix + rName; + + mSocketHandle = CreateFileA( + socket.c_str(), // pipe name + GENERIC_READ | // read and write access + GENERIC_WRITE, + 0, // no sharing + NULL, // default security attributes + OPEN_EXISTING, + 0, // default attributes + NULL); // no template file + + if (mSocketHandle == INVALID_HANDLE_VALUE) + { + DWORD err = GetLastError(); + if (err == ERROR_PIPE_BUSY) + { + BOX_ERROR("Failed to connect to backup daemon: " + "it is busy with another connection"); + } + else + { + BOX_ERROR("Failed to connect to backup daemon: " << + GetErrorMessage(err)); + } + THROW_EXCEPTION(ServerException, SocketOpenError) + } + + mReadClosed = false; + mWriteClosed = false; + mIsServer = false; // just close the socket + mIsConnected = true; +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: WinNamedPipeStream::Read(void *pBuffer, int NBytes) +// Purpose: Reads data from stream. Maybe returns less than asked for. +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- +int WinNamedPipeStream::Read(void *pBuffer, int NBytes, int Timeout) +{ + // TODO no support for timeouts yet + if (!mIsServer && Timeout != IOStream::TimeOutInfinite) + { + THROW_EXCEPTION(CommonException, AssertFailed) + } + + if (mSocketHandle == INVALID_HANDLE_VALUE || !mIsConnected) + { + THROW_EXCEPTION(ServerException, BadSocketHandle) + } + + if (mReadClosed) + { + THROW_EXCEPTION(ConnectionException, SocketShutdownError) + } + + // ensure safe to cast NBytes to unsigned + if (NBytes < 0) + { + THROW_EXCEPTION(CommonException, AssertFailed) + } + + DWORD NumBytesRead; + + if (mIsServer) + { + // satisfy from buffer if possible, to avoid + // blocking on read. + bool needAnotherRead = false; + if (mBytesInBuffer == 0) + { + // overlapped I/O completed successfully? + // (wait if needed) + DWORD waitResult = WaitForSingleObject( + mReadOverlap.hEvent, Timeout); + + if (waitResult == WAIT_ABANDONED) + { + BOX_ERROR("Wait for command socket read " + "abandoned by system"); + THROW_EXCEPTION(ServerException, + BadSocketHandle); + } + else if (waitResult == WAIT_TIMEOUT) + { + // wait timed out, nothing to read + NumBytesRead = 0; + } + else if (waitResult != WAIT_OBJECT_0) + { + BOX_ERROR("Failed to wait for command " + "socket read: unknown result " << + waitResult); + } + // object is ready to read from + else if (GetOverlappedResult(mSocketHandle, + &mReadOverlap, &NumBytesRead, TRUE)) + { + needAnotherRead = true; + } + else + { + DWORD err = GetLastError(); + + if (err == ERROR_HANDLE_EOF) + { + mReadClosed = true; + } + else + { + if (err == ERROR_BROKEN_PIPE) + { + BOX_NOTICE("Control client " + "disconnected"); + } + else + { + BOX_ERROR("Failed to wait for " + "ReadFile to complete: " + << GetErrorMessage(err)); + } + + Close(); + THROW_EXCEPTION(ConnectionException, + Conn_SocketReadError) + } + } + } + else + { + NumBytesRead = 0; + } + + size_t BytesToCopy = NumBytesRead + mBytesInBuffer; + size_t BytesRemaining = 0; + + if (BytesToCopy > (size_t)NBytes) + { + BytesRemaining = BytesToCopy - NBytes; + BytesToCopy = NBytes; + } + + memcpy(pBuffer, mReadBuffer, BytesToCopy); + memmove(mReadBuffer, mReadBuffer + BytesToCopy, BytesRemaining); + + mBytesInBuffer = BytesRemaining; + NumBytesRead = BytesToCopy; + + if (needAnotherRead) + { + // reinitialise the OVERLAPPED structure + memset(&mReadOverlap, 0, sizeof(mReadOverlap)); + mReadOverlap.hEvent = mReadableEvent; + } + + // start the next overlapped read + if (needAnotherRead && !ReadFile(mSocketHandle, + mReadBuffer + mBytesInBuffer, + sizeof(mReadBuffer) - mBytesInBuffer, + NULL, &mReadOverlap)) + { + DWORD err = GetLastError(); + if (err == ERROR_IO_PENDING) + { + // Don't reset yet, there might be data + // in the buffer waiting to be read, + // will check below. + // ResetEvent(mReadableEvent); + } + else if (err == ERROR_HANDLE_EOF) + { + mReadClosed = true; + } + else if (err == ERROR_BROKEN_PIPE) + { + BOX_ERROR("Control client disconnected"); + mReadClosed = true; + } + else + { + BOX_ERROR("Failed to start overlapped read: " + << GetErrorMessage(err)); + Close(); + THROW_EXCEPTION(ConnectionException, + Conn_SocketReadError) + } + } + } + else + { + if (!ReadFile( + mSocketHandle, // pipe handle + pBuffer, // buffer to receive reply + NBytes, // size of buffer + &NumBytesRead, // number of bytes read + NULL)) // not overlapped + { + DWORD err = GetLastError(); + + Close(); + + // ERROR_NO_DATA is a strange name for + // "The pipe is being closed". No exception wanted. + + if (err == ERROR_NO_DATA || + err == ERROR_PIPE_NOT_CONNECTED) + { + NumBytesRead = 0; + } + else + { + BOX_ERROR("Failed to read from control socket: " + << GetErrorMessage(err)); + THROW_EXCEPTION(ConnectionException, + Conn_SocketReadError) + } + } + + // Closed for reading at EOF? + if (NumBytesRead == 0) + { + mReadClosed = true; + } + } + + return NumBytesRead; +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: WinNamedPipeStream::Write(void *pBuffer, int NBytes) +// Purpose: Writes data, blocking until it's all done. +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- +void WinNamedPipeStream::Write(const void *pBuffer, int NBytes) +{ + if (mSocketHandle == INVALID_HANDLE_VALUE || !mIsConnected) + { + THROW_EXCEPTION(ServerException, BadSocketHandle) + } + + // Buffer in byte sized type. + ASSERT(sizeof(char) == 1); + const char *pByteBuffer = (char *)pBuffer; + + int NumBytesWrittenTotal = 0; + + while (NumBytesWrittenTotal < NBytes) + { + DWORD NumBytesWrittenThisTime = 0; + + bool Success = WriteFile( + mSocketHandle, // pipe handle + pByteBuffer + NumBytesWrittenTotal, // message + NBytes - NumBytesWrittenTotal, // message length + &NumBytesWrittenThisTime, // bytes written this time + NULL); // not overlapped + + if (!Success) + { + // ERROR_NO_DATA is a strange name for + // "The pipe is being closed". + + DWORD err = GetLastError(); + + if (err != ERROR_NO_DATA) + { + BOX_ERROR("Failed to write to control " + "socket: " << GetErrorMessage(err)); + } + + Close(); + + THROW_EXCEPTION(ConnectionException, + Conn_SocketWriteError) + } + + NumBytesWrittenTotal += NumBytesWrittenThisTime; + } +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: WinNamedPipeStream::Close() +// Purpose: Closes connection to remote socket +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- +void WinNamedPipeStream::Close() +{ + if (mSocketHandle == INVALID_HANDLE_VALUE && mIsConnected) + { + BOX_ERROR("Named pipe: inconsistent connected state"); + mIsConnected = false; + } + + if (mSocketHandle == INVALID_HANDLE_VALUE) + { + THROW_EXCEPTION(ServerException, BadSocketHandle) + } + + if (mIsServer) + { + if (!CancelIo(mSocketHandle)) + { + BOX_ERROR("Failed to cancel outstanding I/O: " << + GetErrorMessage(GetLastError())); + } + + if (mReadableEvent == INVALID_HANDLE_VALUE) + { + BOX_ERROR("Failed to destroy Readable event: " + "invalid handle"); + } + else if (!CloseHandle(mReadableEvent)) + { + BOX_ERROR("Failed to destroy Readable event: " << + GetErrorMessage(GetLastError())); + } + + mReadableEvent = INVALID_HANDLE_VALUE; + + if (!FlushFileBuffers(mSocketHandle)) + { + BOX_ERROR("Failed to FlushFileBuffers: " << + GetErrorMessage(GetLastError())); + } + + if (!DisconnectNamedPipe(mSocketHandle)) + { + DWORD err = GetLastError(); + if (err != ERROR_PIPE_NOT_CONNECTED) + { + BOX_ERROR("Failed to DisconnectNamedPipe: " << + GetErrorMessage(err)); + } + } + + mIsServer = false; + } + + bool result = CloseHandle(mSocketHandle); + + mSocketHandle = INVALID_HANDLE_VALUE; + mIsConnected = false; + mReadClosed = true; + mWriteClosed = true; + + if (!result) + { + BOX_ERROR("Failed to CloseHandle: " << + GetErrorMessage(GetLastError())); + THROW_EXCEPTION(ServerException, SocketCloseError) + } +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: WinNamedPipeStream::StreamDataLeft() +// Purpose: Still capable of reading data? +// Created: 2003/08/02 +// +// -------------------------------------------------------------------------- +bool WinNamedPipeStream::StreamDataLeft() +{ + return !mReadClosed; +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: WinNamedPipeStream::StreamClosed() +// Purpose: Connection been closed? +// Created: 2003/08/02 +// +// -------------------------------------------------------------------------- +bool WinNamedPipeStream::StreamClosed() +{ + return mWriteClosed; +} + +// -------------------------------------------------------------------------- +// +// Function +// Name: IOStream::WriteAllBuffered() +// Purpose: Ensures that any data which has been buffered is written to the stream +// Created: 2003/08/26 +// +// -------------------------------------------------------------------------- +void WinNamedPipeStream::WriteAllBuffered() +{ + if (mSocketHandle == INVALID_HANDLE_VALUE || !mIsConnected) + { + THROW_EXCEPTION(ServerException, BadSocketHandle) + } + + if (!FlushFileBuffers(mSocketHandle)) + { + BOX_ERROR("Failed to FlushFileBuffers: " << + GetErrorMessage(GetLastError())); + } +} + + +#endif // WIN32 diff --git a/lib/server/WinNamedPipeStream.h b/lib/server/WinNamedPipeStream.h new file mode 100644 index 00000000..386ff7e3 --- /dev/null +++ b/lib/server/WinNamedPipeStream.h @@ -0,0 +1,67 @@ +// -------------------------------------------------------------------------- +// +// File +// Name: WinNamedPipeStream.h +// Purpose: I/O stream interface for Win32 named pipes +// Created: 2005/12/07 +// +// -------------------------------------------------------------------------- + +#if ! defined WINNAMEDPIPESTREAM__H && defined WIN32 +#define WINNAMEDPIPESTREAM__H + +#include "IOStream.h" + +// -------------------------------------------------------------------------- +// +// Class +// Name: WinNamedPipeStream +// Purpose: I/O stream interface for Win32 named pipes +// Created: 2003/07/31 +// +// -------------------------------------------------------------------------- +class WinNamedPipeStream : public IOStream +{ +public: + WinNamedPipeStream(); + WinNamedPipeStream(HANDLE hNamedPipe); + ~WinNamedPipeStream(); + + // server side - create the named pipe and listen for connections + // use WinNamedPipeListener to do this instead. + + // client side - connect to a waiting server + void Connect(const std::string& rName); + + // both sides + virtual int Read(void *pBuffer, int NBytes, + int Timeout = IOStream::TimeOutInfinite); + virtual void Write(const void *pBuffer, int NBytes); + virtual void WriteAllBuffered(); + virtual void Close(); + virtual bool StreamDataLeft(); + virtual bool StreamClosed(); + +protected: + void MarkAsReadClosed() {mReadClosed = true;} + void MarkAsWriteClosed() {mWriteClosed = true;} + +private: + WinNamedPipeStream(const WinNamedPipeStream &rToCopy) + { /* do not call */ } + + HANDLE mSocketHandle; + HANDLE mReadableEvent; + OVERLAPPED mReadOverlap; + uint8_t mReadBuffer[4096]; + size_t mBytesInBuffer; + bool mReadClosed; + bool mWriteClosed; + bool mIsServer; + bool mIsConnected; + +public: + static std::string sPipeNamePrefix; +}; + +#endif // WINNAMEDPIPESTREAM__H diff --git a/lib/server/makeprotocol.pl.in b/lib/server/makeprotocol.pl.in new file mode 100755 index 00000000..91ba55b0 --- /dev/null +++ b/lib/server/makeprotocol.pl.in @@ -0,0 +1,1093 @@ +#!@PERL@ +use strict; + +use lib "../../infrastructure"; +use BoxPlatform; + +# Make protocol C++ classes from a protocol description file + +# built in type info (values are is basic type, C++ typename) +# may get stuff added to it later if protocol uses extra types +my %translate_type_info = +( + 'int64' => [1, 'int64_t'], + 'int32' => [1, 'int32_t'], + 'int16' => [1, 'int16_t'], + 'int8' => [1, 'int8_t'], + 'bool' => [1, 'bool'], + 'string' => [0, 'std::string'] +); + +# built in instructions for logging various types +# may be added to +my %log_display_types = +( + 'int64' => ['0x%llx', 'VAR'], + 'int32' => ['0x%x', 'VAR'], + 'int16' => ['0x%x', 'VAR'], + 'int8' => ['0x%x', 'VAR'], + 'bool' => ['%s', '((VAR)?"true":"false")'], + 'string' => ['%s', 'VAR.c_str()'] +); + + + +my ($type, $file) = @ARGV; + +if($type ne 'Server' && $type ne 'Client') +{ + die "Neither Server or Client is specified on command line\n"; +} + +open IN, $file or die "Can't open input file $file\n"; + +print "Making $type protocol classes from $file...\n"; + +my @extra_header_files; + +my $implement_syslog = 0; +my $implement_filelog = 0; + +# read attributes +my %attr; +while(<IN>) +{ + # get and clean line + my $l = $_; $l =~ s/#.*\Z//; $l =~ s/\A\s+//; $l =~ s/\s+\Z//; next unless $l =~ m/\S/; + + last if $l eq 'BEGIN_OBJECTS'; + + my ($k,$v) = split /\s+/,$l,2; + + if($k eq 'ClientType') + { + add_type($v) if $type eq 'Client'; + } + elsif($k eq 'ServerType') + { + add_type($v) if $type eq 'Server'; + } + elsif($k eq 'ImplementLog') + { + my ($log_if_type,$log_type) = split /\s+/,$v; + if($type eq $log_if_type) + { + if($log_type eq 'syslog') + { + $implement_syslog = 1; + } + elsif($log_type eq 'file') + { + $implement_filelog = 1; + } + else + { + printf("ERROR: Unknown log type for implementation: $log_type\n"); + exit(1); + } + } + } + elsif($k eq 'LogTypeToText') + { + my ($log_if_type,$type_name,$printf_format,$arg_template) = split /\s+/,$v; + if($type eq $log_if_type) + { + $log_display_types{$type_name} = [$printf_format,$arg_template] + } + } + else + { + $attr{$k} = $v; + } +} + +sub add_type +{ + my ($protocol_name, $cpp_name, $header_file) = split /\s+/,$_[0]; + + $translate_type_info{$protocol_name} = [0, $cpp_name]; + push @extra_header_files, $header_file; +} + +# check attributes +for(qw/Name ServerContextClass IdentString/) +{ + if(!exists $attr{$_}) + { + die "Attribute $_ is required, but not specified\n"; + } +} + +my $protocol_name = $attr{'Name'}; +my ($context_class, $context_class_inc) = split /\s+/,$attr{'ServerContextClass'}; +my $ident_string = $attr{'IdentString'}; + +my $current_cmd = ''; +my %cmd_contents; +my %cmd_attributes; +my %cmd_constants; +my %cmd_id; +my @cmd_list; + +# read in the command definitions +while(<IN>) +{ + # get and clean line + my $l = $_; $l =~ s/#.*\Z//; $l =~ s/\s+\Z//; next unless $l =~ m/\S/; + + # definitions or new command thing? + if($l =~ m/\A\s+/) + { + die "No command defined yet" if $current_cmd eq ''; + + # definition of component + $l =~ s/\A\s+//; + + my ($type,$name,$value) = split /\s+/,$l; + if($type eq 'CONSTANT') + { + push @{$cmd_constants{$current_cmd}},"$name = $value" + } + else + { + push @{$cmd_contents{$current_cmd}},$type,$name; + } + } + else + { + # new command + my ($name,$id,@attributes) = split /\s+/,$l; + $cmd_attributes{$name} = [@attributes]; + $cmd_id{$name} = int($id); + $current_cmd = $name; + push @cmd_list,$name; + } +} + +close IN; + + + +# open files +my $h_filename = 'autogen_'.$protocol_name.'Protocol'.$type.'.h'; +open CPP,'>autogen_'.$protocol_name.'Protocol'.$type.'.cpp'; +open H,">$h_filename"; + +print CPP <<__E; + +// Auto-generated file -- do not edit + +#include "Box.h" + +#include <sstream> + +#include "$h_filename" +#include "IOStream.h" + +__E + +if($implement_syslog) +{ + print H <<EOF; +#ifndef WIN32 +#include <syslog.h> +#endif +EOF +} + + +my $guardname = uc 'AUTOGEN_'.$protocol_name.'Protocol'.$type.'_H'; +print H <<__E; + +// Auto-generated file -- do not edit + +#ifndef $guardname +#define $guardname + +#include "Protocol.h" +#include "ProtocolObject.h" +#include "ServerException.h" + +class IOStream; + +__E + +if($implement_filelog) +{ + print H qq~#include <stdio.h>\n~; +} + +# extra headers +for(@extra_header_files) +{ + print H qq~#include "$_"\n~ +} +print H "\n"; + +if($type eq 'Server') +{ + # need utils file for the server + print H '#include "Utils.h"',"\n\n" +} + + +my $derive_objects_from = 'ProtocolObject'; +my $objects_extra_h = ''; +my $objects_extra_cpp = ''; +if($type eq 'Server') +{ + # define the context + print H "class $context_class;\n\n"; + print CPP "#include \"$context_class_inc\"\n\n"; + + # change class we derive the objects from + $derive_objects_from = $protocol_name.'ProtocolObject'; + + $objects_extra_h = <<__E; + virtual std::auto_ptr<ProtocolObject> DoCommand(${protocol_name}ProtocolServer &rProtocol, $context_class &rContext); +__E + $objects_extra_cpp = <<__E; +std::auto_ptr<ProtocolObject> ${derive_objects_from}::DoCommand(${protocol_name}ProtocolServer &rProtocol, $context_class &rContext) +{ + THROW_EXCEPTION(ConnectionException, Conn_Protocol_TriedToExecuteReplyCommand) +} +__E +} + +print CPP qq~#include "MemLeakFindOn.h"\n~; + +if($type eq 'Client' && ($implement_syslog || $implement_filelog)) +{ + # change class we derive the objects from + $derive_objects_from = $protocol_name.'ProtocolObjectCl'; +} +if($implement_syslog) +{ + $objects_extra_h .= <<__E; + virtual void LogSysLog(const char *Action) const = 0; +__E +} +if($implement_filelog) +{ + $objects_extra_h .= <<__E; + virtual void LogFile(const char *Action, FILE *file) const = 0; +__E +} + +if($derive_objects_from ne 'ProtocolObject') +{ + # output a definition for the protocol object derived class + print H <<__E; +class ${protocol_name}ProtocolServer; + +class $derive_objects_from : public ProtocolObject +{ +public: + $derive_objects_from(); + virtual ~$derive_objects_from(); + $derive_objects_from(const $derive_objects_from &rToCopy); + +$objects_extra_h +}; +__E + + # and some cpp definitions + print CPP <<__E; +${derive_objects_from}::${derive_objects_from}() +{ +} +${derive_objects_from}::~${derive_objects_from}() +{ +} +${derive_objects_from}::${derive_objects_from}(const $derive_objects_from &rToCopy) +{ +} +$objects_extra_cpp +__E +} + + + +my $classname_base = $protocol_name.'Protocol'.$type; + +# output the classes +for my $cmd (@cmd_list) +{ + print H <<__E; +class $classname_base$cmd : public $derive_objects_from +{ +public: + $classname_base$cmd(); + $classname_base$cmd(const $classname_base$cmd &rToCopy); + ~$classname_base$cmd(); + int GetType() const; + enum + { + TypeID = $cmd_id{$cmd} + }; +__E + # constants + if(exists $cmd_constants{$cmd}) + { + print H "\tenum\n\t{\n\t\t"; + print H join(",\n\t\t",@{$cmd_constants{$cmd}}); + print H "\n\t};\n"; + } + # flags + if(obj_is_type($cmd,'EndsConversation')) + { + print H "\tbool IsConversationEnd() const;\n"; + } + if(obj_is_type($cmd,'IsError')) + { + print H "\tbool IsError(int &rTypeOut, int &rSubTypeOut) const;\n"; + print H "\tstd::string GetMessage() const;\n"; + } + if($type eq 'Server' && obj_is_type($cmd, 'Command')) + { + print H "\tstd::auto_ptr<ProtocolObject> DoCommand(${protocol_name}ProtocolServer &rProtocol, $context_class &rContext); // IMPLEMENT THIS\n" + } + + # want to be able to read from streams? + my $read_from_streams = (obj_is_type($cmd,'Command') && $type eq 'Server') || (obj_is_type($cmd,'Reply') && $type eq 'Client'); + my $write_to_streams = (obj_is_type($cmd,'Command') && $type eq 'Client') || (obj_is_type($cmd,'Reply') && $type eq 'Server'); + + if($read_from_streams) + { + print H "\tvoid SetPropertiesFromStreamData(Protocol &rProtocol);\n"; + + # write Get functions + for(my $x = 0; $x < $#{$cmd_contents{$cmd}}; $x+=2) + { + my ($ty,$nm) = (${$cmd_contents{$cmd}}[$x], ${$cmd_contents{$cmd}}[$x+1]); + + print H "\t".translate_type_to_arg_type($ty)." Get$nm() {return m$nm;}\n"; + } + } + my $param_con_args = ''; + if($write_to_streams) + { + # extra constructor? + if($#{$cmd_contents{$cmd}} >= 0) + { + my @a; + for(my $x = 0; $x < $#{$cmd_contents{$cmd}}; $x+=2) + { + my ($ty,$nm) = (${$cmd_contents{$cmd}}[$x], ${$cmd_contents{$cmd}}[$x+1]); + + push @a,translate_type_to_arg_type($ty)." $nm"; + } + $param_con_args = join(', ',@a); + print H "\t$classname_base$cmd(".$param_con_args.");\n"; + } + print H "\tvoid WritePropertiesToStreamData(Protocol &rProtocol) const;\n"; + # set functions + for(my $x = 0; $x < $#{$cmd_contents{$cmd}}; $x+=2) + { + my ($ty,$nm) = (${$cmd_contents{$cmd}}[$x], ${$cmd_contents{$cmd}}[$x+1]); + + print H "\tvoid Set$nm(".translate_type_to_arg_type($ty)." $nm) {m$nm = $nm;}\n"; + } + } + + if($implement_syslog) + { + print H "\tvirtual void LogSysLog(const char *Action) const;\n"; + } + if($implement_filelog) + { + print H "\tvirtual void LogFile(const char *Action, FILE *file) const;\n"; + } + + + # write member variables and setup for cpp file + my @def_constructor_list; + my @copy_constructor_list; + my @param_constructor_list; + + print H "private:\n"; + for(my $x = 0; $x < $#{$cmd_contents{$cmd}}; $x+=2) + { + my ($ty,$nm) = (${$cmd_contents{$cmd}}[$x], ${$cmd_contents{$cmd}}[$x+1]); + + print H "\t".translate_type_to_member_type($ty)." m$nm;\n"; + + my ($basic,$typename) = translate_type($ty); + if($basic) + { + push @def_constructor_list, "m$nm(0)"; + } + push @copy_constructor_list, "m$nm(rToCopy.m$nm)"; + push @param_constructor_list, "m$nm($nm)"; + } + + # finish off + print H "};\n\n"; + + # now the cpp file... + my $def_con_vars = join(",\n\t ",@def_constructor_list); + $def_con_vars = "\n\t: ".$def_con_vars if $def_con_vars ne ''; + my $copy_con_vars = join(",\n\t ",@copy_constructor_list); + $copy_con_vars = "\n\t: ".$copy_con_vars if $copy_con_vars ne ''; + my $param_con_vars = join(",\n\t ",@param_constructor_list); + $param_con_vars = "\n\t: ".$param_con_vars if $param_con_vars ne ''; + + my $class = "$classname_base$cmd".'::'; + print CPP <<__E; +$class$classname_base$cmd()$def_con_vars +{ +} +$class$classname_base$cmd(const $classname_base$cmd &rToCopy)$copy_con_vars +{ +} +$class~$classname_base$cmd() +{ +} +int ${class}GetType() const +{ + return $cmd_id{$cmd}; +} +__E + if($read_from_streams) + { + print CPP "void ${class}SetPropertiesFromStreamData(Protocol &rProtocol)\n{\n"; + for(my $x = 0; $x < $#{$cmd_contents{$cmd}}; $x+=2) + { + my ($ty,$nm) = (${$cmd_contents{$cmd}}[$x], ${$cmd_contents{$cmd}}[$x+1]); + if($ty =~ m/\Avector/) + { + print CPP "\trProtocol.ReadVector(m$nm);\n"; + } + else + { + print CPP "\trProtocol.Read(m$nm);\n"; + } + } + print CPP "}\n"; + } + if($write_to_streams) + { + # implement extra constructor? + if($param_con_vars ne '') + { + print CPP "$class$classname_base$cmd($param_con_args)$param_con_vars\n{\n}\n"; + } + print CPP "void ${class}WritePropertiesToStreamData(Protocol &rProtocol) const\n{\n"; + for(my $x = 0; $x < $#{$cmd_contents{$cmd}}; $x+=2) + { + my ($ty,$nm) = (${$cmd_contents{$cmd}}[$x], ${$cmd_contents{$cmd}}[$x+1]); + if($ty =~ m/\Avector/) + { + print CPP "\trProtocol.WriteVector(m$nm);\n"; + } + else + { + print CPP "\trProtocol.Write(m$nm);\n"; + } + } + print CPP "}\n"; + } + if(obj_is_type($cmd,'EndsConversation')) + { + print CPP "bool ${class}IsConversationEnd() const\n{\n\treturn true;\n}\n"; + } + if(obj_is_type($cmd,'IsError')) + { + # get parameters + my ($mem_type,$mem_subtype) = split /,/,obj_get_type_params($cmd,'IsError'); + print CPP <<__E; +bool ${class}IsError(int &rTypeOut, int &rSubTypeOut) const +{ + rTypeOut = m$mem_type; + rSubTypeOut = m$mem_subtype; + return true; +} +std::string ${class}GetMessage() const +{ + switch(m$mem_subtype) + { +__E + foreach my $const (@{$cmd_constants{$cmd}}) + { + next unless $const =~ /^Err_(.*)/; + my $shortname = $1; + $const =~ s/ = .*//; + print CPP <<__E; + case $const: return "$shortname"; +__E + } + print CPP <<__E; + default: + std::ostringstream out; + out << "Unknown subtype " << m$mem_subtype; + return out.str(); + } +} +__E + } + + if($implement_syslog) + { + my ($log) = make_log_strings_framework($cmd); + print CPP <<__E; +void ${class}LogSysLog(const char *Action) const +{ + BOX_TRACE($log); +} +__E + } + if($implement_filelog) + { + my ($log) = make_log_strings_framework($cmd); + print CPP <<__E; +void ${class}LogFile(const char *Action, FILE *File) const +{ + std::ostringstream oss; + oss << $log; + ::fprintf(File, "%s\\n", oss.str().c_str()); + ::fflush(File); +} +__E + } +} + +# finally, the protocol object itself +print H <<__E; +class $classname_base : public Protocol +{ +public: + $classname_base(IOStream &rStream); + virtual ~$classname_base(); + + std::auto_ptr<$derive_objects_from> Receive(); + void Send(const ${derive_objects_from} &rObject); +__E +if($implement_syslog) +{ + print H "\tvoid SetLogToSysLog(bool Log = false) {mLogToSysLog = Log;}\n"; +} +if($implement_filelog) +{ + print H "\tvoid SetLogToFile(FILE *File = 0) {mLogToFile = File;}\n"; +} +if($type eq 'Server') +{ + # need to put in the conversation function + print H "\tvoid DoServer($context_class &rContext);\n\n"; + # and the send vector thing + print H "\tvoid SendStreamAfterCommand(IOStream *pStream);\n\n"; +} +if($type eq 'Client') +{ + # add plain object taking query functions + my $with_params; + for my $cmd (@cmd_list) + { + if(obj_is_type($cmd,'Command')) + { + my $has_stream = obj_is_type($cmd,'StreamWithCommand'); + my $argextra = $has_stream?', IOStream &rStream':''; + my $queryextra = $has_stream?', rStream':''; + my $reply = obj_get_type_params($cmd,'Command'); + print H "\tstd::auto_ptr<$classname_base$reply> Query(const $classname_base$cmd &rQuery$argextra);\n"; + my @a; + my @na; + for(my $x = 0; $x < $#{$cmd_contents{$cmd}}; $x+=2) + { + my ($ty,$nm) = (${$cmd_contents{$cmd}}[$x], ${$cmd_contents{$cmd}}[$x+1]); + push @a,translate_type_to_arg_type($ty)." $nm"; + push @na,"$nm"; + } + my $ar = join(', ',@a); + my $nar = join(', ',@na); + $nar = "($nar)" if $nar ne ''; + + $with_params .= "\tinline std::auto_ptr<$classname_base$reply> Query$cmd($ar$argextra)\n\t{\n"; + $with_params .= "\t\t$classname_base$cmd send$nar;\n"; + $with_params .= "\t\treturn Query(send$queryextra);\n"; + $with_params .= "\t}\n"; + } + } + # quick hack to correct bad argument lists for commands with zero paramters but with streams + $with_params =~ s/\(, /(/g; + print H "\n",$with_params,"\n"; +} +print H <<__E; +private: + $classname_base(const $classname_base &rToCopy); +__E +if($type eq 'Server') +{ + # need to put the streams to send vector + print H "\tstd::vector<IOStream*> mStreamsToSend;\n\tvoid DeleteStreamsToSend();\n"; +} + +if($implement_filelog || $implement_syslog) +{ + print H <<__E; + virtual void InformStreamReceiving(u_int32_t Size); + virtual void InformStreamSending(u_int32_t Size); +__E +} + +if($implement_syslog) +{ + print H "private:\n\tbool mLogToSysLog;\n"; +} +if($implement_filelog) +{ + print H "private:\n\tFILE *mLogToFile;\n"; +} +print H <<__E; + +protected: + virtual std::auto_ptr<ProtocolObject> MakeProtocolObject(int ObjType); + virtual const char *GetIdentString(); +}; + +__E + +my $constructor_extra = ''; +$constructor_extra .= ', mLogToSysLog(false)' if $implement_syslog; +$constructor_extra .= ', mLogToFile(0)' if $implement_filelog; + +my $destructor_extra = ($type eq 'Server')?"\n\tDeleteStreamsToSend();":''; + +my $prefix = $classname_base.'::'; +print CPP <<__E; +$prefix$classname_base(IOStream &rStream) + : Protocol(rStream)$constructor_extra +{ +} +$prefix~$classname_base() +{$destructor_extra +} +const char *${prefix}GetIdentString() +{ + return "$ident_string"; +} +std::auto_ptr<ProtocolObject> ${prefix}MakeProtocolObject(int ObjType) +{ + switch(ObjType) + { +__E + +# do objects within this +for my $cmd (@cmd_list) +{ + print CPP <<__E; + case $cmd_id{$cmd}: + return std::auto_ptr<ProtocolObject>(new $classname_base$cmd); + break; +__E +} + +print CPP <<__E; + default: + THROW_EXCEPTION(ConnectionException, Conn_Protocol_UnknownCommandRecieved) + } +} +__E +# write receive and send functions +print CPP <<__E; +std::auto_ptr<$derive_objects_from> ${prefix}Receive() +{ + std::auto_ptr<${derive_objects_from}> preply((${derive_objects_from}*)(Protocol::Receive().release())); + +__E + if($implement_syslog) + { + print CPP <<__E; + if(mLogToSysLog) + { + preply->LogSysLog("Receive"); + } +__E + } + if($implement_filelog) + { + print CPP <<__E; + if(mLogToFile != 0) + { + preply->LogFile("Receive", mLogToFile); + } +__E + } +print CPP <<__E; + + return preply; +} + +void ${prefix}Send(const ${derive_objects_from} &rObject) +{ +__E + if($implement_syslog) + { + print CPP <<__E; + if(mLogToSysLog) + { + rObject.LogSysLog("Send"); + } +__E + } + if($implement_filelog) + { + print CPP <<__E; + if(mLogToFile != 0) + { + rObject.LogFile("Send", mLogToFile); + } +__E + } + +print CPP <<__E; + Protocol::Send(rObject); +} + +__E +# write server function? +if($type eq 'Server') +{ + print CPP <<__E; +void ${prefix}DoServer($context_class &rContext) +{ + // Handshake with client + Handshake(); + + // Command processing loop + bool inProgress = true; + while(inProgress) + { + // Get an object from the conversation + std::auto_ptr<${derive_objects_from}> pobj(Receive()); + + // Run the command + std::auto_ptr<${derive_objects_from}> preply((${derive_objects_from}*)(pobj->DoCommand(*this, rContext).release())); + + // Send the reply + Send(*(preply.get())); + + // Send any streams + for(unsigned int s = 0; s < mStreamsToSend.size(); s++) + { + // Send the streams + SendStream(*mStreamsToSend[s]); + } + // Delete these streams + DeleteStreamsToSend(); + + // Does this end the conversation? + if(pobj->IsConversationEnd()) + { + inProgress = false; + } + } +} + +void ${prefix}SendStreamAfterCommand(IOStream *pStream) +{ + ASSERT(pStream != NULL); + mStreamsToSend.push_back(pStream); +} + +void ${prefix}DeleteStreamsToSend() +{ + for(std::vector<IOStream*>::iterator i(mStreamsToSend.begin()); i != mStreamsToSend.end(); ++i) + { + delete (*i); + } + mStreamsToSend.clear(); +} + +__E +} + +# write logging functions? +if($implement_filelog || $implement_syslog) +{ + my ($fR,$fS); + + if($implement_syslog) + { + $fR .= <<__E; + if(mLogToSysLog) + { + if(Size==Protocol::ProtocolStream_SizeUncertain) + { + BOX_TRACE("Receiving stream, size uncertain"); + } + else + { + BOX_TRACE("Receiving stream, size " << Size); + } + } +__E + + $fS .= <<__E; + if(mLogToSysLog) + { + if(Size==Protocol::ProtocolStream_SizeUncertain) + { + BOX_TRACE("Sending stream, size uncertain"); + } + else + { + BOX_TRACE("Sending stream, size " << Size); + } + } +__E + } + + if($implement_filelog) + { + $fR .= <<__E; + if(mLogToFile) + { + ::fprintf(mLogToFile, + (Size==Protocol::ProtocolStream_SizeUncertain) + ?"Receiving stream, size uncertain\\n" + :"Receiving stream, size %d\\n", Size); + ::fflush(mLogToFile); + } +__E + $fS .= <<__E; + if(mLogToFile) + { + ::fprintf(mLogToFile, + (Size==Protocol::ProtocolStream_SizeUncertain) + ?"Sending stream, size uncertain\\n" + :"Sending stream, size %d\\n", Size); + ::fflush(mLogToFile); + } +__E + } + + print CPP <<__E; + +void ${prefix}InformStreamReceiving(u_int32_t Size) +{ +$fR} + +void ${prefix}InformStreamSending(u_int32_t Size) +{ +$fS} + +__E +} + + +# write client Query functions? +if($type eq 'Client') +{ + for my $cmd (@cmd_list) + { + if(obj_is_type($cmd,'Command')) + { + my $reply = obj_get_type_params($cmd,'Command'); + my $reply_id = $cmd_id{$reply}; + my $has_stream = obj_is_type($cmd,'StreamWithCommand'); + my $argextra = $has_stream?', IOStream &rStream':''; + my $send_stream_extra = ''; + if($has_stream) + { + $send_stream_extra = <<__E; + + // Send stream after the command + SendStream(rStream); +__E + } + print CPP <<__E; +std::auto_ptr<$classname_base$reply> ${classname_base}::Query(const $classname_base$cmd &rQuery$argextra) +{ + // Send query + Send(rQuery); + $send_stream_extra + // Wait for the reply + std::auto_ptr<${derive_objects_from}> preply(Receive().release()); + + if(preply->GetType() == $reply_id) + { + // Correct response + return std::auto_ptr<$classname_base$reply>(($classname_base$reply*)preply.release()); + } + else + { + // Set protocol error + int type, subType; + if(preply->IsError(type, subType)) + { + SetError(type, subType); + BOX_WARNING("$cmd command failed: received error " << + ((${classname_base}Error&)*preply).GetMessage()); + } + else + { + SetError(Protocol::UnknownError, Protocol::UnknownError); + BOX_WARNING("$cmd command failed: received " + "unexpected response type " << + preply->GetType()); + } + + // Throw an exception + THROW_EXCEPTION(ConnectionException, Conn_Protocol_UnexpectedReply) + } +} +__E + } + } +} + + + +print H <<__E; +#endif // $guardname + +__E + +# close files +close H; +close CPP; + + +sub obj_is_type +{ + my ($c,$ty) = @_; + for(@{$cmd_attributes{$c}}) + { + return 1 if $_ =~ m/\A$ty/; + } + + return 0; +} + +sub obj_get_type_params +{ + my ($c,$ty) = @_; + for(@{$cmd_attributes{$c}}) + { + return $1 if $_ =~ m/\A$ty\((.+?)\)\Z/; + } + die "Can't find attribute $ty\n" +} + +# returns (is basic type, typename) +sub translate_type +{ + my $ty = $_[0]; + + if($ty =~ m/\Avector\<(.+?)\>\Z/) + { + my $v_type = $1; + my (undef,$v_ty) = translate_type($v_type); + return (0, 'std::vector<'.$v_ty.'>') + } + else + { + if(!exists $translate_type_info{$ty}) + { + die "Don't know about type name $ty\n"; + } + return @{$translate_type_info{$ty}} + } +} + +sub translate_type_to_arg_type +{ + my ($basic,$typename) = translate_type(@_); + return $basic?$typename:'const '.$typename.'&' +} + +sub translate_type_to_member_type +{ + my ($basic,$typename) = translate_type(@_); + return $typename +} + +sub make_log_strings +{ + my ($cmd) = @_; + + my @str; + my @arg; + for(my $x = 0; $x < $#{$cmd_contents{$cmd}}; $x+=2) + { + my ($ty,$nm) = (${$cmd_contents{$cmd}}[$x], ${$cmd_contents{$cmd}}[$x+1]); + + if(exists $log_display_types{$ty}) + { + # need to translate it + my ($format,$arg) = @{$log_display_types{$ty}}; + $arg =~ s/VAR/m$nm/g; + + if ($format eq "0x%llx" and $target_windows) + { + $format = "0x%I64x"; + $arg = "(uint64_t)$arg"; + } + + push @str,$format; + push @arg,$arg; + } + else + { + # is opaque + push @str,'OPAQUE'; + } + } + return ($cmd.'('.join(',',@str).')', join(',','',@arg)); +} + +sub make_log_strings_framework +{ + my ($cmd) = @_; + + my @args; + + for(my $x = 0; $x < $#{$cmd_contents{$cmd}}; $x+=2) + { + my ($ty,$nm) = (${$cmd_contents{$cmd}}[$x], ${$cmd_contents{$cmd}}[$x+1]); + + if(exists $log_display_types{$ty}) + { + # need to translate it + my ($format,$arg) = @{$log_display_types{$ty}}; + $arg =~ s/VAR/m$nm/g; + + if ($format eq '\\"%s\\"') + { + $arg = "\"\\\"\" << $arg << \"\\\"\""; + } + elsif ($format =~ m'x$') + { + # my $width = 0; + # $ty =~ /^int(\d+)$/ and $width = $1 / 4; + $arg = "($arg == 0 ? \"0x\" : \"\") " . + "<< std::hex " . + "<< std::showbase " . + # "<< std::setw($width) " . + # "<< std::setfill('0') " . + # "<< std::internal " . + "<< $arg " . + "<< std::dec"; + } + + push @args, $arg; + } + else + { + # is opaque + push @args, '"OPAQUE"'; + } + } + + my $log_cmd = "Action << \" $cmd(\" "; + foreach my $arg (@args) + { + $arg = "<< $arg "; + } + $log_cmd .= join('<< "," ',@args); + $log_cmd .= '<< ")"'; + return $log_cmd; +} + + |