fix: Add helper function to normalize finish_reason in TetoCog and update related logic

This commit is contained in:
Slipstream 2025-05-28 23:53:41 -06:00
parent 55af27a133
commit 308f893df6
Signed by: slipstream
GPG Key ID: 13E498CE010AC6FD

View File

@ -143,6 +143,13 @@ def _get_response_text(response: Optional[types.GenerateContentResponse]) -> Opt
return None
class TetoCog(commands.Cog):
# Helper function to normalize finish_reason
def _as_str(self, fr):
if fr is None:
return None
# Enum -> take .name ; string -> leave as is
return fr.name if hasattr(fr, "name") else str(fr)
# Define command groups at class level
ame_group = app_commands.Group(
name="ame",
@ -425,9 +432,12 @@ class TetoCog(commands.Cog):
raise RuntimeError("Vertex AI response had no candidates.")
candidate = response.candidates[0]
finish_reason = getattr(candidate, "finish_reason", None)
finish_reason_str = self._as_str(finish_reason)
# Check for function calls
if candidate.finish_reason == types.FinishReason.FUNCTION_CALL:
if finish_reason_str == "FUNCTION_CALL":
if not candidate.content or not candidate.content.parts:
# Model asked to call a function but provided no content/parts
return "(Model asked to call a function I didnt give it—check tool config.)"
@ -490,7 +500,6 @@ class TetoCog(commands.Cog):
return final_ai_text_response
else:
# If response has no text part (e.g. only safety block or empty)
finish_reason_str = types.FinishReason(candidate.finish_reason).name if candidate.finish_reason else "UNKNOWN"
safety_ratings_str = ""
if candidate.safety_ratings:
safety_ratings_str = ", ".join([f"{rating.category.name}: {rating.probability.name}" for rating in candidate.safety_ratings])
@ -501,7 +510,7 @@ class TetoCog(commands.Cog):
# If blocked by safety, we should inform the user or log appropriately.
# For now, returning a generic message.
if candidate.finish_reason == types.FinishReason.SAFETY:
if finish_reason_str == "SAFETY":
return f"(Teto AI response was blocked due to safety settings: {safety_ratings_str})"
print(f"[TETO DEBUG] {error_detail}") # Log it