The InspIRCd Project
Home | Developers | Wiki | Forums | Bug Tracker | SVN | Download
Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | Directories | File List | Namespace Members | Class Members | File Members

m_mysql.cpp

Go to the documentation of this file.
00001 /*       +------------------------------------+
00002  *       | Inspire Internet Relay Chat Daemon |
00003  *       +------------------------------------+
00004  *
00005  *  InspIRCd: (C) 2002-2008 InspIRCd Development Team
00006  * See: http://www.inspircd.org/wiki/index.php/Credits
00007  *
00008  * This program is free but copyrighted software; see
00009  *          the file COPYING for details.
00010  *
00011  * ---------------------------------------------------
00012  */
00013 
00014 /* Stop mysql wanting to use long long */
00015 #define NO_CLIENT_LONG_LONG
00016 
00017 #include "inspircd.h"
00018 #include <mysql.h>
00019 #include "m_sqlv2.h"
00020 
00021 #ifdef WINDOWS
00022 #pragma comment(lib, "mysqlclient.lib")
00023 #endif
00024 
00025 /* VERSION 2 API: With nonblocking (threaded) requests */
00026 
00027 /* $ModDesc: SQL Service Provider module for all other m_sql* modules */
00028 /* $CompileFlags: exec("mysql_config --include") */
00029 /* $LinkerFlags: exec("mysql_config --libs_r") rpath("mysql_config --libs_r") */
00030 /* $ModDep: m_sqlv2.h */
00031 
00032 /* THE NONBLOCKING MYSQL API!
00033  *
00034  * MySQL provides no nonblocking (asyncronous) API of its own, and its developers recommend
00035  * that instead, you should thread your program. This is what i've done here to allow for
00036  * asyncronous SQL requests via mysql. The way this works is as follows:
00037  *
00038  * The module spawns a thread via class Thread, and performs its mysql queries in this thread,
00039  * using a queue with priorities. There is a mutex on either end which prevents two threads
00040  * adjusting the queue at the same time, and crashing the ircd. Every 50 milliseconds, the
00041  * worker thread wakes up, and checks if there is a request at the head of its queue.
00042  * If there is, it processes this request, blocking the worker thread but leaving the ircd
00043  * thread to go about its business as usual. During this period, the ircd thread is able
00044  * to insert futher pending requests into the queue.
00045  *
00046  * Once the processing of a request is complete, it is removed from the incoming queue to
00047  * an outgoing queue, and initialized as a 'response'. The worker thread then signals the
00048  * ircd thread (via a loopback socket) of the fact a result is available, by sending the
00049  * connection ID through the connection.
00050  *
00051  * The ircd thread then mutexes the queue once more, reads the outbound response off the head
00052  * of the queue, and sends it on its way to the original calling module.
00053  *
00054  * XXX: You might be asking "why doesnt he just send the response from within the worker thread?"
00055  * The answer to this is simple. The majority of InspIRCd, and in fact most ircd's are not
00056  * threadsafe. This module is designed to be threadsafe and is careful with its use of threads,
00057  * however, if we were to call a module's OnRequest even from within a thread which was not the
00058  * one the module was originally instantiated upon, there is a chance of all hell breaking loose
00059  * if a module is ever put in a re-enterant state (stack corruption could occur, crashes, data
00060  * corruption, and worse, so DONT think about it until the day comes when InspIRCd is 100%
00061  * gauranteed threadsafe!)
00062  *
00063  * For a diagram of this system please see http://www.inspircd.org/wiki/Mysql2
00064  */
00065 
00066 
00067 class SQLConnection;
00068 class MySQLListener;
00069 
00070 
00071 typedef std::map<std::string, SQLConnection*> ConnMap;
00072 static MySQLListener *MessagePipe = NULL;
00073 int QueueFD = -1;
00074 
00075 class DispatcherThread;
00076 
00079 class ModuleSQL : public Module
00080 {
00081  public:
00082 
00083          ConfigReader *Conf;
00084          InspIRCd* PublicServerInstance;
00085          int currid;
00086          bool rehashing;
00087          DispatcherThread* Dispatcher;
00088          Mutex* QueueMutex;
00089          Mutex* ResultsMutex;
00090          Mutex* LoggingMutex;
00091          Mutex* ConnMutex;
00092 
00093          ModuleSQL(InspIRCd* Me);
00094          ~ModuleSQL();
00095          unsigned long NewID();
00096          const char* OnRequest(Request* request);
00097          void OnRehash(User* user, const std::string &parameter);
00098          Version GetVersion();
00099 };
00100 
00101 
00102 
00103 #if !defined(MYSQL_VERSION_ID) || MYSQL_VERSION_ID<32224
00104 #define mysql_field_count mysql_num_fields
00105 #endif
00106 
00107 typedef std::deque<SQLresult*> ResultQueue;
00108 
00111 class MySQLresult : public SQLresult
00112 {
00113         int currentrow;
00114         std::vector<std::string> colnames;
00115         std::vector<SQLfieldList> fieldlists;
00116         SQLfieldMap* fieldmap;
00117         SQLfieldMap fieldmap2;
00118         SQLfieldList emptyfieldlist;
00119         int rows;
00120  public:
00121 
00122         MySQLresult(Module* self, Module* to, MYSQL_RES* res, int affected_rows, unsigned int rid) : SQLresult(self, to, rid), currentrow(0), fieldmap(NULL)
00123         {
00124                 /* A number of affected rows from from mysql_affected_rows.
00125                  */
00126                 fieldlists.clear();
00127                 rows = 0;
00128                 if (affected_rows >= 1)
00129                 {
00130                         rows = affected_rows;
00131                         fieldlists.resize(rows);
00132                 }
00133                 unsigned int field_count = 0;
00134                 if (res)
00135                 {
00136                         MYSQL_ROW row;
00137                         int n = 0;
00138                         while ((row = mysql_fetch_row(res)))
00139                         {
00140                                 if (fieldlists.size() < (unsigned int)rows+1)
00141                                 {
00142                                         fieldlists.resize(fieldlists.size()+1);
00143                                 }
00144                                 field_count = 0;
00145                                 MYSQL_FIELD *fields = mysql_fetch_fields(res);
00146                                 if(mysql_num_fields(res) == 0)
00147                                         break;
00148                                 if (fields && mysql_num_fields(res))
00149                                 {
00150                                         colnames.clear();
00151                                         while (field_count < mysql_num_fields(res))
00152                                         {
00153                                                 std::string a = (fields[field_count].name ? fields[field_count].name : "");
00154                                                 std::string b = (row[field_count] ? row[field_count] : "");
00155                                                 SQLfield sqlf(b, !row[field_count]);
00156                                                 colnames.push_back(a);
00157                                                 fieldlists[n].push_back(sqlf);
00158                                                 field_count++;
00159                                         }
00160                                         n++;
00161                                 }
00162                                 rows++;
00163                         }
00164                         mysql_free_result(res);
00165                 }
00166         }
00167 
00168         MySQLresult(Module* self, Module* to, SQLerror e, unsigned int rid) : SQLresult(self, to, rid), currentrow(0)
00169         {
00170                 rows = 0;
00171                 error = e;
00172         }
00173 
00174         ~MySQLresult()
00175         {
00176         }
00177 
00178         virtual int Rows()
00179         {
00180                 return rows;
00181         }
00182 
00183         virtual int Cols()
00184         {
00185                 return colnames.size();
00186         }
00187 
00188         virtual std::string ColName(int column)
00189         {
00190                 if (column < (int)colnames.size())
00191                 {
00192                         return colnames[column];
00193                 }
00194                 else
00195                 {
00196                         throw SQLbadColName();
00197                 }
00198                 return "";
00199         }
00200 
00201         virtual int ColNum(const std::string &column)
00202         {
00203                 for (unsigned int i = 0; i < colnames.size(); i++)
00204                 {
00205                         if (column == colnames[i])
00206                                 return i;
00207                 }
00208                 throw SQLbadColName();
00209                 return 0;
00210         }
00211 
00212         virtual SQLfield GetValue(int row, int column)
00213         {
00214                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < Cols()))
00215                 {
00216                         return fieldlists[row][column];
00217                 }
00218 
00219                 throw SQLbadColName();
00220 
00221                 /* XXX: We never actually get here because of the throw */
00222                 return SQLfield("",true);
00223         }
00224 
00225         virtual SQLfieldList& GetRow()
00226         {
00227                 if (currentrow < rows)
00228                         return fieldlists[currentrow++];
00229                 else
00230                         return emptyfieldlist;
00231         }
00232 
00233         virtual SQLfieldMap& GetRowMap()
00234         {
00235                 fieldmap2.clear();
00236 
00237                 if (currentrow < rows)
00238                 {
00239                         for (int i = 0; i < Cols(); i++)
00240                         {
00241                                 fieldmap2.insert(std::make_pair(colnames[i],GetValue(currentrow, i)));
00242                         }
00243                         currentrow++;
00244                 }
00245 
00246                 return fieldmap2;
00247         }
00248 
00249         virtual SQLfieldList* GetRowPtr()
00250         {
00251                 SQLfieldList* fieldlist = new SQLfieldList();
00252 
00253                 if (currentrow < rows)
00254                 {
00255                         for (int i = 0; i < Rows(); i++)
00256                         {
00257                                 fieldlist->push_back(fieldlists[currentrow][i]);
00258                         }
00259                         currentrow++;
00260                 }
00261                 return fieldlist;
00262         }
00263 
00264         virtual SQLfieldMap* GetRowMapPtr()
00265         {
00266                 fieldmap = new SQLfieldMap();
00267 
00268                 if (currentrow < rows)
00269                 {
00270                         for (int i = 0; i < Cols(); i++)
00271                         {
00272                                 fieldmap->insert(std::make_pair(colnames[i],GetValue(currentrow, i)));
00273                         }
00274                         currentrow++;
00275                 }
00276 
00277                 return fieldmap;
00278         }
00279 
00280         virtual void Free(SQLfieldMap* fm)
00281         {
00282                 delete fm;
00283         }
00284 
00285         virtual void Free(SQLfieldList* fl)
00286         {
00287                 delete fl;
00288         }
00289 };
00290 
00291 class SQLConnection;
00292 
00293 void NotifyMainThread(SQLConnection* connection_with_new_result);
00294 
00297 class SQLConnection : public classbase
00298 {
00299  protected:
00300 
00301         MYSQL connection;
00302         MYSQL_RES *res;
00303         MYSQL_ROW row;
00304         SQLhost host;
00305         std::map<std::string,std::string> thisrow;
00306         bool Enabled;
00307         ModuleSQL* Parent;
00308 
00309  public:
00310 
00311         QueryQueue queue;
00312         ResultQueue rq;
00313 
00314         // This constructor creates an SQLConnection object with the given credentials, but does not connect yet.
00315         SQLConnection(const SQLhost &hi, ModuleSQL* Creator) : host(hi), Enabled(false), Parent(Creator)
00316         {
00317         }
00318 
00319         ~SQLConnection()
00320         {
00321                 Close();
00322         }
00323 
00324         // This method connects to the database using the credentials supplied to the constructor, and returns
00325         // true upon success.
00326         bool Connect()
00327         {
00328                 unsigned int timeout = 1;
00329                 mysql_init(&connection);
00330                 mysql_options(&connection,MYSQL_OPT_CONNECT_TIMEOUT,(char*)&timeout);
00331                 return mysql_real_connect(&connection, host.host.c_str(), host.user.c_str(), host.pass.c_str(), host.name.c_str(), host.port, NULL, 0);
00332         }
00333 
00334         void DoLeadingQuery()
00335         {
00336                 if (!CheckConnection())
00337                         return;
00338 
00339                 /* Parse the command string and dispatch it to mysql */
00340                 SQLrequest& req = queue.front();
00341 
00342                 /* Pointer to the buffer we screw around with substitution in */
00343                 char* query;
00344 
00345                 /* Pointer to the current end of query, where we append new stuff */
00346                 char* queryend;
00347 
00348                 /* Total length of the unescaped parameters */
00349                 unsigned long paramlen;
00350 
00351                 /* Total length of query, used for binary-safety in mysql_real_query */
00352                 unsigned long querylength = 0;
00353 
00354                 paramlen = 0;
00355 
00356                 for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)
00357                 {
00358                         paramlen += i->size();
00359                 }
00360 
00361                 /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.
00362                  * sizeofquery + (totalparamlength*2) + 1
00363                  *
00364                  * The +1 is for null-terminating the string for mysql_real_escape_string
00365                  */
00366 
00367                 query = new char[req.query.q.length() + (paramlen*2) + 1];
00368                 queryend = query;
00369 
00370                 /* Okay, now we have a buffer large enough we need to start copying the query into it and escaping and substituting
00371                  * the parameters into it...
00372                  */
00373 
00374                 for(unsigned long i = 0; i < req.query.q.length(); i++)
00375                 {
00376                         if(req.query.q[i] == '?')
00377                         {
00378                                 /* We found a place to substitute..what fun.
00379                                  * use mysql calls to escape and write the
00380                                  * escaped string onto the end of our query buffer,
00381                                  * then we "just" need to make sure queryend is
00382                                  * pointing at the right place.
00383                                  */
00384                                 if(req.query.p.size())
00385                                 {
00386                                         unsigned long len = mysql_real_escape_string(&connection, queryend, req.query.p.front().c_str(), req.query.p.front().length());
00387 
00388                                         queryend += len;
00389                                         req.query.p.pop_front();
00390                                 }
00391                                 else
00392                                         break;
00393                         }
00394                         else
00395                         {
00396                                 *queryend = req.query.q[i];
00397                                 queryend++;
00398                         }
00399                         querylength++;
00400                 }
00401 
00402                 *queryend = 0;
00403 
00404                 Parent->QueueMutex->Lock();
00405                 req.query.q = query;
00406                 Parent->QueueMutex->Unlock();
00407 
00408                 if (!mysql_real_query(&connection, req.query.q.data(), req.query.q.length()))
00409                 {
00410                         /* Successfull query */
00411                         res = mysql_use_result(&connection);
00412                         unsigned long rows = mysql_affected_rows(&connection);
00413                         MySQLresult* r = new MySQLresult(Parent, req.GetSource(), res, rows, req.id);
00414                         r->dbid = this->GetID();
00415                         r->query = req.query.q;
00416                         /* Put this new result onto the results queue.
00417                          * XXX: Remember to mutex the queue!
00418                          */
00419                         Parent->ResultsMutex->Lock();
00420                         rq.push_back(r);
00421                         Parent->ResultsMutex->Unlock();
00422                 }
00423                 else
00424                 {
00425                         /* XXX: See /usr/include/mysql/mysqld_error.h for a list of
00426                          * possible error numbers and error messages */
00427                         SQLerror e(SQL_QREPLY_FAIL, ConvToStr(mysql_errno(&connection)) + std::string(": ") + mysql_error(&connection));
00428                         MySQLresult* r = new MySQLresult(Parent, req.GetSource(), e, req.id);
00429                         r->dbid = this->GetID();
00430                         r->query = req.query.q;
00431 
00432                         Parent->ResultsMutex->Lock();
00433                         rq.push_back(r);
00434                         Parent->ResultsMutex->Unlock();
00435                 }
00436 
00437                 /* Now signal the main thread that we've got a result to process.
00438                  * Pass them this connection id as what to examine
00439                  */
00440 
00441                 delete[] query;
00442 
00443                 NotifyMainThread(this);
00444         }
00445 
00446         bool ConnectionLost()
00447         {
00448                 if (&connection) {
00449                         return (mysql_ping(&connection) != 0);
00450                 }
00451                 else return false;
00452         }
00453 
00454         bool CheckConnection()
00455         {
00456                 if (ConnectionLost()) {
00457                         return Connect();
00458                 }
00459                 else return true;
00460         }
00461 
00462         std::string GetError()
00463         {
00464                 return mysql_error(&connection);
00465         }
00466 
00467         const std::string& GetID()
00468         {
00469                 return host.id;
00470         }
00471 
00472         std::string GetHost()
00473         {
00474                 return host.host;
00475         }
00476 
00477         void SetEnable(bool Enable)
00478         {
00479                 Enabled = Enable;
00480         }
00481 
00482         bool IsEnabled()
00483         {
00484                 return Enabled;
00485         }
00486 
00487         void Close()
00488         {
00489                 mysql_close(&connection);
00490         }
00491 
00492         const SQLhost& GetConfHost()
00493         {
00494                 return host;
00495         }
00496 
00497 };
00498 
00499 ConnMap Connections;
00500 
00501 bool HasHost(const SQLhost &host)
00502 {
00503         for (ConnMap::iterator iter = Connections.begin(); iter != Connections.end(); iter++)
00504         {
00505                 if (host == iter->second->GetConfHost())
00506                         return true;
00507         }
00508         return false;
00509 }
00510 
00511 bool HostInConf(ConfigReader* conf, const SQLhost &h)
00512 {
00513         for(int i = 0; i < conf->Enumerate("database"); i++)
00514         {
00515                 SQLhost host;
00516                 host.id         = conf->ReadValue("database", "id", i);
00517                 host.host       = conf->ReadValue("database", "hostname", i);
00518                 host.port       = conf->ReadInteger("database", "port", i, true);
00519                 host.name       = conf->ReadValue("database", "name", i);
00520                 host.user       = conf->ReadValue("database", "username", i);
00521                 host.pass       = conf->ReadValue("database", "password", i);
00522                 host.ssl        = conf->ReadFlag("database", "ssl", i);
00523                 if (h == host)
00524                         return true;
00525         }
00526         return false;
00527 }
00528 
00529 void ClearOldConnections(ConfigReader* conf)
00530 {
00531         ConnMap::iterator i,safei;
00532         for (i = Connections.begin(); i != Connections.end(); i++)
00533         {
00534                 if (!HostInConf(conf, i->second->GetConfHost()))
00535                 {
00536                         delete i->second;
00537                         safei = i;
00538                         --i;
00539                         Connections.erase(safei);
00540                 }
00541         }
00542 }
00543 
00544 void ClearAllConnections()
00545 {
00546         ConnMap::iterator i;
00547         while ((i = Connections.begin()) != Connections.end())
00548         {
00549                 Connections.erase(i);
00550                 delete i->second;
00551         }
00552 }
00553 
00554 void ConnectDatabases(InspIRCd* ServerInstance, ModuleSQL* Parent)
00555 {
00556         for (ConnMap::iterator i = Connections.begin(); i != Connections.end(); i++)
00557         {
00558                 if (i->second->IsEnabled())
00559                         continue;
00560 
00561                 i->second->SetEnable(true);
00562                 if (!i->second->Connect())
00563                 {
00564                         /* XXX: MUTEX */
00565                         Parent->LoggingMutex->Lock();
00566                         ServerInstance->Logs->Log("m_mysql",DEFAULT,"SQL: Failed to connect database "+i->second->GetHost()+": Error: "+i->second->GetError());
00567                         i->second->SetEnable(false);
00568                         Parent->LoggingMutex->Unlock();
00569                 }
00570         }
00571 }
00572 
00573 void LoadDatabases(ConfigReader* conf, InspIRCd* ServerInstance, ModuleSQL* Parent)
00574 {
00575         Parent->ConnMutex->Lock();
00576         ClearOldConnections(conf);
00577         for (int j =0; j < conf->Enumerate("database"); j++)
00578         {
00579                 SQLhost host;
00580                 host.id         = conf->ReadValue("database", "id", j);
00581                 host.host       = conf->ReadValue("database", "hostname", j);
00582                 host.port       = conf->ReadInteger("database", "port", j, true);
00583                 host.name       = conf->ReadValue("database", "name", j);
00584                 host.user       = conf->ReadValue("database", "username", j);
00585                 host.pass       = conf->ReadValue("database", "password", j);
00586                 host.ssl        = conf->ReadFlag("database", "ssl", j);
00587 
00588                 if (HasHost(host))
00589                         continue;
00590 
00591                 if (!host.id.empty() && !host.host.empty() && !host.name.empty() && !host.user.empty() && !host.pass.empty())
00592                 {
00593                         SQLConnection* ThisSQL = new SQLConnection(host, Parent);
00594                         Connections[host.id] = ThisSQL;
00595                 }
00596         }
00597         ConnectDatabases(ServerInstance, Parent);
00598         Parent->ConnMutex->Unlock();
00599 }
00600 
00601 char FindCharId(const std::string &id)
00602 {
00603         char i = 1;
00604         for (ConnMap::iterator iter = Connections.begin(); iter != Connections.end(); ++iter, ++i)
00605         {
00606                 if (iter->first == id)
00607                 {
00608                         return i;
00609                 }
00610         }
00611         return 0;
00612 }
00613 
00614 ConnMap::iterator GetCharId(char id)
00615 {
00616         char i = 1;
00617         for (ConnMap::iterator iter = Connections.begin(); iter != Connections.end(); ++iter, ++i)
00618         {
00619                 if (i == id)
00620                         return iter;
00621         }
00622         return Connections.end();
00623 }
00624 
00625 void NotifyMainThread(SQLConnection* connection_with_new_result)
00626 {
00627         /* Here we write() to the socket the main thread has open
00628          * and we connect()ed back to before our thread became active.
00629          * The main thread is using a nonblocking socket tied into
00630          * the socket engine, so they wont block and they'll receive
00631          * nearly instant notification. Because we're in a seperate
00632          * thread, we can just use standard connect(), and we can
00633          * block if we like. We just send the connection id of the
00634          * connection back.
00635          *
00636          * NOTE: We only send a single char down the connection, this
00637          * way we know it wont get a partial read at the other end if
00638          * the system is especially congested (see bug #263).
00639          * The function FindCharId translates a connection name into a
00640          * one character id, and GetCharId translates a character id
00641          * back into an iterator.
00642          */
00643         char id = FindCharId(connection_with_new_result->GetID());
00644         send(QueueFD, &id, 1, 0);
00645 }
00646 
00647 class ModuleSQL;
00648 
00649 class DispatcherThread : public Thread
00650 {
00651  private:
00652         ModuleSQL* Parent;
00653         InspIRCd* ServerInstance;
00654  public:
00655         DispatcherThread(InspIRCd* Instance, ModuleSQL* CreatorModule) : Thread(), Parent(CreatorModule), ServerInstance(Instance) { }
00656         ~DispatcherThread() { }
00657         virtual void Run();
00658 };
00659 
00662 class Notifier : public BufferedSocket
00663 {
00664         ModuleSQL* Parent;
00665 
00666  public:
00667         Notifier(ModuleSQL* P, InspIRCd* SI, int newfd, char* ip) : BufferedSocket(SI, newfd, ip), Parent(P) { }
00668 
00669         virtual bool OnDataReady()
00670         {
00671                 char data = 0;
00672                 /* NOTE: Only a single character is read so we know we
00673                  * cant get a partial read. (We've been told that theres
00674                  * data waiting, so we wont ever get EAGAIN)
00675                  * The function GetCharId translates a single character
00676                  * back into an iterator.
00677                  */
00678 
00679                 if (ServerInstance->SE->Recv(this, &data, 1, 0) > 0)
00680                 {
00681                         Parent->ConnMutex->Lock();
00682                         ConnMap::iterator iter = GetCharId(data);
00683                         if (iter != Connections.end())
00684                         {
00685                                 /* Lock the mutex, send back the data */
00686                                 Parent->ResultsMutex->Lock();
00687                                 ResultQueue::iterator n = iter->second->rq.begin();
00688                                 (*n)->Send();
00689                                 delete (*n);
00690                                 iter->second->rq.pop_front();
00691                                 Parent->ResultsMutex->Unlock();
00692                                 Parent->ConnMutex->Unlock();
00693                                 return true;
00694                         }
00695                         /* No error, but unknown id */
00696                         Parent->ConnMutex->Unlock();
00697                         return true;
00698                 }
00699 
00700                 /* Erk, error on descriptor! */
00701                 return false;
00702         }
00703 };
00704 
00707 class MySQLListener : public ListenSocketBase
00708 {
00709         ModuleSQL* Parent;
00710         irc::sockets::insp_sockaddr sock_us;
00711         socklen_t uslen;
00712         FileReader* index;
00713 
00714  public:
00715         MySQLListener(ModuleSQL* P, InspIRCd* Instance, int port, const std::string &addr) : ListenSocketBase(Instance, port, addr), Parent(P)
00716         {
00717                 uslen = sizeof(sock_us);
00718                 if (getsockname(this->fd,(sockaddr*)&sock_us,&uslen))
00719                 {
00720                         throw ModuleException("Could not getsockname() to find out port number for ITC port");
00721                 }
00722         }
00723 
00724         virtual void OnAcceptReady(const std::string &ipconnectedto, int nfd, const std::string &incomingip)
00725         {
00726                 new Notifier(this->Parent, this->ServerInstance, nfd, (char *)ipconnectedto.c_str()); // XXX unsafe casts suck
00727         }
00728 
00729         /* Using getsockname and ntohs, we can determine which port number we were allocated */
00730         int GetPort()
00731         {
00732 #ifdef IPV6
00733                 return ntohs(sock_us.sin6_port);
00734 #else
00735                 return ntohs(sock_us.sin_port);
00736 #endif
00737         }
00738 };
00739 
00740 ModuleSQL::ModuleSQL(InspIRCd* Me) : Module(Me), rehashing(false)
00741 {
00742         ServerInstance->Modules->UseInterface("SQLutils");
00743 
00744         Conf = new ConfigReader(ServerInstance);
00745         PublicServerInstance = ServerInstance;
00746         currid = 0;
00747 
00748         /* Create a socket on a random port. Let the tcp stack allocate us an available port */
00749 #ifdef IPV6
00750         MessagePipe = new MySQLListener(this, ServerInstance, 0, "::1");
00751 #else
00752         MessagePipe = new MySQLListener(this, ServerInstance, 0, "127.0.0.1");
00753 #endif
00754 
00755         LoggingMutex = ServerInstance->Mutexes->CreateMutex();
00756         ConnMutex = ServerInstance->Mutexes->CreateMutex();
00757 
00758         if (MessagePipe->GetFd() == -1)
00759         {
00760                 delete ConnMutex;
00761                 ServerInstance->Modules->DoneWithInterface("SQLutils");
00762                 throw ModuleException("m_mysql: unable to create ITC pipe");
00763         }
00764         else
00765         {
00766                 LoggingMutex->Lock();
00767                 ServerInstance->Logs->Log("m_mysql", DEBUG, "MySQL: Interthread comms port is %d", MessagePipe->GetPort());
00768                 LoggingMutex->Unlock();
00769         }
00770 
00771         Dispatcher = new DispatcherThread(ServerInstance, this);
00772         ServerInstance->Threads->Create(Dispatcher);
00773 
00774         ResultsMutex = ServerInstance->Mutexes->CreateMutex();
00775         QueueMutex = ServerInstance->Mutexes->CreateMutex();
00776 
00777         if (!ServerInstance->Modules->PublishFeature("SQL", this))
00778         {
00779                 /* Tell worker thread to exit NOW,
00780                  * Automatically joins */
00781                 delete Dispatcher;
00782                 delete LoggingMutex;
00783                 delete ResultsMutex;
00784                 delete QueueMutex;
00785                 delete ConnMutex;
00786                 ServerInstance->Modules->DoneWithInterface("SQLutils");
00787                 throw ModuleException("m_mysql: Unable to publish feature 'SQL'");
00788         }
00789 
00790         ServerInstance->Modules->PublishInterface("SQL", this);
00791         Implementation eventlist[] = { I_OnRehash, I_OnRequest };
00792         ServerInstance->Modules->Attach(eventlist, this, 2);
00793 }
00794 
00795 ModuleSQL::~ModuleSQL()
00796 {
00797         delete Dispatcher;
00798         ClearAllConnections();
00799         delete Conf;
00800         ServerInstance->Modules->UnpublishInterface("SQL", this);
00801         ServerInstance->Modules->UnpublishFeature("SQL");
00802         ServerInstance->Modules->DoneWithInterface("SQLutils");
00803         delete LoggingMutex;
00804         delete ResultsMutex;
00805         delete QueueMutex;
00806         delete ConnMutex;
00807 }
00808 
00809 unsigned long ModuleSQL::NewID()
00810 {
00811         if (currid+1 == 0)
00812                 currid++;
00813         return ++currid;
00814 }
00815 
00816 const char* ModuleSQL::OnRequest(Request* request)
00817 {
00818         if(strcmp(SQLREQID, request->GetId()) == 0)
00819         {
00820                 SQLrequest* req = (SQLrequest*)request;
00821 
00822                 /* XXX: Lock */
00823                 QueueMutex->Lock();
00824 
00825                 ConnMap::iterator iter;
00826 
00827                 const char* returnval = NULL;
00828 
00829                 ConnMutex->Lock();
00830                 if((iter = Connections.find(req->dbid)) != Connections.end())
00831                 {
00832                         req->id = NewID();
00833                         iter->second->queue.push(*req);
00834                         returnval = SQLSUCCESS;
00835                 }
00836                 else
00837                 {
00838                         req->error.Id(SQL_BAD_DBID);
00839                 }
00840 
00841                 ConnMutex->Unlock();
00842                 QueueMutex->Unlock();
00843 
00844                 return returnval;
00845         }
00846 
00847         return NULL;
00848 }
00849 
00850 void ModuleSQL::OnRehash(User* user, const std::string &parameter)
00851 {
00852         rehashing = true;
00853 }
00854 
00855 Version ModuleSQL::GetVersion()
00856 {
00857         return Version("$Id: m_mysql.cpp 10622 2008-10-04 21:27:52Z brain $", VF_VENDOR | VF_SERVICEPROVIDER, API_VERSION);
00858 }
00859 
00860 void DispatcherThread::Run()
00861 {
00862         LoadDatabases(Parent->Conf, Parent->PublicServerInstance, Parent);
00863 
00864         /* Connect back to the Notifier */
00865 
00866         if ((QueueFD = socket(AF_FAMILY, SOCK_STREAM, 0)) == -1)
00867         {
00868                 /* crap, we're out of sockets... */
00869                 return;
00870         }
00871 
00872         irc::sockets::insp_sockaddr addr;
00873 
00874 #ifdef IPV6
00875         irc::sockets::insp_aton("::1", &addr.sin6_addr);
00876         addr.sin6_family = AF_FAMILY;
00877         addr.sin6_port = htons(MessagePipe->GetPort());
00878 #else
00879         irc::sockets::insp_inaddr ia;
00880         irc::sockets::insp_aton("127.0.0.1", &ia);
00881         addr.sin_family = AF_FAMILY;
00882         addr.sin_addr = ia;
00883         addr.sin_port = htons(MessagePipe->GetPort());
00884 #endif
00885 
00886         if (connect(QueueFD, (sockaddr*)&addr,sizeof(addr)) == -1)
00887         {
00888                 /* wtf, we cant connect to it, but we just created it! */
00889                 return;
00890         }
00891 
00892         while (this->GetExitFlag() == false)
00893         {
00894                 if (Parent->rehashing)
00895                 {
00896                 /* XXX: Lock */
00897                         Parent->QueueMutex->Lock();
00898                         Parent->rehashing = false;
00899                         LoadDatabases(Parent->Conf, Parent->PublicServerInstance, Parent);
00900                         Parent->QueueMutex->Unlock();
00901                         /* XXX: Unlock */
00902                 }
00903 
00904                 SQLConnection* conn = NULL;
00905                 /* XXX: Lock here for safety */
00906                 Parent->QueueMutex->Lock();
00907                 Parent->ConnMutex->Lock();
00908                 for (ConnMap::iterator i = Connections.begin(); i != Connections.end(); i++)
00909                 {
00910                         if (i->second->queue.totalsize())
00911                         {
00912                                 conn = i->second;
00913                                 break;
00914                         }
00915                 }
00916                 Parent->ConnMutex->Unlock();
00917                 Parent->QueueMutex->Unlock();
00918                 /* XXX: Unlock */
00919 
00920                 /* Theres an item! */
00921                 if (conn)
00922                 {
00923                         conn->DoLeadingQuery();
00924 
00925                         /* XXX: Lock */
00926                         Parent->QueueMutex->Lock();
00927                         conn->queue.pop();
00928                         Parent->QueueMutex->Unlock();
00929                         /* XXX: Unlock */
00930                 }
00931 
00932                 usleep(1000);
00933         }
00934 }
00935 
00936 MODULE_INIT(ModuleSQL)
00937