// Copyright The Mumble Developers. All rights reserved.
// Use of this source code is governed by a BSD-style license
// that can be found in the LICENSE file at the root of the
// Mumble source tree or at <https://www.mumble.info/LICENSE>.

#include "BanTable.h"
#include "ChronoUtils.h"
#include "ServerTable.h"

#include "database/AccessException.h"
#include "database/Backend.h"
#include "database/Column.h"
#include "database/Constraint.h"
#include "database/DataType.h"
#include "database/Database.h"
#include "database/ForeignKey.h"
#include "database/MigrationException.h"
#include "database/PrimaryKey.h"
#include "database/TransactionHolder.h"
#include "database/Utils.h"

#include <soci/soci.h>

#include <cassert>
#include <exception>

namespace mdb = ::mumble::db;

namespace mumble {
namespace server {
	namespace db {

		constexpr const char *BanTable::NAME;
		constexpr const char *BanTable::column::server_id;
		constexpr const char *BanTable::column::base_address;
		constexpr const char *BanTable::column::prefix_length;
		constexpr const char *BanTable::column::user_name;
		constexpr const char *BanTable::column::cert_hash;
		constexpr const char *BanTable::column::reason;
		constexpr const char *BanTable::column::start_date;
		constexpr const char *BanTable::column::duration;


		BanTable::BanTable(soci::session &sql, ::mdb::Backend backend, const ServerTable &serverTable)
			: ::mdb::Table(sql, backend, NAME) {
			::mdb::Column serverCol(column::server_id, ::mdb::DataType(::mdb::DataType::Integer));
			serverCol.addConstraint(::mdb::Constraint(::mdb::Constraint::NotNull));

			// 45 characters is the maximum possible length of a valid textual representation of an IPv6 address
			// See https://stackoverflow.com/a/166157
			::mdb::Column addressCol(column::base_address, ::mdb::DataType(::mdb::DataType::VarChar, 45));
			addressCol.addConstraint(::mdb::Constraint(::mdb::Constraint::NotNull));

			::mdb::Column prefixCol(column::prefix_length, ::mdb::DataType(::mdb::DataType::Integer));
			prefixCol.addConstraint(::mdb::Constraint(::mdb::Constraint::NotNull));

			// Since Mumble 1.6, we changed the PK of this table to contain the certificate hash such that
			// admins are able to ban IPs or hashes or both. For that reason, cert_hash may no longer be
			// NULL. Instead "no hash" is represented by the empty string
			::mdb::Column userCertCol(column::cert_hash, ::mdb::DataType(::mdb::DataType::VarChar, 255));
			userCertCol.addConstraint(::mdb::Constraint(::mdb::Constraint::NotNull));

			::mdb::Column userNameCol(column::user_name, ::mdb::DataType(::mdb::DataType::VarChar, 255));
			userNameCol.setDefaultValue("NULL");

			::mdb::Column reasonCol(column::reason, ::mdb::DataType(::mdb::DataType::Text));
			reasonCol.setDefaultValue("NULL");

			::mdb::Column startDateCol(column::start_date, ::mdb::DataType(::mdb::DataType::EpochTime));
			startDateCol.addConstraint(::mdb::Constraint(::mdb::Constraint::NotNull));
			startDateCol.setDefaultValue("0");

			::mdb::Column durationCol(column::duration, ::mdb::DataType(::mdb::DataType::Integer));
			durationCol.addConstraint(::mdb::Constraint(::mdb::Constraint::NotNull));
			durationCol.setDefaultValue("0");


			setColumns(
				{ serverCol, addressCol, prefixCol, userCertCol, userNameCol, reasonCol, startDateCol, durationCol });

			::mdb::PrimaryKey pk(std::vector< std::string >{ column::server_id, column::base_address,
															 column::prefix_length, column::cert_hash });
			setPrimaryKey(pk);

			::mdb::ForeignKey fk(serverTable, { serverCol });
			addForeignKey(fk);
		}

