mirror of
https://github.com/barelyprofessional/KfChatDotNet.git
synced 2026-05-02 04:22:04 -04:00
Completely untested and totally experimental rate limit feature
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
using System.Text.RegularExpressions;
|
||||
using Humanizer;
|
||||
using Humanizer.Localisation;
|
||||
using KfChatDotNetBot.Commands;
|
||||
using KfChatDotNetBot.Extensions;
|
||||
using KfChatDotNetBot.Models;
|
||||
using KfChatDotNetBot.Models.DbModels;
|
||||
using KfChatDotNetBot.Settings;
|
||||
using KfChatDotNetWsClient.Models.Events;
|
||||
@@ -34,6 +36,8 @@ internal class BotCommands
|
||||
{
|
||||
_logger.Debug($"Found command {command.GetType().Name}");
|
||||
}
|
||||
|
||||
_ = CleanupExpiredRateLimitEntriesTask();
|
||||
}
|
||||
|
||||
internal void ProcessMessage(MessageModel message)
|
||||
@@ -89,12 +93,26 @@ internal class BotCommands
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (user.UserRight < command.RequiredRight)
|
||||
{
|
||||
_bot.SendChatMessage($"@{message.Author.Username}, you do not have access to use this command. Your rank: {user.UserRight.Humanize()}; Required rank: {command.RequiredRight.Humanize()}", true);
|
||||
if (continueAfterProcess) continue;
|
||||
break;
|
||||
}
|
||||
|
||||
if (command.RateLimitOptions != null)
|
||||
{
|
||||
var isRateLimited = RateLimitService.IsRateLimited(user, command, message.MessageRawHtmlDecoded);
|
||||
if (isRateLimited.IsRateLimited)
|
||||
{
|
||||
_ = SendCooldownResponse(user, command.RateLimitOptions, isRateLimited.OldestEntryExpires!.Value, command.GetType().Name);
|
||||
}
|
||||
else
|
||||
{
|
||||
RateLimitService.AddEntry(user, command, message.MessageRawHtmlDecoded);
|
||||
}
|
||||
}
|
||||
_ = ProcessMessageAsync(command, message, user, match.Groups);
|
||||
if (!continueAfterProcess) break;
|
||||
}
|
||||
@@ -136,6 +154,48 @@ internal class BotCommands
|
||||
$"🤑🤑 {user.KfUsername} has leveled up to to {newLevel.VipLevel.Icon} {newLevel.VipLevel.Name} Tier {newLevel.Tier} " +
|
||||
$"and received a bonus of {await payout.FormatKasinoCurrencyAsync()}", true);
|
||||
}
|
||||
|
||||
private async Task SendCooldownResponse(UserDbModel user, RateLimitOptionsModel options, DateTimeOffset oldestEntryExpires, string commandName)
|
||||
{
|
||||
if (options.Flags.HasFlag(RateLimitFlags.NoResponse)) return;
|
||||
var timeRemaining = oldestEntryExpires - DateTimeOffset.UtcNow;
|
||||
var message = await _bot.SendChatMessageAsync($"{user.FormatUsername()}, please wait {timeRemaining.Humanize(maxUnit: TimeUnit.Minute, minUnit: TimeUnit.Millisecond, precision: 2)} before attempting to run {commandName} again.", true);
|
||||
if (!options.Flags.HasFlag(RateLimitFlags.AutoDeleteCooldownResponse)) return;
|
||||
var i = 0;
|
||||
while (message.ChatMessageId == null)
|
||||
{
|
||||
i++;
|
||||
await Task.Delay(250, _cancellationToken);
|
||||
if (i > 30)
|
||||
{
|
||||
_logger.Error("Gave up waiting for Sneedchat to give us the message ID for removing a cooldown notification");
|
||||
return;
|
||||
}
|
||||
|
||||
if (message.Status is SentMessageTrackerStatus.NotSending or SentMessageTrackerStatus.Lost)
|
||||
{
|
||||
_logger.Error("Cooldown message was suppressed or lost");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
var autoDeleteInterval =
|
||||
(await SettingsProvider.GetValueAsync(BuiltIn.Keys.BotRateLimitCooldownAutoDeleteDelay)).ToType<int>();
|
||||
await Task.Delay(autoDeleteInterval, _cancellationToken);
|
||||
await _bot.KfClient.DeleteMessageAsync(message.ChatMessageId.Value);
|
||||
}
|
||||
|
||||
private async Task CleanupExpiredRateLimitEntriesTask()
|
||||
{
|
||||
while (!_cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
var interval = (await SettingsProvider.GetValueAsync(BuiltIn.Keys.BotRateLimitExpiredEntryCleanupInterval))
|
||||
.ToType<int>();
|
||||
await Task.Delay(TimeSpan.FromSeconds(interval), _cancellationToken);
|
||||
_logger.Info("Cleaning up expired rate limit entries");
|
||||
RateLimitService.CleanupExpiredEntries();
|
||||
}
|
||||
}
|
||||
|
||||
private static bool HasAttribute<T>(ICommand command) where T : Attribute
|
||||
{
|
||||
|
||||
149
KfChatDotNetBot/Services/RateLimitService.cs
Normal file
149
KfChatDotNetBot/Services/RateLimitService.cs
Normal file
@@ -0,0 +1,149 @@
|
||||
using System.Runtime.Caching;
|
||||
using System.Security.Cryptography;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using KfChatDotNetBot.Commands;
|
||||
using KfChatDotNetBot.Models;
|
||||
using KfChatDotNetBot.Models.DbModels;
|
||||
using NLog;
|
||||
|
||||
namespace KfChatDotNetBot.Services;
|
||||
|
||||
public static class RateLimitService
|
||||
{
|
||||
private static Logger _logger = LogManager.GetCurrentClassLogger();
|
||||
|
||||
/// <summary>
|
||||
/// Check whether a user is rate limited for a given command
|
||||
/// </summary>
|
||||
/// <param name="user">User you wish to check</param>
|
||||
/// <param name="command">Command the user is invoking</param>
|
||||
/// <param name="message">Message the user sent</param>
|
||||
/// <returns></returns>
|
||||
public static IsRateLimitedModel IsRateLimited(UserDbModel user, ICommand command, string message)
|
||||
{
|
||||
var result = new IsRateLimitedModel
|
||||
{
|
||||
IsRateLimited = false
|
||||
};
|
||||
if (command.RateLimitOptions == null) return result;
|
||||
if (command.RateLimitOptions.Flags.HasFlag(RateLimitFlags.ExemptPrivilegedUsers) &&
|
||||
user.UserRight > UserRight.Guest) return result;
|
||||
var entries = GetBucketEntries(command.GetType().Name);
|
||||
if (!command.RateLimitOptions.Flags.HasFlag(RateLimitFlags.Global))
|
||||
{
|
||||
entries = entries.Where(x => x.UserId == user.Id).ToList();
|
||||
}
|
||||
|
||||
if (command.RateLimitOptions.Flags.HasFlag(RateLimitFlags.UseEntireMessage))
|
||||
{
|
||||
var hash = Convert.ToHexString(SHA256.HashData(Encoding.UTF8.GetBytes(message)));
|
||||
entries = entries.Where(x => x.MessageHash == hash).ToList();
|
||||
}
|
||||
|
||||
var now = DateTimeOffset.UtcNow;
|
||||
entries = entries.Where(x => x.EntryExpires > now).ToList();
|
||||
if (entries.Count >= command.RateLimitOptions.MaxInvocations)
|
||||
{
|
||||
result.IsRateLimited = true;
|
||||
result.OldestEntryExpires = entries.OrderBy(x => x.EntryCreated).Last().EntryExpires;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get all the bucket entries for a given command
|
||||
/// </summary>
|
||||
/// <param name="commandName">String representation of the command.
|
||||
/// Get it by running command.GetType().Name</param>
|
||||
/// <returns>A list of entries</returns>
|
||||
/// <exception cref="InvalidOperationException">Thrown if the cached entries were somehow null when converted to a string</exception>
|
||||
public static List<RateLimitBucketEntryModel> GetBucketEntries(string commandName)
|
||||
{
|
||||
var cache = MemoryCache.Default;
|
||||
var entries = cache.Get($"RateLimitBucket:{commandName}");
|
||||
if (entries == null) return [];
|
||||
List<RateLimitBucketEntryModel> bucketEntries;
|
||||
try
|
||||
{
|
||||
bucketEntries = JsonSerializer.Deserialize<List<RateLimitBucketEntryModel>>((string)entries) ??
|
||||
throw new InvalidOperationException();
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
_logger.Error($"Caught an exception when trying to deserialize RateLimitBucket entries for {commandName}. JSON follows");
|
||||
_logger.Error(entries);
|
||||
_logger.Error("Exception follows");
|
||||
_logger.Error(e);
|
||||
return [];
|
||||
}
|
||||
|
||||
return bucketEntries;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Save the current state of bucket entries for a given command
|
||||
/// </summary>
|
||||
/// <param name="commandName">String representation of the command.
|
||||
/// Get it by running command.GetType().Name</param>
|
||||
/// <param name="entries">Entries you wish to save</param>
|
||||
public static void SaveBucketEntries(string commandName, List<RateLimitBucketEntryModel> entries)
|
||||
{
|
||||
var cache = MemoryCache.Default;
|
||||
cache.Set($"RateLimitBucket:{commandName}", JsonSerializer.Serialize(entries),
|
||||
new CacheItemPolicy { AbsoluteExpiration = DateTimeOffset.UtcNow.AddDays(1) });
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Remove the most recent entry for a given user and command
|
||||
/// Use this if you want to invalidate an entry as forgiveness for invalid user input
|
||||
/// </summary>
|
||||
/// <param name="user">User to remove the entry for</param>
|
||||
/// <param name="command">Command the user ran</param>
|
||||
public static void RemoveMostRecentEntry(UserDbModel user, ICommand command)
|
||||
{
|
||||
var entries = GetBucketEntries(command.GetType().Name);
|
||||
var lastEntry = entries.Where(x => x.UserId == user.Id).OrderBy(x => x.EntryCreated).LastOrDefault();
|
||||
if (lastEntry == null) return;
|
||||
entries.Remove(lastEntry);
|
||||
SaveBucketEntries(command.GetType().Name, entries);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Add an entry to the rate limit bucket for the given command
|
||||
/// </summary>
|
||||
/// <param name="user">User the entry belongs to</param>
|
||||
/// <param name="command">Command the user ran</param>
|
||||
/// <param name="message">The user's message</param>
|
||||
public static void AddEntry(UserDbModel user, ICommand command, string message)
|
||||
{
|
||||
if (command.RateLimitOptions == null) return;
|
||||
var commandName = command.GetType().Name;
|
||||
var entries = GetBucketEntries(commandName);
|
||||
entries.Add(new RateLimitBucketEntryModel
|
||||
{
|
||||
UserId = user.Id,
|
||||
EntryCreated = DateTimeOffset.UtcNow,
|
||||
EntryExpires = DateTimeOffset.UtcNow + command.RateLimitOptions.Window,
|
||||
CommandInvoked = commandName,
|
||||
MessageHash = Convert.ToHexString(SHA256.HashData(Encoding.UTF8.GetBytes(message)))
|
||||
});
|
||||
SaveBucketEntries(commandName, entries);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Removes entries which have expired for all commands in the rate limit bucket
|
||||
/// </summary>
|
||||
public static void CleanupExpiredEntries()
|
||||
{
|
||||
var cache = MemoryCache.Default;
|
||||
var now = DateTimeOffset.UtcNow;
|
||||
foreach (var entry in cache.Select(kvp => kvp.Key).Where(kvp => kvp.StartsWith("RateLimitBucket:")).ToList().OfType<string>())
|
||||
{
|
||||
_logger.Info($"Cleaning up expired entries for {entry}");
|
||||
var commandName = entry.Replace("RateLimitBucket:", string.Empty);
|
||||
var entries = GetBucketEntries(commandName);
|
||||
SaveBucketEntries(commandName, entries.Where(x => x.EntryExpires > now).ToList());
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user