4.55.0
All checks were successful
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m43s
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 4m10s
Build-Release-Image / Merge-Images (push) Successful in 27s
Build-Release-Image / Create-Release (push) Successful in 10s
Build-Release-Image / Notify (push) Successful in 3s

This commit is contained in:
MrMeeb 2024-10-18 12:00:06 +01:00
parent 798b58529c
commit da6e56c4eb
58 changed files with 2002 additions and 637 deletions

View File

@ -9,6 +9,7 @@ from sqlalchemy import or_
from app.db import Session from app.db import Session
from app.email_utils import send_welcome_email from app.email_utils import send_welcome_email
from app.partner_user_utils import create_partner_user, create_partner_subscription
from app.utils import sanitize_email, canonicalize_email from app.utils import sanitize_email, canonicalize_email
from app.errors import ( from app.errors import (
AccountAlreadyLinkedToAnotherPartnerException, AccountAlreadyLinkedToAnotherPartnerException,
@ -23,6 +24,7 @@ from app.models import (
User, User,
Alias, Alias,
) )
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.utils import random_string from app.utils import random_string
@ -66,9 +68,10 @@ def set_plan_for_partner_user(partner_user: PartnerUser, plan: SLPlan):
LOG.i( LOG.i(
f"Creating partner_subscription [user_id={partner_user.user_id}] [partner_id={partner_user.partner_id}]" f"Creating partner_subscription [user_id={partner_user.user_id}] [partner_id={partner_user.partner_id}]"
) )
PartnerSubscription.create( create_partner_subscription(
partner_user_id=partner_user.id, partner_user=partner_user,
end_at=plan.expiration, expiration=plan.expiration,
msg="Upgraded via partner. User did not have a previous partner subscription",
) )
agent.record_custom_event("PlanChange", {"plan": "premium", "type": "new"}) agent.record_custom_event("PlanChange", {"plan": "premium", "type": "new"})
else: else:
@ -80,6 +83,11 @@ def set_plan_for_partner_user(partner_user: PartnerUser, plan: SLPlan):
"PlanChange", {"plan": "premium", "type": "extension"} "PlanChange", {"plan": "premium", "type": "extension"}
) )
sub.end_at = plan.expiration sub.end_at = plan.expiration
emit_user_audit_log(
user=partner_user.user,
action=UserAuditLogAction.SubscriptionExtended,
message="Extended partner subscription",
)
Session.commit() Session.commit()
@ -98,8 +106,8 @@ def ensure_partner_user_exists_for_user(
if res and res.partner_id != partner.id: if res and res.partner_id != partner.id:
raise AccountAlreadyLinkedToAnotherPartnerException() raise AccountAlreadyLinkedToAnotherPartnerException()
if not res: if not res:
res = PartnerUser.create( res = create_partner_user(
user_id=sl_user.id, user=sl_user,
partner_id=partner.id, partner_id=partner.id,
partner_email=link_request.email, partner_email=link_request.email,
external_user_id=link_request.external_user_id, external_user_id=link_request.external_user_id,
@ -140,8 +148,8 @@ class NewUserStrategy(ClientMergeStrategy):
activated=True, activated=True,
from_partner=self.link_request.from_partner, from_partner=self.link_request.from_partner,
) )
partner_user = PartnerUser.create( partner_user = create_partner_user(
user_id=new_user.id, user=new_user,
partner_id=self.partner.id, partner_id=self.partner.id,
external_user_id=self.link_request.external_user_id, external_user_id=self.link_request.external_user_id,
partner_email=self.link_request.email, partner_email=self.link_request.email,
@ -200,7 +208,7 @@ def get_login_strategy(
return ExistingUnlinkedUserStrategy(link_request, user, partner) return ExistingUnlinkedUserStrategy(link_request, user, partner)
def check_alias(email: str) -> bool: def check_alias(email: str):
alias = Alias.get_by(email=email) alias = Alias.get_by(email=email)
if alias is not None: if alias is not None:
raise AccountIsUsingAliasAsEmail() raise AccountIsUsingAliasAsEmail()
@ -275,10 +283,26 @@ def switch_already_linked_user(
LOG.i( LOG.i(
f"Deleting previous partner_user:{other_partner_user.id} from user:{current_user.id}" f"Deleting previous partner_user:{other_partner_user.id} from user:{current_user.id}"
) )
emit_user_audit_log(
user=other_partner_user.user,
action=UserAuditLogAction.UnlinkAccount,
message=f"Deleting partner_user {other_partner_user.id} (external_user_id={other_partner_user.external_user_id} | partner_email={other_partner_user.partner_email}) from user {current_user.id}, as we received a new link request for the same partner",
)
PartnerUser.delete(other_partner_user.id) PartnerUser.delete(other_partner_user.id)
LOG.i(f"Linking partner_user:{partner_user.id} to user:{current_user.id}") LOG.i(f"Linking partner_user:{partner_user.id} to user:{current_user.id}")
# Link this partner_user to the current user # Link this partner_user to the current user
emit_user_audit_log(
user=partner_user.user,
action=UserAuditLogAction.UnlinkAccount,
message=f"Unlinking from partner, as user will now be tied to another external account. old=(id={partner_user.user.id} | email={partner_user.user.email}) | new=(id={current_user.id} | email={current_user.email})",
)
partner_user.user_id = current_user.id partner_user.user_id = current_user.id
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.LinkAccount,
message=f"Linking user {current_user.id} ({current_user.email}) to partner_user:{partner_user.id} (external_user_id={partner_user.external_user_id} | partner_email={partner_user.partner_email})",
)
# Set plan # Set plan
set_plan_for_partner_user(partner_user, link_request.plan) set_plan_for_partner_user(partner_user, link_request.plan)
Session.commit() Session.commit()

View File

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional from typing import Optional, List
import arrow import arrow
import sqlalchemy import sqlalchemy
@ -35,6 +35,8 @@ from app.models import (
DomainDeletedAlias, DomainDeletedAlias,
PartnerUser, PartnerUser,
AliasMailbox, AliasMailbox,
AliasAuditLog,
UserAuditLog,
) )
from app.newsletter_utils import send_newsletter_to_user, send_newsletter_to_address from app.newsletter_utils import send_newsletter_to_user, send_newsletter_to_address
@ -737,11 +739,13 @@ class InvalidMailboxDomainAdmin(SLModelView):
class EmailSearchResult: class EmailSearchResult:
no_match: bool = True no_match: bool = True
alias: Optional[Alias] = None alias: Optional[Alias] = None
mailbox: list[Mailbox] = [] alias_audit_log: Optional[List[AliasAuditLog]] = None
mailbox: List[Mailbox] = []
mailbox_count: int = 0 mailbox_count: int = 0
deleted_alias: Optional[DeletedAlias] = None deleted_alias: Optional[DeletedAlias] = None
deleted_custom_alias: Optional[DomainDeletedAlias] = None deleted_custom_alias: Optional[DomainDeletedAlias] = None
user: Optional[User] = None user: Optional[User] = None
user_audit_log: Optional[List[UserAuditLog]] = None
@staticmethod @staticmethod
def from_email(email: str) -> EmailSearchResult: def from_email(email: str) -> EmailSearchResult:
@ -749,10 +753,20 @@ class EmailSearchResult:
alias = Alias.get_by(email=email) alias = Alias.get_by(email=email)
if alias: if alias:
output.alias = alias output.alias = alias
output.alias_audit_log = (
AliasAuditLog.filter_by(alias_id=alias.id)
.order_by(AliasAuditLog.created_at.desc())
.all()
)
output.no_match = False output.no_match = False
user = User.get_by(email=email) user = User.get_by(email=email)
if user: if user:
output.user = user output.user = user
output.user_audit_log = (
UserAuditLog.filter_by(user_id=user.id)
.order_by(UserAuditLog.created_at.desc())
.all()
)
output.no_match = False output.no_match = False
mailboxes = ( mailboxes = (
Mailbox.filter_by(email=email).order_by(Mailbox.id.desc()).limit(10).all() Mailbox.filter_by(email=email).order_by(Mailbox.id.desc()).limit(10).all()

View File

@ -0,0 +1,38 @@
from enum import Enum
from typing import Optional
from app.models import Alias, AliasAuditLog
class AliasAuditLogAction(Enum):
CreateAlias = "create"
ChangeAliasStatus = "change_status"
DeleteAlias = "delete"
UpdateAlias = "update"
InitiateTransferAlias = "initiate_transfer_alias"
AcceptTransferAlias = "accept_transfer_alias"
TransferredAlias = "transferred_alias"
ChangedMailboxes = "changed_mailboxes"
CreateContact = "create_contact"
UpdateContact = "update_contact"
DeleteContact = "delete_contact"
def emit_alias_audit_log(
alias: Alias,
action: AliasAuditLogAction,
message: str,
user_id: Optional[int] = None,
commit: bool = False,
):
AliasAuditLog.create(
user_id=user_id or alias.user_id,
alias_id=alias.id,
alias_email=alias.email,
action=action.value,
message=message,
commit=commit,
)

View File

@ -0,0 +1,61 @@
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional
from app.alias_audit_log_utils import emit_alias_audit_log, AliasAuditLogAction
from app.db import Session
from app.models import Alias, AliasMailbox, Mailbox
_MAX_MAILBOXES_PER_ALIAS = 20
class CannotSetMailboxesForAliasCause(Enum):
Forbidden = "Forbidden"
EmptyMailboxes = "Must choose at least one mailbox"
TooManyMailboxes = "Too many mailboxes"
@dataclass
class SetMailboxesForAliasResult:
performed_change: bool
reason: Optional[CannotSetMailboxesForAliasCause]
def set_mailboxes_for_alias(
user_id: int, alias: Alias, mailbox_ids: List[int]
) -> Optional[CannotSetMailboxesForAliasCause]:
if len(mailbox_ids) == 0:
return CannotSetMailboxesForAliasCause.EmptyMailboxes
if len(mailbox_ids) > _MAX_MAILBOXES_PER_ALIAS:
return CannotSetMailboxesForAliasCause.TooManyMailboxes
mailboxes = (
Session.query(Mailbox)
.filter(
Mailbox.id.in_(mailbox_ids),
Mailbox.user_id == user_id,
Mailbox.verified == True, # noqa: E712
)
.all()
)
if len(mailboxes) != len(mailbox_ids):
return CannotSetMailboxesForAliasCause.Forbidden
# first remove all existing alias-mailboxes links
AliasMailbox.filter_by(alias_id=alias.id).delete()
Session.flush()
# then add all new mailboxes, being the first the one associated with the alias
for i, mailbox in enumerate(mailboxes):
if i == 0:
alias.mailbox_id = mailboxes[0].id
else:
AliasMailbox.create(alias_id=alias.id, mailbox_id=mailbox.id)
emit_alias_audit_log(
alias=alias,
action=AliasAuditLogAction.ChangedMailboxes,
message=",".join([f"{mailbox.id} ({mailbox.email})" for mailbox in mailboxes]),
)
return None

View File

@ -8,6 +8,7 @@ from email_validator import validate_email, EmailNotValidError
from sqlalchemy.exc import IntegrityError, DataError from sqlalchemy.exc import IntegrityError, DataError
from flask import make_response from flask import make_response
from app.alias_audit_log_utils import AliasAuditLogAction, emit_alias_audit_log
from app.config import ( from app.config import (
BOUNCE_PREFIX_FOR_REPLY_PHASE, BOUNCE_PREFIX_FOR_REPLY_PHASE,
BOUNCE_PREFIX, BOUNCE_PREFIX,
@ -368,6 +369,10 @@ def delete_alias(
alias_id = alias.id alias_id = alias.id
alias_email = alias.email alias_email = alias.email
emit_alias_audit_log(
alias, AliasAuditLogAction.DeleteAlias, "Alias deleted by user action"
)
Alias.filter(Alias.id == alias.id).delete() Alias.filter(Alias.id == alias.id).delete()
Session.commit() Session.commit()
@ -450,7 +455,7 @@ def alias_export_csv(user, csv_direct_export=False):
return output return output
def transfer_alias(alias, new_user, new_mailboxes: [Mailbox]): def transfer_alias(alias: Alias, new_user: User, new_mailboxes: [Mailbox]):
# cannot transfer alias which is used for receiving newsletter # cannot transfer alias which is used for receiving newsletter
if User.get_by(newsletter_alias_id=alias.id): if User.get_by(newsletter_alias_id=alias.id):
raise Exception("Cannot transfer alias that's used to receive newsletter") raise Exception("Cannot transfer alias that's used to receive newsletter")
@ -504,6 +509,12 @@ def transfer_alias(alias, new_user, new_mailboxes: [Mailbox]):
alias.disable_pgp = False alias.disable_pgp = False
alias.pinned = False alias.pinned = False
emit_alias_audit_log(
alias=alias,
action=AliasAuditLogAction.TransferredAlias,
message=f"Lost ownership of alias due to alias transfer confirmed. New owner is {new_user.id}",
user_id=old_user.id,
)
EventDispatcher.send_event( EventDispatcher.send_event(
old_user, old_user,
EventContent( EventContent(
@ -513,6 +524,13 @@ def transfer_alias(alias, new_user, new_mailboxes: [Mailbox]):
) )
), ),
) )
emit_alias_audit_log(
alias=alias,
action=AliasAuditLogAction.AcceptTransferAlias,
message=f"Accepted alias transfer from user {old_user.id}",
user_id=new_user.id,
)
EventDispatcher.send_event( EventDispatcher.send_event(
new_user, new_user,
EventContent( EventContent(
@ -529,7 +547,9 @@ def transfer_alias(alias, new_user, new_mailboxes: [Mailbox]):
Session.commit() Session.commit()
def change_alias_status(alias: Alias, enabled: bool, commit: bool = False): def change_alias_status(
alias: Alias, enabled: bool, message: Optional[str] = None, commit: bool = False
):
LOG.i(f"Changing alias {alias} enabled to {enabled}") LOG.i(f"Changing alias {alias} enabled to {enabled}")
alias.enabled = enabled alias.enabled = enabled
@ -540,6 +560,12 @@ def change_alias_status(alias: Alias, enabled: bool, commit: bool = False):
created_at=int(alias.created_at.timestamp), created_at=int(alias.created_at.timestamp),
) )
EventDispatcher.send_event(alias.user, EventContent(alias_status_change=event)) EventDispatcher.send_event(alias.user, EventContent(alias_status_change=event))
audit_log_message = f"Set alias status to {enabled}"
if message is not None:
audit_log_message += f". {message}"
emit_alias_audit_log(
alias, AliasAuditLogAction.ChangeAliasStatus, audit_log_message
)
if commit: if commit:
Session.commit() Session.commit()

View File

@ -1,9 +1,13 @@
from typing import Optional
from deprecated import deprecated from deprecated import deprecated
from flask import g from flask import g
from flask import jsonify from flask import jsonify
from flask import request from flask import request
from app import alias_utils from app import alias_utils
from app.alias_audit_log_utils import emit_alias_audit_log, AliasAuditLogAction
from app.alias_mailbox_utils import set_mailboxes_for_alias
from app.api.base import api_bp, require_api_auth from app.api.base import api_bp, require_api_auth
from app.api.serializer import ( from app.api.serializer import (
AliasInfo, AliasInfo,
@ -26,7 +30,7 @@ from app.errors import (
) )
from app.extensions import limiter from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import Alias, Contact, Mailbox, AliasMailbox, AliasDeleteReason from app.models import Alias, Contact, Mailbox, AliasDeleteReason
@deprecated @deprecated
@ -185,7 +189,11 @@ def toggle_alias(alias_id):
if not alias or alias.user_id != user.id: if not alias or alias.user_id != user.id:
return jsonify(error="Forbidden"), 403 return jsonify(error="Forbidden"), 403
alias_utils.change_alias_status(alias, enabled=not alias.enabled) alias_utils.change_alias_status(
alias,
enabled=not alias.enabled,
message=f"Set enabled={not alias.enabled} via API",
)
LOG.i(f"User {user} changed alias {alias} enabled status to {alias.enabled}") LOG.i(f"User {user} changed alias {alias} enabled status to {alias.enabled}")
Session.commit() Session.commit()
@ -272,10 +280,12 @@ def update_alias(alias_id):
if not alias or alias.user_id != user.id: if not alias or alias.user_id != user.id:
return jsonify(error="Forbidden"), 403 return jsonify(error="Forbidden"), 403
changed_fields = []
changed = False changed = False
if "note" in data: if "note" in data:
new_note = data.get("note") new_note = data.get("note")
alias.note = new_note alias.note = new_note
changed_fields.append("note")
changed = True changed = True
if "mailbox_id" in data: if "mailbox_id" in data:
@ -285,35 +295,19 @@ def update_alias(alias_id):
return jsonify(error="Forbidden"), 400 return jsonify(error="Forbidden"), 400
alias.mailbox_id = mailbox_id alias.mailbox_id = mailbox_id
changed_fields.append(f"mailbox_id ({mailbox_id})")
changed = True changed = True
if "mailbox_ids" in data: if "mailbox_ids" in data:
mailbox_ids = [int(m_id) for m_id in data.get("mailbox_ids")] mailbox_ids = [int(m_id) for m_id in data.get("mailbox_ids")]
mailboxes: [Mailbox] = [] err = set_mailboxes_for_alias(
user_id=user.id, alias=alias, mailbox_ids=mailbox_ids
# check if all mailboxes belong to user )
for mailbox_id in mailbox_ids: if err:
mailbox = Mailbox.get(mailbox_id) return jsonify(error=err.value), 400
if not mailbox or mailbox.user_id != user.id or not mailbox.verified:
return jsonify(error="Forbidden"), 400
mailboxes.append(mailbox)
if not mailboxes:
return jsonify(error="Must choose at least one mailbox"), 400
# <<< update alias mailboxes >>>
# first remove all existing alias-mailboxes links
AliasMailbox.filter_by(alias_id=alias.id).delete()
Session.flush()
# then add all new mailboxes
for i, mailbox in enumerate(mailboxes):
if i == 0:
alias.mailbox_id = mailboxes[0].id
else:
AliasMailbox.create(alias_id=alias.id, mailbox_id=mailbox.id)
# <<< END update alias mailboxes >>>
mailbox_ids_string = ",".join(map(str, mailbox_ids))
changed_fields.append(f"mailbox_ids ({mailbox_ids_string})")
changed = True changed = True
if "name" in data: if "name" in data:
@ -325,17 +319,26 @@ def update_alias(alias_id):
if new_name: if new_name:
new_name = new_name.replace("\n", "") new_name = new_name.replace("\n", "")
alias.name = new_name alias.name = new_name
changed_fields.append("name")
changed = True changed = True
if "disable_pgp" in data: if "disable_pgp" in data:
alias.disable_pgp = data.get("disable_pgp") alias.disable_pgp = data.get("disable_pgp")
changed_fields.append("disable_pgp")
changed = True changed = True
if "pinned" in data: if "pinned" in data:
alias.pinned = data.get("pinned") alias.pinned = data.get("pinned")
changed_fields.append("pinned")
changed = True changed = True
if changed: if changed:
changed_fields_string = ",".join(changed_fields)
emit_alias_audit_log(
alias,
AliasAuditLogAction.UpdateAlias,
f"Alias fields updated ({changed_fields_string})",
)
Session.commit() Session.commit()
return jsonify(ok=True), 200 return jsonify(ok=True), 200
@ -446,11 +449,16 @@ def delete_contact(contact_id):
200 200
""" """
user = g.user user = g.user
contact = Contact.get(contact_id) contact: Optional[Contact] = Contact.get(contact_id)
if not contact or contact.alias.user_id != user.id: if not contact or contact.alias.user_id != user.id:
return jsonify(error="Forbidden"), 403 return jsonify(error="Forbidden"), 403
emit_alias_audit_log(
alias=contact.alias,
action=AliasAuditLogAction.DeleteContact,
message=f"Deleted contact {contact_id} ({contact.email})",
)
Contact.delete(contact_id) Contact.delete(contact_id)
Session.commit() Session.commit()
@ -468,12 +476,17 @@ def toggle_contact(contact_id):
200 200
""" """
user = g.user user = g.user
contact = Contact.get(contact_id) contact: Optional[Contact] = Contact.get(contact_id)
if not contact or contact.alias.user_id != user.id: if not contact or contact.alias.user_id != user.id:
return jsonify(error="Forbidden"), 403 return jsonify(error="Forbidden"), 403
contact.block_forward = not contact.block_forward contact.block_forward = not contact.block_forward
emit_alias_audit_log(
alias=contact.alias,
action=AliasAuditLogAction.UpdateContact,
message=f"Set contact state {contact.id} {contact.email} -> {contact.website_email} to blocked {contact.block_forward}",
)
Session.commit() Session.commit()
return jsonify(block_forward=contact.block_forward), 200 return jsonify(block_forward=contact.block_forward), 200

View File

@ -2,8 +2,10 @@ from flask import g, request
from flask import jsonify from flask import jsonify
from app.api.base import api_bp, require_api_auth from app.api.base import api_bp, require_api_auth
from app.custom_domain_utils import set_custom_domain_mailboxes
from app.db import Session from app.db import Session
from app.models import CustomDomain, DomainDeletedAlias, Mailbox, DomainMailbox from app.log import LOG
from app.models import CustomDomain, DomainDeletedAlias
def custom_domain_to_dict(custom_domain: CustomDomain): def custom_domain_to_dict(custom_domain: CustomDomain):
@ -100,23 +102,14 @@ def update_custom_domain(custom_domain_id):
if "mailbox_ids" in data: if "mailbox_ids" in data:
mailbox_ids = [int(m_id) for m_id in data.get("mailbox_ids")] mailbox_ids = [int(m_id) for m_id in data.get("mailbox_ids")]
if mailbox_ids: result = set_custom_domain_mailboxes(user.id, custom_domain, mailbox_ids)
# check if mailbox is not tempered with if result.success:
mailboxes = []
for mailbox_id in mailbox_ids:
mailbox = Mailbox.get(mailbox_id)
if not mailbox or mailbox.user_id != user.id or not mailbox.verified:
return jsonify(error="Forbidden"), 400
mailboxes.append(mailbox)
# first remove all existing domain-mailboxes links
DomainMailbox.filter_by(domain_id=custom_domain.id).delete()
Session.flush()
for mailbox in mailboxes:
DomainMailbox.create(domain_id=custom_domain.id, mailbox_id=mailbox.id)
changed = True changed = True
else:
LOG.info(
f"Prevented from updating mailboxes [custom_domain_id={custom_domain.id}]: {result.reason.value}"
)
return jsonify(error="Forbidden"), 400
if changed: if changed:
Session.commit() Session.commit()

View File

@ -6,6 +6,7 @@ from app import config
from app.extensions import limiter from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import Job, ApiToCookieToken from app.models import Job, ApiToCookieToken
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
@api_bp.route("/user", methods=["DELETE"]) @api_bp.route("/user", methods=["DELETE"])
@ -16,6 +17,11 @@ def delete_user():
""" """
# Schedule delete account job # Schedule delete account job
emit_user_audit_log(
user=g.user,
action=UserAuditLogAction.UserMarkedForDeletion,
message=f"Marked user {g.user.id} ({g.user.email}) for deletion from API",
)
LOG.w("schedule delete account job for %s", g.user) LOG.w("schedule delete account job for %s", g.user)
Job.create( Job.create(
name=config.JOB_DELETE_ACCOUNT, name=config.JOB_DELETE_ACCOUNT,

View File

@ -601,7 +601,6 @@ SKIP_MX_LOOKUP_ON_CHECK = False
DISABLE_RATE_LIMIT = "DISABLE_RATE_LIMIT" in os.environ DISABLE_RATE_LIMIT = "DISABLE_RATE_LIMIT" in os.environ
SUBSCRIPTION_CHANGE_WEBHOOK = os.environ.get("SUBSCRIPTION_CHANGE_WEBHOOK", None)
MAX_API_KEYS = int(os.environ.get("MAX_API_KEYS", 30)) MAX_API_KEYS = int(os.environ.get("MAX_API_KEYS", 30))
UPCLOUD_USERNAME = os.environ.get("UPCLOUD_USERNAME", None) UPCLOUD_USERNAME = os.environ.get("UPCLOUD_USERNAME", None)
@ -663,3 +662,5 @@ PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES: dict[int, str] = read_partner_dict(
MAILBOX_VERIFICATION_OVERRIDE_CODE: Optional[str] = os.environ.get( MAILBOX_VERIFICATION_OVERRIDE_CODE: Optional[str] = os.environ.get(
"MAILBOX_VERIFICATION_OVERRIDE_CODE", None "MAILBOX_VERIFICATION_OVERRIDE_CODE", None
) )
AUDIT_LOG_MAX_DAYS = int(os.environ.get("AUDIT_LOG_MAX_DAYS", 30))

View File

@ -4,6 +4,7 @@ from typing import Optional
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from app.alias_audit_log_utils import emit_alias_audit_log, AliasAuditLogAction
from app.db import Session from app.db import Session
from app.email_utils import generate_reply_email, parse_full_address from app.email_utils import generate_reply_email, parse_full_address
from app.email_validation import is_valid_email from app.email_validation import is_valid_email
@ -100,6 +101,16 @@ def create_contact(
invalid_email=email == "", invalid_email=email == "",
commit=True, commit=True,
) )
if automatic_created:
trail = ". Automatically created"
else:
trail = ". Created by user action"
emit_alias_audit_log(
alias=alias,
action=AliasAuditLogAction.CreateContact,
message=f"Created contact {contact.id} ({contact.email}){trail}",
commit=True,
)
LOG.d( LOG.d(
f"Created contact {contact} for alias {alias} with email {email} invalid_email={contact.invalid_email}" f"Created contact {contact} for alias {alias} with email {email} invalid_email={contact.invalid_email}"
) )

View File

@ -10,6 +10,7 @@ from app.db import Session
from app.email_utils import get_email_domain_part from app.email_utils import get_email_domain_part
from app.log import LOG from app.log import LOG
from app.models import User, CustomDomain, SLDomain, Mailbox, Job, DomainMailbox from app.models import User, CustomDomain, SLDomain, Mailbox, Job, DomainMailbox
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
_ALLOWED_DOMAIN_REGEX = re.compile(r"^(?!-)[A-Za-z0-9-]{1,63}(?<!-)$") _ALLOWED_DOMAIN_REGEX = re.compile(r"^(?!-)[A-Za-z0-9-]{1,63}(?<!-)$")
_MAX_MAILBOXES_PER_DOMAIN = 20 _MAX_MAILBOXES_PER_DOMAIN = 20
@ -137,6 +138,11 @@ def create_custom_domain(
if partner_id is not None: if partner_id is not None:
new_custom_domain.partner_id = partner_id new_custom_domain.partner_id = partner_id
emit_user_audit_log(
user=user,
action=UserAuditLogAction.CreateCustomDomain,
message=f"Created custom domain {new_custom_domain.id} ({new_domain})",
)
Session.commit() Session.commit()
return CreateCustomDomainResult( return CreateCustomDomainResult(
@ -190,5 +196,11 @@ def set_custom_domain_mailboxes(
for mailbox in mailboxes: for mailbox in mailboxes:
DomainMailbox.create(domain_id=custom_domain.id, mailbox_id=mailbox.id) DomainMailbox.create(domain_id=custom_domain.id, mailbox_id=mailbox.id)
mailboxes_as_str = ",".join(map(str, mailbox_ids))
emit_user_audit_log(
user=custom_domain.user,
action=UserAuditLogAction.UpdateCustomDomain,
message=f"Updated custom domain {custom_domain.id} mailboxes (domain={custom_domain.domain}) (mailboxes={mailboxes_as_str})",
)
Session.commit() Session.commit()
return SetCustomDomainMailboxesResult(success=True) return SetCustomDomainMailboxesResult(success=True)

View File

@ -11,6 +11,7 @@ from app.dns_utils import (
get_network_dns_client, get_network_dns_client,
) )
from app.models import CustomDomain from app.models import CustomDomain
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.utils import random_string from app.utils import random_string
@ -121,6 +122,12 @@ class CustomDomainValidation:
# Original DKIM record is not there, which means the DKIM config is not finished. Proceed with the # Original DKIM record is not there, which means the DKIM config is not finished. Proceed with the
# rest of the code path, returning the invalid records and clearing the flag # rest of the code path, returning the invalid records and clearing the flag
custom_domain.dkim_verified = len(invalid_records) == 0 custom_domain.dkim_verified = len(invalid_records) == 0
if custom_domain.dkim_verified:
emit_user_audit_log(
user=custom_domain.user,
action=UserAuditLogAction.VerifyCustomDomain,
message=f"Verified DKIM records for custom domain {custom_domain.id} ({custom_domain.domain})",
)
Session.commit() Session.commit()
return invalid_records return invalid_records
@ -137,6 +144,11 @@ class CustomDomainValidation:
if expected_verification_record in txt_records: if expected_verification_record in txt_records:
custom_domain.ownership_verified = True custom_domain.ownership_verified = True
emit_user_audit_log(
user=custom_domain.user,
action=UserAuditLogAction.VerifyCustomDomain,
message=f"Verified ownership for custom domain {custom_domain.id} ({custom_domain.domain})",
)
Session.commit() Session.commit()
return DomainValidationResult(success=True, errors=[]) return DomainValidationResult(success=True, errors=[])
else: else:
@ -155,6 +167,11 @@ class CustomDomainValidation:
) )
else: else:
custom_domain.verified = True custom_domain.verified = True
emit_user_audit_log(
user=custom_domain.user,
action=UserAuditLogAction.VerifyCustomDomain,
message=f"Verified MX records for custom domain {custom_domain.id} ({custom_domain.domain})",
)
Session.commit() Session.commit()
return DomainValidationResult(success=True, errors=[]) return DomainValidationResult(success=True, errors=[])
@ -165,6 +182,11 @@ class CustomDomainValidation:
expected_spf_domain = self.get_expected_spf_domain(custom_domain) expected_spf_domain = self.get_expected_spf_domain(custom_domain)
if expected_spf_domain in spf_domains: if expected_spf_domain in spf_domains:
custom_domain.spf_verified = True custom_domain.spf_verified = True
emit_user_audit_log(
user=custom_domain.user,
action=UserAuditLogAction.VerifyCustomDomain,
message=f"Verified SPF records for custom domain {custom_domain.id} ({custom_domain.domain})",
)
Session.commit() Session.commit()
return DomainValidationResult(success=True, errors=[]) return DomainValidationResult(success=True, errors=[])
else: else:
@ -183,6 +205,11 @@ class CustomDomainValidation:
txt_records = self._dns_client.get_txt_record("_dmarc." + custom_domain.domain) txt_records = self._dns_client.get_txt_record("_dmarc." + custom_domain.domain)
if DMARC_RECORD in txt_records: if DMARC_RECORD in txt_records:
custom_domain.dmarc_verified = True custom_domain.dmarc_verified = True
emit_user_audit_log(
user=custom_domain.user,
action=UserAuditLogAction.VerifyCustomDomain,
message=f"Verified DMARC records for custom domain {custom_domain.id} ({custom_domain.domain})",
)
Session.commit() Session.commit()
return DomainValidationResult(success=True, errors=[]) return DomainValidationResult(success=True, errors=[])
else: else:

View File

@ -1,5 +1,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from operator import or_ from operator import or_
from typing import Optional
from flask import render_template, request, redirect, flash from flask import render_template, request, redirect, flash
from flask import url_for from flask import url_for
@ -10,6 +11,7 @@ from wtforms import StringField, validators, ValidationError
# Need to import directly from config to allow modification from the tests # Need to import directly from config to allow modification from the tests
from app import config, parallel_limiter, contact_utils from app import config, parallel_limiter, contact_utils
from app.alias_audit_log_utils import emit_alias_audit_log, AliasAuditLogAction
from app.contact_utils import ContactCreateError from app.contact_utils import ContactCreateError
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
@ -190,7 +192,7 @@ def get_contact_infos(
def delete_contact(alias: Alias, contact_id: int): def delete_contact(alias: Alias, contact_id: int):
contact = Contact.get(contact_id) contact: Optional[Contact] = Contact.get(contact_id)
if not contact: if not contact:
flash("Unknown error. Refresh the page", "warning") flash("Unknown error. Refresh the page", "warning")
@ -198,6 +200,11 @@ def delete_contact(alias: Alias, contact_id: int):
flash("You cannot delete reverse-alias", "warning") flash("You cannot delete reverse-alias", "warning")
else: else:
delete_contact_email = contact.website_email delete_contact_email = contact.website_email
emit_alias_audit_log(
alias=alias,
action=AliasAuditLogAction.DeleteContact,
message=f"Delete contact {contact_id} ({contact.email})",
)
Contact.delete(contact_id) Contact.delete(contact_id)
Session.commit() Session.commit()

View File

@ -7,6 +7,7 @@ from flask import render_template, redirect, url_for, flash, request
from flask_login import login_required, current_user from flask_login import login_required, current_user
from app import config from app import config
from app.alias_audit_log_utils import emit_alias_audit_log, AliasAuditLogAction
from app.alias_utils import transfer_alias from app.alias_utils import transfer_alias
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required from app.dashboard.views.enter_sudo import sudo_required
@ -57,6 +58,12 @@ def alias_transfer_send_route(alias_id):
transfer_token = f"{alias.id}.{secrets.token_urlsafe(32)}" transfer_token = f"{alias.id}.{secrets.token_urlsafe(32)}"
alias.transfer_token = hmac_alias_transfer_token(transfer_token) alias.transfer_token = hmac_alias_transfer_token(transfer_token)
alias.transfer_token_expiration = arrow.utcnow().shift(hours=24) alias.transfer_token_expiration = arrow.utcnow().shift(hours=24)
emit_alias_audit_log(
alias,
AliasAuditLogAction.InitiateTransferAlias,
"Initiated alias transfer",
)
Session.commit() Session.commit()
alias_transfer_url = ( alias_transfer_url = (
config.URL config.URL

View File

@ -1,8 +1,11 @@
from typing import Optional
from flask import render_template, request, redirect, url_for, flash from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import StringField, validators from wtforms import StringField, validators
from app.alias_audit_log_utils import emit_alias_audit_log, AliasAuditLogAction
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
from app.models import Contact from app.models import Contact
@ -20,7 +23,7 @@ class PGPContactForm(FlaskForm):
@dashboard_bp.route("/contact/<int:contact_id>/", methods=["GET", "POST"]) @dashboard_bp.route("/contact/<int:contact_id>/", methods=["GET", "POST"])
@login_required @login_required
def contact_detail_route(contact_id): def contact_detail_route(contact_id):
contact = Contact.get(contact_id) contact: Optional[Contact] = Contact.get(contact_id)
if not contact or contact.user_id != current_user.id: if not contact or contact.user_id != current_user.id:
flash("You cannot see this page", "warning") flash("You cannot see this page", "warning")
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
@ -50,6 +53,11 @@ def contact_detail_route(contact_id):
except PGPException: except PGPException:
flash("Cannot add the public key, please verify it", "error") flash("Cannot add the public key, please verify it", "error")
else: else:
emit_alias_audit_log(
alias=alias,
action=AliasAuditLogAction.UpdateContact,
message=f"Added PGP key {contact.pgp_public_key} for contact {contact_id} ({contact.email})",
)
Session.commit() Session.commit()
flash( flash(
f"PGP public key for {contact.email} is saved successfully", f"PGP public key for {contact.email} is saved successfully",
@ -62,6 +70,11 @@ def contact_detail_route(contact_id):
) )
elif pgp_form.action.data == "remove": elif pgp_form.action.data == "remove":
# Free user can decide to remove contact PGP key # Free user can decide to remove contact PGP key
emit_alias_audit_log(
alias=alias,
action=AliasAuditLogAction.UpdateContact,
message=f"Removed PGP key {contact.pgp_public_key} for contact {contact_id} ({contact.email})",
)
contact.pgp_public_key = None contact.pgp_public_key = None
contact.pgp_finger_print = None contact.pgp_finger_print = None
Session.commit() Session.commit()

View File

@ -8,6 +8,7 @@ from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required from app.dashboard.views.enter_sudo import sudo_required
from app.log import LOG from app.log import LOG
from app.models import Subscription, Job from app.models import Subscription, Job
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
class DeleteDirForm(FlaskForm): class DeleteDirForm(FlaskForm):
@ -33,6 +34,11 @@ def delete_account():
# Schedule delete account job # Schedule delete account job
LOG.w("schedule delete account job for %s", current_user) LOG.w("schedule delete account job for %s", current_user)
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UserMarkedForDeletion,
message=f"User {current_user.id} ({current_user.email}) marked for deletion via webapp",
)
Job.create( Job.create(
name=JOB_DELETE_ACCOUNT, name=JOB_DELETE_ACCOUNT,
payload={"user_id": current_user.id}, payload={"user_id": current_user.id},

View File

@ -1,3 +1,5 @@
from typing import Optional
from flask import render_template, request, redirect, url_for, flash from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
@ -20,6 +22,7 @@ from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
from app.errors import DirectoryInTrashError from app.errors import DirectoryInTrashError
from app.models import Directory, Mailbox, DirectoryMailbox from app.models import Directory, Mailbox, DirectoryMailbox
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
class NewDirForm(FlaskForm): class NewDirForm(FlaskForm):
@ -69,7 +72,9 @@ def directory():
if not delete_dir_form.validate(): if not delete_dir_form.validate():
flash("Invalid request", "warning") flash("Invalid request", "warning")
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
dir_obj = Directory.get(delete_dir_form.directory_id.data) dir_obj: Optional[Directory] = Directory.get(
delete_dir_form.directory_id.data
)
if not dir_obj: if not dir_obj:
flash("Unknown error. Refresh the page", "warning") flash("Unknown error. Refresh the page", "warning")
@ -79,6 +84,11 @@ def directory():
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
name = dir_obj.name name = dir_obj.name
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.DeleteDirectory,
message=f"Delete directory {dir_obj.id} ({dir_obj.name})",
)
Directory.delete(dir_obj.id) Directory.delete(dir_obj.id)
Session.commit() Session.commit()
flash(f"Directory {name} has been deleted", "success") flash(f"Directory {name} has been deleted", "success")
@ -90,7 +100,7 @@ def directory():
flash("Invalid request", "warning") flash("Invalid request", "warning")
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
dir_id = toggle_dir_form.directory_id.data dir_id = toggle_dir_form.directory_id.data
dir_obj = Directory.get(dir_id) dir_obj: Optional[Directory] = Directory.get(dir_id)
if not dir_obj or dir_obj.user_id != current_user.id: if not dir_obj or dir_obj.user_id != current_user.id:
flash("Unknown error. Refresh the page", "warning") flash("Unknown error. Refresh the page", "warning")
@ -103,6 +113,11 @@ def directory():
dir_obj.disabled = True dir_obj.disabled = True
flash(f"On-the-fly is disabled for {dir_obj.name}", "warning") flash(f"On-the-fly is disabled for {dir_obj.name}", "warning")
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateDirectory,
message=f"Updated directory {dir_obj.id} ({dir_obj.name}) set disabled = {dir_obj.disabled}",
)
Session.commit() Session.commit()
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
@ -112,7 +127,7 @@ def directory():
flash("Invalid request", "warning") flash("Invalid request", "warning")
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
dir_id = update_dir_form.directory_id.data dir_id = update_dir_form.directory_id.data
dir_obj = Directory.get(dir_id) dir_obj: Optional[Directory] = Directory.get(dir_id)
if not dir_obj or dir_obj.user_id != current_user.id: if not dir_obj or dir_obj.user_id != current_user.id:
flash("Unknown error. Refresh the page", "warning") flash("Unknown error. Refresh the page", "warning")
@ -143,6 +158,12 @@ def directory():
for mailbox in mailboxes: for mailbox in mailboxes:
DirectoryMailbox.create(directory_id=dir_obj.id, mailbox_id=mailbox.id) DirectoryMailbox.create(directory_id=dir_obj.id, mailbox_id=mailbox.id)
mailboxes_as_str = ",".join(map(str, mailbox_ids))
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateDirectory,
message=f"Updated directory {dir_obj.id} ({dir_obj.name}) mailboxes ({mailboxes_as_str})",
)
Session.commit() Session.commit()
flash(f"Directory {dir_obj.name} has been updated", "success") flash(f"Directory {dir_obj.name} has been updated", "success")
@ -181,6 +202,11 @@ def directory():
new_dir = Directory.create( new_dir = Directory.create(
name=new_dir_name, user_id=current_user.id name=new_dir_name, user_id=current_user.id
) )
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.CreateDirectory,
message=f"New directory {new_dir.name} ({new_dir.name})",
)
except DirectoryInTrashError: except DirectoryInTrashError:
flash( flash(
f"{new_dir_name} has been used before and cannot be reused", f"{new_dir_name} has been used before and cannot be reused",

View File

@ -20,6 +20,7 @@ from app.models import (
AutoCreateRuleMailbox, AutoCreateRuleMailbox,
) )
from app.regex_utils import regex_match from app.regex_utils import regex_match
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.utils import random_string, CSRFValidationForm from app.utils import random_string, CSRFValidationForm
@ -164,6 +165,11 @@ def domain_detail(custom_domain_id):
return redirect(request.url) return redirect(request.url)
if request.form.get("form-name") == "switch-catch-all": if request.form.get("form-name") == "switch-catch-all":
custom_domain.catch_all = not custom_domain.catch_all custom_domain.catch_all = not custom_domain.catch_all
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateCustomDomain,
message=f"Switched custom domain {custom_domain.id} ({custom_domain.domain}) catch all to {custom_domain.catch_all}",
)
Session.commit() Session.commit()
if custom_domain.catch_all: if custom_domain.catch_all:
@ -182,6 +188,11 @@ def domain_detail(custom_domain_id):
elif request.form.get("form-name") == "set-name": elif request.form.get("form-name") == "set-name":
if request.form.get("action") == "save": if request.form.get("action") == "save":
custom_domain.name = request.form.get("alias-name").replace("\n", "") custom_domain.name = request.form.get("alias-name").replace("\n", "")
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateCustomDomain,
message=f"Switched custom domain {custom_domain.id} ({custom_domain.domain}) name",
)
Session.commit() Session.commit()
flash( flash(
f"Default alias name for Domain {custom_domain.domain} has been set", f"Default alias name for Domain {custom_domain.domain} has been set",
@ -189,6 +200,11 @@ def domain_detail(custom_domain_id):
) )
else: else:
custom_domain.name = None custom_domain.name = None
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateCustomDomain,
message=f"Cleared custom domain {custom_domain.id} ({custom_domain.domain}) name",
)
Session.commit() Session.commit()
flash( flash(
f"Default alias name for Domain {custom_domain.domain} has been removed", f"Default alias name for Domain {custom_domain.domain} has been removed",
@ -202,6 +218,11 @@ def domain_detail(custom_domain_id):
custom_domain.random_prefix_generation = ( custom_domain.random_prefix_generation = (
not custom_domain.random_prefix_generation not custom_domain.random_prefix_generation
) )
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateCustomDomain,
message=f"Switched custom domain {custom_domain.id} ({custom_domain.domain}) random prefix generation to {custom_domain.random_prefix_generation}",
)
Session.commit() Session.commit()
if custom_domain.random_prefix_generation: if custom_domain.random_prefix_generation:

View File

@ -149,7 +149,9 @@ def index():
) )
flash(f"Alias {email} has been deleted", "success") flash(f"Alias {email} has been deleted", "success")
elif request.form.get("form-name") == "disable-alias": elif request.form.get("form-name") == "disable-alias":
alias_utils.change_alias_status(alias, enabled=False) alias_utils.change_alias_status(
alias, enabled=False, message="Set enabled=False from dashboard"
)
Session.commit() Session.commit()
flash(f"Alias {alias.email} has been disabled", "success") flash(f"Alias {alias.email} has been disabled", "success")

View File

@ -1,6 +1,7 @@
import base64 import base64
import binascii import binascii
import json import json
from typing import Optional
from flask import render_template, request, redirect, url_for, flash from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user from flask_login import login_required, current_user
@ -15,6 +16,7 @@ from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import Mailbox from app.models import Mailbox
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.utils import CSRFValidationForm from app.utils import CSRFValidationForm
@ -151,7 +153,7 @@ def verify_with_signed_secret(request: str):
flash("Invalid link. Please delete and re-add your mailbox", "error") flash("Invalid link. Please delete and re-add your mailbox", "error")
return redirect(url_for("dashboard.mailbox_route")) return redirect(url_for("dashboard.mailbox_route"))
mailbox_id = mailbox_data[0] mailbox_id = mailbox_data[0]
mailbox = Mailbox.get(mailbox_id) mailbox: Optional[Mailbox] = Mailbox.get(mailbox_id)
if not mailbox: if not mailbox:
flash("Invalid link", "error") flash("Invalid link", "error")
return redirect(url_for("dashboard.mailbox_route")) return redirect(url_for("dashboard.mailbox_route"))
@ -161,6 +163,11 @@ def verify_with_signed_secret(request: str):
return redirect(url_for("dashboard.mailbox_route")) return redirect(url_for("dashboard.mailbox_route"))
mailbox.verified = True mailbox.verified = True
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.VerifyMailbox,
message=f"Verified mailbox {mailbox.id} ({mailbox.email})",
)
Session.commit() Session.commit()
LOG.d("Mailbox %s is verified", mailbox) LOG.d("Mailbox %s is verified", mailbox)

View File

@ -16,10 +16,11 @@ from app.db import Session
from app.email_utils import email_can_be_used_as_mailbox from app.email_utils import email_can_be_used_as_mailbox
from app.email_utils import mailbox_already_used, render, send_email from app.email_utils import mailbox_already_used, render, send_email
from app.extensions import limiter from app.extensions import limiter
from app.log import LOG from app.mailbox_utils import perform_mailbox_email_change, MailboxEmailChangeError
from app.models import Alias, AuthorizedAddress from app.models import Alias, AuthorizedAddress
from app.models import Mailbox from app.models import Mailbox
from app.pgp_utils import PGPException, load_public_key_and_check from app.pgp_utils import PGPException, load_public_key_and_check
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.utils import sanitize_email, CSRFValidationForm from app.utils import sanitize_email, CSRFValidationForm
@ -88,8 +89,12 @@ def mailbox_detail_route(mailbox_id):
flash("SPF enforcement globally not enabled", "error") flash("SPF enforcement globally not enabled", "error")
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
mailbox.force_spf = ( force_spf_value = request.form.get("spf-status") == "on"
True if request.form.get("spf-status") == "on" else False mailbox.force_spf = force_spf_value
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Set force_spf to {force_spf_value} on mailbox {mailbox_id} ({mailbox.email})",
) )
Session.commit() Session.commit()
flash( flash(
@ -113,6 +118,11 @@ def mailbox_detail_route(mailbox_id):
if AuthorizedAddress.get_by(mailbox_id=mailbox.id, email=address): if AuthorizedAddress.get_by(mailbox_id=mailbox.id, email=address):
flash(f"{address} already added", "error") flash(f"{address} already added", "error")
else: else:
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Add authorized address {address} to mailbox {mailbox_id} ({mailbox.email})",
)
AuthorizedAddress.create( AuthorizedAddress.create(
user_id=current_user.id, user_id=current_user.id,
mailbox_id=mailbox.id, mailbox_id=mailbox.id,
@ -133,6 +143,11 @@ def mailbox_detail_route(mailbox_id):
flash("Unknown error. Refresh the page", "warning") flash("Unknown error. Refresh the page", "warning")
else: else:
address = authorized_address.email address = authorized_address.email
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Remove authorized address {address} from mailbox {mailbox_id} ({mailbox.email})",
)
AuthorizedAddress.delete(authorized_address_id) AuthorizedAddress.delete(authorized_address_id)
Session.commit() Session.commit()
flash(f"{address} has been deleted", "success") flash(f"{address} has been deleted", "success")
@ -165,6 +180,11 @@ def mailbox_detail_route(mailbox_id):
except PGPException: except PGPException:
flash("Cannot add the public key, please verify it", "error") flash("Cannot add the public key, please verify it", "error")
else: else:
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Add PGP Key {mailbox.pgp_finger_print} to mailbox {mailbox_id} ({mailbox.email})",
)
Session.commit() Session.commit()
flash("Your PGP public key is saved successfully", "success") flash("Your PGP public key is saved successfully", "success")
return redirect( return redirect(
@ -172,6 +192,11 @@ def mailbox_detail_route(mailbox_id):
) )
elif request.form.get("action") == "remove": elif request.form.get("action") == "remove":
# Free user can decide to remove their added PGP key # Free user can decide to remove their added PGP key
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Remove PGP Key {mailbox.pgp_finger_print} from mailbox {mailbox_id} ({mailbox.email})",
)
mailbox.pgp_public_key = None mailbox.pgp_public_key = None
mailbox.pgp_finger_print = None mailbox.pgp_finger_print = None
mailbox.disable_pgp = False mailbox.disable_pgp = False
@ -191,9 +216,19 @@ def mailbox_detail_route(mailbox_id):
) )
else: else:
mailbox.disable_pgp = False mailbox.disable_pgp = False
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Enabled PGP for mailbox {mailbox_id} ({mailbox.email})",
)
flash(f"PGP is enabled on {mailbox.email}", "info") flash(f"PGP is enabled on {mailbox.email}", "info")
else: else:
mailbox.disable_pgp = True mailbox.disable_pgp = True
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Disabled PGP for mailbox {mailbox_id} ({mailbox.email})",
)
flash(f"PGP is disabled on {mailbox.email}", "info") flash(f"PGP is disabled on {mailbox.email}", "info")
Session.commit() Session.commit()
@ -203,6 +238,11 @@ def mailbox_detail_route(mailbox_id):
elif request.form.get("form-name") == "generic-subject": elif request.form.get("form-name") == "generic-subject":
if request.form.get("action") == "save": if request.form.get("action") == "save":
mailbox.generic_subject = request.form.get("generic-subject") mailbox.generic_subject = request.form.get("generic-subject")
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Set generic subject for mailbox {mailbox_id} ({mailbox.email})",
)
Session.commit() Session.commit()
flash("Generic subject is enabled", "success") flash("Generic subject is enabled", "success")
return redirect( return redirect(
@ -210,6 +250,11 @@ def mailbox_detail_route(mailbox_id):
) )
elif request.form.get("action") == "remove": elif request.form.get("action") == "remove":
mailbox.generic_subject = None mailbox.generic_subject = None
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Remove generic subject for mailbox {mailbox_id} ({mailbox.email})",
)
Session.commit() Session.commit()
flash("Generic subject is disabled", "success") flash("Generic subject is disabled", "success")
return redirect( return redirect(
@ -272,7 +317,7 @@ def cancel_mailbox_change_route(mailbox_id):
@dashboard_bp.route("/mailbox/confirm_change") @dashboard_bp.route("/mailbox/confirm_change")
def mailbox_confirm_change_route(): def mailbox_confirm_email_change_route():
s = TimestampSigner(MAILBOX_SECRET) s = TimestampSigner(MAILBOX_SECRET)
signed_mailbox_id = request.args.get("mailbox_id") signed_mailbox_id = request.args.get("mailbox_id")
@ -281,30 +326,20 @@ def mailbox_confirm_change_route():
except Exception: except Exception:
flash("Invalid link", "error") flash("Invalid link", "error")
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
else:
mailbox = Mailbox.get(mailbox_id)
# new_email can be None if user cancels change in the meantime res = perform_mailbox_email_change(mailbox_id)
if mailbox and mailbox.new_email:
user = mailbox.user flash(res.message, res.message_category)
if Mailbox.get_by(email=mailbox.new_email, user_id=user.id): if res.error:
flash(f"{mailbox.new_email} is already used", "error") if res.error == MailboxEmailChangeError.EmailAlreadyUsed:
return redirect( return redirect(
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox.id) url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id)
) )
elif res.error == MailboxEmailChangeError.InvalidId:
mailbox.email = mailbox.new_email
mailbox.new_email = None
# mark mailbox as verified if the change request is sent from an unverified mailbox
mailbox.verified = True
Session.commit()
LOG.d("Mailbox change %s is verified", mailbox)
flash(f"The {mailbox.email} is updated", "success")
return redirect(
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox.id)
)
else:
flash("Invalid link", "error")
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
else:
raise Exception("Unhandled MailboxEmailChangeError")
else:
return redirect(
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id)
)