		void BanTable::addBan(const DBBan &ban) {
			std::string baseAddress   = DBBan::ipv6ToString(ban.baseAddress.value_or(DBBan::INVALID_IP));
			std::uint8_t prefixLength = ban.prefixLength.value_or(0);
			std::string userCert      = ban.bannedUserCertHash.value_or("");
			try {
				auto startEpoch = static_cast< unsigned int >(toEpochSeconds(ban.startDate));
				auto duration   = static_cast< unsigned int >(ban.duration.count());
				std::string userName;
				std::string reason;
				soci::indicator nameInd   = soci::i_null;
				soci::indicator reasonInd = soci::i_null;

				if (ban.bannedUserName) {
					userName = ban.bannedUserName.value();
					nameInd  = soci::i_ok;
				}
				if (ban.reason) {
					reason    = ban.reason.value();
					reasonInd = soci::i_ok;
				}

				::mdb::TransactionHolder transaction = ensureTransaction();

				m_sql << "INSERT INTO \"" << NAME << "\" (\"" << column::server_id << "\", \"" << column::base_address
					  << "\", \"" << column::prefix_length << "\", \"" << column::cert_hash << "\", \""
					  << column::user_name << "\", \"" << column::reason << "\", \"" << column::start_date << "\", \""
					  << column::duration
					  << "\") VALUES (:serverID, LOWER(:baseAddr), :prefixLength, :userCert, :userName, :reason, "
						 ":startDate, :duration)",
					soci::use(ban.serverID), soci::use(baseAddress), soci::use(prefixLength), soci::use(userCert),
					soci::use(userName, nameInd), soci::use(reason, reasonInd), soci::use(startEpoch),
					soci::use(duration);

				transaction.commit();
			} catch (const soci::soci_error &) {
				std::throw_with_nested(::mdb::AccessException("Failed at adding new Ban for " + baseAddress + "/"
															  + std::to_string(prefixLength) + " Hash: '" + userCert
															  + "' on server with ID " + std::to_string(ban.serverID)));
			}
		}

		void BanTable::removeBan(const DBBan &ban) {
			removeBan(ban.serverID, DBBan::ipv6ToString(ban.baseAddress.value_or(DBBan::INVALID_IP)),
					  ban.prefixLength.value_or(0), ban.bannedUserCertHash.value_or(""));
		}

		void BanTable::removeBan(unsigned int serverID, const std::string &baseAddress, std::uint8_t prefixLength,
								 const std::string &bannedUserCertHash) {
			try {
				::mdb::TransactionHolder transaction = ensureTransaction();

				m_sql << "DELETE FROM \"" << NAME << "\" WHERE \"" << column::server_id << "\" = :serverID AND LOWER(\""
					  << column::base_address << "\") = LOWER(:baseAddress) AND \"" << column::prefix_length
					  << "\" = :prefixLength AND \"" << column::cert_hash << "\" = :userCert",
					soci::use(serverID), soci::use(baseAddress), soci::use(prefixLength), soci::use(bannedUserCertHash);

				transaction.commit();
			} catch (const soci::soci_error &) {
				std::throw_with_nested(::mdb::AccessException(
					"Failed at removing Ban for " + baseAddress + "/" + std::to_string(prefixLength) + " Hash: '"
					+ bannedUserCertHash + "' on server with ID " + std::to_string(serverID)));
			}
		}

		bool BanTable::banExists(const DBBan &ban) {
			return banExists(ban.serverID, DBBan::ipv6ToString(ban.baseAddress.value_or(DBBan::INVALID_IP)),
							 ban.prefixLength.value_or(0), ban.bannedUserCertHash.value_or(""));
		}

		bool BanTable::banExists(unsigned int serverID, const std::string &baseAddress, std::uint8_t prefixLength,
								 const std::string &bannedUserCertHash) {
			try {
				int exists = false;

				::mdb::TransactionHolder transaction = ensureTransaction();

				m_sql << "SELECT 1 FROM \"" << NAME << "\" WHERE \"" << column::server_id
					  << "\" = :serverID AND LOWER(\"" << column::base_address << "\") = LOWER(:baseAddress) AND \""
					  << column::prefix_length << "\" = :prefixLength AND \"" << column::cert_hash
					  << "\" = :userCert LIMIT 1",
					soci::use(serverID), soci::use(baseAddress), soci::use(prefixLength), soci::use(bannedUserCertHash),
					soci::into(exists);

				transaction.commit();

				return exists;
			} catch (const soci::soci_error &) {
				std::throw_with_nested(::mdb::AccessException(
					"Failed at checking whether Ban for " + baseAddress + "/" + std::to_string(prefixLength)
					+ " Hash: '" + bannedUserCertHash + "' exists on server with ID " + std::to_string(serverID)));
			}
		}

