Compare commits

...

3 Commits

Author SHA1 Message Date
357f0cca57 4.49.10
Some checks failed
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m7s
Build-Release-Image / Build-Image (linux/amd64) (push) Has been cancelled
Build-Release-Image / Merge-Images (push) Has been cancelled
Build-Release-Image / Create-Release (push) Has been cancelled
Build-Release-Image / Notify (push) Has been cancelled
2024-09-13 12:00:28 +01:00
8ce90e27f7 4.49.9
Some checks failed
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m4s
Build-Release-Image / Build-Image (linux/amd64) (push) Has been cancelled
Build-Release-Image / Merge-Images (push) Has been cancelled
Build-Release-Image / Create-Release (push) Has been cancelled
Build-Release-Image / Notify (push) Has been cancelled
2024-09-07 12:00:06 +01:00
3ecc8d36f9 4.49.8
Some checks failed
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m18s
Build-Release-Image / Build-Image (linux/amd64) (push) Has been cancelled
Build-Release-Image / Merge-Images (push) Has been cancelled
Build-Release-Image / Create-Release (push) Has been cancelled
Build-Release-Image / Notify (push) Has been cancelled
2024-09-04 12:00:07 +01:00
15 changed files with 294 additions and 75 deletions

View File

@ -109,7 +109,7 @@ jobs:
GITHUB_ACTIONS_TEST: true GITHUB_ACTIONS_TEST: true
- name: Archive code coverage results - name: Archive code coverage results
uses: actions/upload-artifact@v2 uses: actions/upload-artifact@v4
with: with:
name: code-coverage-report name: code-coverage-report
path: htmlcov path: htmlcov

View File

@ -736,7 +736,8 @@ class InvalidMailboxDomainAdmin(SLModelView):
class EmailSearchResult: class EmailSearchResult:
no_match: bool = True no_match: bool = True
alias: Optional[Alias] = None alias: Optional[Alias] = None
mailbox: Optional[Mailbox] = None mailbox: list[Mailbox] = []
mailbox_count: int = 0
deleted_alias: Optional[DeletedAlias] = None deleted_alias: Optional[DeletedAlias] = None
deleted_custom_alias: Optional[DomainDeletedAlias] = None deleted_custom_alias: Optional[DomainDeletedAlias] = None
user: Optional[User] = None user: Optional[User] = None
@ -752,9 +753,12 @@ class EmailSearchResult:
if user: if user:
output.user = user output.user = user
output.no_match = False output.no_match = False
mailbox = Mailbox.get_by(email=email) mailboxes = (
if mailbox: Mailbox.filter_by(email=email).order_by(Mailbox.id.desc()).limit(10).all()
output.mailbox = mailbox )
if mailboxes:
output.mailbox = mailboxes
output.mailbox_count = Mailbox.filter_by(email=email).count()
output.no_match = False output.no_match = False
deleted_alias = DeletedAlias.get_by(email=email) deleted_alias = DeletedAlias.get_by(email=email)
if deleted_alias: if deleted_alias:
@ -779,11 +783,13 @@ class EmailSearchHelpers:
@staticmethod @staticmethod
def mailbox_count(user: User) -> int: def mailbox_count(user: User) -> int:
return Mailbox.filter_by(user_id=user.id).order_by(Mailbox.id.asc()).count() return Mailbox.filter_by(user_id=user.id).order_by(Mailbox.id.desc()).count()
@staticmethod @staticmethod
def alias_list(user: User) -> list[Alias]: def alias_list(user: User) -> list[Alias]:
return Alias.filter_by(user_id=user.id).order_by(Alias.id.asc()).limit(10).all() return (
Alias.filter_by(user_id=user.id).order_by(Alias.id.desc()).limit(10).all()
)
@staticmethod @staticmethod
def alias_count(user: User) -> int: def alias_count(user: User) -> int:
@ -806,9 +812,8 @@ class EmailSearchAdmin(BaseView):
@expose("/", methods=["GET", "POST"]) @expose("/", methods=["GET", "POST"])
def index(self): def index(self):
search = EmailSearchResult() search = EmailSearchResult()
email = "" email = request.args.get("email")
if request.form and request.form["email"]: if email is not None and len(email) > 0:
email = request.form["email"]
email = email.strip() email = email.strip()
search = EmailSearchResult.from_email(email) search = EmailSearchResult.from_email(email)