View File

@ -11,6 +11,7 @@ from app.dashboard.base import dashboard_bp
from app.errors import SubdomainInTrashError from app.errors import SubdomainInTrashError
from app.log import LOG from app.log import LOG
from app.models import CustomDomain, Mailbox, SLDomain from app.models import CustomDomain, Mailbox, SLDomain
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
# Only lowercase letters, numbers, dashes (-) are currently supported # Only lowercase letters, numbers, dashes (-) are currently supported
_SUBDOMAIN_PATTERN = r"[0-9a-z-]{1,}" _SUBDOMAIN_PATTERN = r"[0-9a-z-]{1,}"
@ -102,6 +103,12 @@ def subdomain_route():
ownership_verified=True, ownership_verified=True,
commit=True, commit=True,
) )
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.CreateCustomDomain,
message=f"Create subdomain {new_custom_domain.id} ({full_domain})",
commit=True,
)
except SubdomainInTrashError: except SubdomainInTrashError:
flash( flash(
f"{full_domain} has been used before and cannot be reused", f"{full_domain} has been used before and cannot be reused",

View File

@ -32,7 +32,9 @@ def unsubscribe(alias_id):
# automatic unsubscribe, according to https://tools.ietf.org/html/rfc8058 # automatic unsubscribe, according to https://tools.ietf.org/html/rfc8058
if request.method == "POST": if request.method == "POST":
alias_utils.change_alias_status(alias, False) alias_utils.change_alias_status(
alias, enabled=False, message="Set enabled=False from unsubscribe request"
)
flash(f"Alias {alias.email} has been blocked", "success") flash(f"Alias {alias.email} has been blocked", "success")
Session.commit() Session.commit()

View File

@ -103,7 +103,9 @@ class UnsubscribeHandler:
): ):
return status.E509 return status.E509
LOG.i(f"User disabled alias {alias} via unsubscribe header") LOG.i(f"User disabled alias {alias} via unsubscribe header")
alias_utils.change_alias_status(alias, enabled=False) alias_utils.change_alias_status(
alias, enabled=False, message="Set enabled=False via unsubscribe header"
)
Session.commit() Session.commit()
enable_alias_url = config.URL + f"/dashboard/?highlight_alias_id={alias.id}" enable_alias_url = config.URL + f"/dashboard/?highlight_alias_id={alias.id}"
for mailbox in alias.mailboxes: for mailbox in alias.mailboxes:

View File

@ -1,6 +1,7 @@
import dataclasses import dataclasses
import secrets import secrets
import random import random
from enum import Enum
from typing import Optional from typing import Optional
import arrow import arrow
@ -16,6 +17,7 @@ from app.email_utils import (
from app.email_validation import is_valid_email from app.email_validation import is_valid_email
from app.log import LOG from app.log import LOG
from app.models import User, Mailbox, Job, MailboxActivation from app.models import User, Mailbox, Job, MailboxActivation
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
@dataclasses.dataclass @dataclasses.dataclass
@ -70,9 +72,15 @@ def create_mailbox(
f"User {user} has tried to create mailbox with {email} but email is invalid" f"User {user} has tried to create mailbox with {email} but email is invalid"
) )
raise MailboxError("Invalid email") raise MailboxError("Invalid email")
new_mailbox = Mailbox.create( new_mailbox: Mailbox = Mailbox.create(
email=email, user_id=user.id, verified=verified, commit=True email=email, user_id=user.id, verified=verified, commit=True
) )
emit_user_audit_log(
user=user,
action=UserAuditLogAction.CreateMailbox,
message=f"Create mailbox {new_mailbox.id} ({new_mailbox.email}). Verified={verified}",
commit=True,
)
if verified: if verified:
LOG.i(f"User {user} as created a pre-verified mailbox with {email}") LOG.i(f"User {user} as created a pre-verified mailbox with {email}")
@ -129,7 +137,7 @@ def delete_mailbox(
if not transfer_mailbox.verified: if not transfer_mailbox.verified:
LOG.i(f"User {user} has tried to transfer to a non verified mailbox") LOG.i(f"User {user} has tried to transfer to a non verified mailbox")
MailboxError("Your new mailbox is not verified") raise MailboxError("Your new mailbox is not verified")
# Schedule delete account job # Schedule delete account job
LOG.i( LOG.i(
@ -204,6 +212,11 @@ def verify_mailbox_code(user: User, mailbox_id: int, code: str) -> Mailbox:
raise CannotVerifyError("Invalid activation code") raise CannotVerifyError("Invalid activation code")
LOG.i(f"User {user} has verified mailbox {mailbox_id}") LOG.i(f"User {user} has verified mailbox {mailbox_id}")
mailbox.verified = True mailbox.verified = True
emit_user_audit_log(
user=user,
action=UserAuditLogAction.VerifyMailbox,
message=f"Verify mailbox {mailbox_id} ({mailbox.email})",
)
clear_activation_codes_for_mailbox(mailbox) clear_activation_codes_for_mailbox(mailbox)
return mailbox return mailbox
@ -261,3 +274,54 @@ def send_verification_email(
mailbox_email=mailbox.email, mailbox_email=mailbox.email,
), ),
) )
class MailboxEmailChangeError(Enum):
InvalidId = 1
EmailAlreadyUsed = 2
@dataclasses.dataclass
class MailboxEmailChangeResult:
error: Optional[MailboxEmailChangeError]
message: str
message_category: str
def perform_mailbox_email_change(mailbox_id: int) -> MailboxEmailChangeResult:
mailbox: Optional[Mailbox] = Mailbox.get(mailbox_id)
# new_email can be None if user cancels change in the meantime
if mailbox and mailbox.new_email:
user = mailbox.user
if Mailbox.get_by(email=mailbox.new_email, user_id=user.id):
return MailboxEmailChangeResult(
error=MailboxEmailChangeError.EmailAlreadyUsed,
message=f"{mailbox.new_email} is already used",
message_category="error",
)
emit_user_audit_log(
user=user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Change mailbox email for mailbox {mailbox_id} (old={mailbox.email} | new={mailbox.new_email})",
)
mailbox.email = mailbox.new_email
mailbox.new_email = None
# mark mailbox as verified if the change request is sent from an unverified mailbox
mailbox.verified = True
Session.commit()
LOG.d("Mailbox change %s is verified", mailbox)
return MailboxEmailChangeResult(
error=None,
message=f"The {mailbox.email} is updated",
message_category="success",
)
else:
return MailboxEmailChangeResult(
error=MailboxEmailChangeError.InvalidId,
message="Invalid link",
message_category="error",
)