		DBBan BanTable::getBanDetails(unsigned int serverID, std::string baseAddress, std::uint8_t prefixLength,
									  std::optional< std::string > bannedUserCertHash) {
			return getBanDetails(serverID, DBBan::ipv6FromString(baseAddress), prefixLength, bannedUserCertHash);
		}

		DBBan BanTable::getBanDetails(unsigned int serverID, std::optional< DBBan::ipv6_type > baseAddress,
									  std::optional< std::uint8_t > prefixLength,
									  std::optional< std::string > bannedUserCertHash) {
			DBBan ban;
			ban.serverID           = serverID;
			ban.baseAddress        = baseAddress;
			ban.prefixLength       = prefixLength;
			ban.bannedUserCertHash = bannedUserCertHash;

			std::string baseAddressQuery        = DBBan::ipv6ToString(ban.baseAddress.value_or(DBBan::INVALID_IP));
			std::uint8_t prefixLengthQuery      = ban.prefixLength.value_or(0);
			std::string bannedUserCertHashQuery = ban.bannedUserCertHash.value_or("");

			try {
				unsigned int startEpoch;
				unsigned int duration;
				std::string userName;
				std::string reason;
				soci::indicator nameInd;
				soci::indicator reasonInd;

				::mdb::TransactionHolder transaction = ensureTransaction();

				m_sql << "SELECT \"" << column::user_name << "\", \"" << column::reason << "\", \""
					  << column::start_date << "\", \"" << column::duration << "\" FROM \"" << NAME << "\" WHERE \""
					  << column::server_id << "\" = :serverID AND \"" << column::base_address
					  << "\" = :baseAddress AND \"" << column::prefix_length << "\" = :prefix_length AND \""
					  << column::cert_hash << "\" = :userCert",
					soci::into(userName, nameInd), soci::into(reason, reasonInd), soci::into(startEpoch),
					soci::into(duration), soci::use(serverID), soci::use(baseAddressQuery),
					soci::use(prefixLengthQuery), soci::use(bannedUserCertHashQuery);

				::mdb::utils::verifyQueryResultedInData(m_sql);

				if (nameInd == soci::i_ok) {
					ban.bannedUserName = std::move(userName);
				}
				if (reasonInd == soci::i_ok) {
					ban.reason = std::move(reason);
				}
				ban.startDate = std::chrono::time_point< std::chrono::system_clock >(std::chrono::seconds(startEpoch));
				ban.duration  = std::chrono::seconds(duration);

				transaction.commit();

				return ban;
			} catch (const soci::soci_error &) {
				std::throw_with_nested(::mdb::AccessException(
					"Failed at getting details for Ban of " + baseAddressQuery + "/" + std::to_string(prefixLengthQuery)
					+ " Hash: '" + bannedUserCertHashQuery + "' on server with ID " + std::to_string(serverID)));
			}
		}

