Compare commits

..

6 Commits

Author SHA1 Message Date
da6e56c4eb 4.55.0
All checks were successful
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m43s
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 4m10s
Build-Release-Image / Merge-Images (push) Successful in 27s
Build-Release-Image / Create-Release (push) Successful in 10s
Build-Release-Image / Notify (push) Successful in 3s
2024-10-18 12:00:06 +01:00
798b58529c 4.53.2
All checks were successful
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m29s
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m33s
Build-Release-Image / Merge-Images (push) Successful in 44s
Build-Release-Image / Create-Release (push) Successful in 7s
Build-Release-Image / Notify (push) Successful in 20s
2024-10-11 12:00:07 +01:00
3da6c983e1 4.53.1
All checks were successful
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m13s
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m54s
Build-Release-Image / Merge-Images (push) Successful in 16s
Build-Release-Image / Create-Release (push) Successful in 40s
Build-Release-Image / Notify (push) Successful in 5s
2024-10-09 12:00:06 +01:00
294232a329 4.52.1
All checks were successful
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m56s
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 4m45s
Build-Release-Image / Merge-Images (push) Successful in 22s
Build-Release-Image / Create-Release (push) Successful in 8s
Build-Release-Image / Notify (push) Successful in 3s
2024-10-02 12:00:06 +01:00
fae9d7bc17 4.52.0
All checks were successful
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 4m44s
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 4m31s
Build-Release-Image / Merge-Images (push) Successful in 23s
Build-Release-Image / Create-Release (push) Successful in 23s
Build-Release-Image / Notify (push) Successful in 17s
2024-10-01 12:00:06 +01:00
d666f5af3f 4.51.2
All checks were successful
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m33s
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m35s
Build-Release-Image / Merge-Images (push) Successful in 25s
Build-Release-Image / Create-Release (push) Successful in 10s
Build-Release-Image / Notify (push) Successful in 3s
2024-09-28 12:00:06 +01:00
81 changed files with 3222 additions and 1074 deletions

View File

@ -163,7 +163,7 @@ jobs:
uses: docker/build-push-action@v3 uses: docker/build-push-action@v3
with: with:
context: . context: .
platforms: linux/amd64,linux/arm64 platforms: linux/amd64
push: true push: true
tags: ${{ steps.meta.outputs.tags }} tags: ${{ steps.meta.outputs.tags }}

View File