View File

@ -1673,6 +1673,7 @@ class Alias(Base, ModelMixin):
Session.flush() Session.flush()
# Internal import to avoid global import cycles # Internal import to avoid global import cycles
from app.alias_audit_log_utils import AliasAuditLogAction, emit_alias_audit_log
from app.events.event_dispatcher import EventDispatcher from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import AliasCreated, EventContent from app.events.generated.event_pb2 import AliasCreated, EventContent
@ -1684,6 +1685,9 @@ class Alias(Base, ModelMixin):
created_at=int(new_alias.created_at.timestamp), created_at=int(new_alias.created_at.timestamp),
) )
EventDispatcher.send_event(user, EventContent(alias_created=event)) EventDispatcher.send_event(user, EventContent(alias_created=event))
emit_alias_audit_log(
new_alias, AliasAuditLogAction.CreateAlias, "New alias created"
)
return new_alias return new_alias
@ -3767,15 +3771,14 @@ class SyncEvent(Base, ModelMixin):
sa.Index("ix_sync_event_taken_time", "taken_time"), sa.Index("ix_sync_event_taken_time", "taken_time"),
) )
def mark_as_taken(self) -> bool: def mark_as_taken(self, allow_taken_older_than: Optional[Arrow] = None) -> bool:
sql = """ taken_condition = ["taken_time IS NULL"]
UPDATE sync_event
SET taken_time = :taken_time
WHERE id = :sync_event_id
AND taken_time IS NULL
"""
args = {"taken_time": arrow.now().datetime, "sync_event_id": self.id} args = {"taken_time": arrow.now().datetime, "sync_event_id": self.id}
if allow_taken_older_than:
taken_condition.append("taken_time < :taken_older_than")
args["taken_older_than"] = allow_taken_older_than.datetime
sql_taken_condition = "({})".format(" OR ".join(taken_condition))
sql = f"UPDATE sync_event SET taken_time = :taken_time WHERE id = :sync_event_id AND {sql_taken_condition}"
res = Session.execute(sql, args) res = Session.execute(sql, args)
Session.commit() Session.commit()
@ -3801,3 +3804,39 @@ class SyncEvent(Base, ModelMixin):
.limit(100) .limit(100)
.all() .all()
) )
class AliasAuditLog(Base, ModelMixin):
"""This model holds an audit log for all the actions performed to an alias"""
__tablename__ = "alias_audit_log"
user_id = sa.Column(sa.Integer, nullable=False)
alias_id = sa.Column(sa.Integer, nullable=False)
alias_email = sa.Column(sa.String(255), nullable=False)
action = sa.Column(sa.String(255), nullable=False)
message = sa.Column(sa.Text, default=None, nullable=True)
__table_args__ = (
sa.Index("ix_alias_audit_log_user_id", "user_id"),
sa.Index("ix_alias_audit_log_alias_id", "alias_id"),
sa.Index("ix_alias_audit_log_alias_email", "alias_email"),
sa.Index("ix_alias_audit_log_created_at", "created_at"),
)
class UserAuditLog(Base, ModelMixin):
"""This model holds an audit log for all the actions performed by a user"""
__tablename__ = "user_audit_log"
user_id = sa.Column(sa.Integer, nullable=False)
user_email = sa.Column(sa.String(255), nullable=False)
action = sa.Column(sa.String(255), nullable=False)
message = sa.Column(sa.Text, default=None, nullable=True)
__table_args__ = (
sa.Index("ix_user_audit_log_user_id", "user_id"),
sa.Index("ix_user_audit_log_user_email", "user_email"),
sa.Index("ix_user_audit_log_created_at", "created_at"),
)

View File

@ -0,0 +1,46 @@
from typing import Optional
from arrow import Arrow
from app.models import PartnerUser, PartnerSubscription, User
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
def create_partner_user(
user: User, partner_id: int, partner_email: str, external_user_id: str
) -> PartnerUser:
instance = PartnerUser.create(
user_id=user.id,
partner_id=partner_id,
partner_email=partner_email,
external_user_id=external_user_id,
)
emit_user_audit_log(
user=user,
action=UserAuditLogAction.LinkAccount,
message=f"Linked account to partner_id={partner_id} | partner_email={partner_email} | external_user_id={external_user_id}",
)
return instance
def create_partner_subscription(
partner_user: PartnerUser,
expiration: Optional[Arrow],
msg: Optional[str] = None,
) -> PartnerSubscription:
instance = PartnerSubscription.create(
partner_user_id=partner_user.id,
end_at=expiration,
)
message = "User upgraded through partner subscription"
if msg:
message += f" | {msg}"
emit_user_audit_log(
user=partner_user.user,
action=UserAuditLogAction.Upgrade,
message=message,
)
return instance

View File

View File

@ -0,0 +1,121 @@
from typing import Optional
import arrow
from coinbase_commerce.error import WebhookInvalidPayload, SignatureVerificationError
from coinbase_commerce.webhook import Webhook
from flask import Flask, request
from app.config import COINBASE_WEBHOOK_SECRET
from app.db import Session
from app.email_utils import send_email, render
from app.log import LOG
from app.models import CoinbaseSubscription, User
from app.subscription_webhook import execute_subscription_webhook
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
def setup_coinbase_commerce(app: Flask):
@app.route("/coinbase", methods=["POST"])
def coinbase_webhook():
# event payload
request_data = request.data.decode("utf-8")
# webhook signature
request_sig = request.headers.get("X-CC-Webhook-Signature", None)
try:
# signature verification and event object construction
event = Webhook.construct_event(
request_data, request_sig, COINBASE_WEBHOOK_SECRET
)
except (WebhookInvalidPayload, SignatureVerificationError) as e:
LOG.e("Invalid Coinbase webhook")
return str(e), 400
LOG.d("Coinbase event %s", event)
if event["type"] == "charge:confirmed":
if handle_coinbase_event(event):
return "success", 200
else:
return "error", 400
return "success", 200
def handle_coinbase_event(event) -> bool:
server_user_id = event["data"]["metadata"]["user_id"]
try:
user_id = int(server_user_id)
except ValueError:
user_id = int(float(server_user_id))
code = event["data"]["code"]
user: Optional[User] = User.get(user_id)
if not user:
LOG.e("User not found %s", user_id)
return False
coinbase_subscription: CoinbaseSubscription = CoinbaseSubscription.get_by(
user_id=user_id
)
if not coinbase_subscription:
LOG.d("Create a coinbase subscription for %s", user)
coinbase_subscription = CoinbaseSubscription.create(
user_id=user_id, end_at=arrow.now().shift(years=1), code=code, commit=True
)
emit_user_audit_log(
user=user,
action=UserAuditLogAction.Upgrade,
message="Upgraded though Coinbase",
commit=True,
)
send_email(
user.email,
"Your SimpleLogin account has been upgraded",
render(
"transactional/coinbase/new-subscription.txt",
user=user,
coinbase_subscription=coinbase_subscription,
),
render(
"transactional/coinbase/new-subscription.html",
user=user,
coinbase_subscription=coinbase_subscription,
),
)
else:
if coinbase_subscription.code != code:
LOG.d("Update code from %s to %s", coinbase_subscription.code, code)
coinbase_subscription.code = code
if coinbase_subscription.is_active():
coinbase_subscription.end_at = coinbase_subscription.end_at.shift(years=1)
else: # already expired subscription
coinbase_subscription.end_at = arrow.now().shift(years=1)
emit_user_audit_log(
user=user,
action=UserAuditLogAction.SubscriptionExtended,
message="Extended coinbase subscription",
)
Session.commit()
send_email(
user.email,
"Your SimpleLogin account has been extended",
render(
"transactional/coinbase/extend-subscription.txt",
user=user,
coinbase_subscription=coinbase_subscription,
),
render(
"transactional/coinbase/extend-subscription.html",
user=user,
coinbase_subscription=coinbase_subscription,
),
)
execute_subscription_webhook(user)
return True

286
app/app/payments/paddle.py Normal file
View File

@ -0,0 +1,286 @@
import arrow
import json
from dateutil.relativedelta import relativedelta
from flask import Flask, request
from app import paddle_utils, paddle_callback
from app.config import (
PADDLE_MONTHLY_PRODUCT_ID,
PADDLE_MONTHLY_PRODUCT_IDS,
PADDLE_YEARLY_PRODUCT_IDS,
PADDLE_COUPON_ID,
)
from app.db import Session
from app.email_utils import send_email, render
from app.log import LOG
from app.models import Subscription, PlanEnum, User, Coupon
from app.subscription_webhook import execute_subscription_webhook
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.utils import random_string
def setup_paddle_callback(app: Flask):
@app.route("/paddle", methods=["GET", "POST"])
def paddle():
LOG.d(f"paddle callback {request.form.get('alert_name')} {request.form}")
# make sure the request comes from Paddle
if not paddle_utils.verify_incoming_request(dict(request.form)):
LOG.e("request not coming from paddle. Request data:%s", dict(request.form))
return "KO", 400
if (
request.form.get("alert_name") == "subscription_created"
): # new user subscribes
# the passthrough is json encoded, e.g.
# request.form.get("passthrough") = '{"user_id": 88 }'
passthrough = json.loads(request.form.get("passthrough"))
user_id = passthrough.get("user_id")
user = User.get(user_id)
subscription_plan_id = int(request.form.get("subscription_plan_id"))
if subscription_plan_id in PADDLE_MONTHLY_PRODUCT_IDS:
plan = PlanEnum.monthly
elif subscription_plan_id in PADDLE_YEARLY_PRODUCT_IDS:
plan = PlanEnum.yearly
else:
LOG.e(
"Unknown subscription_plan_id %s %s",
subscription_plan_id,
request.form,
)
return "No such subscription", 400
sub = Subscription.get_by(user_id=user.id)
if not sub:
LOG.d(f"create a new Subscription for user {user}")
Subscription.create(
user_id=user.id,
cancel_url=request.form.get("cancel_url"),
update_url=request.form.get("update_url"),
subscription_id=request.form.get("subscription_id"),
event_time=arrow.now(),
next_bill_date=arrow.get(
request.form.get("next_bill_date"), "YYYY-MM-DD"
).date(),
plan=plan,
)
emit_user_audit_log(
user=user,
action=UserAuditLogAction.Upgrade,
message="Upgraded through Paddle",
)
else:
LOG.d(f"Update an existing Subscription for user {user}")
sub.cancel_url = request.form.get("cancel_url")
sub.update_url = request.form.get("update_url")
sub.subscription_id = request.form.get("subscription_id")
sub.event_time = arrow.now()
sub.next_bill_date = arrow.get(
request.form.get("next_bill_date"), "YYYY-MM-DD"
).date()
sub.plan = plan
# make sure to set the new plan as not-cancelled
# in case user cancels a plan and subscribes a new plan
sub.cancelled = False
emit_user_audit_log(
user=user,
action=UserAuditLogAction.SubscriptionExtended,
message="Extended Paddle subscription",
)
execute_subscription_webhook(user)
LOG.d("User %s upgrades!", user)
Session.commit()
elif request.form.get("alert_name") == "subscription_payment_succeeded":
subscription_id = request.form.get("subscription_id")
LOG.d("Update subscription %s", subscription_id)
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
# when user subscribes, the "subscription_payment_succeeded" can arrive BEFORE "subscription_created"
# at that time, subscription object does not exist yet
if sub:
sub.event_time = arrow.now()
sub.next_bill_date = arrow.get(
request.form.get("next_bill_date"), "YYYY-MM-DD"
).date()
Session.commit()
execute_subscription_webhook(sub.user)
elif request.form.get("alert_name") == "subscription_cancelled":
subscription_id = request.form.get("subscription_id")
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
if sub:
# cancellation_effective_date should be the same as next_bill_date
LOG.w(
"Cancel subscription %s %s on %s, next bill date %s",
subscription_id,
sub.user,
request.form.get("cancellation_effective_date"),
sub.next_bill_date,
)
sub.event_time = arrow.now()
sub.cancelled = True
emit_user_audit_log(
user=sub.user,
action=UserAuditLogAction.SubscriptionCancelled,
message="Cancelled Paddle subscription",
)
Session.commit()
user = sub.user
send_email(
user.email,
"SimpleLogin - your subscription is canceled",
render(
"transactional/subscription-cancel.txt",
user=user,
end_date=request.form.get("cancellation_effective_date"),
),
)
execute_subscription_webhook(sub.user)
else:
# user might have deleted their account
LOG.i(f"Cancel non-exist subscription {subscription_id}")
return "OK"
elif request.form.get("alert_name") == "subscription_updated":
subscription_id = request.form.get("subscription_id")
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
if sub:
next_bill_date = request.form.get("next_bill_date")
if not next_bill_date:
paddle_callback.failed_payment(sub, subscription_id)
return "OK"
LOG.d(
"Update subscription %s %s on %s, next bill date %s",
subscription_id,
sub.user,
request.form.get("cancellation_effective_date"),
sub.next_bill_date,
)
if (
int(request.form.get("subscription_plan_id"))
== PADDLE_MONTHLY_PRODUCT_ID
):
plan = PlanEnum.monthly
else:
plan = PlanEnum.yearly
sub.cancel_url = request.form.get("cancel_url")
sub.update_url = request.form.get("update_url")
sub.event_time = arrow.now()
sub.next_bill_date = arrow.get(
request.form.get("next_bill_date"), "YYYY-MM-DD"
).date()
sub.plan = plan
# make sure to set the new plan as not-cancelled
sub.cancelled = False
emit_user_audit_log(
user=sub.user,
action=UserAuditLogAction.SubscriptionExtended,
message="Extended Paddle subscription",
)
Session.commit()
execute_subscription_webhook(sub.user)
else:
LOG.w(
f"update non-exist subscription {subscription_id}. {request.form}"
)
return "No such subscription", 400
elif request.form.get("alert_name") == "payment_refunded":
subscription_id = request.form.get("subscription_id")
LOG.d("Refund request for subscription %s", subscription_id)
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
if sub:
user = sub.user
Subscription.delete(sub.id)
emit_user_audit_log(
user=user,
action=UserAuditLogAction.SubscriptionCancelled,
message="Paddle subscription cancelled as user requested a refund",
)
Session.commit()
LOG.e("%s requests a refund", user)
execute_subscription_webhook(sub.user)
elif request.form.get("alert_name") == "subscription_payment_refunded":
subscription_id = request.form.get("subscription_id")
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
LOG.d(
"Handle subscription_payment_refunded for subscription %s",
subscription_id,
)
if not sub:
LOG.w(
"No such subscription for %s, payload %s",
subscription_id,
request.form,
)
return "No such subscription"
plan_id = int(request.form["subscription_plan_id"])
if request.form["refund_type"] == "full":
if plan_id in PADDLE_MONTHLY_PRODUCT_IDS:
LOG.d("subtract 1 month from next_bill_date %s", sub.next_bill_date)
sub.next_bill_date = sub.next_bill_date - relativedelta(months=1)
LOG.d("next_bill_date is %s", sub.next_bill_date)
Session.commit()
elif plan_id in PADDLE_YEARLY_PRODUCT_IDS:
LOG.d("subtract 1 year from next_bill_date %s", sub.next_bill_date)
sub.next_bill_date = sub.next_bill_date - relativedelta(years=1)
LOG.d("next_bill_date is %s", sub.next_bill_date)
Session.commit()
else:
LOG.e("Unknown plan_id %s", plan_id)
else:
LOG.w("partial subscription_payment_refunded, not handled")
execute_subscription_webhook(sub.user)
return "OK"
@app.route("/paddle_coupon", methods=["GET", "POST"])
def paddle_coupon():
LOG.d("paddle coupon callback %s", request.form)
if not paddle_utils.verify_incoming_request(dict(request.form)):
LOG.e("request not coming from paddle. Request data:%s", dict(request.form))
return "KO", 400
product_id = request.form.get("p_product_id")
if product_id != PADDLE_COUPON_ID:
LOG.e("product_id %s not match with %s", product_id, PADDLE_COUPON_ID)
return "KO", 400
email = request.form.get("email")
LOG.d("Paddle coupon request for %s", email)
coupon = Coupon.create(
code=random_string(30),
comment="For 1-year coupon",
expires_date=arrow.now().shift(years=1, days=-1),
commit=True,
)
return (
f"Your 1-year coupon is <b>{coupon.code}</b> <br> "
f"It's valid until <b>{coupon.expires_date.date().isoformat()}</b>"
)

