feat: Make Grok generic

This commit is contained in:
2025-04-19 16:24:20 +02:00
parent acb065d3ef
commit f996297de9
2 changed files with 37 additions and 12 deletions

View File

@@ -5,6 +5,9 @@ from google import genai
API_KEY = os.getenv("GOOGLE_API_KEY") API_KEY = os.getenv("GOOGLE_API_KEY")
GENERIC_INSTRUCTION = "You are a Discord chatbot named 'AlterWare' who helps users with all kinds of topics across various subjects. You should limit your answers to fewer than 2000 characters."
SPECIFIC_INSTRUCTION = "You are a Discord chatbot named 'AlterWare' who helps users. You should limit your answers to fewer than 2000 characters."
class DiscourseSummarizer: class DiscourseSummarizer:
def __init__(self): def __init__(self):
@@ -41,10 +44,7 @@ class DiscourseSummarizer:
model=self.model, model=self.model,
config=types.CreateCachedContentConfig( config=types.CreateCachedContentConfig(
display_name=self.display_name, display_name=self.display_name,
system_instruction=system_instruction system_instruction=system_instruction or (SPECIFIC_INSTRUCTION),
or (
"You are a Discord chat bot named 'AlterWare' who helps users. You should limit your answers to be less than 2000 characters."
),
contents=[topic_data], contents=[topic_data],
ttl=self.ttl, ttl=self.ttl,
), ),
@@ -85,7 +85,7 @@ class DiscourseSummarizer:
contents=prompt, contents=prompt,
config=types.GenerateContentConfig( config=types.GenerateContentConfig(
max_output_tokens=400, max_output_tokens=400,
system_instruction="You are a Discord chat bot named 'AlterWare' who helps users. You should limit your answers to be less than 2000 characters.", system_instruction=SPECIFIC_INSTRUCTION,
cached_content=self.cache.name, cached_content=self.cache.name,
), ),
) )
@@ -112,13 +112,26 @@ class DiscourseSummarizer:
contents=prompt, contents=prompt,
config=types.GenerateContentConfig( config=types.GenerateContentConfig(
max_output_tokens=400, max_output_tokens=400,
system_instruction="You are a Discord chat bot named 'AlterWare' who helps users. You should limit your answers to be less than 2000 characters.", system_instruction=SPECIFIC_INSTRUCTION,
),
)
return response.text
def ask_without_context(self, prompt):
response = self.client.models.generate_content(
model=self.model,
contents=prompt,
config=types.GenerateContentConfig(
max_output_tokens=400,
system_instruction=GENERIC_INSTRUCTION,
), ),
) )
return response.text return response.text
async def forward_to_google_api(prompt, bot, image_object=None, reply=None): async def forward_to_google_api(
prompt, bot, image_object=None, reply=None, no_context=False
):
""" """
Forwards the message content and optional image object to a Google API. Forwards the message content and optional image object to a Google API.
@@ -127,6 +140,7 @@ async def forward_to_google_api(prompt, bot, image_object=None, reply=None):
bot (discord.Client): The Discord bot instance. bot (discord.Client): The Discord bot instance.
image_object (tuple, optional): A tuple containing the image URL and its MIME type (e.g., ("url", "image/jpeg")). image_object (tuple, optional): A tuple containing the image URL and its MIME type (e.g., ("url", "image/jpeg")).
reply (discord.Message, optional): The message that was referenced by prompt. reply (discord.Message, optional): The message that was referenced by prompt.
no_context (bool, optional): If True, the bot will not use any cached content or context.
""" """
if not API_KEY: if not API_KEY:
await prompt.reply( await prompt.reply(
@@ -155,7 +169,12 @@ async def forward_to_google_api(prompt, bot, image_object=None, reply=None):
await prompt.reply(f"Failed to fetch the image", mention_author=True) await prompt.reply(f"Failed to fetch the image", mention_author=True)
return return
response = bot.ai_helper.ask_without_cache(input) response = None
if no_context:
response = bot.ai_helper.ask_without_context(input)
else:
response = bot.ai_helper.ask_without_cache(input)
await prompt.reply( await prompt.reply(
response, response,

View File

@@ -53,7 +53,7 @@ def fetch_image_from_message(message):
return image_object return image_object
async def handle_bot_mention(message, bot): async def handle_bot_mention(message, bot, no_context=False):
staff_role = message.guild.get_role(STAFF_ROLE_ID) staff_role = message.guild.get_role(STAFF_ROLE_ID)
member = message.guild.get_member(message.author.id) member = message.guild.get_member(message.author.id)
if staff_role in member.roles: if staff_role in member.roles:
@@ -81,7 +81,9 @@ async def handle_bot_mention(message, bot):
print(f"An error occurred while fetching the referenced message: {e}") print(f"An error occurred while fetching the referenced message: {e}")
# Pass the reply content to forward_to_google_api # Pass the reply content to forward_to_google_api
await forward_to_google_api(message, bot, image_object, reply_content) await forward_to_google_api(
message, bot, image_object, reply_content, no_context
)
return True return True
return False return False
@@ -361,11 +363,15 @@ async def handle_message(message, bot):
await handle_dm(message) await handle_dm(message)
return return
grok_role = message.guild.get_role(GROK_ROLE_ID) if bot.user in message.mentions:
if grok_role in message.role_mentions or bot.user in message.mentions:
if await handle_bot_mention(message, bot): if await handle_bot_mention(message, bot):
return return
grok_role = message.guild.get_role(GROK_ROLE_ID)
if grok_role in message.role_mentions:
if await handle_bot_mention(message, bot, True):
return
# Too many mentions # Too many mentions
if len(message.mentions) >= 3: if len(message.mentions) >= 3:
member = message.guild.get_member(message.author.id) member = message.guild.get_member(message.author.id)