@ -9,6 +9,7 @@ from sqlalchemy import or_
from app.db import Session from app.db import Session
from app.email_utils import send_welcome_email from app.email_utils import send_welcome_email
from app.partner_user_utils import create_partner_user, create_partner_subscription
from app.utils import sanitize_email, canonicalize_email from app.utils import sanitize_email, canonicalize_email
from app.errors import ( from app.errors import (
AccountAlreadyLinkedToAnotherPartnerException, AccountAlreadyLinkedToAnotherPartnerException,
@ -23,6 +24,7 @@ from app.models import (
User, User,
Alias, Alias,
) )
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.utils import random_string from app.utils import random_string
@ -66,9 +68,10 @@ def set_plan_for_partner_user(partner_user: PartnerUser, plan: SLPlan):
LOG.i( LOG.i(
f"Creating partner_subscription [user_id={partner_user.user_id}] [partner_id={partner_user.partner_id}]" f"Creating partner_subscription [user_id={partner_user.user_id}] [partner_id={partner_user.partner_id}]"
) )
PartnerSubscription.create( create_partner_subscription(
partner_user_id=partner_user.id, partner_user=partner_user,
end_at=plan.expiration, expiration=plan.expiration,
msg="Upgraded via partner. User did not have a previous partner subscription",
) )
agent.record_custom_event("PlanChange", {"plan": "premium", "type": "new"}) agent.record_custom_event("PlanChange", {"plan": "premium", "type": "new"})
else: else:
@ -80,6 +83,11 @@ def set_plan_for_partner_user(partner_user: PartnerUser, plan: SLPlan):
"PlanChange", {"plan": "premium", "type": "extension"} "PlanChange", {"plan": "premium", "type": "extension"}
) )
sub.end_at = plan.expiration sub.end_at = plan.expiration
emit_user_audit_log(
user=partner_user.user,
action=UserAuditLogAction.SubscriptionExtended,
message="Extended partner subscription",
)
Session.commit() Session.commit()
@ -98,8 +106,8 @@ def ensure_partner_user_exists_for_user(
if res and res.partner_id != partner.id: if res and res.partner_id != partner.id:
raise AccountAlreadyLinkedToAnotherPartnerException() raise AccountAlreadyLinkedToAnotherPartnerException()
if not res: if not res:
res = PartnerUser.create( res = create_partner_user(
user_id=sl_user.id, user=sl_user,
partner_id=partner.id, partner_id=partner.id,
partner_email=link_request.email, partner_email=link_request.email,
external_user_id=link_request.external_user_id, external_user_id=link_request.external_user_id,
@ -140,8 +148,8 @@ class NewUserStrategy(ClientMergeStrategy):
activated=True, activated=True,
from_partner=self.link_request.from_partner, from_partner=self.link_request.from_partner,
) )
partner_user = PartnerUser.create( partner_user = create_partner_user(
user_id=new_user.id, user=new_user,
partner_id=self.partner.id, partner_id=self.partner.id,
external_user_id=self.link_request.external_user_id, external_user_id=self.link_request.external_user_id,
partner_email=self.link_request.email, partner_email=self.link_request.email,
@ -200,7 +208,7 @@ def get_login_strategy(
return ExistingUnlinkedUserStrategy(link_request, user, partner) return ExistingUnlinkedUserStrategy(link_request, user, partner)
def check_alias(email: str) -> bool: def check_alias(email: str):
alias = Alias.get_by(email=email) alias = Alias.get_by(email=email)
if alias is not None: if alias is not None:
raise AccountIsUsingAliasAsEmail() raise AccountIsUsingAliasAsEmail()
@ -275,10 +283,26 @@ def switch_already_linked_user(
LOG.i( LOG.i(
f"Deleting previous partner_user:{other_partner_user.id} from user:{current_user.id}" f"Deleting previous partner_user:{other_partner_user.id} from user:{current_user.id}"
) )
emit_user_audit_log(
user=other_partner_user.user,
action=UserAuditLogAction.UnlinkAccount,
message=f"Deleting partner_user {other_partner_user.id} (external_user_id={other_partner_user.external_user_id} | partner_email={other_partner_user.partner_email}) from user {current_user.id}, as we received a new link request for the same partner",
)
PartnerUser.delete(other_partner_user.id) PartnerUser.delete(other_partner_user.id)
LOG.i(f"Linking partner_user:{partner_user.id} to user:{current_user.id}") LOG.i(f"Linking partner_user:{partner_user.id} to user:{current_user.id}")
# Link this partner_user to the current user # Link this partner_user to the current user
emit_user_audit_log(
user=partner_user.user,
action=UserAuditLogAction.UnlinkAccount,
message=f"Unlinking from partner, as user will now be tied to another external account. old=(id={partner_user.user.id} | email={partner_user.user.email}) | new=(id={current_user.id} | email={current_user.email})",
)
partner_user.user_id = current_user.id partner_user.user_id = current_user.id
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.LinkAccount,
message=f"Linking user {current_user.id} ({current_user.email}) to partner_user:{partner_user.id} (external_user_id={partner_user.external_user_id} | partner_email={partner_user.partner_email})",
)
# Set plan # Set plan
set_plan_for_partner_user(partner_user, link_request.plan) set_plan_for_partner_user(partner_user, link_request.plan)
Session.commit() Session.commit()

View File

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional from typing import Optional, List
import arrow import arrow
import sqlalchemy import sqlalchemy
@ -34,6 +34,9 @@ from app.models import (
DeletedAlias, DeletedAlias,
DomainDeletedAlias, DomainDeletedAlias,
PartnerUser, PartnerUser,
AliasMailbox,
AliasAuditLog,
UserAuditLog,
) )
from app.newsletter_utils import send_newsletter_to_user, send_newsletter_to_address from app.newsletter_utils import send_newsletter_to_user, send_newsletter_to_address
@ -736,11 +739,13 @@ class InvalidMailboxDomainAdmin(SLModelView):
class EmailSearchResult: class EmailSearchResult:
no_match: bool = True no_match: bool = True
alias: Optional[Alias] = None alias: Optional[Alias] = None
mailbox: list[Mailbox] = [] alias_audit_log: Optional[List[AliasAuditLog]] = None
mailbox: List[Mailbox] = []
mailbox_count: int = 0 mailbox_count: int = 0
deleted_alias: Optional[DeletedAlias] = None deleted_alias: Optional[DeletedAlias] = None
deleted_custom_alias: Optional[DomainDeletedAlias] = None deleted_custom_alias: Optional[DomainDeletedAlias] = None
user: Optional[User] = None user: Optional[User] = None
user_audit_log: Optional[List[UserAuditLog]] = None
@staticmethod @staticmethod
def from_email(email: str) -> EmailSearchResult: def from_email(email: str) -> EmailSearchResult:
@ -748,10 +753,20 @@ class EmailSearchResult:
alias = Alias.get_by(email=email) alias = Alias.get_by(email=email)
if alias: if alias:
output.alias = alias output.alias = alias
output.alias_audit_log = (
AliasAuditLog.filter_by(alias_id=alias.id)
.order_by(AliasAuditLog.created_at.desc())
.all()
)
output.no_match = False output.no_match = False
user = User.get_by(email=email) user = User.get_by(email=email)
if user: if user:
output.user = user output.user = user
output.user_audit_log = (
UserAuditLog.filter_by(user_id=user.id)
.order_by(UserAuditLog.created_at.desc())
.all()
)
output.no_match = False output.no_match = False
mailboxes = ( mailboxes = (
Mailbox.filter_by(email=email).order_by(Mailbox.id.desc()).limit(10).all() Mailbox.filter_by(email=email).order_by(Mailbox.id.desc()).limit(10).all()
@ -785,6 +800,25 @@ class EmailSearchHelpers:
def mailbox_count(user: User) -> int: def mailbox_count(user: User) -> int:
return Mailbox.filter_by(user_id=user.id).order_by(Mailbox.id.desc()).count() return Mailbox.filter_by(user_id=user.id).order_by(Mailbox.id.desc()).count()
@staticmethod
def alias_mailboxes(alias: Alias) -> list[Mailbox]:
return (
Session.query(Mailbox)
.filter(Mailbox.id == Alias.mailbox_id, Alias.id == alias.id)
.union(
Session.query(Mailbox)
.join(AliasMailbox, Mailbox.id == AliasMailbox.mailbox_id)
.filter(AliasMailbox.alias_id == alias.id)
)
.order_by(Mailbox.id)
.limit(10)
.all()
)
@staticmethod
def alias_mailbox_count(alias: Alias) -> int:
return len(alias.mailboxes)
@staticmethod @staticmethod
def alias_list(user: User) -> list[Alias]: def alias_list(user: User) -> list[Alias]:
return ( return (

View File

@ -0,0 +1,38 @@
from enum import Enum
from typing import Optional
from app.models import Alias, AliasAuditLog
class AliasAuditLogAction(Enum):
CreateAlias = "create"
ChangeAliasStatus = "change_status"
DeleteAlias = "delete"
UpdateAlias = "update"
InitiateTransferAlias = "initiate_transfer_alias"
AcceptTransferAlias = "accept_transfer_alias"
TransferredAlias = "transferred_alias"
ChangedMailboxes = "changed_mailboxes"
CreateContact = "create_contact"
UpdateContact = "update_contact"
DeleteContact = "delete_contact"
def emit_alias_audit_log(
alias: Alias,
action: AliasAuditLogAction,
message: str,
user_id: Optional[int] = None,
commit: bool = False,
):
AliasAuditLog.create(
user_id=user_id or alias.user_id,
alias_id=alias.id,
alias_email=alias.email,
action=action.value,
message=message,
commit=commit,
)

View File

@ -0,0 +1,61 @@
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional
from app.alias_audit_log_utils import emit_alias_audit_log, AliasAuditLogAction
from app.db import Session
from app.models import Alias, AliasMailbox, Mailbox
_MAX_MAILBOXES_PER_ALIAS = 20
class CannotSetMailboxesForAliasCause(Enum):
Forbidden = "Forbidden"
EmptyMailboxes = "Must choose at least one mailbox"
TooManyMailboxes = "Too many mailboxes"
@dataclass
class SetMailboxesForAliasResult:
performed_change: bool
reason: Optional[CannotSetMailboxesForAliasCause]
def set_mailboxes_for_alias(
user_id: int, alias: Alias, mailbox_ids: List[int]
) -> Optional[CannotSetMailboxesForAliasCause]:
if len(mailbox_ids) == 0:
return CannotSetMailboxesForAliasCause.EmptyMailboxes
if len(mailbox_ids) > _MAX_MAILBOXES_PER_ALIAS:
return CannotSetMailboxesForAliasCause.TooManyMailboxes
mailboxes = (
Session.query(Mailbox)
.filter(
Mailbox.id.in_(mailbox_ids),
Mailbox.user_id == user_id,
Mailbox.verified == True, # noqa: E712
)
.all()
)
if len(mailboxes) != len(mailbox_ids):
return CannotSetMailboxesForAliasCause.Forbidden
# first remove all existing alias-mailboxes links
AliasMailbox.filter_by(alias_id=alias.id).delete()
Session.flush()
# then add all new mailboxes, being the first the one associated with the alias
for i, mailbox in enumerate(mailboxes):
if i == 0:
alias.mailbox_id = mailboxes[0].id
else:
AliasMailbox.create(alias_id=alias.id, mailbox_id=mailbox.id)
emit_alias_audit_log(
alias=alias,
action=AliasAuditLogAction.ChangedMailboxes,
message=",".join([f"{mailbox.id} ({mailbox.email})" for mailbox in mailboxes]),
)
return None

View File

@ -1,12 +1,14 @@
import csv import csv
from io import StringIO from io import StringIO
import re import re
from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
from email_validator import validate_email, EmailNotValidError from email_validator import validate_email, EmailNotValidError
from sqlalchemy.exc import IntegrityError, DataError from sqlalchemy.exc import IntegrityError, DataError
from flask import make_response from flask import make_response
from app.alias_audit_log_utils import AliasAuditLogAction, emit_alias_audit_log
from app.config import ( from app.config import (
BOUNCE_PREFIX_FOR_REPLY_PHASE, BOUNCE_PREFIX_FOR_REPLY_PHASE,
BOUNCE_PREFIX, BOUNCE_PREFIX,
@ -23,6 +25,7 @@ from app.email_utils import (
send_cannot_create_domain_alias, send_cannot_create_domain_alias,
send_email, send_email,
render, render,
sl_formataddr,
) )
from app.errors import AliasInTrashError from app.errors import AliasInTrashError
from app.events.event_dispatcher import EventDispatcher from app.events.event_dispatcher import EventDispatcher
@ -30,6 +33,7 @@ from app.events.generated.event_pb2 import (
AliasDeleted, AliasDeleted,
AliasStatusChanged, AliasStatusChanged,
EventContent, EventContent,
AliasCreated,
) )
from app.log import LOG from app.log import LOG
from app.models import ( from app.models import (
@ -365,6 +369,10 @@ def delete_alias(
alias_id = alias.id alias_id = alias.id
alias_email = alias.email alias_email = alias.email
emit_alias_audit_log(
alias, AliasAuditLogAction.DeleteAlias, "Alias deleted by user action"
)
Alias.filter(Alias.id == alias.id).delete() Alias.filter(Alias.id == alias.id).delete()
Session.commit() Session.commit()
@ -447,7 +455,7 @@ def alias_export_csv(user, csv_direct_export=False):
return output return output
def transfer_alias(alias, new_user, new_mailboxes: [Mailbox]): def transfer_alias(alias: Alias, new_user: User, new_mailboxes: [Mailbox]):
# cannot transfer alias which is used for receiving newsletter # cannot transfer alias which is used for receiving newsletter
if User.get_by(newsletter_alias_id=alias.id): if User.get_by(newsletter_alias_id=alias.id):
raise Exception("Cannot transfer alias that's used to receive newsletter") raise Exception("Cannot transfer alias that's used to receive newsletter")
@ -501,10 +509,47 @@ def transfer_alias(alias, new_user, new_mailboxes: [Mailbox]):
alias.disable_pgp = False alias.disable_pgp = False
alias.pinned = False alias.pinned = False
emit_alias_audit_log(
alias=alias,
action=AliasAuditLogAction.TransferredAlias,
message=f"Lost ownership of alias due to alias transfer confirmed. New owner is {new_user.id}",
user_id=old_user.id,
)
EventDispatcher.send_event(
old_user,
EventContent(
alias_deleted=AliasDeleted(
id=alias.id,
email=alias.email,
)
),
)
emit_alias_audit_log(
alias=alias,
action=AliasAuditLogAction.AcceptTransferAlias,
message=f"Accepted alias transfer from user {old_user.id}",
user_id=new_user.id,
)
EventDispatcher.send_event(
new_user,
EventContent(
alias_created=AliasCreated(
id=alias.id,
email=alias.email,
note=alias.note,
enabled=alias.enabled,
created_at=int(alias.created_at.timestamp),
)
),
)
Session.commit() Session.commit()
def change_alias_status(alias: Alias, enabled: bool, commit: bool = False): def change_alias_status(
alias: Alias, enabled: bool, message: Optional[str] = None, commit: bool = False
):
LOG.i(f"Changing alias {alias} enabled to {enabled}") LOG.i(f"Changing alias {alias} enabled to {enabled}")
alias.enabled = enabled alias.enabled = enabled
@ -515,6 +560,39 @@ def change_alias_status(alias: Alias, enabled: bool, commit: bool = False):
created_at=int(alias.created_at.timestamp), created_at=int(alias.created_at.timestamp),
) )
EventDispatcher.send_event(alias.user, EventContent(alias_status_change=event)) EventDispatcher.send_event(alias.user, EventContent(alias_status_change=event))
audit_log_message = f"Set alias status to {enabled}"
if message is not None:
audit_log_message += f". {message}"
emit_alias_audit_log(
alias, AliasAuditLogAction.ChangeAliasStatus, audit_log_message
)
if commit: if commit:
Session.commit() Session.commit()
@dataclass
class AliasRecipientName:
name: str
message: Optional[str] = None
def get_alias_recipient_name(alias: Alias) -> AliasRecipientName:
"""
Logic:
1. If alias has name, use it
2. If alias has custom domain, and custom domain has name, use it
3. Otherwise, use the alias email as the recipient
"""
if alias.name:
return AliasRecipientName(
name=sl_formataddr((alias.name, alias.email)),
message=f"Put alias name {alias.name} in from header",
)
elif alias.custom_domain:
if alias.custom_domain.name:
return AliasRecipientName(
name=sl_formataddr((alias.custom_domain.name, alias.email)),
message=f"Put domain default alias name {alias.custom_domain.name} in from header",
)
return AliasRecipientName(name=alias.email)

View File

@ -1,9 +1,13 @@
from typing import Optional
from deprecated import deprecated from deprecated import deprecated
from flask import g from flask import g
from flask import jsonify from flask import jsonify
from flask import request from flask import request
from app import alias_utils from app import alias_utils
from app.alias_audit_log_utils import emit_alias_audit_log, AliasAuditLogAction
from app.alias_mailbox_utils import set_mailboxes_for_alias
from app.api.base import api_bp, require_api_auth from app.api.base import api_bp, require_api_auth
from app.api.serializer import ( from app.api.serializer import (
AliasInfo, AliasInfo,
@ -26,7 +30,7 @@ from app.errors import (
) )
from app.extensions import limiter from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import Alias, Contact, Mailbox, AliasMailbox, AliasDeleteReason from app.models import Alias, Contact, Mailbox, AliasDeleteReason
@deprecated @deprecated
@ -185,7 +189,11 @@ def toggle_alias(alias_id):
if not alias or alias.user_id != user.id: if not alias or alias.user_id != user.id:
return jsonify(error="Forbidden"), 403 return jsonify(error="Forbidden"), 403
alias_utils.change_alias_status(alias, enabled=not alias.enabled) alias_utils.change_alias_status(
alias,
enabled=not alias.enabled,
message=f"Set enabled={not alias.enabled} via API",
)
LOG.i(f"User {user} changed alias {alias} enabled status to {alias.enabled}") LOG.i(f"User {user} changed alias {alias} enabled status to {alias.enabled}")
Session.commit() Session.commit()
@ -272,10 +280,12 @@ def update_alias(alias_id):
if not alias or alias.user_id != user.id: if not alias or alias.user_id != user.id:
return jsonify(error="Forbidden"), 403 return jsonify(error="Forbidden"), 403
changed_fields = []
changed = False changed = False
if "note" in data: if "note" in data:
new_note = data.get("note") new_note = data.get("note")
alias.note = new_note alias.note = new_note
changed_fields.append("note")
changed = True changed = True
if "mailbox_id" in data: if "mailbox_id" in data:
@ -285,35 +295,19 @@ def update_alias(alias_id):
return jsonify(error="Forbidden"), 400 return jsonify(error="Forbidden"), 400
alias.mailbox_id = mailbox_id alias.mailbox_id = mailbox_id
changed_fields.append(f"mailbox_id ({mailbox_id})")
changed = True changed = True
if "mailbox_ids" in data: if "mailbox_ids" in data:
mailbox_ids = [int(m_id) for m_id in data.get("mailbox_ids")] mailbox_ids = [int(m_id) for m_id in data.get("mailbox_ids")]
mailboxes: [Mailbox] = [] err = set_mailboxes_for_alias(
user_id=user.id, alias=alias, mailbox_ids=mailbox_ids
# check if all mailboxes belong to user )
for mailbox_id in mailbox_ids: if err:
mailbox = Mailbox.get(mailbox_id) return jsonify(error=err.value), 400
if not mailbox or mailbox.user_id != user.id or not mailbox.verified:
return jsonify(error="Forbidden"), 400
mailboxes.append(mailbox)
if not mailboxes:
return jsonify(error="Must choose at least one mailbox"), 400
# <<< update alias mailboxes >>>
# first remove all existing alias-mailboxes links
AliasMailbox.filter_by(alias_id=alias.id).delete()
Session.flush()
# then add all new mailboxes
for i, mailbox in enumerate(mailboxes):
if i == 0:
alias.mailbox_id = mailboxes[0].id
else:
AliasMailbox.create(alias_id=alias.id, mailbox_id=mailbox.id)
# <<< END update alias mailboxes >>>
mailbox_ids_string = ",".join(map(str, mailbox_ids))
changed_fields.append(f"mailbox_ids ({mailbox_ids_string})")
changed = True changed = True
if "name" in data: if "name" in data:
@ -325,17 +319,26 @@ def update_alias(alias_id):
if new_name: if new_name:
new_name = new_name.replace("\n", "") new_name = new_name.replace("\n", "")
alias.name = new_name alias.name = new_name
changed_fields.append("name")
changed = True changed = True
if "disable_pgp" in data: if "disable_pgp" in data:
alias.disable_pgp = data.get("disable_pgp") alias.disable_pgp = data.get("disable_pgp")
changed_fields.append("disable_pgp")
changed = True changed = True
if "pinned" in data: if "pinned" in data:
alias.pinned = data.get("pinned") alias.pinned = data.get("pinned")
changed_fields.append("pinned")
changed = True changed = True
if changed: if changed:
changed_fields_string = ",".join(changed_fields)
emit_alias_audit_log(
alias,
AliasAuditLogAction.UpdateAlias,
f"Alias fields updated ({changed_fields_string})",
)
Session.commit() Session.commit()
return jsonify(ok=True), 200 return jsonify(ok=True), 200
@ -424,7 +427,7 @@ def create_contact_route(alias_id):
contact_address = data.get("contact") contact_address = data.get("contact")
try: try:
contact = create_contact(g.user, alias, contact_address) contact = create_contact(alias, contact_address)
except ErrContactErrorUpgradeNeeded as err: except ErrContactErrorUpgradeNeeded as err:
return jsonify(error=err.error_for_user()), 403 return jsonify(error=err.error_for_user()), 403
except (ErrAddressInvalid, CannotCreateContactForReverseAlias) as err: except (ErrAddressInvalid, CannotCreateContactForReverseAlias) as err:
@ -446,11 +449,16 @@ def delete_contact(contact_id):
200 200
""" """
user = g.user user = g.user
contact = Contact.get(contact_id) contact: Optional[Contact] = Contact.get(contact_id)
if not contact or contact.alias.user_id != user.id: if not contact or contact.alias.user_id != user.id:
return jsonify(error="Forbidden"), 403 return jsonify(error="Forbidden"), 403
emit_alias_audit_log(
alias=contact.alias,
action=AliasAuditLogAction.DeleteContact,
message=f"Deleted contact {contact_id} ({contact.email})",
)
Contact.delete(contact_id) Contact.delete(contact_id)
Session.commit() Session.commit()
@ -468,12 +476,17 @@ def toggle_contact(contact_id):
200 200
""" """
user = g.user user = g.user
contact = Contact.get(contact_id) contact: Optional[Contact] = Contact.get(contact_id)
if not contact or contact.alias.user_id != user.id: if not contact or contact.alias.user_id != user.id:
return jsonify(error="Forbidden"), 403 return jsonify(error="Forbidden"), 403
contact.block_forward = not contact.block_forward contact.block_forward = not contact.block_forward
emit_alias_audit_log(
alias=contact.alias,
action=AliasAuditLogAction.UpdateContact,
message=f"Set contact state {contact.id} {contact.email} -> {contact.website_email} to blocked {contact.block_forward}",
)
Session.commit() Session.commit()
return jsonify(block_forward=contact.block_forward), 200 return jsonify(block_forward=contact.block_forward), 200

View File

@ -52,8 +52,12 @@ def auth_login():
password = data.get("password") password = data.get("password")
device = data.get("device") device = data.get("device")
email = sanitize_email(data.get("email")) email = data.get("email")
canonical_email = canonicalize_email(data.get("email")) if not email:
LoginEvent(LoginEvent.ActionType.failed, LoginEvent.Source.api).send()
return jsonify(error="Email or password incorrect"), 400
email = sanitize_email(email)
canonical_email = canonicalize_email(email)
user = User.get_by(email=email) or User.get_by(email=canonical_email) user = User.get_by(email=email) or User.get_by(email=canonical_email)

View File

@ -2,8 +2,10 @@ from flask import g, request
from flask import jsonify from flask import jsonify
from app.api.base import api_bp, require_api_auth from app.api.base import api_bp, require_api_auth
from app.custom_domain_utils import set_custom_domain_mailboxes
from app.db import Session from app.db import Session
from app.models import CustomDomain, DomainDeletedAlias, Mailbox, DomainMailbox from app.log import LOG
from app.models import CustomDomain, DomainDeletedAlias
def custom_domain_to_dict(custom_domain: CustomDomain): def custom_domain_to_dict(custom_domain: CustomDomain):
@ -100,23 +102,14 @@ def update_custom_domain(custom_domain_id):
if "mailbox_ids" in data: if "mailbox_ids" in data:
mailbox_ids = [int(m_id) for m_id in data.get("mailbox_ids")] mailbox_ids = [int(m_id) for m_id in data.get("mailbox_ids")]
if mailbox_ids: result = set_custom_domain_mailboxes(user.id, custom_domain, mailbox_ids)
# check if mailbox is not tempered with if result.success:
mailboxes = []
for mailbox_id in mailbox_ids:
mailbox = Mailbox.get(mailbox_id)
if not mailbox or mailbox.user_id != user.id or not mailbox.verified:
return jsonify(error="Forbidden"), 400
mailboxes.append(mailbox)
# first remove all existing domain-mailboxes links
DomainMailbox.filter_by(domain_id=custom_domain.id).delete()
Session.flush()
for mailbox in mailboxes:
DomainMailbox.create(domain_id=custom_domain.id, mailbox_id=mailbox.id)
changed = True changed = True
else:
LOG.info(
f"Prevented from updating mailboxes [custom_domain_id={custom_domain.id}]: {result.reason.value}"
)
return jsonify(error="Forbidden"), 400
if changed: if changed:
Session.commit() Session.commit()

View File

@ -6,6 +6,7 @@ from app import config
from app.extensions import limiter from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import Job, ApiToCookieToken from app.models import Job, ApiToCookieToken
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
@api_bp.route("/user", methods=["DELETE"]) @api_bp.route("/user", methods=["DELETE"])
@ -16,6 +17,11 @@ def delete_user():
""" """
# Schedule delete account job # Schedule delete account job
emit_user_audit_log(
user=g.user,
action=UserAuditLogAction.UserMarkedForDeletion,
message=f"Marked user {g.user.id} ({g.user.email}) for deletion from API",
)
LOG.w("schedule delete account job for %s", g.user) LOG.w("schedule delete account job for %s", g.user)
Job.create( Job.create(
name=config.JOB_DELETE_ACCOUNT, name=config.JOB_DELETE_ACCOUNT,

View File

@ -87,7 +87,7 @@ def update_user_info():
File.delete(file.id) File.delete(file.id)
s3.delete(file.path) s3.delete(file.path)
Session.flush() Session.flush()
else: if data["profile_picture"] is not None:
raw_data = base64.decodebytes(data["profile_picture"].encode()) raw_data = base64.decodebytes(data["profile_picture"].encode())
if detect_image_format(raw_data) == ImageFormat.Unknown: if detect_image_format(raw_data) == ImageFormat.Unknown:
return jsonify(error="Unsupported image format"), 400 return jsonify(error="Unsupported image format"), 400

View File

@ -601,7 +601,6 @@ SKIP_MX_LOOKUP_ON_CHECK = False
DISABLE_RATE_LIMIT = "DISABLE_RATE_LIMIT" in os.environ DISABLE_RATE_LIMIT = "DISABLE_RATE_LIMIT" in os.environ
SUBSCRIPTION_CHANGE_WEBHOOK = os.environ.get("SUBSCRIPTION_CHANGE_WEBHOOK", None)
MAX_API_KEYS = int(os.environ.get("MAX_API_KEYS", 30)) MAX_API_KEYS = int(os.environ.get("MAX_API_KEYS", 30))
UPCLOUD_USERNAME = os.environ.get("UPCLOUD_USERNAME", None) UPCLOUD_USERNAME = os.environ.get("UPCLOUD_USERNAME", None)
@ -653,7 +652,15 @@ def read_partner_dict(var: str) -> dict[int, str]:
return res return res
PARTNER_DOMAINS: dict[int, str] = read_partner_dict("PARTNER_DOMAINS") PARTNER_DNS_CUSTOM_DOMAINS: dict[int, str] = read_partner_dict(
PARTNER_DOMAIN_VALIDATION_PREFIXES: dict[int, str] = read_partner_dict( "PARTNER_DNS_CUSTOM_DOMAINS"
"PARTNER_DOMAIN_VALIDATION_PREFIXES"
) )
PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES: dict[int, str] = read_partner_dict(
"PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES"
)
MAILBOX_VERIFICATION_OVERRIDE_CODE: Optional[str] = os.environ.get(
"MAILBOX_VERIFICATION_OVERRIDE_CODE", None
)
AUDIT_LOG_MAX_DAYS = int(os.environ.get("AUDIT_LOG_MAX_DAYS", 30))

View File

@ -4,8 +4,9 @@ from typing import Optional
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from app.alias_audit_log_utils import emit_alias_audit_log, AliasAuditLogAction
from app.db import Session from app.db import Session
from app.email_utils import generate_reply_email from app.email_utils import generate_reply_email, parse_full_address
from app.email_validation import is_valid_email from app.email_validation import is_valid_email
from app.log import LOG from app.log import LOG
from app.models import Contact, Alias from app.models import Contact, Alias
@ -14,11 +15,13 @@ from app.utils import sanitize_email
class ContactCreateError(Enum): class ContactCreateError(Enum):
InvalidEmail = "Invalid email" InvalidEmail = "Invalid email"
NotAllowed = "Your plan does not allow to create contacts"
@dataclass @dataclass
class ContactCreateResult: class ContactCreateResult:
contact: Optional[Contact] contact: Optional[Contact]
created: bool
error: Optional[ContactCreateError] error: Optional[ContactCreateError]
@ -33,34 +36,56 @@ def __update_contact_if_needed(
LOG.d(f"Setting {contact} mail_from to {mail_from}") LOG.d(f"Setting {contact} mail_from to {mail_from}")
contact.mail_from = mail_from contact.mail_from = mail_from
Session.commit() Session.commit()
return ContactCreateResult(contact, None) return ContactCreateResult(contact, created=False, error=None)
def create_contact( def create_contact(
email: str, email: str,
name: Optional[str],
alias: Alias, alias: Alias,
name: Optional[str] = None,
mail_from: Optional[str] = None, mail_from: Optional[str] = None,
allow_empty_email: bool = False, allow_empty_email: bool = False,
automatic_created: bool = False, automatic_created: bool = False,
from_partner: bool = False, from_partner: bool = False,
) -> ContactCreateResult: ) -> ContactCreateResult:
if name is not None: # If user cannot create contacts, they still need to be created when receiving an email for an alias
if not automatic_created and not alias.user.can_create_contacts():
return ContactCreateResult(
None, created=False, error=ContactCreateError.NotAllowed
)
# Parse emails with form 'name <email>'
try:
email_name, email = parse_full_address(email)
except ValueError:
email = ""
email_name = ""
# If no name is explicitly given try to get it from the parsed email
if name is None:
name = email_name[: Contact.MAX_NAME_LENGTH]
else:
name = name[: Contact.MAX_NAME_LENGTH] name = name[: Contact.MAX_NAME_LENGTH]
# If still no name is there, make sure the name is None instead of empty string
if not name:
name = None
if name is not None and "\x00" in name: if name is not None and "\x00" in name:
LOG.w("Cannot use contact name because has \\x00") LOG.w("Cannot use contact name because has \\x00")
name = "" name = ""
# Sanitize email and if it's not valid only allow to create a contact if it's explicitly allowed. Otherwise fail
email = sanitize_email(email, not_lower=True)
if not is_valid_email(email): if not is_valid_email(email):
LOG.w(f"invalid contact email {email}") LOG.w(f"invalid contact email {email}")
if not allow_empty_email: if not allow_empty_email:
return ContactCreateResult(None, ContactCreateError.InvalidEmail) return ContactCreateResult(
None, created=False, error=ContactCreateError.InvalidEmail
)
LOG.d("Create a contact with invalid email for %s", alias) LOG.d("Create a contact with invalid email for %s", alias)
# either reuse a contact with empty email or create a new contact with empty email # either reuse a contact with empty email or create a new contact with empty email
email = "" email = ""
email = sanitize_email(email, not_lower=True) # If contact exists, update name and mail_from if needed
contact = Contact.get_by(alias_id=alias.id, website_email=email) contact = Contact.get_by(alias_id=alias.id, website_email=email)
if contact is not None: if contact is not None:
return __update_contact_if_needed(contact, name, mail_from) return __update_contact_if_needed(contact, name, mail_from)
# Create the contact
reply_email = generate_reply_email(email, alias) reply_email = generate_reply_email(email, alias)
try: try:
flags = Contact.FLAG_PARTNER_CREATED if from_partner else 0 flags = Contact.FLAG_PARTNER_CREATED if from_partner else 0
@ -76,6 +101,16 @@ def create_contact(
invalid_email=email == "", invalid_email=email == "",
commit=True, commit=True,
) )
if automatic_created:
trail = ". Automatically created"
else:
trail = ". Created by user action"
emit_alias_audit_log(
alias=alias,
action=AliasAuditLogAction.CreateContact,
message=f"Created contact {contact.id} ({contact.email}){trail}",
commit=True,
)
LOG.d( LOG.d(
f"Created contact {contact} for alias {alias} with email {email} invalid_email={contact.invalid_email}" f"Created contact {contact} for alias {alias} with email {email} invalid_email={contact.invalid_email}"
) )
@ -86,4 +121,4 @@ def create_contact(
) )
contact = Contact.get_by(alias_id=alias.id, website_email=email) contact = Contact.get_by(alias_id=alias.id, website_email=email)
return __update_contact_if_needed(contact, name, mail_from) return __update_contact_if_needed(contact, name, mail_from)
return ContactCreateResult(contact, None) return ContactCreateResult(contact, created=True, error=None)

View File

@ -3,15 +3,17 @@ import re
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Optional from typing import List, Optional
from app.config import JOB_DELETE_DOMAIN from app.config import JOB_DELETE_DOMAIN
from app.db import Session from app.db import Session
from app.email_utils import get_email_domain_part from app.email_utils import get_email_domain_part
from app.log import LOG from app.log import LOG
from app.models import User, CustomDomain, SLDomain, Mailbox, Job from app.models import User, CustomDomain, SLDomain, Mailbox, Job, DomainMailbox
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
_ALLOWED_DOMAIN_REGEX = re.compile(r"^(?!-)[A-Za-z0-9-]{1,63}(?<!-)$") _ALLOWED_DOMAIN_REGEX = re.compile(r"^(?!-)[A-Za-z0-9-]{1,63}(?<!-)$")
_MAX_MAILBOXES_PER_DOMAIN = 20
@dataclass @dataclass
@ -45,6 +47,20 @@ class CannotUseDomainReason(Enum):
raise Exception("Invalid CannotUseDomainReason") raise Exception("Invalid CannotUseDomainReason")
class CannotSetCustomDomainMailboxesCause(Enum):
InvalidMailbox = "Something went wrong, please retry"
NoMailboxes = "You must select at least 1 mailbox"
TooManyMailboxes = (
f"You can only set up to {_MAX_MAILBOXES_PER_DOMAIN} mailboxes per domain"
)
@dataclass
class SetCustomDomainMailboxesResult:
success: bool
reason: Optional[CannotSetCustomDomainMailboxesCause] = None
def is_valid_domain(domain: str) -> bool: def is_valid_domain(domain: str) -> bool:
""" """
Checks that a domain is valid according to RFC 1035 Checks that a domain is valid according to RFC 1035
@ -122,6 +138,11 @@ def create_custom_domain(
if partner_id is not None: if partner_id is not None:
new_custom_domain.partner_id = partner_id new_custom_domain.partner_id = partner_id
emit_user_audit_log(
user=user,
action=UserAuditLogAction.CreateCustomDomain,
message=f"Created custom domain {new_custom_domain.id} ({new_domain})",
)
Session.commit() Session.commit()
return CreateCustomDomainResult( return CreateCustomDomainResult(
@ -140,3 +161,46 @@ def delete_custom_domain(domain: CustomDomain):
run_at=arrow.now(), run_at=arrow.now(),
commit=True, commit=True,
) )
def set_custom_domain_mailboxes(
user_id: int, custom_domain: CustomDomain, mailbox_ids: List[int]
) -> SetCustomDomainMailboxesResult:
if len(mailbox_ids) == 0:
return SetCustomDomainMailboxesResult(
success=False, reason=CannotSetCustomDomainMailboxesCause.NoMailboxes
)
elif len(mailbox_ids) > _MAX_MAILBOXES_PER_DOMAIN:
return SetCustomDomainMailboxesResult(
success=False, reason=CannotSetCustomDomainMailboxesCause.TooManyMailboxes
)
mailboxes = (
Session.query(Mailbox)
.filter(
Mailbox.id.in_(mailbox_ids),
Mailbox.user_id == user_id,
Mailbox.verified == True, # noqa: E712
)
.all()
)
if len(mailboxes) != len(mailbox_ids):
return SetCustomDomainMailboxesResult(
success=False, reason=CannotSetCustomDomainMailboxesCause.InvalidMailbox
)
# first remove all existing domain-mailboxes links
DomainMailbox.filter_by(domain_id=custom_domain.id).delete()
Session.flush()
for mailbox in mailboxes:
DomainMailbox.create(domain_id=custom_domain.id, mailbox_id=mailbox.id)
mailboxes_as_str = ",".join(map(str, mailbox_ids))
emit_user_audit_log(
user=custom_domain.user,
action=UserAuditLogAction.UpdateCustomDomain,
message=f"Updated custom domain {custom_domain.id} mailboxes (domain={custom_domain.domain}) (mailboxes={mailboxes_as_str})",
)
Session.commit()
return SetCustomDomainMailboxesResult(success=True)

View File

@ -1,15 +1,18 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import List, Optional
from app import config from app import config
from app.constants import DMARC_RECORD from app.constants import DMARC_RECORD
from app.db import Session from app.db import Session
from app.dns_utils import ( from app.dns_utils import (
MxRecord,
DNSClient, DNSClient,
is_mx_equivalent, is_mx_equivalent,
get_network_dns_client, get_network_dns_client,
) )
from app.models import CustomDomain from app.models import CustomDomain
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.utils import random_string
@dataclass @dataclass
@ -28,10 +31,10 @@ class CustomDomainValidation:
): ):
self.dkim_domain = dkim_domain self.dkim_domain = dkim_domain
self._dns_client = dns_client self._dns_client = dns_client
self._partner_domains = partner_domains or config.PARTNER_DOMAINS self._partner_domains = partner_domains or config.PARTNER_DNS_CUSTOM_DOMAINS
self._partner_domain_validation_prefixes = ( self._partner_domain_validation_prefixes = (
partner_domains_validation_prefixes partner_domains_validation_prefixes
or config.PARTNER_DOMAIN_VALIDATION_PREFIXES or config.PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES
) )
def get_ownership_verification_record(self, domain: CustomDomain) -> str: def get_ownership_verification_record(self, domain: CustomDomain) -> str:
@ -41,8 +44,36 @@ class CustomDomainValidation:
and domain.partner_id in self._partner_domain_validation_prefixes and domain.partner_id in self._partner_domain_validation_prefixes
): ):
prefix = self._partner_domain_validation_prefixes[domain.partner_id] prefix = self._partner_domain_validation_prefixes[domain.partner_id]
if not domain.ownership_txt_token:
domain.ownership_txt_token = random_string(30)
Session.commit()
return f"{prefix}-verification={domain.ownership_txt_token}" return f"{prefix}-verification={domain.ownership_txt_token}"
def get_expected_mx_records(self, domain: CustomDomain) -> list[MxRecord]:
records = []
if domain.partner_id is not None and domain.partner_id in self._partner_domains:
domain = self._partner_domains[domain.partner_id]
records.append(MxRecord(10, f"mx1.{domain}."))
records.append(MxRecord(20, f"mx2.{domain}."))
else:
# Default ones
for priority, domain in config.EMAIL_SERVERS_WITH_PRIORITY:
records.append(MxRecord(priority, domain))
return records
def get_expected_spf_domain(self, domain: CustomDomain) -> str:
if domain.partner_id is not None and domain.partner_id in self._partner_domains:
return self._partner_domains[domain.partner_id]
else:
return config.EMAIL_DOMAIN
def get_expected_spf_record(self, domain: CustomDomain) -> str:
spf_domain = self.get_expected_spf_domain(domain)
return f"v=spf1 include:{spf_domain} ~all"
def get_dkim_records(self, domain: CustomDomain) -> {str: str}: def get_dkim_records(self, domain: CustomDomain) -> {str: str}:
""" """
Get a list of dkim records to set up. Depending on the custom_domain, whether if it's from a partner or not, Get a list of dkim records to set up. Depending on the custom_domain, whether if it's from a partner or not,
@ -91,6 +122,12 @@ class CustomDomainValidation:
# Original DKIM record is not there, which means the DKIM config is not finished. Proceed with the # Original DKIM record is not there, which means the DKIM config is not finished. Proceed with the
# rest of the code path, returning the invalid records and clearing the flag # rest of the code path, returning the invalid records and clearing the flag
custom_domain.dkim_verified = len(invalid_records) == 0 custom_domain.dkim_verified = len(invalid_records) == 0
if custom_domain.dkim_verified:
emit_user_audit_log(
user=custom_domain.user,
action=UserAuditLogAction.VerifyCustomDomain,
message=f"Verified DKIM records for custom domain {custom_domain.id} ({custom_domain.domain})",
)
Session.commit() Session.commit()
return invalid_records return invalid_records
@ -107,6 +144,11 @@ class CustomDomainValidation:
if expected_verification_record in txt_records: if expected_verification_record in txt_records:
custom_domain.ownership_verified = True custom_domain.ownership_verified = True
emit_user_audit_log(
user=custom_domain.user,
action=UserAuditLogAction.VerifyCustomDomain,
message=f"Verified ownership for custom domain {custom_domain.id} ({custom_domain.domain})",
)
Session.commit() Session.commit()
return DomainValidationResult(success=True, errors=[]) return DomainValidationResult(success=True, errors=[])
else: else:
@ -116,14 +158,20 @@ class CustomDomainValidation:
self, custom_domain: CustomDomain self, custom_domain: CustomDomain
) -> DomainValidationResult: ) -> DomainValidationResult:
mx_domains = self._dns_client.get_mx_domains(custom_domain.domain) mx_domains = self._dns_client.get_mx_domains(custom_domain.domain)
expected_mx_records = self.get_expected_mx_records(custom_domain)
if not is_mx_equivalent(mx_domains, config.EMAIL_SERVERS_WITH_PRIORITY): if not is_mx_equivalent(mx_domains, expected_mx_records):
return DomainValidationResult( return DomainValidationResult(
success=False, success=False,
errors=[f"{priority} {domain}" for (priority, domain) in mx_domains], errors=[f"{record.priority} {record.domain}" for record in mx_domains],
) )
else: else:
custom_domain.verified = True custom_domain.verified = True
emit_user_audit_log(
user=custom_domain.user,
action=UserAuditLogAction.VerifyCustomDomain,
message=f"Verified MX records for custom domain {custom_domain.id} ({custom_domain.domain})",
)
Session.commit() Session.commit()
return DomainValidationResult(success=True, errors=[]) return DomainValidationResult(success=True, errors=[])
@ -131,16 +179,24 @@ class CustomDomainValidation:
self, custom_domain: CustomDomain self, custom_domain: CustomDomain
) -> DomainValidationResult: ) -> DomainValidationResult:
spf_domains = self._dns_client.get_spf_domain(custom_domain.domain) spf_domains = self._dns_client.get_spf_domain(custom_domain.domain)
if config.EMAIL_DOMAIN in spf_domains: expected_spf_domain = self.get_expected_spf_domain(custom_domain)
if expected_spf_domain in spf_domains:
custom_domain.spf_verified = True custom_domain.spf_verified = True
emit_user_audit_log(
user=custom_domain.user,
action=UserAuditLogAction.VerifyCustomDomain,
message=f"Verified SPF records for custom domain {custom_domain.id} ({custom_domain.domain})",
)
Session.commit() Session.commit()
return DomainValidationResult(success=True, errors=[]) return DomainValidationResult(success=True, errors=[])
else: else:
custom_domain.spf_verified = False custom_domain.spf_verified = False
Session.commit() Session.commit()
txt_records = self._dns_client.get_txt_record(custom_domain.domain)
cleaned_records = self.__clean_spf_records(txt_records, custom_domain)
return DomainValidationResult( return DomainValidationResult(
success=False, success=False,
errors=self._dns_client.get_txt_record(custom_domain.domain), errors=cleaned_records,
) )
def validate_dmarc_records( def validate_dmarc_records(
@ -149,9 +205,24 @@ class CustomDomainValidation:
txt_records = self._dns_client.get_txt_record("_dmarc." + custom_domain.domain) txt_records = self._dns_client.get_txt_record("_dmarc." + custom_domain.domain)
if DMARC_RECORD in txt_records: if DMARC_RECORD in txt_records:
custom_domain.dmarc_verified = True custom_domain.dmarc_verified = True
emit_user_audit_log(
user=custom_domain.user,
action=UserAuditLogAction.VerifyCustomDomain,
message=f"Verified DMARC records for custom domain {custom_domain.id} ({custom_domain.domain})",
)
Session.commit() Session.commit()
return DomainValidationResult(success=True, errors=[]) return DomainValidationResult(success=True, errors=[])
else: else:
custom_domain.dmarc_verified = False custom_domain.dmarc_verified = False
Session.commit() Session.commit()
return DomainValidationResult(success=False, errors=txt_records) return DomainValidationResult(success=False, errors=txt_records)
def __clean_spf_records(
self, txt_records: List[str], custom_domain: CustomDomain
) -> List[str]:
final_records = []
verification_record = self.get_ownership_verification_record(custom_domain)
for record in txt_records:
if record != verification_record:
final_records.append(record)
return final_records

View File

@ -1,5 +1,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from operator import or_ from operator import or_
from typing import Optional
from flask import render_template, request, redirect, flash from flask import render_template, request, redirect, flash
from flask import url_for from flask import url_for
@ -9,13 +10,11 @@ from sqlalchemy import and_, func, case
from wtforms import StringField, validators, ValidationError from wtforms import StringField, validators, ValidationError
# Need to import directly from config to allow modification from the tests # Need to import directly from config to allow modification from the tests
from app import config, parallel_limiter from app import config, parallel_limiter, contact_utils
from app.alias_audit_log_utils import emit_alias_audit_log, AliasAuditLogAction
from app.contact_utils import ContactCreateError
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
from app.email_utils import (
generate_reply_email,
parse_full_address,
)
from app.email_validation import is_valid_email from app.email_validation import is_valid_email
from app.errors import ( from app.errors import (
CannotCreateContactForReverseAlias, CannotCreateContactForReverseAlias,
@ -24,8 +23,8 @@ from app.errors import (
ErrContactAlreadyExists, ErrContactAlreadyExists,
) )
from app.log import LOG from app.log import LOG
from app.models import Alias, Contact, EmailLog, User from app.models import Alias, Contact, EmailLog
from app.utils import sanitize_email, CSRFValidationForm from app.utils import CSRFValidationForm
def email_validator(): def email_validator():
@ -51,7 +50,7 @@ def email_validator():
return _check return _check
def create_contact(user: User, alias: Alias, contact_address: str) -> Contact: def create_contact(alias: Alias, contact_address: str) -> Contact:
""" """
Create a contact for a user. Can be restricted for new free users by enabling DISABLE_CREATE_CONTACTS_FOR_FREE_USERS. Create a contact for a user. Can be restricted for new free users by enabling DISABLE_CREATE_CONTACTS_FOR_FREE_USERS.
Can throw exceptions: Can throw exceptions:
@ -61,37 +60,23 @@ def create_contact(user: User, alias: Alias, contact_address: str) -> Contact:
""" """
if not contact_address: if not contact_address:
raise ErrAddressInvalid("Empty address") raise ErrAddressInvalid("Empty address")
try: output = contact_utils.create_contact(email=contact_address, alias=alias)
contact_name, contact_email = parse_full_address(contact_address) if output.error == ContactCreateError.InvalidEmail:
except ValueError:
raise ErrAddressInvalid(contact_address) raise ErrAddressInvalid(contact_address)
elif output.error == ContactCreateError.NotAllowed:
contact_email = sanitize_email(contact_email)
if not is_valid_email(contact_email):
raise ErrAddressInvalid(contact_email)
contact = Contact.get_by(alias_id=alias.id, website_email=contact_email)
if contact:
raise ErrContactAlreadyExists(contact)
if not user.can_create_contacts():
raise ErrContactErrorUpgradeNeeded() raise ErrContactErrorUpgradeNeeded()
elif output.error is not None:
raise ErrAddressInvalid("Invalid address")
elif not output.created:
raise ErrContactAlreadyExists(output.contact)
contact = Contact.create( contact = output.contact
user_id=alias.user_id,
alias_id=alias.id,
website_email=contact_email,
name=contact_name,
reply_email=generate_reply_email(contact_email, alias),
)
LOG.d( LOG.d(
"create reverse-alias for %s %s, reverse alias:%s", "create reverse-alias for %s %s, reverse alias:%s",
contact_address, contact_address,
alias, alias,
contact.reply_email, contact.reply_email,
) )
Session.commit()
return contact return contact
@ -207,7 +192,7 @@ def get_contact_infos(
def delete_contact(alias: Alias, contact_id: int): def delete_contact(alias: Alias, contact_id: int):
contact = Contact.get(contact_id) contact: Optional[Contact] = Contact.get(contact_id)
if not contact: if not contact:
flash("Unknown error. Refresh the page", "warning") flash("Unknown error. Refresh the page", "warning")
@ -215,6 +200,11 @@ def delete_contact(alias: Alias, contact_id: int):
flash("You cannot delete reverse-alias", "warning") flash("You cannot delete reverse-alias", "warning")
else: else:
delete_contact_email = contact.website_email delete_contact_email = contact.website_email
emit_alias_audit_log(
alias=alias,
action=AliasAuditLogAction.DeleteContact,
message=f"Delete contact {contact_id} ({contact.email})",
)
Contact.delete(contact_id) Contact.delete(contact_id)
Session.commit() Session.commit()
@ -261,7 +251,7 @@ def alias_contact_manager(alias_id):
if new_contact_form.validate(): if new_contact_form.validate():
contact_address = new_contact_form.email.data.strip() contact_address = new_contact_form.email.data.strip()
try: try:
contact = create_contact(current_user, alias, contact_address) contact = create_contact(alias, contact_address)
except ( except (
ErrContactErrorUpgradeNeeded, ErrContactErrorUpgradeNeeded,
ErrAddressInvalid, ErrAddressInvalid,

View File

@ -7,6 +7,7 @@ from flask import render_template, redirect, url_for, flash, request
from flask_login import login_required, current_user from flask_login import login_required, current_user
from app import config from app import config
from app.alias_audit_log_utils import emit_alias_audit_log, AliasAuditLogAction
from app.alias_utils import transfer_alias from app.alias_utils import transfer_alias
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required from app.dashboard.views.enter_sudo import sudo_required
@ -57,6 +58,12 @@ def alias_transfer_send_route(alias_id):
transfer_token = f"{alias.id}.{secrets.token_urlsafe(32)}" transfer_token = f"{alias.id}.{secrets.token_urlsafe(32)}"
alias.transfer_token = hmac_alias_transfer_token(transfer_token) alias.transfer_token = hmac_alias_transfer_token(transfer_token)
alias.transfer_token_expiration = arrow.utcnow().shift(hours=24) alias.transfer_token_expiration = arrow.utcnow().shift(hours=24)
emit_alias_audit_log(
alias,
AliasAuditLogAction.InitiateTransferAlias,
"Initiated alias transfer",
)
Session.commit() Session.commit()
alias_transfer_url = ( alias_transfer_url = (
config.URL config.URL

View File

@ -1,8 +1,11 @@
from typing import Optional
from flask import render_template, request, redirect, url_for, flash from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import StringField, validators from wtforms import StringField, validators
from app.alias_audit_log_utils import emit_alias_audit_log, AliasAuditLogAction
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
from app.models import Contact from app.models import Contact
@ -20,7 +23,7 @@ class PGPContactForm(FlaskForm):
@dashboard_bp.route("/contact/<int:contact_id>/", methods=["GET", "POST"]) @dashboard_bp.route("/contact/<int:contact_id>/", methods=["GET", "POST"])
@login_required @login_required
def contact_detail_route(contact_id): def contact_detail_route(contact_id):
contact = Contact.get(contact_id) contact: Optional[Contact] = Contact.get(contact_id)
if not contact or contact.user_id != current_user.id: if not contact or contact.user_id != current_user.id:
flash("You cannot see this page", "warning") flash("You cannot see this page", "warning")
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
@ -50,6 +53,11 @@ def contact_detail_route(contact_id):
except PGPException: except PGPException:
flash("Cannot add the public key, please verify it", "error") flash("Cannot add the public key, please verify it", "error")
else: else:
emit_alias_audit_log(
alias=alias,
action=AliasAuditLogAction.UpdateContact,
message=f"Added PGP key {contact.pgp_public_key} for contact {contact_id} ({contact.email})",
)
Session.commit() Session.commit()
flash( flash(
f"PGP public key for {contact.email} is saved successfully", f"PGP public key for {contact.email} is saved successfully",
@ -62,6 +70,11 @@ def contact_detail_route(contact_id):
) )
elif pgp_form.action.data == "remove": elif pgp_form.action.data == "remove":
# Free user can decide to remove contact PGP key # Free user can decide to remove contact PGP key
emit_alias_audit_log(
alias=alias,
action=AliasAuditLogAction.UpdateContact,
message=f"Removed PGP key {contact.pgp_public_key} for contact {contact_id} ({contact.email})",
)
contact.pgp_public_key = None contact.pgp_public_key = None
contact.pgp_finger_print = None contact.pgp_finger_print = None
Session.commit() Session.commit()

View File

@ -21,7 +21,9 @@ class NewCustomDomainForm(FlaskForm):
@parallel_limiter.lock(only_when=lambda: request.method == "POST") @parallel_limiter.lock(only_when=lambda: request.method == "POST")
def custom_domain(): def custom_domain():
custom_domains = CustomDomain.filter_by( custom_domains = CustomDomain.filter_by(
user_id=current_user.id, is_sl_subdomain=False user_id=current_user.id,
is_sl_subdomain=False,
pending_deletion=False,
).all() ).all()
new_custom_domain_form = NewCustomDomainForm() new_custom_domain_form = NewCustomDomainForm()

View File

@ -8,6 +8,7 @@ from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required from app.dashboard.views.enter_sudo import sudo_required
from app.log import LOG from app.log import LOG
from app.models import Subscription, Job from app.models import Subscription, Job
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
class DeleteDirForm(FlaskForm): class DeleteDirForm(FlaskForm):
@ -33,6 +34,11 @@ def delete_account():
# Schedule delete account job # Schedule delete account job
LOG.w("schedule delete account job for %s", current_user) LOG.w("schedule delete account job for %s", current_user)
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UserMarkedForDeletion,
message=f"User {current_user.id} ({current_user.email}) marked for deletion via webapp",
)
Job.create( Job.create(
name=JOB_DELETE_ACCOUNT, name=JOB_DELETE_ACCOUNT,
payload={"user_id": current_user.id}, payload={"user_id": current_user.id},

View File

@ -1,3 +1,5 @@
from typing import Optional
from flask import render_template, request, redirect, url_for, flash from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
@ -20,6 +22,7 @@ from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
from app.errors import DirectoryInTrashError from app.errors import DirectoryInTrashError
from app.models import Directory, Mailbox, DirectoryMailbox from app.models import Directory, Mailbox, DirectoryMailbox
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
class NewDirForm(FlaskForm): class NewDirForm(FlaskForm):
@ -69,7 +72,9 @@ def directory():
if not delete_dir_form.validate(): if not delete_dir_form.validate():
flash("Invalid request", "warning") flash("Invalid request", "warning")
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
dir_obj = Directory.get(delete_dir_form.directory_id.data) dir_obj: Optional[Directory] = Directory.get(
delete_dir_form.directory_id.data
)
if not dir_obj: if not dir_obj:
flash("Unknown error. Refresh the page", "warning") flash("Unknown error. Refresh the page", "warning")
@ -79,6 +84,11 @@ def directory():
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
name = dir_obj.name name = dir_obj.name
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.DeleteDirectory,
message=f"Delete directory {dir_obj.id} ({dir_obj.name})",
)
Directory.delete(dir_obj.id) Directory.delete(dir_obj.id)
Session.commit() Session.commit()
flash(f"Directory {name} has been deleted", "success") flash(f"Directory {name} has been deleted", "success")
@ -90,7 +100,7 @@ def directory():
flash("Invalid request", "warning") flash("Invalid request", "warning")
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
dir_id = toggle_dir_form.directory_id.data dir_id = toggle_dir_form.directory_id.data
dir_obj = Directory.get(dir_id) dir_obj: Optional[Directory] = Directory.get(dir_id)
if not dir_obj or dir_obj.user_id != current_user.id: if not dir_obj or dir_obj.user_id != current_user.id:
flash("Unknown error. Refresh the page", "warning") flash("Unknown error. Refresh the page", "warning")
@ -103,6 +113,11 @@ def directory():
dir_obj.disabled = True dir_obj.disabled = True
flash(f"On-the-fly is disabled for {dir_obj.name}", "warning") flash(f"On-the-fly is disabled for {dir_obj.name}", "warning")
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateDirectory,
message=f"Updated directory {dir_obj.id} ({dir_obj.name}) set disabled = {dir_obj.disabled}",
)
Session.commit() Session.commit()
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
@ -112,7 +127,7 @@ def directory():
flash("Invalid request", "warning") flash("Invalid request", "warning")
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
dir_id = update_dir_form.directory_id.data dir_id = update_dir_form.directory_id.data
dir_obj = Directory.get(dir_id) dir_obj: Optional[Directory] = Directory.get(dir_id)
if not dir_obj or dir_obj.user_id != current_user.id: if not dir_obj or dir_obj.user_id != current_user.id:
flash("Unknown error. Refresh the page", "warning") flash("Unknown error. Refresh the page", "warning")
@ -143,6 +158,12 @@ def directory():
for mailbox in mailboxes: for mailbox in mailboxes:
DirectoryMailbox.create(directory_id=dir_obj.id, mailbox_id=mailbox.id) DirectoryMailbox.create(directory_id=dir_obj.id, mailbox_id=mailbox.id)
mailboxes_as_str = ",".join(map(str, mailbox_ids))
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateDirectory,
message=f"Updated directory {dir_obj.id} ({dir_obj.name}) mailboxes ({mailboxes_as_str})",
)
Session.commit() Session.commit()
flash(f"Directory {dir_obj.name} has been updated", "success") flash(f"Directory {dir_obj.name} has been updated", "success")
@ -181,6 +202,11 @@ def directory():
new_dir = Directory.create( new_dir = Directory.create(
name=new_dir_name, user_id=current_user.id name=new_dir_name, user_id=current_user.id
) )
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.CreateDirectory,
message=f"New directory {new_dir.name} ({new_dir.name})",
)
except DirectoryInTrashError: except DirectoryInTrashError:
flash( flash(
f"{new_dir_name} has been used before and cannot be reused", f"{new_dir_name} has been used before and cannot be reused",

View File

@ -7,7 +7,7 @@ from wtforms import StringField, validators, IntegerField
from app.constants import DMARC_RECORD from app.constants import DMARC_RECORD
from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN
from app.custom_domain_utils import delete_custom_domain from app.custom_domain_utils import delete_custom_domain, set_custom_domain_mailboxes
from app.custom_domain_validation import CustomDomainValidation from app.custom_domain_validation import CustomDomainValidation
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
@ -16,11 +16,11 @@ from app.models import (
Alias, Alias,
DomainDeletedAlias, DomainDeletedAlias,
Mailbox, Mailbox,
DomainMailbox,
AutoCreateRule, AutoCreateRule,
AutoCreateRuleMailbox, AutoCreateRuleMailbox,
) )
from app.regex_utils import regex_match from app.regex_utils import regex_match
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.utils import random_string, CSRFValidationForm from app.utils import random_string, CSRFValidationForm
@ -37,8 +37,6 @@ def domain_detail_dns(custom_domain_id):
custom_domain.ownership_txt_token = random_string(30) custom_domain.ownership_txt_token = random_string(30)
Session.commit() Session.commit()
spf_record = f"v=spf1 include:{EMAIL_DOMAIN} ~all"
domain_validator = CustomDomainValidation(EMAIL_DOMAIN) domain_validator = CustomDomainValidation(EMAIL_DOMAIN)
csrf_form = CSRFValidationForm() csrf_form = CSRFValidationForm()
@ -142,7 +140,9 @@ def domain_detail_dns(custom_domain_id):
ownership_record=domain_validator.get_ownership_verification_record( ownership_record=domain_validator.get_ownership_verification_record(
custom_domain custom_domain
), ),
expected_mx_records=domain_validator.get_expected_mx_records(custom_domain),
dkim_records=domain_validator.get_dkim_records(custom_domain), dkim_records=domain_validator.get_dkim_records(custom_domain),
spf_record=domain_validator.get_expected_spf_record(custom_domain),
dmarc_record=DMARC_RECORD, dmarc_record=DMARC_RECORD,
**locals(), **locals(),
) )
@ -165,6 +165,11 @@ def domain_detail(custom_domain_id):
return redirect(request.url) return redirect(request.url)
if request.form.get("form-name") == "switch-catch-all": if request.form.get("form-name") == "switch-catch-all":
custom_domain.catch_all = not custom_domain.catch_all custom_domain.catch_all = not custom_domain.catch_all
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateCustomDomain,
message=f"Switched custom domain {custom_domain.id} ({custom_domain.domain}) catch all to {custom_domain.catch_all}",
)
Session.commit() Session.commit()
if custom_domain.catch_all: if custom_domain.catch_all:
@ -183,6 +188,11 @@ def domain_detail(custom_domain_id):
elif request.form.get("form-name") == "set-name": elif request.form.get("form-name") == "set-name":
if request.form.get("action") == "save": if request.form.get("action") == "save":
custom_domain.name = request.form.get("alias-name").replace("\n", "") custom_domain.name = request.form.get("alias-name").replace("\n", "")
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateCustomDomain,
message=f"Switched custom domain {custom_domain.id} ({custom_domain.domain}) name",
)
Session.commit() Session.commit()
flash( flash(
f"Default alias name for Domain {custom_domain.domain} has been set", f"Default alias name for Domain {custom_domain.domain} has been set",
@ -190,6 +200,11 @@ def domain_detail(custom_domain_id):
) )
else: else:
custom_domain.name = None custom_domain.name = None
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateCustomDomain,
message=f"Cleared custom domain {custom_domain.id} ({custom_domain.domain}) name",
)
Session.commit() Session.commit()
flash( flash(
f"Default alias name for Domain {custom_domain.domain} has been removed", f"Default alias name for Domain {custom_domain.domain} has been removed",
@ -203,6 +218,11 @@ def domain_detail(custom_domain_id):
custom_domain.random_prefix_generation = ( custom_domain.random_prefix_generation = (
not custom_domain.random_prefix_generation not custom_domain.random_prefix_generation
) )
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateCustomDomain,
message=f"Switched custom domain {custom_domain.id} ({custom_domain.domain}) random prefix generation to {custom_domain.random_prefix_generation}",
)
Session.commit() Session.commit()
if custom_domain.random_prefix_generation: if custom_domain.random_prefix_generation:
@ -220,40 +240,16 @@ def domain_detail(custom_domain_id):
) )
elif request.form.get("form-name") == "update": elif request.form.get("form-name") == "update":
mailbox_ids = request.form.getlist("mailbox_ids") mailbox_ids = request.form.getlist("mailbox_ids")
# check if mailbox is not tempered with result = set_custom_domain_mailboxes(
mailboxes = [] user_id=current_user.id,
for mailbox_id in mailbox_ids: custom_domain=custom_domain,
mailbox = Mailbox.get(mailbox_id) mailbox_ids=mailbox_ids,
if ( )
not mailbox
or mailbox.user_id != current_user.id
or not mailbox.verified
):
flash("Something went wrong, please retry", "warning")
return redirect(
url_for(
"dashboard.domain_detail", custom_domain_id=custom_domain.id
)
)
mailboxes.append(mailbox)
if not mailboxes: if result.success:
flash("You must select at least 1 mailbox", "warning") flash(f"{custom_domain.domain} mailboxes has been updated", "success")
return redirect( else:
url_for( flash(result.reason.value, "warning")
"dashboard.domain_detail", custom_domain_id=custom_domain.id
)
)
# first remove all existing domain-mailboxes links
DomainMailbox.filter_by(domain_id=custom_domain.id).delete()
Session.flush()
for mailbox in mailboxes:
DomainMailbox.create(domain_id=custom_domain.id, mailbox_id=mailbox.id)
Session.commit()
flash(f"{custom_domain.domain} mailboxes has been updated", "success")
return redirect( return redirect(
url_for("dashboard.domain_detail", custom_domain_id=custom_domain.id) url_for("dashboard.domain_detail", custom_domain_id=custom_domain.id)

View File

@ -149,7 +149,9 @@ def index():
) )
flash(f"Alias {email} has been deleted", "success") flash(f"Alias {email} has been deleted", "success")
elif request.form.get("form-name") == "disable-alias": elif request.form.get("form-name") == "disable-alias":
alias_utils.change_alias_status(alias, enabled=False) alias_utils.change_alias_status(
alias, enabled=False, message="Set enabled=False from dashboard"
)
Session.commit() Session.commit()
flash(f"Alias {alias.email} has been disabled", "success") flash(f"Alias {alias.email} has been disabled", "success")

View File

@ -1,6 +1,7 @@
import base64 import base64
import binascii import binascii
import json import json
from typing import Optional
from flask import render_template, request, redirect, url_for, flash from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user from flask_login import login_required, current_user
@ -15,6 +16,7 @@ from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import Mailbox from app.models import Mailbox
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.utils import CSRFValidationForm from app.utils import CSRFValidationForm
@ -151,7 +153,7 @@ def verify_with_signed_secret(request: str):
flash("Invalid link. Please delete and re-add your mailbox", "error") flash("Invalid link. Please delete and re-add your mailbox", "error")
return redirect(url_for("dashboard.mailbox_route")) return redirect(url_for("dashboard.mailbox_route"))
mailbox_id = mailbox_data[0] mailbox_id = mailbox_data[0]
mailbox = Mailbox.get(mailbox_id) mailbox: Optional[Mailbox] = Mailbox.get(mailbox_id)
if not mailbox: if not mailbox:
flash("Invalid link", "error") flash("Invalid link", "error")
return redirect(url_for("dashboard.mailbox_route")) return redirect(url_for("dashboard.mailbox_route"))
@ -161,6 +163,11 @@ def verify_with_signed_secret(request: str):
return redirect(url_for("dashboard.mailbox_route")) return redirect(url_for("dashboard.mailbox_route"))
mailbox.verified = True mailbox.verified = True
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.VerifyMailbox,
message=f"Verified mailbox {mailbox.id} ({mailbox.email})",
)
Session.commit() Session.commit()
LOG.d("Mailbox %s is verified", mailbox) LOG.d("Mailbox %s is verified", mailbox)

View File

@ -16,10 +16,11 @@ from app.db import Session
from app.email_utils import email_can_be_used_as_mailbox from app.email_utils import email_can_be_used_as_mailbox
from app.email_utils import mailbox_already_used, render, send_email from app.email_utils import mailbox_already_used, render, send_email
from app.extensions import limiter from app.extensions import limiter
from app.log import LOG from app.mailbox_utils import perform_mailbox_email_change, MailboxEmailChangeError
from app.models import Alias, AuthorizedAddress from app.models import Alias, AuthorizedAddress
from app.models import Mailbox from app.models import Mailbox
from app.pgp_utils import PGPException, load_public_key_and_check from app.pgp_utils import PGPException, load_public_key_and_check
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.utils import sanitize_email, CSRFValidationForm from app.utils import sanitize_email, CSRFValidationForm
@ -88,8 +89,12 @@ def mailbox_detail_route(mailbox_id):
flash("SPF enforcement globally not enabled", "error") flash("SPF enforcement globally not enabled", "error")
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
mailbox.force_spf = ( force_spf_value = request.form.get("spf-status") == "on"
True if request.form.get("spf-status") == "on" else False mailbox.force_spf = force_spf_value
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Set force_spf to {force_spf_value} on mailbox {mailbox_id} ({mailbox.email})",
) )
Session.commit() Session.commit()
flash( flash(
@ -113,6 +118,11 @@ def mailbox_detail_route(mailbox_id):
if AuthorizedAddress.get_by(mailbox_id=mailbox.id, email=address): if AuthorizedAddress.get_by(mailbox_id=mailbox.id, email=address):
flash(f"{address} already added", "error") flash(f"{address} already added", "error")
else: else:
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Add authorized address {address} to mailbox {mailbox_id} ({mailbox.email})",
)
AuthorizedAddress.create( AuthorizedAddress.create(
user_id=current_user.id, user_id=current_user.id,
mailbox_id=mailbox.id, mailbox_id=mailbox.id,
@ -133,6 +143,11 @@ def mailbox_detail_route(mailbox_id):
flash("Unknown error. Refresh the page", "warning") flash("Unknown error. Refresh the page", "warning")
else: else:
address = authorized_address.email address = authorized_address.email
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Remove authorized address {address} from mailbox {mailbox_id} ({mailbox.email})",
)
AuthorizedAddress.delete(authorized_address_id) AuthorizedAddress.delete(authorized_address_id)
Session.commit() Session.commit()
flash(f"{address} has been deleted", "success") flash(f"{address} has been deleted", "success")
@ -165,6 +180,11 @@ def mailbox_detail_route(mailbox_id):
except PGPException: except PGPException:
flash("Cannot add the public key, please verify it", "error") flash("Cannot add the public key, please verify it", "error")
else: else:
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Add PGP Key {mailbox.pgp_finger_print} to mailbox {mailbox_id} ({mailbox.email})",
)
Session.commit() Session.commit()
flash("Your PGP public key is saved successfully", "success") flash("Your PGP public key is saved successfully", "success")
return redirect( return redirect(
@ -172,6 +192,11 @@ def mailbox_detail_route(mailbox_id):
) )
elif request.form.get("action") == "remove": elif request.form.get("action") == "remove":
# Free user can decide to remove their added PGP key # Free user can decide to remove their added PGP key
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Remove PGP Key {mailbox.pgp_finger_print} from mailbox {mailbox_id} ({mailbox.email})",
)
mailbox.pgp_public_key = None mailbox.pgp_public_key = None
mailbox.pgp_finger_print = None mailbox.pgp_finger_print = None
mailbox.disable_pgp = False mailbox.disable_pgp = False
@ -191,9 +216,19 @@ def mailbox_detail_route(mailbox_id):
) )
else: else:
mailbox.disable_pgp = False mailbox.disable_pgp = False
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Enabled PGP for mailbox {mailbox_id} ({mailbox.email})",
)
flash(f"PGP is enabled on {mailbox.email}", "info") flash(f"PGP is enabled on {mailbox.email}", "info")
else: else:
mailbox.disable_pgp = True mailbox.disable_pgp = True
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Disabled PGP for mailbox {mailbox_id} ({mailbox.email})",
)
flash(f"PGP is disabled on {mailbox.email}", "info") flash(f"PGP is disabled on {mailbox.email}", "info")
Session.commit() Session.commit()
@ -203,6 +238,11 @@ def mailbox_detail_route(mailbox_id):
elif request.form.get("form-name") == "generic-subject": elif request.form.get("form-name") == "generic-subject":
if request.form.get("action") == "save": if request.form.get("action") == "save":
mailbox.generic_subject = request.form.get("generic-subject") mailbox.generic_subject = request.form.get("generic-subject")
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Set generic subject for mailbox {mailbox_id} ({mailbox.email})",
)
Session.commit() Session.commit()
flash("Generic subject is enabled", "success") flash("Generic subject is enabled", "success")
return redirect( return redirect(
@ -210,6 +250,11 @@ def mailbox_detail_route(mailbox_id):
) )
elif request.form.get("action") == "remove": elif request.form.get("action") == "remove":
mailbox.generic_subject = None mailbox.generic_subject = None
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Remove generic subject for mailbox {mailbox_id} ({mailbox.email})",
)
Session.commit() Session.commit()
flash("Generic subject is disabled", "success") flash("Generic subject is disabled", "success")
return redirect( return redirect(
@ -272,7 +317,7 @@ def cancel_mailbox_change_route(mailbox_id):
@dashboard_bp.route("/mailbox/confirm_change") @dashboard_bp.route("/mailbox/confirm_change")
def mailbox_confirm_change_route(): def mailbox_confirm_email_change_route():
s = TimestampSigner(MAILBOX_SECRET) s = TimestampSigner(MAILBOX_SECRET)
signed_mailbox_id = request.args.get("mailbox_id") signed_mailbox_id = request.args.get("mailbox_id")
@ -281,30 +326,20 @@ def mailbox_confirm_change_route():
except Exception: except Exception:
flash("Invalid link", "error") flash("Invalid link", "error")
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
else:
mailbox = Mailbox.get(mailbox_id)
# new_email can be None if user cancels change in the meantime res = perform_mailbox_email_change(mailbox_id)
if mailbox and mailbox.new_email:
user = mailbox.user
if Mailbox.get_by(email=mailbox.new_email, user_id=user.id):
flash(f"{mailbox.new_email} is already used", "error")
return redirect(
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox.id)
)
mailbox.email = mailbox.new_email flash(res.message, res.message_category)
mailbox.new_email = None if res.error:
if res.error == MailboxEmailChangeError.EmailAlreadyUsed:
# mark mailbox as verified if the change request is sent from an unverified mailbox
mailbox.verified = True
Session.commit()
LOG.d("Mailbox change %s is verified", mailbox)
flash(f"The {mailbox.email} is updated", "success")
return redirect( return redirect(
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox.id) url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id)
) )
else: elif res.error == MailboxEmailChangeError.InvalidId:
flash("Invalid link", "error")
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
else:
raise Exception("Unhandled MailboxEmailChangeError")
else:
return redirect(
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id)
)

