The InspIRCd Project
Home | Developers | Wiki | Forums | Bug Tracker | SVN | Download
Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | Directories | File List | Namespace Members | Class Members | File Members

m_sqlauth.cpp

Go to the documentation of this file.
00001 /*       +------------------------------------+
00002  *       | Inspire Internet Relay Chat Daemon |
00003  *       +------------------------------------+
00004  *
00005  *  InspIRCd: (C) 2002-2008 InspIRCd Development Team
00006  * See: http://www.inspircd.org/wiki/index.php/Credits
00007  *
00008  * This program is free but copyrighted software; see
00009  *            the file COPYING for details.
00010  *
00011  * ---------------------------------------------------
00012  */
00013 
00014 #include "inspircd.h"
00015 #include "m_sqlv2.h"
00016 #include "m_sqlutils.h"
00017 #include "m_hash.h"
00018 
00019 /* $ModDesc: Allow/Deny connections based upon an arbitary SQL table */
00020 /* $ModDep: m_sqlv2.h m_sqlutils.h m_hash.h */
00021 
00022 class ModuleSQLAuth : public Module
00023 {
00024         Module* SQLutils;
00025         Module* SQLprovider;
00026 
00027         std::string freeformquery;
00028         std::string killreason;
00029         std::string allowpattern;
00030         std::string databaseid;
00031 
00032         bool verbose;
00033 
00034 public:
00035         ModuleSQLAuth(InspIRCd* Me)
00036         : Module(Me)
00037         {
00038                 ServerInstance->Modules->UseInterface("SQLutils");
00039                 ServerInstance->Modules->UseInterface("SQL");
00040 
00041                 SQLutils = ServerInstance->Modules->Find("m_sqlutils.so");
00042                 if (!SQLutils)
00043                         throw ModuleException("Can't find m_sqlutils.so. Please load m_sqlutils.so before m_sqlauth.so.");
00044 
00045                 SQLprovider = ServerInstance->Modules->FindFeature("SQL");
00046                 if (!SQLprovider)
00047                         throw ModuleException("Can't find an SQL provider module. Please load one before attempting to load m_sqlauth.");
00048 
00049                 OnRehash(NULL,"");
00050                 Implementation eventlist[] = { I_OnUserDisconnect, I_OnCheckReady, I_OnRequest, I_OnRehash, I_OnUserRegister };
00051                 ServerInstance->Modules->Attach(eventlist, this, 5);
00052         }
00053 
00054         virtual ~ModuleSQLAuth()
00055         {
00056                 ServerInstance->Modules->DoneWithInterface("SQL");
00057                 ServerInstance->Modules->DoneWithInterface("SQLutils");
00058         }
00059 
00060 
00061         virtual void OnRehash(User* user, const std::string &parameter)
00062         {
00063                 ConfigReader Conf(ServerInstance);
00064 
00065                 databaseid      = Conf.ReadValue("sqlauth", "dbid", 0);                 /* Database ID, given to the SQL service provider */
00066                 freeformquery   = Conf.ReadValue("sqlauth", "query", 0);        /* Field name where username can be found */
00067                 killreason      = Conf.ReadValue("sqlauth", "killreason", 0);   /* Reason to give when access is denied to a user (put your reg details here) */
00068                 allowpattern    = Conf.ReadValue("sqlauth", "allowpattern",0 ); /* Allow nicks matching this pattern without requiring auth */
00069                 verbose         = Conf.ReadFlag("sqlauth", "verbose", 0);               /* Set to true if failed connects should be reported to operators */
00070         }
00071 
00072         virtual int OnUserRegister(User* user)
00073         {
00074                 if ((!allowpattern.empty()) && (InspIRCd::Match(user->nick,allowpattern)))
00075                 {
00076                         user->Extend("sqlauthed");
00077                         return 0;
00078                 }
00079 
00080                 if (!CheckCredentials(user))
00081                 {
00082                         ServerInstance->Users->QuitUser(user, killreason);
00083                         return 1;
00084                 }
00085                 return 0;
00086         }
00087 
00088         void SearchAndReplace(std::string& newline, const std::string &find, const std::string &replace)
00089         {
00090                 std::string::size_type x = newline.find(find);
00091                 while (x != std::string::npos)
00092                 {
00093                         newline.erase(x, find.length());
00094                         if (!replace.empty())
00095                                 newline.insert(x, replace);
00096                         x = newline.find(find);
00097                 }
00098         }
00099 
00100         bool CheckCredentials(User* user)
00101         {
00102                 std::string thisquery = freeformquery;
00103                 std::string safepass = user->password;
00104 
00105                 /* Search and replace the escaped nick and escaped pass into the query */
00106 
00107                 SearchAndReplace(safepass, "\"", "");
00108 
00109                 SearchAndReplace(thisquery, "$nick", user->nick);
00110                 SearchAndReplace(thisquery, "$pass", safepass);
00111                 SearchAndReplace(thisquery, "$host", user->host);
00112                 SearchAndReplace(thisquery, "$ip", user->GetIPString());
00113 
00114                 Module* HashMod = ServerInstance->Modules->Find("m_md5.so");
00115 
00116                 if (HashMod)
00117                 {
00118                         HashResetRequest(this, HashMod).Send();
00119                         SearchAndReplace(thisquery, "$md5pass", HashSumRequest(this, HashMod, user->password).Send());
00120                 }
00121 
00122                 HashMod = ServerInstance->Modules->Find("m_sha256.so");
00123 
00124                 if (HashMod)
00125                 {
00126                         HashResetRequest(this, HashMod).Send();
00127                         SearchAndReplace(thisquery, "$sha256pass", HashSumRequest(this, HashMod, user->password).Send());
00128                 }
00129 
00130                 /* Build the query */
00131                 SQLrequest req = SQLrequest(this, SQLprovider, databaseid, SQLquery(thisquery));
00132 
00133                 if(req.Send())
00134                 {
00135                         /* When we get the query response from the service provider we will be given an ID to play with,
00136                          * just an ID number which is unique to this query. We need a way of associating that ID with a User
00137                          * so we insert it into a map mapping the IDs to users.
00138                          * Thankfully m_sqlutils provides this, it will associate a ID with a user or channel, and if the user quits it removes the
00139                          * association. This means that if the user quits during a query we will just get a failed lookup from m_sqlutils - telling
00140                          * us to discard the query.
00141                          */
00142                         AssociateUser(this, SQLutils, req.id, user).Send();
00143 
00144                         return true;
00145                 }
00146                 else
00147                 {
00148                         if (verbose)
00149                                 ServerInstance->SNO->WriteToSnoMask('A', "Forbidden connection from %s!%s@%s (SQL query failed: %s)", user->nick.c_str(), user->ident.c_str(), user->host.c_str(), req.error.Str());
00150                         return false;
00151                 }
00152         }
00153 
00154         virtual const char* OnRequest(Request* request)
00155         {
00156                 if(strcmp(SQLRESID, request->GetId()) == 0)
00157                 {
00158                         SQLresult* res = static_cast<SQLresult*>(request);
00159 
00160                         User* user = GetAssocUser(this, SQLutils, res->id).S().user;
00161                         UnAssociate(this, SQLutils, res->id).S();
00162 
00163                         if(user)
00164                         {
00165                                 if(res->error.Id() == SQL_NO_ERROR)
00166                                 {
00167                                         if(res->Rows())
00168                                         {
00169                                                 /* We got a row in the result, this is enough really */
00170                                                 user->Extend("sqlauthed");
00171                                         }
00172                                         else if (verbose)
00173                                         {
00174                                                 /* No rows in result, this means there was no record matching the user */
00175                                                 ServerInstance->SNO->WriteToSnoMask('A', "Forbidden connection from %s!%s@%s (SQL query returned no matches)", user->nick.c_str(), user->ident.c_str(), user->host.c_str());
00176                                                 user->Extend("sqlauth_failed");
00177                                         }
00178                                 }
00179                                 else if (verbose)
00180                                 {
00181                                         ServerInstance->SNO->WriteToSnoMask('A', "Forbidden connection from %s!%s@%s (SQL query failed: %s)", user->nick.c_str(), user->ident.c_str(), user->host.c_str(), res->error.Str());
00182                                         user->Extend("sqlauth_failed");
00183                                 }
00184                         }
00185                         else
00186                         {
00187                                 return NULL;
00188                         }
00189 
00190                         if (!user->GetExt("sqlauthed"))
00191                         {
00192                                 ServerInstance->Users->QuitUser(user, killreason);
00193                         }
00194                         return SQLSUCCESS;
00195                 }
00196                 return NULL;
00197         }
00198 
00199         virtual void OnUserDisconnect(User* user)
00200         {
00201                 user->Shrink("sqlauthed");
00202                 user->Shrink("sqlauth_failed");
00203         }
00204 
00205         virtual bool OnCheckReady(User* user)
00206         {
00207                 return user->GetExt("sqlauthed");
00208         }
00209 
00210         virtual Version GetVersion()
00211         {
00212                 return Version("$Id: m_sqlauth.cpp 10622 2008-10-04 21:27:52Z brain $", VF_VENDOR, API_VERSION);
00213         }
00214 
00215 };
00216 
00217 MODULE_INIT(ModuleSQLAuth)