diff --git a/app/.github/workflows/main.yml b/app/.github/workflows/main.yml index 6e2ffd6..360a4ba 100644 --- a/app/.github/workflows/main.yml +++ b/app/.github/workflows/main.yml @@ -109,7 +109,7 @@ jobs: GITHUB_ACTIONS_TEST: true - name: Archive code coverage results - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: code-coverage-report path: htmlcov diff --git a/app/app/alias_utils.py b/app/app/alias_utils.py index faed385..49aade1 100644 --- a/app/app/alias_utils.py +++ b/app/app/alias_utils.py @@ -363,11 +363,16 @@ def delete_alias( Session.commit() LOG.i(f"Moving {alias} to global trash {deleted_alias}") + alias_id = alias.id + alias_email = alias.email Alias.filter(Alias.id == alias.id).delete() Session.commit() EventDispatcher.send_event( - user, EventContent(alias_deleted=AliasDeleted(alias_id=alias.id)) + user, + EventContent( + alias_deleted=AliasDeleted(alias_id=alias_id, alias_email=alias_email) + ), ) if commit: Session.commit() diff --git a/app/app/dashboard/views/mailbox.py b/app/app/dashboard/views/mailbox.py index 2fba1b5..36c7f78 100644 --- a/app/app/dashboard/views/mailbox.py +++ b/app/app/dashboard/views/mailbox.py @@ -128,7 +128,7 @@ def mailbox_verify(): except mailbox_utils.MailboxError as e: LOG.i(f"Cannot verify mailbox {mailbox_id} because of {e}") flash(f"Cannot verify mailbox: {e.msg}", "error") - return render_template("dashboard/mailbox_validation.html", mailbox=mailbox) + return redirect(url_for("dashboard.mailbox_route")) LOG.d("Mailbox %s is verified", mailbox) return render_template("dashboard/mailbox_validation.html", mailbox=mailbox) diff --git a/app/app/events/event_dispatcher.py b/app/app/events/event_dispatcher.py index ec2f0a4..7c281cc 100644 --- a/app/app/events/event_dispatcher.py +++ b/app/app/events/event_dispatcher.py @@ -30,14 +30,30 @@ class PostgresDispatcher(Dispatcher): return PostgresDispatcher() +class GlobalDispatcher: + __dispatcher: Optional[Dispatcher] = None + + @staticmethod + def get_dispatcher() -> Dispatcher: + if not GlobalDispatcher.__dispatcher: + GlobalDispatcher.__dispatcher = PostgresDispatcher.get() + return GlobalDispatcher.__dispatcher + + @staticmethod + def set_dispatcher(dispatcher: Optional[Dispatcher]): + GlobalDispatcher.__dispatcher = dispatcher + + class EventDispatcher: @staticmethod def send_event( user: User, content: event_pb2.EventContent, - dispatcher: Dispatcher = PostgresDispatcher.get(), + dispatcher: Optional[Dispatcher] = None, skip_if_webhook_missing: bool = True, ): + if dispatcher is None: + dispatcher = GlobalDispatcher.get_dispatcher() if config.EVENT_WEBHOOK_DISABLE: LOG.i("Not sending events because webhook is disabled") return diff --git a/app/app/proton/proton_callback_handler.py b/app/app/proton/proton_callback_handler.py index 53c8076..f726d48 100644 --- a/app/app/proton/proton_callback_handler.py +++ b/app/app/proton/proton_callback_handler.py @@ -2,9 +2,11 @@ from dataclasses import dataclass from enum import Enum from flask import url_for from typing import Optional +import arrow +from app import config from app.errors import LinkException -from app.models import User, Partner +from app.models import User, Partner, Job from app.proton.proton_client import ProtonClient, ProtonUser from app.account_linking import ( process_login_case, @@ -41,12 +43,21 @@ class ProtonCallbackHandler: def __init__(self, proton_client: ProtonClient): self.proton_client = proton_client + def _initial_alias_sync(self, user: User): + Job.create( + name=config.JOB_SEND_ALIAS_CREATION_EVENTS, + payload={"user_id": user.id}, + run_at=arrow.now(), + commit=True, + ) + def handle_login(self, partner: Partner) -> ProtonCallbackResult: try: user = self.__get_partner_user() if user is None: return generate_account_not_allowed_to_log_in() res = process_login_case(user, partner) + self._initial_alias_sync(res.user) return ProtonCallbackResult( redirect_to_login=False, flash_message=None, @@ -75,6 +86,7 @@ class ProtonCallbackHandler: if user is None: return generate_account_not_allowed_to_log_in() res = process_link_case(user, current_user, partner) + self._initial_alias_sync(res.user) return ProtonCallbackResult( redirect_to_login=False, flash_message="Account successfully linked", diff --git a/app/email_handler.py b/app/email_handler.py index 84301ac..69f3785 100644 --- a/app/email_handler.py +++ b/app/email_handler.py @@ -262,8 +262,6 @@ def get_or_create_contact(from_header: str, mail_from: str, alias: Alias) -> Con Session.commit() except IntegrityError: - # If the tx has been rolled back, the connection is borked. Force close to try to get a new one and start fresh - Session.close() LOG.info( f"Contact with email {contact_email} for alias_id {alias_id} already existed, fetching from DB" ) @@ -818,7 +816,7 @@ def forward_email_to_mailbox( email_log = EmailLog.create( contact_id=contact.id, - user_id=user.id, + user_id=contact.user_id, mailbox_id=mailbox.id, alias_id=contact.alias_id, message_id=str(msg[headers.MESSAGE_ID]), diff --git a/app/oneshot/alias_partner_set_flag_and_clear_note.py b/app/oneshot/alias_partner_set_flag_and_clear_note.py new file mode 100644 index 0000000..b2d60db --- /dev/null +++ b/app/oneshot/alias_partner_set_flag_and_clear_note.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +import argparse +import time + +from sqlalchemy import func +from app.models import Alias +from app.db import Session + +parser = argparse.ArgumentParser( + prog="Backfill alias", description="Update alias notes and backfill flag" +) +parser.add_argument( + "-s", "--start_alias_id", default=0, type=int, help="Initial alias_id" +) +parser.add_argument("-e", "--end_alias_id", default=0, type=int, help="Last alias_id") + +args = parser.parse_args() +alias_id_start = args.start_alias_id +max_alias_id = args.end_alias_id +if max_alias_id == 0: + max_alias_id = Session.query(func.max(Alias.id)).scalar() + +print(f"Checking alias {alias_id_start} to {max_alias_id}") +step = 1000 +noteSql = "(note = 'Created through Proton' or note = 'Created through partner Proton')" +alias_query = f"UPDATE alias set note = NULL, flags = flags | :flag where id>=:start AND id<:end and {noteSql}" +updated = 0 +start_time = time.time() +for batch_start in range(alias_id_start, max_alias_id, step): + rows_done = Session.execute( + alias_query, + { + "start": batch_start, + "end": batch_start + step, + "flag": Alias.FLAG_PARTNER_CREATED, + }, + ) + updated += rows_done.rowcount + Session.commit() + elapsed = time.time() - start_time + time_per_alias = elapsed / (updated + 1) + last_batch_id = batch_start + step + remaining = max_alias_id - last_batch_id + time_remaining = (max_alias_id - last_batch_id) * time_per_alias + hours_remaining = time_remaining / 3600.0 + print( + f"\rAlias {batch_start}/{max_alias_id} {updated} {hours_remaining:.2f}hrs remaining" + ) +print("") diff --git a/app/templates/dashboard/enter_sudo.html b/app/templates/dashboard/enter_sudo.html index 3175f37..e1b3248 100644 --- a/app/templates/dashboard/enter_sudo.html +++ b/app/templates/dashboard/enter_sudo.html @@ -21,8 +21,8 @@

Alternatively you can use your Proton credentials to ensure it's you.

- + Authenticate with Proton @@ -38,4 +38,4 @@ {% endif %} - {% endblock %} + {% endblock %} \ No newline at end of file diff --git a/app/tests/events/event_test_utils.py b/app/tests/events/event_test_utils.py new file mode 100644 index 0000000..3a5342c --- /dev/null +++ b/app/tests/events/event_test_utils.py @@ -0,0 +1,32 @@ +from app.events.event_dispatcher import Dispatcher +from app.models import PartnerUser, User +from app.proton.utils import get_proton_partner +from tests.utils import create_new_user, random_token +from typing import Tuple + + +class OnMemoryDispatcher(Dispatcher): + def __init__(self): + self.memory = [] + + def send(self, event: bytes): + self.memory.append(event) + + def clear(self): + self.memory = [] + + +def _create_unlinked_user() -> User: + return create_new_user() + + +def _create_linked_user() -> Tuple[User, PartnerUser]: + user = _create_unlinked_user() + partner_user = PartnerUser.create( + partner_id=get_proton_partner().id, + user_id=user.id, + external_user_id=random_token(10), + flush=True, + ) + + return user, partner_user diff --git a/app/tests/events/test_event_dispatcher.py b/app/tests/events/test_event_dispatcher.py index 810e147..c142970 100644 --- a/app/tests/events/test_event_dispatcher.py +++ b/app/tests/events/test_event_dispatcher.py @@ -1,33 +1,10 @@ -from app.events.event_dispatcher import EventDispatcher, Dispatcher +from app.events.event_dispatcher import EventDispatcher from app.events.generated.event_pb2 import EventContent, UserDeleted -from app.models import PartnerUser, User -from app.proton.utils import get_proton_partner -from tests.utils import create_new_user, random_token -from typing import Tuple - - -class OnMemoryDispatcher(Dispatcher): - def __init__(self): - self.memory = [] - - def send(self, event: bytes): - self.memory.append(event) - - -def _create_unlinked_user() -> User: - return create_new_user() - - -def _create_linked_user() -> Tuple[User, PartnerUser]: - user = _create_unlinked_user() - partner_user = PartnerUser.create( - partner_id=get_proton_partner().id, - user_id=user.id, - external_user_id=random_token(10), - flush=True, - ) - - return user, partner_user +from .event_test_utils import ( + _create_unlinked_user, + OnMemoryDispatcher, + _create_linked_user, +) def test_event_dispatcher_stores_events(): diff --git a/app/tests/events/test_sent_events.py b/app/tests/events/test_sent_events.py new file mode 100644 index 0000000..ab946dd --- /dev/null +++ b/app/tests/events/test_sent_events.py @@ -0,0 +1,100 @@ +from app import config, alias_utils +from app.db import Session +from app.events.event_dispatcher import GlobalDispatcher +from app.events.generated import event_pb2 +from app.models import Alias, User, PartnerUser +from tests.utils import random_token +from .event_test_utils import ( + OnMemoryDispatcher, + _create_linked_user, +) + +on_memory_dispatcher = OnMemoryDispatcher() + + +def setup_module(): + GlobalDispatcher.set_dispatcher(on_memory_dispatcher) + config.EVENT_WEBHOOK = "http://test" + + +def teardown_module(): + GlobalDispatcher.set_dispatcher(None) + config.EVENT_WEBHOOK = None + + +def setup_function(func): + on_memory_dispatcher.clear() + + +def _get_event_from_string( + data: str, user: User, pu: PartnerUser +) -> event_pb2.EventContent: + event = event_pb2.Event() + event.ParseFromString(data) + assert user.id == event.user_id + assert pu.external_user_id == event.external_user_id + assert pu.partner_id == event.partner_id + return event.content + + +def test_fire_event_on_alias_creation(): + (user, pu) = _create_linked_user() + alias = Alias.create_new_random(user) + Session.flush() + assert len(on_memory_dispatcher.memory) == 1 + event_data = on_memory_dispatcher.memory[0] + event_content = _get_event_from_string(event_data, user, pu) + assert event_content.alias_created is not None + alias_created = event_content.alias_created + assert alias.id == alias_created.alias_id + assert alias.email == alias_created.alias_email + assert "" == alias_created.alias_note + assert alias.enabled == alias_created.enabled + + +def test_fire_event_on_alias_creation_with_note(): + (user, pu) = _create_linked_user() + note = random_token(10) + alias = Alias.create_new_random(user, note=note) + Session.flush() + assert len(on_memory_dispatcher.memory) == 1 + event_data = on_memory_dispatcher.memory[0] + event_content = _get_event_from_string(event_data, user, pu) + assert event_content.alias_created is not None + alias_created = event_content.alias_created + assert alias.id == alias_created.alias_id + assert alias.email == alias_created.alias_email + assert note == alias_created.alias_note + assert alias.enabled == alias_created.enabled + + +def test_fire_event_on_alias_deletion(): + (user, pu) = _create_linked_user() + alias = Alias.create_new_random(user) + alias_id = alias.id + Session.flush() + on_memory_dispatcher.clear() + alias_utils.delete_alias(alias, user) + assert len(on_memory_dispatcher.memory) == 1 + event_data = on_memory_dispatcher.memory[0] + event_content = _get_event_from_string(event_data, user, pu) + assert event_content.alias_deleted is not None + alias_deleted = event_content.alias_deleted + assert alias_id == alias_deleted.alias_id + assert alias.email == alias_deleted.alias_email + + +def test_fire_event_on_alias_status_change(): + (user, pu) = _create_linked_user() + alias = Alias.create_new_random(user) + Session.flush() + on_memory_dispatcher.clear() + alias_utils.change_alias_status(alias, True) + assert len(on_memory_dispatcher.memory) == 1 + event_data = on_memory_dispatcher.memory[0] + event_content = _get_event_from_string(event_data, user, pu) + assert event_content.alias_status_change is not None + event = event_content.alias_status_change + assert alias.id == event.alias_id + assert alias.email == event.alias_email + assert event.enabled diff --git a/app/tests/proton/test_proton_callback_handler.py b/app/tests/proton/test_proton_callback_handler.py index fd7f8bb..9916f99 100644 --- a/app/tests/proton/test_proton_callback_handler.py +++ b/app/tests/proton/test_proton_callback_handler.py @@ -1,4 +1,6 @@ from arrow import Arrow + +from app import config from app.account_linking import ( SLPlan, SLPlanType, @@ -8,7 +10,7 @@ from app.proton.proton_callback_handler import ( ProtonCallbackHandler, generate_account_not_allowed_to_log_in, ) -from app.models import User, PartnerUser +from app.models import User, PartnerUser, Job, JobState from app.proton.utils import get_proton_partner from app.utils import random_string from typing import Optional @@ -23,6 +25,17 @@ class MockProtonClient(ProtonClient): return self.user +def check_initial_sync_job(user: User): + for job in Job.yield_per_query(10).filter_by( + name=config.JOB_SEND_ALIAS_CREATION_EVENTS, + state=JobState.ready.value, + ): + if job.payload.get("user_id") == user.id: + Job.delete(job.id) + return + assert False + + def test_proton_callback_handler_unexistant_sl_user(): email = random_email() name = random_string() @@ -56,6 +69,7 @@ def test_proton_callback_handler_unexistant_sl_user(): ) assert partner_user is not None assert partner_user.external_user_id == external_id + check_initial_sync_job(res.user) def test_proton_callback_handler_existant_sl_user(): @@ -84,6 +98,7 @@ def test_proton_callback_handler_existant_sl_user(): sa = PartnerUser.get_by(user_id=sl_user.id, partner_id=get_proton_partner().id) assert sa is not None assert sa.partner_email == user.email + check_initial_sync_job(res.user) def test_proton_callback_handler_none_user_login():