View File

@ -11,6 +11,7 @@ from app.dashboard.base import dashboard_bp
from app.errors import SubdomainInTrashError from app.errors import SubdomainInTrashError
from app.log import LOG from app.log import LOG
from app.models import CustomDomain, Mailbox, SLDomain from app.models import CustomDomain, Mailbox, SLDomain
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
# Only lowercase letters, numbers, dashes (-) are currently supported # Only lowercase letters, numbers, dashes (-) are currently supported
_SUBDOMAIN_PATTERN = r"[0-9a-z-]{1,}" _SUBDOMAIN_PATTERN = r"[0-9a-z-]{1,}"
@ -102,6 +103,12 @@ def subdomain_route():
ownership_verified=True, ownership_verified=True,
commit=True, commit=True,
) )
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.CreateCustomDomain,
message=f"Create subdomain {new_custom_domain.id} ({full_domain})",
commit=True,
)
except SubdomainInTrashError: except SubdomainInTrashError:
flash( flash(
f"{full_domain} has been used before and cannot be reused", f"{full_domain} has been used before and cannot be reused",

View File

@ -32,7 +32,9 @@ def unsubscribe(alias_id):
# automatic unsubscribe, according to https://tools.ietf.org/html/rfc8058 # automatic unsubscribe, according to https://tools.ietf.org/html/rfc8058
if request.method == "POST": if request.method == "POST":
alias_utils.change_alias_status(alias, False) alias_utils.change_alias_status(
alias, enabled=False, message="Set enabled=False from unsubscribe request"
)
flash(f"Alias {alias.email} has been blocked", "success") flash(f"Alias {alias.email} has been blocked", "success")
Session.commit() Session.commit()

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Tuple, Optional from dataclasses import dataclass
from typing import List, Optional
import dns.resolver import dns.resolver
@ -8,8 +9,14 @@ from app.config import NAMESERVERS
_include_spf = "include:" _include_spf = "include:"
@dataclass
class MxRecord:
priority: int
domain: str
def is_mx_equivalent( def is_mx_equivalent(
mx_domains: List[Tuple[int, str]], ref_mx_domains: List[Tuple[int, str]] mx_domains: List[MxRecord], ref_mx_domains: List[MxRecord]
) -> bool: ) -> bool:
""" """
Compare mx_domains with ref_mx_domains to see if they are equivalent. Compare mx_domains with ref_mx_domains to see if they are equivalent.
@ -18,14 +25,14 @@ def is_mx_equivalent(
The priority order is taken into account but not the priority number. The priority order is taken into account but not the priority number.
For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)] For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)]
""" """
mx_domains = sorted(mx_domains, key=lambda x: x[0]) mx_domains = sorted(mx_domains, key=lambda x: x.priority)
ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x[0]) ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x.priority)
if len(mx_domains) < len(ref_mx_domains): if len(mx_domains) < len(ref_mx_domains):
return False return False
for i in range(len(ref_mx_domains)): for actual, expected in zip(mx_domains, ref_mx_domains):
if mx_domains[i][1] != ref_mx_domains[i][1]: if actual.domain != expected.domain:
return False return False
return True return True
@ -37,7 +44,7 @@ class DNSClient(ABC):
pass pass
@abstractmethod @abstractmethod
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]: def get_mx_domains(self, hostname: str) -> List[MxRecord]:
pass pass
def get_spf_domain(self, hostname: str) -> List[str]: def get_spf_domain(self, hostname: str) -> List[str]:
@ -81,7 +88,7 @@ class NetworkDNSClient(DNSClient):
except Exception: except Exception:
return None return None
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]: def get_mx_domains(self, hostname: str) -> List[MxRecord]:
""" """
return list of (priority, domain name) sorted by priority (lowest priority first) return list of (priority, domain name) sorted by priority (lowest priority first)
domain name ends with a "." at the end. domain name ends with a "." at the end.
@ -92,14 +99,14 @@ class NetworkDNSClient(DNSClient):
for a in answers: for a in answers:
record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.' record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
parts = record.split(" ") parts = record.split(" ")
ret.append((int(parts[0]), parts[1])) ret.append(MxRecord(priority=int(parts[0]), domain=parts[1]))
return sorted(ret, key=lambda x: x[0]) return sorted(ret, key=lambda x: x.priority)
except Exception: except Exception:
return [] return []
def get_txt_record(self, hostname: str) -> List[str]: def get_txt_record(self, hostname: str) -> List[str]:
try: try:
answers = self._resolver.resolve(hostname, "TXT", search=True) answers = self._resolver.resolve(hostname, "TXT", search=False)
ret = [] ret = []
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
for record in a.strings: for record in a.strings:
@ -112,14 +119,14 @@ class NetworkDNSClient(DNSClient):
class InMemoryDNSClient(DNSClient): class InMemoryDNSClient(DNSClient):
def __init__(self): def __init__(self):
self.cname_records: dict[str, Optional[str]] = {} self.cname_records: dict[str, Optional[str]] = {}
self.mx_records: dict[str, List[Tuple[int, str]]] = {} self.mx_records: dict[str, List[MxRecord]] = {}
self.spf_records: dict[str, List[str]] = {} self.spf_records: dict[str, List[str]] = {}
self.txt_records: dict[str, List[str]] = {} self.txt_records: dict[str, List[str]] = {}
def set_cname_record(self, hostname: str, cname: str): def set_cname_record(self, hostname: str, cname: str):
self.cname_records[hostname] = cname self.cname_records[hostname] = cname
def set_mx_records(self, hostname: str, mx_list: List[Tuple[int, str]]): def set_mx_records(self, hostname: str, mx_list: List[MxRecord]):
self.mx_records[hostname] = mx_list self.mx_records[hostname] = mx_list
def set_txt_record(self, hostname: str, txt_list: List[str]): def set_txt_record(self, hostname: str, txt_list: List[str]):
@ -128,9 +135,9 @@ class InMemoryDNSClient(DNSClient):
def get_cname_record(self, hostname: str) -> Optional[str]: def get_cname_record(self, hostname: str) -> Optional[str]:
return self.cname_records.get(hostname) return self.cname_records.get(hostname)
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]: def get_mx_domains(self, hostname: str) -> List[MxRecord]:
mx_list = self.mx_records.get(hostname, []) mx_list = self.mx_records.get(hostname, [])
return sorted(mx_list, key=lambda x: x[0]) return sorted(mx_list, key=lambda x: x.priority)
def get_txt_record(self, hostname: str) -> List[str]: def get_txt_record(self, hostname: str) -> List[str]:
return self.txt_records.get(hostname, []) return self.txt_records.get(hostname, [])
@ -140,5 +147,5 @@ def get_network_dns_client() -> NetworkDNSClient:
return NetworkDNSClient(NAMESERVERS) return NetworkDNSClient(NAMESERVERS)
def get_mx_domains(hostname: str) -> [(int, str)]: def get_mx_domains(hostname: str) -> List[MxRecord]:
return get_network_dns_client().get_mx_domains(hostname) return get_network_dns_client().get_mx_domains(hostname)

View File

@ -592,7 +592,7 @@ def email_can_be_used_as_mailbox(email_address: str) -> bool:
from app.models import CustomDomain from app.models import CustomDomain
if CustomDomain.get_by(domain=domain, verified=True): if CustomDomain.get_by(domain=domain, is_sl_subdomain=True, verified=True):
LOG.d("domain %s is a SimpleLogin custom domain", domain) LOG.d("domain %s is a SimpleLogin custom domain", domain)
return False return False
@ -657,7 +657,7 @@ def get_mx_domain_list(domain) -> [str]:
""" """
priority_domains = get_mx_domains(domain) priority_domains = get_mx_domains(domain)
return [d[:-1] for _, d in priority_domains] return [d.domain[:-1] for d in priority_domains]
def personal_email_already_used(email_address: str) -> bool: def personal_email_already_used(email_address: str) -> bool:

View File

@ -64,10 +64,6 @@ class EventDispatcher:
) )
return return
if config.EVENT_WEBHOOK_ENABLED_USER_IDS is not None:
if user.id not in config.EVENT_WEBHOOK_ENABLED_USER_IDS:
return
partner_user = EventDispatcher.__partner_user(user.id) partner_user = EventDispatcher.__partner_user(user.id)
if not partner_user: if not partner_user:
LOG.i(f"Not sending events because there's no partner user for user {user}") LOG.i(f"Not sending events because there's no partner user for user {user}")

View File

@ -103,7 +103,9 @@ class UnsubscribeHandler:
): ):
return status.E509 return status.E509
LOG.i(f"User disabled alias {alias} via unsubscribe header") LOG.i(f"User disabled alias {alias} via unsubscribe header")
alias_utils.change_alias_status(alias, enabled=False) alias_utils.change_alias_status(
alias, enabled=False, message="Set enabled=False via unsubscribe header"
)
Session.commit() Session.commit()
enable_alias_url = config.URL + f"/dashboard/?highlight_alias_id={alias.id}" enable_alias_url = config.URL + f"/dashboard/?highlight_alias_id={alias.id}"
for mailbox in alias.mailboxes: for mailbox in alias.mailboxes:

View File

@ -1,6 +1,7 @@
import dataclasses import dataclasses
import secrets import secrets
import random import random
from enum import Enum
from typing import Optional from typing import Optional
import arrow import arrow
@ -16,6 +17,7 @@ from app.email_utils import (
from app.email_validation import is_valid_email from app.email_validation import is_valid_email
from app.log import LOG from app.log import LOG
from app.models import User, Mailbox, Job, MailboxActivation from app.models import User, Mailbox, Job, MailboxActivation
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
@dataclasses.dataclass @dataclasses.dataclass
@ -70,9 +72,15 @@ def create_mailbox(
f"User {user} has tried to create mailbox with {email} but email is invalid" f"User {user} has tried to create mailbox with {email} but email is invalid"
) )
raise MailboxError("Invalid email") raise MailboxError("Invalid email")
new_mailbox = Mailbox.create( new_mailbox: Mailbox = Mailbox.create(
email=email, user_id=user.id, verified=verified, commit=True email=email, user_id=user.id, verified=verified, commit=True
) )
emit_user_audit_log(
user=user,
action=UserAuditLogAction.CreateMailbox,
message=f"Create mailbox {new_mailbox.id} ({new_mailbox.email}). Verified={verified}",
commit=True,
)
if verified: if verified:
LOG.i(f"User {user} as created a pre-verified mailbox with {email}") LOG.i(f"User {user} as created a pre-verified mailbox with {email}")
@ -129,7 +137,7 @@ def delete_mailbox(
if not transfer_mailbox.verified: if not transfer_mailbox.verified:
LOG.i(f"User {user} has tried to transfer to a non verified mailbox") LOG.i(f"User {user} has tried to transfer to a non verified mailbox")
MailboxError("Your new mailbox is not verified") raise MailboxError("Your new mailbox is not verified")
# Schedule delete account job # Schedule delete account job
LOG.i( LOG.i(
@ -204,6 +212,11 @@ def verify_mailbox_code(user: User, mailbox_id: int, code: str) -> Mailbox:
raise CannotVerifyError("Invalid activation code") raise CannotVerifyError("Invalid activation code")
LOG.i(f"User {user} has verified mailbox {mailbox_id}") LOG.i(f"User {user} has verified mailbox {mailbox_id}")
mailbox.verified = True mailbox.verified = True
emit_user_audit_log(
user=user,
action=UserAuditLogAction.VerifyMailbox,
message=f"Verify mailbox {mailbox_id} ({mailbox.email})",
)
clear_activation_codes_for_mailbox(mailbox) clear_activation_codes_for_mailbox(mailbox)
return mailbox return mailbox
@ -213,7 +226,10 @@ def generate_activation_code(
) -> MailboxActivation: ) -> MailboxActivation:
clear_activation_codes_for_mailbox(mailbox) clear_activation_codes_for_mailbox(mailbox)
if use_digit_code: if use_digit_code:
code = "{:06d}".format(random.randint(1, 999999)) if config.MAILBOX_VERIFICATION_OVERRIDE_CODE:
code = config.MAILBOX_VERIFICATION_OVERRIDE_CODE
else:
code = "{:06d}".format(random.randint(1, 999999))
else: else:
code = secrets.token_urlsafe(16) code = secrets.token_urlsafe(16)
return MailboxActivation.create( return MailboxActivation.create(
@ -258,3 +274,54 @@ def send_verification_email(
mailbox_email=mailbox.email, mailbox_email=mailbox.email,
), ),
) )
class MailboxEmailChangeError(Enum):
InvalidId = 1
EmailAlreadyUsed = 2
@dataclasses.dataclass
class MailboxEmailChangeResult:
error: Optional[MailboxEmailChangeError]
message: str
message_category: str
def perform_mailbox_email_change(mailbox_id: int) -> MailboxEmailChangeResult:
mailbox: Optional[Mailbox] = Mailbox.get(mailbox_id)
# new_email can be None if user cancels change in the meantime
if mailbox and mailbox.new_email:
user = mailbox.user
if Mailbox.get_by(email=mailbox.new_email, user_id=user.id):
return MailboxEmailChangeResult(
error=MailboxEmailChangeError.EmailAlreadyUsed,
message=f"{mailbox.new_email} is already used",
message_category="error",
)
emit_user_audit_log(
user=user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Change mailbox email for mailbox {mailbox_id} (old={mailbox.email} | new={mailbox.new_email})",
)
mailbox.email = mailbox.new_email
mailbox.new_email = None
# mark mailbox as verified if the change request is sent from an unverified mailbox
mailbox.verified = True
Session.commit()
LOG.d("Mailbox change %s is verified", mailbox)
return MailboxEmailChangeResult(
error=None,
message=f"The {mailbox.email} is updated",
message_category="success",
)
else:
return MailboxEmailChangeResult(
error=MailboxEmailChangeError.InvalidId,
message="Invalid link",
message_category="error",
)

View File

@ -336,7 +336,7 @@ class Fido(Base, ModelMixin):
class User(Base, ModelMixin, UserMixin, PasswordOracle): class User(Base, ModelMixin, UserMixin, PasswordOracle):
__tablename__ = "users" __tablename__ = "users"
FLAG_FREE_DISABLE_CREATE_ALIAS = 1 << 0 FLAG_DISABLE_CREATE_CONTACTS = 1 << 0
FLAG_CREATED_FROM_PARTNER = 1 << 1 FLAG_CREATED_FROM_PARTNER = 1 << 1
FLAG_FREE_OLD_ALIAS_LIMIT = 1 << 2 FLAG_FREE_OLD_ALIAS_LIMIT = 1 << 2
FLAG_CREATED_ALIAS_FROM_PARTNER = 1 << 3 FLAG_CREATED_ALIAS_FROM_PARTNER = 1 << 3
@ -543,7 +543,7 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
# bitwise flags. Allow for future expansion # bitwise flags. Allow for future expansion
flags = sa.Column( flags = sa.Column(
sa.BigInteger, sa.BigInteger,
default=FLAG_FREE_DISABLE_CREATE_ALIAS, default=FLAG_DISABLE_CREATE_CONTACTS,
server_default="0", server_default="0",
nullable=False, nullable=False,
) )
@ -1168,7 +1168,7 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
def can_create_contacts(self) -> bool: def can_create_contacts(self) -> bool:
if self.is_premium(): if self.is_premium():
return True return True
if self.flags & User.FLAG_FREE_DISABLE_CREATE_ALIAS == 0: if self.flags & User.FLAG_DISABLE_CREATE_CONTACTS == 0:
return True return True
return not config.DISABLE_CREATE_CONTACTS_FOR_FREE_USERS return not config.DISABLE_CREATE_CONTACTS_FOR_FREE_USERS
@ -1673,6 +1673,7 @@ class Alias(Base, ModelMixin):
Session.flush() Session.flush()
# Internal import to avoid global import cycles # Internal import to avoid global import cycles
from app.alias_audit_log_utils import AliasAuditLogAction, emit_alias_audit_log
from app.events.event_dispatcher import EventDispatcher from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import AliasCreated, EventContent from app.events.generated.event_pb2 import AliasCreated, EventContent
@ -1684,6 +1685,9 @@ class Alias(Base, ModelMixin):
created_at=int(new_alias.created_at.timestamp), created_at=int(new_alias.created_at.timestamp),
) )
EventDispatcher.send_event(user, EventContent(alias_created=event)) EventDispatcher.send_event(user, EventContent(alias_created=event))
emit_alias_audit_log(
new_alias, AliasAuditLogAction.CreateAlias, "New alias created"
)
return new_alias return new_alias
@ -2443,6 +2447,8 @@ class CustomDomain(Base, ModelMixin):
unique=True, unique=True,
postgresql_where=Column("ownership_verified"), postgresql_where=Column("ownership_verified"),
), # The condition ), # The condition
Index("ix_custom_domain_user_id", "user_id"),
Index("ix_custom_domain_pending_deletion", "pending_deletion"),
) )
user = orm.relationship(User, foreign_keys=[user_id], backref="custom_domains") user = orm.relationship(User, foreign_keys=[user_id], backref="custom_domains")
@ -2764,9 +2770,9 @@ class Mailbox(Base, ModelMixin):
from app.email_utils import get_email_local_part from app.email_utils import get_email_local_part
mx_domains: [(int, str)] = get_mx_domains(get_email_local_part(self.email)) mx_domains = get_mx_domains(get_email_local_part(self.email))
# Proton is the first domain # Proton is the first domain
if mx_domains and mx_domains[0][1] in ( if mx_domains and mx_domains[0].domain in (
"mail.protonmail.ch.", "mail.protonmail.ch.",
"mailsec.protonmail.ch.", "mailsec.protonmail.ch.",
): ):
@ -3765,15 +3771,14 @@ class SyncEvent(Base, ModelMixin):
sa.Index("ix_sync_event_taken_time", "taken_time"), sa.Index("ix_sync_event_taken_time", "taken_time"),
) )
def mark_as_taken(self) -> bool: def mark_as_taken(self, allow_taken_older_than: Optional[Arrow] = None) -> bool:
sql = """ taken_condition = ["taken_time IS NULL"]
UPDATE sync_event
SET taken_time = :taken_time
WHERE id = :sync_event_id
AND taken_time IS NULL
"""
args = {"taken_time": arrow.now().datetime, "sync_event_id": self.id} args = {"taken_time": arrow.now().datetime, "sync_event_id": self.id}
if allow_taken_older_than:
taken_condition.append("taken_time < :taken_older_than")
args["taken_older_than"] = allow_taken_older_than.datetime
sql_taken_condition = "({})".format(" OR ".join(taken_condition))
sql = f"UPDATE sync_event SET taken_time = :taken_time WHERE id = :sync_event_id AND {sql_taken_condition}"
res = Session.execute(sql, args) res = Session.execute(sql, args)
Session.commit() Session.commit()
@ -3799,3 +3804,39 @@ class SyncEvent(Base, ModelMixin):
.limit(100) .limit(100)
.all() .all()
) )
class AliasAuditLog(Base, ModelMixin):
"""This model holds an audit log for all the actions performed to an alias"""
__tablename__ = "alias_audit_log"
user_id = sa.Column(sa.Integer, nullable=False)
alias_id = sa.Column(sa.Integer, nullable=False)
alias_email = sa.Column(sa.String(255), nullable=False)
action = sa.Column(sa.String(255), nullable=False)
message = sa.Column(sa.Text, default=None, nullable=True)
__table_args__ = (
sa.Index("ix_alias_audit_log_user_id", "user_id"),
sa.Index("ix_alias_audit_log_alias_id", "alias_id"),
sa.Index("ix_alias_audit_log_alias_email", "alias_email"),
sa.Index("ix_alias_audit_log_created_at", "created_at"),
)
class UserAuditLog(Base, ModelMixin):
"""This model holds an audit log for all the actions performed by a user"""
__tablename__ = "user_audit_log"
user_id = sa.Column(sa.Integer, nullable=False)
user_email = sa.Column(sa.String(255), nullable=False)
action = sa.Column(sa.String(255), nullable=False)
message = sa.Column(sa.Text, default=None, nullable=True)
__table_args__ = (
sa.Index("ix_user_audit_log_user_id", "user_id"),
sa.Index("ix_user_audit_log_user_email", "user_email"),
sa.Index("ix_user_audit_log_created_at", "created_at"),
)