		std::vector< DBBan > BanTable::getAllBans(unsigned int serverID) {
			try {
				std::vector< DBBan > bans;
				soci::row row;

				::mdb::TransactionHolder transaction = ensureTransaction();

				soci::statement stmt =
					(m_sql.prepare << "SELECT \"" << column::base_address << "\", \"" << column::prefix_length
								   << "\", \"" << column::cert_hash << "\", \"" << column::user_name << "\", \""
								   << column::reason << "\", \"" << column::start_date << "\", \"" << column::duration
								   << "\" FROM \"" << NAME << "\" WHERE " << column::server_id << " = :serverID",
					 soci::use(serverID), soci::into(row));

				stmt.execute(false);

				while (stmt.fetch()) {
					assert(row.size() == 7);
					assert(row.get_properties(0).get_data_type() == soci::dt_string);
					assert(row.get_properties(1).get_data_type() == soci::dt_integer);
					assert(row.get_properties(2).get_data_type() == soci::dt_string);
					assert(row.get_properties(3).get_data_type() == soci::dt_string);
					assert(row.get_properties(4).get_data_type() == soci::dt_string);
					assert(row.get_properties(5).get_data_type() == soci::dt_long_long);
					assert(row.get_properties(6).get_data_type() == soci::dt_integer);

					DBBan ban;
					ban.serverID           = serverID;
					ban.baseAddress        = DBBan::ipv6FromString(row.get< std::string >(0));
					ban.prefixLength       = static_cast< std::uint8_t >(row.get< int >(1));
					ban.bannedUserCertHash = row.get< std::string >(2);
					if (ban.baseAddress.value() == DBBan::INVALID_IP) {
						ban.baseAddress = std::nullopt;
					}
					if (ban.prefixLength.value() == 0) {
						ban.prefixLength = std::nullopt;
					}
					if (ban.bannedUserCertHash.value().empty()) {
						ban.bannedUserCertHash = std::nullopt;
					}
					if (row.get_indicator(3) == soci::i_ok) {
						ban.bannedUserName = row.get< std::string >(3);
					}
					if (row.get_indicator(4) == soci::i_ok) {
						ban.reason = row.get< std::string >(4);
					}
					ban.startDate = std::chrono::time_point< std::chrono::system_clock >(
						std::chrono::seconds(row.get< long long >(5)));
					ban.duration = std::chrono::seconds(row.get< int >(6));

					bans.push_back(std::move(ban));
				}

				transaction.commit();

				return bans;
			} catch (const soci::soci_error &) {
				std::throw_with_nested(
					::mdb::AccessException("Failed at getting Bans on server with ID " + std::to_string(serverID)));
			}
		}

		void doClearBans(soci::session &sql, unsigned int serverID) {
			sql << "DELETE FROM \"" << BanTable::NAME << "\" WHERE \"" << BanTable::column::server_id
				<< "\" = :serverID",
				soci::use(serverID);
		}

		void BanTable::clearBans(unsigned int serverID) {
			try {
				::mdb::TransactionHolder transaction = ensureTransaction();

				doClearBans(m_sql, serverID);

				transaction.commit();
			} catch (const soci::soci_error &) {
				std::throw_with_nested(
					::mdb::AccessException("Failed at clearing Bans on server with ID " + std::to_string(serverID)));
			}
		}

		void BanTable::setBans(unsigned int serverID, const std::vector< DBBan > &bans) {
			try {
				::mdb::TransactionHolder transaction = ensureTransaction();

				// Step 1: Clear old bans
				doClearBans(m_sql, serverID);

				// Step 2: Insert new bans
				soci::statement stmt =
					m_sql.prepare
					<< "INSERT INTO \"" << NAME << "\" (\"" << column::server_id << "\", \"" << column::base_address
					<< "\", \"" << column::prefix_length << "\", \"" << column::cert_hash << "\", \""
					<< column::user_name << "\", \"" << column::reason << "\", \"" << column::start_date << "\", \""
					<< column::duration
					<< "\") VALUES (:serverID, LOWER(:baseAddr), :prefixLength, :userCert, :userName, :reason, "
					   ":startDate, :duration)";

				for (const DBBan &currentBan : bans) {
					assert(currentBan.serverID == serverID);

					std::string baseAddress   = DBBan::ipv6ToString(currentBan.baseAddress.value_or(DBBan::INVALID_IP));
					std::uint8_t prefixLength = currentBan.prefixLength.value_or(0);
					auto startEpoch           = static_cast< unsigned int >(toEpochSeconds(currentBan.startDate));
					auto duration             = static_cast< unsigned int >(currentBan.duration.count());
					std::string userCert      = currentBan.bannedUserCertHash.value_or("");
					std::string userName;
					std::string reason;
					soci::indicator nameInd   = soci::i_null;
					soci::indicator reasonInd = soci::i_null;

					if (currentBan.bannedUserName) {
						userName = currentBan.bannedUserName.value();
						nameInd  = soci::i_ok;
					}
					if (currentBan.reason) {
						reason    = currentBan.reason.value();
						reasonInd = soci::i_ok;
					}

					stmt.exchange(soci::use(serverID));
					stmt.exchange(soci::use(baseAddress));
					stmt.exchange(soci::use(prefixLength));
					stmt.exchange(soci::use(userCert));
					stmt.exchange(soci::use(userName, nameInd));
					stmt.exchange(soci::use(reason, reasonInd));
					stmt.exchange(soci::use(startEpoch));
					stmt.exchange(soci::use(duration));

					stmt.define_and_bind();
					stmt.execute(true);
					stmt.bind_clean_up();
				}

				transaction.commit();
			} catch (const soci::soci_error &) {
				std::throw_with_nested(
					::mdb::AccessException("Failed at setting Bans on server with ID " + std::to_string(serverID)));
			}
		}

