00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014 #include "inspircd.h"
00015 #include <cstdlib>
00016 #include <sstream>
00017 #include <libpq-fe.h>
00018 #include "m_sqlv2.h"
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034 class SQLConn;
00035
00036 typedef std::map<std::string, SQLConn*> ConnMap;
00037
00038
00039
00040
00041
00042
00043
00044
00045 enum SQLstatus { CREAD, CWRITE, WREAD, WWRITE, RREAD, RWRITE };
00046
00049 std::string SQLhost::GetDSN()
00050 {
00051 std::ostringstream conninfo("connect_timeout = '2'");
00052
00053 if (ip.length())
00054 conninfo << " hostaddr = '" << ip << "'";
00055
00056 if (port)
00057 conninfo << " port = '" << port << "'";
00058
00059 if (name.length())
00060 conninfo << " dbname = '" << name << "'";
00061
00062 if (user.length())
00063 conninfo << " user = '" << user << "'";
00064
00065 if (pass.length())
00066 conninfo << " password = '" << pass << "'";
00067
00068 if (ssl)
00069 {
00070 conninfo << " sslmode = 'require'";
00071 }
00072 else
00073 {
00074 conninfo << " sslmode = 'disable'";
00075 }
00076
00077 return conninfo.str();
00078 }
00079
00080 class ReconnectTimer : public Timer
00081 {
00082 private:
00083 Module* mod;
00084 public:
00085 ReconnectTimer(InspIRCd* SI, Module* m)
00086 : Timer(5, SI->Time(), false), mod(m)
00087 {
00088 }
00089 virtual void Tick(time_t TIME);
00090 };
00091
00092
00095 class SQLresolver : public Resolver
00096 {
00097 private:
00098 SQLhost host;
00099 Module* mod;
00100 public:
00101 SQLresolver(Module* m, InspIRCd* Instance, const SQLhost& hi, bool &cached)
00102 : Resolver(Instance, hi.host, DNS_QUERY_FORWARD, cached, (Module*)m), host(hi), mod(m)
00103 {
00104 }
00105
00106 virtual void OnLookupComplete(const std::string &result, unsigned int ttl, bool cached, int resultnum = 0);
00107
00108 virtual void OnError(ResolverError e, const std::string &errormessage)
00109 {
00110 ServerInstance->Logs->Log("m_pgsql",DEBUG, "PgSQL: DNS lookup failed (%s), dying horribly", errormessage.c_str());
00111 }
00112 };
00113
00121 class PgSQLresult : public SQLresult
00122 {
00123 PGresult* res;
00124 int currentrow;
00125 int rows;
00126 int cols;
00127
00128 SQLfieldList* fieldlist;
00129 SQLfieldMap* fieldmap;
00130 public:
00131 PgSQLresult(Module* self, Module* to, unsigned long rid, PGresult* result)
00132 : SQLresult(self, to, rid), res(result), currentrow(0), fieldlist(NULL), fieldmap(NULL)
00133 {
00134 rows = PQntuples(res);
00135 cols = PQnfields(res);
00136 }
00137
00138 ~PgSQLresult()
00139 {
00140
00141 if(fieldlist)
00142 delete fieldlist;
00143
00144 if(fieldmap)
00145 delete fieldmap;
00146
00147 PQclear(res);
00148 }
00149
00150 virtual int Rows()
00151 {
00152 if(!cols && !rows)
00153 {
00154 return atoi(PQcmdTuples(res));
00155 }
00156 else
00157 {
00158 return rows;
00159 }
00160 }
00161
00162 virtual int Cols()
00163 {
00164 return PQnfields(res);
00165 }
00166
00167 virtual std::string ColName(int column)
00168 {
00169 char* name = PQfname(res, column);
00170
00171 return (name) ? name : "";
00172 }
00173
00174 virtual int ColNum(const std::string &column)
00175 {
00176 int n = PQfnumber(res, column.c_str());
00177
00178 if(n == -1)
00179 {
00180 throw SQLbadColName();
00181 }
00182 else
00183 {
00184 return n;
00185 }
00186 }
00187
00188 virtual SQLfield GetValue(int row, int column)
00189 {
00190 char* v = PQgetvalue(res, row, column);
00191
00192 if(v)
00193 {
00194 return SQLfield(std::string(v, PQgetlength(res, row, column)), PQgetisnull(res, row, column));
00195 }
00196 else
00197 {
00198 throw SQLbadColName();
00199 }
00200 }
00201
00202 virtual SQLfieldList& GetRow()
00203 {
00204
00205
00206
00207 if(fieldlist)
00208 {
00209 fieldlist->clear();
00210 }
00211 else
00212 {
00213 fieldlist = new SQLfieldList;
00214 }
00215
00216 if(currentrow < PQntuples(res))
00217 {
00218 int ncols = PQnfields(res);
00219
00220 for(int i = 0; i < ncols; i++)
00221 {
00222 fieldlist->push_back(GetValue(currentrow, i));
00223 }
00224
00225 currentrow++;
00226 }
00227
00228 return *fieldlist;
00229 }
00230
00231 virtual SQLfieldMap& GetRowMap()
00232 {
00233
00234
00235
00236 if(fieldmap)
00237 {
00238 fieldmap->clear();
00239 }
00240 else
00241 {
00242 fieldmap = new SQLfieldMap;
00243 }
00244
00245 if(currentrow < PQntuples(res))
00246 {
00247 int ncols = PQnfields(res);
00248
00249 for(int i = 0; i < ncols; i++)
00250 {
00251 fieldmap->insert(std::make_pair(ColName(i), GetValue(currentrow, i)));
00252 }
00253
00254 currentrow++;
00255 }
00256
00257 return *fieldmap;
00258 }
00259
00260 virtual SQLfieldList* GetRowPtr()
00261 {
00262 SQLfieldList* fl = new SQLfieldList;
00263
00264 if(currentrow < PQntuples(res))
00265 {
00266 int ncols = PQnfields(res);
00267
00268 for(int i = 0; i < ncols; i++)
00269 {
00270 fl->push_back(GetValue(currentrow, i));
00271 }
00272
00273 currentrow++;
00274 }
00275
00276 return fl;
00277 }
00278
00279 virtual SQLfieldMap* GetRowMapPtr()
00280 {
00281 SQLfieldMap* fm = new SQLfieldMap;
00282
00283 if(currentrow < PQntuples(res))
00284 {
00285 int ncols = PQnfields(res);
00286
00287 for(int i = 0; i < ncols; i++)
00288 {
00289 fm->insert(std::make_pair(ColName(i), GetValue(currentrow, i)));
00290 }
00291
00292 currentrow++;
00293 }
00294
00295 return fm;
00296 }
00297
00298 virtual void Free(SQLfieldMap* fm)
00299 {
00300 delete fm;
00301 }
00302
00303 virtual void Free(SQLfieldList* fl)
00304 {
00305 delete fl;
00306 }
00307 };
00308
00311 class SQLConn : public EventHandler
00312 {
00313 private:
00314 InspIRCd* ServerInstance;
00315 SQLhost confhost;
00316 Module* us;
00317 PGconn* sql;
00318 SQLstatus status;
00319 bool qinprog;
00320 QueryQueue queue;
00321 time_t idle;
00322
00323 public:
00324 SQLConn(InspIRCd* SI, Module* self, const SQLhost& hi)
00325 : EventHandler(), ServerInstance(SI), confhost(hi), us(self), sql(NULL), status(CWRITE), qinprog(false)
00326 {
00327 idle = this->ServerInstance->Time();
00328 if(!DoConnect())
00329 {
00330 ServerInstance->Logs->Log("m_pgsql",DEFAULT, "WARNING: Could not connect to database with id: " + ConvToStr(hi.id));
00331 DelayReconnect();
00332 }
00333 }
00334
00335 ~SQLConn()
00336 {
00337 Close();
00338 }
00339
00340 virtual void HandleEvent(EventType et, int errornum)
00341 {
00342 switch (et)
00343 {
00344 case EVENT_READ:
00345 OnDataReady();
00346 break;
00347
00348 case EVENT_WRITE:
00349 OnWriteReady();
00350 break;
00351
00352 case EVENT_ERROR:
00353 DelayReconnect();
00354 break;
00355
00356 default:
00357 break;
00358 }
00359 }
00360
00361 bool DoConnect()
00362 {
00363 if(!(sql = PQconnectStart(confhost.GetDSN().c_str())))
00364 return false;
00365
00366 if(PQstatus(sql) == CONNECTION_BAD)
00367 return false;
00368
00369 if(PQsetnonblocking(sql, 1) == -1)
00370 return false;
00371
00372
00373
00374
00375 this->fd = PQsocket(sql);
00376
00377 if(this->fd <= -1)
00378 return false;
00379
00380 if (!this->ServerInstance->SE->AddFd(this))
00381 {
00382 ServerInstance->Logs->Log("m_pgsql",DEBUG, "BUG: Couldn't add pgsql socket to socket engine");
00383 return false;
00384 }
00385
00386
00387 return DoPoll();
00388 }
00389
00390 bool DoPoll()
00391 {
00392 switch(PQconnectPoll(sql))
00393 {
00394 case PGRES_POLLING_WRITING:
00395 ServerInstance->SE->WantWrite(this);
00396 status = CWRITE;
00397 return true;
00398 case PGRES_POLLING_READING:
00399 status = CREAD;
00400 return true;
00401 case PGRES_POLLING_FAILED:
00402 return false;
00403 case PGRES_POLLING_OK:
00404 status = WWRITE;
00405 return DoConnectedPoll();
00406 default:
00407 return true;
00408 }
00409 }
00410
00411 bool DoConnectedPoll()
00412 {
00413 if(!qinprog && queue.totalsize())
00414 {
00415
00416 SQLrequest& query = queue.front();
00417 DoQuery(query);
00418 }
00419
00420 if(PQconsumeInput(sql))
00421 {
00422
00423
00424
00425 idle = this->ServerInstance->Time();
00426
00427 if (PQisBusy(sql))
00428 {
00429
00430 }
00431 else if (qinprog)
00432 {
00433
00434 SQLrequest& query = queue.front();
00435
00436
00437 Module* to = query.GetSource();
00438
00439
00440 PGresult* result = PQgetResult(sql);
00441
00442
00443
00444
00445
00446
00447
00448 while (PGresult* temp = PQgetResult(sql))
00449 {
00450 PQclear(result);
00451 result = temp;
00452 }
00453
00454 if(to)
00455 {
00456
00457 PgSQLresult reply(us, to, query.id, result);
00458
00459
00460 reply.query = query.query.q;
00461
00462 switch(PQresultStatus(result))
00463 {
00464 case PGRES_EMPTY_QUERY:
00465 case PGRES_BAD_RESPONSE:
00466 case PGRES_FATAL_ERROR:
00467 reply.error.Id(SQL_QREPLY_FAIL);
00468 reply.error.Str(PQresultErrorMessage(result));
00469 default:;
00470
00471 }
00472
00473 reply.Send();
00474
00475
00476 }
00477 else
00478 {
00479
00480
00481
00482
00483 PQclear(result);
00484 }
00485 qinprog = false;
00486 queue.pop();
00487 DoConnectedPoll();
00488 }
00489 return true;
00490 }
00491 else
00492 {
00493
00494
00495
00496
00497
00498 DelayReconnect();
00499 return true;
00500 }
00501 }
00502
00503 bool DoResetPoll()
00504 {
00505 switch(PQresetPoll(sql))
00506 {
00507 case PGRES_POLLING_WRITING:
00508 ServerInstance->SE->WantWrite(this);
00509 status = CWRITE;
00510 return DoPoll();
00511 case PGRES_POLLING_READING:
00512 status = CREAD;
00513 return true;
00514 case PGRES_POLLING_FAILED:
00515 return false;
00516 case PGRES_POLLING_OK:
00517 status = WWRITE;
00518 return DoConnectedPoll();
00519 default:
00520 return true;
00521 }
00522 }
00523
00524 bool OnDataReady()
00525 {
00526
00527 return DoEvent();
00528 }
00529
00530 bool OnWriteReady()
00531 {
00532
00533 return DoEvent();
00534 }
00535
00536 bool OnConnected()
00537 {
00538 return DoEvent();
00539 }
00540
00541 void DelayReconnect();
00542
00543 bool DoEvent()
00544 {
00545 bool ret;
00546
00547 if((status == CREAD) || (status == CWRITE))
00548 {
00549 ret = DoPoll();
00550 }
00551 else if((status == RREAD) || (status == RWRITE))
00552 {
00553 ret = DoResetPoll();
00554 }
00555 else
00556 {
00557 ret = DoConnectedPoll();
00558 }
00559 return ret;
00560 }
00561
00562 SQLerror DoQuery(SQLrequest &req)
00563 {
00564 if((status == WREAD) || (status == WWRITE))
00565 {
00566 if(!qinprog)
00567 {
00568
00569
00570
00571 char* query;
00572
00573 char* queryend;
00574
00575 unsigned int paramlen;
00576
00577 paramlen = 0;
00578
00579 for(ParamL::iterator i = req.query.p.begin(); i != req.query.p.end(); i++)
00580 {
00581 paramlen += i->size();
00582 }
00583
00584
00585
00586
00587
00588
00589
00590 query = new char[req.query.q.length() + (paramlen*2) + 1];
00591 queryend = query;
00592
00593
00594
00595
00596
00597 for(unsigned int i = 0; i < req.query.q.length(); i++)
00598 {
00599 if(req.query.q[i] == '?')
00600 {
00601
00602
00603
00604
00605
00606
00607
00608 if(req.query.p.size())
00609 {
00610 int error = 0;
00611 size_t len = 0;
00612
00613 #ifdef PGSQL_HAS_ESCAPECONN
00614 len = PQescapeStringConn(sql, queryend, req.query.p.front().c_str(), req.query.p.front().length(), &error);
00615 #else
00616 len = PQescapeString (queryend, req.query.p.front().c_str(), req.query.p.front().length());
00617 #endif
00618 if(error)
00619 {
00620 ServerInstance->Logs->Log("m_pgsql",DEBUG, "BUG: Apparently PQescapeStringConn() failed somehow...don't know how or what to do...");
00621 }
00622
00623
00624 queryend += len;
00625
00626
00627 req.query.p.pop_front();
00628 }
00629 else
00630 {
00631 ServerInstance->Logs->Log("m_pgsql",DEBUG, "BUG: Found a substitution location but no parameter to substitute :|");
00632 break;
00633 }
00634 }
00635 else
00636 {
00637 *queryend = req.query.q[i];
00638 queryend++;
00639 }
00640 }
00641
00642
00643 *queryend = 0;
00644 req.query.q = query;
00645
00646 if(PQsendQuery(sql, query))
00647 {
00648 qinprog = true;
00649 delete[] query;
00650 return SQLerror();
00651 }
00652 else
00653 {
00654 delete[] query;
00655 return SQLerror(SQL_QSEND_FAIL, PQerrorMessage(sql));
00656 }
00657 }
00658 }
00659 return SQLerror(SQL_BAD_CONN, "Can't query until connection is complete");
00660 }
00661
00662 SQLerror Query(const SQLrequest &req)
00663 {
00664 queue.push(req);
00665
00666 if(!qinprog && queue.totalsize())
00667 {
00668
00669 SQLrequest& query = queue.front();
00670 return DoQuery(query);
00671 }
00672 else
00673 {
00674 return SQLerror();
00675 }
00676 }
00677
00678 void OnUnloadModule(Module* mod)
00679 {
00680 queue.PurgeModule(mod);
00681 }
00682
00683 const SQLhost GetConfHost()
00684 {
00685 return confhost;
00686 }
00687
00688 void Close() {
00689 if (!this->ServerInstance->SE->DelFd(this))
00690 {
00691 if (sql && PQstatus(sql) == CONNECTION_BAD)
00692 {
00693 this->ServerInstance->SE->DelFd(this, true);
00694 }
00695 else
00696 {
00697 ServerInstance->Logs->Log("m_pgsql",DEBUG, "BUG: PQsocket cant be removed from socket engine!");
00698 }
00699 }
00700
00701 if(sql)
00702 {
00703 PQfinish(sql);
00704 sql = NULL;
00705 }
00706 }
00707
00708 };
00709
00710 class ModulePgSQL : public Module
00711 {
00712 private:
00713 ConnMap connections;
00714 unsigned long currid;
00715 char* sqlsuccess;
00716 ReconnectTimer* retimer;
00717
00718 public:
00719 ModulePgSQL(InspIRCd* Me)
00720 : Module(Me), currid(0)
00721 {
00722 ServerInstance->Modules->UseInterface("SQLutils");
00723
00724 sqlsuccess = new char[strlen(SQLSUCCESS)+1];
00725
00726 strlcpy(sqlsuccess, SQLSUCCESS, strlen(SQLSUCCESS));
00727
00728 if (!ServerInstance->Modules->PublishFeature("SQL", this))
00729 {
00730 throw ModuleException("BUG: PgSQL Unable to publish feature 'SQL'");
00731 }
00732
00733 ReadConf();
00734
00735 ServerInstance->Modules->PublishInterface("SQL", this);
00736 Implementation eventlist[] = { I_OnUnloadModule, I_OnRequest, I_OnRehash, I_OnUserRegister, I_OnCheckReady, I_OnUserDisconnect };
00737 ServerInstance->Modules->Attach(eventlist, this, 6);
00738 }
00739
00740 virtual ~ModulePgSQL()
00741 {
00742 if (retimer)
00743 ServerInstance->Timers->DelTimer(retimer);
00744 ClearAllConnections();
00745 delete[] sqlsuccess;
00746 ServerInstance->Modules->UnpublishInterface("SQL", this);
00747 ServerInstance->Modules->UnpublishFeature("SQL");
00748 ServerInstance->Modules->DoneWithInterface("SQLutils");
00749 }
00750
00751
00752 virtual void OnRehash(User* user, const std::string ¶meter)
00753 {
00754 ReadConf();
00755 }
00756
00757 bool HasHost(const SQLhost &host)
00758 {
00759 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
00760 {
00761 if (host == iter->second->GetConfHost())
00762 return true;
00763 }
00764 return false;
00765 }
00766
00767 bool HostInConf(const SQLhost &h)
00768 {
00769 ConfigReader conf(ServerInstance);
00770 for(int i = 0; i < conf.Enumerate("database"); i++)
00771 {
00772 SQLhost host;
00773 host.id = conf.ReadValue("database", "id", i);
00774 host.host = conf.ReadValue("database", "hostname", i);
00775 host.port = conf.ReadInteger("database", "port", i, true);
00776 host.name = conf.ReadValue("database", "name", i);
00777 host.user = conf.ReadValue("database", "username", i);
00778 host.pass = conf.ReadValue("database", "password", i);
00779 host.ssl = conf.ReadFlag("database", "ssl", "0", i);
00780 if (h == host)
00781 return true;
00782 }
00783 return false;
00784 }
00785
00786 void ReadConf()
00787 {
00788 ClearOldConnections();
00789
00790 ConfigReader conf(ServerInstance);
00791 for(int i = 0; i < conf.Enumerate("database"); i++)
00792 {
00793 SQLhost host;
00794 int ipvalid;
00795
00796 host.id = conf.ReadValue("database", "id", i);
00797 host.host = conf.ReadValue("database", "hostname", i);
00798 host.port = conf.ReadInteger("database", "port", i, true);
00799 host.name = conf.ReadValue("database", "name", i);
00800 host.user = conf.ReadValue("database", "username", i);
00801 host.pass = conf.ReadValue("database", "password", i);
00802 host.ssl = conf.ReadFlag("database", "ssl", "0", i);
00803
00804 if (HasHost(host))
00805 continue;
00806
00807 #ifdef IPV6
00808 if (strchr(host.host.c_str(),':'))
00809 {
00810 in6_addr blargle;
00811 ipvalid = inet_pton(AF_INET6, host.host.c_str(), &blargle);
00812 }
00813 else
00814 #endif
00815 {
00816 in_addr blargle;
00817 ipvalid = inet_aton(host.host.c_str(), &blargle);
00818 }
00819
00820 if(ipvalid > 0)
00821 {
00822
00823 host.ip = host.host;
00824 this->AddConn(host);
00825 }
00826 else if(ipvalid == 0)
00827 {
00828
00829 SQLresolver* resolver;
00830
00831 try
00832 {
00833 bool cached;
00834 resolver = new SQLresolver(this, ServerInstance, host, cached);
00835 ServerInstance->AddResolver(resolver, cached);
00836 }
00837 catch(...)
00838 {
00839
00840 }
00841 }
00842 else
00843 {
00844
00845 ServerInstance->Logs->Log("m_pgsql",DEBUG, "BUG: insp_aton failed returning -1, oh noes.");
00846 }
00847 }
00848 }
00849
00850 void ClearOldConnections()
00851 {
00852 ConnMap::iterator iter,safei;
00853 for (iter = connections.begin(); iter != connections.end(); iter++)
00854 {
00855 if (!HostInConf(iter->second->GetConfHost()))
00856 {
00857 delete iter->second;
00858 safei = iter;
00859 --iter;
00860 connections.erase(safei);
00861 }
00862 }
00863 }
00864
00865 void ClearAllConnections()
00866 {
00867 ConnMap::iterator i;
00868 while ((i = connections.begin()) != connections.end())
00869 {
00870 connections.erase(i);
00871 delete i->second;
00872 }
00873 }
00874
00875 void AddConn(const SQLhost& hi)
00876 {
00877 if (HasHost(hi))
00878 {
00879 ServerInstance->Logs->Log("m_pgsql",DEFAULT, "WARNING: A pgsql connection with id: %s already exists, possibly due to DNS delay. Aborting connection attempt.", hi.id.c_str());
00880 return;
00881 }
00882
00883 SQLConn* newconn;
00884
00885
00886 newconn = new SQLConn(ServerInstance, this, hi);
00887
00888 connections.insert(std::make_pair(hi.id, newconn));
00889 }
00890
00891 void ReconnectConn(SQLConn* conn)
00892 {
00893 for (ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
00894 {
00895 if (conn == iter->second)
00896 {
00897 delete iter->second;
00898 connections.erase(iter);
00899 break;
00900 }
00901 }
00902 retimer = new ReconnectTimer(ServerInstance, this);
00903 ServerInstance->Timers->AddTimer(retimer);
00904 }
00905
00906 virtual const char* OnRequest(Request* request)
00907 {
00908 if(strcmp(SQLREQID, request->GetId()) == 0)
00909 {
00910 SQLrequest* req = (SQLrequest*)request;
00911 ConnMap::iterator iter;
00912 if((iter = connections.find(req->dbid)) != connections.end())
00913 {
00914
00915 req->id = NewID();
00916 req->error = iter->second->Query(*req);
00917
00918 return (req->error.Id() == SQL_NO_ERROR) ? sqlsuccess : NULL;
00919 }
00920 else
00921 {
00922 req->error.Id(SQL_BAD_DBID);
00923 return NULL;
00924 }
00925 }
00926 return NULL;
00927 }
00928
00929 virtual void OnUnloadModule(Module* mod, const std::string& name)
00930 {
00931
00932
00933
00934
00935
00936
00937 for(ConnMap::iterator iter = connections.begin(); iter != connections.end(); iter++)
00938 {
00939 iter->second->OnUnloadModule(mod);
00940 }
00941 }
00942
00943 unsigned long NewID()
00944 {
00945 if (currid+1 == 0)
00946 currid++;
00947
00948 return ++currid;
00949 }
00950
00951 virtual Version GetVersion()
00952 {
00953 return Version("$Id: m_pgsql.cpp 10622 2008-10-04 21:27:52Z brain $", VF_VENDOR|VF_SERVICEPROVIDER, API_VERSION);
00954 }
00955 };
00956
00957
00958
00959
00960 void SQLresolver::OnLookupComplete(const std::string &result, unsigned int ttl, bool cached, int resultnum)
00961 {
00962 if (!resultnum)
00963 {
00964 host.ip = result;
00965 ((ModulePgSQL*)mod)->AddConn(host);
00966 ((ModulePgSQL*)mod)->ClearOldConnections();
00967 }
00968 }
00969
00970 void ReconnectTimer::Tick(time_t time)
00971 {
00972 ((ModulePgSQL*)mod)->ReadConf();
00973 }
00974
00975 void SQLConn::DelayReconnect()
00976 {
00977 ((ModulePgSQL*)us)->ReconnectConn(this);
00978 }
00979
00980 MODULE_INIT(ModulePgSQL)