diff --git a/configs/__init__.py b/configs/__init__.py index 345d461..21440b2 100644 --- a/configs/__init__.py +++ b/configs/__init__.py @@ -43,6 +43,8 @@ def __init__(self): load_dotenv(config_parent / ".env") self.owner_ids: list = cfg["OWNER_IDS"] + self.cohort_channels: list[int] = cfg["COHORT_CHANNEL_IDS"] + self.staff_role_ids: list[int] = cfg["STAFF_ROLE_IDS"] self.TOKEN: str = getenv("TOKEN") self.DBINFO: dict[str: str] = {"host": getenv("DBIP"), "user": getenv("DBUN"), "password": getenv("DBPW"), "database": getenv("DBNAME")} diff --git a/configs/config-sample.json b/configs/config-sample.json index b97f12a..b41da59 100644 --- a/configs/config-sample.json +++ b/configs/config-sample.json @@ -1 +1,5 @@ -{"OWNER_IDS": [1234567890]} \ No newline at end of file +{ + "OWNER_IDS": [1234567890], + "STAFF_ROLE_IDS": [1234567890, 1234567890], + "COHORT_CHANNEL_IDS" : [ATSACHANNELID, TCACHANNELID] +} \ No newline at end of file diff --git a/presentation_utils.py b/presentation_utils.py new file mode 100644 index 0000000..8520dd6 --- /dev/null +++ b/presentation_utils.py @@ -0,0 +1,9 @@ +"""Contains utility functions for communication/presenting text to users""" + + +def get_ord(n: int) -> str: + suffixes = ("th", "st", "nd", "rd", "th", "th", "th", "th", "th", "th", "th", "th", "th", "th") + if n % 100 > 13: + return suffixes[n % 10] + else: + return suffixes[n % 100] diff --git a/tb_db.py b/tb_db.py index 73c7f35..9feb0a5 100644 --- a/tb_db.py +++ b/tb_db.py @@ -15,7 +15,8 @@ def sql_func(func: Callable) -> Callable: def wrapper(*args, **kwargs): with connect_to_db() as conn: with conn.cursor() as cursor: - func(conn, cursor, *args, **kwargs) + return func(conn, cursor, *args, **kwargs) + return wrapper @@ -67,11 +68,15 @@ def sql_op(sql_cmd: list[str] | str, args: list[tuple] | tuple, "uid BIGINT UNSIGNED PRIMARY KEY," "requests BLOB DEFAULT ''," "attended_sessions BLOB DEFAULT ''," - "completed_sessions BLOB DEFAULT '');", + "completed_sessions BLOB DEFAULT ''," + "cohort INT UNSIGNED DEFAULT 0);", "CREATE OR REPLACE TABLE server_data(" "id TINYINT UNSIGNED PRIMARY KEY," "data BLOB DEFAULT '');", - "INSERT INTO server_data VALUES (0, %s);"], [(), (), (), (), (int(0).to_bytes(2, "big") * 256)]) + "INSERT INTO server_data VALUES (0, %s);", + "INSERT INTO server_data VALUES (1, %s);"], + [(), (), (), (), (int(0).to_bytes(2, "big") * 256,), + (int(1 << 16 + 1).to_bytes(4, "big"),)]) print(*sql_op(["SELECT * FROM user_comms;", "SELECT * FROM persistent_messages;", "SELECT * FROM students", "SELECT * FROM server_data"], [(), (), (), ()], fetch_all=True), sep="\n") else: @@ -87,7 +92,8 @@ def sql_op(sql_cmd: list[str] | str, args: list[tuple] | tuple, "uid BIGINT UNSIGNED PRIMARY KEY," "requests BLOB DEFAULT ''," "attended_sessions BLOB DEFAULT ''," - "completed_sessions BLOB DEFAULT '');", + "completed_sessions BLOB DEFAULT ''," + "cohort INT UNSIGNED DEFAULT 0);", "CREATE TABLE IF NOT EXISTS server_data(" "id TINYINT UNSIGNED PRIMARY KEY," "data BLOB DEFAULT '');"], [(), (), (), ()]) diff --git a/tb_discord/__init__.py b/tb_discord/__init__.py index 40f934b..4ef0d41 100644 --- a/tb_discord/__init__.py +++ b/tb_discord/__init__.py @@ -8,6 +8,7 @@ intents.guilds = True intents.members = True intents.guild_messages = True +intents.voice_states = True bot = Bot(command_prefix="t?", intents=intents) diff --git a/tb_discord/tb_commands/filters.py b/tb_discord/tb_commands/filters.py index 2840469..ea01cdb 100644 --- a/tb_discord/tb_commands/filters.py +++ b/tb_discord/tb_commands/filters.py @@ -4,7 +4,27 @@ from configs import configs -__all__ = ["check_is_owner"] +__all__ = ["check_is_owner", "check_is_staff"] + + +def check_is_staff(): + """ + Checks if author of a message has a DC staff or administrator role + Usage: + @bot.command() + @check_is_owner() + async def command(...): + Commands with this check should not appear to any non-admin + Returns: + True or False | If owner is or is not in config + """ + def predicate(interaction: Interaction): + for role in configs.staff_role_ids: + if role in map(lambda x: x.id, interaction.user.roles): + return True + return False + + return app_commands.check(predicate) def check_is_owner(): diff --git a/tb_discord/tb_commands/lesson_tracking.py b/tb_discord/tb_commands/lesson_tracking.py index 376a913..5039237 100644 --- a/tb_discord/tb_commands/lesson_tracking.py +++ b/tb_discord/tb_commands/lesson_tracking.py @@ -1,6 +1,11 @@ -from discord import app_commands, Interaction -from tb_discord.tb_ui.lesson_tracking import Requests -from tb_db import sql_op +import pymysql +from configs import configs +from discord import app_commands, Interaction, VoiceChannel +from presentation_utils import get_ord +from tb_discord import bot +from tb_discord.tb_commands.filters import check_is_staff +from tb_discord.tb_ui.lesson_tracking import CohortUI, Requests +from tb_db import sql_func, sql_op __all__ = ["command_list"] @@ -9,11 +14,12 @@ @app_commands.command() async def register(inter: Interaction): if sql_op("SELECT COUNT(1) FROM students WHERE uid = %s", (inter.user.id,)) == (0,): - sql_op("INSERT INTO students VALUES (%s, default, default, default)", (inter.user.id,)) + sql_op("INSERT INTO students VALUES (%s, default, default, default, NULL)", (inter.user.id,)) await inter.response.send_message("Registered successfully", ephemeral=True) else: - await inter.response.send_message("You have already registered with Towerbot. To unregister, please contact a DC Staff member", - ephemeral=True) + await inter.response.send_message( + "You have already registered with Towerbot. To unregister, please contact a DC Staff member", ephemeral=True + ) @app_commands.command() @@ -46,7 +52,7 @@ async def request_training(inter: Interaction, branch: int, lesson_num: int): request_counts[int.from_bytes(to_request, "big", signed=False)] += 1 # Request counts are stored in a bytestring indexed by the request number as described above. Request counts themselves are - # a two byte big-endian unsigned integer. Ex. ACAD-02 is stored in bytes 2 & 3 + # a two byte big-endian unsigned integer. Ex. ACAD-02 is stored in bytes 4 & 5 request_count_bytes = b"".join(map(lambda x: x.to_bytes(2, "big", signed=False), request_counts)) await inter.response.send_message(f"Request processed", ephemeral=True) @@ -64,4 +70,95 @@ async def lesson_requests(inter: Interaction): tca_requests = request_counts[128:] await inter.response.send_message(embed=Requests(atsa_requests, tca_requests)) -command_list = [lesson_requests, register, request_training] + +@app_commands.command() +@app_commands.choices(branch=[ + app_commands.Choice(name="ATSA PRAC", value=0), + app_commands.Choice(name="TACAD", value=1) +]) +@check_is_staff() +async def clear_lesson_request(inter: Interaction, branch: int, lesson_num: int): + await inter.response.defer(thinking=True, ephemeral=True) + + index = (branch << 7) + lesson_num + + requests: tuple[tuple[bytes, int]] = sql_op("SELECT requests, uid FROM students", (), fetch_all=True) + for request in requests: + request[0].replace(index.to_bytes(1, "big", signed=False), b"") + update_student_requests(requests) + + counts: bytes = sql_op("SELECT data FROM server_data WHERE id = 0", ())[0] + new_counts = counts[:index * 2] + b"\x00\x00" + counts[index * 2 + 2:] + sql_op("UPDATE server_data SET data = %s WHERE id = 0", (new_counts,)) + + await inter.followup.send("Lessons cleared", ephemeral=True) + + +@sql_func +def update_student_requests(conn: pymysql.Connection, cursor: pymysql.connections.Cursor, requests: tuple[tuple[bytes, int]]): + for student in requests: + cursor.execute("UPDATE students SET requests = %s WHERE uid = %s", (student[0], student[1])) + conn.commit() + + +@app_commands.command() +@app_commands.choices(branch=[ + app_commands.Choice(name="ATSA", value=0), + app_commands.Choice(name="TCA", value=1) +]) +@check_is_staff() +async def create_cohort(inter: Interaction, channel: VoiceChannel, branch: int): + members = channel.voice_states.keys() + await inter.response.defer(ephemeral=True, thinking=True) + + cohorts = get_cohorts(members) + if not cohorts: + await inter.followup.send("There are no registered users in the voice channel", ephemeral=True) + return + + filter(lambda x: ((x[1] >> (16 * branch)) & 0x0000ffff) == 0, cohorts) + + await inter.followup.send(f"Creating cohort with {len(cohorts)} members", ephemeral=True) + + cohort_data = int.from_bytes(sql_op("SELECT data FROM server_data WHERE id = 1", ())[0], "big") + next_cohort = (cohort_data >> (16 * branch)) & 0x0000ffff + + cohort_channel = bot.get_channel(configs.cohort_channels[branch]) + cohort_thread = await cohort_channel.create_thread(invitable=True, name= + f"{next_cohort}{get_ord(next_cohort)} {'ATSA' if branch == 0 else 'TCA'} Prospective Cohort") + + cohort_message = await cohort_thread.send("<@" + "><@".join([str(x[0]) for x in cohorts]) + + f"> Welcome to the {next_cohort}{get_ord(next_cohort)} {'ATSA' if branch == 0 else 'TCA'} Prospective Cohort!\n\n" + "Cohorts are small groups of students who attend Digital Controllers sessions around the same time. We encourage " + "you to get to know each other, ask each other questions, and attend future sessions together! By creating cohorts " + "we hope to both ease the difficulty of your learning and create a sense of close community.\n\nTo confirm your " + "interest in this cohort, press the \"Join Cohort\" button below. Don't worry, nothing's permanent, you can always " + "click the \"Leave Cohort\" button and be removed from this thread. If you would like to invite friends, just @ them " + "in this thread and they will be added.\n\nWe wish you the best of luck in your future endeavours here at DC!\n\\- DC " + "Staff and Moderation Team") + + await CohortUI.create(cohort_message, cohort_thread, branch, next_cohort) + + # REMEMBER TO CONVERT TO BYTES FOR BLOB, I SPENT TWO HOURS DEBUGGING THIS + new_cohort_data = (((next_cohort + 1) << (16 * branch)) + (cohort_data & (0xffff << (16 * (1 - branch))))) \ + .to_bytes(4, "big", signed=False) + + sql_op("UPDATE server_data SET data = %s WHERE id = 1", (new_cohort_data,)) + sql_op("INSERT INTO persistent_messages VALUES (%s, %s, 2, %s)", + (cohort_message.id, cohort_thread.id, (branch << 16) + next_cohort)) + + await inter.followup.send("Created cohort", ephemeral=True) + + +@sql_func +def get_cohorts(conn: pymysql.Connection, cursor: pymysql.connections.Cursor, members: tuple[int]) -> list[(int, int)]: + out = [] + for member in members: + cursor.execute("SELECT cohort FROM students WHERE uid = %s", (member,)) + cohort = cursor.fetchone() + if cohort is not None: + out.append((member, cohort[0])) + return out + + +command_list = [create_cohort, clear_lesson_request, lesson_requests, register, request_training] diff --git a/tb_discord/tb_commands/mission_planning.py b/tb_discord/tb_commands/mission_planning.py index 74e8c66..1563655 100644 --- a/tb_discord/tb_commands/mission_planning.py +++ b/tb_discord/tb_commands/mission_planning.py @@ -1,5 +1,6 @@ -"""Towerbot commands dealing with mission planning things""" -from discord import app_commands, Interaction +"""Towerbot commands dealing with mission and event planning""" +from discord import app_commands, Interaction, ScheduledEvent +from tb_discord.tb_commands.filters import check_is_staff import server_data @@ -24,4 +25,22 @@ async def opt_out(interaction: Interaction, dcs_username: str): await interaction.response.send_message("DCS Usernames have a length limit of 25 characters, please try again.") -command_list = [opt_in, opt_out] +@app_commands.command() +@check_is_staff() +async def ping_event(inter: Interaction, event_name: str): + await inter.response.defer(thinking=True) + events = await inter.guild.fetch_scheduled_events() + events = list(filter(lambda event: event.name.lower() == event_name.lower(), events)) + if len(events) > 1: + await inter.followup.send("There is more than one event with that name", ephemeral=True) + elif len(events) == 0: + await inter.followup.send("There are no events with that name", ephemeral=True) + else: + users = [] + async for user in events[0].users(): + users.append(user) + + await inter.followup.send("<@" + "><@".join([str(user.id) for user in users]) + ">") + + +command_list = [opt_in, opt_out, ping_event] diff --git a/tb_discord/tb_commands/roles.py b/tb_discord/tb_commands/roles.py index e6f7699..f38f057 100644 --- a/tb_discord/tb_commands/roles.py +++ b/tb_discord/tb_commands/roles.py @@ -1,7 +1,7 @@ from datetime import datetime, timedelta from discord import app_commands, Message, Interaction, TextChannel, utils from discord.errors import NotFound -from tb_discord.tb_commands.filters import check_is_owner +from tb_discord.tb_commands.filters import check_is_staff from tb_discord.tb_ui import RolesMessage, RoleButtonEmbed, RoleChoiceView, RoleDeleteView @@ -9,7 +9,7 @@ @app_commands.command() -@check_is_owner() +@check_is_staff() async def create_role_buttons(interaction: Interaction, channel: TextChannel, message_id: str = None): # Get original message if message_id: @@ -38,14 +38,14 @@ async def create_role_buttons(interaction: Interaction, channel: TextChannel, me @app_commands.command() -@check_is_owner() +@check_is_staff() async def list_role_buttons(interaction: Interaction): guild_messages = tuple(filter(lambda x: x.message.guild.id == interaction.guild.id, RolesMessage.role_messages)) await interaction.response.send_message(embed=RoleButtonEmbed(guild_messages)) @app_commands.command() -@check_is_owner() +@check_is_staff() async def delete_role_buttons(interaction: Interaction): guild_messages = tuple(filter(lambda x: x.message.guild.id == interaction.guild.id, RolesMessage.role_messages)) embed = RoleButtonEmbed(guild_messages) diff --git a/tb_discord/tb_events.py b/tb_discord/tb_events.py index 4d8043d..c1a242d 100644 --- a/tb_discord/tb_events.py +++ b/tb_discord/tb_events.py @@ -1,5 +1,5 @@ from datetime import datetime -from discord import File, AllowedMentions +from discord import File from discord.errors import NotFound from os import remove from pathlib import Path @@ -7,7 +7,7 @@ from sys import argv from tb_db import sql_op from tb_discord import bot -from tb_discord.tb_ui import RolesMessage, ServersEmbed +from tb_discord.tb_ui import RolesMessage, ServersEmbed, CohortUI import logging import random import re @@ -16,7 +16,7 @@ started = False # message_types returns handler function to initialize persistent message, args structured as (message, channel, data) -message_types = [ServersEmbed.find, RolesMessage.find] +message_types = [ServersEmbed.find, RolesMessage.find, CohortUI.find] JETS = ["F16", "F18", "F15", "F35", "F22", "A10", "F14", "MIR2"] HOLDING_POINTS = ["A", "B", "C", "D"] AERODROMES = ["UG5X", "UG24", "UGKO", "UGKS", "URKA", "URKN", "URMM", "URSS"] diff --git a/tb_discord/tb_ui/__init__.py b/tb_discord/tb_ui/__init__.py index f3694b5..31de9b5 100644 --- a/tb_discord/tb_ui/__init__.py +++ b/tb_discord/tb_ui/__init__.py @@ -1,3 +1,4 @@ -"""Discord embed subclasses for self-containted updates""" +"""Discord embed subclasses for self-contained updates""" +from tb_discord.tb_ui.lesson_tracking import * from tb_discord.tb_ui.server_embeds import * from tb_discord.tb_ui.role_ui import * diff --git a/tb_discord/tb_ui/lesson_tracking.py b/tb_discord/tb_ui/lesson_tracking.py index 453815f..dad046f 100644 --- a/tb_discord/tb_ui/lesson_tracking.py +++ b/tb_discord/tb_ui/lesson_tracking.py @@ -1,12 +1,11 @@ """Module containing ui elements intended for use in the lesson_tracking commands""" -from discord import Embed, Message, TextChannel -from discord.abc import GuildChannel -from discord.errors import NotFound -from discord.ext import tasks +from discord import Embed, Thread, ButtonStyle, Interaction, Message +from discord.ui import Button, View +from presentation_utils import get_ord from tb_db import sql_op -from time import time -import logging -import server_data + + +__all__ = ['CohortUI', 'Requests'] class Requests(Embed): @@ -15,13 +14,13 @@ def __init__(self, atsa_data: list[int], tca_data: list[int]): self.set_author(name='Digital Controllers') self.set_thumbnail(url="https://raw.githubusercontent.com/Digital-Controllers/website/main/docs/assets/logo.png") - if (atsa_max := max(atsa_data)) != 0: + if max(atsa_data) != 0: requested_atsa = str(atsa_data.index(max(atsa_data))).zfill(2) - self.add_field(name="Most Requested ATSA Lesson:", value=f"ACAD-{requested_atsa}") + self.add_field(name="Most Requested ATSA Practical:", value=f"PRAC-{requested_atsa}") else: - self.add_field(name="Most Requested ATSA Lesson:", value=f"No requested ATSA lessons") + self.add_field(name="Most Requested ATSA Practical:", value=f"No requested ATSA lessons") - if (tca_max := max(tca_data)) != 0: + if max(tca_data) != 0: requested_tca = str(tca_data.index(max(tca_data))).zfill(2) self.add_field(name="Most Requested TCA Lesson:", value=f"TACAD-{requested_tca}") else: @@ -32,7 +31,7 @@ def __init__(self, atsa_data: list[int], tca_data: list[int]): atsa_max = max(atsa_data) tca_max = max(tca_data) if atsa_max > tca_max: - top_requested.append(f"ACAD-{str(atsa_data.index(atsa_max)).zfill(2)}") + top_requested.append(f"PRAC-{str(atsa_data.index(atsa_max)).zfill(2)}") atsa_data[atsa_data.index(atsa_max)] = 0 else: if tca_max == 0: @@ -42,3 +41,64 @@ def __init__(self, atsa_data: list[int], tca_data: list[int]): tca_data[tca_data.index(tca_max)] = 0 self.add_field(name="Top 5 Requested Lessons:", value="\n".join(top_requested), inline=False) + + +class CohortUI(View): + @classmethod + async def create(cls, msg: Message, thread: Thread, branch: int, num: int): + return await msg.edit(view=CohortUI(thread, branch, num)) + + @classmethod + async def find(cls, msg: Message, thread: Thread, view_data): + view_data = int(view_data) + branch = view_data >> 16 + num = view_data & 0x1111 + return await msg.edit(view=CohortUI(thread, branch, num)) + + def __init__(self, thread: Thread, branch: int, num: int): + super().__init__() + self.thread = thread + self.branch = branch + self.num = num + self.add_item(CohortJoinButton(branch, num)) + self.add_item(CohortLeaveButton(thread, branch, num)) + + +class CohortJoinButton(Button): + def __init__(self, branch: int, num: int): + super().__init__(style=ButtonStyle.primary, label="Join Cohort") + self.branch = branch + self.num = num + + async def callback(self, inter: Interaction): + existing_cohorts = sql_op("SELECT cohort FROM students WHERE uid = %s", (inter.user.id,))[0] + + if (existing_cohorts >> (self.branch * 16)) & 0x0000ffff != 0: + await inter.response.send_message("You are already in a cohort, please leave that one before trying to join a new one.", + ephemeral=True) + else: + await inter.response.send_message(f"Welcome to the {self.num}{get_ord(self.num)} {'ATSA' if self.branch == 0 else 'TCA'} " + f"Cohort <@{inter.user.id}>", ephemeral=True) + existing_cohorts |= self.num << (self.branch * 16) + sql_op("UPDATE students SET cohort = %s WHERE uid = %s", (existing_cohorts, inter.user.id)) + + +class CohortLeaveButton(Button): + def __init__(self, thread, branch: int, num: int): + super().__init__(style=ButtonStyle.primary, label="Leave Cohort") + self.thread = thread + self.branch = branch + self.num = num + + async def callback(self, inter: Interaction): + existing_cohorts = sql_op("SELECT cohort FROM students WHERE uid = %s", (inter.user.id,))[0] + if (existing_cohorts >> (self.branch * 16)) & 0x0000ffff != self.num: + await inter.response.send_message("You do not belong to this cohort", ephemeral=True) + else: + existing_cohorts &= ~(0xffff << (self.branch * 16)) + await self.thread.remove_user(inter.user) + await inter.response.defer(ephemeral=True) + sql_op("UPDATE students SET cohort = %s WHERE uid = %s", (existing_cohorts, inter.user.id)) + + +