View File

@ -5,6 +5,7 @@ from app.db import Session
from app.log import LOG from app.log import LOG
from app.errors import ProtonPartnerNotSetUp from app.errors import ProtonPartnerNotSetUp
from app.models import Partner, PartnerUser, User from app.models import Partner, PartnerUser, User
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
PROTON_PARTNER_NAME = "Proton" PROTON_PARTNER_NAME = "Proton"
_PROTON_PARTNER: Optional[Partner] = None _PROTON_PARTNER: Optional[Partner] = None
@ -32,6 +33,11 @@ def perform_proton_account_unlink(current_user: User):
) )
if partner_user is not None: if partner_user is not None:
LOG.info(f"User {current_user} has unlinked the account from {partner_user}") LOG.info(f"User {current_user} has unlinked the account from {partner_user}")
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UnlinkAccount,
message=f"User has unlinked the account (email={partner_user.partner_email} | external_user_id={partner_user.external_user_id})",
)
PartnerUser.delete(partner_user.id) PartnerUser.delete(partner_user.id)
Session.commit() Session.commit()
agent.record_custom_event("AccountUnlinked", {"partner": proton_partner.name}) agent.record_custom_event("AccountUnlinked", {"partner": proton_partner.name})

View File

@ -1,40 +1,16 @@
import requests
from requests import RequestException
from app import config
from app.db import Session from app.db import Session
from app.events.event_dispatcher import EventDispatcher from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import EventContent, UserPlanChanged from app.events.generated.event_pb2 import EventContent, UserPlanChanged
from app.log import LOG
from app.models import User from app.models import User
def execute_subscription_webhook(user: User): def execute_subscription_webhook(user: User):
webhook_url = config.SUBSCRIPTION_CHANGE_WEBHOOK
if webhook_url is None:
return
subscription_end = user.get_active_subscription_end( subscription_end = user.get_active_subscription_end(
include_partner_subscription=False include_partner_subscription=False
) )
sl_subscription_end = None sl_subscription_end = None
if subscription_end: if subscription_end:
sl_subscription_end = subscription_end.timestamp sl_subscription_end = subscription_end.timestamp
payload = {
"user_id": user.id,
"is_premium": user.is_premium(),
"active_subscription_end": sl_subscription_end,
}
try:
response = requests.post(webhook_url, json=payload, timeout=2)
if response.status_code == 200:
LOG.i("Sent request to subscription update webhook successfully")
else:
LOG.i(
f"Request to webhook failed with status {response.status_code}: {response.text}"
)
except RequestException as e:
LOG.error(f"Subscription request exception: {e}")
event = UserPlanChanged(plan_end_time=sl_subscription_end) event = UserPlanChanged(plan_end_time=sl_subscription_end)
EventDispatcher.send_event(user, EventContent(user_plan_change=event)) EventDispatcher.send_event(user, EventContent(user_plan_change=event))
Session.commit() Session.commit()

View File

@ -0,0 +1,40 @@
from enum import Enum
from app.models import User, UserAuditLog
class UserAuditLogAction(Enum):
Upgrade = "upgrade"
SubscriptionExtended = "subscription_extended"
SubscriptionCancelled = "subscription_cancelled"
LinkAccount = "link_account"
UnlinkAccount = "unlink_account"
CreateMailbox = "create_mailbox"
VerifyMailbox = "verify_mailbox"
UpdateMailbox = "update_mailbox"
DeleteMailbox = "delete_mailbox"
CreateCustomDomain = "create_custom_domain"
VerifyCustomDomain = "verify_custom_domain"
UpdateCustomDomain = "update_custom_domain"
DeleteCustomDomain = "delete_custom_domain"
CreateDirectory = "create_directory"
UpdateDirectory = "update_directory"
DeleteDirectory = "delete_directory"
UserMarkedForDeletion = "user_marked_for_deletion"
DeleteUser = "delete_user"
def emit_user_audit_log(
user: User, action: UserAuditLogAction, message: str, commit: bool = False
):
UserAuditLog.create(
user_id=user.id,
user_email=user.email,
action=action.value,
message=message,
commit=commit,
)

View File

@ -3,6 +3,7 @@ from typing import Optional
from app.db import Session from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import User, SLDomain, CustomDomain, Mailbox from app.models import User, SLDomain, CustomDomain, Mailbox
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
class CannotSetAlias(Exception): class CannotSetAlias(Exception):
@ -54,7 +55,7 @@ def set_default_alias_domain(user: User, domain_name: Optional[str]):
def set_default_mailbox(user: User, mailbox_id: int) -> Mailbox: def set_default_mailbox(user: User, mailbox_id: int) -> Mailbox:
mailbox = Mailbox.get(mailbox_id) mailbox: Optional[Mailbox] = Mailbox.get(mailbox_id)
if not mailbox or mailbox.user_id != user.id: if not mailbox or mailbox.user_id != user.id:
raise CannotSetMailbox("Invalid mailbox") raise CannotSetMailbox("Invalid mailbox")
@ -67,5 +68,11 @@ def set_default_mailbox(user: User, mailbox_id: int) -> Mailbox:
LOG.i(f"User {user} has set mailbox {mailbox} as his default one") LOG.i(f"User {user} has set mailbox {mailbox} as his default one")
user.default_mailbox_id = mailbox.id user.default_mailbox_id = mailbox.id
emit_user_audit_log(
user=user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Set mailbox {mailbox.id} ({mailbox.email}) as default",
)
Session.commit() Session.commit()
return mailbox return mailbox

View File

@ -60,8 +60,11 @@ from app.models import (
) )
from app.pgp_utils import load_public_key_and_check, PGPException from app.pgp_utils import load_public_key_and_check, PGPException
from app.proton.utils import get_proton_partner from app.proton.utils import get_proton_partner
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.utils import sanitize_email from app.utils import sanitize_email
from server import create_light_app from server import create_light_app
from tasks.clean_alias_audit_log import cleanup_alias_audit_log
from tasks.clean_user_audit_log import cleanup_user_audit_log
from tasks.cleanup_old_imports import cleanup_old_imports from tasks.cleanup_old_imports import cleanup_old_imports
from tasks.cleanup_old_jobs import cleanup_old_jobs from tasks.cleanup_old_jobs import cleanup_old_jobs
from tasks.cleanup_old_notifications import cleanup_old_notifications from tasks.cleanup_old_notifications import cleanup_old_notifications
@ -1218,7 +1221,7 @@ def notify_hibp():
def clear_users_scheduled_to_be_deleted(dry_run=False): def clear_users_scheduled_to_be_deleted(dry_run=False):
users = User.filter( users: List[User] = User.filter(
and_( and_(
User.delete_on.isnot(None), User.delete_on.isnot(None),
User.delete_on <= arrow.now().shift(days=-DELETE_GRACE_DAYS), User.delete_on <= arrow.now().shift(days=-DELETE_GRACE_DAYS),
@ -1230,6 +1233,11 @@ def clear_users_scheduled_to_be_deleted(dry_run=False):
) )
if dry_run: if dry_run:
continue continue
emit_user_audit_log(
user=user,
action=UserAuditLogAction.DeleteUser,
message=f"Delete user {user.id} ({user.email})",
)
User.delete(user.id) User.delete(user.id)
Session.commit() Session.commit()
@ -1241,6 +1249,16 @@ def delete_old_data():
cleanup_old_notifications(oldest_valid) cleanup_old_notifications(oldest_valid)
def clear_alias_audit_log():
oldest_valid = arrow.now().shift(days=-config.AUDIT_LOG_MAX_DAYS)
cleanup_alias_audit_log(oldest_valid)
def clear_user_audit_log():
oldest_valid = arrow.now().shift(days=-config.AUDIT_LOG_MAX_DAYS)
cleanup_user_audit_log(oldest_valid)
if __name__ == "__main__": if __name__ == "__main__":
LOG.d("Start running cronjob") LOG.d("Start running cronjob")
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -1249,22 +1267,6 @@ if __name__ == "__main__":
"--job", "--job",
help="Choose a cron job to run", help="Choose a cron job to run",
type=str, type=str,
choices=[
"stats",
"notify_trial_end",
"notify_manual_subscription_end",
"notify_premium_end",
"delete_logs",
"delete_old_data",
"poll_apple_subscription",
"sanity_check",
"delete_old_monitoring",
"check_custom_domain",
"check_hibp",
"notify_hibp",
"cleanup_tokens",
"send_undelivered_mails",
],
) )
args = parser.parse_args() args = parser.parse_args()
# wrap in an app context to benefit from app setup like database cleanup, sentry integration, etc # wrap in an app context to benefit from app setup like database cleanup, sentry integration, etc
@ -1313,4 +1315,10 @@ if __name__ == "__main__":
load_unsent_mails_from_fs_and_resend() load_unsent_mails_from_fs_and_resend()
elif args.job == "delete_scheduled_users": elif args.job == "delete_scheduled_users":
LOG.d("Deleting users scheduled to be deleted") LOG.d("Deleting users scheduled to be deleted")
clear_users_scheduled_to_be_deleted(dry_run=True) clear_users_scheduled_to_be_deleted()
elif args.job == "clear_alias_audit_log":
LOG.d("Clearing alias audit log")
clear_alias_audit_log()
elif args.job == "clear_user_audit_log":
LOG.d("Clearing user audit log")
clear_user_audit_log()

View File

@ -80,3 +80,17 @@ jobs:
schedule: "*/5 * * * *" schedule: "*/5 * * * *"
captureStderr: true captureStderr: true
concurrencyPolicy: Forbid concurrencyPolicy: Forbid
- name: SimpleLogin clear alias_audit_log old entries
command: python /code/cron.py -j clear_alias_audit_log
shell: /bin/bash
schedule: "0 * * * *" # Once every hour
captureStderr: true
concurrencyPolicy: Forbid
- name: SimpleLogin clear user_audit_log old entries
command: python /code/cron.py -j clear_user_audit_log
shell: /bin/bash
schedule: "0 * * * *" # Once every hour
captureStderr: true
concurrencyPolicy: Forbid

View File