		void BanTable::migrate(unsigned int fromSchemaVersion, unsigned int toSchemaVersion) {
			// Note: Always hard-code table and column names in this function in order to ensure that this
			// migration path always stays the same regardless of whether the respective named constants change.
			assert(fromSchemaVersion <= toSchemaVersion);

			try {
				if (fromSchemaVersion < 10) {
					// Before v4, we stored IPv4 addresses in the DB and there only were the fields server_id, base (the
					// IPv4 address) and mask.
					// In v10 the following columns have been renamed:
					// "base" -> "ipv6_base_address" Also we switched from storing a binary representation of the IP
					//		address to storing a textual representation, instead.
					// "mask" -> "prefix_length"
					// "name" -> "banned_user_name"
					// "hash" -> "banned_user_cert_hash"
					// "start" -> "start_date" Also we changed its type from a native DATE format into using epoch
					//		seconds.
					soci::row row;

					std::string startConversion = mdb::utils::dateToEpoch("\"start\"", m_backend);
					std::string baseConversion;
					switch (m_backend) {
						case ::mdb::Backend::SQLite:
						case ::mdb::Backend::MySQL:
							baseConversion = "HEX(\"base\")";
							break;
						case ::mdb::Backend::PostgreSQL:
							baseConversion = "ENCODE(\"base\"::bytea, 'hex')";
							break;
					}
					assert(!baseConversion.empty());

					soci::statement selectStmt =
						(m_sql.prepare << "SELECT \"server_id\", " + baseConversion
											  + ", \"mask\", \"hash\", \"name\", \"reason\", " + startConversion
											  + ", \"duration\" FROM \"bans"
											  + std::string(::mdb::Database::OLD_TABLE_SUFFIX) + "\"",
						 soci::into(row));

					soci::statement insertStmt =
						m_sql.prepare
						<< "INSERT INTO \"" << NAME << "\" (\"" << column::server_id << "\", \"" << column::base_address
						<< "\", \"" << column::prefix_length << "\", \"" << column::cert_hash << "\", \""
						<< column::user_name << "\", \"" << column::reason << "\", \"" << column::start_date << "\", \""
						<< column::duration
						<< "\") VALUES (:serverID, :baseAddr, :prefixLength, :certHash, :userName, :reason, "
						   ":startDate, :duration)";

					selectStmt.execute(false);

					while (selectStmt.fetch()) {
						int serverID;
						std::string baseAddress;
						int prefixLength;
						std::string bannedCertHash = "";
						std::string bannedName;
						soci::indicator nameInd = soci::i_null;
						std::string reason;
						soci::indicator reasonInd = soci::i_null;
						long long startDate       = 0;
						int duration              = 0;

						assert(row.size() == 8);
						assert(row.get_properties(0).get_data_type() == soci::dt_integer);
						assert(row.get_indicator(0) == soci::i_ok);
						assert(row.get_properties(1).get_data_type() == soci::dt_string);
						assert(row.get_indicator(1) == soci::i_ok);
						assert(row.get_properties(2).get_data_type() == soci::dt_integer);
						assert(row.get_indicator(2) == soci::i_ok);

						serverID     = row.get< int >(0);
						baseAddress  = row.get< std::string >(1);
						prefixLength = row.get< int >(2);

						assert(row.get_properties(3).get_data_type() == soci::dt_string);
						assert(row.get_properties(4).get_data_type() == soci::dt_string);
						assert(row.get_properties(5).get_data_type() == soci::dt_string);
						// The actual datatype for the start time is long, but SOCI returns a string when using SQLite
						// due to issues with SQLite's type system (or rather the lack thereof)
						assert(row.get_properties(6).get_data_type() == soci::dt_long_long
							   || row.get_properties(6).get_data_type() == soci::dt_string);
						assert(row.get_indicator(6) == soci::i_ok);
						assert(row.get_properties(7).get_data_type() == soci::dt_integer);

						bool success          = false;
						DBBan::ipv6_type ipv6 = ::mdb::utils::hexToBinary< DBBan::ipv6_type >(baseAddress, &success);
						if (!success) {
							throw ::mdb::MigrationException("Encountered invalid hex representation of IPv6 address '"
															+ baseAddress + "' while migrating table \"" + NAME + "\"");
						}
						baseAddress = DBBan::ipv6ToString(ipv6);

						if (row.get_indicator(3) == soci::i_ok) {
							bannedCertHash = row.get< std::string >(3);
						}
						if (row.get_indicator(4) == soci::i_ok) {
							bannedName = row.get< std::string >(4);
							nameInd    = soci::i_ok;
						}
						if (row.get_indicator(5) == soci::i_ok) {
							reason    = row.get< std::string >(5);
							reasonInd = soci::i_ok;
						}

						if (row.get_properties(6).get_data_type() == soci::dt_long_long) {
							startDate = row.get< long long >(6);
						} else {
							// SQLite code path
							assert(row.get_properties(6).get_data_type() == soci::dt_string);
							std::string strStartDate = row.get< std::string >(6);
							startDate                = std::stoll(strStartDate);
						}
						duration = row.get_indicator(7) != soci::i_null ? row.get< int >(7) : 0;

						insertStmt.exchange(soci::use(serverID));
						insertStmt.exchange(soci::use(baseAddress));
						insertStmt.exchange(soci::use(prefixLength));
						insertStmt.exchange(soci::use(bannedCertHash));
						insertStmt.exchange(soci::use(bannedName, nameInd));
						insertStmt.exchange(soci::use(reason, reasonInd));
						insertStmt.exchange(soci::use(startDate));
						insertStmt.exchange(soci::use(duration));

						insertStmt.define_and_bind();
						insertStmt.execute(true);
						insertStmt.bind_clean_up();
					}
				} else if (fromSchemaVersion < 11) {
					// Before v11, a ban was allowed to have a NULL value for the certificate hash. v11 enables admins
					// to selectively ban IP or Hash or both. This requires to add the Hash to the primary key.
					// Since not all database backends allow to have NULL in a primary key or unique index, we have to
					// convert all NULL values for certificate hashes to an empty string.
					m_sql << "INSERT INTO \"" << NAME << "\" (\"" << column::server_id << "\", \""
						  << column::base_address << "\", \"" << column::prefix_length << "\", \"" << column::cert_hash
						  << "\", \"" << column::user_name << "\", \"" << column::reason << "\", \""
						  << column::start_date << "\", \"" << column::duration
						  << "\") SELECT \"server_id\", \"ipv6_base_address\", \"prefix_length\", "
						  << ::mdb::utils::nonNullOf("\"banned_user_cert_hash\"").otherwise("''")
						  << ", \"banned_user_name\", \"reason\", \"start_date\", \"duration\" FROM \"bans"
						  << ::mdb::Database::OLD_TABLE_SUFFIX << "\"";

				} else {
					// Use default implementation to handle migration without change of format
					mdb::Table::migrate(fromSchemaVersion, toSchemaVersion);
				}
			} catch (const soci::soci_error &) {
				std::throw_with_nested(::mdb::MigrationException(
					std::string("Failed at migrating table \"") + NAME + "\" from schema version "
					+ std::to_string(fromSchemaVersion) + " to " + std::to_string(toSchemaVersion)));
			}
		}

	} // namespace db
} // namespace server
} // namespace mumble
