Browse Source

Automatically remove roles on mute

ghorsington 3 years ago
parent
commit
6dec6bc60e
2 changed files with 47 additions and 12 deletions
  1. 44 12
      bot/src/plugins/violation.ts
  2. 3 0
      shared/src/db/entity/Violation.ts

+ 44 - 12
bot/src/plugins/violation.ts

@@ -26,12 +26,14 @@ interface ViolationInfo {
 
 type TimedViolation = Violation & { endsAt: Date };
 
-type StartViolationFunction = (member: GuildMember | PartialGuildMember, settings: GuildViolationSettings) => Promise<void>;
-type StopViolationFunction = (guild: Guild, userId: string, settings: GuildViolationSettings) => Promise<void>;
+type ModifyViolationFunction = (member: GuildMember | PartialGuildMember, settings: GuildViolationSettings, violation: DeepPartial<TimedViolation>) => DeepPartial<TimedViolation>; 
+type StartViolationFunction = (member: GuildMember | PartialGuildMember, settings: GuildViolationSettings, violation: TimedViolation) => Promise<void>;
+type StopViolationFunction = (guild: Guild, userId: string, settings: GuildViolationSettings, violation: TimedViolation) => Promise<void>;
 interface TimedViolationStopHandler {
     type: ObjectType<TimedViolation>;
     start: StartViolationFunction;
     stop: StopViolationFunction;
+    modify?: ModifyViolationFunction;
     command: string;
 }
 
@@ -42,7 +44,7 @@ export class ViolationPlugin {
         {
             command: "mute",
             type: Mute,
-            start: async (member: GuildMember | PartialGuildMember, settings: GuildViolationSettings): Promise<void> => {
+            start: async (member: GuildMember | PartialGuildMember, settings: GuildViolationSettings, violation: TimedViolation): Promise<void> => {
                 const muteRoleResolve = await tryDo(member.guild.roles.fetch(settings.muteRoleId));
                 if (!muteRoleResolve.ok || !muteRoleResolve.result) {
                     logger.error(
@@ -53,9 +55,20 @@ export class ViolationPlugin {
                         settings.muteRoleId);
                     return;
                 }
+                // First mute, then remove other roles
                 await member.roles.add(muteRoleResolve.result);
+                const mute = violation as Mute;
+                if (mute.previousRoles) {
+                    const result = await tryDo(member.roles.remove(mute.previousRoles));
+                    if (!result.ok) {
+                        logger.error("mute: couldn't remove all roles from user %s#%s (%s)!" ,
+                            member.user?.username,
+                            member.user?.discriminator,
+                            member.user?.id);
+                    }
+                }
             },
-            stop: async (guild: Guild, userId: string, settings: GuildViolationSettings): Promise<void> => {
+            stop: async (guild: Guild, userId: string, settings: GuildViolationSettings, violation: TimedViolation): Promise<void> => {
                 const muteRoleResolve = await tryDo(guild.roles.fetch(settings.muteRoleId));
 
                 if (!muteRoleResolve.ok || !muteRoleResolve.result) {
@@ -72,6 +85,18 @@ export class ViolationPlugin {
                 }
 
                 await memberResolve.result.roles.remove(muteRole);
+                const mute = violation as Mute;
+                if (mute.previousRoles) {
+                    const result = await tryDo(memberResolve.result.roles.add(mute.previousRoles));
+                    if (!result.ok) {
+                        logger.warn("mute: couldn't readd all roles for user %s (tried to restore role ids: %s)", memberResolve.result.id, mute.previousRoles.join(", "));
+                    }
+                }
+            },
+            modify: (member: GuildMember | PartialGuildMember, settings: GuildViolationSettings, violation: DeepPartial<Mute>): DeepPartial<Mute> => {
+                const originalRoles = member.roles.cache.keyArray().filter(r => r != settings.muteRoleId);
+                violation.previousRoles = originalRoles;
+                return violation;
             }
         }
     ];
@@ -130,7 +155,7 @@ export class ViolationPlugin {
                 if (violation.endsAt < new Date())
                     await repo.update({ id: violation.id }, { valid: false });
                 else
-                    await handler.start(member, settings);
+                    await handler.start(member, settings, violation);
             }
         }
     }
@@ -159,7 +184,7 @@ export class ViolationPlugin {
         if (!info.dryRun) {
             eventLogger.warn("User %s#%s muted user %s#%s for %s because: %s", message.author.username, message.author.discriminator, info.member.user.username, info.member.user.discriminator, info.duration, info.reason);
         }
-        await this.applyTimedViolation(Mute, info, "mute", handler.start, handler.stop);
+        await this.applyTimedViolation(Mute, info, "mute", handler.start, handler.stop, handler.modify);
         await this.sendViolationMessage(message, info, "User has been muted for server violation");
     }
 
@@ -232,11 +257,11 @@ export class ViolationPlugin {
         delete this.jobs[existingViolation.id];
 
         const handler = this.getViolationHandler(type);
-        await handler.stop(message.guild, user.id, settings);
+        await handler.stop(message.guild, user.id, settings, existingViolation);
         await message.reply(`removed ${command} on user!`);
     }
 
-    private async applyTimedViolation<T extends TimedViolation>(type: ObjectType<T>, info: ViolationInfo, command = "violation", apply: StartViolationFunction, remove: StopViolationFunction) {
+    private async applyTimedViolation<T extends TimedViolation>(type: ObjectType<T>, info: ViolationInfo, command = "violation", apply: StartViolationFunction, remove: StopViolationFunction, modify?: ModifyViolationFunction) {
         if (info.dryRun)
             return;
 
@@ -249,22 +274,29 @@ export class ViolationPlugin {
             }
         });
 
+        let appliedViolation: T;
         if (existingViolation) {
             logger.warn("%s: trying to reapply on user %s#%s (%s)", command, info.member.user.username, info.member.user.discriminator, info.member.id);
             await violationRepo.update({ id: existingViolation.id } as unknown as FindConditions<T>, { endsAt: info.endDate } as unknown as QueryDeepPartialEntity<T>);
             const job = this.jobs[existingViolation.id];
             rescheduleJob(job, info.endDate);
+            appliedViolation = existingViolation;
         } else {
-            const newViolation = await violationRepo.save({
+            let rawViolation: DeepPartial<TimedViolation> = {
                 guildId: info.guild.id,
                 userId: info.member.id,
                 reason: info.reason,
                 endsAt: info.endDate,
                 valid: true,
-            } as unknown as DeepPartial<T>);
+            };
+            if (modify) {
+                rawViolation = modify(info.member, info.settings, rawViolation);
+            }
+            const newViolation = await violationRepo.save(rawViolation as unknown as DeepPartial<T>);
             this.jobs[newViolation.id] = scheduleJob(info.endDate, this.scheduleRemoveViolation(type, info.guild.id, info.member.id, remove, command));
+            appliedViolation = newViolation;
         }
-        await apply(info.member, info.settings);
+        await apply(info.member, info.settings, appliedViolation);
     }
 
     private scheduleRemoveViolation<T extends TimedViolation>(type: ObjectType<T>, guildId: string, userId: string, handle: StopViolationFunction, command = "violation") {
@@ -300,7 +332,7 @@ export class ViolationPlugin {
                 return;
             }
 
-            await handle(guild, userId, settings);
+            await handle(guild, userId, settings, violation);
         };
     }
 

+ 3 - 0
shared/src/db/entity/Violation.ts

@@ -23,6 +23,9 @@ export abstract class Violation {
 export class Mute extends Violation {
     @Column()
     endsAt: Date;
+
+    @Column("text", { array: true })
+    previousRoles?: string[];
 }
 
 @ChildEntity()