@ -1502,7 +1502,9 @@ def handle_bounce_forward_phase(msg: Message, email_log: EmailLog):
LOG.w( LOG.w(
f"Disable alias {alias} because {reason}. {alias.mailboxes} {alias.user}. Last contact {contact}" f"Disable alias {alias} because {reason}. {alias.mailboxes} {alias.user}. Last contact {contact}"
) )
change_alias_status(alias, enabled=False) change_alias_status(
alias, enabled=False, message=f"Set enabled=False due to {reason}"
)
Notification.create( Notification.create(
user_id=user.id, user_id=user.id,

View File

@ -85,24 +85,28 @@ class DeadLetterEventSource(EventSource):
def __init__(self, max_retries: int): def __init__(self, max_retries: int):
self.__max_retries = max_retries self.__max_retries = max_retries
@newrelic.agent.background_task() def execute_loop(
def run(self, on_event: Callable[[SyncEvent], NoReturn]): self, on_event: Callable[[SyncEvent], NoReturn]
while True: ) -> list[SyncEvent]:
try: threshold = arrow.utcnow().shift(minutes=-_DEAD_LETTER_THRESHOLD_MINUTES)
threshold = arrow.utcnow().shift(
minutes=-_DEAD_LETTER_THRESHOLD_MINUTES
)
events = SyncEvent.get_dead_letter( events = SyncEvent.get_dead_letter(
older_than=threshold, max_retries=self.__max_retries older_than=threshold, max_retries=self.__max_retries
) )
if events: if events:
LOG.info(f"Got {len(events)} dead letter events") LOG.info(f"Got {len(events)} dead letter events")
if events:
newrelic.agent.record_custom_metric( newrelic.agent.record_custom_metric(
"Custom/dead_letter_events_to_process", len(events) "Custom/dead_letter_events_to_process", len(events)
) )
for event in events: for event in events:
if event.mark_as_taken(allow_taken_older_than=threshold):
on_event(event) on_event(event)
return events
@newrelic.agent.background_task()
def run(self, on_event: Callable[[SyncEvent], NoReturn]):
while True:
try:
events = self.execute_loop(on_event)
Session.close() # Ensure that we have a new connection and we don't have a dangling tx with a lock Session.close() # Ensure that we have a new connection and we don't have a dangling tx with a lock
if not events: if not events:
LOG.debug("No dead letter events") LOG.debug("No dead letter events")

View File

@ -20,6 +20,7 @@ from app.jobs.event_jobs import send_alias_creation_events_for_user
from app.jobs.export_user_data_job import ExportUserDataJob from app.jobs.export_user_data_job import ExportUserDataJob
from app.log import LOG from app.log import LOG
from app.models import User, Job, BatchImport, Mailbox, CustomDomain, JobState from app.models import User, Job, BatchImport, Mailbox, CustomDomain, JobState
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from server import create_light_app from server import create_light_app
@ -128,7 +129,7 @@ def welcome_proton(user):
def delete_mailbox_job(job: Job): def delete_mailbox_job(job: Job):
mailbox_id = job.payload.get("mailbox_id") mailbox_id = job.payload.get("mailbox_id")
mailbox = Mailbox.get(mailbox_id) mailbox: Optional[Mailbox] = Mailbox.get(mailbox_id)
if not mailbox: if not mailbox:
return return
@ -152,6 +153,12 @@ def delete_mailbox_job(job: Job):
mailbox_email = mailbox.email mailbox_email = mailbox.email
user = mailbox.user user = mailbox.user
emit_user_audit_log(
user=user,
action=UserAuditLogAction.DeleteMailbox,
message=f"Delete mailbox {mailbox.id} ({mailbox.email})",
)
Mailbox.delete(mailbox_id) Mailbox.delete(mailbox_id)
Session.commit() Session.commit()
LOG.d("Mailbox %s %s deleted", mailbox_id, mailbox_email) LOG.d("Mailbox %s %s deleted", mailbox_id, mailbox_email)
@ -244,6 +251,7 @@ def process_job(job: Job):
if not custom_domain: if not custom_domain:
return return
is_subdomain = custom_domain.is_sl_subdomain
domain_name = custom_domain.domain domain_name = custom_domain.domain
user = custom_domain.user user = custom_domain.user
@ -251,6 +259,16 @@ def process_job(job: Job):
CustomDomain.delete(custom_domain.id) CustomDomain.delete(custom_domain.id)
Session.commit() Session.commit()
if is_subdomain:
message = f"Delete subdomain {custom_domain_id} ({domain_name})"
else:
message = f"Delete custom domain {custom_domain_id} ({domain_name})"
emit_user_audit_log(
user=user,
action=UserAuditLogAction.DeleteCustomDomain,
message=message,
)
LOG.d("Domain %s deleted", domain_name) LOG.d("Domain %s deleted", domain_name)
if custom_domain_partner_id is None: if custom_domain_partner_id is None:

View File

@ -0,0 +1,45 @@
"""alias_audit_log
Revision ID: 91ed7f46dc81
Revises: 62afa3a10010
Create Date: 2024-10-11 13:22:11.594054
"""
import sqlalchemy_utils
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '91ed7f46dc81'
down_revision = '62afa3a10010'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('alias_audit_log',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('created_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=False),
sa.Column('updated_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('alias_id', sa.Integer(), nullable=False),
sa.Column('alias_email', sa.String(length=255), nullable=False),
sa.Column('action', sa.String(length=255), nullable=False),
sa.Column('message', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index('ix_alias_audit_log_alias_email', 'alias_audit_log', ['alias_email'], unique=False)
op.create_index('ix_alias_audit_log_alias_id', 'alias_audit_log', ['alias_id'], unique=False)
op.create_index('ix_alias_audit_log_user_id', 'alias_audit_log', ['user_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index('ix_alias_audit_log_user_id', table_name='alias_audit_log')
op.drop_index('ix_alias_audit_log_alias_id', table_name='alias_audit_log')
op.drop_index('ix_alias_audit_log_alias_email', table_name='alias_audit_log')
op.drop_table('alias_audit_log')
# ### end Alembic commands ###

View File

@ -0,0 +1,44 @@
"""user_audit_log
Revision ID: 7d7b84779837
Revises: 91ed7f46dc81
Create Date: 2024-10-16 11:52:49.128644
"""
import sqlalchemy_utils
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '7d7b84779837'
down_revision = '91ed7f46dc81'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('user_audit_log',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('created_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=False),
sa.Column('updated_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('user_email', sa.String(length=255), nullable=False),
sa.Column('action', sa.String(length=255), nullable=False),
sa.Column('message', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index('ix_user_audit_log_user_email', 'user_audit_log', ['user_email'], unique=False)
op.create_index('ix_user_audit_log_user_id', 'user_audit_log', ['user_id'], unique=False)
op.create_index('ix_user_audit_log_created_at', 'user_audit_log', ['created_at'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index('ix_user_audit_log_user_id', table_name='user_audit_log')
op.drop_index('ix_user_audit_log_user_email', table_name='user_audit_log')
op.drop_index('ix_user_audit_log_created_at', table_name='user_audit_log')
op.drop_table('user_audit_log')
# ### end Alembic commands ###

View File

@ -0,0 +1,27 @@
"""alias_audit_log_index_created_at
Revision ID: 32f25cbf12f6
Revises: 7d7b84779837
Create Date: 2024-10-16 16:45:36.827161
"""
import sqlalchemy_utils
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '32f25cbf12f6'
down_revision = '7d7b84779837'
branch_labels = None
depends_on = None
def upgrade():
with op.get_context().autocommit_block():
op.create_index('ix_alias_audit_log_created_at', 'alias_audit_log', ['created_at'], unique=False, postgresql_concurrently=True)
def downgrade():
with op.get_context().autocommit_block():
op.drop_index('ix_alias_audit_log_created_at', table_name='alias_audit_log', postgresql_concurrently=True)

View File

@ -21,7 +21,7 @@ if max_alias_id == 0:
max_alias_id = Session.query(func.max(Alias.id)).scalar() max_alias_id = Session.query(func.max(Alias.id)).scalar()
print(f"Checking alias {alias_id_start} to {max_alias_id}") print(f"Checking alias {alias_id_start} to {max_alias_id}")
step = 1000 step = 10000
noteSql = "(note = 'Created through Proton' or note = 'Created through partner Proton')" noteSql = "(note = 'Created through Proton' or note = 'Created through partner Proton')"
alias_query = f"UPDATE alias set note = NULL, flags = flags | :flag where id>=:start AND id<:end and {noteSql}" alias_query = f"UPDATE alias set note = NULL, flags = flags | :flag where id>=:start AND id<:end and {noteSql}"
updated = 0 updated = 0
@ -38,12 +38,12 @@ for batch_start in range(alias_id_start, max_alias_id, step):
updated += rows_done.rowcount updated += rows_done.rowcount
Session.commit() Session.commit()
elapsed = time.time() - start_time elapsed = time.time() - start_time
time_per_alias = elapsed / (updated + 1)
last_batch_id = batch_start + step last_batch_id = batch_start + step
time_per_alias = elapsed / (last_batch_id)
remaining = max_alias_id - last_batch_id remaining = max_alias_id - last_batch_id
time_remaining = (max_alias_id - last_batch_id) * time_per_alias time_remaining = remaining / time_per_alias
hours_remaining = time_remaining / 3600.0 hours_remaining = time_remaining / 60.0
print( print(
f"\rAlias {batch_start}/{max_alias_id} {updated} {hours_remaining:.2f}hrs remaining" f"\rAlias {batch_start}/{max_alias_id} {updated} {hours_remaining:.2f} mins remaining"
) )
print("") print("")

View File

@ -0,0 +1,63 @@
#!/usr/bin/env python3
import argparse
import time
from sqlalchemy import func
from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import UserPlanChanged, EventContent
from app.models import PartnerUser
from app.db import Session
parser = argparse.ArgumentParser(
prog="Backfill alias", description="Update alias notes and backfill flag"
)
parser.add_argument(
"-s", "--start_pu_id", default=0, type=int, help="Initial partner_user_id"
)
parser.add_argument(
"-e", "--end_pu_id", default=0, type=int, help="Last partner_user_id"
)
args = parser.parse_args()
pu_id_start = args.start_pu_id
max_pu_id = args.end_pu_id
if max_pu_id == 0:
max_pu_id = Session.query(func.max(PartnerUser.id)).scalar()
print(f"Checking partner user {pu_id_start} to {max_pu_id}")
step = 100
updated = 0
start_time = time.time()
with_premium = 0
for batch_start in range(pu_id_start, max_pu_id, step):
partner_users = (
Session.query(PartnerUser).filter(
PartnerUser.id >= batch_start, PartnerUser.id < batch_start + step
)
).all()
for partner_user in partner_users:
subscription_end = partner_user.user.get_active_subscription_end(
include_partner_subscription=False
)
end_timestamp = None
if subscription_end:
with_premium += 1
end_timestamp = subscription_end.timestamp
event = UserPlanChanged(plan_end_time=end_timestamp)
EventDispatcher.send_event(
partner_user.user, EventContent(user_plan_change=event)
)
Session.flush()
updated += 1
Session.commit()
elapsed = time.time() - start_time
last_batch_id = batch_start + step
time_per_alias = elapsed / (last_batch_id)
remaining = max_pu_id - last_batch_id
time_remaining = remaining / time_per_alias
hours_remaining = time_remaining / 60.0
print(
f"\PartnerUser {batch_start}/{max_pu_id} {updated} {hours_remaining:.2f} mins remaining"
)
print(f"With SL premium {with_premium}")

View File

@ -1,4 +1,3 @@
import json
import os import os
import time import time
from datetime import timedelta from datetime import timedelta
@ -9,9 +8,7 @@ import flask_limiter
import flask_profiler import flask_profiler
import newrelic.agent import newrelic.agent
import sentry_sdk import sentry_sdk
from coinbase_commerce.error import WebhookInvalidPayload, SignatureVerificationError
from coinbase_commerce.webhook import Webhook
from dateutil.relativedelta import relativedelta
from flask import ( from flask import (
Flask, Flask,
redirect, redirect,
@ -30,7 +27,7 @@ from sentry_sdk.integrations.flask import FlaskIntegration
from sentry_sdk.integrations.sqlalchemy import SqlalchemyIntegration from sentry_sdk.integrations.sqlalchemy import SqlalchemyIntegration
from werkzeug.middleware.proxy_fix import ProxyFix from werkzeug.middleware.proxy_fix import ProxyFix
from app import paddle_utils, config, paddle_callback, constants from app import config, constants
from app.admin_model import ( from app.admin_model import (
SLAdminIndexView, SLAdminIndexView,
UserAdmin, UserAdmin,
@ -56,7 +53,6 @@ from app.config import (
FLASK_SECRET, FLASK_SECRET,
SENTRY_DSN, SENTRY_DSN,
URL, URL,
PADDLE_MONTHLY_PRODUCT_ID,
FLASK_PROFILER_PATH, FLASK_PROFILER_PATH,
FLASK_PROFILER_PASSWORD, FLASK_PROFILER_PASSWORD,
SENTRY_FRONT_END_DSN, SENTRY_FRONT_END_DSN,
@ -70,22 +66,16 @@ from app.config import (
LANDING_PAGE_URL, LANDING_PAGE_URL,
STATUS_PAGE_URL, STATUS_PAGE_URL,
SUPPORT_EMAIL, SUPPORT_EMAIL,
PADDLE_MONTHLY_PRODUCT_IDS,
PADDLE_YEARLY_PRODUCT_IDS,
PGP_SIGNER, PGP_SIGNER,
COINBASE_WEBHOOK_SECRET,
PAGE_LIMIT, PAGE_LIMIT,
PADDLE_COUPON_ID,
ZENDESK_ENABLED, ZENDESK_ENABLED,
MAX_NB_EMAIL_FREE_PLAN, MAX_NB_EMAIL_FREE_PLAN,
MEM_STORE_URI, MEM_STORE_URI,
) )
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.subscription_webhook import execute_subscription_webhook
from app.db import Session from app.db import Session
from app.developer.base import developer_bp from app.developer.base import developer_bp
from app.discover.base import discover_bp from app.discover.base import discover_bp
from app.email_utils import send_email, render
from app.extensions import login_manager, limiter from app.extensions import login_manager, limiter
from app.fake_data import fake_data from app.fake_data import fake_data
from app.internal.base import internal_bp from app.internal.base import internal_bp
@ -94,11 +84,8 @@ from app.log import LOG
from app.models import ( from app.models import (
User, User,
Alias, Alias,
Subscription,
PlanEnum,
CustomDomain, CustomDomain,
Mailbox, Mailbox,
CoinbaseSubscription,
EmailLog, EmailLog,
Contact, Contact,
ManualSubscription, ManualSubscription,
@ -115,10 +102,11 @@ from app.monitor.base import monitor_bp
from app.newsletter_utils import send_newsletter_to_user from app.newsletter_utils import send_newsletter_to_user
from app.oauth.base import oauth_bp from app.oauth.base import oauth_bp
from app.onboarding.base import onboarding_bp from app.onboarding.base import onboarding_bp
from app.payments.coinbase import setup_coinbase_commerce
from app.payments.paddle import setup_paddle_callback
from app.phone.base import phone_bp from app.phone.base import phone_bp
from app.redis_services import initialize_redis_services from app.redis_services import initialize_redis_services
from app.sentry_utils import sentry_before_send from app.sentry_utils import sentry_before_send
from app.utils import random_string
if SENTRY_DSN: if SENTRY_DSN:
LOG.d("enable sentry") LOG.d("enable sentry")
@ -446,341 +434,6 @@ def jinja2_filter(app):
) )
def setup_paddle_callback(app: Flask):
@app.route("/paddle", methods=["GET", "POST"])
def paddle():
LOG.d(f"paddle callback {request.form.get('alert_name')} {request.form}")
# make sure the request comes from Paddle
if not paddle_utils.verify_incoming_request(dict(request.form)):
LOG.e("request not coming from paddle. Request data:%s", dict(request.form))
return "KO", 400
if (
request.form.get("alert_name") == "subscription_created"
): # new user subscribes
# the passthrough is json encoded, e.g.
# request.form.get("passthrough") = '{"user_id": 88 }'
passthrough = json.loads(request.form.get("passthrough"))
user_id = passthrough.get("user_id")
user = User.get(user_id)
subscription_plan_id = int(request.form.get("subscription_plan_id"))
if subscription_plan_id in PADDLE_MONTHLY_PRODUCT_IDS:
plan = PlanEnum.monthly
elif subscription_plan_id in PADDLE_YEARLY_PRODUCT_IDS:
plan = PlanEnum.yearly
else:
LOG.e(
"Unknown subscription_plan_id %s %s",
subscription_plan_id,
request.form,
)
return "No such subscription", 400
sub = Subscription.get_by(user_id=user.id)
if not sub:
LOG.d(f"create a new Subscription for user {user}")
Subscription.create(
user_id=user.id,
cancel_url=request.form.get("cancel_url"),
update_url=request.form.get("update_url"),
subscription_id=request.form.get("subscription_id"),
event_time=arrow.now(),
next_bill_date=arrow.get(
request.form.get("next_bill_date"), "YYYY-MM-DD"
).date(),
plan=plan,
)
else:
LOG.d(f"Update an existing Subscription for user {user}")
sub.cancel_url = request.form.get("cancel_url")
sub.update_url = request.form.get("update_url")
sub.subscription_id = request.form.get("subscription_id")
sub.event_time = arrow.now()
sub.next_bill_date = arrow.get(
request.form.get("next_bill_date"), "YYYY-MM-DD"
).date()
sub.plan = plan
# make sure to set the new plan as not-cancelled
# in case user cancels a plan and subscribes a new plan
sub.cancelled = False
execute_subscription_webhook(user)
LOG.d("User %s upgrades!", user)
Session.commit()
elif request.form.get("alert_name") == "subscription_payment_succeeded":
subscription_id = request.form.get("subscription_id")
LOG.d("Update subscription %s", subscription_id)
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
# when user subscribes, the "subscription_payment_succeeded" can arrive BEFORE "subscription_created"
# at that time, subscription object does not exist yet
if sub:
sub.event_time = arrow.now()
sub.next_bill_date = arrow.get(
request.form.get("next_bill_date"), "YYYY-MM-DD"
).date()
Session.commit()
execute_subscription_webhook(sub.user)
elif request.form.get("alert_name") == "subscription_cancelled":
subscription_id = request.form.get("subscription_id")
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
if sub:
# cancellation_effective_date should be the same as next_bill_date
LOG.w(
"Cancel subscription %s %s on %s, next bill date %s",
subscription_id,
sub.user,
request.form.get("cancellation_effective_date"),
sub.next_bill_date,
)
sub.event_time = arrow.now()
sub.cancelled = True
Session.commit()
user = sub.user
send_email(
user.email,
"SimpleLogin - your subscription is canceled",
render(
"transactional/subscription-cancel.txt",
user=user,
end_date=request.form.get("cancellation_effective_date"),
),
)
execute_subscription_webhook(sub.user)
else:
# user might have deleted their account
LOG.i(f"Cancel non-exist subscription {subscription_id}")
return "OK"
elif request.form.get("alert_name") == "subscription_updated":
subscription_id = request.form.get("subscription_id")
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
if sub:
next_bill_date = request.form.get("next_bill_date")
if not next_bill_date:
paddle_callback.failed_payment(sub, subscription_id)
return "OK"
LOG.d(
"Update subscription %s %s on %s, next bill date %s",
subscription_id,
sub.user,
request.form.get("cancellation_effective_date"),
sub.next_bill_date,
)
if (
int(request.form.get("subscription_plan_id"))
== PADDLE_MONTHLY_PRODUCT_ID
):
plan = PlanEnum.monthly
else:
plan = PlanEnum.yearly
sub.cancel_url = request.form.get("cancel_url")
sub.update_url = request.form.get("update_url")
sub.event_time = arrow.now()
sub.next_bill_date = arrow.get(
request.form.get("next_bill_date"), "YYYY-MM-DD"
).date()
sub.plan = plan
# make sure to set the new plan as not-cancelled
sub.cancelled = False
Session.commit()
execute_subscription_webhook(sub.user)
else:
LOG.w(
f"update non-exist subscription {subscription_id}. {request.form}"
)
return "No such subscription", 400
elif request.form.get("alert_name") == "payment_refunded":
subscription_id = request.form.get("subscription_id")
LOG.d("Refund request for subscription %s", subscription_id)
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
if sub:
user = sub.user
Subscription.delete(sub.id)
Session.commit()
LOG.e("%s requests a refund", user)
execute_subscription_webhook(sub.user)
elif request.form.get("alert_name") == "subscription_payment_refunded":
subscription_id = request.form.get("subscription_id")
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
LOG.d(
"Handle subscription_payment_refunded for subscription %s",
subscription_id,
)
if not sub:
LOG.w(
"No such subscription for %s, payload %s",
subscription_id,
request.form,
)
return "No such subscription"
plan_id = int(request.form["subscription_plan_id"])
if request.form["refund_type"] == "full":
if plan_id in PADDLE_MONTHLY_PRODUCT_IDS:
LOG.d("subtract 1 month from next_bill_date %s", sub.next_bill_date)
sub.next_bill_date = sub.next_bill_date - relativedelta(months=1)
LOG.d("next_bill_date is %s", sub.next_bill_date)
Session.commit()
elif plan_id in PADDLE_YEARLY_PRODUCT_IDS:
LOG.d("subtract 1 year from next_bill_date %s", sub.next_bill_date)
sub.next_bill_date = sub.next_bill_date - relativedelta(years=1)
LOG.d("next_bill_date is %s", sub.next_bill_date)
Session.commit()
else:
LOG.e("Unknown plan_id %s", plan_id)
else:
LOG.w("partial subscription_payment_refunded, not handled")
execute_subscription_webhook(sub.user)
return "OK"
@app.route("/paddle_coupon", methods=["GET", "POST"])
def paddle_coupon():
LOG.d("paddle coupon callback %s", request.form)
if not paddle_utils.verify_incoming_request(dict(request.form)):
LOG.e("request not coming from paddle. Request data:%s", dict(request.form))
return "KO", 400
product_id = request.form.get("p_product_id")
if product_id != PADDLE_COUPON_ID:
LOG.e("product_id %s not match with %s", product_id, PADDLE_COUPON_ID)
return "KO", 400
email = request.form.get("email")
LOG.d("Paddle coupon request for %s", email)
coupon = Coupon.create(
code=random_string(30),
comment="For 1-year coupon",
expires_date=arrow.now().shift(years=1, days=-1),
commit=True,
)
return (
f"Your 1-year coupon is <b>{coupon.code}</b> <br> "
f"It's valid until <b>{coupon.expires_date.date().isoformat()}</b>"
)
def setup_coinbase_commerce(app):
@app.route("/coinbase", methods=["POST"])
def coinbase_webhook():
# event payload
request_data = request.data.decode("utf-8")
# webhook signature
request_sig = request.headers.get("X-CC-Webhook-Signature", None)
try:
# signature verification and event object construction
event = Webhook.construct_event(
request_data, request_sig, COINBASE_WEBHOOK_SECRET
)
except (WebhookInvalidPayload, SignatureVerificationError) as e:
LOG.e("Invalid Coinbase webhook")
return str(e), 400
LOG.d("Coinbase event %s", event)
if event["type"] == "charge:confirmed":
if handle_coinbase_event(event):
return "success", 200
else:
return "error", 400
return "success", 200
def handle_coinbase_event(event) -> bool:
server_user_id = event["data"]["metadata"]["user_id"]
try:
user_id = int(server_user_id)
except ValueError:
user_id = int(float(server_user_id))
code = event["data"]["code"]
user = User.get(user_id)
if not user:
LOG.e("User not found %s", user_id)
return False
coinbase_subscription: CoinbaseSubscription = CoinbaseSubscription.get_by(
user_id=user_id
)
if not coinbase_subscription:
LOG.d("Create a coinbase subscription for %s", user)
coinbase_subscription = CoinbaseSubscription.create(
user_id=user_id, end_at=arrow.now().shift(years=1), code=code, commit=True
)
send_email(
user.email,
"Your SimpleLogin account has been upgraded",
render(
"transactional/coinbase/new-subscription.txt",
user=user,
coinbase_subscription=coinbase_subscription,
),
render(
"transactional/coinbase/new-subscription.html",
user=user,
coinbase_subscription=coinbase_subscription,
),
)
else:
if coinbase_subscription.code != code:
LOG.d("Update code from %s to %s", coinbase_subscription.code, code)
coinbase_subscription.code = code
if coinbase_subscription.is_active():
coinbase_subscription.end_at = coinbase_subscription.end_at.shift(years=1)
else: # already expired subscription
coinbase_subscription.end_at = arrow.now().shift(years=1)
Session.commit()
send_email(
user.email,
"Your SimpleLogin account has been extended",
render(
"transactional/coinbase/extend-subscription.txt",
user=user,
coinbase_subscription=coinbase_subscription,
),
render(
"transactional/coinbase/extend-subscription.html",
user=user,
coinbase_subscription=coinbase_subscription,
),
)
execute_subscription_webhook(user)
return True
def init_extensions(app: Flask): def init_extensions(app: Flask):
login_manager.init_app(app) login_manager.init_app(app)

View File

@ -0,0 +1,12 @@
import arrow
from app.db import Session
from app.log import LOG
from app.models import AliasAuditLog
def cleanup_alias_audit_log(oldest_allowed: arrow.Arrow):
LOG.i(f"Deleting alias_audit_log older than {oldest_allowed}")
count = AliasAuditLog.filter(AliasAuditLog.created_at < oldest_allowed).delete()
Session.commit()
LOG.i(f"Deleted {count} alias_audit_log entries")

View File

@ -0,0 +1,12 @@
import arrow
from app.db import Session
from app.log import LOG
from app.models import UserAuditLog
def cleanup_user_audit_log(oldest_allowed: arrow.Arrow):
LOG.i(f"Deleting user_audit_log older than {oldest_allowed}")
count = UserAuditLog.filter(UserAuditLog.created_at < oldest_allowed).delete()
Session.commit()
LOG.i(f"Deleted {count} user_audit_log entries")

View File

@ -154,6 +154,56 @@
</table> </table>
{{ show_user(data.domain_deleted_alias.domain.user) }} {{ show_user(data.domain_deleted_alias.domain.user) }}
{%- endmacro %} {%- endmacro %}
{% macro list_alias_audit_log(alias_audit_log) %}
<h4>Alias Audit Log</h4>
<table class="table">
<thead>
<tr>
<th>User ID</th>
<th>Alias ID</th>
<th>Alias Email</th>
<th>Action</th>
<th>Message</th>
<th>Time</th>
</tr>
</thead>
<tbody>
{% for entry in alias_audit_log %}
<tr>
<td>{{ entry.user_id }}</td>
<td>{{ entry.alias_id }}</td>
<td><a href="?email={{ entry.alias_email }}">{{ entry.alias_email }}</a></td>
<td>{{ entry.action }}</td>
<td>{{ entry.message }}</td>
<td>{{ entry.created_at }}</td>
</tr>
{% endfor %}
</tbody>
</table>
{% endmacro %}
{% macro list_user_audit_log(user_audit_log) %}
<h4>User Audit Log</h4>
<table class="table">
<thead>
<tr>
<th>User email</th>
<th>Action</th>
<th>Message</th>
<th>Time</th>
</tr>
</thead>
<tbody>
{% for entry in user_audit_log %}
<tr>
<td><a href="?email={{ entry.user_email }}">{{ entry.user_email }}</a></td>
<td>{{ entry.action }}</td>
<td>{{ entry.message }}</td>
<td>{{ entry.created_at }}</td>
</tr>
{% endfor %}
</tbody>
</table>
{% endmacro %}
{% block body %} {% block body %}
<div class="border border-dark border-2 mt-1 mb-2 p-3"> <div class="border border-dark border-2 mt-1 mb-2 p-3">
@ -177,6 +227,7 @@
<div class="border border-dark border-2 mt-1 mb-2 p-3"> <div class="border border-dark border-2 mt-1 mb-2 p-3">
<h3 class="mb-3">Found Alias {{ data.alias.email }}</h3> <h3 class="mb-3">Found Alias {{ data.alias.email }}</h3>
{{ list_alias(1,[data.alias]) }} {{ list_alias(1,[data.alias]) }}
{{ list_alias_audit_log(data.alias_audit_log) }}
{{ list_mailboxes("Mailboxes for alias", helper.alias_mailbox_count(data.alias), helper.alias_mailboxes(data.alias)) }} {{ list_mailboxes("Mailboxes for alias", helper.alias_mailbox_count(data.alias), helper.alias_mailboxes(data.alias)) }}
{{ show_user(data.alias.user) }} {{ show_user(data.alias.user) }}
</div> </div>
@ -188,6 +239,7 @@
{{ show_user(data.user) }} {{ show_user(data.user) }}
{{ list_mailboxes("Mailboxes for user", helper.mailbox_count(data.user) , helper.mailbox_list(data.user) ) }} {{ list_mailboxes("Mailboxes for user", helper.mailbox_count(data.user) , helper.mailbox_list(data.user) ) }}
{{ list_alias(helper.alias_count(data.user) ,helper.alias_list(data.user)) }} {{ list_alias(helper.alias_count(data.user) ,helper.alias_list(data.user)) }}
{{ list_user_audit_log(data.user_audit_log) }}
</div> </div>
{% endif %} {% endif %}
{% if data.mailbox_count > 10 %} {% if data.mailbox_count > 10 %}

View File

@ -0,0 +1,54 @@
import arrow
from app.db import Session
from app.models import SyncEvent
from events.event_source import DeadLetterEventSource, _DEAD_LETTER_THRESHOLD_MINUTES
class EventCounter:
def __init__(self):
self.processed_events = 0
def on_event(self, event: SyncEvent):
self.processed_events += 1
def setup_function(func):
Session.query(SyncEvent).delete()
def test_dead_letter_does_not_take_untaken_events():
source = DeadLetterEventSource(1)
counter = EventCounter()
threshold_time = arrow.utcnow().shift(minutes=-(_DEAD_LETTER_THRESHOLD_MINUTES) + 1)
SyncEvent.create(
content="test".encode("utf-8"), created_at=threshold_time, flush=True
)
SyncEvent.create(
content="test".encode("utf-8"), taken_time=threshold_time, flush=True
)
events_processed = source.execute_loop(on_event=counter.on_event)
assert len(events_processed) == 0
assert counter.processed_events == 0
def test_dead_letter_takes_untaken_events_created_older_than_threshold():
source = DeadLetterEventSource(1)
counter = EventCounter()
old_create = arrow.utcnow().shift(minutes=-_DEAD_LETTER_THRESHOLD_MINUTES - 1)
SyncEvent.create(content="test".encode("utf-8"), created_at=old_create, flush=True)
events_processed = source.execute_loop(on_event=counter.on_event)
assert len(events_processed) == 1
assert events_processed[0].taken_time > old_create
assert counter.processed_events == 1
def test_dead_letter_takes_taken_events_created_older_than_threshold():
source = DeadLetterEventSource(1)
counter = EventCounter()
old_taken = arrow.utcnow().shift(minutes=-_DEAD_LETTER_THRESHOLD_MINUTES - 1)
SyncEvent.create(content="test".encode("utf-8"), taken_time=old_taken, flush=True)
events_processed = source.execute_loop(on_event=counter.on_event)
assert len(events_processed) == 1
assert events_processed[0].taken_time > old_taken
assert counter.processed_events == 1

View File

@ -1,7 +1,9 @@
import arrow
from app import config, alias_utils from app import config, alias_utils
from app.db import Session from app.db import Session
from app.events.event_dispatcher import GlobalDispatcher from app.events.event_dispatcher import GlobalDispatcher
from app.models import Alias from app.models import Alias, SyncEvent
from tests.utils import random_token from tests.utils import random_token
from .event_test_utils import ( from .event_test_utils import (
OnMemoryDispatcher, OnMemoryDispatcher,
@ -26,6 +28,33 @@ def setup_function(func):
on_memory_dispatcher.clear() on_memory_dispatcher.clear()
def test_event_taken_updates():
event = SyncEvent.create(content="test".encode("utf-8"), flush=True)
assert event.taken_time is None
assert event.mark_as_taken()
assert event.taken_time is not None
def test_event_mark_as_taken_does_nothing_for_taken_events():
now = arrow.utcnow()
event = SyncEvent.create(content="test".encode("utf-8"), taken_time=now, flush=True)
assert not event.mark_as_taken()
def test_event_mark_as_taken_does_nothing_for_not_before_events():
now = arrow.utcnow()
event = SyncEvent.create(content="test".encode("utf-8"), taken_time=now, flush=True)
older_than = now.shift(minutes=-1)
assert not event.mark_as_taken(allow_taken_older_than=older_than)
def test_event_mark_as_taken_works_for_before_events():
now = arrow.utcnow()
event = SyncEvent.create(content="test".encode("utf-8"), taken_time=now, flush=True)
older_than = now.shift(minutes=+1)
assert event.mark_as_taken(allow_taken_older_than=older_than)
def test_fire_event_on_alias_creation(): def test_fire_event_on_alias_creation():
(user, pu) = _create_linked_user() (user, pu) = _create_linked_user()
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
@ -79,7 +108,7 @@ def test_fire_event_on_alias_status_change():
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
Session.flush() Session.flush()
on_memory_dispatcher.clear() on_memory_dispatcher.clear()
alias_utils.change_alias_status(alias, True) alias_utils.change_alias_status(alias, enabled=True)
assert len(on_memory_dispatcher.memory) == 1 assert len(on_memory_dispatcher.memory) == 1
event_data = on_memory_dispatcher.memory[0] event_data = on_memory_dispatcher.memory[0]
event_content = _get_event_from_string(event_data, user, pu) event_content = _get_event_from_string(event_data, user, pu)

View File

@ -0,0 +1,109 @@
import arrow
from app import config
from app.events.event_dispatcher import GlobalDispatcher
from app.events.generated.event_pb2 import UserPlanChanged
from app.models import (
Subscription,
AppleSubscription,
CoinbaseSubscription,
ManualSubscription,
User,
PartnerUser,
)
from .event_test_utils import (
OnMemoryDispatcher,
_create_linked_user,
_get_event_from_string,
)
from tests.utils import random_token
from app.subscription_webhook import execute_subscription_webhook
on_memory_dispatcher = OnMemoryDispatcher()
def setup_module():
GlobalDispatcher.set_dispatcher(on_memory_dispatcher)
config.EVENT_WEBHOOK = "http://test"
def teardown_module():
GlobalDispatcher.set_dispatcher(None)
config.EVENT_WEBHOOK = None
def setup_function(func):
on_memory_dispatcher.clear()
def check_event(user: User, pu: PartnerUser) -> UserPlanChanged:
assert len(on_memory_dispatcher.memory) == 1
event_data = on_memory_dispatcher.memory[0]
event_content = _get_event_from_string(event_data, user, pu)
assert event_content.user_plan_change is not None
plan_change = event_content.user_plan_change
return plan_change
def test_webhook_with_trial():
(user, pu) = _create_linked_user()
execute_subscription_webhook(user)
assert check_event(user, pu).plan_end_time == 0
def test_webhook_with_subscription():
(user, pu) = _create_linked_user()
end_at = arrow.utcnow().shift(days=1).replace(hour=0, minute=0, second=0)
Subscription.create(
user_id=user.id,
cancel_url="",
update_url="",
subscription_id=random_token(10),
event_time=arrow.now(),
next_bill_date=end_at.date(),
plan="yearly",
flush=True,
)
execute_subscription_webhook(user)
assert check_event(user, pu).plan_end_time == end_at.timestamp
def test_webhook_with_apple_subscription():
(user, pu) = _create_linked_user()
end_at = arrow.utcnow().shift(days=2).replace(hour=0, minute=0, second=0)
AppleSubscription.create(
user_id=user.id,
receipt_data=arrow.now().date().strftime("%Y-%m-%d"),
expires_date=end_at.date().strftime("%Y-%m-%d"),
original_transaction_id=random_token(10),
plan="yearly",
product_id="",
flush=True,
)
execute_subscription_webhook(user)
assert check_event(user, pu).plan_end_time == end_at.timestamp
def test_webhook_with_coinbase_subscription():
(user, pu) = _create_linked_user()
end_at = arrow.utcnow().shift(days=3).replace(hour=0, minute=0, second=0)
CoinbaseSubscription.create(
user_id=user.id, end_at=end_at.date().strftime("%Y-%m-%d"), flush=True
)
execute_subscription_webhook(user)
assert check_event(user, pu).plan_end_time == end_at.timestamp
def test_webhook_with_manual_subscription():
(user, pu) = _create_linked_user()
end_at = arrow.utcnow().shift(days=3).replace(hour=0, minute=0, second=0)
ManualSubscription.create(
user_id=user.id, end_at=end_at.date().strftime("%Y-%m-%d"), flush=True
)
execute_subscription_webhook(user)
assert check_event(user, pu).plan_end_time == end_at.timestamp

View File

@ -1,3 +1,5 @@
from typing import List
import pytest import pytest
from arrow import Arrow from arrow import Arrow
@ -16,8 +18,9 @@ from app.account_linking import (
) )
from app.db import Session from app.db import Session
from app.errors import AccountAlreadyLinkedToAnotherPartnerException from app.errors import AccountAlreadyLinkedToAnotherPartnerException
from app.models import Partner, PartnerUser, User from app.models import Partner, PartnerUser, User, UserAuditLog
from app.proton.utils import get_proton_partner from app.proton.utils import get_proton_partner
from app.user_audit_log_utils import UserAuditLogAction
from app.utils import random_string, canonicalize_email from app.utils import random_string, canonicalize_email
from tests.utils import random_email from tests.utils import random_email
@ -91,6 +94,11 @@ def test_login_case_from_partner():
) )
assert res.user.activated is True assert res.user.activated is True
audit_logs: List[UserAuditLog] = UserAuditLog.filter_by(user_id=res.user.id).all()
assert len(audit_logs) == 1
assert audit_logs[0].user_id == res.user.id
assert audit_logs[0].action == UserAuditLogAction.LinkAccount.value
def test_login_case_from_partner_with_uppercase_email(): def test_login_case_from_partner_with_uppercase_email():
partner = get_proton_partner() partner = get_proton_partner()
@ -125,6 +133,11 @@ def test_login_case_from_web():
assert 0 == (res.user.flags & User.FLAG_CREATED_FROM_PARTNER) assert 0 == (res.user.flags & User.FLAG_CREATED_FROM_PARTNER)
assert res.user.activated is True assert res.user.activated is True
audit_logs: List[UserAuditLog] = UserAuditLog.filter_by(user_id=res.user.id).all()
assert len(audit_logs) == 1
assert audit_logs[0].user_id == res.user.id
assert audit_logs[0].action == UserAuditLogAction.LinkAccount.value
def test_get_strategy_existing_sl_user(): def test_get_strategy_existing_sl_user():
email = random_email() email = random_email()
@ -205,6 +218,10 @@ def test_link_account_with_proton_account_same_address(flask_client):
) )
assert partner_user.partner_id == get_proton_partner().id assert partner_user.partner_id == get_proton_partner().id
assert partner_user.external_user_id == partner_user_id assert partner_user.external_user_id == partner_user_id
audit_logs: List[UserAuditLog] = UserAuditLog.filter_by(user_id=res.user.id).all()
assert len(audit_logs) == 1
assert audit_logs[0].user_id == res.user.id
assert audit_logs[0].action == UserAuditLogAction.LinkAccount.value
def test_link_account_with_proton_account_different_address(flask_client): def test_link_account_with_proton_account_different_address(flask_client):
@ -229,6 +246,11 @@ def test_link_account_with_proton_account_different_address(flask_client):
assert partner_user.partner_id == get_proton_partner().id assert partner_user.partner_id == get_proton_partner().id
assert partner_user.external_user_id == partner_user_id assert partner_user.external_user_id == partner_user_id
audit_logs: List[UserAuditLog] = UserAuditLog.filter_by(user_id=res.user.id).all()
assert len(audit_logs) == 1
assert audit_logs[0].user_id == res.user.id
assert audit_logs[0].action == UserAuditLogAction.LinkAccount.value
def test_link_account_with_proton_account_same_address_but_linked_to_other_user( def test_link_account_with_proton_account_same_address_but_linked_to_other_user(
flask_client, flask_client,
@ -248,22 +270,54 @@ def test_link_account_with_proton_account_same_address_but_linked_to_other_user(
partner_user_id, email=random_email() partner_user_id, email=random_email()
) # User already linked with the proton account ) # User already linked with the proton account
# START Ensure sl_user_2 has a partner_user with the right data
partner_user = PartnerUser.get_by(
partner_id=get_proton_partner().id, user_id=sl_user_2.id
)
assert partner_user is not None
assert partner_user.partner_id == get_proton_partner().id
assert partner_user.external_user_id == partner_user_id
assert partner_user.partner_email == sl_user_2.email
assert partner_user.user_id == sl_user_2.id
# END Ensure sl_user_2 has a partner_user with the right data
# Proceed to link sl_user_1
res = process_link_case(link_request, sl_user_1, get_proton_partner()) res = process_link_case(link_request, sl_user_1, get_proton_partner())
# Check that the result is linking sl_user_1
assert res.user.id == sl_user_1.id assert res.user.id == sl_user_1.id
assert res.user.email == partner_email assert res.user.email == partner_email
assert res.strategy == "Link" assert res.strategy == "Link"
# Ensure partner_user for sl_user_1 exists
partner_user = PartnerUser.get_by( partner_user = PartnerUser.get_by(
partner_id=get_proton_partner().id, user_id=sl_user_1.id partner_id=get_proton_partner().id, user_id=sl_user_1.id
) )
assert partner_user.partner_id == get_proton_partner().id assert partner_user.partner_id == get_proton_partner().id
assert partner_user.external_user_id == partner_user_id assert partner_user.external_user_id == partner_user_id
# Ensure partner_user for sl_user_2 does not exist anymore
partner_user = PartnerUser.get_by( partner_user = PartnerUser.get_by(
partner_id=get_proton_partner().id, user_id=sl_user_2.id partner_id=get_proton_partner().id, user_id=sl_user_2.id
) )
assert partner_user is None assert partner_user is None
# Ensure audit logs for sl_user_1 show the link action
sl_user_1_audit_logs: List[UserAuditLog] = UserAuditLog.filter_by(
user_id=sl_user_1.id
).all()
assert len(sl_user_1_audit_logs) == 1
assert sl_user_1_audit_logs[0].user_id == sl_user_1.id
assert sl_user_1_audit_logs[0].action == UserAuditLogAction.LinkAccount.value
# Ensure audit logs for sl_user_2 show the unlink action
sl_user_2_audit_logs: List[UserAuditLog] = UserAuditLog.filter_by(
user_id=sl_user_2.id
).all()
assert len(sl_user_2_audit_logs) == 1
assert sl_user_2_audit_logs[0].user_id == sl_user_2.id
assert sl_user_2_audit_logs[0].action == UserAuditLogAction.UnlinkAccount.value
def test_link_account_with_proton_account_different_address_and_linked_to_other_user( def test_link_account_with_proton_account_different_address_and_linked_to_other_user(
flask_client, flask_client,
@ -300,6 +354,22 @@ def test_link_account_with_proton_account_different_address_and_linked_to_other_
) )
assert partner_user_2 is None assert partner_user_2 is None
# Ensure audit logs for sl_user_1 show the link action
sl_user_1_audit_logs: List[UserAuditLog] = UserAuditLog.filter_by(
user_id=sl_user_1.id
).all()
assert len(sl_user_1_audit_logs) == 1
assert sl_user_1_audit_logs[0].user_id == sl_user_1.id
assert sl_user_1_audit_logs[0].action == UserAuditLogAction.LinkAccount.value
# Ensure audit logs for sl_user_2 show the unlink action
sl_user_2_audit_logs: List[UserAuditLog] = UserAuditLog.filter_by(
user_id=sl_user_2.id
).all()
assert len(sl_user_2_audit_logs) == 1
assert sl_user_2_audit_logs[0].user_id == sl_user_2.id
assert sl_user_2_audit_logs[0].action == UserAuditLogAction.UnlinkAccount.value
def test_cannot_create_instance_of_base_strategy(): def test_cannot_create_instance_of_base_strategy():
with pytest.raises(Exception): with pytest.raises(Exception):

View File

@ -0,0 +1,95 @@
import random
from app.alias_audit_log_utils import emit_alias_audit_log, AliasAuditLogAction
from app.alias_utils import delete_alias, transfer_alias
from app.models import Alias, AliasAuditLog, AliasDeleteReason
from app.utils import random_string
from tests.utils import create_new_user, random_email
def test_emit_alias_audit_log_for_random_data():
user = create_new_user()
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
)
random_user_id = random.randint(1000, 2000)
message = random_string()
action = AliasAuditLogAction.ChangeAliasStatus
emit_alias_audit_log(
alias=alias,
user_id=random_user_id,
action=action,
message=message,
commit=True,
)
logs_for_alias = AliasAuditLog.filter_by(alias_id=alias.id).all()
assert len(logs_for_alias) == 2
last_log = logs_for_alias[-1]
assert last_log.alias_id == alias.id
assert last_log.alias_email == alias.email
assert last_log.user_id == random_user_id
assert last_log.action == action.value
assert last_log.message == message
def test_emit_alias_audit_log_on_alias_creation():
user = create_new_user()
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
)
log_for_alias = AliasAuditLog.filter_by(alias_id=alias.id).all()
assert len(log_for_alias) == 1
assert log_for_alias[0].alias_id == alias.id
assert log_for_alias[0].alias_email == alias.email
assert log_for_alias[0].user_id == user.id
assert log_for_alias[0].action == AliasAuditLogAction.CreateAlias.value
def test_alias_audit_log_exists_after_alias_deletion():
user = create_new_user()
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
)
alias_id = alias.id
emit_alias_audit_log(alias, AliasAuditLogAction.UpdateAlias, "")
emit_alias_audit_log(alias, AliasAuditLogAction.UpdateAlias, "")
delete_alias(alias, user, AliasDeleteReason.ManualAction, commit=True)
db_alias = Alias.get_by(id=alias_id)
assert db_alias is None
logs_for_alias = AliasAuditLog.filter_by(alias_id=alias.id).all()
assert len(logs_for_alias) == 4
assert logs_for_alias[0].action == AliasAuditLogAction.CreateAlias.value
assert logs_for_alias[1].action == AliasAuditLogAction.UpdateAlias.value
assert logs_for_alias[2].action == AliasAuditLogAction.UpdateAlias.value
assert logs_for_alias[3].action == AliasAuditLogAction.DeleteAlias.value
def test_alias_audit_log_for_transfer():
original_user = create_new_user()
new_user = create_new_user()
alias = Alias.create(
user_id=original_user.id,
email=random_email(),
mailbox_id=original_user.default_mailbox_id,
)
transfer_alias(alias, new_user, [new_user.default_mailbox])
logs_for_alias = AliasAuditLog.filter_by(alias_id=alias.id).all()
assert len(logs_for_alias) == 3
assert logs_for_alias[0].action == AliasAuditLogAction.CreateAlias.value
assert logs_for_alias[1].action == AliasAuditLogAction.TransferredAlias.value
assert logs_for_alias[1].user_id == original_user.id
assert logs_for_alias[2].action == AliasAuditLogAction.AcceptTransferAlias.value
assert logs_for_alias[2].user_id == new_user.id

View File

@ -0,0 +1,70 @@
from typing import Tuple
from app.alias_audit_log_utils import AliasAuditLogAction
from app.alias_mailbox_utils import (
set_mailboxes_for_alias,
CannotSetMailboxesForAliasCause,
)
from app.models import Alias, Mailbox, User, AliasMailbox, AliasAuditLog
from tests.utils import create_new_user, random_email
def setup() -> Tuple[User, Alias]:
user = create_new_user()
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
commit=True,
)
return user, alias
def test_set_mailboxes_for_alias_empty_list():
user, alias = setup()
err = set_mailboxes_for_alias(user.id, alias, [])
assert err is CannotSetMailboxesForAliasCause.EmptyMailboxes
def test_set_mailboxes_for_alias_mailbox_for_other_user():
user, alias = setup()
another_user = create_new_user()
err = set_mailboxes_for_alias(user.id, alias, [another_user.default_mailbox_id])
assert err is CannotSetMailboxesForAliasCause.Forbidden
def test_set_mailboxes_for_alias_mailbox_not_exists():
user, alias = setup()
err = set_mailboxes_for_alias(user.id, alias, [9999999])
assert err is CannotSetMailboxesForAliasCause.Forbidden
def test_set_mailboxes_for_alias_mailbox_success():
user, alias = setup()
mb1 = Mailbox.create(
user_id=user.id,
email=random_email(),
verified=True,
)
mb2 = Mailbox.create(
user_id=user.id,
email=random_email(),
verified=True,
commit=True,
)
err = set_mailboxes_for_alias(user.id, alias, [mb1.id, mb2.id])
assert err is None
db_alias = Alias.get_by(id=alias.id)
assert db_alias is not None
assert db_alias.mailbox_id == mb1.id
alias_mailboxes = AliasMailbox.filter_by(alias_id=alias.id).all()
assert len(alias_mailboxes) == 1
assert alias_mailboxes[0].mailbox_id == mb2.id
audit_logs = AliasAuditLog.filter_by(alias_id=alias.id).all()
assert len(audit_logs) == 2
assert audit_logs[0].action == AliasAuditLogAction.CreateAlias.value
assert audit_logs[1].action == AliasAuditLogAction.ChangedMailboxes.value
assert audit_logs[1].message == f"{mb1.id} ({mb1.email}),{mb2.id} ({mb2.email})"

View File

@ -6,7 +6,9 @@ import pytest
from app import mailbox_utils, config from app import mailbox_utils, config
from app.db import Session from app.db import Session
from app.mail_sender import mail_sender from app.mail_sender import mail_sender
from app.models import Mailbox, MailboxActivation, User, Job from app.mailbox_utils import MailboxEmailChangeError
from app.models import Mailbox, MailboxActivation, User, Job, UserAuditLog
from app.user_audit_log_utils import UserAuditLogAction
from tests.utils import create_new_user, random_email from tests.utils import create_new_user, random_email
@ -218,7 +220,11 @@ def test_delete_with_transfer():
user, random_email(), use_digit_codes=True, send_link=False user, random_email(), use_digit_codes=True, send_link=False
).mailbox ).mailbox
transfer_mailbox = mailbox_utils.create_mailbox( transfer_mailbox = mailbox_utils.create_mailbox(
user, random_email(), use_digit_codes=True, send_link=False user,
random_email(),
use_digit_codes=True,
send_link=False,
verified=True,
).mailbox ).mailbox
mailbox_utils.delete_mailbox( mailbox_utils.delete_mailbox(
user, mailbox.id, transfer_mailbox_id=transfer_mailbox.id user, mailbox.id, transfer_mailbox_id=transfer_mailbox.id
@ -236,6 +242,28 @@ def test_delete_with_transfer():
assert job.payload["transfer_mailbox_id"] is None assert job.payload["transfer_mailbox_id"] is None
def test_cannot_delete_with_transfer_to_unverified_mailbox():
mailbox = mailbox_utils.create_mailbox(
user, random_email(), use_digit_codes=True, send_link=False
).mailbox
transfer_mailbox = mailbox_utils.create_mailbox(
user,
random_email(),
use_digit_codes=True,
send_link=False,
verified=False,
).mailbox
with pytest.raises(mailbox_utils.MailboxError):
mailbox_utils.delete_mailbox(
user, mailbox.id, transfer_mailbox_id=transfer_mailbox.id
)
# Verify mailbox still exists
db_mailbox = Mailbox.get_by(id=mailbox.id)
assert db_mailbox is not None
def test_verify_non_existing_mailbox(): def test_verify_non_existing_mailbox():
with pytest.raises(mailbox_utils.MailboxError): with pytest.raises(mailbox_utils.MailboxError):
mailbox_utils.verify_mailbox_code(user, 999999999, "9999999") mailbox_utils.verify_mailbox_code(user, 999999999, "9999999")
@ -302,3 +330,74 @@ def test_verify_ok():
assert activation is None assert activation is None
mailbox = Mailbox.get(id=output.mailbox.id) mailbox = Mailbox.get(id=output.mailbox.id)
assert mailbox.verified assert mailbox.verified
# perform_mailbox_email_change
def test_perform_mailbox_email_change_invalid_id():
res = mailbox_utils.perform_mailbox_email_change(99999)
assert res.error == MailboxEmailChangeError.InvalidId
assert res.message_category == "error"
def test_perform_mailbox_email_change_valid_id_not_new_email():
user = create_new_user()
mb = Mailbox.create(
user_id=user.id,
email=random_email(),
new_email=None,
verified=True,
commit=True,
)
res = mailbox_utils.perform_mailbox_email_change(mb.id)
assert res.error == MailboxEmailChangeError.InvalidId
assert res.message_category == "error"
audit_log_entries = UserAuditLog.filter_by(user_id=user.id).count()
assert audit_log_entries == 0
def test_perform_mailbox_email_change_valid_id_email_already_used():
user = create_new_user()
new_email = random_email()
# Create mailbox with that email
Mailbox.create(
user_id=user.id,
email=new_email,
verified=True,
)
mb_to_change = Mailbox.create(
user_id=user.id,
email=random_email(),
new_email=new_email,
verified=True,
commit=True,
)
res = mailbox_utils.perform_mailbox_email_change(mb_to_change.id)
assert res.error == MailboxEmailChangeError.EmailAlreadyUsed
assert res.message_category == "error"
audit_log_entries = UserAuditLog.filter_by(user_id=user.id).count()
assert audit_log_entries == 0
def test_perform_mailbox_email_change_success():
user = create_new_user()
new_email = random_email()
mb = Mailbox.create(
user_id=user.id,
email=random_email(),
new_email=new_email,
verified=True,
commit=True,
)
res = mailbox_utils.perform_mailbox_email_change(mb.id)
assert res.error is None
assert res.message_category == "success"
db_mailbox = Mailbox.get_by(id=mb.id)
assert db_mailbox is not None
assert db_mailbox.verified is True
assert db_mailbox.email == new_email
assert db_mailbox.new_email is None
audit_log_entries = UserAuditLog.filter_by(user_id=user.id).all()
assert len(audit_log_entries) == 1
assert audit_log_entries[0].action == UserAuditLogAction.UpdateMailbox.value

View File

@ -2,7 +2,7 @@ import arrow
from app.db import Session from app.db import Session
from app.models import CoinbaseSubscription from app.models import CoinbaseSubscription
from server import handle_coinbase_event from app.payments.coinbase import handle_coinbase_event
from tests.utils import create_new_user from tests.utils import create_new_user

View File

@ -1,113 +0,0 @@
import http.server
import json
import threading
import arrow
from app import config
from app.models import (
Subscription,
AppleSubscription,
CoinbaseSubscription,
ManualSubscription,
)
from tests.utils import create_new_user, random_token
from app.subscription_webhook import execute_subscription_webhook
http_server = None
last_http_request = None
def setup_module():
global http_server
http_server = http.server.ThreadingHTTPServer(("", 0), HTTPTestServer)
print(http_server.server_port)
threading.Thread(target=http_server.serve_forever, daemon=True).start()
config.SUBSCRIPTION_CHANGE_WEBHOOK = f"http://localhost:{http_server.server_port}"
def teardown_module():
global http_server
config.SUBSCRIPTION_CHANGE_WEBHOOK = None
http_server.shutdown()
class HTTPTestServer(http.server.BaseHTTPRequestHandler):
def do_POST(self):
global last_http_request
content_len = int(self.headers.get("Content-Length"))
body_data = self.rfile.read(content_len)
last_http_request = json.loads(body_data)
self.send_response(200)
def test_webhook_with_trial():
user = create_new_user()
execute_subscription_webhook(user)
assert last_http_request["user_id"] == user.id
assert last_http_request["is_premium"]
assert last_http_request["active_subscription_end"] is None
def test_webhook_with_subscription():
user = create_new_user()
end_at = arrow.utcnow().shift(days=1).replace(hour=0, minute=0, second=0)
Subscription.create(
user_id=user.id,
cancel_url="",
update_url="",
subscription_id=random_token(10),
event_time=arrow.now(),
next_bill_date=end_at.date(),
plan="yearly",
flush=True,
)
execute_subscription_webhook(user)
assert last_http_request["user_id"] == user.id
assert last_http_request["is_premium"]
assert last_http_request["active_subscription_end"] == end_at.timestamp
def test_webhook_with_apple_subscription():
user = create_new_user()
end_at = arrow.utcnow().shift(days=2).replace(hour=0, minute=0, second=0)
AppleSubscription.create(
user_id=user.id,
receipt_data=arrow.now().date().strftime("%Y-%m-%d"),
expires_date=end_at.date().strftime("%Y-%m-%d"),
original_transaction_id=random_token(10),
plan="yearly",
product_id="",
flush=True,
)
execute_subscription_webhook(user)
assert last_http_request["user_id"] == user.id
assert last_http_request["is_premium"]
assert last_http_request["active_subscription_end"] == end_at.timestamp
def test_webhook_with_coinbase_subscription():
user = create_new_user()
end_at = arrow.utcnow().shift(days=3).replace(hour=0, minute=0, second=0)
CoinbaseSubscription.create(
user_id=user.id, end_at=end_at.date().strftime("%Y-%m-%d"), flush=True
)
execute_subscription_webhook(user)
assert last_http_request["user_id"] == user.id
assert last_http_request["is_premium"]
assert last_http_request["active_subscription_end"] == end_at.timestamp
def test_webhook_with_manual_subscription():
user = create_new_user()
end_at = arrow.utcnow().shift(days=3).replace(hour=0, minute=0, second=0)
ManualSubscription.create(
user_id=user.id, end_at=end_at.date().strftime("%Y-%m-%d"), flush=True
)
execute_subscription_webhook(user)
assert last_http_request["user_id"] == user.id
assert last_http_request["is_premium"]
assert last_http_request["active_subscription_end"] == end_at.timestamp

View File

@ -0,0 +1,52 @@
from typing import List
from app import config, mailbox_utils
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.models import UserAuditLog
from app.utils import random_string
from tests.utils import create_new_user, random_email
def setup_module():
config.SKIP_MX_LOOKUP_ON_CHECK = True
def teardown_module():
config.SKIP_MX_LOOKUP_ON_CHECK = False
def test_emit_alias_audit_log_for_random_data():
user = create_new_user()
message = random_string()
action = UserAuditLogAction.CreateMailbox
emit_user_audit_log(
user=user,
action=action,
message=message,
commit=True,
)
logs_for_user: List[UserAuditLog] = UserAuditLog.filter_by(user_id=user.id).all()
assert len(logs_for_user) == 1
assert logs_for_user[0].user_id == user.id
assert logs_for_user[0].user_email == user.email
assert logs_for_user[0].action == action.value
assert logs_for_user[0].message == message
def test_emit_audit_log_on_mailbox_creation():
user = create_new_user()
output = mailbox_utils.create_mailbox(
user=user, email=random_email(), verified=True
)
logs_for_user: List[UserAuditLog] = UserAuditLog.filter_by(user_id=user.id).all()
assert len(logs_for_user) == 1
assert logs_for_user[0].user_id == user.id
assert logs_for_user[0].user_email == user.email
assert logs_for_user[0].action == UserAuditLogAction.CreateMailbox.value
assert (
logs_for_user[0].message
== f"Create mailbox {output.mailbox.id} ({output.mailbox.email}). Verified=True"
)