diff --git a/bot/ai/handle_request.py b/bot/ai/handle_request.py index 351771b..d3a8171 100644 --- a/bot/ai/handle_request.py +++ b/bot/ai/handle_request.py @@ -118,7 +118,7 @@ class DiscourseSummarizer: return response.text -async def forward_to_google_api(prompt, bot, image_object=None): +async def forward_to_google_api(prompt, bot, image_object=None, reply=None): """ Forwards the message content and optional image object to a Google API. @@ -126,6 +126,7 @@ async def forward_to_google_api(prompt, bot, image_object=None): prompt (discord.Message): The message object to forward. bot (discord.Client): The Discord bot instance. image_object (tuple, optional): A tuple containing the image URL and its MIME type (e.g., ("url", "image/jpeg")). + reply (discord.Message, optional): The message that was referenced by prompt. """ if not API_KEY: await prompt.reply( @@ -135,12 +136,21 @@ async def forward_to_google_api(prompt, bot, image_object=None): return input = [prompt.content] + + # Have the reply come first in the prompt + if reply: + input.insert(0, reply.content) + if image_object: try: image_url, mime_type = image_object image = requests.get(image_url) image.raise_for_status() - input.append(types.Part.from_bytes(data=image.content, mime_type=mime_type)) + + # If there is an image, add it to the input before anything else + input.insert( + 0, types.Part.from_bytes(data=image.content, mime_type=mime_type) + ) except requests.RequestException: await prompt.reply(f"Failed to fetch the image", mention_author=True) return diff --git a/bot/events_handlers/message_events.py b/bot/events_handlers/message_events.py index 9e55c19..c5fbec9 100644 --- a/bot/events_handlers/message_events.py +++ b/bot/events_handlers/message_events.py @@ -38,6 +38,54 @@ SPAM_ROLE_ID = 1350511935677927514 STAFF_ROLE_ID = 1112016152873414707 +def fetch_image_from_message(message): + image_object = None + for attachment in message.attachments: + if attachment.filename.lower().endswith( + ".jpg" + ) or attachment.filename.lower().endswith(".jpeg"): + image_object = (attachment.url, "image/jpeg") + break + elif attachment.filename.lower().endswith(".png"): + image_object = (attachment.url, "image/png") + break + return image_object + + +async def handle_bot_mention(message, bot): + staff_role = message.guild.get_role(STAFF_ROLE_ID) + member = message.guild.get_member(message.author.id) + if staff_role in member.roles: + # Prioritize the image object from the first message + image_object = fetch_image_from_message(message) + + # Check if the message is a reply to another message + reply_content = None + if message.reference: + try: + referenced_message = await message.channel.fetch_message( + message.reference.message_id + ) + reply_content = referenced_message + + # Check if the referenced message has an image object (if not already set) + if image_object is None: + image_object = fetch_image_from_message(referenced_message) + + except discord.NotFound: + print("Referenced message not found.") + except discord.Forbidden: + print("Bot does not have permission to fetch the referenced message.") + except discord.HTTPException as e: + print(f"An error occurred while fetching the referenced message: {e}") + + # Pass the reply content to forward_to_google_api + await forward_to_google_api(message, bot, image_object, reply_content) + return True + + return False + + async def handle_dm(message): await message.channel.send( "If you DM this bot again, I will carpet-bomb your house." @@ -314,22 +362,7 @@ async def handle_message(message, bot): # Check if the bot is mentioned if bot.user in message.mentions: - staff_role = message.guild.get_role(STAFF_ROLE_ID) - member = message.guild.get_member(message.author.id) - if staff_role in member.roles: - image_object = None - - for attachment in message.attachments: - if attachment.filename.lower().endswith( - ".jpg" - ) or attachment.filename.lower().endswith(".jpeg"): - image_object = (attachment.url, "image/jpeg") - break - elif attachment.filename.lower().endswith(".png"): - image_object = (attachment.url, "image/png") - break - - await forward_to_google_api(message, bot, image_object) + if await handle_bot_mention(message, bot): return # Too many mentions