The InspIRCd Project
Home | Developers | Wiki | Forums | Bug Tracker | SVN | Download
Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | Directories | File List | Namespace Members | Class Members | File Members

m_sqlite3.cpp

Go to the documentation of this file.
00001 /*               +------------------------------------+
00002  *               | Inspire Internet Relay Chat Daemon |
00003  *               +------------------------------------+
00004  *
00005  *      InspIRCd: (C) 2002-2008 InspIRCd Development Team
00006  * See: http://www.inspircd.org/wiki/index.php/Credits
00007  *
00008  * This program is free but copyrighted software; see
00009  *                        the file COPYING for details.
00010  *
00011  * ---------------------------------------------------
00012  */
00013 
00014 #include "inspircd.h"
00015 #include <sqlite3.h>
00016 #include "m_sqlv2.h"
00017 
00018 /* $ModDesc: sqlite3 provider */
00019 /* $CompileFlags: pkgconfversion("sqlite3","3.3") pkgconfincludes("sqlite3","/sqlite3.h","") */
00020 /* $LinkerFlags: pkgconflibs("sqlite3","/libsqlite3.so","-lsqlite3") */
00021 /* $ModDep: m_sqlv2.h */
00022 /* $NoPedantic */
00023 
00024 class SQLConn;
00025 class SQLite3Result;
00026 class ResultNotifier;
00027 class SQLiteListener;
00028 class ModuleSQLite3;
00029 
00030 typedef std::map<std::string, SQLConn*> ConnMap;
00031 typedef std::deque<classbase*> paramlist;
00032 typedef std::deque<SQLite3Result*> ResultQueue;
00033 
00034 ResultNotifier* notifier = NULL;
00035 SQLiteListener* listener = NULL;
00036 int QueueFD = -1;
00037 
00038 class ResultNotifier : public BufferedSocket
00039 {
00040         ModuleSQLite3* mod;
00041 
00042  public:
00043         ResultNotifier(ModuleSQLite3* m, InspIRCd* SI, int newfd, char* ip) : BufferedSocket(SI, newfd, ip), mod(m)
00044         {
00045         }
00046 
00047         virtual bool OnDataReady()
00048         {
00049                 char data = 0;
00050                 if (ServerInstance->SE->Recv(this, &data, 1, 0) > 0)
00051                 {
00052                         Dispatch();
00053                         return true;
00054                 }
00055                 return false;
00056         }
00057 
00058         void Dispatch();
00059 };
00060 
00061 class SQLiteListener : public ListenSocketBase
00062 {
00063         ModuleSQLite3* Parent;
00064         irc::sockets::insp_sockaddr sock_us;
00065         socklen_t uslen;
00066         FileReader* index;
00067 
00068  public:
00069         SQLiteListener(ModuleSQLite3* P, InspIRCd* Instance, int port, const std::string &addr) : ListenSocketBase(Instance, port, addr), Parent(P)
00070         {
00071                 uslen = sizeof(sock_us);
00072                 if (getsockname(this->fd,(sockaddr*)&sock_us,&uslen))
00073                 {
00074                         throw ModuleException("Could not getsockname() to find out port number for ITC port");
00075                 }
00076         }
00077 
00078         virtual void OnAcceptReady(const std::string &ipconnectedto, int nfd, const std::string &incomingip)
00079         {
00080                 new ResultNotifier(this->Parent, this->ServerInstance, nfd, (char *)ipconnectedto.c_str()); // XXX unsafe casts suck
00081         }
00082 
00083         /* Using getsockname and ntohs, we can determine which port number we were allocated */
00084         int GetPort()
00085         {
00086 #ifdef IPV6
00087                 return ntohs(sock_us.sin6_port);
00088 #else
00089                 return ntohs(sock_us.sin_port);
00090 #endif
00091         }
00092 };
00093 
00094 class SQLite3Result : public SQLresult
00095 {
00096  private:
00097         int currentrow;
00098         int rows;
00099         int cols;
00100 
00101         std::vector<std::string> colnames;
00102         std::vector<SQLfieldList> fieldlists;
00103         SQLfieldList emptyfieldlist;
00104 
00105         SQLfieldList* fieldlist;
00106         SQLfieldMap* fieldmap;
00107 
00108  public:
00109         SQLite3Result(Module* self, Module* to, unsigned int rid)
00110         : SQLresult(self, to, rid), currentrow(0), rows(0), cols(0), fieldlist(NULL), fieldmap(NULL)
00111         {
00112         }
00113 
00114         ~SQLite3Result()
00115         {
00116         }
00117 
00118         void AddRow(int colsnum, char **dat, char **colname)
00119         {
00120                 colnames.clear();
00121                 cols = colsnum;
00122                 for (int i = 0; i < colsnum; i++)
00123                 {
00124                         fieldlists.resize(fieldlists.size()+1);
00125                         colnames.push_back(colname[i]);
00126                         SQLfield sf(dat[i] ? dat[i] : "", dat[i] ? false : true);
00127                         fieldlists[rows].push_back(sf);
00128                 }
00129                 rows++;
00130         }
00131 
00132         void UpdateAffectedCount()
00133         {
00134                 rows++;
00135         }
00136 
00137         virtual int Rows()
00138         {
00139                 return rows;
00140         }
00141 
00142         virtual int Cols()
00143         {
00144                 return cols;
00145         }
00146 
00147         virtual std::string ColName(int column)
00148         {
00149                 if (column < (int)colnames.size())
00150                 {
00151                         return colnames[column];
00152                 }
00153                 else
00154                 {
00155                         throw SQLbadColName();
00156                 }
00157                 return "";
00158         }
00159 
00160         virtual int ColNum(const std::string &column)
00161         {
00162                 for (unsigned int i = 0; i < colnames.size(); i++)
00163                 {
00164                         if (column == colnames[i])
00165                                 return i;
00166                 }
00167                 throw SQLbadColName();
00168                 return 0;
00169         }
00170 
00171         virtual SQLfield GetValue(int row, int column)
00172         {
00173                 if ((row >= 0) && (row < rows) && (column >= 0) && (column < Cols()))
00174                 {
00175                         return fieldlists[row][column];
00176                 }
00177 
00178                 throw SQLbadColName();
00179 
00180                 /* XXX: We never actually get here because of the throw */
00181                 return SQLfield("",true);
00182         }
00183 
00184         virtual SQLfieldList& GetRow()
00185         {
00186                 if (currentrow < rows)
00187                         return fieldlists[currentrow];
00188                 else
00189                         return emptyfieldlist;
00190         }
00191 
00192         virtual SQLfieldMap& GetRowMap()
00193         {
00194                 /* In an effort to reduce overhead we don't actually allocate the map
00195                  * until the first time it's needed...so...
00196                  */
00197                 if(fieldmap)
00198                 {
00199                         fieldmap->clear();
00200                 }
00201                 else
00202                 {
00203                         fieldmap = new SQLfieldMap;
00204                 }
00205 
00206                 if (currentrow < rows)
00207                 {
00208                         for (int i = 0; i < Cols(); i++)
00209                         {
00210                                 fieldmap->insert(std::make_pair(ColName(i), GetValue(currentrow, i)));
00211                         }
00212                         currentrow++;
00213                 }
00214 
00215                 return *fieldmap;
00216         }
00217 
00218         virtual SQLfieldList* GetRowPtr()
00219         {
00220                 fieldlist = new SQLfieldList();
00221 
00222                 if (currentrow < rows)
00223                 {
00224                         for (int i = 0; i < Rows(); i++)
00225                         {
00226                                 fieldlist->push_back(fieldlists[currentrow][i]);
00227                         }
00228                         currentrow++;
00229                 }
00230                 return fieldlist;
00231         }
00232 
00233         virtual SQLfieldMap* GetRowMapPtr()
00234         {
00235                 fieldmap = new SQLfieldMap();
00236 
00237                 if (currentrow < rows)
00238                 {
00239                         for (int i = 0; i < Cols(); i++)
00240                         {
00241                                 fieldmap->insert(std::make_pair(colnames[i],GetValue(currentrow, i)));
00242                         }
00243                         currentrow++;
00244                 }
00245 
00246                 return fieldmap;
00247         }
00248 
00249         virtual void Free(SQLfieldMap* fm)
00250         {
00251                 delete fm;
00252         }
00253 
00254         virtual void Free(SQLfieldList* fl)
00255         {
00256                 delete fl;
00257         }
00258 
00259 
00260 };
00261 
00262 class SQLConn : public classbase
00263 {
00264  private:
00265         ResultQueue results;
00266         InspIRCd* ServerInstance;
00267         Module* mod;
00268         SQLhost host;
00269         sqlite3* conn;
00270 
00271  public:
00272         SQLConn(InspIRCd* SI, Module* m, const SQLhost& hi)
00273         : ServerInstance(SI), mod(m), host(hi)
00274         {
00275                 if (OpenDB() != SQLITE_OK)
00276                 {
00277                         ServerInstance->Logs->Log("m_sqlite3",DEFAULT, "WARNING: Could not open DB with id: " + host.id);
00278                         CloseDB();
00279                 }
00280         }
00281 
00282         ~SQLConn()
00283         {
00284                 CloseDB();
00285         }
00286 
00287         SQLerror Query(SQLrequest &req)
00288         {
00289                 /* Pointer to the buffer we screw around with substitution in */
00290                 char* query;
00291 
00292                 /* Pointer to the current end of query, where we append new stuff */
00293                 char* queryend;
00294 
00295                 /* Total length of the unescaped parameters */
00296                 unsigned long paramlen;
00297 
00298                 /* Total length of query, used for binary-safety */
00299                 unsigned long querylength = 0;
00300 
00301                 paramlen = 0;
00302                 for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)
00303                 {
00304                         paramlen += i->size();
00305                 }
00306 
00307                 /* To avoid a lot of allocations, allocate enough memory for the biggest the escaped query could possibly be.
00308                  * sizeofquery + (totalparamlength*2) + 1
00309                  *
00310                  * The +1 is for null-terminating the string
00311                  */
00312                 query = new char[req.query.q.length() + (paramlen*2) + 1];
00313                 queryend = query;
00314 
00315                 for(unsigned long i = 0; i < req.query.q.length(); i++)
00316                 {
00317                         if(req.query.q[i] == '?')
00318                         {
00319                                 if(req.query.p.size())
00320                                 {
00321                                         char* escaped;
00322                                         escaped = sqlite3_mprintf("%q", req.query.p.front().c_str());
00323                                         for (char* n = escaped; *n; n++)
00324                                         {
00325                                                 *queryend = *n;
00326                                                 queryend++;
00327                                         }
00328                                         sqlite3_free(escaped);
00329                                         req.query.p.pop_front();
00330                                 }
00331                                 else
00332                                         break;
00333                         }
00334                         else
00335                         {
00336                                 *queryend = req.query.q[i];
00337                                 queryend++;
00338                         }
00339                         querylength++;
00340                 }
00341                 *queryend = 0;
00342                 req.query.q = query;
00343 
00344                 SQLite3Result* res = new SQLite3Result(mod, req.GetSource(), req.id);
00345                 res->dbid = host.id;
00346                 res->query = req.query.q;
00347                 paramlist params;
00348                 params.push_back(this);
00349                 params.push_back(res);
00350 
00351                 char *errmsg = 0;
00352                 sqlite3_update_hook(conn, QueryUpdateHook, &params);
00353                 if (sqlite3_exec(conn, req.query.q.data(), QueryResult, &params, &errmsg) != SQLITE_OK)
00354                 {
00355                         std::string error(errmsg);
00356                         sqlite3_free(errmsg);
00357                         delete[] query;
00358                         delete res;
00359                         return SQLerror(SQL_QSEND_FAIL, error);
00360                 }
00361                 delete[] query;
00362 
00363                 results.push_back(res);
00364                 SendNotify();
00365                 return SQLerror();
00366         }
00367 
00368         static int QueryResult(void *params, int argc, char **argv, char **azColName)
00369         {
00370                 paramlist* p = (paramlist*)params;
00371                 ((SQLConn*)(*p)[0])->ResultReady(((SQLite3Result*)(*p)[1]), argc, argv, azColName);
00372                 return 0;
00373         }
00374 
00375         static void QueryUpdateHook(void *params, int eventid, char const * azSQLite, char const * azColName, sqlite_int64 rowid)
00376         {
00377                 paramlist* p = (paramlist*)params;
00378                 ((SQLConn*)(*p)[0])->AffectedReady(((SQLite3Result*)(*p)[1]));
00379         }
00380 
00381         void ResultReady(SQLite3Result *res, int cols, char **data, char **colnames)
00382         {
00383                 res->AddRow(cols, data, colnames);
00384         }
00385 
00386         void AffectedReady(SQLite3Result *res)
00387         {
00388                 res->UpdateAffectedCount();
00389         }
00390 
00391         int OpenDB()
00392         {
00393                 return sqlite3_open(host.host.c_str(), &conn);
00394         }
00395 
00396         void CloseDB()
00397         {
00398                 sqlite3_interrupt(conn);
00399                 sqlite3_close(conn);
00400         }
00401 
00402         SQLhost GetConfHost()
00403         {
00404                 return host;
00405         }
00406 
00407         void SendResults()
00408         {
00409                 while (results.size())
00410                 {
00411                         SQLite3Result* res = results[0];
00412                         if (res->GetDest())
00413                         {
00414                                 res->Send();
00415                         }
00416                         else
00417                         {
00418                                 /* If the client module is unloaded partway through a query then the provider will set
00419                                  * the pointer to NULL. We cannot just cancel the query as the result will still come
00420                                  * through at some point...and it could get messy if we play with invalid pointers...
00421                                  */
00422                                 delete res;
00423                         }
00424                         results.pop_front();
00425                 }
00426         }
00427 
00428         void ClearResults()
00429         {
00430                 while (results.size())
00431                 {
00432                         SQLite3Result* res = results[0];
00433                         delete res;
00434                         results.pop_front();
00435                 }
00436         }
00437 
00438         void SendNotify()
00439         {
00440                 if (QueueFD < 0)
00441                 {
00442                         if ((QueueFD = socket(AF_FAMILY, SOCK_STREAM, 0)) == -1)
00443                         {
00444                                 /* crap, we're out of sockets... */
00445                                 return;
00446                         }
00447 
00448                         irc::sockets::insp_sockaddr addr;
00449 
00450 #ifdef IPV6
00451                         irc::sockets::insp_aton("::1", &addr.sin6_addr);
00452                         addr.sin6_family = AF_FAMILY;
00453                         addr.sin6_port = htons(listener->GetPort());
00454 #else
00455                         irc::sockets::insp_inaddr ia;
00456                         irc::sockets::insp_aton("127.0.0.1", &ia);
00457                         addr.sin_family = AF_FAMILY;
00458                         addr.sin_addr = ia;
00459                         addr.sin_port = htons(listener->GetPort());
00460 #endif
00461 
00462                         if (connect(QueueFD, (sockaddr*)&addr,sizeof(addr)) == -1)
00463                         {
00464                                 /* wtf, we cant connect to it, but we just created it! */
00465                                 return;
00466                         }
00467                 }
00468                 char id = 0;
00469                 send(QueueFD, &id, 1, 0);
00470         }
00471 
00472 };
00473 
00474 
00475 class ModuleSQLite3 : public Module
00476 {
00477  private:
00478         ConnMap connections;
00479         unsigned long currid;
00480 
00481  public:
00482         ModuleSQLite3(InspIRCd* Me)
00483         : Module(Me), currid(0)
00484         {
00485                 ServerInstance->Modules->UseInterface("SQLutils");
00486 
00487                 if (!ServerInstance->Modules->PublishFeature("SQL", this))
00488                 {
00489                         throw ModuleException("m_sqlite3: Unable to publish feature 'SQL'");
00490                 }
00491 
00492                 /* Create a socket on a random port. Let the tcp stack allocate us an available port */
00493 #ifdef IPV6
00494                 listener = new SQLiteListener(this, ServerInstance, 0, "::1");
00495 #else
00496                 listener = new SQLiteListener(this, ServerInstance, 0, "127.0.0.1");
00497 #endif
00498 
00499                 if (listener->GetFd() == -1)
00500                 {
00501                         ServerInstance->Modules->DoneWithInterface("SQLutils");
00502                         throw ModuleException("m_sqlite3: unable to create ITC pipe");
00503                 }
00504                 else
00505                 {
00506                         ServerInstance->Logs->Log("m_sqlite3", DEBUG, "SQLite: Interthread comms port is %d", listener->GetPort());
00507                 }
00508 
00509                 ReadConf();
00510 
00511                 ServerInstance->Modules->PublishInterface("SQL", this);
00512                 Implementation eventlist[] = { I_OnRequest, I_OnRehash };
00513                 ServerInstance->Modules->Attach(eventlist, this, 2);
00514         }
00515 
00516         virtual ~ModuleSQLite3()
00517         {
00518                 ClearQueue();
00519                 ClearAllConnections();
00520 
00521                 ServerInstance->SE->DelFd(listener);
00522                 ServerInstance->BufferedSocketCull();
00523 
00524                 if (QueueFD >= 0)
00525                 {
00526                         shutdown(QueueFD, 2);
00527                         close(QueueFD);
00528                 }
00529 
00530                 if (notifier)
00531                 {
00532                         ServerInstance->SE->DelFd(notifier);
00533                         notifier->Close();
00534                         ServerInstance->BufferedSocketCull();
00535                 }
00536 
00537                 ServerInstance->Modules->UnpublishInterface("SQL", this);
00538                 ServerInstance->Modules->UnpublishFeature("SQL");
00539                 ServerInstance->Modules->DoneWithInterface("SQLutils");
00540         }
00541 
00542 
00543         void SendQueue()
00544         {
00545                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
00546                 {
00547                         iter->second->SendResults();
00548                 }
00549         }
00550 
00551         void ClearQueue()
00552         {
00553                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
00554                 {
00555                         iter->second->ClearResults();
00556                 }
00557         }
00558 
00559         bool HasHost(const SQLhost &host)
00560         {
00561                 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
00562                 {
00563                         if (host == iter->second->GetConfHost())
00564                                 return true;
00565                 }
00566                 return false;
00567         }
00568 
00569         bool HostInConf(const SQLhost &h)
00570         {
00571                 ConfigReader conf(ServerInstance);
00572                 for(int i = 0; i < conf.Enumerate("database"); i++)
00573                 {
00574                         SQLhost host;
00575                         host.id         = conf.ReadValue("database", "id", i);
00576                         host.host       = conf.ReadValue("database", "hostname", i);
00577                         host.port       = conf.ReadInteger("database", "port", i, true);
00578                         host.name       = conf.ReadValue("database", "name", i);
00579                         host.user       = conf.ReadValue("database", "username", i);
00580                         host.pass       = conf.ReadValue("database", "password", i);
00581                         if (h == host)
00582                                 return true;
00583                 }
00584                 return false;
00585         }
00586 
00587         void ReadConf()
00588         {
00589                 ClearOldConnections();
00590 
00591                 ConfigReader conf(ServerInstance);
00592                 for(int i = 0; i < conf.Enumerate("database"); i++)
00593                 {
00594                         SQLhost host;
00595 
00596                         host.id         = conf.ReadValue("database", "id", i);
00597                         host.host       = conf.ReadValue("database", "hostname", i);
00598                         host.port       = conf.ReadInteger("database", "port", i, true);
00599                         host.name       = conf.ReadValue("database", "name", i);
00600                         host.user       = conf.ReadValue("database", "username", i);
00601                         host.pass       = conf.ReadValue("database", "password", i);
00602 
00603                         if (HasHost(host))
00604                                 continue;
00605 
00606                         this->AddConn(host);
00607                 }
00608         }
00609 
00610         void AddConn(const SQLhost& hi)
00611         {
00612                 if (HasHost(hi))
00613                 {
00614                         ServerInstance->Logs->Log("m_sqlite3",DEFAULT, "WARNING: A sqlite connection with id: %s already exists. Aborting database open attempt.", hi.id.c_str());
00615                         return;
00616                 }
00617 
00618                 SQLConn* newconn;
00619 
00620                 newconn = new SQLConn(ServerInstance, this, hi);
00621 
00622                 connections.insert(std::make_pair(hi.id, newconn));
00623         }
00624 
00625         void ClearOldConnections()
00626         {
00627                 ConnMap::iterator iter,safei;
00628                 for (iter = connections.begin(); iter != connections.end(); iter++)
00629                 {
00630                         if (!HostInConf(iter->second->GetConfHost()))
00631                         {
00632                                 delete iter->second;
00633                                 safei = iter;
00634                                 --iter;
00635                                 connections.erase(safei);
00636                         }
00637                 }
00638         }
00639 
00640         void ClearAllConnections()
00641         {
00642                 ConnMap::iterator i;
00643                 while ((i = connections.begin()) != connections.end())
00644                 {
00645                         connections.erase(i);
00646                         delete i->second;
00647                 }
00648         }
00649 
00650         virtual void OnRehash(User* user, const std::string &parameter)
00651         {
00652                 ReadConf();
00653         }
00654 
00655         virtual const char* OnRequest(Request* request)
00656         {
00657                 if(strcmp(SQLREQID, request->GetId()) == 0)
00658                 {
00659                         SQLrequest* req = (SQLrequest*)request;
00660                         ConnMap::iterator iter;
00661                         if((iter = connections.find(req->dbid)) != connections.end())
00662                         {
00663                                 req->id = NewID();
00664                                 req->error = iter->second->Query(*req);
00665                                 return SQLSUCCESS;
00666                         }
00667                         else
00668                         {
00669                                 req->error.Id(SQL_BAD_DBID);
00670                                 return NULL;
00671                         }
00672                 }
00673                 return NULL;
00674         }
00675 
00676         unsigned long NewID()
00677         {
00678                 if (currid+1 == 0)
00679                         currid++;
00680 
00681                 return ++currid;
00682         }
00683 
00684         virtual Version GetVersion()
00685         {
00686                 return Version("$Id: m_sqlite3.cpp 10622 2008-10-04 21:27:52Z brain $", VF_VENDOR | VF_SERVICEPROVIDER, API_VERSION);
00687         }
00688 
00689 };
00690 
00691 void ResultNotifier::Dispatch()
00692 {
00693         mod->SendQueue();
00694 }
00695 
00696 MODULE_INIT(ModuleSQLite3)