using System.Runtime.Caching;
using System.Security.Cryptography;
using System.Text;
using System.Text.Json;
using KfChatDotNetBot.Commands;
using KfChatDotNetBot.Models;
using KfChatDotNetBot.Models.DbModels;
using NLog;
namespace KfChatDotNetBot.Services;
public static class RateLimitService
{
private static Logger _logger = LogManager.GetCurrentClassLogger();
///
/// Check whether a user is rate limited for a given command
///
/// User you wish to check
/// Command the user is invoking
/// Message the user sent
///
public static IsRateLimitedModel IsRateLimited(UserDbModel user, ICommand command, string message)
{
var result = new IsRateLimitedModel
{
IsRateLimited = false
};
if (command.RateLimitOptions == null) return result;
if (command.RateLimitOptions.Flags.HasFlag(RateLimitFlags.ExemptPrivilegedUsers) &&
user.UserRight > UserRight.Guest) return result;
var entries = GetBucketEntries(command.GetType().Name);
if (!command.RateLimitOptions.Flags.HasFlag(RateLimitFlags.Global))
{
entries = entries.Where(x => x.UserId == user.Id).ToList();
}
if (command.RateLimitOptions.Flags.HasFlag(RateLimitFlags.UseEntireMessage))
{
var hash = Convert.ToHexString(SHA256.HashData(Encoding.UTF8.GetBytes(message)));
entries = entries.Where(x => x.MessageHash == hash).ToList();
}
var now = DateTimeOffset.UtcNow;
entries = entries.Where(x => x.EntryExpires > now).ToList();
if (entries.Count >= command.RateLimitOptions.MaxInvocations)
{
result.IsRateLimited = true;
result.OldestEntryExpires = entries.OrderBy(x => x.EntryCreated).First().EntryExpires;
}
return result;
}
///
/// Get all the bucket entries for a given command
///
/// String representation of the command.
/// Get it by running command.GetType().Name
/// A list of entries
/// Thrown if the cached entries were somehow null when converted to a string
public static List GetBucketEntries(string commandName)
{
var cache = MemoryCache.Default;
var entries = cache.Get($"RateLimitBucket:{commandName}");
if (entries == null) return [];
List bucketEntries;
try
{
bucketEntries = JsonSerializer.Deserialize>((string)entries) ??
throw new InvalidOperationException();
}
catch (Exception e)
{
_logger.Error($"Caught an exception when trying to deserialize RateLimitBucket entries for {commandName}. JSON follows");
_logger.Error(entries);
_logger.Error("Exception follows");
_logger.Error(e);
return [];
}
return bucketEntries;
}
///
/// Save the current state of bucket entries for a given command
///
/// String representation of the command.
/// Get it by running command.GetType().Name
/// Entries you wish to save
public static void SaveBucketEntries(string commandName, List entries)
{
var cache = MemoryCache.Default;
cache.Set($"RateLimitBucket:{commandName}", JsonSerializer.Serialize(entries),
new CacheItemPolicy { AbsoluteExpiration = DateTimeOffset.UtcNow.AddDays(1) });
}
///
/// Remove the most recent entry for a given user and command
/// Use this if you want to invalidate an entry as forgiveness for invalid user input
///
/// User to remove the entry for
/// Command the user ran
public static void RemoveMostRecentEntry(UserDbModel user, ICommand command)
{
var entries = GetBucketEntries(command.GetType().Name);
var lastEntry = entries.Where(x => x.UserId == user.Id).OrderBy(x => x.EntryCreated).LastOrDefault();
if (lastEntry == null) return;
entries.Remove(lastEntry);
SaveBucketEntries(command.GetType().Name, entries);
}
///
/// Add an entry to the rate limit bucket for the given command
///
/// User the entry belongs to
/// Command the user ran
/// The user's message
public static void AddEntry(UserDbModel user, ICommand command, string message)
{
if (command.RateLimitOptions == null) return;
var commandName = command.GetType().Name;
var entries = GetBucketEntries(commandName);
entries.Add(new RateLimitBucketEntryModel
{
UserId = user.Id,
EntryCreated = DateTimeOffset.UtcNow,
EntryExpires = DateTimeOffset.UtcNow + command.RateLimitOptions.Window,
CommandInvoked = commandName,
MessageHash = Convert.ToHexString(SHA256.HashData(Encoding.UTF8.GetBytes(message)))
});
SaveBucketEntries(commandName, entries);
}
///
/// Removes entries which have expired for all commands in the rate limit bucket
///
public static void CleanupExpiredEntries()
{
var cache = MemoryCache.Default;
var now = DateTimeOffset.UtcNow;
foreach (var entry in cache.Select(kvp => kvp.Key).Where(kvp => kvp.StartsWith("RateLimitBucket:")).ToList().OfType())
{
_logger.Info($"Cleaning up expired entries for {entry}");
var commandName = entry.Replace("RateLimitBucket:", string.Empty);
var entries = GetBucketEntries(commandName);
SaveBucketEntries(commandName, entries.Where(x => x.EntryExpires > now).ToList());
}
}
}