View File

@ -0,0 +1,46 @@
from typing import Optional
from arrow import Arrow
from app.models import PartnerUser, PartnerSubscription, User
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
def create_partner_user(
user: User, partner_id: int, partner_email: str, external_user_id: str
) -> PartnerUser:
instance = PartnerUser.create(
user_id=user.id,
partner_id=partner_id,
partner_email=partner_email,
external_user_id=external_user_id,
)
emit_user_audit_log(
user=user,
action=UserAuditLogAction.LinkAccount,
message=f"Linked account to partner_id={partner_id} | partner_email={partner_email} | external_user_id={external_user_id}",
)
return instance
def create_partner_subscription(
partner_user: PartnerUser,
expiration: Optional[Arrow],
msg: Optional[str] = None,
) -> PartnerSubscription:
instance = PartnerSubscription.create(
partner_user_id=partner_user.id,
end_at=expiration,
)
message = "User upgraded through partner subscription"
if msg:
message += f" | {msg}"
emit_user_audit_log(
user=partner_user.user,
action=UserAuditLogAction.Upgrade,
message=message,
)
return instance

View File

View File

@ -0,0 +1,121 @@
from typing import Optional
import arrow
from coinbase_commerce.error import WebhookInvalidPayload, SignatureVerificationError
from coinbase_commerce.webhook import Webhook
from flask import Flask, request
from app.config import COINBASE_WEBHOOK_SECRET
from app.db import Session
from app.email_utils import send_email, render
from app.log import LOG
from app.models import CoinbaseSubscription, User
from app.subscription_webhook import execute_subscription_webhook
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
def setup_coinbase_commerce(app: Flask):
@app.route("/coinbase", methods=["POST"])
def coinbase_webhook():
# event payload
request_data = request.data.decode("utf-8")
# webhook signature
request_sig = request.headers.get("X-CC-Webhook-Signature", None)
try:
# signature verification and event object construction
event = Webhook.construct_event(
request_data, request_sig, COINBASE_WEBHOOK_SECRET
)
except (WebhookInvalidPayload, SignatureVerificationError) as e:
LOG.e("Invalid Coinbase webhook")
return str(e), 400
LOG.d("Coinbase event %s", event)
if event["type"] == "charge:confirmed":
if handle_coinbase_event(event):
return "success", 200
else:
return "error", 400
return "success", 200
def handle_coinbase_event(event) -> bool:
server_user_id = event["data"]["metadata"]["user_id"]
try:
user_id = int(server_user_id)
except ValueError:
user_id = int(float(server_user_id))
code = event["data"]["code"]
user: Optional[User] = User.get(user_id)
if not user:
LOG.e("User not found %s", user_id)
return False
coinbase_subscription: CoinbaseSubscription = CoinbaseSubscription.get_by(
user_id=user_id
)
if not coinbase_subscription:
LOG.d("Create a coinbase subscription for %s", user)
coinbase_subscription = CoinbaseSubscription.create(
user_id=user_id, end_at=arrow.now().shift(years=1), code=code, commit=True
)
emit_user_audit_log(
user=user,
action=UserAuditLogAction.Upgrade,
message="Upgraded though Coinbase",
commit=True,
)
send_email(
user.email,
"Your SimpleLogin account has been upgraded",
render(
"transactional/coinbase/new-subscription.txt",
user=user,
coinbase_subscription=coinbase_subscription,
),
render(
"transactional/coinbase/new-subscription.html",
user=user,
coinbase_subscription=coinbase_subscription,
),
)
else:
if coinbase_subscription.code != code:
LOG.d("Update code from %s to %s", coinbase_subscription.code, code)
coinbase_subscription.code = code
if coinbase_subscription.is_active():
coinbase_subscription.end_at = coinbase_subscription.end_at.shift(years=1)
else: # already expired subscription
coinbase_subscription.end_at = arrow.now().shift(years=1)
emit_user_audit_log(
user=user,
action=UserAuditLogAction.SubscriptionExtended,
message="Extended coinbase subscription",
)
Session.commit()
send_email(
user.email,
"Your SimpleLogin account has been extended",
render(
"transactional/coinbase/extend-subscription.txt",
user=user,
coinbase_subscription=coinbase_subscription,
),
render(
"transactional/coinbase/extend-subscription.html",
user=user,
coinbase_subscription=coinbase_subscription,
),
)
execute_subscription_webhook(user)
return True

286
app/app/payments/paddle.py Normal file
View File

@ -0,0 +1,286 @@
import arrow
import json
from dateutil.relativedelta import relativedelta
from flask import Flask, request
from app import paddle_utils, paddle_callback
from app.config import (
PADDLE_MONTHLY_PRODUCT_ID,
PADDLE_MONTHLY_PRODUCT_IDS,
PADDLE_YEARLY_PRODUCT_IDS,
PADDLE_COUPON_ID,
)
from app.db import Session
from app.email_utils import send_email, render
from app.log import LOG
from app.models import Subscription, PlanEnum, User, Coupon
from app.subscription_webhook import execute_subscription_webhook
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.utils import random_string
def setup_paddle_callback(app: Flask):
@app.route("/paddle", methods=["GET", "POST"])
def paddle():
LOG.d(f"paddle callback {request.form.get('alert_name')} {request.form}")
# make sure the request comes from Paddle
if not paddle_utils.verify_incoming_request(dict(request.form)):
LOG.e("request not coming from paddle. Request data:%s", dict(request.form))
return "KO", 400
if (
request.form.get("alert_name") == "subscription_created"
): # new user subscribes
# the passthrough is json encoded, e.g.
# request.form.get("passthrough") = '{"user_id": 88 }'
passthrough = json.loads(request.form.get("passthrough"))
user_id = passthrough.get("user_id")
user = User.get(user_id)
subscription_plan_id = int(request.form.get("subscription_plan_id"))
if subscription_plan_id in PADDLE_MONTHLY_PRODUCT_IDS:
plan = PlanEnum.monthly
elif subscription_plan_id in PADDLE_YEARLY_PRODUCT_IDS:
plan = PlanEnum.yearly
else:
LOG.e(
"Unknown subscription_plan_id %s %s",
subscription_plan_id,
request.form,
)
return "No such subscription", 400
sub = Subscription.get_by(user_id=user.id)
if not sub:
LOG.d(f"create a new Subscription for user {user}")
Subscription.create(
user_id=user.id,
cancel_url=request.form.get("cancel_url"),
update_url=request.form.get("update_url"),
subscription_id=request.form.get("subscription_id"),
event_time=arrow.now(),
next_bill_date=arrow.get(
request.form.get("next_bill_date"), "YYYY-MM-DD"
).date(),
plan=plan,
)
emit_user_audit_log(
user=user,
action=UserAuditLogAction.Upgrade,
message="Upgraded through Paddle",
)
else:
LOG.d(f"Update an existing Subscription for user {user}")
sub.cancel_url = request.form.get("cancel_url")
sub.update_url = request.form.get("update_url")
sub.subscription_id = request.form.get("subscription_id")
sub.event_time = arrow.now()
sub.next_bill_date = arrow.get(
request.form.get("next_bill_date"), "YYYY-MM-DD"
).date()
sub.plan = plan
# make sure to set the new plan as not-cancelled
# in case user cancels a plan and subscribes a new plan
sub.cancelled = False
emit_user_audit_log(
user=user,
action=UserAuditLogAction.SubscriptionExtended,
message="Extended Paddle subscription",
)
execute_subscription_webhook(user)
LOG.d("User %s upgrades!", user)
Session.commit()
elif request.form.get("alert_name") == "subscription_payment_succeeded":
subscription_id = request.form.get("subscription_id")
LOG.d("Update subscription %s", subscription_id)
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
# when user subscribes, the "subscription_payment_succeeded" can arrive BEFORE "subscription_created"
# at that time, subscription object does not exist yet
if sub:
sub.event_time = arrow.now()
sub.next_bill_date = arrow.get(
request.form.get("next_bill_date"), "YYYY-MM-DD"
).date()
Session.commit()
execute_subscription_webhook(sub.user)
elif request.form.get("alert_name") == "subscription_cancelled":
subscription_id = request.form.get("subscription_id")
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
if sub:
# cancellation_effective_date should be the same as next_bill_date
LOG.w(
"Cancel subscription %s %s on %s, next bill date %s",
subscription_id,
sub.user,
request.form.get("cancellation_effective_date"),
sub.next_bill_date,
)
sub.event_time = arrow.now()
sub.cancelled = True
emit_user_audit_log(
user=sub.user,
action=UserAuditLogAction.SubscriptionCancelled,
message="Cancelled Paddle subscription",
)
Session.commit()
user = sub.user
send_email(
user.email,
"SimpleLogin - your subscription is canceled",
render(
"transactional/subscription-cancel.txt",
user=user,
end_date=request.form.get("cancellation_effective_date"),
),
)
execute_subscription_webhook(sub.user)
else:
# user might have deleted their account
LOG.i(f"Cancel non-exist subscription {subscription_id}")
return "OK"
elif request.form.get("alert_name") == "subscription_updated":
subscription_id = request.form.get("subscription_id")
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
if sub:
next_bill_date = request.form.get("next_bill_date")
if not next_bill_date:
paddle_callback.failed_payment(sub, subscription_id)
return "OK"
LOG.d(
"Update subscription %s %s on %s, next bill date %s",
subscription_id,
sub.user,
request.form.get("cancellation_effective_date"),
sub.next_bill_date,
)
if (
int(request.form.get("subscription_plan_id"))
== PADDLE_MONTHLY_PRODUCT_ID
):
plan = PlanEnum.monthly
else:
plan = PlanEnum.yearly
sub.cancel_url = request.form.get("cancel_url")
sub.update_url = request.form.get("update_url")
sub.event_time = arrow.now()
sub.next_bill_date = arrow.get(
request.form.get("next_bill_date"), "YYYY-MM-DD"
).date()
sub.plan = plan
# make sure to set the new plan as not-cancelled
sub.cancelled = False
emit_user_audit_log(
user=sub.user,
action=UserAuditLogAction.SubscriptionExtended,
message="Extended Paddle subscription",
)
Session.commit()
execute_subscription_webhook(sub.user)
else:
LOG.w(
f"update non-exist subscription {subscription_id}. {request.form}"
)
return "No such subscription", 400
elif request.form.get("alert_name") == "payment_refunded":
subscription_id = request.form.get("subscription_id")
LOG.d("Refund request for subscription %s", subscription_id)
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
if sub:
user = sub.user
Subscription.delete(sub.id)
emit_user_audit_log(
user=user,
action=UserAuditLogAction.SubscriptionCancelled,
message="Paddle subscription cancelled as user requested a refund",
)
Session.commit()
LOG.e("%s requests a refund", user)
execute_subscription_webhook(sub.user)
elif request.form.get("alert_name") == "subscription_payment_refunded":
subscription_id = request.form.get("subscription_id")
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
LOG.d(
"Handle subscription_payment_refunded for subscription %s",
subscription_id,
)
if not sub:
LOG.w(
"No such subscription for %s, payload %s",
subscription_id,
request.form,
)
return "No such subscription"
plan_id = int(request.form["subscription_plan_id"])
if request.form["refund_type"] == "full":
if plan_id in PADDLE_MONTHLY_PRODUCT_IDS:
LOG.d("subtract 1 month from next_bill_date %s", sub.next_bill_date)
sub.next_bill_date = sub.next_bill_date - relativedelta(months=1)
LOG.d("next_bill_date is %s", sub.next_bill_date)
Session.commit()
elif plan_id in PADDLE_YEARLY_PRODUCT_IDS:
LOG.d("subtract 1 year from next_bill_date %s", sub.next_bill_date)
sub.next_bill_date = sub.next_bill_date - relativedelta(years=1)
LOG.d("next_bill_date is %s", sub.next_bill_date)
Session.commit()
else:
LOG.e("Unknown plan_id %s", plan_id)
else:
LOG.w("partial subscription_payment_refunded, not handled")
execute_subscription_webhook(sub.user)
return "OK"
@app.route("/paddle_coupon", methods=["GET", "POST"])
def paddle_coupon():
LOG.d("paddle coupon callback %s", request.form)
if not paddle_utils.verify_incoming_request(dict(request.form)):
LOG.e("request not coming from paddle. Request data:%s", dict(request.form))
return "KO", 400
product_id = request.form.get("p_product_id")
if product_id != PADDLE_COUPON_ID:
LOG.e("product_id %s not match with %s", product_id, PADDLE_COUPON_ID)
return "KO", 400
email = request.form.get("email")
LOG.d("Paddle coupon request for %s", email)
coupon = Coupon.create(
code=random_string(30),
comment="For 1-year coupon",
expires_date=arrow.now().shift(years=1, days=-1),
commit=True,
)
return (
f"Your 1-year coupon is <b>{coupon.code}</b> <br> "
f"It's valid until <b>{coupon.expires_date.date().isoformat()}</b>"
)

View File

@ -5,6 +5,7 @@ from app.db import Session
from app.log import LOG from app.log import LOG
from app.errors import ProtonPartnerNotSetUp from app.errors import ProtonPartnerNotSetUp
from app.models import Partner, PartnerUser, User from app.models import Partner, PartnerUser, User
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
PROTON_PARTNER_NAME = "Proton" PROTON_PARTNER_NAME = "Proton"
_PROTON_PARTNER: Optional[Partner] = None _PROTON_PARTNER: Optional[Partner] = None
@ -32,6 +33,11 @@ def perform_proton_account_unlink(current_user: User):
) )
if partner_user is not None: if partner_user is not None:
LOG.info(f"User {current_user} has unlinked the account from {partner_user}") LOG.info(f"User {current_user} has unlinked the account from {partner_user}")
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UnlinkAccount,
message=f"User has unlinked the account (email={partner_user.partner_email} | external_user_id={partner_user.external_user_id})",
)
PartnerUser.delete(partner_user.id) PartnerUser.delete(partner_user.id)
Session.commit() Session.commit()
agent.record_custom_event("AccountUnlinked", {"partner": proton_partner.name}) agent.record_custom_event("AccountUnlinked", {"partner": proton_partner.name})

21
app/app/sentry_utils.py Normal file
View File

@ -0,0 +1,21 @@
from typing import Optional
from sentry_sdk.types import Event, Hint
_HTTP_CODES_TO_IGNORE = [416]
def _should_send(_event: Event, hint: Hint) -> bool:
# Check if this is an HTTP Exception event
if "exc_info" in hint:
exc_type, exc_value, exc_traceback = hint["exc_info"]
# Check if it's a Werkzeug HTTPException (raised for HTTP status codes)
if hasattr(exc_value, "code") and exc_value.code in _HTTP_CODES_TO_IGNORE:
return False
return True
def sentry_before_send(event: Event, hint: Hint) -> Optional[Event]:
if _should_send(event, hint):
return event
return None

View File

@ -1,40 +1,16 @@
import requests
from requests import RequestException
from app import config
from app.db import Session from app.db import Session
from app.events.event_dispatcher import EventDispatcher from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import EventContent, UserPlanChanged from app.events.generated.event_pb2 import EventContent, UserPlanChanged
from app.log import LOG
from app.models import User from app.models import User
def execute_subscription_webhook(user: User): def execute_subscription_webhook(user: User):
webhook_url = config.SUBSCRIPTION_CHANGE_WEBHOOK
if webhook_url is None:
return
subscription_end = user.get_active_subscription_end( subscription_end = user.get_active_subscription_end(
include_partner_subscription=False include_partner_subscription=False
) )
sl_subscription_end = None sl_subscription_end = None
if subscription_end: if subscription_end:
sl_subscription_end = subscription_end.timestamp sl_subscription_end = subscription_end.timestamp
payload = {
"user_id": user.id,
"is_premium": user.is_premium(),
"active_subscription_end": sl_subscription_end,
}
try:
response = requests.post(webhook_url, json=payload, timeout=2)
if response.status_code == 200:
LOG.i("Sent request to subscription update webhook successfully")
else:
LOG.i(
f"Request to webhook failed with status {response.status_code}: {response.text}"
)
except RequestException as e:
LOG.error(f"Subscription request exception: {e}")
event = UserPlanChanged(plan_end_time=sl_subscription_end) event = UserPlanChanged(plan_end_time=sl_subscription_end)
EventDispatcher.send_event(user, EventContent(user_plan_change=event)) EventDispatcher.send_event(user, EventContent(user_plan_change=event))
Session.commit() Session.commit()

View File

@ -0,0 +1,40 @@
from enum import Enum
from app.models import User, UserAuditLog
class UserAuditLogAction(Enum):
Upgrade = "upgrade"
SubscriptionExtended = "subscription_extended"
SubscriptionCancelled = "subscription_cancelled"
LinkAccount = "link_account"
UnlinkAccount = "unlink_account"
CreateMailbox = "create_mailbox"
VerifyMailbox = "verify_mailbox"
UpdateMailbox = "update_mailbox"
DeleteMailbox = "delete_mailbox"
CreateCustomDomain = "create_custom_domain"
VerifyCustomDomain = "verify_custom_domain"
UpdateCustomDomain = "update_custom_domain"
DeleteCustomDomain = "delete_custom_domain"
CreateDirectory = "create_directory"
UpdateDirectory = "update_directory"
DeleteDirectory = "delete_directory"
UserMarkedForDeletion = "user_marked_for_deletion"
DeleteUser = "delete_user"
def emit_user_audit_log(
user: User, action: UserAuditLogAction, message: str, commit: bool = False
):
UserAuditLog.create(
user_id=user.id,
user_email=user.email,
action=action.value,
message=message,
commit=commit,
)

View File

@ -3,6 +3,7 @@ from typing import Optional
from app.db import Session from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import User, SLDomain, CustomDomain, Mailbox from app.models import User, SLDomain, CustomDomain, Mailbox
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
class CannotSetAlias(Exception): class CannotSetAlias(Exception):
@ -54,7 +55,7 @@ def set_default_alias_domain(user: User, domain_name: Optional[str]):
def set_default_mailbox(user: User, mailbox_id: int) -> Mailbox: def set_default_mailbox(user: User, mailbox_id: int) -> Mailbox:
mailbox = Mailbox.get(mailbox_id) mailbox: Optional[Mailbox] = Mailbox.get(mailbox_id)
if not mailbox or mailbox.user_id != user.id: if not mailbox or mailbox.user_id != user.id:
raise CannotSetMailbox("Invalid mailbox") raise CannotSetMailbox("Invalid mailbox")
@ -67,5 +68,11 @@ def set_default_mailbox(user: User, mailbox_id: int) -> Mailbox:
LOG.i(f"User {user} has set mailbox {mailbox} as his default one") LOG.i(f"User {user} has set mailbox {mailbox} as his default one")
user.default_mailbox_id = mailbox.id user.default_mailbox_id = mailbox.id
emit_user_audit_log(
user=user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Set mailbox {mailbox.id} ({mailbox.email}) as default",
)
Session.commit() Session.commit()
return mailbox return mailbox

View File

@ -14,6 +14,7 @@ from sqlalchemy.sql import Insert, text
from app import s3, config from app import s3, config
from app.alias_utils import nb_email_log_for_mailbox from app.alias_utils import nb_email_log_for_mailbox
from app.api.views.apple import verify_receipt from app.api.views.apple import verify_receipt
from app.custom_domain_validation import CustomDomainValidation
from app.db import Session from app.db import Session
from app.dns_utils import get_mx_domains, is_mx_equivalent from app.dns_utils import get_mx_domains, is_mx_equivalent
from app.email_utils import ( from app.email_utils import (
@ -59,8 +60,11 @@ from app.models import (
) )
from app.pgp_utils import load_public_key_and_check, PGPException from app.pgp_utils import load_public_key_and_check, PGPException
from app.proton.utils import get_proton_partner from app.proton.utils import get_proton_partner
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.utils import sanitize_email from app.utils import sanitize_email
from server import create_light_app from server import create_light_app
from tasks.clean_alias_audit_log import cleanup_alias_audit_log
from tasks.clean_user_audit_log import cleanup_user_audit_log
from tasks.cleanup_old_imports import cleanup_old_imports from tasks.cleanup_old_imports import cleanup_old_imports
from tasks.cleanup_old_jobs import cleanup_old_jobs from tasks.cleanup_old_jobs import cleanup_old_jobs
from tasks.cleanup_old_notifications import cleanup_old_notifications from tasks.cleanup_old_notifications import cleanup_old_notifications
@ -905,9 +909,11 @@ def check_custom_domain():
LOG.i("custom domain has been deleted") LOG.i("custom domain has been deleted")
def check_single_custom_domain(custom_domain): def check_single_custom_domain(custom_domain: CustomDomain):
mx_domains = get_mx_domains(custom_domain.domain) mx_domains = get_mx_domains(custom_domain.domain)
if not is_mx_equivalent(mx_domains, config.EMAIL_SERVERS_WITH_PRIORITY): validator = CustomDomainValidation(dkim_domain=config.EMAIL_DOMAIN)
expected_custom_domains = validator.get_expected_mx_records(custom_domain)
if not is_mx_equivalent(mx_domains, expected_custom_domains):
user = custom_domain.user user = custom_domain.user
LOG.w( LOG.w(
"The MX record is not correctly set for %s %s %s", "The MX record is not correctly set for %s %s %s",
@ -1215,7 +1221,7 @@ def notify_hibp():
def clear_users_scheduled_to_be_deleted(dry_run=False): def clear_users_scheduled_to_be_deleted(dry_run=False):
users = User.filter( users: List[User] = User.filter(
and_( and_(
User.delete_on.isnot(None), User.delete_on.isnot(None),
User.delete_on <= arrow.now().shift(days=-DELETE_GRACE_DAYS), User.delete_on <= arrow.now().shift(days=-DELETE_GRACE_DAYS),
@ -1227,6 +1233,11 @@ def clear_users_scheduled_to_be_deleted(dry_run=False):
) )
if dry_run: if dry_run:
continue continue
emit_user_audit_log(
user=user,
action=UserAuditLogAction.DeleteUser,
message=f"Delete user {user.id} ({user.email})",
)
User.delete(user.id) User.delete(user.id)
Session.commit() Session.commit()
@ -1238,6 +1249,16 @@ def delete_old_data():
cleanup_old_notifications(oldest_valid) cleanup_old_notifications(oldest_valid)
def clear_alias_audit_log():
oldest_valid = arrow.now().shift(days=-config.AUDIT_LOG_MAX_DAYS)
cleanup_alias_audit_log(oldest_valid)
def clear_user_audit_log():
oldest_valid = arrow.now().shift(days=-config.AUDIT_LOG_MAX_DAYS)
cleanup_user_audit_log(oldest_valid)
if __name__ == "__main__": if __name__ == "__main__":
LOG.d("Start running cronjob") LOG.d("Start running cronjob")
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -1246,22 +1267,6 @@ if __name__ == "__main__":
"--job", "--job",
help="Choose a cron job to run", help="Choose a cron job to run",
type=str, type=str,
choices=[
"stats",
"notify_trial_end",
"notify_manual_subscription_end",
"notify_premium_end",
"delete_logs",
"delete_old_data",
"poll_apple_subscription",
"sanity_check",
"delete_old_monitoring",
"check_custom_domain",
"check_hibp",
"notify_hibp",
"cleanup_tokens",
"send_undelivered_mails",
],
) )
args = parser.parse_args() args = parser.parse_args()
# wrap in an app context to benefit from app setup like database cleanup, sentry integration, etc # wrap in an app context to benefit from app setup like database cleanup, sentry integration, etc
@ -1310,4 +1315,10 @@ if __name__ == "__main__":
load_unsent_mails_from_fs_and_resend() load_unsent_mails_from_fs_and_resend()
elif args.job == "delete_scheduled_users": elif args.job == "delete_scheduled_users":
LOG.d("Deleting users scheduled to be deleted") LOG.d("Deleting users scheduled to be deleted")
clear_users_scheduled_to_be_deleted(dry_run=True) clear_users_scheduled_to_be_deleted()
elif args.job == "clear_alias_audit_log":
LOG.d("Clearing alias audit log")
clear_alias_audit_log()
elif args.job == "clear_user_audit_log":
LOG.d("Clearing user audit log")
clear_user_audit_log()

View File

@ -80,3 +80,17 @@ jobs:
schedule: "*/5 * * * *" schedule: "*/5 * * * *"
captureStderr: true captureStderr: true
concurrencyPolicy: Forbid concurrencyPolicy: Forbid
- name: SimpleLogin clear alias_audit_log old entries
command: python /code/cron.py -j clear_alias_audit_log
shell: /bin/bash
schedule: "0 * * * *" # Once every hour
captureStderr: true
concurrencyPolicy: Forbid
- name: SimpleLogin clear user_audit_log old entries
command: python /code/cron.py -j clear_user_audit_log
shell: /bin/bash
schedule: "0 * * * *" # Once every hour
captureStderr: true
concurrencyPolicy: Forbid

View File

@ -53,7 +53,11 @@ from flanker.addresslib.address import EmailAddress
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from app import pgp_utils, s3, config, contact_utils from app import pgp_utils, s3, config, contact_utils
from app.alias_utils import try_auto_create, change_alias_status from app.alias_utils import (
try_auto_create,
change_alias_status,
get_alias_recipient_name,
)
from app.config import ( from app.config import (
EMAIL_DOMAIN, EMAIL_DOMAIN,
URL, URL,
@ -197,8 +201,8 @@ def get_or_create_contact(from_header: str, mail_from: str, alias: Alias) -> Con
contact_email = mail_from contact_email = mail_from
contact_result = contact_utils.create_contact( contact_result = contact_utils.create_contact(
email=contact_email, email=contact_email,
name=contact_name,
alias=alias, alias=alias,
name=contact_name,
mail_from=mail_from, mail_from=mail_from,
allow_empty_email=True, allow_empty_email=True,
automatic_created=True, automatic_created=True,
@ -229,7 +233,7 @@ def get_or_create_reply_to_contact(
) )
return None return None
return contact_utils.create_contact(contact_address, contact_name, alias).contact return contact_utils.create_contact(contact_address, alias, contact_name).contact
def replace_header_when_forward(msg: Message, alias: Alias, header: str): def replace_header_when_forward(msg: Message, alias: Alias, header: str):
@ -1161,23 +1165,11 @@ def handle_reply(envelope, msg: Message, rcpt_to: str) -> (bool, str):
Session.commit() Session.commit()
# make the email comes from alias recipient_name = get_alias_recipient_name(alias)
from_header = alias.email if recipient_name.message:
# add alias name from alias LOG.d(recipient_name.message)
if alias.name: LOG.d("From header is %s", recipient_name.name)
LOG.d("Put alias name %s in from header", alias.name) add_or_replace_header(msg, headers.FROM, recipient_name.name)
from_header = sl_formataddr((alias.name, alias.email))
elif alias.custom_domain:
# add alias name from domain
if alias.custom_domain.name:
LOG.d(
"Put domain default alias name %s in from header",
alias.custom_domain.name,
)
from_header = sl_formataddr((alias.custom_domain.name, alias.email))
LOG.d("From header is %s", from_header)
add_or_replace_header(msg, headers.FROM, from_header)
try: try:
if str(msg[headers.TO]).lower() == "undisclosed-recipients:;": if str(msg[headers.TO]).lower() == "undisclosed-recipients:;":
@ -1510,7 +1502,9 @@ def handle_bounce_forward_phase(msg: Message, email_log: EmailLog):
LOG.w( LOG.w(
f"Disable alias {alias} because {reason}. {alias.mailboxes} {alias.user}. Last contact {contact}" f"Disable alias {alias} because {reason}. {alias.mailboxes} {alias.user}. Last contact {contact}"
) )
change_alias_status(alias, enabled=False) change_alias_status(
alias, enabled=False, message=f"Set enabled=False due to {reason}"
)
Notification.create( Notification.create(
user_id=user.id, user_id=user.id,

View File

@ -72,7 +72,9 @@ class PostgresEventSource(EventSource):
Session.close() # Ensure we get a new connection and we don't leave a dangling tx Session.close() # Ensure we get a new connection and we don't leave a dangling tx
def __connect(self): def __connect(self):
self.__connection = psycopg2.connect(self.__connection_string) self.__connection = psycopg2.connect(
self.__connection_string, application_name="sl-event-listen"
)
from app.db import Session from app.db import Session
@ -83,24 +85,28 @@ class DeadLetterEventSource(EventSource):
def __init__(self, max_retries: int): def __init__(self, max_retries: int):
self.__max_retries = max_retries self.__max_retries = max_retries
def execute_loop(
self, on_event: Callable[[SyncEvent], NoReturn]
) -> list[SyncEvent]:
threshold = arrow.utcnow().shift(minutes=-_DEAD_LETTER_THRESHOLD_MINUTES)
events = SyncEvent.get_dead_letter(
older_than=threshold, max_retries=self.__max_retries
)
if events:
LOG.info(f"Got {len(events)} dead letter events")
newrelic.agent.record_custom_metric(
"Custom/dead_letter_events_to_process", len(events)
)
for event in events:
if event.mark_as_taken(allow_taken_older_than=threshold):
on_event(event)
return events
@newrelic.agent.background_task() @newrelic.agent.background_task()
def run(self, on_event: Callable[[SyncEvent], NoReturn]): def run(self, on_event: Callable[[SyncEvent], NoReturn]):
while True: while True:
try: try:
threshold = arrow.utcnow().shift( events = self.execute_loop(on_event)
minutes=-_DEAD_LETTER_THRESHOLD_MINUTES
)
events = SyncEvent.get_dead_letter(
older_than=threshold, max_retries=self.__max_retries
)
if events:
LOG.info(f"Got {len(events)} dead letter events")
if events:
newrelic.agent.record_custom_metric(
"Custom/dead_letter_events_to_process", len(events)
)
for event in events:
on_event(event)
Session.close() # Ensure that we have a new connection and we don't have a dangling tx with a lock Session.close() # Ensure that we have a new connection and we don't have a dangling tx with a lock
if not events: if not events:
LOG.debug("No dead letter events") LOG.debug("No dead letter events")

View File

@ -20,6 +20,7 @@ from app.jobs.event_jobs import send_alias_creation_events_for_user
from app.jobs.export_user_data_job import ExportUserDataJob from app.jobs.export_user_data_job import ExportUserDataJob
from app.log import LOG from app.log import LOG
from app.models import User, Job, BatchImport, Mailbox, CustomDomain, JobState from app.models import User, Job, BatchImport, Mailbox, CustomDomain, JobState
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from server import create_light_app from server import create_light_app
@ -128,7 +129,7 @@ def welcome_proton(user):
def delete_mailbox_job(job: Job): def delete_mailbox_job(job: Job):
mailbox_id = job.payload.get("mailbox_id") mailbox_id = job.payload.get("mailbox_id")
mailbox = Mailbox.get(mailbox_id) mailbox: Optional[Mailbox] = Mailbox.get(mailbox_id)
if not mailbox: if not mailbox:
return return
@ -152,6 +153,12 @@ def delete_mailbox_job(job: Job):
mailbox_email = mailbox.email mailbox_email = mailbox.email
user = mailbox.user user = mailbox.user
emit_user_audit_log(
user=user,
action=UserAuditLogAction.DeleteMailbox,
message=f"Delete mailbox {mailbox.id} ({mailbox.email})",
)
Mailbox.delete(mailbox_id) Mailbox.delete(mailbox_id)
Session.commit() Session.commit()
LOG.d("Mailbox %s %s deleted", mailbox_id, mailbox_email) LOG.d("Mailbox %s %s deleted", mailbox_id, mailbox_email)
@ -244,15 +251,27 @@ def process_job(job: Job):
if not custom_domain: if not custom_domain:
return return
is_subdomain = custom_domain.is_sl_subdomain
domain_name = custom_domain.domain domain_name = custom_domain.domain
user = custom_domain.user user = custom_domain.user
custom_domain_partner_id = custom_domain.partner_id
CustomDomain.delete(custom_domain.id) CustomDomain.delete(custom_domain.id)
Session.commit() Session.commit()
if is_subdomain:
message = f"Delete subdomain {custom_domain_id} ({domain_name})"
else:
message = f"Delete custom domain {custom_domain_id} ({domain_name})"
emit_user_audit_log(
user=user,
action=UserAuditLogAction.DeleteCustomDomain,
message=message,
)
LOG.d("Domain %s deleted", domain_name) LOG.d("Domain %s deleted", domain_name)
if custom_domain.partner_id is None: if custom_domain_partner_id is None:
send_email( send_email(
user.email, user.email,
f"Your domain {domain_name} has been deleted", f"Your domain {domain_name} has been deleted",

View File

@ -0,0 +1,27 @@
"""custom domain indices
Revision ID: 62afa3a10010
Revises: 88dd7a0abf54
Create Date: 2024-09-30 11:40:04.127791
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = '62afa3a10010'
down_revision = '88dd7a0abf54'
branch_labels = None
depends_on = None
def upgrade():
with op.get_context().autocommit_block():
op.create_index('ix_custom_domain_pending_deletion', 'custom_domain', ['pending_deletion'], unique=False, postgresql_concurrently=True)
op.create_index('ix_custom_domain_user_id', 'custom_domain', ['user_id'], unique=False, postgresql_concurrently=True)
def downgrade():
with op.get_context().autocommit_block():
op.drop_index('ix_custom_domain_user_id', table_name='custom_domain', postgresql_concurrently=True)
op.drop_index('ix_custom_domain_pending_deletion', table_name='custom_domain', postgresql_concurrently=True)

View File

@ -0,0 +1,45 @@
"""alias_audit_log
Revision ID: 91ed7f46dc81
Revises: 62afa3a10010
Create Date: 2024-10-11 13:22:11.594054
"""
import sqlalchemy_utils
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '91ed7f46dc81'
down_revision = '62afa3a10010'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('alias_audit_log',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('created_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=False),
sa.Column('updated_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('alias_id', sa.Integer(), nullable=False),
sa.Column('alias_email', sa.String(length=255), nullable=False),
sa.Column('action', sa.String(length=255), nullable=False),
sa.Column('message', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index('ix_alias_audit_log_alias_email', 'alias_audit_log', ['alias_email'], unique=False)
op.create_index('ix_alias_audit_log_alias_id', 'alias_audit_log', ['alias_id'], unique=False)
op.create_index('ix_alias_audit_log_user_id', 'alias_audit_log', ['user_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index('ix_alias_audit_log_user_id', table_name='alias_audit_log')
op.drop_index('ix_alias_audit_log_alias_id', table_name='alias_audit_log')
op.drop_index('ix_alias_audit_log_alias_email', table_name='alias_audit_log')
op.drop_table('alias_audit_log')
# ### end Alembic commands ###

View File

@ -0,0 +1,44 @@
"""user_audit_log
Revision ID: 7d7b84779837
Revises: 91ed7f46dc81
Create Date: 2024-10-16 11:52:49.128644
"""
import sqlalchemy_utils
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '7d7b84779837'
down_revision = '91ed7f46dc81'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('user_audit_log',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('created_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=False),
sa.Column('updated_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('user_email', sa.String(length=255), nullable=False),
sa.Column('action', sa.String(length=255), nullable=False),
sa.Column('message', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index('ix_user_audit_log_user_email', 'user_audit_log', ['user_email'], unique=False)
op.create_index('ix_user_audit_log_user_id', 'user_audit_log', ['user_id'], unique=False)
op.create_index('ix_user_audit_log_created_at', 'user_audit_log', ['created_at'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index('ix_user_audit_log_user_id', table_name='user_audit_log')
op.drop_index('ix_user_audit_log_user_email', table_name='user_audit_log')
op.drop_index('ix_user_audit_log_created_at', table_name='user_audit_log')
op.drop_table('user_audit_log')
# ### end Alembic commands ###

View File

@ -0,0 +1,27 @@
"""alias_audit_log_index_created_at
Revision ID: 32f25cbf12f6
Revises: 7d7b84779837
Create Date: 2024-10-16 16:45:36.827161
"""
import sqlalchemy_utils
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '32f25cbf12f6'
down_revision = '7d7b84779837'
branch_labels = None
depends_on = None
def upgrade():
with op.get_context().autocommit_block():
op.create_index('ix_alias_audit_log_created_at', 'alias_audit_log', ['created_at'], unique=False, postgresql_concurrently=True)
def downgrade():
with op.get_context().autocommit_block():
op.drop_index('ix_alias_audit_log_created_at', table_name='alias_audit_log', postgresql_concurrently=True)

View File

@ -94,6 +94,20 @@ def log_nb_db_connection():
newrelic.agent.record_custom_metric("Custom/nb_db_connections", nb_connection) newrelic.agent.record_custom_metric("Custom/nb_db_connections", nb_connection)
@newrelic.agent.background_task()
def log_nb_db_connection_by_app_name():
# get the number of connections to the DB
rows = Session.execute(
"SELECT application_name, count(datid) FROM pg_stat_activity group by application_name"
)
for row in rows:
if row[0].find("sl-") == 0:
LOG.d("number of db connections for app %s = %s", row[0], row[1])
newrelic.agent.record_custom_metric(
f"Custom/nb_db_app_connection/{row[0]}", row[1]
)
@newrelic.agent.background_task() @newrelic.agent.background_task()
def log_pending_to_process_events(): def log_pending_to_process_events():
r = Session.execute("select count(*) from sync_event WHERE taken_time IS NULL;") r = Session.execute("select count(*) from sync_event WHERE taken_time IS NULL;")
@ -148,6 +162,7 @@ if __name__ == "__main__":
log_pending_to_process_events() log_pending_to_process_events()
log_events_pending_dead_letter() log_events_pending_dead_letter()
log_failed_events() log_failed_events()
log_nb_db_connection_by_app_name()
Session.close() Session.close()
exporter.run() exporter.run()

View File

@ -21,7 +21,7 @@ if max_alias_id == 0:
max_alias_id = Session.query(func.max(Alias.id)).scalar() max_alias_id = Session.query(func.max(Alias.id)).scalar()
print(f"Checking alias {alias_id_start} to {max_alias_id}") print(f"Checking alias {alias_id_start} to {max_alias_id}")
step = 1000 step = 10000
noteSql = "(note = 'Created through Proton' or note = 'Created through partner Proton')" noteSql = "(note = 'Created through Proton' or note = 'Created through partner Proton')"
alias_query = f"UPDATE alias set note = NULL, flags = flags | :flag where id>=:start AND id<:end and {noteSql}" alias_query = f"UPDATE alias set note = NULL, flags = flags | :flag where id>=:start AND id<:end and {noteSql}"
updated = 0 updated = 0
@ -38,12 +38,12 @@ for batch_start in range(alias_id_start, max_alias_id, step):
updated += rows_done.rowcount updated += rows_done.rowcount
Session.commit() Session.commit()
elapsed = time.time() - start_time elapsed = time.time() - start_time
time_per_alias = elapsed / (updated + 1)
last_batch_id = batch_start + step last_batch_id = batch_start + step
time_per_alias = elapsed / (last_batch_id)
remaining = max_alias_id - last_batch_id remaining = max_alias_id - last_batch_id
time_remaining = (max_alias_id - last_batch_id) * time_per_alias time_remaining = remaining / time_per_alias
hours_remaining = time_remaining / 3600.0 hours_remaining = time_remaining / 60.0
print( print(
f"\rAlias {batch_start}/{max_alias_id} {updated} {hours_remaining:.2f}hrs remaining" f"\rAlias {batch_start}/{max_alias_id} {updated} {hours_remaining:.2f} mins remaining"
) )
print("") print("")

View File

@ -0,0 +1,63 @@
#!/usr/bin/env python3
import argparse
import time
from sqlalchemy import func
from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import UserPlanChanged, EventContent
from app.models import PartnerUser
from app.db import Session
parser = argparse.ArgumentParser(
prog="Backfill alias", description="Update alias notes and backfill flag"
)
parser.add_argument(
"-s", "--start_pu_id", default=0, type=int, help="Initial partner_user_id"
)
parser.add_argument(
"-e", "--end_pu_id", default=0, type=int, help="Last partner_user_id"
)
args = parser.parse_args()
pu_id_start = args.start_pu_id
max_pu_id = args.end_pu_id
if max_pu_id == 0:
max_pu_id = Session.query(func.max(PartnerUser.id)).scalar()
print(f"Checking partner user {pu_id_start} to {max_pu_id}")
step = 100
updated = 0
start_time = time.time()
with_premium = 0
for batch_start in range(pu_id_start, max_pu_id, step):
partner_users = (
Session.query(PartnerUser).filter(
PartnerUser.id >= batch_start, PartnerUser.id < batch_start + step
)
).all()
for partner_user in partner_users:
subscription_end = partner_user.user.get_active_subscription_end(
include_partner_subscription=False
)
end_timestamp = None
if subscription_end:
with_premium += 1
end_timestamp = subscription_end.timestamp
event = UserPlanChanged(plan_end_time=end_timestamp)
EventDispatcher.send_event(
partner_user.user, EventContent(user_plan_change=event)
)
Session.flush()
updated += 1
Session.commit()
elapsed = time.time() - start_time
last_batch_id = batch_start + step
time_per_alias = elapsed / (last_batch_id)
remaining = max_pu_id - last_batch_id
time_remaining = remaining / time_per_alias
hours_remaining = time_remaining / 60.0
print(
f"\PartnerUser {batch_start}/{max_pu_id} {updated} {hours_remaining:.2f} mins remaining"
)
print(f"With SL premium {with_premium}")

98
app/poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. # This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand.
[[package]] [[package]]
name = "aiohttp" name = "aiohttp"
@ -360,35 +360,41 @@ files = [
[[package]] [[package]]
name = "boto3" name = "boto3"
version = "1.15.9" version = "1.35.37"
description = "The AWS SDK for Python" description = "The AWS SDK for Python"
optional = false optional = false
python-versions = "*" python-versions = ">=3.8"
files = [ files = [
{file = "boto3-1.15.9-py2.py3-none-any.whl", hash = "sha256:e0a1dbc0a0e460dc6de2f4144b5015edad3ab5c17ee83c6194b1a010d815bc60"}, {file = "boto3-1.35.37-py3-none-any.whl", hash = "sha256:385ca77bf8ea4ab2d97f6e2435bdb29f77d9301e2f7ac796c2f465753c2adf3c"},
{file = "boto3-1.15.9.tar.gz", hash = "sha256:02f5f7a2b1349760b030c34f90a9cb4600bf8fe3cbc76b801d122bc4cecf3a7f"}, {file = "boto3-1.35.37.tar.gz", hash = "sha256:470d981583885859fed2fd1c185eeb01cc03e60272d499bafe41b12625b158c8"},
] ]
[package.dependencies] [package.dependencies]
botocore = ">=1.18.9,<1.19.0" botocore = ">=1.35.37,<1.36.0"
jmespath = ">=0.7.1,<1.0.0" jmespath = ">=0.7.1,<2.0.0"
s3transfer = ">=0.3.0,<0.4.0" s3transfer = ">=0.10.0,<0.11.0"
[package.extras]
crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
[[package]] [[package]]
name = "botocore" name = "botocore"
version = "1.18.9" version = "1.35.37"
description = "Low-level, data-driven core of boto 3." description = "Low-level, data-driven core of boto 3."
optional = false optional = false
python-versions = "*" python-versions = ">=3.8"
files = [ files = [
{file = "botocore-1.18.9-py2.py3-none-any.whl", hash = "sha256:dc3244170254cbba7dfde00b0489f830069d93dd6a9e555178d989072d7ee7c2"}, {file = "botocore-1.35.37-py3-none-any.whl", hash = "sha256:64f965d4ba7adb8d79ce044c3aef7356e05dd74753cf7e9115b80f477845d920"},
{file = "botocore-1.18.9.tar.gz", hash = "sha256:35b06b8801eb2dd7e708de35581f9c0304740645874f3af5b8b0c1648f8d6365"}, {file = "botocore-1.35.37.tar.gz", hash = "sha256:b2b4d29bafd95b698344f2f0577bb67064adbf1735d8a0e3c7473daa59c23ba6"},
] ]
[package.dependencies] [package.dependencies]
jmespath = ">=0.7.1,<1.0.0" jmespath = ">=0.7.1,<2.0.0"
python-dateutil = ">=2.1,<3.0.0" python-dateutil = ">=2.1,<3.0.0"
urllib3 = {version = ">=1.20,<1.26", markers = "python_version != \"3.4\""} urllib3 = {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""}
[package.extras]
crt = ["awscrt (==0.22.0)"]
[[package]] [[package]]
name = "cachetools" name = "cachetools"
@ -1372,6 +1378,7 @@ files = [
{file = "greenlet-2.0.2-cp27-cp27m-win32.whl", hash = "sha256:6c3acb79b0bfd4fe733dff8bc62695283b57949ebcca05ae5c129eb606ff2d74"}, {file = "greenlet-2.0.2-cp27-cp27m-win32.whl", hash = "sha256:6c3acb79b0bfd4fe733dff8bc62695283b57949ebcca05ae5c129eb606ff2d74"},
{file = "greenlet-2.0.2-cp27-cp27m-win_amd64.whl", hash = "sha256:283737e0da3f08bd637b5ad058507e578dd462db259f7f6e4c5c365ba4ee9343"}, {file = "greenlet-2.0.2-cp27-cp27m-win_amd64.whl", hash = "sha256:283737e0da3f08bd637b5ad058507e578dd462db259f7f6e4c5c365ba4ee9343"},
{file = "greenlet-2.0.2-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:d27ec7509b9c18b6d73f2f5ede2622441de812e7b1a80bbd446cb0633bd3d5ae"}, {file = "greenlet-2.0.2-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:d27ec7509b9c18b6d73f2f5ede2622441de812e7b1a80bbd446cb0633bd3d5ae"},
{file = "greenlet-2.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d967650d3f56af314b72df7089d96cda1083a7fc2da05b375d2bc48c82ab3f3c"},
{file = "greenlet-2.0.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:30bcf80dda7f15ac77ba5af2b961bdd9dbc77fd4ac6105cee85b0d0a5fcf74df"}, {file = "greenlet-2.0.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:30bcf80dda7f15ac77ba5af2b961bdd9dbc77fd4ac6105cee85b0d0a5fcf74df"},
{file = "greenlet-2.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:26fbfce90728d82bc9e6c38ea4d038cba20b7faf8a0ca53a9c07b67318d46088"}, {file = "greenlet-2.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:26fbfce90728d82bc9e6c38ea4d038cba20b7faf8a0ca53a9c07b67318d46088"},
{file = "greenlet-2.0.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9190f09060ea4debddd24665d6804b995a9c122ef5917ab26e1566dcc712ceeb"}, {file = "greenlet-2.0.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9190f09060ea4debddd24665d6804b995a9c122ef5917ab26e1566dcc712ceeb"},
@ -1380,6 +1387,7 @@ files = [
{file = "greenlet-2.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:76ae285c8104046b3a7f06b42f29c7b73f77683df18c49ab5af7983994c2dd91"}, {file = "greenlet-2.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:76ae285c8104046b3a7f06b42f29c7b73f77683df18c49ab5af7983994c2dd91"},
{file = "greenlet-2.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:2d4686f195e32d36b4d7cf2d166857dbd0ee9f3d20ae349b6bf8afc8485b3645"}, {file = "greenlet-2.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:2d4686f195e32d36b4d7cf2d166857dbd0ee9f3d20ae349b6bf8afc8485b3645"},
{file = "greenlet-2.0.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c4302695ad8027363e96311df24ee28978162cdcdd2006476c43970b384a244c"}, {file = "greenlet-2.0.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c4302695ad8027363e96311df24ee28978162cdcdd2006476c43970b384a244c"},
{file = "greenlet-2.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d4606a527e30548153be1a9f155f4e283d109ffba663a15856089fb55f933e47"},
{file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c48f54ef8e05f04d6eff74b8233f6063cb1ed960243eacc474ee73a2ea8573ca"}, {file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c48f54ef8e05f04d6eff74b8233f6063cb1ed960243eacc474ee73a2ea8573ca"},
{file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a1846f1b999e78e13837c93c778dcfc3365902cfb8d1bdb7dd73ead37059f0d0"}, {file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a1846f1b999e78e13837c93c778dcfc3365902cfb8d1bdb7dd73ead37059f0d0"},
{file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a06ad5312349fec0ab944664b01d26f8d1f05009566339ac6f63f56589bc1a2"}, {file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a06ad5312349fec0ab944664b01d26f8d1f05009566339ac6f63f56589bc1a2"},
@ -1409,6 +1417,7 @@ files = [
{file = "greenlet-2.0.2-cp37-cp37m-win32.whl", hash = "sha256:3f6ea9bd35eb450837a3d80e77b517ea5bc56b4647f5502cd28de13675ee12f7"}, {file = "greenlet-2.0.2-cp37-cp37m-win32.whl", hash = "sha256:3f6ea9bd35eb450837a3d80e77b517ea5bc56b4647f5502cd28de13675ee12f7"},
{file = "greenlet-2.0.2-cp37-cp37m-win_amd64.whl", hash = "sha256:7492e2b7bd7c9b9916388d9df23fa49d9b88ac0640db0a5b4ecc2b653bf451e3"}, {file = "greenlet-2.0.2-cp37-cp37m-win_amd64.whl", hash = "sha256:7492e2b7bd7c9b9916388d9df23fa49d9b88ac0640db0a5b4ecc2b653bf451e3"},
{file = "greenlet-2.0.2-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b864ba53912b6c3ab6bcb2beb19f19edd01a6bfcbdfe1f37ddd1778abfe75a30"}, {file = "greenlet-2.0.2-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b864ba53912b6c3ab6bcb2beb19f19edd01a6bfcbdfe1f37ddd1778abfe75a30"},
{file = "greenlet-2.0.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1087300cf9700bbf455b1b97e24db18f2f77b55302a68272c56209d5587c12d1"},
{file = "greenlet-2.0.2-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:ba2956617f1c42598a308a84c6cf021a90ff3862eddafd20c3333d50f0edb45b"}, {file = "greenlet-2.0.2-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:ba2956617f1c42598a308a84c6cf021a90ff3862eddafd20c3333d50f0edb45b"},
{file = "greenlet-2.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc3a569657468b6f3fb60587e48356fe512c1754ca05a564f11366ac9e306526"}, {file = "greenlet-2.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc3a569657468b6f3fb60587e48356fe512c1754ca05a564f11366ac9e306526"},
{file = "greenlet-2.0.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8eab883b3b2a38cc1e050819ef06a7e6344d4a990d24d45bc6f2cf959045a45b"}, {file = "greenlet-2.0.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8eab883b3b2a38cc1e050819ef06a7e6344d4a990d24d45bc6f2cf959045a45b"},
@ -1417,6 +1426,7 @@ files = [
{file = "greenlet-2.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b0ef99cdbe2b682b9ccbb964743a6aca37905fda5e0452e5ee239b1654d37f2a"}, {file = "greenlet-2.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b0ef99cdbe2b682b9ccbb964743a6aca37905fda5e0452e5ee239b1654d37f2a"},
{file = "greenlet-2.0.2-cp38-cp38-win32.whl", hash = "sha256:b80f600eddddce72320dbbc8e3784d16bd3fb7b517e82476d8da921f27d4b249"}, {file = "greenlet-2.0.2-cp38-cp38-win32.whl", hash = "sha256:b80f600eddddce72320dbbc8e3784d16bd3fb7b517e82476d8da921f27d4b249"},
{file = "greenlet-2.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:4d2e11331fc0c02b6e84b0d28ece3a36e0548ee1a1ce9ddde03752d9b79bba40"}, {file = "greenlet-2.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:4d2e11331fc0c02b6e84b0d28ece3a36e0548ee1a1ce9ddde03752d9b79bba40"},
{file = "greenlet-2.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8512a0c38cfd4e66a858ddd1b17705587900dd760c6003998e9472b77b56d417"},
{file = "greenlet-2.0.2-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:88d9ab96491d38a5ab7c56dd7a3cc37d83336ecc564e4e8816dbed12e5aaefc8"}, {file = "greenlet-2.0.2-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:88d9ab96491d38a5ab7c56dd7a3cc37d83336ecc564e4e8816dbed12e5aaefc8"},
{file = "greenlet-2.0.2-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:561091a7be172ab497a3527602d467e2b3fbe75f9e783d8b8ce403fa414f71a6"}, {file = "greenlet-2.0.2-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:561091a7be172ab497a3527602d467e2b3fbe75f9e783d8b8ce403fa414f71a6"},
{file = "greenlet-2.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:971ce5e14dc5e73715755d0ca2975ac88cfdaefcaab078a284fea6cfabf866df"}, {file = "greenlet-2.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:971ce5e14dc5e73715755d0ca2975ac88cfdaefcaab078a284fea6cfabf866df"},
@ -2891,50 +2901,72 @@ files = [
[[package]] [[package]]
name = "s3transfer" name = "s3transfer"
version = "0.3.3" version = "0.10.3"
description = "An Amazon S3 Transfer Manager" description = "An Amazon S3 Transfer Manager"
optional = false optional = false
python-versions = "*" python-versions = ">=3.8"
files = [ files = [
{file = "s3transfer-0.3.3-py2.py3-none-any.whl", hash = "sha256:2482b4259524933a022d59da830f51bd746db62f047d6eb213f2f8855dcb8a13"}, {file = "s3transfer-0.10.3-py3-none-any.whl", hash = "sha256:263ed587a5803c6c708d3ce44dc4dfedaab4c1a32e8329bab818933d79ddcf5d"},
{file = "s3transfer-0.3.3.tar.gz", hash = "sha256:921a37e2aefc64145e7b73d50c71bb4f26f46e4c9f414dc648c6245ff92cf7db"}, {file = "s3transfer-0.10.3.tar.gz", hash = "sha256:4f50ed74ab84d474ce614475e0b8d5047ff080810aac5d01ea25231cfc944b0c"},
] ]
[package.dependencies] [package.dependencies]
botocore = ">=1.12.36,<2.0a.0" botocore = ">=1.33.2,<2.0a.0"
[package.extras]
crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"]
[[package]] [[package]]
name = "sentry-sdk" name = "sentry-sdk"
version = "1.5.11" version = "2.16.0"
description = "Python client for Sentry (https://sentry.io)" description = "Python client for Sentry (https://sentry.io)"
optional = false optional = false
python-versions = "*" python-versions = ">=3.6"
files = [ files = [
{file = "sentry-sdk-1.5.11.tar.gz", hash = "sha256:6c01d9d0b65935fd275adc120194737d1df317dce811e642cbf0394d0d37a007"}, {file = "sentry_sdk-2.16.0-py2.py3-none-any.whl", hash = "sha256:49139c31ebcd398f4f6396b18910610a0c1602f6e67083240c33019d1f6aa30c"},
{file = "sentry_sdk-1.5.11-py2.py3-none-any.whl", hash = "sha256:c17179183cac614e900cbd048dab03f49a48e2820182ec686c25e7ce46f8548f"}, {file = "sentry_sdk-2.16.0.tar.gz", hash = "sha256:90f733b32e15dfc1999e6b7aca67a38688a567329de4d6e184154a73f96c6892"},
] ]
[package.dependencies] [package.dependencies]
certifi = "*" certifi = "*"
urllib3 = ">=1.10.0" urllib3 = ">=1.26.11"
[package.extras] [package.extras]
aiohttp = ["aiohttp (>=3.5)"] aiohttp = ["aiohttp (>=3.5)"]
anthropic = ["anthropic (>=0.16)"]
arq = ["arq (>=0.23)"]
asyncpg = ["asyncpg (>=0.23)"]
beam = ["apache-beam (>=2.12)"] beam = ["apache-beam (>=2.12)"]
bottle = ["bottle (>=0.12.13)"] bottle = ["bottle (>=0.12.13)"]
celery = ["celery (>=3)"] celery = ["celery (>=3)"]
celery-redbeat = ["celery-redbeat (>=2)"]
chalice = ["chalice (>=1.16.0)"] chalice = ["chalice (>=1.16.0)"]
clickhouse-driver = ["clickhouse-driver (>=0.2.0)"]
django = ["django (>=1.8)"] django = ["django (>=1.8)"]
falcon = ["falcon (>=1.4)"] falcon = ["falcon (>=1.4)"]
flask = ["blinker (>=1.1)", "flask (>=0.11)"] fastapi = ["fastapi (>=0.79.0)"]
flask = ["blinker (>=1.1)", "flask (>=0.11)", "markupsafe"]
grpcio = ["grpcio (>=1.21.1)", "protobuf (>=3.8.0)"]
http2 = ["httpcore[http2] (==1.*)"]
httpx = ["httpx (>=0.16.0)"] httpx = ["httpx (>=0.16.0)"]
huey = ["huey (>=2)"]
huggingface-hub = ["huggingface-hub (>=0.22)"]
langchain = ["langchain (>=0.0.210)"]
litestar = ["litestar (>=2.0.0)"]
loguru = ["loguru (>=0.5)"]
openai = ["openai (>=1.0.0)", "tiktoken (>=0.3.0)"]
opentelemetry = ["opentelemetry-distro (>=0.35b0)"]
opentelemetry-experimental = ["opentelemetry-distro"]
pure-eval = ["asttokens", "executing", "pure-eval"] pure-eval = ["asttokens", "executing", "pure-eval"]
pymongo = ["pymongo (>=3.1)"]
pyspark = ["pyspark (>=2.4.4)"] pyspark = ["pyspark (>=2.4.4)"]
quart = ["blinker (>=1.1)", "quart (>=0.16.1)"] quart = ["blinker (>=1.1)", "quart (>=0.16.1)"]
rq = ["rq (>=0.6)"] rq = ["rq (>=0.6)"]
sanic = ["sanic (>=0.8)"] sanic = ["sanic (>=0.8)"]
sqlalchemy = ["sqlalchemy (>=1.2)"] sqlalchemy = ["sqlalchemy (>=1.2)"]
tornado = ["tornado (>=5)"] starlette = ["starlette (>=0.19.1)"]
starlite = ["starlite (>=1.48)"]
tornado = ["tornado (>=6)"]
[[package]] [[package]]
name = "setuptools" name = "setuptools"
@ -3295,18 +3327,18 @@ files = [
[[package]] [[package]]
name = "urllib3" name = "urllib3"
version = "1.25.10" version = "1.26.20"
description = "HTTP library with thread-safe connection pooling, file post, and more." description = "HTTP library with thread-safe connection pooling, file post, and more."
optional = false optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4" python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7"
files = [ files = [
{file = "urllib3-1.25.10-py2.py3-none-any.whl", hash = "sha256:e7983572181f5e1522d9c98453462384ee92a0be7fac5f1413a1e35c56cc0461"}, {file = "urllib3-1.26.20-py2.py3-none-any.whl", hash = "sha256:0ed14ccfbf1c30a9072c7ca157e4319b70d65f623e91e7b32fadb2853431016e"},
{file = "urllib3-1.25.10.tar.gz", hash = "sha256:91056c15fa70756691db97756772bb1eb9678fa585d9184f24534b100dc60f4a"}, {file = "urllib3-1.26.20.tar.gz", hash = "sha256:40c2dc0c681e47eb8f90e7e27bf6ff7df2e677421fd46756da1161c39ca70d32"},
] ]
[package.extras] [package.extras]
brotli = ["brotlipy (>=0.6.0)"] brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"]
secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)"] secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"]
socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"]
[[package]] [[package]]
@ -3704,4 +3736,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "22b9a61e9999a215aacb889b3790ee1a6840ce249aea2e3d16c6113243d5c126" content-hash = "314f199bd50ccbf636ce1c6c753f8c79a1f5a16aa7c1a330a2ec514a13dbad2d"

View File

@ -69,7 +69,7 @@ python-dotenv = "^0.14.0"
ipython = "^7.31.1" ipython = "^7.31.1"
sqlalchemy_utils = "^0.36.8" sqlalchemy_utils = "^0.36.8"
psycopg2-binary = "^2.9.3" psycopg2-binary = "^2.9.3"
sentry_sdk = "^1.5.11" sentry_sdk = "^2.16.0"
blinker = "^1.4" blinker = "^1.4"
arrow = "^0.16.0" arrow = "^0.16.0"
Flask-WTF = "^0.14.3" Flask-WTF = "^0.14.3"

View File

@ -1,4 +1,3 @@
import json
import os import os
import time import time
from datetime import timedelta from datetime import timedelta
@ -7,10 +6,9 @@ import arrow
import click import click
import flask_limiter import flask_limiter
import flask_profiler import flask_profiler
import newrelic.agent
import sentry_sdk import sentry_sdk
from coinbase_commerce.error import WebhookInvalidPayload, SignatureVerificationError
from coinbase_commerce.webhook import Webhook
from dateutil.relativedelta import relativedelta
from flask import ( from flask import (
Flask, Flask,
redirect, redirect,
@ -29,7 +27,7 @@ from sentry_sdk.integrations.flask import FlaskIntegration
from sentry_sdk.integrations.sqlalchemy import SqlalchemyIntegration from sentry_sdk.integrations.sqlalchemy import SqlalchemyIntegration
from werkzeug.middleware.proxy_fix import ProxyFix from werkzeug.middleware.proxy_fix import ProxyFix
from app import paddle_utils, config, paddle_callback, constants from app import config, constants
from app.admin_model import ( from app.admin_model import (
SLAdminIndexView, SLAdminIndexView,
UserAdmin, UserAdmin,
@ -55,7 +53,6 @@ from app.config import (
FLASK_SECRET, FLASK_SECRET,
SENTRY_DSN, SENTRY_DSN,
URL, URL,
PADDLE_MONTHLY_PRODUCT_ID,
FLASK_PROFILER_PATH, FLASK_PROFILER_PATH,
FLASK_PROFILER_PASSWORD, FLASK_PROFILER_PASSWORD,
SENTRY_FRONT_END_DSN, SENTRY_FRONT_END_DSN,
@ -69,22 +66,16 @@ from app.config import (
LANDING_PAGE_URL, LANDING_PAGE_URL,
STATUS_PAGE_URL, STATUS_PAGE_URL,
SUPPORT_EMAIL, SUPPORT_EMAIL,
PADDLE_MONTHLY_PRODUCT_IDS,
PADDLE_YEARLY_PRODUCT_IDS,
PGP_SIGNER, PGP_SIGNER,
COINBASE_WEBHOOK_SECRET,
PAGE_LIMIT, PAGE_LIMIT,
PADDLE_COUPON_ID,
ZENDESK_ENABLED, ZENDESK_ENABLED,
MAX_NB_EMAIL_FREE_PLAN, MAX_NB_EMAIL_FREE_PLAN,
MEM_STORE_URI, MEM_STORE_URI,
) )
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.subscription_webhook import execute_subscription_webhook
from app.db import Session from app.db import Session
from app.developer.base import developer_bp from app.developer.base import developer_bp
from app.discover.base import discover_bp from app.discover.base import discover_bp
from app.email_utils import send_email, render
from app.extensions import login_manager, limiter from app.extensions import login_manager, limiter
from app.fake_data import fake_data from app.fake_data import fake_data
from app.internal.base import internal_bp from app.internal.base import internal_bp
@ -93,11 +84,8 @@ from app.log import LOG
from app.models import ( from app.models import (
User, User,
Alias, Alias,
Subscription,
PlanEnum,
CustomDomain, CustomDomain,
Mailbox, Mailbox,
CoinbaseSubscription,
EmailLog, EmailLog,
Contact, Contact,
ManualSubscription, ManualSubscription,
@ -114,9 +102,11 @@ from app.monitor.base import monitor_bp
from app.newsletter_utils import send_newsletter_to_user from app.newsletter_utils import send_newsletter_to_user
from app.oauth.base import oauth_bp from app.oauth.base import oauth_bp
from app.onboarding.base import onboarding_bp from app.onboarding.base import onboarding_bp
from app.payments.coinbase import setup_coinbase_commerce
from app.payments.paddle import setup_paddle_callback
from app.phone.base import phone_bp from app.phone.base import phone_bp
from app.redis_services import initialize_redis_services from app.redis_services import initialize_redis_services
from app.utils import random_string from app.sentry_utils import sentry_before_send
if SENTRY_DSN: if SENTRY_DSN:
LOG.d("enable sentry") LOG.d("enable sentry")
@ -127,6 +117,7 @@ if SENTRY_DSN:
FlaskIntegration(), FlaskIntegration(),
SqlalchemyIntegration(), SqlalchemyIntegration(),
], ],
before_send=sentry_before_send,
) )
# the app is served behind nginx which uses http and not https # the app is served behind nginx which uses http and not https
@ -299,7 +290,9 @@ def set_index_page(app):
res.status_code, res.status_code,
time.time() - start_time, time.time() - start_time,
) )
newrelic.agent.record_custom_event(
"HttpResponseStatus", {"code": res.status_code}
)
return res return res
@ -441,341 +434,6 @@ def jinja2_filter(app):
) )
def setup_paddle_callback(app: Flask):
@app.route("/paddle", methods=["GET", "POST"])
def paddle():
LOG.d(f"paddle callback {request.form.get('alert_name')} {request.form}")
# make sure the request comes from Paddle
if not paddle_utils.verify_incoming_request(dict(request.form)):
LOG.e("request not coming from paddle. Request data:%s", dict(request.form))
return "KO", 400
if (
request.form.get("alert_name") == "subscription_created"
): # new user subscribes
# the passthrough is json encoded, e.g.
# request.form.get("passthrough") = '{"user_id": 88 }'
passthrough = json.loads(request.form.get("passthrough"))
user_id = passthrough.get("user_id")
user = User.get(user_id)
subscription_plan_id = int(request.form.get("subscription_plan_id"))
if subscription_plan_id in PADDLE_MONTHLY_PRODUCT_IDS:
plan = PlanEnum.monthly
elif subscription_plan_id in PADDLE_YEARLY_PRODUCT_IDS:
plan = PlanEnum.yearly
else:
LOG.e(
"Unknown subscription_plan_id %s %s",
subscription_plan_id,
request.form,
)
return "No such subscription", 400
sub = Subscription.get_by(user_id=user.id)
if not sub:
LOG.d(f"create a new Subscription for user {user}")
Subscription.create(
user_id=user.id,
cancel_url=request.form.get("cancel_url"),
update_url=request.form.get("update_url"),
subscription_id=request.form.get("subscription_id"),
event_time=arrow.now(),
next_bill_date=arrow.get(
request.form.get("next_bill_date"), "YYYY-MM-DD"
).date(),
plan=plan,
)
else:
LOG.d(f"Update an existing Subscription for user {user}")
sub.cancel_url = request.form.get("cancel_url")
sub.update_url = request.form.get("update_url")
sub.subscription_id = request.form.get("subscription_id")
sub.event_time = arrow.now()
sub.next_bill_date = arrow.get(
request.form.get("next_bill_date"), "YYYY-MM-DD"
).date()
sub.plan = plan
# make sure to set the new plan as not-cancelled
# in case user cancels a plan and subscribes a new plan
sub.cancelled = False
execute_subscription_webhook(user)
LOG.d("User %s upgrades!", user)
Session.commit()
elif request.form.get("alert_name") == "subscription_payment_succeeded":
subscription_id = request.form.get("subscription_id")
LOG.d("Update subscription %s", subscription_id)
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
# when user subscribes, the "subscription_payment_succeeded" can arrive BEFORE "subscription_created"
# at that time, subscription object does not exist yet
if sub:
sub.event_time = arrow.now()
sub.next_bill_date = arrow.get(
request.form.get("next_bill_date"), "YYYY-MM-DD"
).date()
Session.commit()
execute_subscription_webhook(sub.user)
elif request.form.get("alert_name") == "subscription_cancelled":
subscription_id = request.form.get("subscription_id")
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
if sub:
# cancellation_effective_date should be the same as next_bill_date
LOG.w(
"Cancel subscription %s %s on %s, next bill date %s",
subscription_id,
sub.user,
request.form.get("cancellation_effective_date"),
sub.next_bill_date,
)
sub.event_time = arrow.now()
sub.cancelled = True
Session.commit()
user = sub.user
send_email(
user.email,
"SimpleLogin - your subscription is canceled",
render(
"transactional/subscription-cancel.txt",
user=user,
end_date=request.form.get("cancellation_effective_date"),
),
)
execute_subscription_webhook(sub.user)
else:
# user might have deleted their account
LOG.i(f"Cancel non-exist subscription {subscription_id}")
return "OK"
elif request.form.get("alert_name") == "subscription_updated":
subscription_id = request.form.get("subscription_id")
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
if sub:
next_bill_date = request.form.get("next_bill_date")
if not next_bill_date:
paddle_callback.failed_payment(sub, subscription_id)
return "OK"
LOG.d(
"Update subscription %s %s on %s, next bill date %s",
subscription_id,
sub.user,
request.form.get("cancellation_effective_date"),
sub.next_bill_date,
)
if (
int(request.form.get("subscription_plan_id"))
== PADDLE_MONTHLY_PRODUCT_ID
):
plan = PlanEnum.monthly
else:
plan = PlanEnum.yearly
sub.cancel_url = request.form.get("cancel_url")
sub.update_url = request.form.get("update_url")
sub.event_time = arrow.now()
sub.next_bill_date = arrow.get(
request.form.get("next_bill_date"), "YYYY-MM-DD"
).date()
sub.plan = plan
# make sure to set the new plan as not-cancelled
sub.cancelled = False
Session.commit()
execute_subscription_webhook(sub.user)
else:
LOG.w(
f"update non-exist subscription {subscription_id}. {request.form}"
)
return "No such subscription", 400
elif request.form.get("alert_name") == "payment_refunded":
subscription_id = request.form.get("subscription_id")
LOG.d("Refund request for subscription %s", subscription_id)
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
if sub:
user = sub.user
Subscription.delete(sub.id)
Session.commit()
LOG.e("%s requests a refund", user)
execute_subscription_webhook(sub.user)
elif request.form.get("alert_name") == "subscription_payment_refunded":
subscription_id = request.form.get("subscription_id")
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
LOG.d(
"Handle subscription_payment_refunded for subscription %s",
subscription_id,
)
if not sub:
LOG.w(
"No such subscription for %s, payload %s",
subscription_id,
request.form,
)
return "No such subscription"
plan_id = int(request.form["subscription_plan_id"])
if request.form["refund_type"] == "full":
if plan_id in PADDLE_MONTHLY_PRODUCT_IDS:
LOG.d("subtract 1 month from next_bill_date %s", sub.next_bill_date)
sub.next_bill_date = sub.next_bill_date - relativedelta(months=1)
LOG.d("next_bill_date is %s", sub.next_bill_date)
Session.commit()
elif plan_id in PADDLE_YEARLY_PRODUCT_IDS:
LOG.d("subtract 1 year from next_bill_date %s", sub.next_bill_date)
sub.next_bill_date = sub.next_bill_date - relativedelta(years=1)
LOG.d("next_bill_date is %s", sub.next_bill_date)
Session.commit()
else:
LOG.e("Unknown plan_id %s", plan_id)
else:
LOG.w("partial subscription_payment_refunded, not handled")
execute_subscription_webhook(sub.user)
return "OK"
@app.route("/paddle_coupon", methods=["GET", "POST"])
def paddle_coupon():
LOG.d("paddle coupon callback %s", request.form)
if not paddle_utils.verify_incoming_request(dict(request.form)):
LOG.e("request not coming from paddle. Request data:%s", dict(request.form))
return "KO", 400
product_id = request.form.get("p_product_id")
if product_id != PADDLE_COUPON_ID:
LOG.e("product_id %s not match with %s", product_id, PADDLE_COUPON_ID)
return "KO", 400
email = request.form.get("email")
LOG.d("Paddle coupon request for %s", email)
coupon = Coupon.create(
code=random_string(30),
comment="For 1-year coupon",
expires_date=arrow.now().shift(years=1, days=-1),
commit=True,
)
return (
f"Your 1-year coupon is <b>{coupon.code}</b> <br> "
f"It's valid until <b>{coupon.expires_date.date().isoformat()}</b>"
)
def setup_coinbase_commerce(app):
@app.route("/coinbase", methods=["POST"])
def coinbase_webhook():
# event payload
request_data = request.data.decode("utf-8")
# webhook signature
request_sig = request.headers.get("X-CC-Webhook-Signature", None)
try:
# signature verification and event object construction
event = Webhook.construct_event(
request_data, request_sig, COINBASE_WEBHOOK_SECRET
)
except (WebhookInvalidPayload, SignatureVerificationError) as e:
LOG.e("Invalid Coinbase webhook")
return str(e), 400
LOG.d("Coinbase event %s", event)
if event["type"] == "charge:confirmed":
if handle_coinbase_event(event):
return "success", 200
else:
return "error", 400
return "success", 200
def handle_coinbase_event(event) -> bool:
server_user_id = event["data"]["metadata"]["user_id"]
try:
user_id = int(server_user_id)
except ValueError:
user_id = int(float(server_user_id))
code = event["data"]["code"]
user = User.get(user_id)
if not user:
LOG.e("User not found %s", user_id)
return False
coinbase_subscription: CoinbaseSubscription = CoinbaseSubscription.get_by(
user_id=user_id
)
if not coinbase_subscription:
LOG.d("Create a coinbase subscription for %s", user)
coinbase_subscription = CoinbaseSubscription.create(
user_id=user_id, end_at=arrow.now().shift(years=1), code=code, commit=True
)
send_email(
user.email,
"Your SimpleLogin account has been upgraded",
render(
"transactional/coinbase/new-subscription.txt",
user=user,
coinbase_subscription=coinbase_subscription,
),
render(
"transactional/coinbase/new-subscription.html",
user=user,
coinbase_subscription=coinbase_subscription,
),
)
else:
if coinbase_subscription.code != code:
LOG.d("Update code from %s to %s", coinbase_subscription.code, code)
coinbase_subscription.code = code
if coinbase_subscription.is_active():
coinbase_subscription.end_at = coinbase_subscription.end_at.shift(years=1)
else: # already expired subscription
coinbase_subscription.end_at = arrow.now().shift(years=1)
Session.commit()
send_email(
user.email,
"Your SimpleLogin account has been extended",
render(
"transactional/coinbase/extend-subscription.txt",
user=user,
coinbase_subscription=coinbase_subscription,
),
render(
"transactional/coinbase/extend-subscription.html",
user=user,
coinbase_subscription=coinbase_subscription,
),
)
execute_subscription_webhook(user)
return True
def init_extensions(app: Flask): def init_extensions(app: Flask):
login_manager.init_app(app) login_manager.init_app(app)

View File

@ -0,0 +1,12 @@
import arrow
from app.db import Session
from app.log import LOG
from app.models import AliasAuditLog
def cleanup_alias_audit_log(oldest_allowed: arrow.Arrow):
LOG.i(f"Deleting alias_audit_log older than {oldest_allowed}")
count = AliasAuditLog.filter(AliasAuditLog.created_at < oldest_allowed).delete()
Session.commit()
LOG.i(f"Deleted {count} alias_audit_log entries")

View File

@ -0,0 +1,12 @@
import arrow
from app.db import Session
from app.log import LOG
from app.models import UserAuditLog
def cleanup_user_audit_log(oldest_allowed: arrow.Arrow):
LOG.i(f"Deleting user_audit_log older than {oldest_allowed}")
count = UserAuditLog.filter(UserAuditLog.created_at < oldest_allowed).delete()
Session.commit()
LOG.i(f"Deleted {count} user_audit_log entries")

View File

@ -1,220 +1,272 @@
{% extends 'admin/master.html' %} {% extends 'admin/master.html' %}
{% macro show_user(user) -%} {% macro show_user(user) -%}
<h4>User {{ user.email }} with ID {{ user.id }}.</h4> <h4>User {{ user.email }} with ID {{ user.id }}.</h4>
{% set pu = helper.partner_user(user) %} {% set pu = helper.partner_user(user) %}
<table class="table"> <table class="table">
<thead> <thead>
<tr>
<th scope="col">User ID</th>
<th scope="col">Email</th>
<th scope="col">Status</th>
<th scope="col">Paid</th>
<th>Subscription</th>
<th>Created At</th>
<th>Updated At</th>
<th>Connected with Proton account</th>
</tr>
</thead>
<tbody>
<tr>
<td>{{ user.id }}</td>
<td><a href="?email={{ user.email }}">{{ user.email }}</a></td>
{% if user.disabled %}
<td class="text-danger">Disabled</td>
{% else %}
<td class="text-success">Enabled</td>
{% endif %}
<td>{{ "yes" if user.is_paid() else "No" }}</td>
<td>{{ user.get_active_subscription() }}</td>
<td>{{ user.created_at }}</td>
<td>{{ user.updated_at }}</td>
{% if pu %}
<td><a href="?email={{ pu.partner_email }}">{{ pu.partner_email }}</a></td>
{% else %}
<td>No</td>
{% endif %}
</tr>
</tbody>
</table>
{%- endmacro %}
{% macro list_mailboxes(mbox_count, mboxes) %}
<h4>
{{ mbox_count }} Mailboxes found.
{% if mbox_count>10 %}Showing only the last 10.{% endif %}
</h4>
<table class="table">
<thead>
<tr>
<th>Mailbox ID</th>
<th>Email</th>
<th>Verified</th>
<th>Created At</th>
</tr>
</thead>
<tbody>
{% for mailbox in mboxes %}
<tr> <tr>
<td>{{ mailbox.id }}</td> <th scope="col">User ID</th>
<td><a href="?email={{mailbox.email}}">{{mailbox.email}}</a></td> <th scope="col">Email</th>
<td>{{ "Yes" if mailbox.verified else "No" }}</td> <th scope="col">Status</th>
<td> <th scope="col">Paid</th>
{{ mailbox.created_at }} <th>Subscription</th>
</td> <th>Created At</th>
<th>Updated At</th>
<th>Connected with Proton account</th>
</tr> </tr>
{% endfor %} </thead>
</tbody> <tbody>
</table> <tr>
<td>{{ user.id }}</td>
<td><a href="?email={{ user.email }}">{{ user.email }}</a></td>
{% if user.disabled %}
<td class="text-danger">Disabled</td>
{% else %}
<td class="text-success">Enabled</td>
{% endif %}
<td>{{ "yes" if user.is_paid() else "No" }}</td>
<td>{{ user.get_active_subscription() }}</td>
<td>{{ user.created_at }}</td>
<td>{{ user.updated_at }}</td>
{% if pu %}
<td><a href="?email={{ pu.partner_email }}">{{ pu.partner_email }}</a></td>
{% else %}
<td>No</td>
{% endif %}
</tr>
</tbody>
</table>
{%- endmacro %}
{% macro list_mailboxes(message, mbox_count, mboxes) %}
<h4>
{{ mbox_count }} {{ message }}.
{% if mbox_count>10 %}Showing only the last 10.{% endif %}
</h4>
<table class="table">
<thead>
<tr>
<th>Mailbox ID</th>
<th>Email</th>
<th>Verified</th>
<th>Created At</th>
</tr>
</thead>
<tbody>
{% for mailbox in mboxes %}
<tr>
<td>{{ mailbox.id }}</td>
<td><a href="?email={{ mailbox.email }}">{{ mailbox.email }}</a></td>
<td>{{ "Yes" if mailbox.verified else "No" }}</td>
<td>
{{ mailbox.created_at }}
</td>
</tr>
{% endfor %}
</tbody>
</table>
{% endmacro %} {% endmacro %}
{% macro list_alias(alias_count, aliases) %} {% macro list_alias(alias_count, aliases) %}
<h4> <h4>
{{ alias_count }} Aliases found. {{ alias_count }} Aliases found.
{% if alias_count>10 %}Showing only the last 10.{% endif %} {% if alias_count>10 %}Showing only the last 10.{% endif %}
</h4> </h4>
<table class="table"> <table class="table">
<thead> <thead>
<tr>
<th>
Alias ID
</th>
<th>
Email
</th>
<th>
Verified
</th>
<th>
Created At
</th>
</tr>
</thead>
<tbody>
{% for alias in aliases %}
<tr> <tr>
<td>{{ alias.id }}</td> <th>
<td><a href="?email={{alias.email}}">{{alias.email}}</a></td> Alias ID
<td>{{ "Yes" if alias.verified else "No" }}</td> </th>
<td>{{ alias.created_at }}</td> <th>
Email
</th>
<th>
Enabled
</th>
<th>
Created At
</th>
</tr> </tr>
{% endfor %} </thead>
</tbody> <tbody>
</table> {% for alias in aliases %}
<tr>
<td>{{ alias.id }}</td>
<td><a href="?email={{ alias.email }}">{{ alias.email }}</a></td>
<td>{{ "Yes" if alias.enabled else "No" }}</td>
<td>{{ alias.created_at }}</td>
</tr>
{% endfor %}
</tbody>
</table>
{% endmacro %} {% endmacro %}
{% macro show_deleted_alias(deleted_alias) -%} {% macro show_deleted_alias(deleted_alias) -%}
<h4>Deleted Alias {{ deleted_alias.email }} with ID {{ deleted_alias.id }}.</h4> <h4>Deleted Alias {{ deleted_alias.email }} with ID {{ deleted_alias.id }}.</h4>
<table class="table"> <table class="table">
<thead> <thead>
<tr> <tr>
<th scope="col">Deleted Alias ID</th> <th scope="col">Deleted Alias ID</th>
<th scope="col">Email</th> <th scope="col">Email</th>
<th scope="col">Deleted At</th> <th scope="col">Deleted At</th>
<th scope="col">Reason</th> <th scope="col">Reason</th>
</tr> </tr>
</thead> </thead>
<tbody> <tbody>
<tr> <tr>
<td>{{ deleted_alias.id }}</td> <td>{{ deleted_alias.id }}</td>
<td>{{ deleted_alias.email }}</td> <td>{{ deleted_alias.email }}</td>
<td>{{ deleted_alias.created_at }}</td> <td>{{ deleted_alias.created_at }}</td>
<td>{{ deleted_alias.reason }}</td> <td>{{ deleted_alias.reason }}</td>
</tr> </tr>
</tbody> </tbody>
</table> </table>
{%- endmacro %} {%- endmacro %}
{% macro show_domain_deleted_alias(dom_deleted_alias) -%} {% macro show_domain_deleted_alias(dom_deleted_alias) -%}
<h4> <h4>
Domain Deleted Alias {{ dom_deleted_alias.email }} with ID {{ dom_deleted_alias.id }} for domain {{ dom_deleted_alias.domain.domain }} Domain Deleted Alias {{ dom_deleted_alias.email }} with ID {{ dom_deleted_alias.id }} for
</h4> domain {{ dom_deleted_alias.domain.domain }}
<table class="table"> </h4>
<thead> <table class="table">
<tr> <thead>
<th scope="col">Deleted Alias ID</th> <tr>
<th scope="col">Email</th> <th scope="col">Deleted Alias ID</th>
<th scope="col">Domain</th> <th scope="col">Email</th>
<th scope="col">Domain ID</th> <th scope="col">Domain</th>
<th scope="col">Domain owner user ID</th> <th scope="col">Domain ID</th>
<th scope="col">Domain owner user email</th> <th scope="col">Domain owner user ID</th>
<th scope="col">Deleted At</th> <th scope="col">Domain owner user email</th>
</tr> <th scope="col">Deleted At</th>
</thead> </tr>
<tbody> </thead>
<tr> <tbody>
<td>{{ dom_deleted_alias.id }}</td> <tr>
<td>{{ dom_deleted_alias.email }}</td> <td>{{ dom_deleted_alias.id }}</td>
<td>{{ dom_deleted_alias.domain.domain }}</td> <td>{{ dom_deleted_alias.email }}</td>
<td>{{ dom_deleted_alias.domain.id }}</td> <td>{{ dom_deleted_alias.domain.domain }}</td>
<td>{{ dom_deleted_alias.domain.user_id }}</td> <td>{{ dom_deleted_alias.domain.id }}</td>
<td>{{ dom_deleted_alias.created_at }}</td> <td>{{ dom_deleted_alias.domain.user_id }}</td>
</tr> <td>{{ dom_deleted_alias.created_at }}</td>
</tbody> </tr>
</table> </tbody>
{{ show_user(data.domain_deleted_alias.domain.user) }} </table>
{{ show_user(data.domain_deleted_alias.domain.user) }}
{%- endmacro %} {%- endmacro %}
{% macro list_alias_audit_log(alias_audit_log) %}
<h4>Alias Audit Log</h4>
<table class="table">
<thead>
<tr>
<th>User ID</th>
<th>Alias ID</th>
<th>Alias Email</th>
<th>Action</th>
<th>Message</th>
<th>Time</th>
</tr>
</thead>
<tbody>
{% for entry in alias_audit_log %}
<tr>
<td>{{ entry.user_id }}</td>
<td>{{ entry.alias_id }}</td>
<td><a href="?email={{ entry.alias_email }}">{{ entry.alias_email }}</a></td>
<td>{{ entry.action }}</td>
<td>{{ entry.message }}</td>
<td>{{ entry.created_at }}</td>
</tr>
{% endfor %}
</tbody>
</table>
{% endmacro %}
{% macro list_user_audit_log(user_audit_log) %}
<h4>User Audit Log</h4>
<table class="table">
<thead>
<tr>
<th>User email</th>
<th>Action</th>
<th>Message</th>
<th>Time</th>
</tr>
</thead>
<tbody>
{% for entry in user_audit_log %}
<tr>
<td><a href="?email={{ entry.user_email }}">{{ entry.user_email }}</a></td>
<td>{{ entry.action }}</td>
<td>{{ entry.message }}</td>
<td>{{ entry.created_at }}</td>
</tr>
{% endfor %}
</tbody>
</table>
{% endmacro %}
{% block body %} {% block body %}
<div class="border border-dark border-2 mt-1 mb-2 p-3">
<form method="get">
<div class="form-group">
<label for="email">Email to search:</label>
<input type="text"
class="form-control"
name="email"
value="{{ email or '' }}" />
</div>
<button type="submit" class="btn btn-primary">Submit</button>
</form>
</div>
{% if data.no_match and email %}
<div class="border border-dark border-2 mt-1 mb-2 p-3 alert alert-warning"
role="alert">No user, alias or mailbox found for {{ email }}</div>
{% endif %}
{% if data.alias %}
<div class="border border-dark border-2 mt-1 mb-2 p-3"> <div class="border border-dark border-2 mt-1 mb-2 p-3">
<h3 class="mb-3">Found Alias {{ data.alias.email }}</h3> <form method="get">
{{ list_alias(1,[data.alias]) }} <div class="form-group">
{{ show_user(data.alias.user) }} <label for="email">Email to search:</label>
{{ list_mailboxes(helper.mailbox_count(data.alias.user) , helper.mailbox_list(data.alias.user) ) }} <input type="text"
class="form-control"
name="email"
value="{{ email or '' }}"/>
</div>
<button type="submit" class="btn btn-primary">Submit</button>
</form>
</div> </div>
{% endif %} {% if data.no_match and email %}
{% if data.user %} <div class="border border-dark border-2 mt-1 mb-2 p-3 alert alert-warning"
role="alert">No user, alias or mailbox found for {{ email }}</div>
<div class="border border-dark border-2 mt-1 mb-2 p-3">
<h3 class="mb-3">Found User {{ data.user.email }}</h3>
{{ show_user(data.user) }}
{{ list_mailboxes(helper.mailbox_count(data.user) , helper.mailbox_list(data.user) ) }}
{{ list_alias(helper.alias_count(data.user) ,helper.alias_list(data.user)) }}
</div>
{% endif %}
{% if data.mailbox_count > 10 %}
<h3>Found more than 10 mailboxes for {{ email }}. Showing the last 10</h3>
{% elif data.mailbox_count > 0 %}
<h3>Found {{ data.mailbox_count }} mailbox(es) for {{ email }}</h3>
{% endif %} {% endif %}
{% for mailbox in data.mailbox %}
<div class="border border-dark mt-1 mb-2 p-3"> {% if data.alias %}
<h3 class="mb-3">Found Mailbox {{ mailbox.email }}</h3> <div class="border border-dark border-2 mt-1 mb-2 p-3">
{{ list_mailboxes(1, [mailbox]) }} <h3 class="mb-3">Found Alias {{ data.alias.email }}</h3>
{{ show_user(mailbox.user) }} {{ list_alias(1,[data.alias]) }}
</div> {{ list_alias_audit_log(data.alias_audit_log) }}
{% endfor %} {{ list_mailboxes("Mailboxes for alias", helper.alias_mailbox_count(data.alias), helper.alias_mailboxes(data.alias)) }}
{% if data.deleted_alias %} {{ show_user(data.alias.user) }}
</div>
{% endif %}
<div class="border border-dark mt-1 mb-2 p-3"> {% if data.user %}
<h3 class="mb-3">Found DeletedAlias {{ data.deleted_alias.email }}</h3> <div class="border border-dark border-2 mt-1 mb-2 p-3">
{{ show_deleted_alias(data.deleted_alias) }} <h3 class="mb-3">Found User {{ data.user.email }}</h3>
</div> {{ show_user(data.user) }}
{% endif %} {{ list_mailboxes("Mailboxes for user", helper.mailbox_count(data.user) , helper.mailbox_list(data.user) ) }}
{% if data.domain_deleted_alias %} {{ list_alias(helper.alias_count(data.user) ,helper.alias_list(data.user)) }}
{{ list_user_audit_log(data.user_audit_log) }}
</div>
{% endif %}
{% if data.mailbox_count > 10 %}
<h3>Found more than 10 mailboxes for {{ email }}. Showing the last 10</h3>
{% elif data.mailbox_count > 0 %}
<h3>Found {{ data.mailbox_count }} mailbox(es) for {{ email }}</h3>
{% endif %}
{% for mailbox in data.mailbox %}
<div class="border border-dark mt-1 mb-2 p-3"> <div class="border border-dark mt-1 mb-2 p-3">
<h3 class="mb-3">Found DomainDeletedAlias {{ data.domain_deleted_alias.email }}</h3> <h3 class="mb-3">Found Mailbox {{ mailbox.email }}</h3>
{{ show_domain_deleted_alias(data.domain_deleted_alias) }} {{ list_mailboxes("Mailbox found", 1, [mailbox]) }}
</div> {{ show_user(mailbox.user) }}
{% endif %} </div>
{% endfor %}
{% if data.deleted_alias %}
<div class="border border-dark mt-1 mb-2 p-3">
<h3 class="mb-3">Found DeletedAlias {{ data.deleted_alias.email }}</h3>
{{ show_deleted_alias(data.deleted_alias) }}
</div>
{% endif %}
{% if data.domain_deleted_alias %}
<div class="border border-dark mt-1 mb-2 p-3">
<h3 class="mb-3">Found DomainDeletedAlias {{ data.domain_deleted_alias.email }}</h3>
{{ show_domain_deleted_alias(data.domain_deleted_alias) }}
</div>
{% endif %}
{% endblock %} {% endblock %}

View File

@ -91,7 +91,8 @@
<br /> <br />
Some domain registrars (Namecheap, CloudFlare, etc) might also use <em>@</em> for the root domain. Some domain registrars (Namecheap, CloudFlare, etc) might also use <em>@</em> for the root domain.
</div> </div>
{% for priority, email_server in EMAIL_SERVERS_WITH_PRIORITY %}
{% for record in expected_mx_records %}
<div class="mb-3 p-3 dns-record"> <div class="mb-3 p-3 dns-record">
Record: MX Record: MX
@ -99,14 +100,15 @@
Domain: {{ custom_domain.domain }} or Domain: {{ custom_domain.domain }} or
<b>@</b> <b>@</b>
<br /> <br />
Priority: {{ priority }} Priority: {{ record.priority }}
<br /> <br />
Target: <em data-toggle="tooltip" Target: <em data-toggle="tooltip"
title="Click to copy" title="Click to copy"
class="clipboard" class="clipboard"
data-clipboard-text="{{ email_server }}">{{ email_server }}</em> data-clipboard-text="{{ record.domain }}">{{ record.domain }}</em>
</div> </div>
{% endfor %} {% endfor %}
<form method="post" action="#mx-form"> <form method="post" action="#mx-form">
{{ csrf_form.csrf_token }} {{ csrf_form.csrf_token }}
<input type="hidden" name="form-name" value="check-mx"> <input type="hidden" name="form-name" value="check-mx">

View File

@ -536,7 +536,7 @@ def test_create_contact_route_free_users(flask_client):
assert r.status_code == 201 assert r.status_code == 201
# End trial and disallow for new free users. Config should allow it # End trial and disallow for new free users. Config should allow it
user.flags = User.FLAG_FREE_DISABLE_CREATE_ALIAS user.flags = User.FLAG_DISABLE_CREATE_CONTACTS
Session.commit() Session.commit()
r = flask_client.post( r = flask_client.post(
url_for("api.create_contact_route", alias_id=alias.id), url_for("api.create_contact_route", alias_id=alias.id),

View File

@ -4,7 +4,7 @@ from app.models import (
Alias, Alias,
Contact, Contact,
) )
from tests.utils import login from tests.utils import login, random_email
def test_add_contact_success(flask_client): def test_add_contact_success(flask_client):
@ -13,26 +13,28 @@ def test_add_contact_success(flask_client):
assert Contact.filter_by(user_id=user.id).count() == 0 assert Contact.filter_by(user_id=user.id).count() == 0
email = random_email()
# <<< Create a new contact >>> # <<< Create a new contact >>>
flask_client.post( flask_client.post(
url_for("dashboard.alias_contact_manager", alias_id=alias.id), url_for("dashboard.alias_contact_manager", alias_id=alias.id),
data={ data={
"form-name": "create", "form-name": "create",
"email": "abcd@gmail.com", "email": email,
}, },
follow_redirects=True, follow_redirects=True,
) )
# a new contact is added # a new contact is added
assert Contact.filter_by(user_id=user.id).count() == 1 assert Contact.filter_by(user_id=user.id).count() == 1
contact = Contact.filter_by(user_id=user.id).first() contact = Contact.filter_by(user_id=user.id).first()
assert contact.website_email == "abcd@gmail.com" assert contact.website_email == email
# <<< Create a new contact using a full email format >>> # <<< Create a new contact using a full email format >>>
email = random_email()
flask_client.post( flask_client.post(
url_for("dashboard.alias_contact_manager", alias_id=alias.id), url_for("dashboard.alias_contact_manager", alias_id=alias.id),
data={ data={
"form-name": "create", "form-name": "create",
"email": "First Last <another@gmail.com>", "email": f"First Last <{email}>",
}, },
follow_redirects=True, follow_redirects=True,
) )
@ -41,7 +43,7 @@ def test_add_contact_success(flask_client):
contact = ( contact = (
Contact.filter_by(user_id=user.id).filter(Contact.id != contact.id).first() Contact.filter_by(user_id=user.id).filter(Contact.id != contact.id).first()
) )
assert contact.website_email == "another@gmail.com" assert contact.website_email == email
assert contact.name == "First Last" assert contact.name == "First Last"
# <<< Create a new contact with invalid email address >>> # <<< Create a new contact with invalid email address >>>

View File

@ -1,38 +1,72 @@
import app.alias_utils import app.alias_utils
from app import config
from app.db import Session from app.db import Session
from app.events.event_dispatcher import GlobalDispatcher
from app.models import ( from app.models import (
Alias, Alias,
Mailbox, Mailbox,
User,
AliasMailbox, AliasMailbox,
) )
from tests.events.event_test_utils import (
OnMemoryDispatcher,
_get_event_from_string,
_create_linked_user,
)
from tests.utils import login from tests.utils import login
on_memory_dispatcher = OnMemoryDispatcher()
def setup_module():
GlobalDispatcher.set_dispatcher(on_memory_dispatcher)
config.EVENT_WEBHOOK = "http://test"
def teardown_module():
GlobalDispatcher.set_dispatcher(None)
config.EVENT_WEBHOOK = None
def test_alias_transfer(flask_client): def test_alias_transfer(flask_client):
user = login(flask_client) (source_user, source_user_pu) = _create_linked_user()
mb = Mailbox.create(user_id=user.id, email="mb@gmail.com", commit=True) source_user = login(flask_client, source_user)
mb = Mailbox.create(user_id=source_user.id, email="mb@gmail.com", commit=True)
alias = Alias.create_new_random(user) alias = Alias.create_new_random(source_user)
Session.commit() Session.commit()
AliasMailbox.create(alias_id=alias.id, mailbox_id=mb.id, commit=True) AliasMailbox.create(alias_id=alias.id, mailbox_id=mb.id, commit=True)
new_user = User.create( (target_user, target_user_pu) = _create_linked_user()
email="hey@example.com",
password="password",
activated=True,
commit=True,
)
Mailbox.create( Mailbox.create(
user_id=new_user.id, email="hey2@example.com", verified=True, commit=True user_id=target_user.id, email="hey2@example.com", verified=True, commit=True
) )
app.alias_utils.transfer_alias(alias, new_user, new_user.mailboxes()) on_memory_dispatcher.clear()
app.alias_utils.transfer_alias(alias, target_user, target_user.mailboxes())
# refresh from db # refresh from db
alias = Alias.get(alias.id) alias = Alias.get(alias.id)
assert alias.user == new_user assert alias.user == target_user
assert set(alias.mailboxes) == set(new_user.mailboxes()) assert set(alias.mailboxes) == set(target_user.mailboxes())
assert len(alias.mailboxes) == 2 assert len(alias.mailboxes) == 2
# Check events
assert len(on_memory_dispatcher.memory) == 2
# 1st delete event
event_data = on_memory_dispatcher.memory[0]
event_content = _get_event_from_string(event_data, source_user, source_user_pu)
assert event_content.alias_deleted is not None
alias_deleted = event_content.alias_deleted
assert alias_deleted.id == alias.id
assert alias_deleted.email == alias.email
# 2nd create event
event_data = on_memory_dispatcher.memory[1]
event_content = _get_event_from_string(event_data, target_user, target_user_pu)
assert event_content.alias_created is not None
alias_created = event_content.alias_created
assert alias.id == alias_created.id
assert alias.email == alias_created.email
assert alias.note or "" == alias_created.note
assert alias.enabled == alias_created.enabled

View File

@ -1,4 +1,5 @@
from app.events.event_dispatcher import Dispatcher from app.events.event_dispatcher import Dispatcher
from app.events.generated import event_pb2
from app.models import PartnerUser, User from app.models import PartnerUser, User
from app.proton.utils import get_proton_partner from app.proton.utils import get_proton_partner
from tests.utils import create_new_user, random_token from tests.utils import create_new_user, random_token
@ -30,3 +31,14 @@ def _create_linked_user() -> Tuple[User, PartnerUser]:
) )
return user, partner_user return user, partner_user
def _get_event_from_string(
data: str, user: User, pu: PartnerUser
) -> event_pb2.EventContent:
event = event_pb2.Event()
event.ParseFromString(data)
assert user.id == event.user_id
assert pu.external_user_id == event.external_user_id
assert pu.partner_id == event.partner_id
return event.content

View File

@ -0,0 +1,54 @@
import arrow
from app.db import Session
from app.models import SyncEvent
from events.event_source import DeadLetterEventSource, _DEAD_LETTER_THRESHOLD_MINUTES
class EventCounter:
def __init__(self):
self.processed_events = 0
def on_event(self, event: SyncEvent):
self.processed_events += 1
def setup_function(func):
Session.query(SyncEvent).delete()
def test_dead_letter_does_not_take_untaken_events():
source = DeadLetterEventSource(1)
counter = EventCounter()
threshold_time = arrow.utcnow().shift(minutes=-(_DEAD_LETTER_THRESHOLD_MINUTES) + 1)
SyncEvent.create(
content="test".encode("utf-8"), created_at=threshold_time, flush=True
)
SyncEvent.create(
content="test".encode("utf-8"), taken_time=threshold_time, flush=True
)
events_processed = source.execute_loop(on_event=counter.on_event)
assert len(events_processed) == 0
assert counter.processed_events == 0
def test_dead_letter_takes_untaken_events_created_older_than_threshold():
source = DeadLetterEventSource(1)
counter = EventCounter()
old_create = arrow.utcnow().shift(minutes=-_DEAD_LETTER_THRESHOLD_MINUTES - 1)
SyncEvent.create(content="test".encode("utf-8"), created_at=old_create, flush=True)
events_processed = source.execute_loop(on_event=counter.on_event)
assert len(events_processed) == 1
assert events_processed[0].taken_time > old_create
assert counter.processed_events == 1
def test_dead_letter_takes_taken_events_created_older_than_threshold():
source = DeadLetterEventSource(1)
counter = EventCounter()
old_taken = arrow.utcnow().shift(minutes=-_DEAD_LETTER_THRESHOLD_MINUTES - 1)
SyncEvent.create(content="test".encode("utf-8"), taken_time=old_taken, flush=True)
events_processed = source.execute_loop(on_event=counter.on_event)
assert len(events_processed) == 1
assert events_processed[0].taken_time > old_taken
assert counter.processed_events == 1

View File

@ -1,12 +1,14 @@
import arrow
from app import config, alias_utils from app import config, alias_utils
from app.db import Session from app.db import Session
from app.events.event_dispatcher import GlobalDispatcher from app.events.event_dispatcher import GlobalDispatcher
from app.events.generated import event_pb2 from app.models import Alias, SyncEvent
from app.models import Alias, User, PartnerUser
from tests.utils import random_token from tests.utils import random_token
from .event_test_utils import ( from .event_test_utils import (
OnMemoryDispatcher, OnMemoryDispatcher,
_create_linked_user, _create_linked_user,
_get_event_from_string,
) )
on_memory_dispatcher = OnMemoryDispatcher() on_memory_dispatcher = OnMemoryDispatcher()
@ -26,15 +28,31 @@ def setup_function(func):
on_memory_dispatcher.clear() on_memory_dispatcher.clear()
def _get_event_from_string( def test_event_taken_updates():
data: str, user: User, pu: PartnerUser event = SyncEvent.create(content="test".encode("utf-8"), flush=True)
) -> event_pb2.EventContent: assert event.taken_time is None
event = event_pb2.Event() assert event.mark_as_taken()
event.ParseFromString(data) assert event.taken_time is not None
assert user.id == event.user_id
assert pu.external_user_id == event.external_user_id
assert pu.partner_id == event.partner_id def test_event_mark_as_taken_does_nothing_for_taken_events():
return event.content now = arrow.utcnow()
event = SyncEvent.create(content="test".encode("utf-8"), taken_time=now, flush=True)
assert not event.mark_as_taken()
def test_event_mark_as_taken_does_nothing_for_not_before_events():
now = arrow.utcnow()
event = SyncEvent.create(content="test".encode("utf-8"), taken_time=now, flush=True)
older_than = now.shift(minutes=-1)
assert not event.mark_as_taken(allow_taken_older_than=older_than)
def test_event_mark_as_taken_works_for_before_events():
now = arrow.utcnow()
event = SyncEvent.create(content="test".encode("utf-8"), taken_time=now, flush=True)
older_than = now.shift(minutes=+1)
assert event.mark_as_taken(allow_taken_older_than=older_than)
def test_fire_event_on_alias_creation(): def test_fire_event_on_alias_creation():
@ -90,7 +108,7 @@ def test_fire_event_on_alias_status_change():
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
Session.flush() Session.flush()
on_memory_dispatcher.clear() on_memory_dispatcher.clear()
alias_utils.change_alias_status(alias, True) alias_utils.change_alias_status(alias, enabled=True)
assert len(on_memory_dispatcher.memory) == 1 assert len(on_memory_dispatcher.memory) == 1
event_data = on_memory_dispatcher.memory[0] event_data = on_memory_dispatcher.memory[0]
event_content = _get_event_from_string(event_data, user, pu) event_content = _get_event_from_string(event_data, user, pu)

View File

@ -0,0 +1,109 @@
import arrow
from app import config
from app.events.event_dispatcher import GlobalDispatcher
from app.events.generated.event_pb2 import UserPlanChanged
from app.models import (
Subscription,
AppleSubscription,
CoinbaseSubscription,
ManualSubscription,
User,
PartnerUser,
)
from .event_test_utils import (
OnMemoryDispatcher,
_create_linked_user,
_get_event_from_string,
)
from tests.utils import random_token
from app.subscription_webhook import execute_subscription_webhook
on_memory_dispatcher = OnMemoryDispatcher()
def setup_module():
GlobalDispatcher.set_dispatcher(on_memory_dispatcher)
config.EVENT_WEBHOOK = "http://test"
def teardown_module():
GlobalDispatcher.set_dispatcher(None)
config.EVENT_WEBHOOK = None
def setup_function(func):
on_memory_dispatcher.clear()
def check_event(user: User, pu: PartnerUser) -> UserPlanChanged:
assert len(on_memory_dispatcher.memory) == 1
event_data = on_memory_dispatcher.memory[0]
event_content = _get_event_from_string(event_data, user, pu)
assert event_content.user_plan_change is not None
plan_change = event_content.user_plan_change
return plan_change
def test_webhook_with_trial():
(user, pu) = _create_linked_user()
execute_subscription_webhook(user)
assert check_event(user, pu).plan_end_time == 0
def test_webhook_with_subscription():
(user, pu) = _create_linked_user()
end_at = arrow.utcnow().shift(days=1).replace(hour=0, minute=0, second=0)
Subscription.create(
user_id=user.id,
cancel_url="",
update_url="",
subscription_id=random_token(10),
event_time=arrow.now(),
next_bill_date=end_at.date(),
plan="yearly",
flush=True,
)
execute_subscription_webhook(user)
assert check_event(user, pu).plan_end_time == end_at.timestamp
def test_webhook_with_apple_subscription():
(user, pu) = _create_linked_user()
end_at = arrow.utcnow().shift(days=2).replace(hour=0, minute=0, second=0)
AppleSubscription.create(
user_id=user.id,
receipt_data=arrow.now().date().strftime("%Y-%m-%d"),
expires_date=end_at.date().strftime("%Y-%m-%d"),
original_transaction_id=random_token(10),
plan="yearly",
product_id="",
flush=True,
)
execute_subscription_webhook(user)
assert check_event(user, pu).plan_end_time == end_at.timestamp
def test_webhook_with_coinbase_subscription():
(user, pu) = _create_linked_user()
end_at = arrow.utcnow().shift(days=3).replace(hour=0, minute=0, second=0)
CoinbaseSubscription.create(
user_id=user.id, end_at=end_at.date().strftime("%Y-%m-%d"), flush=True
)
execute_subscription_webhook(user)
assert check_event(user, pu).plan_end_time == end_at.timestamp
def test_webhook_with_manual_subscription():
(user, pu) = _create_linked_user()
end_at = arrow.utcnow().shift(days=3).replace(hour=0, minute=0, second=0)
ManualSubscription.create(
user_id=user.id, end_at=end_at.date().strftime("%Y-%m-%d"), flush=True
)
execute_subscription_webhook(user)
assert check_event(user, pu).plan_end_time == end_at.timestamp

View File

@ -1,3 +1,5 @@
from typing import List
import pytest import pytest
from arrow import Arrow from arrow import Arrow
@ -16,8 +18,9 @@ from app.account_linking import (
) )
from app.db import Session from app.db import Session
from app.errors import AccountAlreadyLinkedToAnotherPartnerException from app.errors import AccountAlreadyLinkedToAnotherPartnerException
from app.models import Partner, PartnerUser, User from app.models import Partner, PartnerUser, User, UserAuditLog
from app.proton.utils import get_proton_partner from app.proton.utils import get_proton_partner
from app.user_audit_log_utils import UserAuditLogAction
from app.utils import random_string, canonicalize_email from app.utils import random_string, canonicalize_email
from tests.utils import random_email from tests.utils import random_email
@ -91,6 +94,11 @@ def test_login_case_from_partner():
) )
assert res.user.activated is True assert res.user.activated is True
audit_logs: List[UserAuditLog] = UserAuditLog.filter_by(user_id=res.user.id).all()
assert len(audit_logs) == 1
assert audit_logs[0].user_id == res.user.id
assert audit_logs[0].action == UserAuditLogAction.LinkAccount.value
def test_login_case_from_partner_with_uppercase_email(): def test_login_case_from_partner_with_uppercase_email():
partner = get_proton_partner() partner = get_proton_partner()
@ -125,6 +133,11 @@ def test_login_case_from_web():
assert 0 == (res.user.flags & User.FLAG_CREATED_FROM_PARTNER) assert 0 == (res.user.flags & User.FLAG_CREATED_FROM_PARTNER)
assert res.user.activated is True assert res.user.activated is True
audit_logs: List[UserAuditLog] = UserAuditLog.filter_by(user_id=res.user.id).all()
assert len(audit_logs) == 1
assert audit_logs[0].user_id == res.user.id
assert audit_logs[0].action == UserAuditLogAction.LinkAccount.value
def test_get_strategy_existing_sl_user(): def test_get_strategy_existing_sl_user():
email = random_email() email = random_email()
@ -205,6 +218,10 @@ def test_link_account_with_proton_account_same_address(flask_client):
) )
assert partner_user.partner_id == get_proton_partner().id assert partner_user.partner_id == get_proton_partner().id
assert partner_user.external_user_id == partner_user_id assert partner_user.external_user_id == partner_user_id
audit_logs: List[UserAuditLog] = UserAuditLog.filter_by(user_id=res.user.id).all()
assert len(audit_logs) == 1
assert audit_logs[0].user_id == res.user.id
assert audit_logs[0].action == UserAuditLogAction.LinkAccount.value
def test_link_account_with_proton_account_different_address(flask_client): def test_link_account_with_proton_account_different_address(flask_client):
@ -229,6 +246,11 @@ def test_link_account_with_proton_account_different_address(flask_client):
assert partner_user.partner_id == get_proton_partner().id assert partner_user.partner_id == get_proton_partner().id
assert partner_user.external_user_id == partner_user_id assert partner_user.external_user_id == partner_user_id
audit_logs: List[UserAuditLog] = UserAuditLog.filter_by(user_id=res.user.id).all()
assert len(audit_logs) == 1
assert audit_logs[0].user_id == res.user.id
assert audit_logs[0].action == UserAuditLogAction.LinkAccount.value
def test_link_account_with_proton_account_same_address_but_linked_to_other_user( def test_link_account_with_proton_account_same_address_but_linked_to_other_user(
flask_client, flask_client,
@ -248,22 +270,54 @@ def test_link_account_with_proton_account_same_address_but_linked_to_other_user(
partner_user_id, email=random_email() partner_user_id, email=random_email()
) # User already linked with the proton account ) # User already linked with the proton account
# START Ensure sl_user_2 has a partner_user with the right data
partner_user = PartnerUser.get_by(
partner_id=get_proton_partner().id, user_id=sl_user_2.id
)
assert partner_user is not None
assert partner_user.partner_id == get_proton_partner().id
assert partner_user.external_user_id == partner_user_id
assert partner_user.partner_email == sl_user_2.email
assert partner_user.user_id == sl_user_2.id
# END Ensure sl_user_2 has a partner_user with the right data
# Proceed to link sl_user_1
res = process_link_case(link_request, sl_user_1, get_proton_partner()) res = process_link_case(link_request, sl_user_1, get_proton_partner())
# Check that the result is linking sl_user_1
assert res.user.id == sl_user_1.id assert res.user.id == sl_user_1.id
assert res.user.email == partner_email assert res.user.email == partner_email
assert res.strategy == "Link" assert res.strategy == "Link"
# Ensure partner_user for sl_user_1 exists
partner_user = PartnerUser.get_by( partner_user = PartnerUser.get_by(
partner_id=get_proton_partner().id, user_id=sl_user_1.id partner_id=get_proton_partner().id, user_id=sl_user_1.id
) )
assert partner_user.partner_id == get_proton_partner().id assert partner_user.partner_id == get_proton_partner().id
assert partner_user.external_user_id == partner_user_id assert partner_user.external_user_id == partner_user_id
# Ensure partner_user for sl_user_2 does not exist anymore
partner_user = PartnerUser.get_by( partner_user = PartnerUser.get_by(
partner_id=get_proton_partner().id, user_id=sl_user_2.id partner_id=get_proton_partner().id, user_id=sl_user_2.id
) )
assert partner_user is None assert partner_user is None
# Ensure audit logs for sl_user_1 show the link action
sl_user_1_audit_logs: List[UserAuditLog] = UserAuditLog.filter_by(
user_id=sl_user_1.id
).all()
assert len(sl_user_1_audit_logs) == 1
assert sl_user_1_audit_logs[0].user_id == sl_user_1.id
assert sl_user_1_audit_logs[0].action == UserAuditLogAction.LinkAccount.value
# Ensure audit logs for sl_user_2 show the unlink action
sl_user_2_audit_logs: List[UserAuditLog] = UserAuditLog.filter_by(
user_id=sl_user_2.id
).all()
assert len(sl_user_2_audit_logs) == 1
assert sl_user_2_audit_logs[0].user_id == sl_user_2.id
assert sl_user_2_audit_logs[0].action == UserAuditLogAction.UnlinkAccount.value
def test_link_account_with_proton_account_different_address_and_linked_to_other_user( def test_link_account_with_proton_account_different_address_and_linked_to_other_user(
flask_client, flask_client,
@ -300,6 +354,22 @@ def test_link_account_with_proton_account_different_address_and_linked_to_other_
) )
assert partner_user_2 is None assert partner_user_2 is None
# Ensure audit logs for sl_user_1 show the link action
sl_user_1_audit_logs: List[UserAuditLog] = UserAuditLog.filter_by(
user_id=sl_user_1.id
).all()
assert len(sl_user_1_audit_logs) == 1
assert sl_user_1_audit_logs[0].user_id == sl_user_1.id
assert sl_user_1_audit_logs[0].action == UserAuditLogAction.LinkAccount.value
# Ensure audit logs for sl_user_2 show the unlink action
sl_user_2_audit_logs: List[UserAuditLog] = UserAuditLog.filter_by(
user_id=sl_user_2.id
).all()
assert len(sl_user_2_audit_logs) == 1
assert sl_user_2_audit_logs[0].user_id == sl_user_2.id
assert sl_user_2_audit_logs[0].action == UserAuditLogAction.UnlinkAccount.value
def test_cannot_create_instance_of_base_strategy(): def test_cannot_create_instance_of_base_strategy():
with pytest.raises(Exception): with pytest.raises(Exception):

View File

@ -0,0 +1,95 @@
import random
from app.alias_audit_log_utils import emit_alias_audit_log, AliasAuditLogAction
from app.alias_utils import delete_alias, transfer_alias
from app.models import Alias, AliasAuditLog, AliasDeleteReason
from app.utils import random_string
from tests.utils import create_new_user, random_email
def test_emit_alias_audit_log_for_random_data():
user = create_new_user()
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
)
random_user_id = random.randint(1000, 2000)
message = random_string()
action = AliasAuditLogAction.ChangeAliasStatus
emit_alias_audit_log(
alias=alias,
user_id=random_user_id,
action=action,
message=message,
commit=True,
)
logs_for_alias = AliasAuditLog.filter_by(alias_id=alias.id).all()
assert len(logs_for_alias) == 2
last_log = logs_for_alias[-1]
assert last_log.alias_id == alias.id
assert last_log.alias_email == alias.email
assert last_log.user_id == random_user_id
assert last_log.action == action.value
assert last_log.message == message
def test_emit_alias_audit_log_on_alias_creation():
user = create_new_user()
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
)
log_for_alias = AliasAuditLog.filter_by(alias_id=alias.id).all()
assert len(log_for_alias) == 1
assert log_for_alias[0].alias_id == alias.id
assert log_for_alias[0].alias_email == alias.email
assert log_for_alias[0].user_id == user.id
assert log_for_alias[0].action == AliasAuditLogAction.CreateAlias.value
def test_alias_audit_log_exists_after_alias_deletion():
user = create_new_user()
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
)
alias_id = alias.id
emit_alias_audit_log(alias, AliasAuditLogAction.UpdateAlias, "")
emit_alias_audit_log(alias, AliasAuditLogAction.UpdateAlias, "")
delete_alias(alias, user, AliasDeleteReason.ManualAction, commit=True)
db_alias = Alias.get_by(id=alias_id)
assert db_alias is None
logs_for_alias = AliasAuditLog.filter_by(alias_id=alias.id).all()
assert len(logs_for_alias) == 4
assert logs_for_alias[0].action == AliasAuditLogAction.CreateAlias.value
assert logs_for_alias[1].action == AliasAuditLogAction.UpdateAlias.value
assert logs_for_alias[2].action == AliasAuditLogAction.UpdateAlias.value
assert logs_for_alias[3].action == AliasAuditLogAction.DeleteAlias.value
def test_alias_audit_log_for_transfer():
original_user = create_new_user()
new_user = create_new_user()
alias = Alias.create(
user_id=original_user.id,
email=random_email(),
mailbox_id=original_user.default_mailbox_id,
)
transfer_alias(alias, new_user, [new_user.default_mailbox])
logs_for_alias = AliasAuditLog.filter_by(alias_id=alias.id).all()
assert len(logs_for_alias) == 3
assert logs_for_alias[0].action == AliasAuditLogAction.CreateAlias.value
assert logs_for_alias[1].action == AliasAuditLogAction.TransferredAlias.value
assert logs_for_alias[1].user_id == original_user.id
assert logs_for_alias[2].action == AliasAuditLogAction.AcceptTransferAlias.value
assert logs_for_alias[2].user_id == new_user.id

View File

@ -0,0 +1,70 @@
from typing import Tuple
from app.alias_audit_log_utils import AliasAuditLogAction
from app.alias_mailbox_utils import (
set_mailboxes_for_alias,
CannotSetMailboxesForAliasCause,
)
from app.models import Alias, Mailbox, User, AliasMailbox, AliasAuditLog
from tests.utils import create_new_user, random_email
def setup() -> Tuple[User, Alias]:
user = create_new_user()
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
commit=True,
)
return user, alias
def test_set_mailboxes_for_alias_empty_list():
user, alias = setup()
err = set_mailboxes_for_alias(user.id, alias, [])
assert err is CannotSetMailboxesForAliasCause.EmptyMailboxes
def test_set_mailboxes_for_alias_mailbox_for_other_user():
user, alias = setup()
another_user = create_new_user()
err = set_mailboxes_for_alias(user.id, alias, [another_user.default_mailbox_id])
assert err is CannotSetMailboxesForAliasCause.Forbidden
def test_set_mailboxes_for_alias_mailbox_not_exists():
user, alias = setup()
err = set_mailboxes_for_alias(user.id, alias, [9999999])
assert err is CannotSetMailboxesForAliasCause.Forbidden
def test_set_mailboxes_for_alias_mailbox_success():
user, alias = setup()
mb1 = Mailbox.create(
user_id=user.id,
email=random_email(),
verified=True,
)
mb2 = Mailbox.create(
user_id=user.id,
email=random_email(),
verified=True,
commit=True,
)
err = set_mailboxes_for_alias(user.id, alias, [mb1.id, mb2.id])
assert err is None
db_alias = Alias.get_by(id=alias.id)
assert db_alias is not None
assert db_alias.mailbox_id == mb1.id
alias_mailboxes = AliasMailbox.filter_by(alias_id=alias.id).all()
assert len(alias_mailboxes) == 1
assert alias_mailboxes[0].mailbox_id == mb2.id
audit_logs = AliasAuditLog.filter_by(alias_id=alias.id).all()
assert len(audit_logs) == 2
assert audit_logs[0].action == AliasAuditLogAction.CreateAlias.value
assert audit_logs[1].action == AliasAuditLogAction.ChangedMailboxes.value
assert audit_logs[1].message == f"{mb1.id} ({mb1.email}),{mb2.id} ({mb2.email})"

View File

@ -4,6 +4,7 @@ from app.alias_utils import (
delete_alias, delete_alias,
check_alias_prefix, check_alias_prefix,
get_user_if_alias_would_auto_create, get_user_if_alias_would_auto_create,
get_alias_recipient_name,
try_auto_create, try_auto_create,
) )
from app.config import ALIAS_DOMAINS from app.config import ALIAS_DOMAINS
@ -18,7 +19,8 @@ from app.models import (
User, User,
DomainDeletedAlias, DomainDeletedAlias,
) )
from tests.utils import create_new_user, random_domain, random_token from app.utils import random_string
from tests.utils import create_new_user, random_domain, random_token, random_email
def test_delete_alias(flask_client): def test_delete_alias(flask_client):
@ -131,3 +133,91 @@ def test_auto_create_alias(flask_client):
assert result, f"Case {test_id} - Failed address {address}" assert result, f"Case {test_id} - Failed address {address}"
else: else:
assert result is None, f"Case {test_id} - Failed address {address}" assert result is None, f"Case {test_id} - Failed address {address}"
# get_alias_recipient_name
def test_get_alias_recipient_name_no_overrides():
user = create_new_user()
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
commit=True,
)
res = get_alias_recipient_name(alias)
assert res.message is None
assert res.name == alias.email
def test_get_alias_recipient_name_alias_name():
user = create_new_user()
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
name=random_string(),
commit=True,
)
res = get_alias_recipient_name(alias)
assert res.message is not None
assert res.name == f"{alias.name} <{alias.email}>"
def test_get_alias_recipient_alias_with_name_and_custom_domain_name():
user = create_new_user()
custom_domain = CustomDomain.create(
user_id=user.id,
domain=random_domain(),
name=random_string(),
verified=True,
)
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
name=random_string(),
custom_domain_id=custom_domain.id,
commit=True,
)
res = get_alias_recipient_name(alias)
assert res.message is not None
assert res.name == f"{alias.name} <{alias.email}>"
def test_get_alias_recipient_alias_without_name_and_custom_domain_without_name():
user = create_new_user()
custom_domain = CustomDomain.create(
user_id=user.id,
domain=random_domain(),
verified=True,
)
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
custom_domain_id=custom_domain.id,
commit=True,
)
res = get_alias_recipient_name(alias)
assert res.message is None
assert res.name == alias.email
def test_get_alias_recipient_alias_without_name_and_custom_domain_name():
user = create_new_user()
custom_domain = CustomDomain.create(
user_id=user.id,
domain=random_domain(),
name=random_string(),
verified=True,
)
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
custom_domain_id=custom_domain.id,
commit=True,
)
res = get_alias_recipient_name(alias)
assert res.message is not None
assert res.name == f"{custom_domain.name} <{alias.email}>"

View File

@ -1,15 +1,26 @@
from typing import Optional from typing import Optional
import pytest import pytest
from app import config
from app.contact_utils import create_contact, ContactCreateError from app.contact_utils import create_contact, ContactCreateError
from app.db import Session from app.db import Session
from app.models import ( from app.models import (
Alias, Alias,
Contact, Contact,
User,
) )
from tests.utils import create_new_user, random_email, random_token from tests.utils import create_new_user, random_email, random_token
def setup_module(module):
config.DISABLE_CREATE_CONTACTS_FOR_FREE_USERS = True
def teardown_module(module):
config.DISABLE_CREATE_CONTACTS_FOR_FREE_USERS = False
def create_provider(): def create_provider():
# name auto_created from_partner # name auto_created from_partner
yield ["name", "a@b.c", True, True] yield ["name", "a@b.c", True, True]
@ -34,8 +45,8 @@ def test_create_contact(
email = random_email() email = random_email()
contact_result = create_contact( contact_result = create_contact(
email, email,
name,
alias, alias,
name=name,
mail_from=mail_from, mail_from=mail_from,
automatic_created=automatic_created, automatic_created=automatic_created,
from_partner=from_partner, from_partner=from_partner,
@ -57,7 +68,7 @@ def test_create_contact_email_email_not_allowed():
user = create_new_user() user = create_new_user()
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
Session.commit() Session.commit()
contact_result = create_contact("", "", alias) contact_result = create_contact("", alias)
assert contact_result.contact is None assert contact_result.contact is None
assert contact_result.error == ContactCreateError.InvalidEmail assert contact_result.error == ContactCreateError.InvalidEmail
@ -66,21 +77,84 @@ def test_create_contact_email_email_allowed():
user = create_new_user() user = create_new_user()
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
Session.commit() Session.commit()
contact_result = create_contact("", "", alias, allow_empty_email=True) contact_result = create_contact("", alias, allow_empty_email=True)
assert contact_result.error is None assert contact_result.error is None
assert contact_result.contact is not None assert contact_result.contact is not None
assert contact_result.contact.website_email == "" assert contact_result.contact.website_email == ""
assert contact_result.contact.invalid_email assert contact_result.contact.invalid_email
def test_create_contact_name_overrides_email_name():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
email = random_email()
name = random_token()
contact_result = create_contact(f"superseeded <{email}>", alias, name=name)
assert contact_result.error is None
assert contact_result.contact is not None
assert contact_result.contact.website_email == email
assert contact_result.contact.name == name
def test_create_contact_name_taken_from_email():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
email = random_email()
name = random_token()
contact_result = create_contact(f"{name} <{email}>", alias)
assert contact_result.error is None
assert contact_result.contact is not None
assert contact_result.contact.website_email == email
assert contact_result.contact.name == name
def test_create_contact_empty_name_is_none():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
email = random_email()
contact_result = create_contact(email, alias, name="")
assert contact_result.error is None
assert contact_result.contact is not None
assert contact_result.contact.website_email == email
assert contact_result.contact.name is None
def test_create_contact_free_user():
user = create_new_user()
user.trial_end = None
user.flags = 0
alias = Alias.create_new_random(user)
Session.flush()
# Free users without the FREE_DISABLE_CREATE_CONTACTS
result = create_contact(random_email(), alias)
assert result.error is None
assert result.created
assert result.contact is not None
assert not result.contact.automatic_created
# Free users with the flag should be able to still create automatic emails
user.flags = User.FLAG_DISABLE_CREATE_CONTACTS
Session.flush()
result = create_contact(random_email(), alias, automatic_created=True)
assert result.error is None
assert result.created
assert result.contact is not None
assert result.contact.automatic_created
# Free users with the flag cannot create non-automatic emails
result = create_contact(random_email(), alias)
assert result.error == ContactCreateError.NotAllowed
def test_do_not_allow_invalid_email(): def test_do_not_allow_invalid_email():
user = create_new_user() user = create_new_user()
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
Session.commit() Session.commit()
contact_result = create_contact("potato", "", alias) contact_result = create_contact("potato", alias)
assert contact_result.contact is None assert contact_result.contact is None
assert contact_result.error == ContactCreateError.InvalidEmail assert contact_result.error == ContactCreateError.InvalidEmail
contact_result = create_contact("asdf\x00@gmail.com", "", alias) contact_result = create_contact("asdf\x00@gmail.com", alias)
assert contact_result.contact is None assert contact_result.contact is None
assert contact_result.error == ContactCreateError.InvalidEmail assert contact_result.error == ContactCreateError.InvalidEmail
@ -90,13 +164,15 @@ def test_update_name_for_existing():
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
Session.commit() Session.commit()
email = random_email() email = random_email()
contact_result = create_contact(email, "", alias) contact_result = create_contact(email, alias)
assert contact_result.error is None assert contact_result.error is None
assert contact_result.created
assert contact_result.contact is not None assert contact_result.contact is not None
assert contact_result.contact.name == "" assert contact_result.contact.name is None
name = random_token() name = random_token()
contact_result = create_contact(email, name, alias) contact_result = create_contact(email, alias, name=name)
assert contact_result.error is None assert contact_result.error is None
assert not contact_result.created
assert contact_result.contact is not None assert contact_result.contact is not None
assert contact_result.contact.name == name assert contact_result.contact.name == name
@ -106,12 +182,15 @@ def test_update_mail_from_for_existing():
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
Session.commit() Session.commit()
email = random_email() email = random_email()
contact_result = create_contact(email, "", alias) contact_result = create_contact(email, alias)
assert contact_result.error is None assert contact_result.error is None
assert contact_result.created
assert contact_result.contact is not None
assert contact_result.contact is not None assert contact_result.contact is not None
assert contact_result.contact.mail_from is None assert contact_result.contact.mail_from is None
mail_from = random_email() mail_from = random_email()
contact_result = create_contact(email, "", alias, mail_from=mail_from) contact_result = create_contact(email, alias, mail_from=mail_from)
assert contact_result.error is None assert contact_result.error is None
assert not contact_result.created
assert contact_result.contact is not None assert contact_result.contact is not None
assert contact_result.contact.mail_from == mail_from assert contact_result.contact.mail_from == mail_from

View File

@ -7,11 +7,13 @@ from app.custom_domain_utils import (
create_custom_domain, create_custom_domain,
is_valid_domain, is_valid_domain,
sanitize_domain, sanitize_domain,
set_custom_domain_mailboxes,
CannotUseDomainReason, CannotUseDomainReason,
CannotSetCustomDomainMailboxesCause,
) )
from app.db import Session from app.db import Session
from app.models import User, CustomDomain, Mailbox from app.models import User, CustomDomain, Mailbox, DomainMailbox
from tests.utils import get_proton_partner from tests.utils import get_proton_partner, random_email
from tests.utils import create_new_user, random_string, random_domain from tests.utils import create_new_user, random_string, random_domain
user: Optional[User] = None user: Optional[User] = None
@ -147,3 +149,119 @@ def test_creates_custom_domain_with_partner_id():
assert res.instance.domain == domain assert res.instance.domain == domain
assert res.instance.user_id == user.id assert res.instance.user_id == user.id
assert res.instance.partner_id == proton_partner.id assert res.instance.partner_id == proton_partner.id
# set_custom_domain_mailboxes
def test_set_custom_domain_mailboxes_empty_list():
domain = CustomDomain.create(user_id=user.id, domain=random_domain(), commit=True)
res = set_custom_domain_mailboxes(user.id, domain, [])
assert res.success is False
assert res.reason == CannotSetCustomDomainMailboxesCause.NoMailboxes
def test_set_custom_domain_mailboxes_mailbox_from_another_user():
other_user = create_new_user()
other_mailbox = Mailbox.create(
user_id=other_user.id, email=random_email(), verified=True
)
domain = CustomDomain.create(user_id=user.id, domain=random_domain(), commit=True)
res = set_custom_domain_mailboxes(user.id, domain, [other_mailbox.id])
assert res.success is False
assert res.reason == CannotSetCustomDomainMailboxesCause.InvalidMailbox
def test_set_custom_domain_mailboxes_mailbox_from_current_user_and_another_user():
other_user = create_new_user()
other_mailbox = Mailbox.create(
user_id=other_user.id, email=random_email(), verified=True
)
domain = CustomDomain.create(user_id=user.id, domain=random_domain(), commit=True)
res = set_custom_domain_mailboxes(
user.id, domain, [user.default_mailbox_id, other_mailbox.id]
)
assert res.success is False
assert res.reason == CannotSetCustomDomainMailboxesCause.InvalidMailbox
def test_set_custom_domain_mailboxes_success():
other_mailbox = Mailbox.create(user_id=user.id, email=random_email(), verified=True)
domain = CustomDomain.create(user_id=user.id, domain=random_domain(), commit=True)
res = set_custom_domain_mailboxes(
user.id, domain, [user.default_mailbox_id, other_mailbox.id]
)
assert res.success is True
assert res.reason is None
domain_mailboxes = DomainMailbox.filter_by(domain_id=domain.id).all()
assert len(domain_mailboxes) == 2
assert domain_mailboxes[0].domain_id == domain.id
assert domain_mailboxes[0].mailbox_id == user.default_mailbox_id
assert domain_mailboxes[1].domain_id == domain.id
assert domain_mailboxes[1].mailbox_id == other_mailbox.id
def test_set_custom_domain_mailboxes_set_twice():
other_mailbox = Mailbox.create(user_id=user.id, email=random_email(), verified=True)
domain = CustomDomain.create(user_id=user.id, domain=random_domain(), commit=True)
res = set_custom_domain_mailboxes(
user.id, domain, [user.default_mailbox_id, other_mailbox.id]
)
assert res.success is True
assert res.reason is None
res = set_custom_domain_mailboxes(
user.id, domain, [user.default_mailbox_id, other_mailbox.id]
)
assert res.success is True
assert res.reason is None
domain_mailboxes = DomainMailbox.filter_by(domain_id=domain.id).all()
assert len(domain_mailboxes) == 2
assert domain_mailboxes[0].domain_id == domain.id
assert domain_mailboxes[0].mailbox_id == user.default_mailbox_id
assert domain_mailboxes[1].domain_id == domain.id
assert domain_mailboxes[1].mailbox_id == other_mailbox.id
def test_set_custom_domain_mailboxes_removes_old_association():
domain = CustomDomain.create(user_id=user.id, domain=random_domain(), commit=True)
res = set_custom_domain_mailboxes(user.id, domain, [user.default_mailbox_id])
assert res.success is True
assert res.reason is None
other_mailbox = Mailbox.create(
user_id=user.id, email=random_email(), verified=True, commit=True
)
res = set_custom_domain_mailboxes(user.id, domain, [other_mailbox.id])
assert res.success is True
assert res.reason is None
domain_mailboxes = DomainMailbox.filter_by(domain_id=domain.id).all()
assert len(domain_mailboxes) == 1
assert domain_mailboxes[0].domain_id == domain.id
assert domain_mailboxes[0].mailbox_id == other_mailbox.id
def test_set_custom_domain_mailboxes_with_unverified_mailbox():
domain = CustomDomain.create(user_id=user.id, domain=random_domain())
verified_mailbox = Mailbox.create(
user_id=user.id,
email=random_email(),
verified=True,
)
unverified_mailbox = Mailbox.create(
user_id=user.id,
email=random_email(),
verified=False,
)
res = set_custom_domain_mailboxes(
user.id, domain, [verified_mailbox.id, unverified_mailbox.id]
)
assert res.success is False
assert res.reason is CannotSetCustomDomainMailboxesCause.InvalidMailbox

View File

@ -5,7 +5,7 @@ from app.constants import DMARC_RECORD
from app.custom_domain_validation import CustomDomainValidation from app.custom_domain_validation import CustomDomainValidation
from app.db import Session from app.db import Session
from app.models import CustomDomain, User from app.models import CustomDomain, User
from app.dns_utils import InMemoryDNSClient from app.dns_utils import InMemoryDNSClient, MxRecord
from app.proton.utils import get_proton_partner from app.proton.utils import get_proton_partner
from app.utils import random_string from app.utils import random_string
from tests.utils import create_new_user, random_domain from tests.utils import create_new_user, random_domain
@ -58,6 +58,123 @@ def test_custom_domain_validation_get_dkim_records_for_partner():
assert records["dkim._domainkey"] == f"dkim._domainkey.{dkim_domain}" assert records["dkim._domainkey"] == f"dkim._domainkey.{dkim_domain}"
# get_expected_mx_records
def test_custom_domain_validation_get_expected_mx_records_regular_domain():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
dkim_domain = random_domain()
validator = CustomDomainValidation(
domain, partner_domains={partner_id: dkim_domain}
)
records = validator.get_expected_mx_records(custom_domain)
# As the domain is not a partner_domain,default records should be used even if
# there is a config for the partner
assert len(records) == len(config.EMAIL_SERVERS_WITH_PRIORITY)
for i in range(len(config.EMAIL_SERVERS_WITH_PRIORITY)):
config_record = config.EMAIL_SERVERS_WITH_PRIORITY[i]
assert records[i].priority == config_record[0]
assert records[i].domain == config_record[1]
def test_custom_domain_validation_get_expected_mx_records_domain_from_partner():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
custom_domain.partner_id = partner_id
Session.commit()
dkim_domain = random_domain()
validator = CustomDomainValidation(dkim_domain)
records = validator.get_expected_mx_records(custom_domain)
# As the domain is a partner_domain but there is no custom config for partner, default records
# should be used
assert len(records) == len(config.EMAIL_SERVERS_WITH_PRIORITY)
for i in range(len(config.EMAIL_SERVERS_WITH_PRIORITY)):
config_record = config.EMAIL_SERVERS_WITH_PRIORITY[i]
assert records[i].priority == config_record[0]
assert records[i].domain == config_record[1]
def test_custom_domain_validation_get_expected_mx_records_domain_from_partner_with_custom_config():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
custom_domain.partner_id = partner_id
Session.commit()
dkim_domain = random_domain()
expected_mx_domain = random_domain()
validator = CustomDomainValidation(
dkim_domain, partner_domains={partner_id: expected_mx_domain}
)
records = validator.get_expected_mx_records(custom_domain)
# As the domain is a partner_domain and there is a custom config for partner, partner records
# should be used
assert len(records) == 2
assert records[0].priority == 10
assert records[0].domain == f"mx1.{expected_mx_domain}."
assert records[1].priority == 20
assert records[1].domain == f"mx2.{expected_mx_domain}."
# get_expected_spf_records
def test_custom_domain_validation_get_expected_spf_record_regular_domain():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
dkim_domain = random_domain()
validator = CustomDomainValidation(
domain, partner_domains={partner_id: dkim_domain}
)
record = validator.get_expected_spf_record(custom_domain)
# As the domain is not a partner_domain, default records should be used even if
# there is a config for the partner
assert record == f"v=spf1 include:{config.EMAIL_DOMAIN} ~all"
def test_custom_domain_validation_get_expected_spf_record_domain_from_partner():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
custom_domain.partner_id = partner_id
Session.commit()
dkim_domain = random_domain()
validator = CustomDomainValidation(dkim_domain)
record = validator.get_expected_spf_record(custom_domain)
# As the domain is a partner_domain but there is no custom config for partner, default records
# should be used
assert record == f"v=spf1 include:{config.EMAIL_DOMAIN} ~all"
def test_custom_domain_validation_get_expected_spf_record_domain_from_partner_with_custom_config():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
custom_domain.partner_id = partner_id
Session.commit()
dkim_domain = random_domain()
expected_mx_domain = random_domain()
validator = CustomDomainValidation(
dkim_domain, partner_domains={partner_id: expected_mx_domain}
)
record = validator.get_expected_spf_record(custom_domain)
# As the domain is a partner_domain and there is a custom config for partner, partner records
# should be used
assert record == f"v=spf1 include:{expected_mx_domain} ~all"
# validate_dkim_records # validate_dkim_records
def test_custom_domain_validation_validate_dkim_records_empty_records_failure(): def test_custom_domain_validation_validate_dkim_records_empty_records_failure():
dns_client = InMemoryDNSClient() dns_client = InMemoryDNSClient()
@ -253,7 +370,7 @@ def test_custom_domain_validation_validate_mx_records_wrong_records_failure():
wrong_record_1 = random_string() wrong_record_1 = random_string()
wrong_record_2 = random_string() wrong_record_2 = random_string()
wrong_records = [(10, wrong_record_1), (20, wrong_record_2)] wrong_records = [MxRecord(10, wrong_record_1), MxRecord(20, wrong_record_2)]
dns_client.set_mx_records(domain.domain, wrong_records) dns_client.set_mx_records(domain.domain, wrong_records)
res = validator.validate_mx_records(domain) res = validator.validate_mx_records(domain)
@ -270,7 +387,7 @@ def test_custom_domain_validation_validate_mx_records_success():
domain = create_custom_domain(random_domain()) domain = create_custom_domain(random_domain())
dns_client.set_mx_records(domain.domain, config.EMAIL_SERVERS_WITH_PRIORITY) dns_client.set_mx_records(domain.domain, validator.get_expected_mx_records(domain))
res = validator.validate_mx_records(domain) res = validator.validate_mx_records(domain)
assert res.success is True assert res.success is True
@ -328,6 +445,58 @@ def test_custom_domain_validation_validate_spf_records_success():
assert db_domain.spf_verified is True assert db_domain.spf_verified is True
def test_custom_domain_validation_validate_spf_records_partner_domain_success():
dns_client = InMemoryDNSClient()
proton_partner_id = get_proton_partner().id
expected_domain = random_domain()
validator = CustomDomainValidation(
dkim_domain=random_domain(),
dns_client=dns_client,
partner_domains={proton_partner_id: expected_domain},
)
domain = create_custom_domain(random_domain())
domain.partner_id = proton_partner_id
Session.commit()
dns_client.set_txt_record(domain.domain, [f"v=spf1 include:{expected_domain}"])
res = validator.validate_spf_records(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.spf_verified is True
def test_custom_domain_validation_validate_spf_cleans_verification_record():
dns_client = InMemoryDNSClient()
proton_partner_id = get_proton_partner().id
expected_domain = random_domain()
validator = CustomDomainValidation(
dkim_domain=random_domain(),
dns_client=dns_client,
partner_domains={proton_partner_id: expected_domain},
)
domain = create_custom_domain(random_domain())
domain.partner_id = proton_partner_id
Session.commit()
wrong_record = random_string()
dns_client.set_txt_record(
hostname=domain.domain,
txt_list=[wrong_record, validator.get_ownership_verification_record(domain)],
)
res = validator.validate_spf_records(domain)
assert res.success is False
assert len(res.errors) == 1
assert res.errors[0] == wrong_record
# validate_dmarc_records # validate_dmarc_records
def test_custom_domain_validation_validate_dmarc_records_empty_failure(): def test_custom_domain_validation_validate_dmarc_records_empty_failure():
dns_client = InMemoryDNSClient() dns_client = InMemoryDNSClient()

View File

@ -3,6 +3,7 @@ from app.dns_utils import (
get_network_dns_client, get_network_dns_client,
is_mx_equivalent, is_mx_equivalent,
InMemoryDNSClient, InMemoryDNSClient,
MxRecord,
) )
from tests.utils import random_domain from tests.utils import random_domain
@ -17,8 +18,8 @@ def test_get_mx_domains():
assert len(r) > 0 assert len(r) > 0
for x in r: for x in r:
assert x[0] > 0 assert x.priority > 0
assert x[1] assert x.domain
def test_get_spf_domain(): def test_get_spf_domain():
@ -33,20 +34,32 @@ def test_get_txt_record():
def test_is_mx_equivalent(): def test_is_mx_equivalent():
assert is_mx_equivalent([], []) assert is_mx_equivalent([], [])
assert is_mx_equivalent([(1, "domain")], [(1, "domain")])
assert is_mx_equivalent( assert is_mx_equivalent(
[(10, "domain1"), (20, "domain2")], [(10, "domain1"), (20, "domain2")] mx_domains=[MxRecord(1, "domain")], ref_mx_domains=[MxRecord(1, "domain")]
) )
assert is_mx_equivalent( assert is_mx_equivalent(
[(5, "domain1"), (10, "domain2")], [(10, "domain1"), (20, "domain2")] mx_domains=[MxRecord(10, "domain1"), MxRecord(20, "domain2")],
ref_mx_domains=[MxRecord(10, "domain1"), MxRecord(20, "domain2")],
) )
assert is_mx_equivalent( assert is_mx_equivalent(
[(5, "domain1"), (10, "domain2"), (20, "domain3")], mx_domains=[MxRecord(5, "domain1"), MxRecord(10, "domain2")],
[(10, "domain1"), (20, "domain2")], ref_mx_domains=[MxRecord(10, "domain1"), MxRecord(20, "domain2")],
)
assert is_mx_equivalent(
mx_domains=[
MxRecord(5, "domain1"),
MxRecord(10, "domain2"),
MxRecord(20, "domain3"),
],
ref_mx_domains=[MxRecord(10, "domain1"), MxRecord(20, "domain2")],
) )
assert not is_mx_equivalent( assert not is_mx_equivalent(
[(5, "domain1"), (10, "domain2")], mx_domains=[MxRecord(5, "domain1"), MxRecord(10, "domain2")],
[(10, "domain1"), (20, "domain2"), (20, "domain3")], ref_mx_domains=[
MxRecord(10, "domain1"),
MxRecord(20, "domain2"),
MxRecord(20, "domain3"),
],
) )

View File

@ -90,12 +90,19 @@ def test_can_be_used_as_personal_email(flask_client):
assert not email_can_be_used_as_mailbox("ab@sl.local") assert not email_can_be_used_as_mailbox("ab@sl.local")
assert not email_can_be_used_as_mailbox("hey@d1.test") assert not email_can_be_used_as_mailbox("hey@d1.test")
# custom domain # custom domain as SL domain
domain = random_domain() domain = random_domain()
user = create_new_user() user = create_new_user()
CustomDomain.create(user_id=user.id, domain=domain, verified=True, commit=True) domain_obj = CustomDomain.create(
user_id=user.id, domain=domain, verified=True, is_sl_subdomain=True, flush=True
)
assert not email_can_be_used_as_mailbox(f"hey@{domain}") assert not email_can_be_used_as_mailbox(f"hey@{domain}")
# custom domain is NOT SL domain
domain_obj.is_sl_subdomain = False
Session.flush()
assert email_can_be_used_as_mailbox(f"hey@{domain}")
# disposable domain # disposable domain
disposable_domain = random_domain() disposable_domain = random_domain()
InvalidMailboxDomain.create(domain=disposable_domain, commit=True) InvalidMailboxDomain.create(domain=disposable_domain, commit=True)

View File

@ -6,7 +6,9 @@ import pytest
from app import mailbox_utils, config from app import mailbox_utils, config
from app.db import Session from app.db import Session
from app.mail_sender import mail_sender from app.mail_sender import mail_sender
from app.models import Mailbox, MailboxActivation, User, Job from app.mailbox_utils import MailboxEmailChangeError
from app.models import Mailbox, MailboxActivation, User, Job, UserAuditLog
from app.user_audit_log_utils import UserAuditLogAction
from tests.utils import create_new_user, random_email from tests.utils import create_new_user, random_email
@ -218,7 +220,11 @@ def test_delete_with_transfer():
user, random_email(), use_digit_codes=True, send_link=False user, random_email(), use_digit_codes=True, send_link=False
).mailbox ).mailbox
transfer_mailbox = mailbox_utils.create_mailbox( transfer_mailbox = mailbox_utils.create_mailbox(
user, random_email(), use_digit_codes=True, send_link=False user,
random_email(),
use_digit_codes=True,
send_link=False,
verified=True,
).mailbox ).mailbox
mailbox_utils.delete_mailbox( mailbox_utils.delete_mailbox(
user, mailbox.id, transfer_mailbox_id=transfer_mailbox.id user, mailbox.id, transfer_mailbox_id=transfer_mailbox.id
@ -236,6 +242,28 @@ def test_delete_with_transfer():
assert job.payload["transfer_mailbox_id"] is None assert job.payload["transfer_mailbox_id"] is None
def test_cannot_delete_with_transfer_to_unverified_mailbox():
mailbox = mailbox_utils.create_mailbox(
user, random_email(), use_digit_codes=True, send_link=False
).mailbox
transfer_mailbox = mailbox_utils.create_mailbox(
user,
random_email(),
use_digit_codes=True,
send_link=False,
verified=False,
).mailbox
with pytest.raises(mailbox_utils.MailboxError):
mailbox_utils.delete_mailbox(
user, mailbox.id, transfer_mailbox_id=transfer_mailbox.id
)
# Verify mailbox still exists
db_mailbox = Mailbox.get_by(id=mailbox.id)
assert db_mailbox is not None
def test_verify_non_existing_mailbox(): def test_verify_non_existing_mailbox():
with pytest.raises(mailbox_utils.MailboxError): with pytest.raises(mailbox_utils.MailboxError):
mailbox_utils.verify_mailbox_code(user, 999999999, "9999999") mailbox_utils.verify_mailbox_code(user, 999999999, "9999999")
@ -302,3 +330,74 @@ def test_verify_ok():
assert activation is None assert activation is None
mailbox = Mailbox.get(id=output.mailbox.id) mailbox = Mailbox.get(id=output.mailbox.id)
assert mailbox.verified assert mailbox.verified
# perform_mailbox_email_change
def test_perform_mailbox_email_change_invalid_id():
res = mailbox_utils.perform_mailbox_email_change(99999)
assert res.error == MailboxEmailChangeError.InvalidId
assert res.message_category == "error"
def test_perform_mailbox_email_change_valid_id_not_new_email():
user = create_new_user()
mb = Mailbox.create(
user_id=user.id,
email=random_email(),
new_email=None,
verified=True,
commit=True,
)
res = mailbox_utils.perform_mailbox_email_change(mb.id)
assert res.error == MailboxEmailChangeError.InvalidId
assert res.message_category == "error"
audit_log_entries = UserAuditLog.filter_by(user_id=user.id).count()
assert audit_log_entries == 0
def test_perform_mailbox_email_change_valid_id_email_already_used():
user = create_new_user()
new_email = random_email()
# Create mailbox with that email
Mailbox.create(
user_id=user.id,
email=new_email,
verified=True,
)
mb_to_change = Mailbox.create(
user_id=user.id,
email=random_email(),
new_email=new_email,
verified=True,
commit=True,
)
res = mailbox_utils.perform_mailbox_email_change(mb_to_change.id)
assert res.error == MailboxEmailChangeError.EmailAlreadyUsed
assert res.message_category == "error"
audit_log_entries = UserAuditLog.filter_by(user_id=user.id).count()
assert audit_log_entries == 0
def test_perform_mailbox_email_change_success():
user = create_new_user()
new_email = random_email()
mb = Mailbox.create(
user_id=user.id,
email=random_email(),
new_email=new_email,
verified=True,
commit=True,
)
res = mailbox_utils.perform_mailbox_email_change(mb.id)
assert res.error is None
assert res.message_category == "success"
db_mailbox = Mailbox.get_by(id=mb.id)
assert db_mailbox is not None
assert db_mailbox.verified is True
assert db_mailbox.email == new_email
assert db_mailbox.new_email is None
audit_log_entries = UserAuditLog.filter_by(user_id=user.id).all()
assert len(audit_log_entries) == 1
assert audit_log_entries[0].action == UserAuditLogAction.UpdateMailbox.value

View File

@ -2,7 +2,7 @@ import arrow
from app.db import Session from app.db import Session
from app.models import CoinbaseSubscription from app.models import CoinbaseSubscription
from server import handle_coinbase_event from app.payments.coinbase import handle_coinbase_event
from tests.utils import create_new_user from tests.utils import create_new_user

View File

@ -1,113 +0,0 @@
import http.server
import json
import threading
import arrow
from app import config
from app.models import (
Subscription,
AppleSubscription,
CoinbaseSubscription,
ManualSubscription,
)
from tests.utils import create_new_user, random_token
from app.subscription_webhook import execute_subscription_webhook
http_server = None
last_http_request = None
def setup_module():
global http_server
http_server = http.server.ThreadingHTTPServer(("", 0), HTTPTestServer)
print(http_server.server_port)
threading.Thread(target=http_server.serve_forever, daemon=True).start()
config.SUBSCRIPTION_CHANGE_WEBHOOK = f"http://localhost:{http_server.server_port}"
def teardown_module():
global http_server
config.SUBSCRIPTION_CHANGE_WEBHOOK = None
http_server.shutdown()
class HTTPTestServer(http.server.BaseHTTPRequestHandler):
def do_POST(self):
global last_http_request
content_len = int(self.headers.get("Content-Length"))
body_data = self.rfile.read(content_len)
last_http_request = json.loads(body_data)
self.send_response(200)
def test_webhook_with_trial():
user = create_new_user()
execute_subscription_webhook(user)
assert last_http_request["user_id"] == user.id
assert last_http_request["is_premium"]
assert last_http_request["active_subscription_end"] is None
def test_webhook_with_subscription():
user = create_new_user()
end_at = arrow.utcnow().shift(days=1).replace(hour=0, minute=0, second=0)
Subscription.create(
user_id=user.id,
cancel_url="",
update_url="",
subscription_id=random_token(10),
event_time=arrow.now(),
next_bill_date=end_at.date(),
plan="yearly",
flush=True,
)
execute_subscription_webhook(user)
assert last_http_request["user_id"] == user.id
assert last_http_request["is_premium"]
assert last_http_request["active_subscription_end"] == end_at.timestamp
def test_webhook_with_apple_subscription():
user = create_new_user()
end_at = arrow.utcnow().shift(days=2).replace(hour=0, minute=0, second=0)
AppleSubscription.create(
user_id=user.id,
receipt_data=arrow.now().date().strftime("%Y-%m-%d"),
expires_date=end_at.date().strftime("%Y-%m-%d"),
original_transaction_id=random_token(10),
plan="yearly",
product_id="",
flush=True,
)
execute_subscription_webhook(user)
assert last_http_request["user_id"] == user.id
assert last_http_request["is_premium"]
assert last_http_request["active_subscription_end"] == end_at.timestamp
def test_webhook_with_coinbase_subscription():
user = create_new_user()
end_at = arrow.utcnow().shift(days=3).replace(hour=0, minute=0, second=0)
CoinbaseSubscription.create(
user_id=user.id, end_at=end_at.date().strftime("%Y-%m-%d"), flush=True
)
execute_subscription_webhook(user)
assert last_http_request["user_id"] == user.id
assert last_http_request["is_premium"]
assert last_http_request["active_subscription_end"] == end_at.timestamp
def test_webhook_with_manual_subscription():
user = create_new_user()
end_at = arrow.utcnow().shift(days=3).replace(hour=0, minute=0, second=0)
ManualSubscription.create(
user_id=user.id, end_at=end_at.date().strftime("%Y-%m-%d"), flush=True
)
execute_subscription_webhook(user)
assert last_http_request["user_id"] == user.id
assert last_http_request["is_premium"]
assert last_http_request["active_subscription_end"] == end_at.timestamp

View File

@ -0,0 +1,52 @@
from typing import List
from app import config, mailbox_utils
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.models import UserAuditLog
from app.utils import random_string
from tests.utils import create_new_user, random_email
def setup_module():
config.SKIP_MX_LOOKUP_ON_CHECK = True
def teardown_module():
config.SKIP_MX_LOOKUP_ON_CHECK = False
def test_emit_alias_audit_log_for_random_data():
user = create_new_user()
message = random_string()
action = UserAuditLogAction.CreateMailbox
emit_user_audit_log(
user=user,
action=action,
message=message,
commit=True,
)
logs_for_user: List[UserAuditLog] = UserAuditLog.filter_by(user_id=user.id).all()
assert len(logs_for_user) == 1
assert logs_for_user[0].user_id == user.id
assert logs_for_user[0].user_email == user.email
assert logs_for_user[0].action == action.value
assert logs_for_user[0].message == message
def test_emit_audit_log_on_mailbox_creation():
user = create_new_user()
output = mailbox_utils.create_mailbox(
user=user, email=random_email(), verified=True
)
logs_for_user: List[UserAuditLog] = UserAuditLog.filter_by(user_id=user.id).all()
assert len(logs_for_user) == 1
assert logs_for_user[0].user_id == user.id
assert logs_for_user[0].user_email == user.email
assert logs_for_user[0].action == UserAuditLogAction.CreateMailbox.value
assert (
logs_for_user[0].message
== f"Create mailbox {output.mailbox.id} ({output.mailbox.email}). Verified=True"
)