-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
apps/ai_reports: connect with external API
- add a celery task which sends a comment to the XAI server and saves the response as AiReport - connect comment post_save signal with celery task - rename category to label to be in line with the XAI response - change AiReport fields to JSONField for now - add tests - **BREAKING CHANGE** Reset migrations for the ai_reports app (see changelog)
- Loading branch information
Showing
17 changed files
with
188 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
17 changes: 0 additions & 17 deletions
17
apps/ai_reports/migrations/0002_aireport_show_in_discussion.py
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,24 @@ | ||
import json | ||
|
||
from rest_framework import serializers | ||
|
||
from apps.ai_reports.models import AiReport | ||
|
||
|
||
class AiReportSerializer(serializers.ModelSerializer): | ||
explanation = serializers.SerializerMethodField() | ||
|
||
class Meta: | ||
model = AiReport | ||
fields = ( | ||
"category", | ||
"label", | ||
"confidence", | ||
"explanation", | ||
"is_pending", | ||
"comment", | ||
"show_in_discussion", | ||
) | ||
|
||
# FIXME: remove once frontend knows what to do with this | ||
def get_explanation(self, ai_report: AiReport) -> str: | ||
return json.dumps(ai_report.explanation) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import httpx | ||
from django.conf import settings | ||
from django.db.models import signals | ||
from django.dispatch import receiver | ||
|
||
from adhocracy4.comments.models import Comment | ||
from apps.ai_reports.tasks import get_classification_for_comment | ||
|
||
client = httpx.Client() | ||
|
||
|
||
@receiver(signals.post_save, sender=Comment) | ||
def get_ai_classification(sender, instance, created, update_fields, **kwargs): | ||
if getattr(settings, "XAI_API_URL"): | ||
comment_text_changed = getattr(instance, "_former_comment") != getattr( | ||
instance, "comment" | ||
) | ||
if created or comment_text_changed: | ||
# FIXME: use delay_on_commit() once updated to celery 5.x | ||
get_classification_for_comment.delay(instance.pk) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import ast | ||
|
||
import backoff | ||
import httpx | ||
from celery import shared_task | ||
from django.conf import settings | ||
|
||
from adhocracy4.comments.models import Comment | ||
from apps import logger | ||
from apps.ai_reports.models import AiReport | ||
|
||
client = httpx.Client() | ||
|
||
|
||
@shared_task | ||
def get_classification_for_comment(comment_pk: int) -> None: | ||
try: | ||
comment = Comment.objects.get(pk=comment_pk) | ||
response = call_ai_api(comment=comment.comment) | ||
if response.status_code == 200: | ||
extract_and_save_ai_classifications(comment=comment, report=response.json()) | ||
else: | ||
logger.error("Error: XAI server returned %s", response.status_code) | ||
except httpx.HTTPError as e: | ||
logger.error("Error connecting to %s: %s", settings.AI_API_URL, str(e)) | ||
|
||
|
||
def skip_retry(e: Exception) -> bool: | ||
if isinstance(e, httpx.HTTPStatusError): | ||
return 400 <= e.response.status_code < 500 | ||
return False | ||
|
||
|
||
@backoff.on_exception( | ||
backoff.expo, httpx.HTTPError, max_tries=4, factor=2, giveup=skip_retry | ||
) | ||
def call_ai_api(comment: str) -> httpx.Response: | ||
response = client.post( | ||
settings.XAI_API_URL, | ||
json={"comment": comment}, | ||
headers={"Accept": "application/json", "Content-Type": "application/json"}, | ||
timeout=25.0, | ||
) | ||
response.raise_for_status() | ||
return response | ||
|
||
|
||
def extract_and_save_ai_classifications(comment: Comment, report: dict) -> None: | ||
# FIXME: the data returned from the api is not actually valid json, so we need | ||
# to use ast to explicitly convert it. This should be fixed on their side. | ||
confidence = ast.literal_eval((report["confidence"])) | ||
label = ast.literal_eval(report["label"]) | ||
explanation = ast.literal_eval(report["explanation"]) | ||
|
||
ai_report = AiReport( | ||
comment=comment, confidence=confidence, label=label, explanation=explanation | ||
) | ||
ai_report.save() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
### Added | ||
|
||
- add celery task which sends a comment to the xai and stores the response as | ||
AiReport | ||
- add comment post_save signal which connects comment creation / editing with | ||
the xai celery task | ||
- added tests | ||
- add tests | ||
- add pytest-mock package to dev requirements | ||
- added backup and httpx to fork requirements | ||
|
||
### Changed | ||
|
||
- rename category to label to be in line with the XAI response | ||
- change AiReport fields to JSONField for now | ||
- **BREAKING CHANGE** Reset migrations for the ai_reports app as there's no | ||
automatic way to convert the `confidence` field to JSONField |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,4 @@ pytest==7.3.2 | |
pytest-cov==4.1.0 | ||
pytest-django==4.5.2 | ||
pytest-factoryboy==2.5.1 | ||
pytest-mock==3.14.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
# requirements needed in this fork, but not a+ | ||
backoff==2.2.1 | ||
httpx==0.27.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
from pytest_factoryboy import register | ||
|
||
from tests.ideas import factories as idea_factories | ||
|
||
from .factories import AiReportFactory | ||
|
||
register(AiReportFactory) | ||
register(idea_factories.IdeaFactory) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import pytest | ||
from django.db.models import signals | ||
from django.test import override_settings | ||
|
||
from adhocracy4.comments.models import Comment | ||
from apps.ai_reports.signals import get_ai_classification | ||
|
||
|
||
@pytest.mark.django_db | ||
def test_comment_not_sent_to_xai_no_url(mocker, idea, comment_factory, caplog): | ||
task_mock = mocker.patch( | ||
"apps.ai_reports.tasks.get_classification_for_comment" ".delay" | ||
) | ||
assert get_ai_classification in signals.post_save._live_receivers(Comment) | ||
comment_factory(content_object=idea, comment="lala") | ||
task_mock.assert_not_called() | ||
|
||
|
||
@override_settings(XAI_API_URL="https://liqd.net") | ||
@pytest.mark.django_db | ||
def test_comment_sent_to_xai_on_comment_text_change(mocker, idea, comment_factory): | ||
task_mock = mocker.patch( | ||
"apps.ai_reports.tasks.get_classification_for_comment" ".delay" | ||
) | ||
|
||
assert get_ai_classification in signals.post_save._live_receivers(Comment) | ||
comment = comment_factory(content_object=idea, comment="lala") | ||
task_mock.assert_called_once_with(comment.pk) | ||
task_mock.reset_mock() | ||
|
||
comment.comment = "modified comment" | ||
comment.save() | ||
task_mock.assert_called_once_with(comment.pk) | ||
|
||
|
||
@override_settings(XAI_API_URL="https://liqd.net") | ||
@pytest.mark.django_db | ||
def test_comment_not_sent_to_xai_without_comment_text_change( | ||
mocker, idea, comment_factory, caplog | ||
): | ||
task_mock = mocker.patch( | ||
"apps.ai_reports.tasks.get_classification_for_comment" ".delay" | ||
) | ||
assert get_ai_classification in signals.post_save._live_receivers(Comment) | ||
comment = comment_factory(content_object=idea, comment="lala") | ||
task_mock.assert_called_once_with(comment.pk) | ||
task_mock.reset_mock() | ||
|
||
comment.is_blocked = True | ||
comment.save() | ||
task_mock.assert_not_called() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters