using System.Runtime.Caching; using System.Security.Cryptography; using System.Text; using System.Text.Json; using KfChatDotNetBot.Commands; using KfChatDotNetBot.Models; using KfChatDotNetBot.Models.DbModels; using NLog; namespace KfChatDotNetBot.Services; public static class RateLimitService { private static Logger _logger = LogManager.GetCurrentClassLogger(); /// /// Check whether a user is rate limited for a given command /// /// User you wish to check /// Command the user is invoking /// Message the user sent /// public static IsRateLimitedModel IsRateLimited(UserDbModel user, ICommand command, string message) { var result = new IsRateLimitedModel { IsRateLimited = false }; if (command.RateLimitOptions == null) return result; if (command.RateLimitOptions.Flags.HasFlag(RateLimitFlags.ExemptPrivilegedUsers) && user.UserRight > UserRight.Guest) return result; var entries = GetBucketEntries(command.GetType().Name); if (!command.RateLimitOptions.Flags.HasFlag(RateLimitFlags.Global)) { entries = entries.Where(x => x.UserId == user.Id).ToList(); } if (command.RateLimitOptions.Flags.HasFlag(RateLimitFlags.UseEntireMessage)) { var hash = Convert.ToHexString(SHA256.HashData(Encoding.UTF8.GetBytes(message))); entries = entries.Where(x => x.MessageHash == hash).ToList(); } var now = DateTimeOffset.UtcNow; entries = entries.Where(x => x.EntryExpires > now).ToList(); if (entries.Count >= command.RateLimitOptions.MaxInvocations) { result.IsRateLimited = true; result.OldestEntryExpires = entries.OrderBy(x => x.EntryCreated).First().EntryExpires; } return result; } /// /// Get all the bucket entries for a given command /// /// String representation of the command. /// Get it by running command.GetType().Name /// A list of entries /// Thrown if the cached entries were somehow null when converted to a string public static List GetBucketEntries(string commandName) { var cache = MemoryCache.Default; var entries = cache.Get($"RateLimitBucket:{commandName}"); if (entries == null) return []; List bucketEntries; try { bucketEntries = JsonSerializer.Deserialize>((string)entries) ?? throw new InvalidOperationException(); } catch (Exception e) { _logger.Error($"Caught an exception when trying to deserialize RateLimitBucket entries for {commandName}. JSON follows"); _logger.Error(entries); _logger.Error("Exception follows"); _logger.Error(e); return []; } return bucketEntries; } /// /// Save the current state of bucket entries for a given command /// /// String representation of the command. /// Get it by running command.GetType().Name /// Entries you wish to save public static void SaveBucketEntries(string commandName, List entries) { var cache = MemoryCache.Default; cache.Set($"RateLimitBucket:{commandName}", JsonSerializer.Serialize(entries), new CacheItemPolicy { AbsoluteExpiration = DateTimeOffset.UtcNow.AddDays(1) }); } /// /// Remove the most recent entry for a given user and command /// Use this if you want to invalidate an entry as forgiveness for invalid user input /// /// User to remove the entry for /// Command the user ran public static void RemoveMostRecentEntry(UserDbModel user, ICommand command) { var entries = GetBucketEntries(command.GetType().Name); var lastEntry = entries.Where(x => x.UserId == user.Id).OrderBy(x => x.EntryCreated).LastOrDefault(); if (lastEntry == null) return; entries.Remove(lastEntry); SaveBucketEntries(command.GetType().Name, entries); } /// /// Add an entry to the rate limit bucket for the given command /// /// User the entry belongs to /// Command the user ran /// The user's message public static void AddEntry(UserDbModel user, ICommand command, string message) { if (command.RateLimitOptions == null) return; var commandName = command.GetType().Name; var entries = GetBucketEntries(commandName); entries.Add(new RateLimitBucketEntryModel { UserId = user.Id, EntryCreated = DateTimeOffset.UtcNow, EntryExpires = DateTimeOffset.UtcNow + command.RateLimitOptions.Window, CommandInvoked = commandName, MessageHash = Convert.ToHexString(SHA256.HashData(Encoding.UTF8.GetBytes(message))) }); SaveBucketEntries(commandName, entries); } /// /// Removes entries which have expired for all commands in the rate limit bucket /// public static void CleanupExpiredEntries() { var cache = MemoryCache.Default; var now = DateTimeOffset.UtcNow; foreach (var entry in cache.Select(kvp => kvp.Key).Where(kvp => kvp.StartsWith("RateLimitBucket:")).ToList().OfType()) { _logger.Info($"Cleaning up expired entries for {entry}"); var commandName = entry.Replace("RateLimitBucket:", string.Empty); var entries = GetBucketEntries(commandName); SaveBucketEntries(commandName, entries.Where(x => x.EntryExpires > now).ToList()); } } }