View File

@ -363,11 +363,16 @@ def delete_alias(
Session.commit() Session.commit()
LOG.i(f"Moving {alias} to global trash {deleted_alias}") LOG.i(f"Moving {alias} to global trash {deleted_alias}")
alias_id = alias.id
alias_email = alias.email
Alias.filter(Alias.id == alias.id).delete() Alias.filter(Alias.id == alias.id).delete()
Session.commit() Session.commit()
EventDispatcher.send_event( EventDispatcher.send_event(
user, EventContent(alias_deleted=AliasDeleted(alias_id=alias.id)) user,
EventContent(
alias_deleted=AliasDeleted(alias_id=alias_id, alias_email=alias_email)
),
) )
if commit: if commit:
Session.commit() Session.commit()

View File

@ -123,7 +123,12 @@ def mailbox_verify():
if not code: if not code:
# Old way # Old way
return verify_with_signed_secret(mailbox_id) return verify_with_signed_secret(mailbox_id)
mailbox = mailbox_utils.verify_mailbox_code(current_user, mailbox_id, code) try:
mailbox = mailbox_utils.verify_mailbox_code(current_user, mailbox_id, code)
except mailbox_utils.MailboxError as e:
LOG.i(f"Cannot verify mailbox {mailbox_id} because of {e}")
flash(f"Cannot verify mailbox: {e.msg}", "error")
return redirect(url_for("dashboard.mailbox_route"))
LOG.d("Mailbox %s is verified", mailbox) LOG.d("Mailbox %s is verified", mailbox)
return render_template("dashboard/mailbox_validation.html", mailbox=mailbox) return render_template("dashboard/mailbox_validation.html", mailbox=mailbox)

View File

@ -30,14 +30,30 @@ class PostgresDispatcher(Dispatcher):
return PostgresDispatcher() return PostgresDispatcher()
class GlobalDispatcher:
__dispatcher: Optional[Dispatcher] = None
@staticmethod
def get_dispatcher() -> Dispatcher:
if not GlobalDispatcher.__dispatcher:
GlobalDispatcher.__dispatcher = PostgresDispatcher.get()
return GlobalDispatcher.__dispatcher
@staticmethod
def set_dispatcher(dispatcher: Optional[Dispatcher]):
GlobalDispatcher.__dispatcher = dispatcher
class EventDispatcher: class EventDispatcher:
@staticmethod @staticmethod
def send_event( def send_event(
user: User, user: User,
content: event_pb2.EventContent, content: event_pb2.EventContent,
dispatcher: Dispatcher = PostgresDispatcher.get(), dispatcher: Optional[Dispatcher] = None,
skip_if_webhook_missing: bool = True, skip_if_webhook_missing: bool = True,
): ):
if dispatcher is None:
dispatcher = GlobalDispatcher.get_dispatcher()
if config.EVENT_WEBHOOK_DISABLE: if config.EVENT_WEBHOOK_DISABLE:
LOG.i("Not sending events because webhook is disabled") LOG.i("Not sending events because webhook is disabled")
return return

View File

@ -1660,18 +1660,6 @@ class Alias(Base, ModelMixin):
Session.add(new_alias) Session.add(new_alias)
DailyMetric.get_or_create_today_metric().nb_alias += 1 DailyMetric.get_or_create_today_metric().nb_alias += 1
# Internal import to avoid global import cycles
from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import AliasCreated, EventContent
event = AliasCreated(
alias_id=new_alias.id,
alias_email=new_alias.email,
alias_note=new_alias.note,
enabled=True,
)
EventDispatcher.send_event(user, EventContent(alias_created=event))
if ( if (
new_alias.flags & cls.FLAG_PARTNER_CREATED > 0 new_alias.flags & cls.FLAG_PARTNER_CREATED > 0
and new_alias.user.flags & User.FLAG_CREATED_ALIAS_FROM_PARTNER == 0 and new_alias.user.flags & User.FLAG_CREATED_ALIAS_FROM_PARTNER == 0
@ -1684,6 +1672,18 @@ class Alias(Base, ModelMixin):
if flush: if flush:
Session.flush() Session.flush()
# Internal import to avoid global import cycles
from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import AliasCreated, EventContent
event = AliasCreated(
alias_id=new_alias.id,
alias_email=new_alias.email,
alias_note=new_alias.note,
enabled=True,
)
EventDispatcher.send_event(user, EventContent(alias_created=event))
return new_alias return new_alias
@classmethod @classmethod

View File

@ -2,9 +2,11 @@ from dataclasses import dataclass
from enum import Enum from enum import Enum
from flask import url_for from flask import url_for
from typing import Optional from typing import Optional
import arrow
from app import config
from app.errors import LinkException from app.errors import LinkException
from app.models import User, Partner from app.models import User, Partner, Job
from app.proton.proton_client import ProtonClient, ProtonUser from app.proton.proton_client import ProtonClient, ProtonUser
from app.account_linking import ( from app.account_linking import (
process_login_case, process_login_case,
@ -41,12 +43,21 @@ class ProtonCallbackHandler:
def __init__(self, proton_client: ProtonClient): def __init__(self, proton_client: ProtonClient):
self.proton_client = proton_client self.proton_client = proton_client
def _initial_alias_sync(self, user: User):
Job.create(
name=config.JOB_SEND_ALIAS_CREATION_EVENTS,
payload={"user_id": user.id},
run_at=arrow.now(),
commit=True,
)
def handle_login(self, partner: Partner) -> ProtonCallbackResult: def handle_login(self, partner: Partner) -> ProtonCallbackResult:
try: try:
user = self.__get_partner_user() user = self.__get_partner_user()
if user is None: if user is None:
return generate_account_not_allowed_to_log_in() return generate_account_not_allowed_to_log_in()
res = process_login_case(user, partner) res = process_login_case(user, partner)
self._initial_alias_sync(res.user)
return ProtonCallbackResult( return ProtonCallbackResult(
redirect_to_login=False, redirect_to_login=False,
flash_message=None, flash_message=None,
@ -75,6 +86,7 @@ class ProtonCallbackHandler:
if user is None: if user is None:
return generate_account_not_allowed_to_log_in() return generate_account_not_allowed_to_log_in()
res = process_link_case(user, current_user, partner) res = process_link_case(user, current_user, partner)
self._initial_alias_sync(res.user)
return ProtonCallbackResult( return ProtonCallbackResult(
redirect_to_login=False, redirect_to_login=False,
flash_message="Account successfully linked", flash_message="Account successfully linked",

View File

@ -262,8 +262,6 @@ def get_or_create_contact(from_header: str, mail_from: str, alias: Alias) -> Con
Session.commit() Session.commit()
except IntegrityError: except IntegrityError:
# If the tx has been rolled back, the connection is borked. Force close to try to get a new one and start fresh
Session.close()
LOG.info( LOG.info(
f"Contact with email {contact_email} for alias_id {alias_id} already existed, fetching from DB" f"Contact with email {contact_email} for alias_id {alias_id} already existed, fetching from DB"
) )
@ -818,7 +816,7 @@ def forward_email_to_mailbox(
email_log = EmailLog.create( email_log = EmailLog.create(
contact_id=contact.id, contact_id=contact.id,
user_id=user.id, user_id=contact.user_id,
mailbox_id=mailbox.id, mailbox_id=mailbox.id,
alias_id=contact.alias_id, alias_id=contact.alias_id,
message_id=str(msg[headers.MESSAGE_ID]), message_id=str(msg[headers.MESSAGE_ID]),

View File

@ -0,0 +1,49 @@
#!/usr/bin/env python3
import argparse
import time
from sqlalchemy import func
from app.models import Alias
from app.db import Session
parser = argparse.ArgumentParser(
prog="Backfill alias", description="Update alias notes and backfill flag"
)
parser.add_argument(
"-s", "--start_alias_id", default=0, type=int, help="Initial alias_id"
)
parser.add_argument("-e", "--end_alias_id", default=0, type=int, help="Last alias_id")
args = parser.parse_args()
alias_id_start = args.start_alias_id
max_alias_id = args.end_alias_id
if max_alias_id == 0:
max_alias_id = Session.query(func.max(Alias.id)).scalar()
print(f"Checking alias {alias_id_start} to {max_alias_id}")
step = 1000
noteSql = "(note = 'Created through Proton' or note = 'Created through partner Proton')"
alias_query = f"UPDATE alias set note = NULL, flags = flags | :flag where id>=:start AND id<:end and {noteSql}"
updated = 0
start_time = time.time()
for batch_start in range(alias_id_start, max_alias_id, step):
rows_done = Session.execute(
alias_query,
{
"start": batch_start,
"end": batch_start + step,
"flag": Alias.FLAG_PARTNER_CREATED,
},
)
updated += rows_done.rowcount
Session.commit()
elapsed = time.time() - start_time
time_per_alias = elapsed / (updated + 1)
last_batch_id = batch_start + step
remaining = max_alias_id - last_batch_id
time_remaining = (max_alias_id - last_batch_id) * time_per_alias
hours_remaining = time_remaining / 3600.0
print(
f"\rAlias {batch_start}/{max_alias_id} {updated} {hours_remaining:.2f}hrs remaining"
)
print("")

View File

@ -19,7 +19,7 @@
<tbody> <tbody>
<tr> <tr>
<td>{{ user.id }}</td> <td>{{ user.id }}</td>
<td>{{ user.email }}</td> <td><a href="?email={{ user.email }}">{{ user.email }}</a></td>
{% if user.disabled %} {% if user.disabled %}
<td class="text-danger">Disabled</td> <td class="text-danger">Disabled</td>
@ -32,7 +32,7 @@
<td>{{ user.updated_at }}</td> <td>{{ user.updated_at }}</td>
{% if pu %} {% if pu %}
<td>{{ pu.partner_email }}</td> <td><a href="?email={{ pu.partner_email }}">{{ pu.partner_email }}</a></td>
{% else %} {% else %}
<td>No</td> <td>No</td>
{% endif %} {% endif %}
@ -43,7 +43,7 @@
{% macro list_mailboxes(mbox_count, mboxes) %} {% macro list_mailboxes(mbox_count, mboxes) %}
<h4> <h4>
{{ mbox_count }} Mailboxes found. {{ mbox_count }} Mailboxes found.
{% if mbox_count>10 %}Showing only the first 10.{% endif %} {% if mbox_count>10 %}Showing only the last 10.{% endif %}
</h4> </h4>
<table class="table"> <table class="table">
<thead> <thead>
@ -59,7 +59,7 @@
<tr> <tr>
<td>{{ mailbox.id }}</td> <td>{{ mailbox.id }}</td>
<td>{{ mailbox.email }}</td> <td><a href="?email={{mailbox.email}}">{{mailbox.email}}</a></td>
<td>{{ "Yes" if mailbox.verified else "No" }}</td> <td>{{ "Yes" if mailbox.verified else "No" }}</td>
<td> <td>
{{ mailbox.created_at }} {{ mailbox.created_at }}
@ -72,7 +72,7 @@
{% macro list_alias(alias_count, aliases) %} {% macro list_alias(alias_count, aliases) %}
<h4> <h4>
{{ alias_count }} Aliases found. {{ alias_count }} Aliases found.
{% if alias_count>10 %}Showing only the first 10.{% endif %} {% if alias_count>10 %}Showing only the last 10.{% endif %}
</h4> </h4>
<table class="table"> <table class="table">
<thead> <thead>
@ -95,7 +95,7 @@
{% for alias in aliases %} {% for alias in aliases %}
<tr> <tr>
<td>{{ alias.id }}</td> <td>{{ alias.id }}</td>
<td>{{ alias.email }}</td> <td><a href="?email={{alias.email}}">{{alias.email}}</a></td>
<td>{{ "Yes" if alias.verified else "No" }}</td> <td>{{ "Yes" if alias.verified else "No" }}</td>
<td>{{ alias.created_at }}</td> <td>{{ alias.created_at }}</td>
</tr> </tr>
@ -156,7 +156,7 @@
{% block body %} {% block body %}
<div class="border border-dark border-2 mt-1 mb-2 p-3"> <div class="border border-dark border-2 mt-1 mb-2 p-3">
<form method="post"> <form method="get">
<div class="form-group"> <div class="form-group">
<label for="email">Email to search:</label> <label for="email">Email to search:</label>
<input type="text" <input type="text"
@ -167,7 +167,7 @@
<button type="submit" class="btn btn-primary">Submit</button> <button type="submit" class="btn btn-primary">Submit</button>
</form> </form>
</div> </div>
{% if no_match %} {% if data.no_match and email %}
<div class="border border-dark border-2 mt-1 mb-2 p-3 alert alert-warning" <div class="border border-dark border-2 mt-1 mb-2 p-3 alert alert-warning"
role="alert">No user, alias or mailbox found for {{ email }}</div> role="alert">No user, alias or mailbox found for {{ email }}</div>
@ -190,14 +190,19 @@
{{ list_alias(helper.alias_count(data.user) ,helper.alias_list(data.user)) }} {{ list_alias(helper.alias_count(data.user) ,helper.alias_list(data.user)) }}
</div> </div>
{% endif %} {% endif %}
{% if data.mailbox %} {% if data.mailbox_count > 10 %}
<h3>Found more than 10 mailboxes for {{ email }}. Showing the last 10</h3>
{% elif data.mailbox_count > 0 %}
<h3>Found {{ data.mailbox_count }} mailbox(es) for {{ email }}</h3>
{% endif %}
{% for mailbox in data.mailbox %}
<div class="border border-dark mt-1 mb-2 p-3"> <div class="border border-dark mt-1 mb-2 p-3">
<h3 class="mb-3">Found Mailbox {{ data.mailbox.email }}</h3> <h3 class="mb-3">Found Mailbox {{ mailbox.email }}</h3>
{{ list_mailboxes(1, [data.mailbox]) }} {{ list_mailboxes(1, [mailbox]) }}
{{ show_user(data.mailbox.user) }} {{ show_user(mailbox.user) }}
</div> </div>
{% endif %} {% endfor %}
{% if data.deleted_alias %} {% if data.deleted_alias %}
<div class="border border-dark mt-1 mb-2 p-3"> <div class="border border-dark mt-1 mb-2 p-3">

View File

@ -21,8 +21,8 @@
<div class="my-3"> <div class="my-3">
<p>Alternatively you can use your Proton credentials to ensure it's you.</p> <p>Alternatively you can use your Proton credentials to ensure it's you.</p>
</div> </div>
<a class="btn btn-primary btn-block mt-2 proton-button w-25" <a class="btn btn-primary btn-block mt-2 proton-button"
href="{{ url_for('auth.proton_login', next=next) }}"> href="{{ url_for('auth.proton_login', next=next) }}" style="max-width: 400px">
<img class="mr-2" src="/static/images/proton.svg" /> <img class="mr-2" src="/static/images/proton.svg" />
Authenticate with Proton Authenticate with Proton
</a> </a>
@ -38,4 +38,4 @@
{% endif %} {% endif %}
</div> </div>
</div> </div>
{% endblock %} {% endblock %}

View File

@ -0,0 +1,32 @@
from app.events.event_dispatcher import Dispatcher
from app.models import PartnerUser, User
from app.proton.utils import get_proton_partner
from tests.utils import create_new_user, random_token
from typing import Tuple
class OnMemoryDispatcher(Dispatcher):
def __init__(self):
self.memory = []
def send(self, event: bytes):
self.memory.append(event)
def clear(self):
self.memory = []
def _create_unlinked_user() -> User:
return create_new_user()
def _create_linked_user() -> Tuple[User, PartnerUser]:
user = _create_unlinked_user()
partner_user = PartnerUser.create(
partner_id=get_proton_partner().id,
user_id=user.id,
external_user_id=random_token(10),
flush=True,
)
return user, partner_user

View File

@ -1,33 +1,10 @@
from app.events.event_dispatcher import EventDispatcher, Dispatcher from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import EventContent, UserDeleted from app.events.generated.event_pb2 import EventContent, UserDeleted
from app.models import PartnerUser, User from .event_test_utils import (
from app.proton.utils import get_proton_partner _create_unlinked_user,
from tests.utils import create_new_user, random_token OnMemoryDispatcher,
from typing import Tuple _create_linked_user,
)
class OnMemoryDispatcher(Dispatcher):
def __init__(self):
self.memory = []
def send(self, event: bytes):
self.memory.append(event)
def _create_unlinked_user() -> User:
return create_new_user()
def _create_linked_user() -> Tuple[User, PartnerUser]:
user = _create_unlinked_user()
partner_user = PartnerUser.create(
partner_id=get_proton_partner().id,
user_id=user.id,
external_user_id=random_token(10),
flush=True,
)
return user, partner_user
def test_event_dispatcher_stores_events(): def test_event_dispatcher_stores_events():

View File

@ -0,0 +1,100 @@
from app import config, alias_utils
from app.db import Session
from app.events.event_dispatcher import GlobalDispatcher
from app.events.generated import event_pb2
from app.models import Alias, User, PartnerUser
from tests.utils import random_token
from .event_test_utils import (
OnMemoryDispatcher,
_create_linked_user,
)
on_memory_dispatcher = OnMemoryDispatcher()
def setup_module():
GlobalDispatcher.set_dispatcher(on_memory_dispatcher)
config.EVENT_WEBHOOK = "http://test"
def teardown_module():
GlobalDispatcher.set_dispatcher(None)
config.EVENT_WEBHOOK = None
def setup_function(func):
on_memory_dispatcher.clear()
def _get_event_from_string(
data: str, user: User, pu: PartnerUser
) -> event_pb2.EventContent:
event = event_pb2.Event()
event.ParseFromString(data)
assert user.id == event.user_id
assert pu.external_user_id == event.external_user_id
assert pu.partner_id == event.partner_id
return event.content
def test_fire_event_on_alias_creation():
(user, pu) = _create_linked_user()
alias = Alias.create_new_random(user)
Session.flush()
assert len(on_memory_dispatcher.memory) == 1
event_data = on_memory_dispatcher.memory[0]
event_content = _get_event_from_string(event_data, user, pu)
assert event_content.alias_created is not None
alias_created = event_content.alias_created
assert alias.id == alias_created.alias_id
assert alias.email == alias_created.alias_email
assert "" == alias_created.alias_note
assert alias.enabled == alias_created.enabled
def test_fire_event_on_alias_creation_with_note():
(user, pu) = _create_linked_user()
note = random_token(10)
alias = Alias.create_new_random(user, note=note)
Session.flush()
assert len(on_memory_dispatcher.memory) == 1
event_data = on_memory_dispatcher.memory[0]
event_content = _get_event_from_string(event_data, user, pu)
assert event_content.alias_created is not None
alias_created = event_content.alias_created
assert alias.id == alias_created.alias_id
assert alias.email == alias_created.alias_email
assert note == alias_created.alias_note
assert alias.enabled == alias_created.enabled
def test_fire_event_on_alias_deletion():
(user, pu) = _create_linked_user()
alias = Alias.create_new_random(user)
alias_id = alias.id
Session.flush()
on_memory_dispatcher.clear()
alias_utils.delete_alias(alias, user)
assert len(on_memory_dispatcher.memory) == 1
event_data = on_memory_dispatcher.memory[0]
event_content = _get_event_from_string(event_data, user, pu)
assert event_content.alias_deleted is not None
alias_deleted = event_content.alias_deleted
assert alias_id == alias_deleted.alias_id
assert alias.email == alias_deleted.alias_email
def test_fire_event_on_alias_status_change():
(user, pu) = _create_linked_user()
alias = Alias.create_new_random(user)
Session.flush()
on_memory_dispatcher.clear()
alias_utils.change_alias_status(alias, True)
assert len(on_memory_dispatcher.memory) == 1
event_data = on_memory_dispatcher.memory[0]
event_content = _get_event_from_string(event_data, user, pu)
assert event_content.alias_status_change is not None
event = event_content.alias_status_change
assert alias.id == event.alias_id
assert alias.email == event.alias_email
assert event.enabled

View File

@ -1,4 +1,6 @@
from arrow import Arrow from arrow import Arrow
from app import config
from app.account_linking import ( from app.account_linking import (
SLPlan, SLPlan,
SLPlanType, SLPlanType,
@ -8,7 +10,7 @@ from app.proton.proton_callback_handler import (
ProtonCallbackHandler, ProtonCallbackHandler,
generate_account_not_allowed_to_log_in, generate_account_not_allowed_to_log_in,
) )
from app.models import User, PartnerUser from app.models import User, PartnerUser, Job, JobState
from app.proton.utils import get_proton_partner from app.proton.utils import get_proton_partner
from app.utils import random_string from app.utils import random_string
from typing import Optional from typing import Optional
@ -23,6 +25,17 @@ class MockProtonClient(ProtonClient):
return self.user return self.user
def check_initial_sync_job(user: User):
for job in Job.yield_per_query(10).filter_by(
name=config.JOB_SEND_ALIAS_CREATION_EVENTS,
state=JobState.ready.value,
):
if job.payload.get("user_id") == user.id:
Job.delete(job.id)
return
assert False
def test_proton_callback_handler_unexistant_sl_user(): def test_proton_callback_handler_unexistant_sl_user():
email = random_email() email = random_email()
name = random_string() name = random_string()
@ -56,6 +69,7 @@ def test_proton_callback_handler_unexistant_sl_user():
) )
assert partner_user is not None assert partner_user is not None
assert partner_user.external_user_id == external_id assert partner_user.external_user_id == external_id
check_initial_sync_job(res.user)
def test_proton_callback_handler_existant_sl_user(): def test_proton_callback_handler_existant_sl_user():
@ -84,6 +98,7 @@ def test_proton_callback_handler_existant_sl_user():
sa = PartnerUser.get_by(user_id=sl_user.id, partner_id=get_proton_partner().id) sa = PartnerUser.get_by(user_id=sl_user.id, partner_id=get_proton_partner().id)
assert sa is not None assert sa is not None
assert sa.partner_email == user.email assert sa.partner_email == user.email
check_initial_sync_job(res.user)
def test_proton_callback_handler_none_user_login(): def test_proton_callback_handler_none_user_login():