Compare commits

...

13 Commits

Author SHA1 Message Date
bc48198bb1 4.55.1
All checks were successful
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m28s
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m31s
Build-Release-Image / Merge-Images (push) Successful in 16s
Build-Release-Image / Create-Release (push) Successful in 9s
Build-Release-Image / Notify (push) Successful in 4s
2024-10-19 12:00:05 +01:00
da6e56c4eb 4.55.0
All checks were successful
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m43s
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 4m10s
Build-Release-Image / Merge-Images (push) Successful in 27s
Build-Release-Image / Create-Release (push) Successful in 10s
Build-Release-Image / Notify (push) Successful in 3s
2024-10-18 12:00:06 +01:00
798b58529c 4.53.2
All checks were successful
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m29s
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m33s
Build-Release-Image / Merge-Images (push) Successful in 44s
Build-Release-Image / Create-Release (push) Successful in 7s
Build-Release-Image / Notify (push) Successful in 20s
2024-10-11 12:00:07 +01:00
3da6c983e1 4.53.1
All checks were successful
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m13s
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m54s
Build-Release-Image / Merge-Images (push) Successful in 16s
Build-Release-Image / Create-Release (push) Successful in 40s
Build-Release-Image / Notify (push) Successful in 5s
2024-10-09 12:00:06 +01:00
294232a329 4.52.1
All checks were successful
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m56s
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 4m45s
Build-Release-Image / Merge-Images (push) Successful in 22s
Build-Release-Image / Create-Release (push) Successful in 8s
Build-Release-Image / Notify (push) Successful in 3s
2024-10-02 12:00:06 +01:00
fae9d7bc17 4.52.0
All checks were successful
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 4m44s
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 4m31s
Build-Release-Image / Merge-Images (push) Successful in 23s
Build-Release-Image / Create-Release (push) Successful in 23s
Build-Release-Image / Notify (push) Successful in 17s
2024-10-01 12:00:06 +01:00
d666f5af3f 4.51.2
All checks were successful
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m33s
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m35s
Build-Release-Image / Merge-Images (push) Successful in 25s
Build-Release-Image / Create-Release (push) Successful in 10s
Build-Release-Image / Notify (push) Successful in 3s
2024-09-28 12:00:06 +01:00
556fae02d5 4.51.1
All checks were successful
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m21s
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m40s
Build-Release-Image / Merge-Images (push) Successful in 23s
Build-Release-Image / Create-Release (push) Successful in 9s
Build-Release-Image / Notify (push) Successful in 4s
2024-09-26 12:00:06 +01:00
fd4c67c3d1 4.51.0
All checks were successful
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m30s
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m31s
Build-Release-Image / Merge-Images (push) Successful in 11s
Build-Release-Image / Create-Release (push) Successful in 9s
Build-Release-Image / Notify (push) Successful in 2s
2024-09-25 12:00:07 +01:00
edef254529 4.50.0
All checks were successful
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m43s
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m54s
Build-Release-Image / Merge-Images (push) Successful in 23s
Build-Release-Image / Create-Release (push) Successful in 14s
Build-Release-Image / Notify (push) Successful in 3s
2024-09-19 12:00:06 +01:00
357f0cca57 4.49.10
Some checks failed
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m7s
Build-Release-Image / Build-Image (linux/amd64) (push) Has been cancelled
Build-Release-Image / Merge-Images (push) Has been cancelled
Build-Release-Image / Create-Release (push) Has been cancelled
Build-Release-Image / Notify (push) Has been cancelled
2024-09-13 12:00:28 +01:00
8ce90e27f7 4.49.9
Some checks failed
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m4s
Build-Release-Image / Build-Image (linux/amd64) (push) Has been cancelled
Build-Release-Image / Merge-Images (push) Has been cancelled
Build-Release-Image / Create-Release (push) Has been cancelled
Build-Release-Image / Notify (push) Has been cancelled
2024-09-07 12:00:06 +01:00
3ecc8d36f9 4.49.8
Some checks failed
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m18s
Build-Release-Image / Build-Image (linux/amd64) (push) Has been cancelled
Build-Release-Image / Merge-Images (push) Has been cancelled
Build-Release-Image / Create-Release (push) Has been cancelled
Build-Release-Image / Notify (push) Has been cancelled
2024-09-04 12:00:07 +01:00
99 changed files with 4939 additions and 1483 deletions

View File

@ -109,7 +109,7 @@ jobs:
GITHUB_ACTIONS_TEST: true
- name: Archive code coverage results
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v4
with:
name: code-coverage-report
path: htmlcov
@ -163,7 +163,7 @@ jobs:
uses: docker/build-push-action@v3
with:
context: .
platforms: linux/amd64,linux/arm64
platforms: linux/amd64
push: true
tags: ${{ steps.meta.outputs.tags }}

View File

@ -9,6 +9,7 @@ from sqlalchemy import or_
from app.db import Session
from app.email_utils import send_welcome_email
from app.partner_user_utils import create_partner_user, create_partner_subscription
from app.utils import sanitize_email, canonicalize_email
from app.errors import (
AccountAlreadyLinkedToAnotherPartnerException,
@ -23,6 +24,7 @@ from app.models import (
User,
Alias,
)
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.utils import random_string
@ -66,9 +68,10 @@ def set_plan_for_partner_user(partner_user: PartnerUser, plan: SLPlan):
LOG.i(
f"Creating partner_subscription [user_id={partner_user.user_id}] [partner_id={partner_user.partner_id}]"
)
PartnerSubscription.create(
partner_user_id=partner_user.id,
end_at=plan.expiration,
create_partner_subscription(
partner_user=partner_user,
expiration=plan.expiration,
msg="Upgraded via partner. User did not have a previous partner subscription",
)
agent.record_custom_event("PlanChange", {"plan": "premium", "type": "new"})
else:
@ -80,6 +83,11 @@ def set_plan_for_partner_user(partner_user: PartnerUser, plan: SLPlan):
"PlanChange", {"plan": "premium", "type": "extension"}
)
sub.end_at = plan.expiration
emit_user_audit_log(
user=partner_user.user,
action=UserAuditLogAction.SubscriptionExtended,
message="Extended partner subscription",
)
Session.commit()
@ -98,8 +106,8 @@ def ensure_partner_user_exists_for_user(
if res and res.partner_id != partner.id:
raise AccountAlreadyLinkedToAnotherPartnerException()
if not res:
res = PartnerUser.create(
user_id=sl_user.id,
res = create_partner_user(
user=sl_user,
partner_id=partner.id,
partner_email=link_request.email,
external_user_id=link_request.external_user_id,
@ -140,8 +148,8 @@ class NewUserStrategy(ClientMergeStrategy):
activated=True,
from_partner=self.link_request.from_partner,
)
partner_user = PartnerUser.create(
user_id=new_user.id,
partner_user = create_partner_user(
user=new_user,
partner_id=self.partner.id,
external_user_id=self.link_request.external_user_id,
partner_email=self.link_request.email,
@ -200,7 +208,7 @@ def get_login_strategy(
return ExistingUnlinkedUserStrategy(link_request, user, partner)
def check_alias(email: str) -> bool:
def check_alias(email: str):
alias = Alias.get_by(email=email)
if alias is not None:
raise AccountIsUsingAliasAsEmail()
@ -275,10 +283,26 @@ def switch_already_linked_user(
LOG.i(
f"Deleting previous partner_user:{other_partner_user.id} from user:{current_user.id}"
)
emit_user_audit_log(
user=other_partner_user.user,
action=UserAuditLogAction.UnlinkAccount,
message=f"Deleting partner_user {other_partner_user.id} (external_user_id={other_partner_user.external_user_id} | partner_email={other_partner_user.partner_email}) from user {current_user.id}, as we received a new link request for the same partner",
)
PartnerUser.delete(other_partner_user.id)
LOG.i(f"Linking partner_user:{partner_user.id} to user:{current_user.id}")
# Link this partner_user to the current user
emit_user_audit_log(
user=partner_user.user,
action=UserAuditLogAction.UnlinkAccount,
message=f"Unlinking from partner, as user will now be tied to another external account. old=(id={partner_user.user.id} | email={partner_user.user.email}) | new=(id={current_user.id} | email={current_user.email})",
)
partner_user.user_id = current_user.id
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.LinkAccount,
message=f"Linking user {current_user.id} ({current_user.email}) to partner_user:{partner_user.id} (external_user_id={partner_user.external_user_id} | partner_email={partner_user.partner_email})",
)
# Set plan
set_plan_for_partner_user(partner_user, link_request.plan)
Session.commit()

View File

@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Optional
from typing import Optional, List
import arrow
import sqlalchemy
@ -34,6 +34,9 @@ from app.models import (
DeletedAlias,
DomainDeletedAlias,
PartnerUser,
AliasMailbox,
AliasAuditLog,
UserAuditLog,
)
from app.newsletter_utils import send_newsletter_to_user, send_newsletter_to_address
@ -112,7 +115,7 @@ class SLAdminIndexView(AdminIndexView):
if not current_user.is_authenticated or not current_user.is_admin:
return redirect(url_for("auth.login", next=request.url))
return redirect("/admin/user")
return redirect("/admin/email_search")
class UserAdmin(SLModelView):
@ -736,33 +739,72 @@ class InvalidMailboxDomainAdmin(SLModelView):
class EmailSearchResult:
no_match: bool = True
alias: Optional[Alias] = None
mailbox: Optional[Mailbox] = None
alias_audit_log: Optional[List[AliasAuditLog]] = None
mailbox: List[Mailbox] = []
mailbox_count: int = 0
deleted_alias: Optional[DeletedAlias] = None
deleted_custom_alias: Optional[DomainDeletedAlias] = None
deleted_alias_audit_log: Optional[List[AliasAuditLog]] = None
domain_deleted_alias: Optional[DomainDeletedAlias] = None
domain_deleted_alias_audit_log: Optional[List[AliasAuditLog]] = None
user: Optional[User] = None
user_audit_log: Optional[List[UserAuditLog]] = None
query: str
@staticmethod
def from_email(email: str) -> EmailSearchResult:
output = EmailSearchResult()
output.query = email
alias = Alias.get_by(email=email)
if alias:
output.alias = alias
output.alias_audit_log = (
AliasAuditLog.filter_by(alias_id=alias.id)
.order_by(AliasAuditLog.created_at.desc())
.all()
)
output.no_match = False
user = User.get_by(email=email)
if user:
output.user = user
output.user_audit_log = (
UserAuditLog.filter_by(user_id=user.id)
.order_by(UserAuditLog.created_at.desc())
.all()
)
output.no_match = False
mailbox = Mailbox.get_by(email=email)
if mailbox:
output.mailbox = mailbox
user_audit_log = (
UserAuditLog.filter_by(user_email=email)
.order_by(UserAuditLog.created_at.desc())
.all()
)
if user_audit_log:
output.user_audit_log = user_audit_log
output.no_match = False
mailboxes = (
Mailbox.filter_by(email=email).order_by(Mailbox.id.desc()).limit(10).all()
)
if mailboxes:
output.mailbox = mailboxes
output.mailbox_count = Mailbox.filter_by(email=email).count()
output.no_match = False
deleted_alias = DeletedAlias.get_by(email=email)
if deleted_alias:
output.deleted_alias = deleted_alias
output.deleted_alias_audit_log = (
AliasAuditLog.filter_by(alias_email=deleted_alias.email)
.order_by(AliasAuditLog.created_at.desc())
.all()
)
output.no_match = False
domain_deleted_alias = DomainDeletedAlias.get_by(email=email)
if domain_deleted_alias:
output.domain_deleted_alias = domain_deleted_alias
output.domain_deleted_alias_audit_log = (
AliasAuditLog.filter_by(alias_email=domain_deleted_alias.email)
.order_by(AliasAuditLog.created_at.desc())
.all()
)
output.no_match = False
return output
@ -779,11 +821,32 @@ class EmailSearchHelpers:
@staticmethod
def mailbox_count(user: User) -> int:
return Mailbox.filter_by(user_id=user.id).order_by(Mailbox.id.asc()).count()
return Mailbox.filter_by(user_id=user.id).order_by(Mailbox.id.desc()).count()
@staticmethod
def alias_mailboxes(alias: Alias) -> list[Mailbox]:
return (
Session.query(Mailbox)
.filter(Mailbox.id == Alias.mailbox_id, Alias.id == alias.id)
.union(
Session.query(Mailbox)
.join(AliasMailbox, Mailbox.id == AliasMailbox.mailbox_id)
.filter(AliasMailbox.alias_id == alias.id)
)
.order_by(Mailbox.id)
.limit(10)
.all()
)
@staticmethod
def alias_mailbox_count(alias: Alias) -> int:
return len(alias.mailboxes)
@staticmethod
def alias_list(user: User) -> list[Alias]:
return Alias.filter_by(user_id=user.id).order_by(Alias.id.asc()).limit(10).all()
return (
Alias.filter_by(user_id=user.id).order_by(Alias.id.desc()).limit(10).all()
)
@staticmethod
def alias_count(user: User) -> int:
@ -806,9 +869,8 @@ class EmailSearchAdmin(BaseView):
@expose("/", methods=["GET", "POST"])
def index(self):
search = EmailSearchResult()
email = ""
if request.form and request.form["email"]:
email = request.form["email"]
email = request.args.get("email")
if email is not None and len(email) > 0:
email = email.strip()
search = EmailSearchResult.from_email(email)

View File

@ -0,0 +1,38 @@
from enum import Enum
from typing import Optional
from app.models import Alias, AliasAuditLog
class AliasAuditLogAction(Enum):
CreateAlias = "create"
ChangeAliasStatus = "change_status"
DeleteAlias = "delete"
UpdateAlias = "update"
InitiateTransferAlias = "initiate_transfer_alias"
AcceptTransferAlias = "accept_transfer_alias"
TransferredAlias = "transferred_alias"
ChangedMailboxes = "changed_mailboxes"
CreateContact = "create_contact"
UpdateContact = "update_contact"
DeleteContact = "delete_contact"
def emit_alias_audit_log(
alias: Alias,
action: AliasAuditLogAction,
message: str,
user_id: Optional[int] = None,
commit: bool = False,
):
AliasAuditLog.create(
user_id=user_id or alias.user_id,
alias_id=alias.id,
alias_email=alias.email,
action=action.value,
message=message,
commit=commit,
)

View File

@ -0,0 +1,61 @@
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional
from app.alias_audit_log_utils import emit_alias_audit_log, AliasAuditLogAction
from app.db import Session
from app.models import Alias, AliasMailbox, Mailbox
_MAX_MAILBOXES_PER_ALIAS = 20
class CannotSetMailboxesForAliasCause(Enum):
Forbidden = "Forbidden"
EmptyMailboxes = "Must choose at least one mailbox"
TooManyMailboxes = "Too many mailboxes"
@dataclass
class SetMailboxesForAliasResult:
performed_change: bool
reason: Optional[CannotSetMailboxesForAliasCause]
def set_mailboxes_for_alias(
user_id: int, alias: Alias, mailbox_ids: List[int]
) -> Optional[CannotSetMailboxesForAliasCause]:
if len(mailbox_ids) == 0:
return CannotSetMailboxesForAliasCause.EmptyMailboxes
if len(mailbox_ids) > _MAX_MAILBOXES_PER_ALIAS:
return CannotSetMailboxesForAliasCause.TooManyMailboxes
mailboxes = (
Session.query(Mailbox)
.filter(
Mailbox.id.in_(mailbox_ids),
Mailbox.user_id == user_id,
Mailbox.verified == True, # noqa: E712
)
.all()
)
if len(mailboxes) != len(mailbox_ids):
return CannotSetMailboxesForAliasCause.Forbidden
# first remove all existing alias-mailboxes links
AliasMailbox.filter_by(alias_id=alias.id).delete()
Session.flush()
# then add all new mailboxes, being the first the one associated with the alias
for i, mailbox in enumerate(mailboxes):
if i == 0:
alias.mailbox_id = mailboxes[0].id
else:
AliasMailbox.create(alias_id=alias.id, mailbox_id=mailbox.id)
emit_alias_audit_log(
alias=alias,
action=AliasAuditLogAction.ChangedMailboxes,
message=",".join([f"{mailbox.id} ({mailbox.email})" for mailbox in mailboxes]),
)
return None

View File

@ -1,12 +1,14 @@
import csv
from io import StringIO
import re
from dataclasses import dataclass
from typing import Optional, Tuple
from email_validator import validate_email, EmailNotValidError
from sqlalchemy.exc import IntegrityError, DataError
from flask import make_response
from app.alias_audit_log_utils import AliasAuditLogAction, emit_alias_audit_log
from app.config import (
BOUNCE_PREFIX_FOR_REPLY_PHASE,
BOUNCE_PREFIX,
@ -23,6 +25,7 @@ from app.email_utils import (
send_cannot_create_domain_alias,
send_email,
render,
sl_formataddr,
)
from app.errors import AliasInTrashError
from app.events.event_dispatcher import EventDispatcher
@ -30,6 +33,7 @@ from app.events.generated.event_pb2 import (
AliasDeleted,
AliasStatusChanged,
EventContent,
AliasCreated,
)
from app.log import LOG
from app.models import (
@ -363,11 +367,18 @@ def delete_alias(
Session.commit()
LOG.i(f"Moving {alias} to global trash {deleted_alias}")
alias_id = alias.id
alias_email = alias.email
emit_alias_audit_log(
alias, AliasAuditLogAction.DeleteAlias, "Alias deleted by user action"
)
Alias.filter(Alias.id == alias.id).delete()
Session.commit()
EventDispatcher.send_event(
user, EventContent(alias_deleted=AliasDeleted(alias_id=alias.id))
user,
EventContent(alias_deleted=AliasDeleted(id=alias_id, email=alias_email)),
)
if commit:
Session.commit()
@ -444,7 +455,7 @@ def alias_export_csv(user, csv_direct_export=False):
return output
def transfer_alias(alias, new_user, new_mailboxes: [Mailbox]):
def transfer_alias(alias: Alias, new_user: User, new_mailboxes: [Mailbox]):
# cannot transfer alias which is used for receiving newsletter
if User.get_by(newsletter_alias_id=alias.id):
raise Exception("Cannot transfer alias that's used to receive newsletter")
@ -498,17 +509,90 @@ def transfer_alias(alias, new_user, new_mailboxes: [Mailbox]):
alias.disable_pgp = False
alias.pinned = False
emit_alias_audit_log(
alias=alias,
action=AliasAuditLogAction.TransferredAlias,
message=f"Lost ownership of alias due to alias transfer confirmed. New owner is {new_user.id}",
user_id=old_user.id,
)
EventDispatcher.send_event(
old_user,
EventContent(
alias_deleted=AliasDeleted(
id=alias.id,
email=alias.email,
)
),
)
emit_alias_audit_log(
alias=alias,
action=AliasAuditLogAction.AcceptTransferAlias,
message=f"Accepted alias transfer from user {old_user.id}",
user_id=new_user.id,
)
EventDispatcher.send_event(
new_user,
EventContent(
alias_created=AliasCreated(
id=alias.id,
email=alias.email,
note=alias.note,
enabled=alias.enabled,
created_at=int(alias.created_at.timestamp),
)
),
)
Session.commit()
def change_alias_status(alias: Alias, enabled: bool, commit: bool = False):
def change_alias_status(
alias: Alias, enabled: bool, message: Optional[str] = None, commit: bool = False
):
LOG.i(f"Changing alias {alias} enabled to {enabled}")
alias.enabled = enabled
event = AliasStatusChanged(
alias_id=alias.id, alias_email=alias.email, enabled=enabled
id=alias.id,
email=alias.email,
enabled=enabled,
created_at=int(alias.created_at.timestamp),
)
EventDispatcher.send_event(alias.user, EventContent(alias_status_change=event))
audit_log_message = f"Set alias status to {enabled}"
if message is not None:
audit_log_message += f". {message}"
emit_alias_audit_log(
alias, AliasAuditLogAction.ChangeAliasStatus, audit_log_message
)
if commit:
Session.commit()
@dataclass
class AliasRecipientName:
name: str
message: Optional[str] = None
def get_alias_recipient_name(alias: Alias) -> AliasRecipientName:
"""
Logic:
1. If alias has name, use it
2. If alias has custom domain, and custom domain has name, use it
3. Otherwise, use the alias email as the recipient
"""
if alias.name:
return AliasRecipientName(
name=sl_formataddr((alias.name, alias.email)),
message=f"Put alias name {alias.name} in from header",
)
elif alias.custom_domain:
if alias.custom_domain.name:
return AliasRecipientName(
name=sl_formataddr((alias.custom_domain.name, alias.email)),
message=f"Put domain default alias name {alias.custom_domain.name} in from header",
)
return AliasRecipientName(name=alias.email)

View File

@ -1,9 +1,13 @@
from typing import Optional
from deprecated import deprecated
from flask import g
from flask import jsonify
from flask import request
from app import alias_utils
from app.alias_audit_log_utils import emit_alias_audit_log, AliasAuditLogAction
from app.alias_mailbox_utils import set_mailboxes_for_alias
from app.api.base import api_bp, require_api_auth
from app.api.serializer import (
AliasInfo,
@ -26,7 +30,7 @@ from app.errors import (
)
from app.extensions import limiter
from app.log import LOG
from app.models import Alias, Contact, Mailbox, AliasMailbox, AliasDeleteReason
from app.models import Alias, Contact, Mailbox, AliasDeleteReason
@deprecated
@ -185,7 +189,11 @@ def toggle_alias(alias_id):
if not alias or alias.user_id != user.id:
return jsonify(error="Forbidden"), 403
alias_utils.change_alias_status(alias, enabled=not alias.enabled)
alias_utils.change_alias_status(
alias,
enabled=not alias.enabled,
message=f"Set enabled={not alias.enabled} via API",
)
LOG.i(f"User {user} changed alias {alias} enabled status to {alias.enabled}")
Session.commit()
@ -272,10 +280,12 @@ def update_alias(alias_id):
if not alias or alias.user_id != user.id:
return jsonify(error="Forbidden"), 403
changed_fields = []
changed = False
if "note" in data:
new_note = data.get("note")
alias.note = new_note
changed_fields.append("note")
changed = True
if "mailbox_id" in data:
@ -285,35 +295,19 @@ def update_alias(alias_id):
return jsonify(error="Forbidden"), 400
alias.mailbox_id = mailbox_id
changed_fields.append(f"mailbox_id ({mailbox_id})")
changed = True
if "mailbox_ids" in data:
mailbox_ids = [int(m_id) for m_id in data.get("mailbox_ids")]
mailboxes: [Mailbox] = []
# check if all mailboxes belong to user
for mailbox_id in mailbox_ids:
mailbox = Mailbox.get(mailbox_id)
if not mailbox or mailbox.user_id != user.id or not mailbox.verified:
return jsonify(error="Forbidden"), 400
mailboxes.append(mailbox)
if not mailboxes:
return jsonify(error="Must choose at least one mailbox"), 400
# <<< update alias mailboxes >>>
# first remove all existing alias-mailboxes links
AliasMailbox.filter_by(alias_id=alias.id).delete()
Session.flush()
# then add all new mailboxes
for i, mailbox in enumerate(mailboxes):
if i == 0:
alias.mailbox_id = mailboxes[0].id
else:
AliasMailbox.create(alias_id=alias.id, mailbox_id=mailbox.id)
# <<< END update alias mailboxes >>>
err = set_mailboxes_for_alias(
user_id=user.id, alias=alias, mailbox_ids=mailbox_ids
)
if err:
return jsonify(error=err.value), 400
mailbox_ids_string = ",".join(map(str, mailbox_ids))
changed_fields.append(f"mailbox_ids ({mailbox_ids_string})")
changed = True
if "name" in data:
@ -325,17 +319,26 @@ def update_alias(alias_id):
if new_name:
new_name = new_name.replace("\n", "")
alias.name = new_name
changed_fields.append("name")
changed = True
if "disable_pgp" in data:
alias.disable_pgp = data.get("disable_pgp")
changed_fields.append("disable_pgp")
changed = True
if "pinned" in data:
alias.pinned = data.get("pinned")
changed_fields.append("pinned")
changed = True
if changed:
changed_fields_string = ",".join(changed_fields)
emit_alias_audit_log(
alias,
AliasAuditLogAction.UpdateAlias,
f"Alias fields updated ({changed_fields_string})",
)
Session.commit()
return jsonify(ok=True), 200
@ -424,7 +427,7 @@ def create_contact_route(alias_id):
contact_address = data.get("contact")
try:
contact = create_contact(g.user, alias, contact_address)
contact = create_contact(alias, contact_address)
except ErrContactErrorUpgradeNeeded as err:
return jsonify(error=err.error_for_user()), 403
except (ErrAddressInvalid, CannotCreateContactForReverseAlias) as err:
@ -446,11 +449,16 @@ def delete_contact(contact_id):
200
"""
user = g.user
contact = Contact.get(contact_id)
contact: Optional[Contact] = Contact.get(contact_id)
if not contact or contact.alias.user_id != user.id:
return jsonify(error="Forbidden"), 403
emit_alias_audit_log(
alias=contact.alias,
action=AliasAuditLogAction.DeleteContact,
message=f"Deleted contact {contact_id} ({contact.email})",
)
Contact.delete(contact_id)
Session.commit()
@ -468,12 +476,17 @@ def toggle_contact(contact_id):
200
"""
user = g.user
contact = Contact.get(contact_id)
contact: Optional[Contact] = Contact.get(contact_id)
if not contact or contact.alias.user_id != user.id:
return jsonify(error="Forbidden"), 403
contact.block_forward = not contact.block_forward
emit_alias_audit_log(
alias=contact.alias,
action=AliasAuditLogAction.UpdateContact,
message=f"Set contact state {contact.id} {contact.email} -> {contact.website_email} to blocked {contact.block_forward}",
)
Session.commit()
return jsonify(block_forward=contact.block_forward), 200

View File

@ -23,6 +23,7 @@ from app.events.auth_event import LoginEvent, RegisterEvent
from app.extensions import limiter
from app.log import LOG
from app.models import User, ApiKey, SocialAuth, AccountActivation
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.utils import sanitize_email, canonicalize_email
@ -52,8 +53,12 @@ def auth_login():
password = data.get("password")
device = data.get("device")
email = sanitize_email(data.get("email"))
canonical_email = canonicalize_email(data.get("email"))
email = data.get("email")
if not email:
LoginEvent(LoginEvent.ActionType.failed, LoginEvent.Source.api).send()
return jsonify(error="Email or password incorrect"), 400
email = sanitize_email(email)
canonical_email = canonicalize_email(email)
user = User.get_by(email=email) or User.get_by(email=canonical_email)
@ -183,6 +188,11 @@ def auth_activate():
LOG.d("activate user %s", user)
user.activated = True
emit_user_audit_log(
user=user,
action=UserAuditLogAction.ActivateUser,
message=f"User has been activated: {user.email}",
)
AccountActivation.delete(account_activation.id)
Session.commit()

View File

@ -2,8 +2,10 @@ from flask import g, request
from flask import jsonify
from app.api.base import api_bp, require_api_auth
from app.custom_domain_utils import set_custom_domain_mailboxes
from app.db import Session
from app.models import CustomDomain, DomainDeletedAlias, Mailbox, DomainMailbox
from app.log import LOG
from app.models import CustomDomain, DomainDeletedAlias
def custom_domain_to_dict(custom_domain: CustomDomain):
@ -100,23 +102,14 @@ def update_custom_domain(custom_domain_id):
if "mailbox_ids" in data:
mailbox_ids = [int(m_id) for m_id in data.get("mailbox_ids")]
if mailbox_ids:
# check if mailbox is not tempered with
mailboxes = []
for mailbox_id in mailbox_ids:
mailbox = Mailbox.get(mailbox_id)
if not mailbox or mailbox.user_id != user.id or not mailbox.verified:
return jsonify(error="Forbidden"), 400
mailboxes.append(mailbox)
# first remove all existing domain-mailboxes links
DomainMailbox.filter_by(domain_id=custom_domain.id).delete()
Session.flush()
for mailbox in mailboxes:
DomainMailbox.create(domain_id=custom_domain.id, mailbox_id=mailbox.id)
result = set_custom_domain_mailboxes(user.id, custom_domain, mailbox_ids)
if result.success:
changed = True
else:
LOG.info(
f"Prevented from updating mailboxes [custom_domain_id={custom_domain.id}]: {result.reason.value}"
)
return jsonify(error="Forbidden"), 400
if changed:
Session.commit()

View File

@ -6,6 +6,7 @@ from app import config
from app.extensions import limiter
from app.log import LOG
from app.models import Job, ApiToCookieToken
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
@api_bp.route("/user", methods=["DELETE"])
@ -16,6 +17,11 @@ def delete_user():
"""
# Schedule delete account job
emit_user_audit_log(
user=g.user,
action=UserAuditLogAction.UserMarkedForDeletion,
message=f"Marked user {g.user.id} ({g.user.email}) for deletion from API",
)
LOG.w("schedule delete account job for %s", g.user)
Job.create(
name=config.JOB_DELETE_ACCOUNT,

View File

@ -87,7 +87,7 @@ def update_user_info():
File.delete(file.id)
s3.delete(file.path)
Session.flush()
else:
if data["profile_picture"] is not None:
raw_data = base64.decodebytes(data["profile_picture"].encode())
if detect_image_format(raw_data) == ImageFormat.Unknown:
return jsonify(error="Unsupported image format"), 400

View File

@ -7,6 +7,7 @@ from app.db import Session
from app.extensions import limiter
from app.log import LOG
from app.models import ActivationCode
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.utils import sanitize_next_url
@ -47,6 +48,11 @@ def activate():
user = activation_code.user
user.activated = True
emit_user_audit_log(
user=user,
action=UserAuditLogAction.ActivateUser,
message=f"User has been activated: {user.email}",
)
login_user(user)
# activation code is to be used only once

View File

@ -9,6 +9,7 @@ from app.auth.views.login_utils import after_login
from app.db import Session
from app.extensions import limiter
from app.models import ResetPasswordCode
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
class ResetPasswordForm(FlaskForm):
@ -59,6 +60,11 @@ def reset_password():
# this can be served to activate user too
user.activated = True
emit_user_audit_log(
user=user,
action=UserAuditLogAction.ResetPassword,
message="User has reset their password",
)
# remove all reset password codes
ResetPasswordCode.filter_by(user_id=user.id).delete()

View File

@ -35,6 +35,33 @@ def sl_getenv(env_var: str, default_factory: Callable = None):
return literal_eval(value)
def get_env_dict(env_var: str) -> dict[str, str]:
"""
Get an env variable and convert it into a python dictionary with keys and values as strings.
Args:
env_var (str): env var, example: SL_DB
Syntax is: key1=value1;key2=value2
Components separated by ;
key and value separated by =
"""
value = os.getenv(env_var)
if not value:
return {}
components = value.split(";")
result = {}
for component in components:
if component == "":
continue
parts = component.split("=")
if len(parts) != 2:
raise Exception(f"Invalid config for env var {env_var}")
result[parts[0].strip()] = parts[1].strip()
return result
config_file = os.environ.get("CONFIG")
if config_file:
config_file = get_abs_path(config_file)
@ -574,7 +601,6 @@ SKIP_MX_LOOKUP_ON_CHECK = False
DISABLE_RATE_LIMIT = "DISABLE_RATE_LIMIT" in os.environ
SUBSCRIPTION_CHANGE_WEBHOOK = os.environ.get("SUBSCRIPTION_CHANGE_WEBHOOK", None)
MAX_API_KEYS = int(os.environ.get("MAX_API_KEYS", 30))
UPCLOUD_USERNAME = os.environ.get("UPCLOUD_USERNAME", None)
@ -609,3 +635,32 @@ EVENT_WEBHOOK_ENABLED_USER_IDS: Optional[List[int]] = read_webhook_enabled_user_
# Allow to define a different DB_URI for the event listener, in case we want to skip the connection pool
# It defaults to the regular DB_URI in case it's needed
EVENT_LISTENER_DB_URI = os.environ.get("EVENT_LISTENER_DB_URI", DB_URI)
def read_partner_dict(var: str) -> dict[int, str]:
partner_value = get_env_dict(var)
if len(partner_value) == 0:
return {}
res: dict[int, str] = {}
for partner_id in partner_value.keys():
try:
partner_id_int = int(partner_id.strip())
res[partner_id_int] = partner_value[partner_id]
except ValueError:
pass
return res
PARTNER_DNS_CUSTOM_DOMAINS: dict[int, str] = read_partner_dict(
"PARTNER_DNS_CUSTOM_DOMAINS"
)
PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES: dict[int, str] = read_partner_dict(
"PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES"
)
MAILBOX_VERIFICATION_OVERRIDE_CODE: Optional[str] = os.environ.get(
"MAILBOX_VERIFICATION_OVERRIDE_CODE", None
)
AUDIT_LOG_MAX_DAYS = int(os.environ.get("AUDIT_LOG_MAX_DAYS", 30))

View File

@ -1 +1,2 @@
HEADER_ALLOW_API_COOKIES = "X-Sl-Allowcookies"
DMARC_RECORD = "v=DMARC1; p=quarantine; pct=100; adkim=s; aspf=s"

124
app/app/contact_utils.py Normal file
View File

@ -0,0 +1,124 @@
from dataclasses import dataclass
from enum import Enum
from typing import Optional
from sqlalchemy.exc import IntegrityError
from app.alias_audit_log_utils import emit_alias_audit_log, AliasAuditLogAction
from app.db import Session
from app.email_utils import generate_reply_email, parse_full_address
from app.email_validation import is_valid_email
from app.log import LOG
from app.models import Contact, Alias
from app.utils import sanitize_email
class ContactCreateError(Enum):
InvalidEmail = "Invalid email"
NotAllowed = "Your plan does not allow to create contacts"
@dataclass
class ContactCreateResult:
contact: Optional[Contact]
created: bool
error: Optional[ContactCreateError]
def __update_contact_if_needed(
contact: Contact, name: Optional[str], mail_from: Optional[str]
) -> ContactCreateResult:
if name and contact.name != name:
LOG.d(f"Setting {contact} name to {name}")
contact.name = name
Session.commit()
if mail_from and contact.mail_from is None:
LOG.d(f"Setting {contact} mail_from to {mail_from}")
contact.mail_from = mail_from
Session.commit()
return ContactCreateResult(contact, created=False, error=None)
def create_contact(
email: str,
alias: Alias,
name: Optional[str] = None,
mail_from: Optional[str] = None,
allow_empty_email: bool = False,
automatic_created: bool = False,
from_partner: bool = False,
) -> ContactCreateResult:
# If user cannot create contacts, they still need to be created when receiving an email for an alias
if not automatic_created and not alias.user.can_create_contacts():
return ContactCreateResult(
None, created=False, error=ContactCreateError.NotAllowed
)
# Parse emails with form 'name <email>'
try:
email_name, email = parse_full_address(email)
except ValueError:
email = ""
email_name = ""
# If no name is explicitly given try to get it from the parsed email
if name is None:
name = email_name[: Contact.MAX_NAME_LENGTH]
else:
name = name[: Contact.MAX_NAME_LENGTH]
# If still no name is there, make sure the name is None instead of empty string
if not name:
name = None
if name is not None and "\x00" in name:
LOG.w("Cannot use contact name because has \\x00")
name = ""
# Sanitize email and if it's not valid only allow to create a contact if it's explicitly allowed. Otherwise fail
email = sanitize_email(email, not_lower=True)
if not is_valid_email(email):
LOG.w(f"invalid contact email {email}")
if not allow_empty_email:
return ContactCreateResult(
None, created=False, error=ContactCreateError.InvalidEmail
)
LOG.d("Create a contact with invalid email for %s", alias)
# either reuse a contact with empty email or create a new contact with empty email
email = ""
# If contact exists, update name and mail_from if needed
contact = Contact.get_by(alias_id=alias.id, website_email=email)
if contact is not None:
return __update_contact_if_needed(contact, name, mail_from)
# Create the contact
reply_email = generate_reply_email(email, alias)
try:
flags = Contact.FLAG_PARTNER_CREATED if from_partner else 0
contact = Contact.create(
user_id=alias.user_id,
alias_id=alias.id,
website_email=email,
name=name,
reply_email=reply_email,
mail_from=mail_from,
automatic_created=automatic_created,
flags=flags,
invalid_email=email == "",
commit=True,
)
if automatic_created:
trail = ". Automatically created"
else:
trail = ". Created by user action"
emit_alias_audit_log(
alias=alias,
action=AliasAuditLogAction.CreateContact,
message=f"Created contact {contact.id} ({contact.email}){trail}",
commit=True,
)
LOG.d(
f"Created contact {contact} for alias {alias} with email {email} invalid_email={contact.invalid_email}"
)
except IntegrityError:
Session.rollback()
LOG.info(
f"Contact with email {email} for alias_id {alias.id} already existed, fetching from DB"
)
contact = Contact.get_by(alias_id=alias.id, website_email=email)
return __update_contact_if_needed(contact, name, mail_from)
return ContactCreateResult(contact, created=True, error=None)

View File

@ -0,0 +1,206 @@
import arrow
import re
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional
from app.config import JOB_DELETE_DOMAIN
from app.db import Session
from app.email_utils import get_email_domain_part
from app.log import LOG
from app.models import User, CustomDomain, SLDomain, Mailbox, Job, DomainMailbox
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
_ALLOWED_DOMAIN_REGEX = re.compile(r"^(?!-)[A-Za-z0-9-]{1,63}(?<!-)$")
_MAX_MAILBOXES_PER_DOMAIN = 20
@dataclass
class CreateCustomDomainResult:
message: str = ""
message_category: str = ""
success: bool = False
instance: Optional[CustomDomain] = None
redirect: Optional[str] = None
class CannotUseDomainReason(Enum):
InvalidDomain = 1
BuiltinDomain = 2
DomainAlreadyUsed = 3
DomainPartOfUserEmail = 4
DomainUserInMailbox = 5
def message(self, domain: str) -> str:
if self == CannotUseDomainReason.InvalidDomain:
return "This is not a valid domain"
elif self == CannotUseDomainReason.BuiltinDomain:
return "A custom domain cannot be a built-in domain."
elif self == CannotUseDomainReason.DomainAlreadyUsed:
return f"{domain} already used"
elif self == CannotUseDomainReason.DomainPartOfUserEmail:
return "You cannot add a domain that you are currently using for your personal email. Please change your personal email to your real email"
elif self == CannotUseDomainReason.DomainUserInMailbox:
return f"{domain} already used in a SimpleLogin mailbox"
else:
raise Exception("Invalid CannotUseDomainReason")
class CannotSetCustomDomainMailboxesCause(Enum):
InvalidMailbox = "Something went wrong, please retry"
NoMailboxes = "You must select at least 1 mailbox"
TooManyMailboxes = (
f"You can only set up to {_MAX_MAILBOXES_PER_DOMAIN} mailboxes per domain"
)
@dataclass
class SetCustomDomainMailboxesResult:
success: bool
reason: Optional[CannotSetCustomDomainMailboxesCause] = None
def is_valid_domain(domain: str) -> bool:
"""
Checks that a domain is valid according to RFC 1035
"""
if len(domain) > 255:
return False
if domain.endswith("."):
domain = domain[:-1] # Strip the trailing dot
labels = domain.split(".")
if not labels:
return False
for label in labels:
if not _ALLOWED_DOMAIN_REGEX.match(label):
return False
return True
def sanitize_domain(domain: str) -> str:
new_domain = domain.lower().strip()
if new_domain.startswith("http://"):
new_domain = new_domain[len("http://") :]
if new_domain.startswith("https://"):
new_domain = new_domain[len("https://") :]
return new_domain
def can_domain_be_used(user: User, domain: str) -> Optional[CannotUseDomainReason]:
if not is_valid_domain(domain):
return CannotUseDomainReason.InvalidDomain
elif SLDomain.get_by(domain=domain):
return CannotUseDomainReason.BuiltinDomain
elif CustomDomain.get_by(domain=domain):
return CannotUseDomainReason.DomainAlreadyUsed
elif get_email_domain_part(user.email) == domain:
return CannotUseDomainReason.DomainPartOfUserEmail
elif Mailbox.filter(
Mailbox.verified.is_(True), Mailbox.email.endswith(f"@{domain}")
).first():
return CannotUseDomainReason.DomainUserInMailbox
else:
return None
def create_custom_domain(
user: User, domain: str, partner_id: Optional[int] = None
) -> CreateCustomDomainResult:
if not user.is_premium():
return CreateCustomDomainResult(
message="Only premium plan can add custom domain",
message_category="warning",
)
new_domain = sanitize_domain(domain)
domain_forbidden_cause = can_domain_be_used(user, new_domain)
if domain_forbidden_cause:
return CreateCustomDomainResult(
message=domain_forbidden_cause.message(new_domain), message_category="error"
)
new_custom_domain = CustomDomain.create(domain=new_domain, user_id=user.id)
# new domain has ownership verified if its parent has the ownership verified
for root_cd in user.custom_domains:
if new_domain.endswith("." + root_cd.domain) and root_cd.ownership_verified:
LOG.i(
"%s ownership verified thanks to %s",
new_custom_domain,
root_cd,
)
new_custom_domain.ownership_verified = True
# Add the partner_id in case it's passed
if partner_id is not None:
new_custom_domain.partner_id = partner_id
emit_user_audit_log(
user=user,
action=UserAuditLogAction.CreateCustomDomain,
message=f"Created custom domain {new_custom_domain.id} ({new_domain})",
)
Session.commit()
return CreateCustomDomainResult(
success=True,
instance=new_custom_domain,
)
def delete_custom_domain(domain: CustomDomain):
# Schedule delete domain job
LOG.w("schedule delete domain job for %s", domain)
domain.pending_deletion = True
Job.create(
name=JOB_DELETE_DOMAIN,
payload={"custom_domain_id": domain.id},
run_at=arrow.now(),
commit=True,
)
def set_custom_domain_mailboxes(
user_id: int, custom_domain: CustomDomain, mailbox_ids: List[int]
) -> SetCustomDomainMailboxesResult:
if len(mailbox_ids) == 0:
return SetCustomDomainMailboxesResult(
success=False, reason=CannotSetCustomDomainMailboxesCause.NoMailboxes
)
elif len(mailbox_ids) > _MAX_MAILBOXES_PER_DOMAIN:
return SetCustomDomainMailboxesResult(
success=False, reason=CannotSetCustomDomainMailboxesCause.TooManyMailboxes
)
mailboxes = (
Session.query(Mailbox)
.filter(
Mailbox.id.in_(mailbox_ids),
Mailbox.user_id == user_id,
Mailbox.verified == True, # noqa: E712
)
.all()
)
if len(mailboxes) != len(mailbox_ids):
return SetCustomDomainMailboxesResult(
success=False, reason=CannotSetCustomDomainMailboxesCause.InvalidMailbox
)
# first remove all existing domain-mailboxes links
DomainMailbox.filter_by(domain_id=custom_domain.id).delete()
Session.flush()
for mailbox in mailboxes:
DomainMailbox.create(domain_id=custom_domain.id, mailbox_id=mailbox.id)
mailboxes_as_str = ",".join(map(str, mailbox_ids))
emit_user_audit_log(
user=custom_domain.user,
action=UserAuditLogAction.UpdateCustomDomain,
message=f"Updated custom domain {custom_domain.id} mailboxes (domain={custom_domain.domain}) (mailboxes={mailboxes_as_str})",
)
Session.commit()
return SetCustomDomainMailboxesResult(success=True)

View File

@ -1,37 +1,228 @@
from dataclasses import dataclass
from typing import List, Optional
from app import config
from app.constants import DMARC_RECORD
from app.db import Session
from app.dns_utils import get_cname_record
from app.dns_utils import (
MxRecord,
DNSClient,
is_mx_equivalent,
get_network_dns_client,
)
from app.models import CustomDomain
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.utils import random_string
@dataclass
class DomainValidationResult:
success: bool
errors: [str]
class CustomDomainValidation:
def __init__(self, dkim_domain: str):
def __init__(
self,
dkim_domain: str,
dns_client: DNSClient = get_network_dns_client(),
partner_domains: Optional[dict[int, str]] = None,
partner_domains_validation_prefixes: Optional[dict[int, str]] = None,
):
self.dkim_domain = dkim_domain
self._dkim_records = {
(f"{key}._domainkey", f"{key}._domainkey.{self.dkim_domain}")
self._dns_client = dns_client
self._partner_domains = partner_domains or config.PARTNER_DNS_CUSTOM_DOMAINS
self._partner_domain_validation_prefixes = (
partner_domains_validation_prefixes
or config.PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES
)
def get_ownership_verification_record(self, domain: CustomDomain) -> str:
prefix = "sl"
if (
domain.partner_id is not None
and domain.partner_id in self._partner_domain_validation_prefixes
):
prefix = self._partner_domain_validation_prefixes[domain.partner_id]
if not domain.ownership_txt_token:
domain.ownership_txt_token = random_string(30)
Session.commit()
return f"{prefix}-verification={domain.ownership_txt_token}"
def get_expected_mx_records(self, domain: CustomDomain) -> list[MxRecord]:
records = []
if domain.partner_id is not None and domain.partner_id in self._partner_domains:
domain = self._partner_domains[domain.partner_id]
records.append(MxRecord(10, f"mx1.{domain}."))
records.append(MxRecord(20, f"mx2.{domain}."))
else:
# Default ones
for priority, domain in config.EMAIL_SERVERS_WITH_PRIORITY:
records.append(MxRecord(priority, domain))
return records
def get_expected_spf_domain(self, domain: CustomDomain) -> str:
if domain.partner_id is not None and domain.partner_id in self._partner_domains:
return self._partner_domains[domain.partner_id]
else:
return config.EMAIL_DOMAIN
def get_expected_spf_record(self, domain: CustomDomain) -> str:
spf_domain = self.get_expected_spf_domain(domain)
return f"v=spf1 include:{spf_domain} ~all"
def get_dkim_records(self, domain: CustomDomain) -> {str: str}:
"""
Get a list of dkim records to set up. Depending on the custom_domain, whether if it's from a partner or not,
it will return the default ones or the partner ones.
"""
# By default use the default domain
dkim_domain = self.dkim_domain
if domain.partner_id is not None:
# Domain is from a partner. Retrieve the partner config and use that domain if exists
dkim_domain = self._partner_domains.get(domain.partner_id, dkim_domain)
return {
f"{key}._domainkey": f"{key}._domainkey.{dkim_domain}"
for key in ("dkim", "dkim02", "dkim03")
}
def get_dkim_records(self) -> {str: str}:
"""
Get a list of dkim records to set up. It will be
"""
return self._dkim_records
def validate_dkim_records(self, custom_domain: CustomDomain) -> dict[str, str]:
"""
Check if dkim records are properly set for this custom domain.
Returns empty list if all records are ok. Other-wise return the records that aren't properly configured
"""
correct_records = {}
invalid_records = {}
for prefix, expected_record in self.get_dkim_records():
expected_records = self.get_dkim_records(custom_domain)
for prefix, expected_record in expected_records.items():
custom_record = f"{prefix}.{custom_domain.domain}"
dkim_record = get_cname_record(custom_record)
if dkim_record != expected_record:
dkim_record = self._dns_client.get_cname_record(custom_record)
if dkim_record == expected_record:
correct_records[prefix] = custom_record
else:
invalid_records[custom_record] = dkim_record or "empty"
# HACK: If dkim is enabled, don't disable it to give users time to update their CNAMES
# HACK
# As initially we only had one dkim record, we want to allow users that had only the original dkim record and
# the domain validated to continue seeing it as validated (although showing them the missing records).
# However, if not even the original dkim record is right, even if the domain was dkim_verified in the past,
# we will remove the dkim_verified flag.
# This is done in order to give users with the old dkim config (only one) to update their CNAMEs
if custom_domain.dkim_verified:
return invalid_records
# Check if at least the original dkim is there
if correct_records.get("dkim._domainkey") is not None:
# Original dkim record is there. Return the missing records (if any) and don't clear the flag
return invalid_records
# Original DKIM record is not there, which means the DKIM config is not finished. Proceed with the
# rest of the code path, returning the invalid records and clearing the flag
custom_domain.dkim_verified = len(invalid_records) == 0
if custom_domain.dkim_verified:
emit_user_audit_log(
user=custom_domain.user,
action=UserAuditLogAction.VerifyCustomDomain,
message=f"Verified DKIM records for custom domain {custom_domain.id} ({custom_domain.domain})",
)
Session.commit()
return invalid_records
def validate_domain_ownership(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
"""
Check if the custom_domain has added the ownership verification records
"""
txt_records = self._dns_client.get_txt_record(custom_domain.domain)
expected_verification_record = self.get_ownership_verification_record(
custom_domain
)
if expected_verification_record in txt_records:
custom_domain.ownership_verified = True
emit_user_audit_log(
user=custom_domain.user,
action=UserAuditLogAction.VerifyCustomDomain,
message=f"Verified ownership for custom domain {custom_domain.id} ({custom_domain.domain})",
)
Session.commit()
return DomainValidationResult(success=True, errors=[])
else:
return DomainValidationResult(success=False, errors=txt_records)
def validate_mx_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
mx_domains = self._dns_client.get_mx_domains(custom_domain.domain)
expected_mx_records = self.get_expected_mx_records(custom_domain)
if not is_mx_equivalent(mx_domains, expected_mx_records):
return DomainValidationResult(
success=False,
errors=[f"{record.priority} {record.domain}" for record in mx_domains],
)
else:
custom_domain.verified = True
emit_user_audit_log(
user=custom_domain.user,
action=UserAuditLogAction.VerifyCustomDomain,
message=f"Verified MX records for custom domain {custom_domain.id} ({custom_domain.domain})",
)
Session.commit()
return DomainValidationResult(success=True, errors=[])
def validate_spf_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
spf_domains = self._dns_client.get_spf_domain(custom_domain.domain)
expected_spf_domain = self.get_expected_spf_domain(custom_domain)
if expected_spf_domain in spf_domains:
custom_domain.spf_verified = True
emit_user_audit_log(
user=custom_domain.user,
action=UserAuditLogAction.VerifyCustomDomain,
message=f"Verified SPF records for custom domain {custom_domain.id} ({custom_domain.domain})",
)
Session.commit()
return DomainValidationResult(success=True, errors=[])
else:
custom_domain.spf_verified = False
Session.commit()
txt_records = self._dns_client.get_txt_record(custom_domain.domain)
cleaned_records = self.__clean_spf_records(txt_records, custom_domain)
return DomainValidationResult(
success=False,
errors=cleaned_records,
)
def validate_dmarc_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
txt_records = self._dns_client.get_txt_record("_dmarc." + custom_domain.domain)
if DMARC_RECORD in txt_records:
custom_domain.dmarc_verified = True
emit_user_audit_log(
user=custom_domain.user,
action=UserAuditLogAction.VerifyCustomDomain,
message=f"Verified DMARC records for custom domain {custom_domain.id} ({custom_domain.domain})",
)
Session.commit()
return DomainValidationResult(success=True, errors=[])
else:
custom_domain.dmarc_verified = False
Session.commit()
return DomainValidationResult(success=False, errors=txt_records)
def __clean_spf_records(
self, txt_records: List[str], custom_domain: CustomDomain
) -> List[str]:
final_records = []
verification_record = self.get_ownership_verification_record(custom_domain)
for record in txt_records:
if record != verification_record:
final_records.append(record)
return final_records

View File

@ -1,5 +1,6 @@
from dataclasses import dataclass
from operator import or_
from typing import Optional
from flask import render_template, request, redirect, flash
from flask import url_for
@ -9,13 +10,11 @@ from sqlalchemy import and_, func, case
from wtforms import StringField, validators, ValidationError
# Need to import directly from config to allow modification from the tests
from app import config, parallel_limiter
from app import config, parallel_limiter, contact_utils
from app.alias_audit_log_utils import emit_alias_audit_log, AliasAuditLogAction
from app.contact_utils import ContactCreateError
from app.dashboard.base import dashboard_bp
from app.db import Session
from app.email_utils import (
generate_reply_email,
parse_full_address,
)
from app.email_validation import is_valid_email
from app.errors import (
CannotCreateContactForReverseAlias,
@ -24,8 +23,8 @@ from app.errors import (
ErrContactAlreadyExists,
)
from app.log import LOG
from app.models import Alias, Contact, EmailLog, User
from app.utils import sanitize_email, CSRFValidationForm
from app.models import Alias, Contact, EmailLog
from app.utils import CSRFValidationForm
def email_validator():
@ -51,7 +50,7 @@ def email_validator():
return _check
def create_contact(user: User, alias: Alias, contact_address: str) -> Contact:
def create_contact(alias: Alias, contact_address: str) -> Contact:
"""
Create a contact for a user. Can be restricted for new free users by enabling DISABLE_CREATE_CONTACTS_FOR_FREE_USERS.
Can throw exceptions:
@ -61,37 +60,23 @@ def create_contact(user: User, alias: Alias, contact_address: str) -> Contact:
"""
if not contact_address:
raise ErrAddressInvalid("Empty address")
try:
contact_name, contact_email = parse_full_address(contact_address)
except ValueError:
output = contact_utils.create_contact(email=contact_address, alias=alias)
if output.error == ContactCreateError.InvalidEmail:
raise ErrAddressInvalid(contact_address)
contact_email = sanitize_email(contact_email)
if not is_valid_email(contact_email):
raise ErrAddressInvalid(contact_email)
contact = Contact.get_by(alias_id=alias.id, website_email=contact_email)
if contact:
raise ErrContactAlreadyExists(contact)
if not user.can_create_contacts():
elif output.error == ContactCreateError.NotAllowed:
raise ErrContactErrorUpgradeNeeded()
elif output.error is not None:
raise ErrAddressInvalid("Invalid address")
elif not output.created:
raise ErrContactAlreadyExists(output.contact)
contact = Contact.create(
user_id=alias.user_id,
alias_id=alias.id,
website_email=contact_email,
name=contact_name,
reply_email=generate_reply_email(contact_email, alias),
)
contact = output.contact
LOG.d(
"create reverse-alias for %s %s, reverse alias:%s",
contact_address,
alias,
contact.reply_email,
)
Session.commit()
return contact
@ -207,7 +192,7 @@ def get_contact_infos(
def delete_contact(alias: Alias, contact_id: int):
contact = Contact.get(contact_id)
contact: Optional[Contact] = Contact.get(contact_id)
if not contact:
flash("Unknown error. Refresh the page", "warning")
@ -215,6 +200,11 @@ def delete_contact(alias: Alias, contact_id: int):
flash("You cannot delete reverse-alias", "warning")
else:
delete_contact_email = contact.website_email
emit_alias_audit_log(
alias=alias,
action=AliasAuditLogAction.DeleteContact,
message=f"Delete contact {contact_id} ({contact.email})",
)
Contact.delete(contact_id)
Session.commit()
@ -261,7 +251,7 @@ def alias_contact_manager(alias_id):
if new_contact_form.validate():
contact_address = new_contact_form.email.data.strip()
try:
contact = create_contact(current_user, alias, contact_address)
contact = create_contact(alias, contact_address)
except (
ErrContactErrorUpgradeNeeded,
ErrAddressInvalid,

View File

@ -7,6 +7,7 @@ from flask import render_template, redirect, url_for, flash, request
from flask_login import login_required, current_user
from app import config
from app.alias_audit_log_utils import emit_alias_audit_log, AliasAuditLogAction
from app.alias_utils import transfer_alias
from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required
@ -57,6 +58,12 @@ def alias_transfer_send_route(alias_id):
transfer_token = f"{alias.id}.{secrets.token_urlsafe(32)}"
alias.transfer_token = hmac_alias_transfer_token(transfer_token)
alias.transfer_token_expiration = arrow.utcnow().shift(hours=24)
emit_alias_audit_log(
alias,
AliasAuditLogAction.InitiateTransferAlias,
"Initiated alias transfer",
)
Session.commit()
alias_transfer_url = (
config.URL

View File

@ -1,8 +1,11 @@
from typing import Optional
from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user
from flask_wtf import FlaskForm
from wtforms import StringField, validators
from app.alias_audit_log_utils import emit_alias_audit_log, AliasAuditLogAction
from app.dashboard.base import dashboard_bp
from app.db import Session
from app.models import Contact
@ -20,7 +23,7 @@ class PGPContactForm(FlaskForm):
@dashboard_bp.route("/contact/<int:contact_id>/", methods=["GET", "POST"])
@login_required
def contact_detail_route(contact_id):
contact = Contact.get(contact_id)
contact: Optional[Contact] = Contact.get(contact_id)
if not contact or contact.user_id != current_user.id:
flash("You cannot see this page", "warning")
return redirect(url_for("dashboard.index"))
@ -50,6 +53,11 @@ def contact_detail_route(contact_id):
except PGPException:
flash("Cannot add the public key, please verify it", "error")
else:
emit_alias_audit_log(
alias=alias,
action=AliasAuditLogAction.UpdateContact,
message=f"Added PGP key {contact.pgp_public_key} for contact {contact_id} ({contact.email})",
)
Session.commit()
flash(
f"PGP public key for {contact.email} is saved successfully",
@ -62,6 +70,11 @@ def contact_detail_route(contact_id):
)
elif pgp_form.action.data == "remove":
# Free user can decide to remove contact PGP key
emit_alias_audit_log(
alias=alias,
action=AliasAuditLogAction.UpdateContact,
message=f"Removed PGP key {contact.pgp_public_key} for contact {contact_id} ({contact.email})",
)
contact.pgp_public_key = None
contact.pgp_finger_print = None
Session.commit()

View File

@ -5,11 +5,9 @@ from wtforms import StringField, validators
from app import parallel_limiter
from app.config import EMAIL_SERVERS_WITH_PRIORITY
from app.custom_domain_utils import create_custom_domain
from app.dashboard.base import dashboard_bp
from app.db import Session
from app.email_utils import get_email_domain_part
from app.log import LOG
from app.models import CustomDomain, Mailbox, DomainMailbox, SLDomain
from app.models import CustomDomain
class NewCustomDomainForm(FlaskForm):
@ -23,13 +21,12 @@ class NewCustomDomainForm(FlaskForm):
@parallel_limiter.lock(only_when=lambda: request.method == "POST")
def custom_domain():
custom_domains = CustomDomain.filter_by(
user_id=current_user.id, is_sl_subdomain=False
user_id=current_user.id,
is_sl_subdomain=False,
pending_deletion=False,
).all()
mailboxes = current_user.mailboxes()
new_custom_domain_form = NewCustomDomainForm()
errors = {}
if request.method == "POST":
if request.form.get("form-name") == "create":
if not current_user.is_premium():
@ -37,87 +34,25 @@ def custom_domain():
return redirect(url_for("dashboard.custom_domain"))
if new_custom_domain_form.validate():
new_domain = new_custom_domain_form.domain.data.lower().strip()
if new_domain.startswith("http://"):
new_domain = new_domain[len("http://") :]
if new_domain.startswith("https://"):
new_domain = new_domain[len("https://") :]
if SLDomain.get_by(domain=new_domain):
flash("A custom domain cannot be a built-in domain.", "error")
elif CustomDomain.get_by(domain=new_domain):
flash(f"{new_domain} already used", "error")
elif get_email_domain_part(current_user.email) == new_domain:
flash(
"You cannot add a domain that you are currently using for your personal email. "
"Please change your personal email to your real email",
"error",
)
elif Mailbox.filter(
Mailbox.verified.is_(True), Mailbox.email.endswith(f"@{new_domain}")
).first():
flash(
f"{new_domain} already used in a SimpleLogin mailbox", "error"
)
else:
new_custom_domain = CustomDomain.create(
domain=new_domain, user_id=current_user.id
)
# new domain has ownership verified if its parent has the ownership verified
for root_cd in current_user.custom_domains:
if (
new_domain.endswith("." + root_cd.domain)
and root_cd.ownership_verified
):
LOG.i(
"%s ownership verified thanks to %s",
new_custom_domain,
root_cd,
)
new_custom_domain.ownership_verified = True
Session.commit()
mailbox_ids = request.form.getlist("mailbox_ids")
if mailbox_ids:
# check if mailbox is not tempered with
mailboxes = []
for mailbox_id in mailbox_ids:
mailbox = Mailbox.get(mailbox_id)
if (
not mailbox
or mailbox.user_id != current_user.id
or not mailbox.verified
):
flash("Something went wrong, please retry", "warning")
return redirect(url_for("dashboard.custom_domain"))
mailboxes.append(mailbox)
for mailbox in mailboxes:
DomainMailbox.create(
domain_id=new_custom_domain.id, mailbox_id=mailbox.id
)
Session.commit()
flash(
f"New domain {new_custom_domain.domain} is created", "success"
)
res = create_custom_domain(
user=current_user, domain=new_custom_domain_form.domain.data
)
if res.success:
flash(f"New domain {res.instance.domain} is created", "success")
return redirect(
url_for(
"dashboard.domain_detail_dns",
custom_domain_id=new_custom_domain.id,
custom_domain_id=res.instance.id,
)
)
else:
flash(res.message, res.message_category)
if res.redirect:
return redirect(url_for(res.redirect))
return render_template(
"dashboard/custom_domain.html",
custom_domains=custom_domains,
new_custom_domain_form=new_custom_domain_form,
EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY,
errors=errors,
mailboxes=mailboxes,
)

View File

@ -8,6 +8,7 @@ from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required
from app.log import LOG
from app.models import Subscription, Job
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
class DeleteDirForm(FlaskForm):
@ -33,6 +34,11 @@ def delete_account():
# Schedule delete account job
LOG.w("schedule delete account job for %s", current_user)
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UserMarkedForDeletion,
message=f"User {current_user.id} ({current_user.email}) marked for deletion via webapp",
)
Job.create(
name=JOB_DELETE_ACCOUNT,
payload={"user_id": current_user.id},

View File

@ -1,3 +1,5 @@
from typing import Optional
from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user
from flask_wtf import FlaskForm
@ -20,6 +22,7 @@ from app.dashboard.base import dashboard_bp
from app.db import Session
from app.errors import DirectoryInTrashError
from app.models import Directory, Mailbox, DirectoryMailbox
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
class NewDirForm(FlaskForm):
@ -69,7 +72,9 @@ def directory():
if not delete_dir_form.validate():
flash("Invalid request", "warning")
return redirect(url_for("dashboard.directory"))
dir_obj = Directory.get(delete_dir_form.directory_id.data)
dir_obj: Optional[Directory] = Directory.get(
delete_dir_form.directory_id.data
)
if not dir_obj:
flash("Unknown error. Refresh the page", "warning")
@ -79,6 +84,11 @@ def directory():
return redirect(url_for("dashboard.directory"))
name = dir_obj.name
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.DeleteDirectory,
message=f"Delete directory {dir_obj.id} ({dir_obj.name})",
)
Directory.delete(dir_obj.id)
Session.commit()
flash(f"Directory {name} has been deleted", "success")
@ -90,7 +100,7 @@ def directory():
flash("Invalid request", "warning")
return redirect(url_for("dashboard.directory"))
dir_id = toggle_dir_form.directory_id.data
dir_obj = Directory.get(dir_id)
dir_obj: Optional[Directory] = Directory.get(dir_id)
if not dir_obj or dir_obj.user_id != current_user.id:
flash("Unknown error. Refresh the page", "warning")
@ -103,6 +113,11 @@ def directory():
dir_obj.disabled = True
flash(f"On-the-fly is disabled for {dir_obj.name}", "warning")
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateDirectory,
message=f"Updated directory {dir_obj.id} ({dir_obj.name}) set disabled = {dir_obj.disabled}",
)
Session.commit()
return redirect(url_for("dashboard.directory"))
@ -112,7 +127,7 @@ def directory():
flash("Invalid request", "warning")
return redirect(url_for("dashboard.directory"))
dir_id = update_dir_form.directory_id.data
dir_obj = Directory.get(dir_id)
dir_obj: Optional[Directory] = Directory.get(dir_id)
if not dir_obj or dir_obj.user_id != current_user.id:
flash("Unknown error. Refresh the page", "warning")
@ -143,6 +158,12 @@ def directory():
for mailbox in mailboxes:
DirectoryMailbox.create(directory_id=dir_obj.id, mailbox_id=mailbox.id)
mailboxes_as_str = ",".join(map(str, mailbox_ids))
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateDirectory,
message=f"Updated directory {dir_obj.id} ({dir_obj.name}) mailboxes ({mailboxes_as_str})",
)
Session.commit()
flash(f"Directory {dir_obj.name} has been updated", "success")
@ -181,6 +202,11 @@ def directory():
new_dir = Directory.create(
name=new_dir_name, user_id=current_user.id
)
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.CreateDirectory,
message=f"New directory {new_dir.name} ({new_dir.name})",
)
except DirectoryInTrashError:
flash(
f"{new_dir_name} has been used before and cannot be reused",

View File

@ -1,33 +1,26 @@
import re
import arrow
from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user
from flask_wtf import FlaskForm
from wtforms import StringField, validators, IntegerField
from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN, JOB_DELETE_DOMAIN
from app.constants import DMARC_RECORD
from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN
from app.custom_domain_utils import delete_custom_domain, set_custom_domain_mailboxes
from app.custom_domain_validation import CustomDomainValidation
from app.dashboard.base import dashboard_bp
from app.db import Session
from app.dns_utils import (
get_mx_domains,
get_spf_domain,
get_txt_record,
is_mx_equivalent,
)
from app.log import LOG
from app.models import (
CustomDomain,
Alias,
DomainDeletedAlias,
Mailbox,
DomainMailbox,
AutoCreateRule,
AutoCreateRuleMailbox,
Job,
)
from app.regex_utils import regex_match
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.utils import random_string, CSRFValidationForm
@ -44,13 +37,9 @@ def domain_detail_dns(custom_domain_id):
custom_domain.ownership_txt_token = random_string(30)
Session.commit()
spf_record = f"v=spf1 include:{EMAIL_DOMAIN} ~all"
domain_validator = CustomDomainValidation(EMAIL_DOMAIN)
csrf_form = CSRFValidationForm()
dmarc_record = "v=DMARC1; p=quarantine; pct=100; adkim=s; aspf=s"
mx_ok = spf_ok = dkim_ok = dmarc_ok = ownership_ok = True
mx_errors = spf_errors = dkim_errors = dmarc_errors = ownership_errors = []
@ -59,15 +48,14 @@ def domain_detail_dns(custom_domain_id):
flash("Invalid request", "warning")
return redirect(request.url)
if request.form.get("form-name") == "check-ownership":
txt_records = get_txt_record(custom_domain.domain)
if custom_domain.get_ownership_dns_txt_value() in txt_records:
ownership_validation_result = domain_validator.validate_domain_ownership(
custom_domain
)
if ownership_validation_result.success:
flash(
"Domain ownership is verified. Please proceed to the other records setup",
"success",
)
custom_domain.ownership_verified = True
Session.commit()
return redirect(
url_for(
"dashboard.domain_detail_dns",
@ -78,36 +66,28 @@ def domain_detail_dns(custom_domain_id):
else:
flash("We can't find the needed TXT record", "error")
ownership_ok = False
ownership_errors = txt_records
ownership_errors = ownership_validation_result.errors
elif request.form.get("form-name") == "check-mx":
mx_domains = get_mx_domains(custom_domain.domain)
if not is_mx_equivalent(mx_domains, EMAIL_SERVERS_WITH_PRIORITY):
flash("The MX record is not correctly set", "warning")
mx_ok = False
# build mx_errors to show to user
mx_errors = [
f"{priority} {domain}" for (priority, domain) in mx_domains
]
else:
mx_validation_result = domain_validator.validate_mx_records(custom_domain)
if mx_validation_result.success:
flash(
"Your domain can start receiving emails. You can now use it to create alias",
"success",
)
custom_domain.verified = True
Session.commit()
return redirect(
url_for(
"dashboard.domain_detail_dns", custom_domain_id=custom_domain.id
)
)
else:
flash("The MX record is not correctly set", "warning")
mx_ok = False
mx_errors = mx_validation_result.errors
elif request.form.get("form-name") == "check-spf":
spf_domains = get_spf_domain(custom_domain.domain)
if EMAIL_DOMAIN in spf_domains:
custom_domain.spf_verified = True
Session.commit()
spf_validation_result = domain_validator.validate_spf_records(custom_domain)
if spf_validation_result.success:
flash("SPF is setup correctly", "success")
return redirect(
url_for(
@ -115,14 +95,12 @@ def domain_detail_dns(custom_domain_id):
)
)
else:
custom_domain.spf_verified = False
Session.commit()
flash(
f"SPF: {EMAIL_DOMAIN} is not included in your SPF record.",
"warning",
)
spf_ok = False
spf_errors = get_txt_record(custom_domain.domain)
spf_errors = spf_validation_result.errors
elif request.form.get("form-name") == "check-dkim":
dkim_errors = domain_validator.validate_dkim_records(custom_domain)
@ -138,10 +116,10 @@ def domain_detail_dns(custom_domain_id):
flash("DKIM: the CNAME record is not correctly set", "warning")
elif request.form.get("form-name") == "check-dmarc":
txt_records = get_txt_record("_dmarc." + custom_domain.domain)
if dmarc_record in txt_records:
custom_domain.dmarc_verified = True
Session.commit()
dmarc_validation_result = domain_validator.validate_dmarc_records(
custom_domain
)
if dmarc_validation_result.success:
flash("DMARC is setup correctly", "success")
return redirect(
url_for(
@ -149,19 +127,23 @@ def domain_detail_dns(custom_domain_id):
)
)
else:
custom_domain.dmarc_verified = False
Session.commit()
flash(
"DMARC: The TXT record is not correctly set",
"warning",
)
dmarc_ok = False
dmarc_errors = txt_records
dmarc_errors = dmarc_validation_result.errors
return render_template(
"dashboard/domain_detail/dns.html",
EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY,
dkim_records=domain_validator.get_dkim_records(),
ownership_record=domain_validator.get_ownership_verification_record(
custom_domain
),
expected_mx_records=domain_validator.get_expected_mx_records(custom_domain),
dkim_records=domain_validator.get_dkim_records(custom_domain),
spf_record=domain_validator.get_expected_spf_record(custom_domain),
dmarc_record=DMARC_RECORD,
**locals(),
)
@ -183,6 +165,11 @@ def domain_detail(custom_domain_id):
return redirect(request.url)
if request.form.get("form-name") == "switch-catch-all":
custom_domain.catch_all = not custom_domain.catch_all
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateCustomDomain,
message=f"Switched custom domain {custom_domain.id} ({custom_domain.domain}) catch all to {custom_domain.catch_all}",
)
Session.commit()
if custom_domain.catch_all:
@ -201,6 +188,11 @@ def domain_detail(custom_domain_id):
elif request.form.get("form-name") == "set-name":
if request.form.get("action") == "save":
custom_domain.name = request.form.get("alias-name").replace("\n", "")
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateCustomDomain,
message=f"Switched custom domain {custom_domain.id} ({custom_domain.domain}) name",
)
Session.commit()
flash(
f"Default alias name for Domain {custom_domain.domain} has been set",
@ -208,6 +200,11 @@ def domain_detail(custom_domain_id):
)
else:
custom_domain.name = None
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateCustomDomain,
message=f"Cleared custom domain {custom_domain.id} ({custom_domain.domain}) name",
)
Session.commit()
flash(
f"Default alias name for Domain {custom_domain.domain} has been removed",
@ -221,6 +218,11 @@ def domain_detail(custom_domain_id):
custom_domain.random_prefix_generation = (
not custom_domain.random_prefix_generation
)
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateCustomDomain,
message=f"Switched custom domain {custom_domain.id} ({custom_domain.domain}) random prefix generation to {custom_domain.random_prefix_generation}",
)
Session.commit()
if custom_domain.random_prefix_generation:
@ -238,40 +240,16 @@ def domain_detail(custom_domain_id):
)
elif request.form.get("form-name") == "update":
mailbox_ids = request.form.getlist("mailbox_ids")
# check if mailbox is not tempered with
mailboxes = []
for mailbox_id in mailbox_ids:
mailbox = Mailbox.get(mailbox_id)
if (
not mailbox
or mailbox.user_id != current_user.id
or not mailbox.verified
):
flash("Something went wrong, please retry", "warning")
return redirect(
url_for(
"dashboard.domain_detail", custom_domain_id=custom_domain.id
)
)
mailboxes.append(mailbox)
result = set_custom_domain_mailboxes(
user_id=current_user.id,
custom_domain=custom_domain,
mailbox_ids=mailbox_ids,
)
if not mailboxes:
flash("You must select at least 1 mailbox", "warning")
return redirect(
url_for(
"dashboard.domain_detail", custom_domain_id=custom_domain.id
)
)
# first remove all existing domain-mailboxes links
DomainMailbox.filter_by(domain_id=custom_domain.id).delete()
Session.flush()
for mailbox in mailboxes:
DomainMailbox.create(domain_id=custom_domain.id, mailbox_id=mailbox.id)
Session.commit()
flash(f"{custom_domain.domain} mailboxes has been updated", "success")
if result.success:
flash(f"{custom_domain.domain} mailboxes has been updated", "success")
else:
flash(result.reason.value, "warning")
return redirect(
url_for("dashboard.domain_detail", custom_domain_id=custom_domain.id)
@ -279,16 +257,8 @@ def domain_detail(custom_domain_id):
elif request.form.get("form-name") == "delete":
name = custom_domain.domain
LOG.d("Schedule deleting %s", custom_domain)
# Schedule delete domain job
LOG.w("schedule delete domain job for %s", custom_domain)
Job.create(
name=JOB_DELETE_DOMAIN,
payload={"custom_domain_id": custom_domain.id},
run_at=arrow.now(),
commit=True,
)
delete_custom_domain(custom_domain)
flash(
f"{name} scheduled for deletion."

View File

@ -149,7 +149,9 @@ def index():
)
flash(f"Alias {email} has been deleted", "success")
elif request.form.get("form-name") == "disable-alias":
alias_utils.change_alias_status(alias, enabled=False)
alias_utils.change_alias_status(
alias, enabled=False, message="Set enabled=False from dashboard"
)
Session.commit()
flash(f"Alias {alias.email} has been disabled", "success")

View File

@ -1,6 +1,7 @@
import base64
import binascii
import json
from typing import Optional
from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user
@ -15,6 +16,7 @@ from app.dashboard.base import dashboard_bp
from app.db import Session
from app.log import LOG
from app.models import Mailbox
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.utils import CSRFValidationForm
@ -123,7 +125,12 @@ def mailbox_verify():
if not code:
# Old way
return verify_with_signed_secret(mailbox_id)
mailbox = mailbox_utils.verify_mailbox_code(current_user, mailbox_id, code)
try:
mailbox = mailbox_utils.verify_mailbox_code(current_user, mailbox_id, code)
except mailbox_utils.MailboxError as e:
LOG.i(f"Cannot verify mailbox {mailbox_id} because of {e}")
flash(f"Cannot verify mailbox: {e.msg}", "error")
return redirect(url_for("dashboard.mailbox_route"))
LOG.d("Mailbox %s is verified", mailbox)
return render_template("dashboard/mailbox_validation.html", mailbox=mailbox)
@ -146,7 +153,7 @@ def verify_with_signed_secret(request: str):
flash("Invalid link. Please delete and re-add your mailbox", "error")
return redirect(url_for("dashboard.mailbox_route"))
mailbox_id = mailbox_data[0]
mailbox = Mailbox.get(mailbox_id)
mailbox: Optional[Mailbox] = Mailbox.get(mailbox_id)
if not mailbox:
flash("Invalid link", "error")
return redirect(url_for("dashboard.mailbox_route"))
@ -156,6 +163,11 @@ def verify_with_signed_secret(request: str):
return redirect(url_for("dashboard.mailbox_route"))
mailbox.verified = True
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.VerifyMailbox,
message=f"Verified mailbox {mailbox.id} ({mailbox.email})",
)
Session.commit()
LOG.d("Mailbox %s is verified", mailbox)

View File

@ -16,10 +16,11 @@ from app.db import Session
from app.email_utils import email_can_be_used_as_mailbox
from app.email_utils import mailbox_already_used, render, send_email
from app.extensions import limiter
from app.log import LOG
from app.mailbox_utils import perform_mailbox_email_change, MailboxEmailChangeError
from app.models import Alias, AuthorizedAddress
from app.models import Mailbox
from app.pgp_utils import PGPException, load_public_key_and_check
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.utils import sanitize_email, CSRFValidationForm
@ -88,8 +89,12 @@ def mailbox_detail_route(mailbox_id):
flash("SPF enforcement globally not enabled", "error")
return redirect(url_for("dashboard.index"))
mailbox.force_spf = (
True if request.form.get("spf-status") == "on" else False
force_spf_value = request.form.get("spf-status") == "on"
mailbox.force_spf = force_spf_value
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Set force_spf to {force_spf_value} on mailbox {mailbox_id} ({mailbox.email})",
)
Session.commit()
flash(
@ -113,6 +118,11 @@ def mailbox_detail_route(mailbox_id):
if AuthorizedAddress.get_by(mailbox_id=mailbox.id, email=address):
flash(f"{address} already added", "error")
else:
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Add authorized address {address} to mailbox {mailbox_id} ({mailbox.email})",
)
AuthorizedAddress.create(
user_id=current_user.id,
mailbox_id=mailbox.id,
@ -133,6 +143,11 @@ def mailbox_detail_route(mailbox_id):
flash("Unknown error. Refresh the page", "warning")
else:
address = authorized_address.email
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Remove authorized address {address} from mailbox {mailbox_id} ({mailbox.email})",
)
AuthorizedAddress.delete(authorized_address_id)
Session.commit()
flash(f"{address} has been deleted", "success")
@ -165,6 +180,11 @@ def mailbox_detail_route(mailbox_id):
except PGPException:
flash("Cannot add the public key, please verify it", "error")
else:
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Add PGP Key {mailbox.pgp_finger_print} to mailbox {mailbox_id} ({mailbox.email})",
)
Session.commit()
flash("Your PGP public key is saved successfully", "success")
return redirect(
@ -172,6 +192,11 @@ def mailbox_detail_route(mailbox_id):
)
elif request.form.get("action") == "remove":
# Free user can decide to remove their added PGP key
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Remove PGP Key {mailbox.pgp_finger_print} from mailbox {mailbox_id} ({mailbox.email})",
)
mailbox.pgp_public_key = None
mailbox.pgp_finger_print = None
mailbox.disable_pgp = False
@ -191,9 +216,19 @@ def mailbox_detail_route(mailbox_id):
)
else:
mailbox.disable_pgp = False
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Enabled PGP for mailbox {mailbox_id} ({mailbox.email})",
)
flash(f"PGP is enabled on {mailbox.email}", "info")
else:
mailbox.disable_pgp = True
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Disabled PGP for mailbox {mailbox_id} ({mailbox.email})",
)
flash(f"PGP is disabled on {mailbox.email}", "info")
Session.commit()
@ -203,6 +238,11 @@ def mailbox_detail_route(mailbox_id):
elif request.form.get("form-name") == "generic-subject":
if request.form.get("action") == "save":
mailbox.generic_subject = request.form.get("generic-subject")
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Set generic subject for mailbox {mailbox_id} ({mailbox.email})",
)
Session.commit()
flash("Generic subject is enabled", "success")
return redirect(
@ -210,6 +250,11 @@ def mailbox_detail_route(mailbox_id):
)
elif request.form.get("action") == "remove":
mailbox.generic_subject = None
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Remove generic subject for mailbox {mailbox_id} ({mailbox.email})",
)
Session.commit()
flash("Generic subject is disabled", "success")
return redirect(
@ -272,7 +317,7 @@ def cancel_mailbox_change_route(mailbox_id):
@dashboard_bp.route("/mailbox/confirm_change")
def mailbox_confirm_change_route():
def mailbox_confirm_email_change_route():
s = TimestampSigner(MAILBOX_SECRET)
signed_mailbox_id = request.args.get("mailbox_id")
@ -281,30 +326,20 @@ def mailbox_confirm_change_route():
except Exception:
flash("Invalid link", "error")
return redirect(url_for("dashboard.index"))
else:
mailbox = Mailbox.get(mailbox_id)
# new_email can be None if user cancels change in the meantime
if mailbox and mailbox.new_email:
user = mailbox.user
if Mailbox.get_by(email=mailbox.new_email, user_id=user.id):
flash(f"{mailbox.new_email} is already used", "error")
return redirect(
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox.id)
)
res = perform_mailbox_email_change(mailbox_id)
mailbox.email = mailbox.new_email
mailbox.new_email = None
# mark mailbox as verified if the change request is sent from an unverified mailbox
mailbox.verified = True
Session.commit()
LOG.d("Mailbox change %s is verified", mailbox)
flash(f"The {mailbox.email} is updated", "success")
flash(res.message, res.message_category)
if res.error:
if res.error == MailboxEmailChangeError.EmailAlreadyUsed:
return redirect(
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox.id)
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id)
)
else:
flash("Invalid link", "error")
elif res.error == MailboxEmailChangeError.InvalidId:
return redirect(url_for("dashboard.index"))
else:
raise Exception("Unhandled MailboxEmailChangeError")
else:
return redirect(
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id)
)

View File

@ -11,6 +11,7 @@ from app.dashboard.base import dashboard_bp
from app.errors import SubdomainInTrashError
from app.log import LOG
from app.models import CustomDomain, Mailbox, SLDomain
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
# Only lowercase letters, numbers, dashes (-) are currently supported
_SUBDOMAIN_PATTERN = r"[0-9a-z-]{1,}"
@ -102,6 +103,12 @@ def subdomain_route():
ownership_verified=True,
commit=True,
)
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.CreateCustomDomain,
message=f"Create subdomain {new_custom_domain.id} ({full_domain})",
commit=True,
)
except SubdomainInTrashError:
flash(
f"{full_domain} has been used before and cannot be reused",

View File

@ -32,7 +32,9 @@ def unsubscribe(alias_id):
# automatic unsubscribe, according to https://tools.ietf.org/html/rfc8058
if request.method == "POST":
alias_utils.change_alias_status(alias, False)
alias_utils.change_alias_status(
alias, enabled=False, message="Set enabled=False from unsubscribe request"
)
flash(f"Alias {alias.email} has been blocked", "success")
Session.commit()

View File

@ -1,102 +1,22 @@
from app import config
from typing import Optional, List, Tuple
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional
import dns.resolver
def _get_dns_resolver():
my_resolver = dns.resolver.Resolver()
my_resolver.nameservers = config.NAMESERVERS
return my_resolver
def get_ns(hostname) -> [str]:
try:
answers = _get_dns_resolver().resolve(hostname, "NS", search=True)
except Exception:
return []
return [a.to_text() for a in answers]
def get_cname_record(hostname) -> Optional[str]:
"""Return the CNAME record if exists for a domain, WITHOUT the trailing period at the end"""
try:
answers = _get_dns_resolver().resolve(hostname, "CNAME", search=True)
except Exception:
return None
for a in answers:
ret = a.to_text()
return ret[:-1]
return None
def get_mx_domains(hostname) -> [(int, str)]:
"""return list of (priority, domain name) sorted by priority (lowest priority first)
domain name ends with a "." at the end.
"""
try:
answers = _get_dns_resolver().resolve(hostname, "MX", search=True)
except Exception:
return []
ret = []
for a in answers:
record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
parts = record.split(" ")
ret.append((int(parts[0]), parts[1]))
return sorted(ret, key=lambda prio_domain: prio_domain[0])
from app.config import NAMESERVERS
_include_spf = "include:"
def get_spf_domain(hostname) -> [str]:
"""return all domains listed in *include:*"""
try:
answers = _get_dns_resolver().resolve(hostname, "TXT", search=True)
except Exception:
return []
ret = []
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
for record in a.strings:
record = record.decode() # record is bytes
if record.startswith("v=spf1"):
parts = record.split(" ")
for part in parts:
if part.startswith(_include_spf):
ret.append(part[part.find(_include_spf) + len(_include_spf) :])
return ret
def get_txt_record(hostname) -> [str]:
try:
answers = _get_dns_resolver().resolve(hostname, "TXT", search=True)
except Exception:
return []
ret = []
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
for record in a.strings:
record = record.decode() # record is bytes
ret.append(record)
return ret
@dataclass
class MxRecord:
priority: int
domain: str
def is_mx_equivalent(
mx_domains: List[Tuple[int, str]], ref_mx_domains: List[Tuple[int, str]]
mx_domains: List[MxRecord], ref_mx_domains: List[MxRecord]
) -> bool:
"""
Compare mx_domains with ref_mx_domains to see if they are equivalent.
@ -105,16 +25,127 @@ def is_mx_equivalent(
The priority order is taken into account but not the priority number.
For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)]
"""
mx_domains = sorted(mx_domains, key=lambda priority_domain: priority_domain[0])
ref_mx_domains = sorted(
ref_mx_domains, key=lambda priority_domain: priority_domain[0]
)
mx_domains = sorted(mx_domains, key=lambda x: x.priority)
ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x.priority)
if len(mx_domains) < len(ref_mx_domains):
return False
for i in range(0, len(ref_mx_domains)):
if mx_domains[i][1] != ref_mx_domains[i][1]:
for actual, expected in zip(mx_domains, ref_mx_domains):
if actual.domain != expected.domain:
return False
return True
class DNSClient(ABC):
@abstractmethod
def get_cname_record(self, hostname: str) -> Optional[str]:
pass
@abstractmethod
def get_mx_domains(self, hostname: str) -> List[MxRecord]:
pass
def get_spf_domain(self, hostname: str) -> List[str]:
"""
return all domains listed in *include:*
"""
try:
records = self.get_txt_record(hostname)
ret = []
for record in records:
if record.startswith("v=spf1"):
parts = record.split(" ")
for part in parts:
if part.startswith(_include_spf):
ret.append(
part[part.find(_include_spf) + len(_include_spf) :]
)
return ret
except Exception:
return []
@abstractmethod
def get_txt_record(self, hostname: str) -> List[str]:
pass
class NetworkDNSClient(DNSClient):
def __init__(self, nameservers: List[str]):
self._resolver = dns.resolver.Resolver()
self._resolver.nameservers = nameservers
def get_cname_record(self, hostname: str) -> Optional[str]:
"""
Return the CNAME record if exists for a domain, WITHOUT the trailing period at the end
"""
try:
answers = self._resolver.resolve(hostname, "CNAME", search=True)
for a in answers:
ret = a.to_text()
return ret[:-1]
except Exception:
return None
def get_mx_domains(self, hostname: str) -> List[MxRecord]:
"""
return list of (priority, domain name) sorted by priority (lowest priority first)
domain name ends with a "." at the end.
"""
try:
answers = self._resolver.resolve(hostname, "MX", search=True)
ret = []
for a in answers:
record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
parts = record.split(" ")
ret.append(MxRecord(priority=int(parts[0]), domain=parts[1]))
return sorted(ret, key=lambda x: x.priority)
except Exception:
return []
def get_txt_record(self, hostname: str) -> List[str]:
try:
answers = self._resolver.resolve(hostname, "TXT", search=False)
ret = []
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
for record in a.strings:
ret.append(record.decode())
return ret
except Exception:
return []
class InMemoryDNSClient(DNSClient):
def __init__(self):
self.cname_records: dict[str, Optional[str]] = {}
self.mx_records: dict[str, List[MxRecord]] = {}
self.spf_records: dict[str, List[str]] = {}
self.txt_records: dict[str, List[str]] = {}
def set_cname_record(self, hostname: str, cname: str):
self.cname_records[hostname] = cname
def set_mx_records(self, hostname: str, mx_list: List[MxRecord]):
self.mx_records[hostname] = mx_list
def set_txt_record(self, hostname: str, txt_list: List[str]):
self.txt_records[hostname] = txt_list
def get_cname_record(self, hostname: str) -> Optional[str]:
return self.cname_records.get(hostname)
def get_mx_domains(self, hostname: str) -> List[MxRecord]:
mx_list = self.mx_records.get(hostname, [])
return sorted(mx_list, key=lambda x: x.priority)
def get_txt_record(self, hostname: str) -> List[str]:
return self.txt_records.get(hostname, [])
def get_network_dns_client() -> NetworkDNSClient:
return NetworkDNSClient(NAMESERVERS)
def get_mx_domains(hostname: str) -> List[MxRecord]:
return get_network_dns_client().get_mx_domains(hostname)

View File

@ -592,7 +592,7 @@ def email_can_be_used_as_mailbox(email_address: str) -> bool:
from app.models import CustomDomain
if CustomDomain.get_by(domain=domain, verified=True):
if CustomDomain.get_by(domain=domain, is_sl_subdomain=True, verified=True):
LOG.d("domain %s is a SimpleLogin custom domain", domain)
return False
@ -657,7 +657,7 @@ def get_mx_domain_list(domain) -> [str]:
"""
priority_domains = get_mx_domains(domain)
return [d[:-1] for _, d in priority_domains]
return [d.domain[:-1] for d in priority_domains]
def personal_email_already_used(email_address: str) -> bool:

View File

@ -30,14 +30,30 @@ class PostgresDispatcher(Dispatcher):
return PostgresDispatcher()
class GlobalDispatcher:
__dispatcher: Optional[Dispatcher] = None
@staticmethod
def get_dispatcher() -> Dispatcher:
if not GlobalDispatcher.__dispatcher:
GlobalDispatcher.__dispatcher = PostgresDispatcher.get()
return GlobalDispatcher.__dispatcher
@staticmethod
def set_dispatcher(dispatcher: Optional[Dispatcher]):
GlobalDispatcher.__dispatcher = dispatcher
class EventDispatcher:
@staticmethod
def send_event(
user: User,
content: event_pb2.EventContent,
dispatcher: Dispatcher = PostgresDispatcher.get(),
dispatcher: Optional[Dispatcher] = None,
skip_if_webhook_missing: bool = True,
):
if dispatcher is None:
dispatcher = GlobalDispatcher.get_dispatcher()
if config.EVENT_WEBHOOK_DISABLE:
LOG.i("Not sending events because webhook is disabled")
return
@ -48,15 +64,9 @@ class EventDispatcher:
)
return
if config.EVENT_WEBHOOK_ENABLED_USER_IDS is not None:
if user.id not in config.EVENT_WEBHOOK_ENABLED_USER_IDS:
return
partner_user = EventDispatcher.__partner_user(user.id)
if not partner_user:
LOG.i(
f"Not sending events because there's no partner user for user {user}"
)
LOG.i(f"Not sending events because there's no partner user for user {user}")
return
event = event_pb2.Event(
@ -68,7 +78,9 @@ class EventDispatcher:
serialized = event.SerializeToString()
dispatcher.send(serialized)
newrelic.agent.record_custom_metric("Custom/events_stored", 1)
event_type = content.WhichOneof("content")
newrelic.agent.record_custom_event("EventStoredToDb", {"type": event_type})
LOG.i("Sent event to the dispatcher")
@staticmethod

View File

@ -24,7 +24,7 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0b\x65vent.proto\x12\x12simplelogin_events\"(\n\x0fUserPlanChanged\x12\x15\n\rplan_end_time\x18\x01 \x01(\r\"\r\n\x0bUserDeleted\"Z\n\x0c\x41liasCreated\x12\x10\n\x08\x61lias_id\x18\x01 \x01(\r\x12\x13\n\x0b\x61lias_email\x18\x02 \x01(\t\x12\x12\n\nalias_note\x18\x03 \x01(\t\x12\x0f\n\x07\x65nabled\x18\x04 \x01(\x08\"L\n\x12\x41liasStatusChanged\x12\x10\n\x08\x61lias_id\x18\x01 \x01(\r\x12\x13\n\x0b\x61lias_email\x18\x02 \x01(\t\x12\x0f\n\x07\x65nabled\x18\x03 \x01(\x08\"5\n\x0c\x41liasDeleted\x12\x10\n\x08\x61lias_id\x18\x01 \x01(\r\x12\x13\n\x0b\x61lias_email\x18\x02 \x01(\t\"D\n\x10\x41liasCreatedList\x12\x30\n\x06\x65vents\x18\x01 \x03(\x0b\x32 .simplelogin_events.AliasCreated\"\x93\x03\n\x0c\x45ventContent\x12?\n\x10user_plan_change\x18\x01 \x01(\x0b\x32#.simplelogin_events.UserPlanChangedH\x00\x12\x37\n\x0cuser_deleted\x18\x02 \x01(\x0b\x32\x1f.simplelogin_events.UserDeletedH\x00\x12\x39\n\ralias_created\x18\x03 \x01(\x0b\x32 .simplelogin_events.AliasCreatedH\x00\x12\x45\n\x13\x61lias_status_change\x18\x04 \x01(\x0b\x32&.simplelogin_events.AliasStatusChangedH\x00\x12\x39\n\ralias_deleted\x18\x05 \x01(\x0b\x32 .simplelogin_events.AliasDeletedH\x00\x12\x41\n\x11\x61lias_create_list\x18\x06 \x01(\x0b\x32$.simplelogin_events.AliasCreatedListH\x00\x42\t\n\x07\x63ontent\"y\n\x05\x45vent\x12\x0f\n\x07user_id\x18\x01 \x01(\r\x12\x18\n\x10\x65xternal_user_id\x18\x02 \x01(\t\x12\x12\n\npartner_id\x18\x03 \x01(\r\x12\x31\n\x07\x63ontent\x18\x04 \x01(\x0b\x32 .simplelogin_events.EventContentb\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0b\x65vent.proto\x12\x12simplelogin_events\"(\n\x0fUserPlanChanged\x12\x15\n\rplan_end_time\x18\x01 \x01(\r\"\r\n\x0bUserDeleted\"\\\n\x0c\x41liasCreated\x12\n\n\x02id\x18\x01 \x01(\r\x12\r\n\x05\x65mail\x18\x02 \x01(\t\x12\x0c\n\x04note\x18\x03 \x01(\t\x12\x0f\n\x07\x65nabled\x18\x04 \x01(\x08\x12\x12\n\ncreated_at\x18\x05 \x01(\r\"T\n\x12\x41liasStatusChanged\x12\n\n\x02id\x18\x01 \x01(\r\x12\r\n\x05\x65mail\x18\x02 \x01(\t\x12\x0f\n\x07\x65nabled\x18\x03 \x01(\x08\x12\x12\n\ncreated_at\x18\x04 \x01(\r\")\n\x0c\x41liasDeleted\x12\n\n\x02id\x18\x01 \x01(\r\x12\r\n\x05\x65mail\x18\x02 \x01(\t\"D\n\x10\x41liasCreatedList\x12\x30\n\x06\x65vents\x18\x01 \x03(\x0b\x32 .simplelogin_events.AliasCreated\"\x93\x03\n\x0c\x45ventContent\x12?\n\x10user_plan_change\x18\x01 \x01(\x0b\x32#.simplelogin_events.UserPlanChangedH\x00\x12\x37\n\x0cuser_deleted\x18\x02 \x01(\x0b\x32\x1f.simplelogin_events.UserDeletedH\x00\x12\x39\n\ralias_created\x18\x03 \x01(\x0b\x32 .simplelogin_events.AliasCreatedH\x00\x12\x45\n\x13\x61lias_status_change\x18\x04 \x01(\x0b\x32&.simplelogin_events.AliasStatusChangedH\x00\x12\x39\n\ralias_deleted\x18\x05 \x01(\x0b\x32 .simplelogin_events.AliasDeletedH\x00\x12\x41\n\x11\x61lias_create_list\x18\x06 \x01(\x0b\x32$.simplelogin_events.AliasCreatedListH\x00\x42\t\n\x07\x63ontent\"y\n\x05\x45vent\x12\x0f\n\x07user_id\x18\x01 \x01(\r\x12\x18\n\x10\x65xternal_user_id\x18\x02 \x01(\t\x12\x12\n\npartner_id\x18\x03 \x01(\r\x12\x31\n\x07\x63ontent\x18\x04 \x01(\x0b\x32 .simplelogin_events.EventContentb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@ -36,15 +36,15 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals['_USERDELETED']._serialized_start=77
_globals['_USERDELETED']._serialized_end=90
_globals['_ALIASCREATED']._serialized_start=92
_globals['_ALIASCREATED']._serialized_end=182
_globals['_ALIASSTATUSCHANGED']._serialized_start=184
_globals['_ALIASSTATUSCHANGED']._serialized_end=260
_globals['_ALIASDELETED']._serialized_start=262
_globals['_ALIASDELETED']._serialized_end=315
_globals['_ALIASCREATEDLIST']._serialized_start=317
_globals['_ALIASCREATEDLIST']._serialized_end=385
_globals['_EVENTCONTENT']._serialized_start=388
_globals['_EVENTCONTENT']._serialized_end=791
_globals['_EVENT']._serialized_start=793
_globals['_EVENT']._serialized_end=914
_globals['_ALIASCREATED']._serialized_end=184
_globals['_ALIASSTATUSCHANGED']._serialized_start=186
_globals['_ALIASSTATUSCHANGED']._serialized_end=270
_globals['_ALIASDELETED']._serialized_start=272
_globals['_ALIASDELETED']._serialized_end=313
_globals['_ALIASCREATEDLIST']._serialized_start=315
_globals['_ALIASCREATEDLIST']._serialized_end=383
_globals['_EVENTCONTENT']._serialized_start=386
_globals['_EVENTCONTENT']._serialized_end=789
_globals['_EVENT']._serialized_start=791
_globals['_EVENT']._serialized_end=912
# @@protoc_insertion_point(module_scope)

View File

@ -16,34 +16,38 @@ class UserDeleted(_message.Message):
def __init__(self) -> None: ...
class AliasCreated(_message.Message):
__slots__ = ("alias_id", "alias_email", "alias_note", "enabled")
ALIAS_ID_FIELD_NUMBER: _ClassVar[int]
ALIAS_EMAIL_FIELD_NUMBER: _ClassVar[int]
ALIAS_NOTE_FIELD_NUMBER: _ClassVar[int]
__slots__ = ("id", "email", "note", "enabled", "created_at")
ID_FIELD_NUMBER: _ClassVar[int]
EMAIL_FIELD_NUMBER: _ClassVar[int]
NOTE_FIELD_NUMBER: _ClassVar[int]
ENABLED_FIELD_NUMBER: _ClassVar[int]
alias_id: int
alias_email: str
alias_note: str
CREATED_AT_FIELD_NUMBER: _ClassVar[int]
id: int
email: str
note: str
enabled: bool
def __init__(self, alias_id: _Optional[int] = ..., alias_email: _Optional[str] = ..., alias_note: _Optional[str] = ..., enabled: bool = ...) -> None: ...
created_at: int
def __init__(self, id: _Optional[int] = ..., email: _Optional[str] = ..., note: _Optional[str] = ..., enabled: bool = ..., created_at: _Optional[int] = ...) -> None: ...
class AliasStatusChanged(_message.Message):
__slots__ = ("alias_id", "alias_email", "enabled")
ALIAS_ID_FIELD_NUMBER: _ClassVar[int]
ALIAS_EMAIL_FIELD_NUMBER: _ClassVar[int]
__slots__ = ("id", "email", "enabled", "created_at")
ID_FIELD_NUMBER: _ClassVar[int]
EMAIL_FIELD_NUMBER: _ClassVar[int]
ENABLED_FIELD_NUMBER: _ClassVar[int]
alias_id: int
alias_email: str
CREATED_AT_FIELD_NUMBER: _ClassVar[int]
id: int
email: str
enabled: bool
def __init__(self, alias_id: _Optional[int] = ..., alias_email: _Optional[str] = ..., enabled: bool = ...) -> None: ...
created_at: int
def __init__(self, id: _Optional[int] = ..., email: _Optional[str] = ..., enabled: bool = ..., created_at: _Optional[int] = ...) -> None: ...
class AliasDeleted(_message.Message):
__slots__ = ("alias_id", "alias_email")
ALIAS_ID_FIELD_NUMBER: _ClassVar[int]
ALIAS_EMAIL_FIELD_NUMBER: _ClassVar[int]
alias_id: int
alias_email: str
def __init__(self, alias_id: _Optional[int] = ..., alias_email: _Optional[str] = ...) -> None: ...
__slots__ = ("id", "email")
ID_FIELD_NUMBER: _ClassVar[int]
EMAIL_FIELD_NUMBER: _ClassVar[int]
id: int
email: str
def __init__(self, id: _Optional[int] = ..., email: _Optional[str] = ...) -> None: ...
class AliasCreatedList(_message.Message):
__slots__ = ("events",)

View File

@ -103,7 +103,9 @@ class UnsubscribeHandler:
):
return status.E509
LOG.i(f"User disabled alias {alias} via unsubscribe header")
alias_utils.change_alias_status(alias, enabled=False)
alias_utils.change_alias_status(
alias, enabled=False, message="Set enabled=False via unsubscribe header"
)
Session.commit()
enable_alias_url = config.URL + f"/dashboard/?highlight_alias_id={alias.id}"
for mailbox in alias.mailboxes:

View File

@ -22,10 +22,11 @@ def send_alias_creation_events_for_user(
):
event_list.append(
AliasCreated(
alias_id=alias.id,
alias_email=alias.email,
alias_note=alias.note,
id=alias.id,
email=alias.email,
note=alias.note,
enabled=alias.enabled,
created_at=int(alias.created_at.timestamp),
)
)
if len(event_list) >= chunk_size:

View File

@ -1,6 +1,7 @@
import dataclasses
import secrets
import random
from enum import Enum
from typing import Optional
import arrow
@ -16,6 +17,7 @@ from app.email_utils import (
from app.email_validation import is_valid_email
from app.log import LOG
from app.models import User, Mailbox, Job, MailboxActivation
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
@dataclasses.dataclass
@ -70,9 +72,15 @@ def create_mailbox(
f"User {user} has tried to create mailbox with {email} but email is invalid"
)
raise MailboxError("Invalid email")
new_mailbox = Mailbox.create(
new_mailbox: Mailbox = Mailbox.create(
email=email, user_id=user.id, verified=verified, commit=True
)
emit_user_audit_log(
user=user,
action=UserAuditLogAction.CreateMailbox,
message=f"Create mailbox {new_mailbox.id} ({new_mailbox.email}). Verified={verified}",
commit=True,
)
if verified:
LOG.i(f"User {user} as created a pre-verified mailbox with {email}")
@ -129,7 +137,7 @@ def delete_mailbox(
if not transfer_mailbox.verified:
LOG.i(f"User {user} has tried to transfer to a non verified mailbox")
MailboxError("Your new mailbox is not verified")
raise MailboxError("Your new mailbox is not verified")
# Schedule delete account job
LOG.i(
@ -204,6 +212,11 @@ def verify_mailbox_code(user: User, mailbox_id: int, code: str) -> Mailbox:
raise CannotVerifyError("Invalid activation code")
LOG.i(f"User {user} has verified mailbox {mailbox_id}")
mailbox.verified = True
emit_user_audit_log(
user=user,
action=UserAuditLogAction.VerifyMailbox,
message=f"Verify mailbox {mailbox_id} ({mailbox.email})",
)
clear_activation_codes_for_mailbox(mailbox)
return mailbox
@ -213,7 +226,10 @@ def generate_activation_code(
) -> MailboxActivation:
clear_activation_codes_for_mailbox(mailbox)
if use_digit_code:
code = "{:06d}".format(random.randint(1, 999999))
if config.MAILBOX_VERIFICATION_OVERRIDE_CODE:
code = config.MAILBOX_VERIFICATION_OVERRIDE_CODE
else:
code = "{:06d}".format(random.randint(1, 999999))
else:
code = secrets.token_urlsafe(16)
return MailboxActivation.create(
@ -258,3 +274,54 @@ def send_verification_email(
mailbox_email=mailbox.email,
),
)
class MailboxEmailChangeError(Enum):
InvalidId = 1
EmailAlreadyUsed = 2
@dataclasses.dataclass
class MailboxEmailChangeResult:
error: Optional[MailboxEmailChangeError]
message: str
message_category: str
def perform_mailbox_email_change(mailbox_id: int) -> MailboxEmailChangeResult:
mailbox: Optional[Mailbox] = Mailbox.get(mailbox_id)
# new_email can be None if user cancels change in the meantime
if mailbox and mailbox.new_email:
user = mailbox.user
if Mailbox.get_by(email=mailbox.new_email, user_id=user.id):
return MailboxEmailChangeResult(
error=MailboxEmailChangeError.EmailAlreadyUsed,
message=f"{mailbox.new_email} is already used",
message_category="error",
)
emit_user_audit_log(
user=user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Change mailbox email for mailbox {mailbox_id} (old={mailbox.email} | new={mailbox.new_email})",
)
mailbox.email = mailbox.new_email
mailbox.new_email = None
# mark mailbox as verified if the change request is sent from an unverified mailbox
mailbox.verified = True
Session.commit()
LOG.d("Mailbox change %s is verified", mailbox)
return MailboxEmailChangeResult(
error=None,
message=f"The {mailbox.email} is updated",
message_category="success",
)
else:
return MailboxEmailChangeResult(
error=MailboxEmailChangeError.InvalidId,
message="Invalid link",
message_category="error",
)

View File

@ -336,7 +336,7 @@ class Fido(Base, ModelMixin):
class User(Base, ModelMixin, UserMixin, PasswordOracle):
__tablename__ = "users"
FLAG_FREE_DISABLE_CREATE_ALIAS = 1 << 0
FLAG_DISABLE_CREATE_CONTACTS = 1 << 0
FLAG_CREATED_FROM_PARTNER = 1 << 1
FLAG_FREE_OLD_ALIAS_LIMIT = 1 << 2
FLAG_CREATED_ALIAS_FROM_PARTNER = 1 << 3
@ -543,7 +543,7 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
# bitwise flags. Allow for future expansion
flags = sa.Column(
sa.BigInteger,
default=FLAG_FREE_DISABLE_CREATE_ALIAS,
default=FLAG_DISABLE_CREATE_CONTACTS,
server_default="0",
nullable=False,
)
@ -616,6 +616,15 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
if "alternative_id" not in kwargs:
user.alternative_id = str(uuid.uuid4())
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
trail = ". Created from partner" if from_partner else ""
emit_user_audit_log(
user=user,
action=UserAuditLogAction.CreateUser,
message=f"Created user {email}{trail}",
)
# If the user is created from partner, do not notify
# nor give a trial
if from_partner:
@ -973,7 +982,7 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
def has_custom_domain(self):
return CustomDomain.filter_by(user_id=self.id, verified=True).count() > 0
def custom_domains(self):
def custom_domains(self) -> List["CustomDomain"]:
return CustomDomain.filter_by(user_id=self.id, verified=True).all()
def available_domains_for_random_alias(
@ -1168,7 +1177,7 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
def can_create_contacts(self) -> bool:
if self.is_premium():
return True
if self.flags & User.FLAG_FREE_DISABLE_CREATE_ALIAS == 0:
if self.flags & User.FLAG_DISABLE_CREATE_CONTACTS == 0:
return True
return not config.DISABLE_CREATE_CONTACTS_FOR_FREE_USERS
@ -1660,18 +1669,6 @@ class Alias(Base, ModelMixin):
Session.add(new_alias)
DailyMetric.get_or_create_today_metric().nb_alias += 1
# Internal import to avoid global import cycles
from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import AliasCreated, EventContent
event = AliasCreated(
alias_id=new_alias.id,
alias_email=new_alias.email,
alias_note=new_alias.note,
enabled=True,
)
EventDispatcher.send_event(user, EventContent(alias_created=event))
if (
new_alias.flags & cls.FLAG_PARTNER_CREATED > 0
and new_alias.user.flags & User.FLAG_CREATED_ALIAS_FROM_PARTNER == 0
@ -1684,6 +1681,23 @@ class Alias(Base, ModelMixin):
if flush:
Session.flush()
# Internal import to avoid global import cycles
from app.alias_audit_log_utils import AliasAuditLogAction, emit_alias_audit_log
from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import AliasCreated, EventContent
event = AliasCreated(
id=new_alias.id,
email=new_alias.email,
note=new_alias.note,
enabled=True,
created_at=int(new_alias.created_at.timestamp),
)
EventDispatcher.send_event(user, EventContent(alias_created=event))
emit_alias_audit_log(
new_alias, AliasAuditLogAction.CreateAlias, "New alias created"
)
return new_alias
@classmethod
@ -1862,6 +1876,8 @@ class Contact(Base, ModelMixin):
MAX_NAME_LENGTH = 512
FLAG_PARTNER_CREATED = 1 << 0
__tablename__ = "contact"
__table_args__ = (
@ -1920,6 +1936,9 @@ class Contact(Base, ModelMixin):
# whether contact is created automatically during the forward phase
automatic_created = sa.Column(sa.Boolean, nullable=True, default=False)
# contact flags
flags = sa.Column(sa.Integer, nullable=False, default=0, server_default="0")
@property
def email(self):
return self.website_email
@ -2418,6 +2437,18 @@ class CustomDomain(Base, ModelMixin):
sa.Boolean, nullable=False, default=False, server_default="0"
)
partner_id = sa.Column(
sa.Integer,
sa.ForeignKey("partner.id"),
nullable=True,
default=None,
server_default=None,
)
pending_deletion = sa.Column(
sa.Boolean, nullable=False, default=False, server_default="0"
)
__table_args__ = (
Index(
"ix_unique_domain", # Index name
@ -2425,6 +2456,8 @@ class CustomDomain(Base, ModelMixin):
unique=True,
postgresql_where=Column("ownership_verified"),
), # The condition
Index("ix_custom_domain_user_id", "user_id"),
Index("ix_custom_domain_pending_deletion", "pending_deletion"),
)
user = orm.relationship(User, foreign_keys=[user_id], backref="custom_domains")
@ -2442,9 +2475,6 @@ class CustomDomain(Base, ModelMixin):
def get_trash_url(self):
return config.URL + f"/dashboard/domains/{self.id}/trash"
def get_ownership_dns_txt_value(self):
return f"sl-verification={self.ownership_txt_token}"
@classmethod
def create(cls, **kwargs):
domain = kwargs.get("domain")
@ -2749,9 +2779,9 @@ class Mailbox(Base, ModelMixin):
from app.email_utils import get_email_local_part
mx_domains: [(int, str)] = get_mx_domains(get_email_local_part(self.email))
mx_domains = get_mx_domains(get_email_local_part(self.email))
# Proton is the first domain
if mx_domains and mx_domains[0][1] in (
if mx_domains and mx_domains[0].domain in (
"mail.protonmail.ch.",
"mailsec.protonmail.ch.",
):
@ -3750,15 +3780,14 @@ class SyncEvent(Base, ModelMixin):
sa.Index("ix_sync_event_taken_time", "taken_time"),
)
def mark_as_taken(self) -> bool:
sql = """
UPDATE sync_event
SET taken_time = :taken_time
WHERE id = :sync_event_id
AND taken_time IS NULL
"""
def mark_as_taken(self, allow_taken_older_than: Optional[Arrow] = None) -> bool:
taken_condition = ["taken_time IS NULL"]
args = {"taken_time": arrow.now().datetime, "sync_event_id": self.id}
if allow_taken_older_than:
taken_condition.append("taken_time < :taken_older_than")
args["taken_older_than"] = allow_taken_older_than.datetime
sql_taken_condition = "({})".format(" OR ".join(taken_condition))
sql = f"UPDATE sync_event SET taken_time = :taken_time WHERE id = :sync_event_id AND {sql_taken_condition}"
res = Session.execute(sql, args)
Session.commit()
@ -3784,3 +3813,39 @@ class SyncEvent(Base, ModelMixin):
.limit(100)
.all()
)
class AliasAuditLog(Base, ModelMixin):
"""This model holds an audit log for all the actions performed to an alias"""
__tablename__ = "alias_audit_log"
user_id = sa.Column(sa.Integer, nullable=False)
alias_id = sa.Column(sa.Integer, nullable=False)
alias_email = sa.Column(sa.String(255), nullable=False)
action = sa.Column(sa.String(255), nullable=False)
message = sa.Column(sa.Text, default=None, nullable=True)
__table_args__ = (
sa.Index("ix_alias_audit_log_user_id", "user_id"),
sa.Index("ix_alias_audit_log_alias_id", "alias_id"),
sa.Index("ix_alias_audit_log_alias_email", "alias_email"),
sa.Index("ix_alias_audit_log_created_at", "created_at"),
)
class UserAuditLog(Base, ModelMixin):
"""This model holds an audit log for all the actions performed by a user"""
__tablename__ = "user_audit_log"
user_id = sa.Column(sa.Integer, nullable=False)
user_email = sa.Column(sa.String(255), nullable=False)
action = sa.Column(sa.String(255), nullable=False)
message = sa.Column(sa.Text, default=None, nullable=True)
__table_args__ = (
sa.Index("ix_user_audit_log_user_id", "user_id"),
sa.Index("ix_user_audit_log_user_email", "user_email"),
sa.Index("ix_user_audit_log_created_at", "created_at"),
)

View File

@ -0,0 +1,46 @@
from typing import Optional
from arrow import Arrow
from app.models import PartnerUser, PartnerSubscription, User
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
def create_partner_user(
user: User, partner_id: int, partner_email: str, external_user_id: str
) -> PartnerUser:
instance = PartnerUser.create(
user_id=user.id,
partner_id=partner_id,
partner_email=partner_email,
external_user_id=external_user_id,
)
emit_user_audit_log(
user=user,
action=UserAuditLogAction.LinkAccount,
message=f"Linked account to partner_id={partner_id} | partner_email={partner_email} | external_user_id={external_user_id}",
)
return instance
def create_partner_subscription(
partner_user: PartnerUser,
expiration: Optional[Arrow],
msg: Optional[str] = None,
) -> PartnerSubscription:
instance = PartnerSubscription.create(
partner_user_id=partner_user.id,
end_at=expiration,
)
message = "User upgraded through partner subscription"
if msg:
message += f" | {msg}"
emit_user_audit_log(
user=partner_user.user,
action=UserAuditLogAction.Upgrade,
message=message,
)
return instance

View File

View File

@ -0,0 +1,121 @@
from typing import Optional
import arrow
from coinbase_commerce.error import WebhookInvalidPayload, SignatureVerificationError
from coinbase_commerce.webhook import Webhook
from flask import Flask, request
from app.config import COINBASE_WEBHOOK_SECRET
from app.db import Session
from app.email_utils import send_email, render
from app.log import LOG
from app.models import CoinbaseSubscription, User
from app.subscription_webhook import execute_subscription_webhook
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
def setup_coinbase_commerce(app: Flask):
@app.route("/coinbase", methods=["POST"])
def coinbase_webhook():
# event payload
request_data = request.data.decode("utf-8")
# webhook signature
request_sig = request.headers.get("X-CC-Webhook-Signature", None)
try:
# signature verification and event object construction
event = Webhook.construct_event(
request_data, request_sig, COINBASE_WEBHOOK_SECRET
)
except (WebhookInvalidPayload, SignatureVerificationError) as e:
LOG.e("Invalid Coinbase webhook")
return str(e), 400
LOG.d("Coinbase event %s", event)
if event["type"] == "charge:confirmed":
if handle_coinbase_event(event):
return "success", 200
else:
return "error", 400
return "success", 200
def handle_coinbase_event(event) -> bool:
server_user_id = event["data"]["metadata"]["user_id"]
try:
user_id = int(server_user_id)
except ValueError:
user_id = int(float(server_user_id))
code = event["data"]["code"]
user: Optional[User] = User.get(user_id)
if not user:
LOG.e("User not found %s", user_id)
return False
coinbase_subscription: CoinbaseSubscription = CoinbaseSubscription.get_by(
user_id=user_id
)
if not coinbase_subscription:
LOG.d("Create a coinbase subscription for %s", user)
coinbase_subscription = CoinbaseSubscription.create(
user_id=user_id, end_at=arrow.now().shift(years=1), code=code, commit=True
)
emit_user_audit_log(
user=user,
action=UserAuditLogAction.Upgrade,
message="Upgraded though Coinbase",
commit=True,
)
send_email(
user.email,
"Your SimpleLogin account has been upgraded",
render(
"transactional/coinbase/new-subscription.txt",
user=user,
coinbase_subscription=coinbase_subscription,
),
render(
"transactional/coinbase/new-subscription.html",
user=user,
coinbase_subscription=coinbase_subscription,
),
)
else:
if coinbase_subscription.code != code:
LOG.d("Update code from %s to %s", coinbase_subscription.code, code)
coinbase_subscription.code = code
if coinbase_subscription.is_active():
coinbase_subscription.end_at = coinbase_subscription.end_at.shift(years=1)
else: # already expired subscription
coinbase_subscription.end_at = arrow.now().shift(years=1)
emit_user_audit_log(
user=user,
action=UserAuditLogAction.SubscriptionExtended,
message="Extended coinbase subscription",
)
Session.commit()
send_email(
user.email,
"Your SimpleLogin account has been extended",
render(
"transactional/coinbase/extend-subscription.txt",
user=user,
coinbase_subscription=coinbase_subscription,
),
render(
"transactional/coinbase/extend-subscription.html",
user=user,
coinbase_subscription=coinbase_subscription,
),
)
execute_subscription_webhook(user)
return True

286
app/app/payments/paddle.py Normal file
View File

@ -0,0 +1,286 @@
import arrow
import json
from dateutil.relativedelta import relativedelta
from flask import Flask, request
from app import paddle_utils, paddle_callback
from app.config import (
PADDLE_MONTHLY_PRODUCT_ID,
PADDLE_MONTHLY_PRODUCT_IDS,
PADDLE_YEARLY_PRODUCT_IDS,
PADDLE_COUPON_ID,
)
from app.db import Session
from app.email_utils import send_email, render
from app.log import LOG
from app.models import Subscription, PlanEnum, User, Coupon
from app.subscription_webhook import execute_subscription_webhook
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.utils import random_string
def setup_paddle_callback(app: Flask):
@app.route("/paddle", methods=["GET", "POST"])
def paddle():
LOG.d(f"paddle callback {request.form.get('alert_name')} {request.form}")
# make sure the request comes from Paddle
if not paddle_utils.verify_incoming_request(dict(request.form)):
LOG.e("request not coming from paddle. Request data:%s", dict(request.form))
return "KO", 400
if (
request.form.get("alert_name") == "subscription_created"
): # new user subscribes
# the passthrough is json encoded, e.g.
# request.form.get("passthrough") = '{"user_id": 88 }'
passthrough = json.loads(request.form.get("passthrough"))
user_id = passthrough.get("user_id")
user = User.get(user_id)
subscription_plan_id = int(request.form.get("subscription_plan_id"))
if subscription_plan_id in PADDLE_MONTHLY_PRODUCT_IDS:
plan = PlanEnum.monthly
elif subscription_plan_id in PADDLE_YEARLY_PRODUCT_IDS:
plan = PlanEnum.yearly
else:
LOG.e(
"Unknown subscription_plan_id %s %s",
subscription_plan_id,
request.form,
)
return "No such subscription", 400
sub = Subscription.get_by(user_id=user.id)
if not sub:
LOG.d(f"create a new Subscription for user {user}")
Subscription.create(
user_id=user.id,
cancel_url=request.form.get("cancel_url"),
update_url=request.form.get("update_url"),
subscription_id=request.form.get("subscription_id"),
event_time=arrow.now(),
next_bill_date=arrow.get(
request.form.get("next_bill_date"), "YYYY-MM-DD"
).date(),
plan=plan,
)
emit_user_audit_log(
user=user,
action=UserAuditLogAction.Upgrade,
message="Upgraded through Paddle",
)
else:
LOG.d(f"Update an existing Subscription for user {user}")
sub.cancel_url = request.form.get("cancel_url")
sub.update_url = request.form.get("update_url")
sub.subscription_id = request.form.get("subscription_id")
sub.event_time = arrow.now()
sub.next_bill_date = arrow.get(
request.form.get("next_bill_date"), "YYYY-MM-DD"
).date()
sub.plan = plan
# make sure to set the new plan as not-cancelled
# in case user cancels a plan and subscribes a new plan
sub.cancelled = False
emit_user_audit_log(
user=user,
action=UserAuditLogAction.SubscriptionExtended,
message="Extended Paddle subscription",
)
execute_subscription_webhook(user)
LOG.d("User %s upgrades!", user)
Session.commit()
elif request.form.get("alert_name") == "subscription_payment_succeeded":
subscription_id = request.form.get("subscription_id")
LOG.d("Update subscription %s", subscription_id)
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
# when user subscribes, the "subscription_payment_succeeded" can arrive BEFORE "subscription_created"
# at that time, subscription object does not exist yet
if sub:
sub.event_time = arrow.now()
sub.next_bill_date = arrow.get(
request.form.get("next_bill_date"), "YYYY-MM-DD"
).date()
Session.commit()
execute_subscription_webhook(sub.user)
elif request.form.get("alert_name") == "subscription_cancelled":
subscription_id = request.form.get("subscription_id")
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
if sub:
# cancellation_effective_date should be the same as next_bill_date
LOG.w(
"Cancel subscription %s %s on %s, next bill date %s",
subscription_id,
sub.user,
request.form.get("cancellation_effective_date"),
sub.next_bill_date,
)
sub.event_time = arrow.now()
sub.cancelled = True
emit_user_audit_log(
user=sub.user,
action=UserAuditLogAction.SubscriptionCancelled,
message="Cancelled Paddle subscription",
)
Session.commit()
user = sub.user
send_email(
user.email,
"SimpleLogin - your subscription is canceled",
render(
"transactional/subscription-cancel.txt",
user=user,
end_date=request.form.get("cancellation_effective_date"),
),
)
execute_subscription_webhook(sub.user)
else:
# user might have deleted their account
LOG.i(f"Cancel non-exist subscription {subscription_id}")
return "OK"
elif request.form.get("alert_name") == "subscription_updated":
subscription_id = request.form.get("subscription_id")
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
if sub:
next_bill_date = request.form.get("next_bill_date")
if not next_bill_date:
paddle_callback.failed_payment(sub, subscription_id)
return "OK"
LOG.d(
"Update subscription %s %s on %s, next bill date %s",
subscription_id,
sub.user,
request.form.get("cancellation_effective_date"),
sub.next_bill_date,
)
if (
int(request.form.get("subscription_plan_id"))
== PADDLE_MONTHLY_PRODUCT_ID
):
plan = PlanEnum.monthly
else:
plan = PlanEnum.yearly
sub.cancel_url = request.form.get("cancel_url")
sub.update_url = request.form.get("update_url")
sub.event_time = arrow.now()
sub.next_bill_date = arrow.get(
request.form.get("next_bill_date"), "YYYY-MM-DD"
).date()
sub.plan = plan
# make sure to set the new plan as not-cancelled
sub.cancelled = False
emit_user_audit_log(
user=sub.user,
action=UserAuditLogAction.SubscriptionExtended,
message="Extended Paddle subscription",
)
Session.commit()
execute_subscription_webhook(sub.user)
else:
LOG.w(
f"update non-exist subscription {subscription_id}. {request.form}"
)
return "No such subscription", 400
elif request.form.get("alert_name") == "payment_refunded":
subscription_id = request.form.get("subscription_id")
LOG.d("Refund request for subscription %s", subscription_id)
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
if sub:
user = sub.user
Subscription.delete(sub.id)
emit_user_audit_log(
user=user,
action=UserAuditLogAction.SubscriptionCancelled,
message="Paddle subscription cancelled as user requested a refund",
)
Session.commit()
LOG.e("%s requests a refund", user)
execute_subscription_webhook(sub.user)
elif request.form.get("alert_name") == "subscription_payment_refunded":
subscription_id = request.form.get("subscription_id")
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
LOG.d(
"Handle subscription_payment_refunded for subscription %s",
subscription_id,
)
if not sub:
LOG.w(
"No such subscription for %s, payload %s",
subscription_id,
request.form,
)
return "No such subscription"
plan_id = int(request.form["subscription_plan_id"])
if request.form["refund_type"] == "full":
if plan_id in PADDLE_MONTHLY_PRODUCT_IDS:
LOG.d("subtract 1 month from next_bill_date %s", sub.next_bill_date)
sub.next_bill_date = sub.next_bill_date - relativedelta(months=1)
LOG.d("next_bill_date is %s", sub.next_bill_date)
Session.commit()
elif plan_id in PADDLE_YEARLY_PRODUCT_IDS:
LOG.d("subtract 1 year from next_bill_date %s", sub.next_bill_date)
sub.next_bill_date = sub.next_bill_date - relativedelta(years=1)
LOG.d("next_bill_date is %s", sub.next_bill_date)
Session.commit()
else:
LOG.e("Unknown plan_id %s", plan_id)
else:
LOG.w("partial subscription_payment_refunded, not handled")
execute_subscription_webhook(sub.user)
return "OK"
@app.route("/paddle_coupon", methods=["GET", "POST"])
def paddle_coupon():
LOG.d("paddle coupon callback %s", request.form)
if not paddle_utils.verify_incoming_request(dict(request.form)):
LOG.e("request not coming from paddle. Request data:%s", dict(request.form))
return "KO", 400
product_id = request.form.get("p_product_id")
if product_id != PADDLE_COUPON_ID:
LOG.e("product_id %s not match with %s", product_id, PADDLE_COUPON_ID)
return "KO", 400
email = request.form.get("email")
LOG.d("Paddle coupon request for %s", email)
coupon = Coupon.create(
code=random_string(30),
comment="For 1-year coupon",
expires_date=arrow.now().shift(years=1, days=-1),
commit=True,
)
return (
f"Your 1-year coupon is <b>{coupon.code}</b> <br> "
f"It's valid until <b>{coupon.expires_date.date().isoformat()}</b>"
)

View File

@ -2,9 +2,11 @@ from dataclasses import dataclass
from enum import Enum
from flask import url_for
from typing import Optional
import arrow
from app import config
from app.errors import LinkException
from app.models import User, Partner
from app.models import User, Partner, Job
from app.proton.proton_client import ProtonClient, ProtonUser
from app.account_linking import (
process_login_case,
@ -41,12 +43,21 @@ class ProtonCallbackHandler:
def __init__(self, proton_client: ProtonClient):
self.proton_client = proton_client
def _initial_alias_sync(self, user: User):
Job.create(
name=config.JOB_SEND_ALIAS_CREATION_EVENTS,
payload={"user_id": user.id},
run_at=arrow.now(),
commit=True,
)
def handle_login(self, partner: Partner) -> ProtonCallbackResult:
try:
user = self.__get_partner_user()
if user is None:
return generate_account_not_allowed_to_log_in()
res = process_login_case(user, partner)
self._initial_alias_sync(res.user)
return ProtonCallbackResult(
redirect_to_login=False,
flash_message=None,
@ -75,6 +86,7 @@ class ProtonCallbackHandler:
if user is None:
return generate_account_not_allowed_to_log_in()
res = process_link_case(user, current_user, partner)
self._initial_alias_sync(res.user)
return ProtonCallbackResult(
redirect_to_login=False,
flash_message="Account successfully linked",

View File

@ -5,6 +5,7 @@ from app.db import Session
from app.log import LOG
from app.errors import ProtonPartnerNotSetUp
from app.models import Partner, PartnerUser, User
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
PROTON_PARTNER_NAME = "Proton"
_PROTON_PARTNER: Optional[Partner] = None
@ -32,6 +33,11 @@ def perform_proton_account_unlink(current_user: User):
)
if partner_user is not None:
LOG.info(f"User {current_user} has unlinked the account from {partner_user}")
emit_user_audit_log(
user=current_user,
action=UserAuditLogAction.UnlinkAccount,
message=f"User has unlinked the account (email={partner_user.partner_email} | external_user_id={partner_user.external_user_id})",
)
PartnerUser.delete(partner_user.id)
Session.commit()
agent.record_custom_event("AccountUnlinked", {"partner": proton_partner.name})

21
app/app/sentry_utils.py Normal file
View File

@ -0,0 +1,21 @@
from typing import Optional
from sentry_sdk.types import Event, Hint
_HTTP_CODES_TO_IGNORE = [416]
def _should_send(_event: Event, hint: Hint) -> bool:
# Check if this is an HTTP Exception event
if "exc_info" in hint:
exc_type, exc_value, exc_traceback = hint["exc_info"]
# Check if it's a Werkzeug HTTPException (raised for HTTP status codes)
if hasattr(exc_value, "code") and exc_value.code in _HTTP_CODES_TO_IGNORE:
return False
return True
def sentry_before_send(event: Event, hint: Hint) -> Optional[Event]:
if _should_send(event, hint):
return event
return None

View File

@ -1,40 +1,16 @@
import requests
from requests import RequestException
from app import config
from app.db import Session
from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import EventContent, UserPlanChanged
from app.log import LOG
from app.models import User
def execute_subscription_webhook(user: User):
webhook_url = config.SUBSCRIPTION_CHANGE_WEBHOOK
if webhook_url is None:
return
subscription_end = user.get_active_subscription_end(
include_partner_subscription=False
)
sl_subscription_end = None
if subscription_end:
sl_subscription_end = subscription_end.timestamp
payload = {
"user_id": user.id,
"is_premium": user.is_premium(),
"active_subscription_end": sl_subscription_end,
}
try:
response = requests.post(webhook_url, json=payload, timeout=2)
if response.status_code == 200:
LOG.i("Sent request to subscription update webhook successfully")
else:
LOG.i(
f"Request to webhook failed with status {response.status_code}: {response.text}"
)
except RequestException as e:
LOG.error(f"Subscription request exception: {e}")
event = UserPlanChanged(plan_end_time=sl_subscription_end)
EventDispatcher.send_event(user, EventContent(user_plan_change=event))
Session.commit()

View File

@ -0,0 +1,44 @@
from enum import Enum
from app.models import User, UserAuditLog
class UserAuditLogAction(Enum):
CreateUser = "create_user"
ActivateUser = "activate_user"
ResetPassword = "reset_password"
Upgrade = "upgrade"
SubscriptionExtended = "subscription_extended"
SubscriptionCancelled = "subscription_cancelled"
LinkAccount = "link_account"
UnlinkAccount = "unlink_account"
CreateMailbox = "create_mailbox"
VerifyMailbox = "verify_mailbox"
UpdateMailbox = "update_mailbox"
DeleteMailbox = "delete_mailbox"
CreateCustomDomain = "create_custom_domain"
VerifyCustomDomain = "verify_custom_domain"
UpdateCustomDomain = "update_custom_domain"
DeleteCustomDomain = "delete_custom_domain"
CreateDirectory = "create_directory"
UpdateDirectory = "update_directory"
DeleteDirectory = "delete_directory"
UserMarkedForDeletion = "user_marked_for_deletion"
DeleteUser = "delete_user"
def emit_user_audit_log(
user: User, action: UserAuditLogAction, message: str, commit: bool = False
):
UserAuditLog.create(
user_id=user.id,
user_email=user.email,
action=action.value,
message=message,
commit=commit,
)

View File

@ -3,6 +3,7 @@ from typing import Optional
from app.db import Session
from app.log import LOG
from app.models import User, SLDomain, CustomDomain, Mailbox
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
class CannotSetAlias(Exception):
@ -54,7 +55,7 @@ def set_default_alias_domain(user: User, domain_name: Optional[str]):
def set_default_mailbox(user: User, mailbox_id: int) -> Mailbox:
mailbox = Mailbox.get(mailbox_id)
mailbox: Optional[Mailbox] = Mailbox.get(mailbox_id)
if not mailbox or mailbox.user_id != user.id:
raise CannotSetMailbox("Invalid mailbox")
@ -67,5 +68,11 @@ def set_default_mailbox(user: User, mailbox_id: int) -> Mailbox:
LOG.i(f"User {user} has set mailbox {mailbox} as his default one")
user.default_mailbox_id = mailbox.id
emit_user_audit_log(
user=user,
action=UserAuditLogAction.UpdateMailbox,
message=f"Set mailbox {mailbox.id} ({mailbox.email}) as default",
)
Session.commit()
return mailbox

View File

@ -14,6 +14,7 @@ from sqlalchemy.sql import Insert, text
from app import s3, config
from app.alias_utils import nb_email_log_for_mailbox
from app.api.views.apple import verify_receipt
from app.custom_domain_validation import CustomDomainValidation
from app.db import Session
from app.dns_utils import get_mx_domains, is_mx_equivalent
from app.email_utils import (
@ -59,8 +60,11 @@ from app.models import (
)
from app.pgp_utils import load_public_key_and_check, PGPException
from app.proton.utils import get_proton_partner
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.utils import sanitize_email
from server import create_light_app
from tasks.clean_alias_audit_log import cleanup_alias_audit_log
from tasks.clean_user_audit_log import cleanup_user_audit_log
from tasks.cleanup_old_imports import cleanup_old_imports
from tasks.cleanup_old_jobs import cleanup_old_jobs
from tasks.cleanup_old_notifications import cleanup_old_notifications
@ -905,9 +909,11 @@ def check_custom_domain():
LOG.i("custom domain has been deleted")
def check_single_custom_domain(custom_domain):
def check_single_custom_domain(custom_domain: CustomDomain):
mx_domains = get_mx_domains(custom_domain.domain)
if not is_mx_equivalent(mx_domains, config.EMAIL_SERVERS_WITH_PRIORITY):
validator = CustomDomainValidation(dkim_domain=config.EMAIL_DOMAIN)
expected_custom_domains = validator.get_expected_mx_records(custom_domain)
if not is_mx_equivalent(mx_domains, expected_custom_domains):
user = custom_domain.user
LOG.w(
"The MX record is not correctly set for %s %s %s",
@ -1215,7 +1221,7 @@ def notify_hibp():
def clear_users_scheduled_to_be_deleted(dry_run=False):
users = User.filter(
users: List[User] = User.filter(
and_(
User.delete_on.isnot(None),
User.delete_on <= arrow.now().shift(days=-DELETE_GRACE_DAYS),
@ -1227,6 +1233,11 @@ def clear_users_scheduled_to_be_deleted(dry_run=False):
)
if dry_run:
continue
emit_user_audit_log(
user=user,
action=UserAuditLogAction.DeleteUser,
message=f"Delete user {user.id} ({user.email})",
)
User.delete(user.id)
Session.commit()
@ -1238,6 +1249,16 @@ def delete_old_data():
cleanup_old_notifications(oldest_valid)
def clear_alias_audit_log():
oldest_valid = arrow.now().shift(days=-config.AUDIT_LOG_MAX_DAYS)
cleanup_alias_audit_log(oldest_valid)
def clear_user_audit_log():
oldest_valid = arrow.now().shift(days=-config.AUDIT_LOG_MAX_DAYS)
cleanup_user_audit_log(oldest_valid)
if __name__ == "__main__":
LOG.d("Start running cronjob")
parser = argparse.ArgumentParser()
@ -1246,22 +1267,6 @@ if __name__ == "__main__":
"--job",
help="Choose a cron job to run",
type=str,
choices=[
"stats",
"notify_trial_end",
"notify_manual_subscription_end",
"notify_premium_end",
"delete_logs",
"delete_old_data",
"poll_apple_subscription",
"sanity_check",
"delete_old_monitoring",
"check_custom_domain",
"check_hibp",
"notify_hibp",
"cleanup_tokens",
"send_undelivered_mails",
],
)
args = parser.parse_args()
# wrap in an app context to benefit from app setup like database cleanup, sentry integration, etc
@ -1310,4 +1315,10 @@ if __name__ == "__main__":
load_unsent_mails_from_fs_and_resend()
elif args.job == "delete_scheduled_users":
LOG.d("Deleting users scheduled to be deleted")
clear_users_scheduled_to_be_deleted(dry_run=True)
clear_users_scheduled_to_be_deleted()
elif args.job == "clear_alias_audit_log":
LOG.d("Clearing alias audit log")
clear_alias_audit_log()
elif args.job == "clear_user_audit_log":
LOG.d("Clearing user audit log")
clear_user_audit_log()

View File

@ -80,3 +80,17 @@ jobs:
schedule: "*/5 * * * *"
captureStderr: true
concurrencyPolicy: Forbid
- name: SimpleLogin clear alias_audit_log old entries
command: python /code/cron.py -j clear_alias_audit_log
shell: /bin/bash
schedule: "0 * * * *" # Once every hour
captureStderr: true
concurrencyPolicy: Forbid
- name: SimpleLogin clear user_audit_log old entries
command: python /code/cron.py -j clear_user_audit_log
shell: /bin/bash
schedule: "0 * * * *" # Once every hour
captureStderr: true
concurrencyPolicy: Forbid

View File

@ -52,8 +52,12 @@ from flanker.addresslib import address
from flanker.addresslib.address import EmailAddress
from sqlalchemy.exc import IntegrityError
from app import pgp_utils, s3, config
from app.alias_utils import try_auto_create, change_alias_status
from app import pgp_utils, s3, config, contact_utils
from app.alias_utils import (
try_auto_create,
change_alias_status,
get_alias_recipient_name,
)
from app.config import (
EMAIL_DOMAIN,
URL,
@ -195,81 +199,16 @@ def get_or_create_contact(from_header: str, mail_from: str, alias: Alias) -> Con
mail_from,
)
contact_email = mail_from
if not is_valid_email(contact_email):
LOG.w(
"invalid contact email %s. Parse from %s %s",
contact_email,
from_header,
mail_from,
)
# either reuse a contact with empty email or create a new contact with empty email
contact_email = ""
contact_email = sanitize_email(contact_email, not_lower=True)
if contact_name and "\x00" in contact_name:
LOG.w("issue with contact name %s", contact_name)
contact_name = ""
contact = Contact.get_by(alias_id=alias.id, website_email=contact_email)
if contact:
if contact.name != contact_name:
LOG.d(
"Update contact %s name %s to %s",
contact,
contact.name,
contact_name,
)
contact.name = contact_name
Session.commit()
# contact created in the past does not have mail_from and from_header field
if not contact.mail_from and mail_from:
LOG.d(
"Set contact mail_from %s: %s to %s",
contact,
contact.mail_from,
mail_from,
)
contact.mail_from = mail_from
Session.commit()
else:
alias_id = alias.id
try:
contact_email_for_reply = (
contact_email if is_valid_email(contact_email) else ""
)
contact = Contact.create(
user_id=alias.user_id,
alias_id=alias_id,
website_email=contact_email,
name=contact_name,
mail_from=mail_from,
reply_email=generate_reply_email(contact_email_for_reply, alias),
automatic_created=True,
)
if not contact_email:
LOG.d("Create a contact with invalid email for %s", alias)
contact.invalid_email = True
LOG.d(
"create contact %s for %s, reverse alias:%s",
contact_email,
alias,
contact.reply_email,
)
Session.commit()
except IntegrityError:
# If the tx has been rolled back, the connection is borked. Force close to try to get a new one and start fresh
Session.close()
LOG.info(
f"Contact with email {contact_email} for alias_id {alias_id} already existed, fetching from DB"
)
contact = Contact.get_by(alias_id=alias_id, website_email=contact_email)
return contact
contact_result = contact_utils.create_contact(
email=contact_email,
alias=alias,
name=contact_name,
mail_from=mail_from,
allow_empty_email=True,
automatic_created=True,
from_partner=False,
)
return contact_result.contact
def get_or_create_reply_to_contact(
@ -294,33 +233,7 @@ def get_or_create_reply_to_contact(
)
return None
contact = Contact.get_by(alias_id=alias.id, website_email=contact_address)
if contact:
return contact
else:
LOG.d(
"create contact %s for alias %s via reply-to header %s",
contact_address,
alias,
reply_to_header,
)
try:
contact = Contact.create(
user_id=alias.user_id,
alias_id=alias.id,
website_email=contact_address,
name=contact_name,
reply_email=generate_reply_email(contact_address, alias),
automatic_created=True,
)
Session.commit()
except IntegrityError:
LOG.w("Contact %s %s already exist", alias, contact_address)
Session.rollback()
contact = Contact.get_by(alias_id=alias.id, website_email=contact_address)
return contact
return contact_utils.create_contact(contact_address, alias, contact_name).contact
def replace_header_when_forward(msg: Message, alias: Alias, header: str):
@ -818,7 +731,7 @@ def forward_email_to_mailbox(
email_log = EmailLog.create(
contact_id=contact.id,
user_id=user.id,
user_id=contact.user_id,
mailbox_id=mailbox.id,
alias_id=contact.alias_id,
message_id=str(msg[headers.MESSAGE_ID]),
@ -1252,23 +1165,11 @@ def handle_reply(envelope, msg: Message, rcpt_to: str) -> (bool, str):
Session.commit()
# make the email comes from alias
from_header = alias.email
# add alias name from alias
if alias.name:
LOG.d("Put alias name %s in from header", alias.name)
from_header = sl_formataddr((alias.name, alias.email))
elif alias.custom_domain:
# add alias name from domain
if alias.custom_domain.name:
LOG.d(
"Put domain default alias name %s in from header",
alias.custom_domain.name,
)
from_header = sl_formataddr((alias.custom_domain.name, alias.email))
LOG.d("From header is %s", from_header)
add_or_replace_header(msg, headers.FROM, from_header)
recipient_name = get_alias_recipient_name(alias)
if recipient_name.message:
LOG.d(recipient_name.message)
LOG.d("From header is %s", recipient_name.name)
add_or_replace_header(msg, headers.FROM, recipient_name.name)
try:
if str(msg[headers.TO]).lower() == "undisclosed-recipients:;":
@ -1601,7 +1502,9 @@ def handle_bounce_forward_phase(msg: Message, email_log: EmailLog):
LOG.w(
f"Disable alias {alias} because {reason}. {alias.mailboxes} {alias.user}. Last contact {contact}"
)
change_alias_status(alias, enabled=False)
change_alias_status(
alias, enabled=False, message=f"Set enabled=False due to {reason}"
)
Notification.create(
user_id=user.id,

View File

@ -9,7 +9,7 @@ from events.runner import Runner
from events.event_source import DeadLetterEventSource, PostgresEventSource
from events.event_sink import ConsoleEventSink, HttpEventSink
_DEFAULT_MAX_RETRIES = 100
_DEFAULT_MAX_RETRIES = 10
class Mode(Enum):

View File

@ -27,7 +27,9 @@ class HttpEventSink(EventSink):
headers={"Content-Type": "application/x-protobuf"},
verify=not EVENT_WEBHOOK_SKIP_VERIFY_SSL,
)
newrelic.agent.record_custom_event("event_sent", {"http_code": res.status_code})
newrelic.agent.record_custom_event(
"EventSentToPartner", {"http_code": res.status_code}
)
if res.status_code != 200:
LOG.warning(
f"Failed to send event to webhook: {res.status_code} {res.text}"

View File

@ -72,7 +72,9 @@ class PostgresEventSource(EventSource):
Session.close() # Ensure we get a new connection and we don't leave a dangling tx
def __connect(self):
self.__connection = psycopg2.connect(self.__connection_string)
self.__connection = psycopg2.connect(
self.__connection_string, application_name="sl-event-listen"
)
from app.db import Session
@ -83,24 +85,28 @@ class DeadLetterEventSource(EventSource):
def __init__(self, max_retries: int):
self.__max_retries = max_retries
def execute_loop(
self, on_event: Callable[[SyncEvent], NoReturn]
) -> list[SyncEvent]:
threshold = arrow.utcnow().shift(minutes=-_DEAD_LETTER_THRESHOLD_MINUTES)
events = SyncEvent.get_dead_letter(
older_than=threshold, max_retries=self.__max_retries
)
if events:
LOG.info(f"Got {len(events)} dead letter events")
newrelic.agent.record_custom_metric(
"Custom/dead_letter_events_to_process", len(events)
)
for event in events:
if event.mark_as_taken(allow_taken_older_than=threshold):
on_event(event)
return events
@newrelic.agent.background_task()
def run(self, on_event: Callable[[SyncEvent], NoReturn]):
while True:
try:
threshold = arrow.utcnow().shift(
minutes=-_DEAD_LETTER_THRESHOLD_MINUTES
)
events = SyncEvent.get_dead_letter(
older_than=threshold, max_retries=self.__max_retries
)
if events:
LOG.info(f"Got {len(events)} dead letter events")
if events:
newrelic.agent.record_custom_metric(
"Custom/dead_letter_events_to_process", len(events)
)
for event in events:
on_event(event)
events = self.execute_loop(on_event)
Session.close() # Ensure that we have a new connection and we don't have a dangling tx with a lock
if not events:
LOG.debug("No dead letter events")

View File

@ -3,7 +3,7 @@ Run scheduled jobs.
Not meant for running job at precise time (+- 1h)
"""
import time
from typing import List
from typing import List, Optional
import arrow
from sqlalchemy.sql.expression import or_, and_
@ -20,6 +20,7 @@ from app.jobs.event_jobs import send_alias_creation_events_for_user
from app.jobs.export_user_data_job import ExportUserDataJob
from app.log import LOG
from app.models import User, Job, BatchImport, Mailbox, CustomDomain, JobState
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from server import create_light_app
@ -128,7 +129,7 @@ def welcome_proton(user):
def delete_mailbox_job(job: Job):
mailbox_id = job.payload.get("mailbox_id")
mailbox = Mailbox.get(mailbox_id)
mailbox: Optional[Mailbox] = Mailbox.get(mailbox_id)
if not mailbox:
return
@ -152,6 +153,12 @@ def delete_mailbox_job(job: Job):
mailbox_email = mailbox.email
user = mailbox.user
emit_user_audit_log(
user=user,
action=UserAuditLogAction.DeleteMailbox,
message=f"Delete mailbox {mailbox.id} ({mailbox.email})",
)
Mailbox.delete(mailbox_id)
Session.commit()
LOG.d("Mailbox %s %s deleted", mailbox_id, mailbox_email)
@ -240,28 +247,41 @@ def process_job(job: Job):
elif job.name == config.JOB_DELETE_DOMAIN:
custom_domain_id = job.payload.get("custom_domain_id")
custom_domain = CustomDomain.get(custom_domain_id)
custom_domain: Optional[CustomDomain] = CustomDomain.get(custom_domain_id)
if not custom_domain:
return
is_subdomain = custom_domain.is_sl_subdomain
domain_name = custom_domain.domain
user = custom_domain.user
custom_domain_partner_id = custom_domain.partner_id
CustomDomain.delete(custom_domain.id)
Session.commit()
if is_subdomain:
message = f"Delete subdomain {custom_domain_id} ({domain_name})"
else:
message = f"Delete custom domain {custom_domain_id} ({domain_name})"
emit_user_audit_log(
user=user,
action=UserAuditLogAction.DeleteCustomDomain,
message=message,
)
LOG.d("Domain %s deleted", domain_name)
send_email(
user.email,
f"Your domain {domain_name} has been deleted",
f"""Domain {domain_name} along with its aliases are deleted successfully.
if custom_domain_partner_id is None:
send_email(
user.email,
f"Your domain {domain_name} has been deleted",
f"""Domain {domain_name} along with its aliases are deleted successfully.
Regards,
SimpleLogin team.
""",
retries=3,
)
Regards,
SimpleLogin team.
""",
retries=3,
)
elif job.name == config.JOB_SEND_USER_REPORT:
export_job = ExportUserDataJob.create_from_job(job)
if export_job:

View File

@ -0,0 +1,30 @@
"""Custom Domain partner id
Revision ID: 2441b7ff5da9
Revises: 1c14339aae90
Create Date: 2024-09-13 15:43:02.425964
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '2441b7ff5da9'
down_revision = '1c14339aae90'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('custom_domain', sa.Column('partner_id', sa.Integer(), nullable=True, default=None, server_default=None))
op.create_foreign_key(None, 'custom_domain', 'partner', ['partner_id'], ['id'])
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, 'custom_domain', type_='foreignkey')
op.drop_column('custom_domain', 'partner_id')
# ### end Alembic commands ###

View File

@ -0,0 +1,31 @@
"""contact.flags and custom_domain.pending_deletion
Revision ID: 88dd7a0abf54
Revises: 2441b7ff5da9
Create Date: 2024-09-19 15:41:20.910374
"""
import sqlalchemy_utils
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '88dd7a0abf54'
down_revision = '2441b7ff5da9'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('contact', sa.Column('flags', sa.Integer(), server_default='0', nullable=False))
op.add_column('custom_domain', sa.Column('pending_deletion', sa.Boolean(), server_default='0', nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('custom_domain', 'pending_deletion')
op.drop_column('contact', 'flags')
# ### end Alembic commands ###

View File

@ -0,0 +1,27 @@
"""custom domain indices
Revision ID: 62afa3a10010
Revises: 88dd7a0abf54
Create Date: 2024-09-30 11:40:04.127791
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = '62afa3a10010'
down_revision = '88dd7a0abf54'
branch_labels = None
depends_on = None
def upgrade():
with op.get_context().autocommit_block():
op.create_index('ix_custom_domain_pending_deletion', 'custom_domain', ['pending_deletion'], unique=False, postgresql_concurrently=True)
op.create_index('ix_custom_domain_user_id', 'custom_domain', ['user_id'], unique=False, postgresql_concurrently=True)
def downgrade():
with op.get_context().autocommit_block():
op.drop_index('ix_custom_domain_user_id', table_name='custom_domain', postgresql_concurrently=True)
op.drop_index('ix_custom_domain_pending_deletion', table_name='custom_domain', postgresql_concurrently=True)

View File

@ -0,0 +1,45 @@
"""alias_audit_log
Revision ID: 91ed7f46dc81
Revises: 62afa3a10010
Create Date: 2024-10-11 13:22:11.594054
"""
import sqlalchemy_utils
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '91ed7f46dc81'
down_revision = '62afa3a10010'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('alias_audit_log',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('created_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=False),
sa.Column('updated_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('alias_id', sa.Integer(), nullable=False),
sa.Column('alias_email', sa.String(length=255), nullable=False),
sa.Column('action', sa.String(length=255), nullable=False),
sa.Column('message', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index('ix_alias_audit_log_alias_email', 'alias_audit_log', ['alias_email'], unique=False)
op.create_index('ix_alias_audit_log_alias_id', 'alias_audit_log', ['alias_id'], unique=False)
op.create_index('ix_alias_audit_log_user_id', 'alias_audit_log', ['user_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index('ix_alias_audit_log_user_id', table_name='alias_audit_log')
op.drop_index('ix_alias_audit_log_alias_id', table_name='alias_audit_log')
op.drop_index('ix_alias_audit_log_alias_email', table_name='alias_audit_log')
op.drop_table('alias_audit_log')
# ### end Alembic commands ###

View File

@ -0,0 +1,44 @@
"""user_audit_log
Revision ID: 7d7b84779837
Revises: 91ed7f46dc81
Create Date: 2024-10-16 11:52:49.128644
"""
import sqlalchemy_utils
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '7d7b84779837'
down_revision = '91ed7f46dc81'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('user_audit_log',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('created_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=False),
sa.Column('updated_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('user_email', sa.String(length=255), nullable=False),
sa.Column('action', sa.String(length=255), nullable=False),
sa.Column('message', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index('ix_user_audit_log_user_email', 'user_audit_log', ['user_email'], unique=False)
op.create_index('ix_user_audit_log_user_id', 'user_audit_log', ['user_id'], unique=False)
op.create_index('ix_user_audit_log_created_at', 'user_audit_log', ['created_at'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index('ix_user_audit_log_user_id', table_name='user_audit_log')
op.drop_index('ix_user_audit_log_user_email', table_name='user_audit_log')
op.drop_index('ix_user_audit_log_created_at', table_name='user_audit_log')
op.drop_table('user_audit_log')
# ### end Alembic commands ###

View File

@ -0,0 +1,27 @@
"""alias_audit_log_index_created_at
Revision ID: 32f25cbf12f6
Revises: 7d7b84779837
Create Date: 2024-10-16 16:45:36.827161
"""
import sqlalchemy_utils
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '32f25cbf12f6'
down_revision = '7d7b84779837'
branch_labels = None
depends_on = None
def upgrade():
with op.get_context().autocommit_block():
op.create_index('ix_alias_audit_log_created_at', 'alias_audit_log', ['created_at'], unique=False, postgresql_concurrently=True)
def downgrade():
with op.get_context().autocommit_block():
op.drop_index('ix_alias_audit_log_created_at', table_name='alias_audit_log', postgresql_concurrently=True)

View File

@ -94,6 +94,20 @@ def log_nb_db_connection():
newrelic.agent.record_custom_metric("Custom/nb_db_connections", nb_connection)
@newrelic.agent.background_task()
def log_nb_db_connection_by_app_name():
# get the number of connections to the DB
rows = Session.execute(
"SELECT application_name, count(datid) FROM pg_stat_activity group by application_name"
)
for row in rows:
if row[0].find("sl-") == 0:
LOG.d("number of db connections for app %s = %s", row[0], row[1])
newrelic.agent.record_custom_metric(
f"Custom/nb_db_app_connection/{row[0]}", row[1]
)
@newrelic.agent.background_task()
def log_pending_to_process_events():
r = Session.execute("select count(*) from sync_event WHERE taken_time IS NULL;")
@ -125,6 +139,21 @@ def log_events_pending_dead_letter():
)
@newrelic.agent.background_task()
def log_failed_events():
r = Session.execute(
"""
SELECT COUNT(*)
FROM sync_event
WHERE retry_count >= 10;
""",
)
failed_events = list(r)[0][0]
LOG.d("number of failed events %s", failed_events)
newrelic.agent.record_custom_metric("Custom/sync_events_failed", failed_events)
if __name__ == "__main__":
exporter = MetricExporter(get_newrelic_license())
while True:
@ -132,6 +161,8 @@ if __name__ == "__main__":
log_nb_db_connection()
log_pending_to_process_events()
log_events_pending_dead_letter()
log_failed_events()
log_nb_db_connection_by_app_name()
Session.close()
exporter.run()

View File

@ -0,0 +1,49 @@
#!/usr/bin/env python3
import argparse
import time
from sqlalchemy import func
from app.models import Alias
from app.db import Session
parser = argparse.ArgumentParser(
prog="Backfill alias", description="Update alias notes and backfill flag"
)
parser.add_argument(
"-s", "--start_alias_id", default=0, type=int, help="Initial alias_id"
)
parser.add_argument("-e", "--end_alias_id", default=0, type=int, help="Last alias_id")
args = parser.parse_args()
alias_id_start = args.start_alias_id
max_alias_id = args.end_alias_id
if max_alias_id == 0:
max_alias_id = Session.query(func.max(Alias.id)).scalar()
print(f"Checking alias {alias_id_start} to {max_alias_id}")
step = 10000
noteSql = "(note = 'Created through Proton' or note = 'Created through partner Proton')"
alias_query = f"UPDATE alias set note = NULL, flags = flags | :flag where id>=:start AND id<:end and {noteSql}"
updated = 0
start_time = time.time()
for batch_start in range(alias_id_start, max_alias_id, step):
rows_done = Session.execute(
alias_query,
{
"start": batch_start,
"end": batch_start + step,
"flag": Alias.FLAG_PARTNER_CREATED,
},
)
updated += rows_done.rowcount
Session.commit()
elapsed = time.time() - start_time
last_batch_id = batch_start + step
time_per_alias = elapsed / (last_batch_id)
remaining = max_alias_id - last_batch_id
time_remaining = remaining / time_per_alias
hours_remaining = time_remaining / 60.0
print(
f"\rAlias {batch_start}/{max_alias_id} {updated} {hours_remaining:.2f} mins remaining"
)
print("")

View File

@ -0,0 +1,63 @@
#!/usr/bin/env python3
import argparse
import time
from sqlalchemy import func
from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import UserPlanChanged, EventContent
from app.models import PartnerUser
from app.db import Session
parser = argparse.ArgumentParser(
prog="Backfill alias", description="Update alias notes and backfill flag"
)
parser.add_argument(
"-s", "--start_pu_id", default=0, type=int, help="Initial partner_user_id"
)
parser.add_argument(
"-e", "--end_pu_id", default=0, type=int, help="Last partner_user_id"
)
args = parser.parse_args()
pu_id_start = args.start_pu_id
max_pu_id = args.end_pu_id
if max_pu_id == 0:
max_pu_id = Session.query(func.max(PartnerUser.id)).scalar()
print(f"Checking partner user {pu_id_start} to {max_pu_id}")
step = 100
updated = 0
start_time = time.time()
with_premium = 0
for batch_start in range(pu_id_start, max_pu_id, step):
partner_users = (
Session.query(PartnerUser).filter(
PartnerUser.id >= batch_start, PartnerUser.id < batch_start + step
)
).all()
for partner_user in partner_users:
subscription_end = partner_user.user.get_active_subscription_end(
include_partner_subscription=False
)
end_timestamp = None
if subscription_end:
with_premium += 1
end_timestamp = subscription_end.timestamp
event = UserPlanChanged(plan_end_time=end_timestamp)
EventDispatcher.send_event(
partner_user.user, EventContent(user_plan_change=event)
)
Session.flush()
updated += 1
Session.commit()
elapsed = time.time() - start_time
last_batch_id = batch_start + step
time_per_alias = elapsed / (last_batch_id)
remaining = max_pu_id - last_batch_id
time_remaining = remaining / time_per_alias
hours_remaining = time_remaining / 60.0
print(
f"\PartnerUser {batch_start}/{max_pu_id} {updated} {hours_remaining:.2f} mins remaining"
)
print(f"With SL premium {with_premium}")

98
app/poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand.
[[package]]
name = "aiohttp"
@ -360,35 +360,41 @@ files = [
[[package]]
name = "boto3"
version = "1.15.9"
version = "1.35.37"
description = "The AWS SDK for Python"
optional = false
python-versions = "*"
python-versions = ">=3.8"
files = [
{file = "boto3-1.15.9-py2.py3-none-any.whl", hash = "sha256:e0a1dbc0a0e460dc6de2f4144b5015edad3ab5c17ee83c6194b1a010d815bc60"},
{file = "boto3-1.15.9.tar.gz", hash = "sha256:02f5f7a2b1349760b030c34f90a9cb4600bf8fe3cbc76b801d122bc4cecf3a7f"},
{file = "boto3-1.35.37-py3-none-any.whl", hash = "sha256:385ca77bf8ea4ab2d97f6e2435bdb29f77d9301e2f7ac796c2f465753c2adf3c"},
{file = "boto3-1.35.37.tar.gz", hash = "sha256:470d981583885859fed2fd1c185eeb01cc03e60272d499bafe41b12625b158c8"},
]
[package.dependencies]
botocore = ">=1.18.9,<1.19.0"
jmespath = ">=0.7.1,<1.0.0"
s3transfer = ">=0.3.0,<0.4.0"
botocore = ">=1.35.37,<1.36.0"
jmespath = ">=0.7.1,<2.0.0"
s3transfer = ">=0.10.0,<0.11.0"
[package.extras]
crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
[[package]]
name = "botocore"
version = "1.18.9"
version = "1.35.37"
description = "Low-level, data-driven core of boto 3."
optional = false
python-versions = "*"
python-versions = ">=3.8"
files = [
{file = "botocore-1.18.9-py2.py3-none-any.whl", hash = "sha256:dc3244170254cbba7dfde00b0489f830069d93dd6a9e555178d989072d7ee7c2"},
{file = "botocore-1.18.9.tar.gz", hash = "sha256:35b06b8801eb2dd7e708de35581f9c0304740645874f3af5b8b0c1648f8d6365"},
{file = "botocore-1.35.37-py3-none-any.whl", hash = "sha256:64f965d4ba7adb8d79ce044c3aef7356e05dd74753cf7e9115b80f477845d920"},
{file = "botocore-1.35.37.tar.gz", hash = "sha256:b2b4d29bafd95b698344f2f0577bb67064adbf1735d8a0e3c7473daa59c23ba6"},
]
[package.dependencies]
jmespath = ">=0.7.1,<1.0.0"
jmespath = ">=0.7.1,<2.0.0"
python-dateutil = ">=2.1,<3.0.0"
urllib3 = {version = ">=1.20,<1.26", markers = "python_version != \"3.4\""}
urllib3 = {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""}
[package.extras]
crt = ["awscrt (==0.22.0)"]
[[package]]
name = "cachetools"
@ -1372,6 +1378,7 @@ files = [
{file = "greenlet-2.0.2-cp27-cp27m-win32.whl", hash = "sha256:6c3acb79b0bfd4fe733dff8bc62695283b57949ebcca05ae5c129eb606ff2d74"},
{file = "greenlet-2.0.2-cp27-cp27m-win_amd64.whl", hash = "sha256:283737e0da3f08bd637b5ad058507e578dd462db259f7f6e4c5c365ba4ee9343"},
{file = "greenlet-2.0.2-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:d27ec7509b9c18b6d73f2f5ede2622441de812e7b1a80bbd446cb0633bd3d5ae"},
{file = "greenlet-2.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d967650d3f56af314b72df7089d96cda1083a7fc2da05b375d2bc48c82ab3f3c"},
{file = "greenlet-2.0.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:30bcf80dda7f15ac77ba5af2b961bdd9dbc77fd4ac6105cee85b0d0a5fcf74df"},
{file = "greenlet-2.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:26fbfce90728d82bc9e6c38ea4d038cba20b7faf8a0ca53a9c07b67318d46088"},
{file = "greenlet-2.0.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9190f09060ea4debddd24665d6804b995a9c122ef5917ab26e1566dcc712ceeb"},
@ -1380,6 +1387,7 @@ files = [
{file = "greenlet-2.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:76ae285c8104046b3a7f06b42f29c7b73f77683df18c49ab5af7983994c2dd91"},
{file = "greenlet-2.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:2d4686f195e32d36b4d7cf2d166857dbd0ee9f3d20ae349b6bf8afc8485b3645"},
{file = "greenlet-2.0.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c4302695ad8027363e96311df24ee28978162cdcdd2006476c43970b384a244c"},
{file = "greenlet-2.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d4606a527e30548153be1a9f155f4e283d109ffba663a15856089fb55f933e47"},
{file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c48f54ef8e05f04d6eff74b8233f6063cb1ed960243eacc474ee73a2ea8573ca"},
{file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a1846f1b999e78e13837c93c778dcfc3365902cfb8d1bdb7dd73ead37059f0d0"},
{file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a06ad5312349fec0ab944664b01d26f8d1f05009566339ac6f63f56589bc1a2"},
@ -1409,6 +1417,7 @@ files = [
{file = "greenlet-2.0.2-cp37-cp37m-win32.whl", hash = "sha256:3f6ea9bd35eb450837a3d80e77b517ea5bc56b4647f5502cd28de13675ee12f7"},
{file = "greenlet-2.0.2-cp37-cp37m-win_amd64.whl", hash = "sha256:7492e2b7bd7c9b9916388d9df23fa49d9b88ac0640db0a5b4ecc2b653bf451e3"},
{file = "greenlet-2.0.2-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b864ba53912b6c3ab6bcb2beb19f19edd01a6bfcbdfe1f37ddd1778abfe75a30"},
{file = "greenlet-2.0.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1087300cf9700bbf455b1b97e24db18f2f77b55302a68272c56209d5587c12d1"},
{file = "greenlet-2.0.2-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:ba2956617f1c42598a308a84c6cf021a90ff3862eddafd20c3333d50f0edb45b"},
{file = "greenlet-2.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc3a569657468b6f3fb60587e48356fe512c1754ca05a564f11366ac9e306526"},
{file = "greenlet-2.0.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8eab883b3b2a38cc1e050819ef06a7e6344d4a990d24d45bc6f2cf959045a45b"},
@ -1417,6 +1426,7 @@ files = [
{file = "greenlet-2.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b0ef99cdbe2b682b9ccbb964743a6aca37905fda5e0452e5ee239b1654d37f2a"},
{file = "greenlet-2.0.2-cp38-cp38-win32.whl", hash = "sha256:b80f600eddddce72320dbbc8e3784d16bd3fb7b517e82476d8da921f27d4b249"},
{file = "greenlet-2.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:4d2e11331fc0c02b6e84b0d28ece3a36e0548ee1a1ce9ddde03752d9b79bba40"},
{file = "greenlet-2.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8512a0c38cfd4e66a858ddd1b17705587900dd760c6003998e9472b77b56d417"},
{file = "greenlet-2.0.2-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:88d9ab96491d38a5ab7c56dd7a3cc37d83336ecc564e4e8816dbed12e5aaefc8"},
{file = "greenlet-2.0.2-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:561091a7be172ab497a3527602d467e2b3fbe75f9e783d8b8ce403fa414f71a6"},
{file = "greenlet-2.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:971ce5e14dc5e73715755d0ca2975ac88cfdaefcaab078a284fea6cfabf866df"},
@ -2891,50 +2901,72 @@ files = [
[[package]]
name = "s3transfer"
version = "0.3.3"
version = "0.10.3"
description = "An Amazon S3 Transfer Manager"
optional = false
python-versions = "*"
python-versions = ">=3.8"
files = [
{file = "s3transfer-0.3.3-py2.py3-none-any.whl", hash = "sha256:2482b4259524933a022d59da830f51bd746db62f047d6eb213f2f8855dcb8a13"},
{file = "s3transfer-0.3.3.tar.gz", hash = "sha256:921a37e2aefc64145e7b73d50c71bb4f26f46e4c9f414dc648c6245ff92cf7db"},
{file = "s3transfer-0.10.3-py3-none-any.whl", hash = "sha256:263ed587a5803c6c708d3ce44dc4dfedaab4c1a32e8329bab818933d79ddcf5d"},
{file = "s3transfer-0.10.3.tar.gz", hash = "sha256:4f50ed74ab84d474ce614475e0b8d5047ff080810aac5d01ea25231cfc944b0c"},
]
[package.dependencies]
botocore = ">=1.12.36,<2.0a.0"
botocore = ">=1.33.2,<2.0a.0"
[package.extras]
crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"]
[[package]]
name = "sentry-sdk"
version = "1.5.11"
version = "2.16.0"
description = "Python client for Sentry (https://sentry.io)"
optional = false
python-versions = "*"
python-versions = ">=3.6"
files = [
{file = "sentry-sdk-1.5.11.tar.gz", hash = "sha256:6c01d9d0b65935fd275adc120194737d1df317dce811e642cbf0394d0d37a007"},
{file = "sentry_sdk-1.5.11-py2.py3-none-any.whl", hash = "sha256:c17179183cac614e900cbd048dab03f49a48e2820182ec686c25e7ce46f8548f"},
{file = "sentry_sdk-2.16.0-py2.py3-none-any.whl", hash = "sha256:49139c31ebcd398f4f6396b18910610a0c1602f6e67083240c33019d1f6aa30c"},
{file = "sentry_sdk-2.16.0.tar.gz", hash = "sha256:90f733b32e15dfc1999e6b7aca67a38688a567329de4d6e184154a73f96c6892"},
]
[package.dependencies]
certifi = "*"
urllib3 = ">=1.10.0"
urllib3 = ">=1.26.11"
[package.extras]
aiohttp = ["aiohttp (>=3.5)"]
anthropic = ["anthropic (>=0.16)"]
arq = ["arq (>=0.23)"]
asyncpg = ["asyncpg (>=0.23)"]
beam = ["apache-beam (>=2.12)"]
bottle = ["bottle (>=0.12.13)"]
celery = ["celery (>=3)"]
celery-redbeat = ["celery-redbeat (>=2)"]
chalice = ["chalice (>=1.16.0)"]
clickhouse-driver = ["clickhouse-driver (>=0.2.0)"]
django = ["django (>=1.8)"]
falcon = ["falcon (>=1.4)"]
flask = ["blinker (>=1.1)", "flask (>=0.11)"]
fastapi = ["fastapi (>=0.79.0)"]
flask = ["blinker (>=1.1)", "flask (>=0.11)", "markupsafe"]
grpcio = ["grpcio (>=1.21.1)", "protobuf (>=3.8.0)"]
http2 = ["httpcore[http2] (==1.*)"]
httpx = ["httpx (>=0.16.0)"]
huey = ["huey (>=2)"]
huggingface-hub = ["huggingface-hub (>=0.22)"]
langchain = ["langchain (>=0.0.210)"]
litestar = ["litestar (>=2.0.0)"]
loguru = ["loguru (>=0.5)"]
openai = ["openai (>=1.0.0)", "tiktoken (>=0.3.0)"]
opentelemetry = ["opentelemetry-distro (>=0.35b0)"]
opentelemetry-experimental = ["opentelemetry-distro"]
pure-eval = ["asttokens", "executing", "pure-eval"]
pymongo = ["pymongo (>=3.1)"]
pyspark = ["pyspark (>=2.4.4)"]
quart = ["blinker (>=1.1)", "quart (>=0.16.1)"]
rq = ["rq (>=0.6)"]
sanic = ["sanic (>=0.8)"]
sqlalchemy = ["sqlalchemy (>=1.2)"]
tornado = ["tornado (>=5)"]
starlette = ["starlette (>=0.19.1)"]
starlite = ["starlite (>=1.48)"]
tornado = ["tornado (>=6)"]
[[package]]
name = "setuptools"
@ -3295,18 +3327,18 @@ files = [
[[package]]
name = "urllib3"
version = "1.25.10"
version = "1.26.20"
description = "HTTP library with thread-safe connection pooling, file post, and more."
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4"
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7"
files = [
{file = "urllib3-1.25.10-py2.py3-none-any.whl", hash = "sha256:e7983572181f5e1522d9c98453462384ee92a0be7fac5f1413a1e35c56cc0461"},
{file = "urllib3-1.25.10.tar.gz", hash = "sha256:91056c15fa70756691db97756772bb1eb9678fa585d9184f24534b100dc60f4a"},
{file = "urllib3-1.26.20-py2.py3-none-any.whl", hash = "sha256:0ed14ccfbf1c30a9072c7ca157e4319b70d65f623e91e7b32fadb2853431016e"},
{file = "urllib3-1.26.20.tar.gz", hash = "sha256:40c2dc0c681e47eb8f90e7e27bf6ff7df2e677421fd46756da1161c39ca70d32"},
]
[package.extras]
brotli = ["brotlipy (>=0.6.0)"]
secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)"]
brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"]
secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"]
socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"]
[[package]]
@ -3704,4 +3736,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "22b9a61e9999a215aacb889b3790ee1a6840ce249aea2e3d16c6113243d5c126"
content-hash = "314f199bd50ccbf636ce1c6c753f8c79a1f5a16aa7c1a330a2ec514a13dbad2d"

View File

@ -10,21 +10,23 @@ message UserDeleted {
}
message AliasCreated {
uint32 alias_id = 1;
string alias_email = 2;
string alias_note = 3;
uint32 id = 1;
string email = 2;
string note = 3;
bool enabled = 4;
uint32 created_at = 5;
}
message AliasStatusChanged {
uint32 alias_id = 1;
string alias_email = 2;
uint32 id = 1;
string email = 2;
bool enabled = 3;
uint32 created_at = 4;
}
message AliasDeleted {
uint32 alias_id = 1;
string alias_email = 2;
uint32 id = 1;
string email = 2;
}
message AliasCreatedList {

View File

@ -69,7 +69,7 @@ python-dotenv = "^0.14.0"
ipython = "^7.31.1"
sqlalchemy_utils = "^0.36.8"
psycopg2-binary = "^2.9.3"
sentry_sdk = "^1.5.11"
sentry_sdk = "^2.16.0"
blinker = "^1.4"
arrow = "^0.16.0"
Flask-WTF = "^0.14.3"

View File

@ -12,10 +12,10 @@ docker run -p 25432:5432 --name ${container_name} -e POSTGRES_PASSWORD=postgres
sleep 3
# upgrade the DB to the latest stage and
env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl rye run alembic upgrade head
env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl poetry run alembic upgrade head
# generate the migration script.
env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl rye run alembic revision --autogenerate $@
env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl poetry run alembic revision --autogenerate $@
# remove the db
docker rm -f ${container_name}

View File

@ -1,4 +1,3 @@
import json
import os
import time
from datetime import timedelta
@ -7,10 +6,9 @@ import arrow
import click
import flask_limiter
import flask_profiler
import newrelic.agent
import sentry_sdk
from coinbase_commerce.error import WebhookInvalidPayload, SignatureVerificationError
from coinbase_commerce.webhook import Webhook
from dateutil.relativedelta import relativedelta
from flask import (
Flask,
redirect,
@ -29,7 +27,7 @@ from sentry_sdk.integrations.flask import FlaskIntegration
from sentry_sdk.integrations.sqlalchemy import SqlalchemyIntegration
from werkzeug.middleware.proxy_fix import ProxyFix
from app import paddle_utils, config, paddle_callback, constants
from app import config, constants
from app.admin_model import (
SLAdminIndexView,
UserAdmin,
@ -55,7 +53,6 @@ from app.config import (
FLASK_SECRET,
SENTRY_DSN,
URL,
PADDLE_MONTHLY_PRODUCT_ID,
FLASK_PROFILER_PATH,
FLASK_PROFILER_PASSWORD,
SENTRY_FRONT_END_DSN,
@ -69,22 +66,16 @@ from app.config import (
LANDING_PAGE_URL,
STATUS_PAGE_URL,
SUPPORT_EMAIL,
PADDLE_MONTHLY_PRODUCT_IDS,
PADDLE_YEARLY_PRODUCT_IDS,
PGP_SIGNER,
COINBASE_WEBHOOK_SECRET,
PAGE_LIMIT,
PADDLE_COUPON_ID,
ZENDESK_ENABLED,
MAX_NB_EMAIL_FREE_PLAN,
MEM_STORE_URI,
)
from app.dashboard.base import dashboard_bp
from app.subscription_webhook import execute_subscription_webhook
from app.db import Session
from app.developer.base import developer_bp
from app.discover.base import discover_bp
from app.email_utils import send_email, render
from app.extensions import login_manager, limiter
from app.fake_data import fake_data
from app.internal.base import internal_bp
@ -93,11 +84,8 @@ from app.log import LOG
from app.models import (
User,
Alias,
Subscription,
PlanEnum,
CustomDomain,
Mailbox,
CoinbaseSubscription,
EmailLog,
Contact,
ManualSubscription,
@ -114,9 +102,11 @@ from app.monitor.base import monitor_bp
from app.newsletter_utils import send_newsletter_to_user
from app.oauth.base import oauth_bp
from app.onboarding.base import onboarding_bp
from app.payments.coinbase import setup_coinbase_commerce
from app.payments.paddle import setup_paddle_callback
from app.phone.base import phone_bp
from app.redis_services import initialize_redis_services
from app.utils import random_string
from app.sentry_utils import sentry_before_send
if SENTRY_DSN:
LOG.d("enable sentry")
@ -127,6 +117,7 @@ if SENTRY_DSN:
FlaskIntegration(),
SqlalchemyIntegration(),
],
before_send=sentry_before_send,
)
# the app is served behind nginx which uses http and not https
@ -299,7 +290,9 @@ def set_index_page(app):
res.status_code,
time.time() - start_time,
)
newrelic.agent.record_custom_event(
"HttpResponseStatus", {"code": res.status_code}
)
return res
@ -441,341 +434,6 @@ def jinja2_filter(app):
)
def setup_paddle_callback(app: Flask):
@app.route("/paddle", methods=["GET", "POST"])
def paddle():
LOG.d(f"paddle callback {request.form.get('alert_name')} {request.form}")
# make sure the request comes from Paddle
if not paddle_utils.verify_incoming_request(dict(request.form)):
LOG.e("request not coming from paddle. Request data:%s", dict(request.form))
return "KO", 400
if (
request.form.get("alert_name") == "subscription_created"
): # new user subscribes
# the passthrough is json encoded, e.g.
# request.form.get("passthrough") = '{"user_id": 88 }'
passthrough = json.loads(request.form.get("passthrough"))
user_id = passthrough.get("user_id")
user = User.get(user_id)
subscription_plan_id = int(request.form.get("subscription_plan_id"))
if subscription_plan_id in PADDLE_MONTHLY_PRODUCT_IDS:
plan = PlanEnum.monthly
elif subscription_plan_id in PADDLE_YEARLY_PRODUCT_IDS:
plan = PlanEnum.yearly
else:
LOG.e(
"Unknown subscription_plan_id %s %s",
subscription_plan_id,
request.form,
)
return "No such subscription", 400
sub = Subscription.get_by(user_id=user.id)
if not sub:
LOG.d(f"create a new Subscription for user {user}")
Subscription.create(
user_id=user.id,
cancel_url=request.form.get("cancel_url"),
update_url=request.form.get("update_url"),
subscription_id=request.form.get("subscription_id"),
event_time=arrow.now(),
next_bill_date=arrow.get(
request.form.get("next_bill_date"), "YYYY-MM-DD"
).date(),
plan=plan,
)
else:
LOG.d(f"Update an existing Subscription for user {user}")
sub.cancel_url = request.form.get("cancel_url")
sub.update_url = request.form.get("update_url")
sub.subscription_id = request.form.get("subscription_id")
sub.event_time = arrow.now()
sub.next_bill_date = arrow.get(
request.form.get("next_bill_date"), "YYYY-MM-DD"
).date()
sub.plan = plan
# make sure to set the new plan as not-cancelled
# in case user cancels a plan and subscribes a new plan
sub.cancelled = False
execute_subscription_webhook(user)
LOG.d("User %s upgrades!", user)
Session.commit()
elif request.form.get("alert_name") == "subscription_payment_succeeded":
subscription_id = request.form.get("subscription_id")
LOG.d("Update subscription %s", subscription_id)
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
# when user subscribes, the "subscription_payment_succeeded" can arrive BEFORE "subscription_created"
# at that time, subscription object does not exist yet
if sub:
sub.event_time = arrow.now()
sub.next_bill_date = arrow.get(
request.form.get("next_bill_date"), "YYYY-MM-DD"
).date()
Session.commit()
execute_subscription_webhook(sub.user)
elif request.form.get("alert_name") == "subscription_cancelled":
subscription_id = request.form.get("subscription_id")
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
if sub:
# cancellation_effective_date should be the same as next_bill_date
LOG.w(
"Cancel subscription %s %s on %s, next bill date %s",
subscription_id,
sub.user,
request.form.get("cancellation_effective_date"),
sub.next_bill_date,
)
sub.event_time = arrow.now()
sub.cancelled = True
Session.commit()
user = sub.user
send_email(
user.email,
"SimpleLogin - your subscription is canceled",
render(
"transactional/subscription-cancel.txt",
user=user,
end_date=request.form.get("cancellation_effective_date"),
),
)
execute_subscription_webhook(sub.user)
else:
# user might have deleted their account
LOG.i(f"Cancel non-exist subscription {subscription_id}")
return "OK"
elif request.form.get("alert_name") == "subscription_updated":
subscription_id = request.form.get("subscription_id")
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
if sub:
next_bill_date = request.form.get("next_bill_date")
if not next_bill_date:
paddle_callback.failed_payment(sub, subscription_id)
return "OK"
LOG.d(
"Update subscription %s %s on %s, next bill date %s",
subscription_id,
sub.user,
request.form.get("cancellation_effective_date"),
sub.next_bill_date,
)
if (
int(request.form.get("subscription_plan_id"))
== PADDLE_MONTHLY_PRODUCT_ID
):
plan = PlanEnum.monthly
else:
plan = PlanEnum.yearly
sub.cancel_url = request.form.get("cancel_url")
sub.update_url = request.form.get("update_url")
sub.event_time = arrow.now()
sub.next_bill_date = arrow.get(
request.form.get("next_bill_date"), "YYYY-MM-DD"
).date()
sub.plan = plan
# make sure to set the new plan as not-cancelled
sub.cancelled = False
Session.commit()
execute_subscription_webhook(sub.user)
else:
LOG.w(
f"update non-exist subscription {subscription_id}. {request.form}"
)
return "No such subscription", 400
elif request.form.get("alert_name") == "payment_refunded":
subscription_id = request.form.get("subscription_id")
LOG.d("Refund request for subscription %s", subscription_id)
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
if sub:
user = sub.user
Subscription.delete(sub.id)
Session.commit()
LOG.e("%s requests a refund", user)
execute_subscription_webhook(sub.user)
elif request.form.get("alert_name") == "subscription_payment_refunded":
subscription_id = request.form.get("subscription_id")
sub: Subscription = Subscription.get_by(subscription_id=subscription_id)
LOG.d(
"Handle subscription_payment_refunded for subscription %s",
subscription_id,
)
if not sub:
LOG.w(
"No such subscription for %s, payload %s",
subscription_id,
request.form,
)
return "No such subscription"
plan_id = int(request.form["subscription_plan_id"])
if request.form["refund_type"] == "full":
if plan_id in PADDLE_MONTHLY_PRODUCT_IDS:
LOG.d("subtract 1 month from next_bill_date %s", sub.next_bill_date)
sub.next_bill_date = sub.next_bill_date - relativedelta(months=1)
LOG.d("next_bill_date is %s", sub.next_bill_date)
Session.commit()
elif plan_id in PADDLE_YEARLY_PRODUCT_IDS:
LOG.d("subtract 1 year from next_bill_date %s", sub.next_bill_date)
sub.next_bill_date = sub.next_bill_date - relativedelta(years=1)
LOG.d("next_bill_date is %s", sub.next_bill_date)
Session.commit()
else:
LOG.e("Unknown plan_id %s", plan_id)
else:
LOG.w("partial subscription_payment_refunded, not handled")
execute_subscription_webhook(sub.user)
return "OK"
@app.route("/paddle_coupon", methods=["GET", "POST"])
def paddle_coupon():
LOG.d("paddle coupon callback %s", request.form)
if not paddle_utils.verify_incoming_request(dict(request.form)):
LOG.e("request not coming from paddle. Request data:%s", dict(request.form))
return "KO", 400
product_id = request.form.get("p_product_id")
if product_id != PADDLE_COUPON_ID:
LOG.e("product_id %s not match with %s", product_id, PADDLE_COUPON_ID)
return "KO", 400
email = request.form.get("email")
LOG.d("Paddle coupon request for %s", email)
coupon = Coupon.create(
code=random_string(30),
comment="For 1-year coupon",
expires_date=arrow.now().shift(years=1, days=-1),
commit=True,
)
return (
f"Your 1-year coupon is <b>{coupon.code}</b> <br> "
f"It's valid until <b>{coupon.expires_date.date().isoformat()}</b>"
)
def setup_coinbase_commerce(app):
@app.route("/coinbase", methods=["POST"])
def coinbase_webhook():
# event payload
request_data = request.data.decode("utf-8")
# webhook signature
request_sig = request.headers.get("X-CC-Webhook-Signature", None)
try:
# signature verification and event object construction
event = Webhook.construct_event(
request_data, request_sig, COINBASE_WEBHOOK_SECRET
)
except (WebhookInvalidPayload, SignatureVerificationError) as e:
LOG.e("Invalid Coinbase webhook")
return str(e), 400
LOG.d("Coinbase event %s", event)
if event["type"] == "charge:confirmed":
if handle_coinbase_event(event):
return "success", 200
else:
return "error", 400
return "success", 200
def handle_coinbase_event(event) -> bool:
server_user_id = event["data"]["metadata"]["user_id"]
try:
user_id = int(server_user_id)
except ValueError:
user_id = int(float(server_user_id))
code = event["data"]["code"]
user = User.get(user_id)
if not user:
LOG.e("User not found %s", user_id)
return False
coinbase_subscription: CoinbaseSubscription = CoinbaseSubscription.get_by(
user_id=user_id
)
if not coinbase_subscription:
LOG.d("Create a coinbase subscription for %s", user)
coinbase_subscription = CoinbaseSubscription.create(
user_id=user_id, end_at=arrow.now().shift(years=1), code=code, commit=True
)
send_email(
user.email,
"Your SimpleLogin account has been upgraded",
render(
"transactional/coinbase/new-subscription.txt",
user=user,
coinbase_subscription=coinbase_subscription,
),
render(
"transactional/coinbase/new-subscription.html",
user=user,
coinbase_subscription=coinbase_subscription,
),
)
else:
if coinbase_subscription.code != code:
LOG.d("Update code from %s to %s", coinbase_subscription.code, code)
coinbase_subscription.code = code
if coinbase_subscription.is_active():
coinbase_subscription.end_at = coinbase_subscription.end_at.shift(years=1)
else: # already expired subscription
coinbase_subscription.end_at = arrow.now().shift(years=1)
Session.commit()
send_email(
user.email,
"Your SimpleLogin account has been extended",
render(
"transactional/coinbase/extend-subscription.txt",
user=user,
coinbase_subscription=coinbase_subscription,
),
render(
"transactional/coinbase/extend-subscription.html",
user=user,
coinbase_subscription=coinbase_subscription,
),
)
execute_subscription_webhook(user)
return True
def init_extensions(app: Flask):
login_manager.init_app(app)
@ -784,10 +442,10 @@ def init_admin(app):
admin = Admin(name="SimpleLogin", template_mode="bootstrap4")
admin.init_app(app, index_view=SLAdminIndexView())
admin.add_view(EmailSearchAdmin(name="Email Search", endpoint="email_search"))
admin.add_view(UserAdmin(User, Session))
admin.add_view(AliasAdmin(Alias, Session))
admin.add_view(MailboxAdmin(Mailbox, Session))
admin.add_view(EmailSearchAdmin(name="Email Search", endpoint="email_search"))
admin.add_view(CouponAdmin(Coupon, Session))
admin.add_view(ManualSubscriptionAdmin(ManualSubscription, Session))
admin.add_view(CustomDomainAdmin(CustomDomain, Session))

View File

@ -0,0 +1,12 @@
import arrow
from app.db import Session
from app.log import LOG
from app.models import AliasAuditLog
def cleanup_alias_audit_log(oldest_allowed: arrow.Arrow):
LOG.i(f"Deleting alias_audit_log older than {oldest_allowed}")
count = AliasAuditLog.filter(AliasAuditLog.created_at < oldest_allowed).delete()
Session.commit()
LOG.i(f"Deleted {count} alias_audit_log entries")

View File

@ -0,0 +1,12 @@
import arrow
from app.db import Session
from app.log import LOG
from app.models import UserAuditLog
def cleanup_user_audit_log(oldest_allowed: arrow.Arrow):
LOG.i(f"Deleting user_audit_log older than {oldest_allowed}")
count = UserAuditLog.filter(UserAuditLog.created_at < oldest_allowed).delete()
Session.commit()
LOG.i(f"Deleted {count} user_audit_log entries")

View File

@ -1,215 +1,279 @@
{% extends 'admin/master.html' %}
{% macro show_user(user) -%}
<h4>User {{ user.email }} with ID {{ user.id }}.</h4>
{% set pu = helper.partner_user(user) %}
<table class="table">
<thead>
<tr>
<th scope="col">User ID</th>
<th scope="col">Email</th>
<th scope="col">Status</th>
<th scope="col">Paid</th>
<th>Subscription</th>
<th>Created At</th>
<th>Updated At</th>
<th>Connected with Proton account</th>
</tr>
</thead>
<tbody>
<tr>
<td>{{ user.id }}</td>
<td>{{ user.email }}</td>
{% if user.disabled %}
<td class="text-danger">Disabled</td>
{% else %}
<td class="text-success">Enabled</td>
{% endif %}
<td>{{ "yes" if user.is_paid() else "No" }}</td>
<td>{{ user.get_active_subscription() }}</td>
<td>{{ user.created_at }}</td>
<td>{{ user.updated_at }}</td>
{% if pu %}
<td>{{ pu.partner_email }}</td>
{% else %}
<td>No</td>
{% endif %}
</tr>
</tbody>
</table>
{%- endmacro %}
{% macro list_mailboxes(mbox_count, mboxes) %}
<h4>
{{ mbox_count }} Mailboxes found.
{% if mbox_count>10 %}Showing only the first 10.{% endif %}
</h4>
<table class="table">
<thead>
<tr>
<th>Mailbox ID</th>
<th>Email</th>
<th>Verified</th>
<th>Created At</th>
</tr>
</thead>
<tbody>
{% for mailbox in mboxes %}
<h4>User {{ user.email }} with ID {{ user.id }}.</h4>
{% set pu = helper.partner_user(user) %}
<table class="table">
<thead>
<tr>
<td>{{ mailbox.id }}</td>
<td>{{ mailbox.email }}</td>
<td>{{ "Yes" if mailbox.verified else "No" }}</td>
<td>
{{ mailbox.created_at }}
</td>
<th scope="col">User ID</th>
<th scope="col">Email</th>
<th scope="col">Status</th>
<th scope="col">Paid</th>
<th>Subscription</th>
<th>Created At</th>
<th>Updated At</th>
<th>Connected with Proton account</th>
</tr>
{% endfor %}
</tbody>
</table>
</thead>
<tbody>
<tr>
<td>{{ user.id }}</td>
<td><a href="?email={{ user.email }}">{{ user.email }}</a></td>
{% if user.disabled %}
<td class="text-danger">Disabled</td>
{% else %}
<td class="text-success">Enabled</td>
{% endif %}
<td>{{ "yes" if user.is_paid() else "No" }}</td>
<td>{{ user.get_active_subscription() }}</td>
<td>{{ user.created_at }}</td>
<td>{{ user.updated_at }}</td>
{% if pu %}
<td><a href="?email={{ pu.partner_email }}">{{ pu.partner_email }}</a></td>
{% else %}
<td>No</td>
{% endif %}
</tr>
</tbody>
</table>
{%- endmacro %}
{% macro list_mailboxes(message, mbox_count, mboxes) %}
<h4>
{{ mbox_count }} {{ message }}.
{% if mbox_count>10 %}Showing only the last 10.{% endif %}
</h4>
<table class="table">
<thead>
<tr>
<th>Mailbox ID</th>
<th>Email</th>
<th>Verified</th>
<th>Created At</th>
</tr>
</thead>
<tbody>
{% for mailbox in mboxes %}
<tr>
<td>{{ mailbox.id }}</td>
<td><a href="?email={{ mailbox.email }}">{{ mailbox.email }}</a></td>
<td>{{ "Yes" if mailbox.verified else "No" }}</td>
<td>
{{ mailbox.created_at }}
</td>
</tr>
{% endfor %}
</tbody>
</table>
{% endmacro %}
{% macro list_alias(alias_count, aliases) %}
<h4>
{{ alias_count }} Aliases found.
{% if alias_count>10 %}Showing only the first 10.{% endif %}
</h4>
<table class="table">
<thead>
<tr>
<th>
Alias ID
</th>
<th>
Email
</th>
<th>
Verified
</th>
<th>
Created At
</th>
</tr>
</thead>
<tbody>
{% for alias in aliases %}
<h4>
{{ alias_count }} Aliases found.
{% if alias_count>10 %}Showing only the last 10.{% endif %}
</h4>
<table class="table">
<thead>
<tr>
<td>{{ alias.id }}</td>
<td>{{ alias.email }}</td>
<td>{{ "Yes" if alias.verified else "No" }}</td>
<td>{{ alias.created_at }}</td>
<th>
Alias ID
</th>
<th>
Email
</th>
<th>
Enabled
</th>
<th>
Created At
</th>
</tr>
{% endfor %}
</tbody>
</table>
</thead>
<tbody>
{% for alias in aliases %}
<tr>
<td>{{ alias.id }}</td>
<td><a href="?email={{ alias.email }}">{{ alias.email }}</a></td>
<td>{{ "Yes" if alias.enabled else "No" }}</td>
<td>{{ alias.created_at }}</td>
</tr>
{% endfor %}
</tbody>
</table>
{% endmacro %}
{% macro show_deleted_alias(deleted_alias) -%}
<h4>Deleted Alias {{ deleted_alias.email }} with ID {{ deleted_alias.id }}.</h4>
<table class="table">
<thead>
<tr>
<th scope="col">Deleted Alias ID</th>
<th scope="col">Email</th>
<th scope="col">Deleted At</th>
<th scope="col">Reason</th>
</tr>
</thead>
<tbody>
<tr>
<td>{{ deleted_alias.id }}</td>
<td>{{ deleted_alias.email }}</td>
<td>{{ deleted_alias.created_at }}</td>
<td>{{ deleted_alias.reason }}</td>
</tr>
</tbody>
</table>
<h4>Deleted Alias {{ deleted_alias.email }} with ID {{ deleted_alias.id }}.</h4>
<table class="table">
<thead>
<tr>
<th scope="col">Deleted Alias ID</th>
<th scope="col">Email</th>
<th scope="col">Deleted At</th>
<th scope="col">Reason</th>
</tr>
</thead>
<tbody>
<tr>
<td>{{ deleted_alias.id }}</td>
<td>{{ deleted_alias.email }}</td>
<td>{{ deleted_alias.created_at }}</td>
<td>{{ deleted_alias.reason }}</td>
</tr>
</tbody>
</table>
{%- endmacro %}
{% macro show_domain_deleted_alias(dom_deleted_alias) -%}
<h4>
Domain Deleted Alias {{ dom_deleted_alias.email }} with ID {{ dom_deleted_alias.id }} for domain {{ dom_deleted_alias.domain.domain }}
</h4>
<table class="table">
<thead>
<tr>
<th scope="col">Deleted Alias ID</th>
<th scope="col">Email</th>
<th scope="col">Domain</th>
<th scope="col">Domain ID</th>
<th scope="col">Domain owner user ID</th>
<th scope="col">Domain owner user email</th>
<th scope="col">Deleted At</th>
</tr>
</thead>
<tbody>
<tr>
<td>{{ dom_deleted_alias.id }}</td>
<td>{{ dom_deleted_alias.email }}</td>
<td>{{ dom_deleted_alias.domain.domain }}</td>
<td>{{ dom_deleted_alias.domain.id }}</td>
<td>{{ dom_deleted_alias.domain.user_id }}</td>
<td>{{ dom_deleted_alias.created_at }}</td>
</tr>
</tbody>
</table>
{{ show_user(data.domain_deleted_alias.domain.user) }}
<h4>
Domain Deleted Alias {{ dom_deleted_alias.email }} with ID {{ dom_deleted_alias.id }} for
domain {{ dom_deleted_alias.domain.domain }}
</h4>
<table class="table">
<thead>
<tr>
<th scope="col">Deleted Alias ID</th>
<th scope="col">Email</th>
<th scope="col">Domain</th>
<th scope="col">Domain ID</th>
<th scope="col">Domain owner user ID</th>
<th scope="col">Domain owner user email</th>
<th scope="col">Deleted At</th>
</tr>
</thead>
<tbody>
<tr>
<td>{{ dom_deleted_alias.id }}</td>
<td>{{ dom_deleted_alias.email }}</td>
<td>{{ dom_deleted_alias.domain.domain }}</td>
<td>{{ dom_deleted_alias.domain.id }}</td>
<td>{{ dom_deleted_alias.domain.user_id }}</td>
<td>{{ dom_deleted_alias.created_at }}</td>
</tr>
</tbody>
</table>
{{ show_user(data.domain_deleted_alias.domain.user) }}
{%- endmacro %}
{% macro list_alias_audit_log(alias_audit_log) %}
<h4>Alias Audit Log</h4>
<table class="table">
<thead>
<tr>
<th>User ID</th>
<th>Alias ID</th>
<th>Alias Email</th>
<th>Action</th>
<th>Message</th>
<th>Time</th>
</tr>
</thead>
<tbody>
{% for entry in alias_audit_log %}
<tr>
<td>{{ entry.user_id }}</td>
<td>{{ entry.alias_id }}</td>
<td><a href="?email={{ entry.alias_email }}">{{ entry.alias_email }}</a></td>
<td>{{ entry.action }}</td>
<td>{{ entry.message }}</td>
<td>{{ entry.created_at }}</td>
</tr>
{% endfor %}
</tbody>
</table>
{% endmacro %}
{% macro list_user_audit_log(user_audit_log) %}
<h4>User Audit Log</h4>
<table class="table">
<thead>
<tr>
<th>User email</th>
<th>Action</th>
<th>Message</th>
<th>Time</th>
</tr>
</thead>
<tbody>
{% for entry in user_audit_log %}
<tr>
<td><a href="?email={{ entry.user_email }}">{{ entry.user_email }}</a></td>
<td>{{ entry.action }}</td>
<td>{{ entry.message }}</td>
<td>{{ entry.created_at }}</td>
</tr>
{% endfor %}
</tbody>
</table>
{% endmacro %}
{% block body %}
<div class="border border-dark border-2 mt-1 mb-2 p-3">
<form method="post">
<div class="form-group">
<label for="email">Email to search:</label>
<input type="text"
class="form-control"
name="email"
value="{{ email or '' }}" />
</div>
<button type="submit" class="btn btn-primary">Submit</button>
</form>
</div>
{% if no_match %}
<div class="border border-dark border-2 mt-1 mb-2 p-3 alert alert-warning"
role="alert">No user, alias or mailbox found for {{ email }}</div>
{% endif %}
{% if data.alias %}
<div class="border border-dark border-2 mt-1 mb-2 p-3">
<h3 class="mb-3">Found Alias {{ data.alias.email }}</h3>
{{ list_alias(1,[data.alias]) }}
{{ show_user(data.alias.user) }}
{{ list_mailboxes(helper.mailbox_count(data.alias.user) , helper.mailbox_list(data.alias.user) ) }}
<form method="get">
<div class="form-group">
<label for="email">Email to search:</label>
<input type="text"
class="form-control"
name="email"
value="{{ email or '' }}"/>
</div>
<button type="submit" class="btn btn-primary">Submit</button>
</form>
</div>
{% endif %}
{% if data.user %}
{% if data.no_match and email %}
<div class="border border-dark border-2 mt-1 mb-2 p-3 alert alert-warning"
role="alert">No user, alias or mailbox found for {{ email }}</div>
{% endif %}
<div class="border border-dark border-2 mt-1 mb-2 p-3">
<h3 class="mb-3">Found User {{ data.user.email }}</h3>
{{ show_user(data.user) }}
{{ list_mailboxes(helper.mailbox_count(data.user) , helper.mailbox_list(data.user) ) }}
{{ list_alias(helper.alias_count(data.user) ,helper.alias_list(data.user)) }}
</div>
{% endif %}
{% if data.mailbox %}
{% if data.alias %}
<div class="border border-dark border-2 mt-1 mb-2 p-3">
<h3 class="mb-3">Found Alias {{ data.alias.email }}</h3>
{{ list_alias(1,[data.alias]) }}
{{ list_alias_audit_log(data.alias_audit_log) }}
{{ list_mailboxes("Mailboxes for alias", helper.alias_mailbox_count(data.alias), helper.alias_mailboxes(data.alias)) }}
{{ show_user(data.alias.user) }}
</div>
{% endif %}
<div class="border border-dark mt-1 mb-2 p-3">
<h3 class="mb-3">Found Mailbox {{ data.mailbox.email }}</h3>
{{ list_mailboxes(1, [data.mailbox]) }}
{{ show_user(data.mailbox.user) }}
</div>
{% endif %}
{% if data.deleted_alias %}
{% if data.user %}
<div class="border border-dark border-2 mt-1 mb-2 p-3">
<h3 class="mb-3">Found User {{ data.user.email }}</h3>
{{ show_user(data.user) }}
{{ list_mailboxes("Mailboxes for user", helper.mailbox_count(data.user) , helper.mailbox_list(data.user) ) }}
{{ list_alias(helper.alias_count(data.user) ,helper.alias_list(data.user)) }}
</div>
{% endif %}
{% if data.user_audit_log %}
<div class="border border-dark border-2 mt-1 mb-2 p-3">
<h3 class="mb-3">Audit log entries for user {{ data.query }}</h3>
{{ list_user_audit_log(data.user_audit_log) }}
</div>
{% endif %}
{% if data.mailbox_count > 10 %}
<h3>Found more than 10 mailboxes for {{ email }}. Showing the last 10</h3>
{% elif data.mailbox_count > 0 %}
<h3>Found {{ data.mailbox_count }} mailbox(es) for {{ email }}</h3>
{% endif %}
{% for mailbox in data.mailbox %}
<div class="border border-dark mt-1 mb-2 p-3">
<h3 class="mb-3">Found DeletedAlias {{ data.deleted_alias.email }}</h3>
{{ show_deleted_alias(data.deleted_alias) }}
</div>
{% endif %}
{% if data.domain_deleted_alias %}
<div class="border border-dark mt-1 mb-2 p-3">
<h3 class="mb-3">Found Mailbox {{ mailbox.email }}</h3>
{{ list_mailboxes("Mailbox found", 1, [mailbox]) }}
{{ show_user(mailbox.user) }}
</div>
{% endfor %}
{% if data.deleted_alias %}
<div class="border border-dark mt-1 mb-2 p-3">
<h3 class="mb-3">Found DomainDeletedAlias {{ data.domain_deleted_alias.email }}</h3>
{{ show_domain_deleted_alias(data.domain_deleted_alias) }}
</div>
{% endif %}
<div class="border border-dark mt-1 mb-2 p-3">
<h3 class="mb-3">Found DeletedAlias {{ data.deleted_alias.email }}</h3>
{{ show_deleted_alias(data.deleted_alias) }}
{{ list_alias_audit_log(data.deleted_alias_audit_log) }}
</div>
{% endif %}
{% if data.domain_deleted_alias %}
<div class="border border-dark mt-1 mb-2 p-3">
<h3 class="mb-3">Found DomainDeletedAlias {{ data.domain_deleted_alias.email }}</h3>
{{ show_domain_deleted_alias(data.domain_deleted_alias) }}
{{ list_alias_audit_log(data.domain_deleted_alias_audit_log) }}
</div>
{% endif %}
{% endblock %}

View File

@ -94,4 +94,4 @@
</div>
</div>
{% endblock %}
{% block script %}<script>$('.mailbox-select').multipleSelect();</script>{% endblock %}

View File

@ -38,7 +38,7 @@
Value: <em data-toggle="tooltip"
title="Click to copy"
class="clipboard"
data-clipboard-text="{{ custom_domain.get_ownership_dns_txt_value() }}">{{ custom_domain.get_ownership_dns_txt_value() }}</em>
data-clipboard-text="{{ ownership_record }}">{{ ownership_record }}</em>
</div>
<form method="post" action="#ownership-form">
{{ csrf_form.csrf_token }}
@ -91,7 +91,8 @@
<br />
Some domain registrars (Namecheap, CloudFlare, etc) might also use <em>@</em> for the root domain.
</div>
{% for priority, email_server in EMAIL_SERVERS_WITH_PRIORITY %}
{% for record in expected_mx_records %}
<div class="mb-3 p-3 dns-record">
Record: MX
@ -99,14 +100,15 @@
Domain: {{ custom_domain.domain }} or
<b>@</b>
<br />
Priority: {{ priority }}
Priority: {{ record.priority }}
<br />
Target: <em data-toggle="tooltip"
title="Click to copy"
class="clipboard"
data-clipboard-text="{{ email_server }}">{{ email_server }}</em>
data-clipboard-text="{{ record.domain }}">{{ record.domain }}</em>
</div>
{% endfor %}
<form method="post" action="#mx-form">
{{ csrf_form.csrf_token }}
<input type="hidden" name="form-name" value="check-mx">
@ -237,7 +239,7 @@
folder.
</div>
<div class="mb-2">Add the following CNAME DNS records to your domain.</div>
{% for dkim_prefix, dkim_cname_value in dkim_records %}
{% for dkim_prefix, dkim_cname_value in dkim_records.items() %}
<div class="mb-2 p-3 dns-record">
Record: CNAME

View File

@ -21,8 +21,8 @@
<div class="my-3">
<p>Alternatively you can use your Proton credentials to ensure it's you.</p>
</div>
<a class="btn btn-primary btn-block mt-2 proton-button w-25"
href="{{ url_for('auth.proton_login', next=next) }}">
<a class="btn btn-primary btn-block mt-2 proton-button"
href="{{ url_for('auth.proton_login', next=next) }}" style="max-width: 400px">
<img class="mr-2" src="/static/images/proton.svg" />
Authenticate with Proton
</a>
@ -38,4 +38,4 @@
{% endif %}
</div>
</div>
{% endblock %}
{% endblock %}

View File

@ -536,7 +536,7 @@ def test_create_contact_route_free_users(flask_client):
assert r.status_code == 201
# End trial and disallow for new free users. Config should allow it
user.flags = User.FLAG_FREE_DISABLE_CREATE_ALIAS
user.flags = User.FLAG_DISABLE_CREATE_CONTACTS
Session.commit()
r = flask_client.post(
url_for("api.create_contact_route", alias_id=alias.id),

View File

@ -4,7 +4,7 @@ from app.models import (
Alias,
Contact,
)
from tests.utils import login
from tests.utils import login, random_email
def test_add_contact_success(flask_client):
@ -13,26 +13,28 @@ def test_add_contact_success(flask_client):
assert Contact.filter_by(user_id=user.id).count() == 0
email = random_email()
# <<< Create a new contact >>>
flask_client.post(
url_for("dashboard.alias_contact_manager", alias_id=alias.id),
data={
"form-name": "create",
"email": "abcd@gmail.com",
"email": email,
},
follow_redirects=True,
)
# a new contact is added
assert Contact.filter_by(user_id=user.id).count() == 1
contact = Contact.filter_by(user_id=user.id).first()
assert contact.website_email == "abcd@gmail.com"
assert contact.website_email == email
# <<< Create a new contact using a full email format >>>
email = random_email()
flask_client.post(
url_for("dashboard.alias_contact_manager", alias_id=alias.id),
data={
"form-name": "create",
"email": "First Last <another@gmail.com>",
"email": f"First Last <{email}>",
},
follow_redirects=True,
)
@ -41,7 +43,7 @@ def test_add_contact_success(flask_client):
contact = (
Contact.filter_by(user_id=user.id).filter(Contact.id != contact.id).first()
)
assert contact.website_email == "another@gmail.com"
assert contact.website_email == email
assert contact.name == "First Last"
# <<< Create a new contact with invalid email address >>>

View File

@ -1,38 +1,72 @@
import app.alias_utils
from app import config
from app.db import Session
from app.events.event_dispatcher import GlobalDispatcher
from app.models import (
Alias,
Mailbox,
User,
AliasMailbox,
)
from tests.events.event_test_utils import (
OnMemoryDispatcher,
_get_event_from_string,
_create_linked_user,
)
from tests.utils import login
on_memory_dispatcher = OnMemoryDispatcher()
def setup_module():
GlobalDispatcher.set_dispatcher(on_memory_dispatcher)
config.EVENT_WEBHOOK = "http://test"
def teardown_module():
GlobalDispatcher.set_dispatcher(None)
config.EVENT_WEBHOOK = None
def test_alias_transfer(flask_client):
user = login(flask_client)
mb = Mailbox.create(user_id=user.id, email="mb@gmail.com", commit=True)
(source_user, source_user_pu) = _create_linked_user()
source_user = login(flask_client, source_user)
mb = Mailbox.create(user_id=source_user.id, email="mb@gmail.com", commit=True)
alias = Alias.create_new_random(user)
alias = Alias.create_new_random(source_user)
Session.commit()
AliasMailbox.create(alias_id=alias.id, mailbox_id=mb.id, commit=True)
new_user = User.create(
email="hey@example.com",
password="password",
activated=True,
commit=True,
)
(target_user, target_user_pu) = _create_linked_user()
Mailbox.create(
user_id=new_user.id, email="hey2@example.com", verified=True, commit=True
user_id=target_user.id, email="hey2@example.com", verified=True, commit=True
)
app.alias_utils.transfer_alias(alias, new_user, new_user.mailboxes())
on_memory_dispatcher.clear()
app.alias_utils.transfer_alias(alias, target_user, target_user.mailboxes())
# refresh from db
alias = Alias.get(alias.id)
assert alias.user == new_user
assert set(alias.mailboxes) == set(new_user.mailboxes())
assert alias.user == target_user
assert set(alias.mailboxes) == set(target_user.mailboxes())
assert len(alias.mailboxes) == 2
# Check events
assert len(on_memory_dispatcher.memory) == 2
# 1st delete event
event_data = on_memory_dispatcher.memory[0]
event_content = _get_event_from_string(event_data, source_user, source_user_pu)
assert event_content.alias_deleted is not None
alias_deleted = event_content.alias_deleted
assert alias_deleted.id == alias.id
assert alias_deleted.email == alias.email
# 2nd create event
event_data = on_memory_dispatcher.memory[1]
event_content = _get_event_from_string(event_data, target_user, target_user_pu)
assert event_content.alias_created is not None
alias_created = event_content.alias_created
assert alias.id == alias_created.id
assert alias.email == alias_created.email
assert alias.note or "" == alias_created.note
assert alias.enabled == alias_created.enabled

View File

@ -0,0 +1,44 @@
from app.events.event_dispatcher import Dispatcher
from app.events.generated import event_pb2
from app.models import PartnerUser, User
from app.proton.utils import get_proton_partner
from tests.utils import create_new_user, random_token
from typing import Tuple
class OnMemoryDispatcher(Dispatcher):
def __init__(self):
self.memory = []
def send(self, event: bytes):
self.memory.append(event)
def clear(self):
self.memory = []
def _create_unlinked_user() -> User:
return create_new_user()
def _create_linked_user() -> Tuple[User, PartnerUser]:
user = _create_unlinked_user()
partner_user = PartnerUser.create(
partner_id=get_proton_partner().id,
user_id=user.id,
external_user_id=random_token(10),
flush=True,
)
return user, partner_user
def _get_event_from_string(
data: str, user: User, pu: PartnerUser
) -> event_pb2.EventContent:
event = event_pb2.Event()
event.ParseFromString(data)
assert user.id == event.user_id
assert pu.external_user_id == event.external_user_id
assert pu.partner_id == event.partner_id
return event.content

View File

@ -0,0 +1,54 @@
import arrow
from app.db import Session
from app.models import SyncEvent
from events.event_source import DeadLetterEventSource, _DEAD_LETTER_THRESHOLD_MINUTES
class EventCounter:
def __init__(self):
self.processed_events = 0
def on_event(self, event: SyncEvent):
self.processed_events += 1
def setup_function(func):
Session.query(SyncEvent).delete()
def test_dead_letter_does_not_take_untaken_events():
source = DeadLetterEventSource(1)
counter = EventCounter()
threshold_time = arrow.utcnow().shift(minutes=-(_DEAD_LETTER_THRESHOLD_MINUTES) + 1)
SyncEvent.create(
content="test".encode("utf-8"), created_at=threshold_time, flush=True
)
SyncEvent.create(
content="test".encode("utf-8"), taken_time=threshold_time, flush=True
)
events_processed = source.execute_loop(on_event=counter.on_event)
assert len(events_processed) == 0
assert counter.processed_events == 0
def test_dead_letter_takes_untaken_events_created_older_than_threshold():
source = DeadLetterEventSource(1)
counter = EventCounter()
old_create = arrow.utcnow().shift(minutes=-_DEAD_LETTER_THRESHOLD_MINUTES - 1)
SyncEvent.create(content="test".encode("utf-8"), created_at=old_create, flush=True)
events_processed = source.execute_loop(on_event=counter.on_event)
assert len(events_processed) == 1
assert events_processed[0].taken_time > old_create
assert counter.processed_events == 1
def test_dead_letter_takes_taken_events_created_older_than_threshold():
source = DeadLetterEventSource(1)
counter = EventCounter()
old_taken = arrow.utcnow().shift(minutes=-_DEAD_LETTER_THRESHOLD_MINUTES - 1)
SyncEvent.create(content="test".encode("utf-8"), taken_time=old_taken, flush=True)
events_processed = source.execute_loop(on_event=counter.on_event)
assert len(events_processed) == 1
assert events_processed[0].taken_time > old_taken
assert counter.processed_events == 1

View File

@ -1,33 +1,10 @@
from app.events.event_dispatcher import EventDispatcher, Dispatcher
from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import EventContent, UserDeleted
from app.models import PartnerUser, User
from app.proton.utils import get_proton_partner
from tests.utils import create_new_user, random_token
from typing import Tuple
class OnMemoryDispatcher(Dispatcher):
def __init__(self):
self.memory = []
def send(self, event: bytes):
self.memory.append(event)
def _create_unlinked_user() -> User:
return create_new_user()
def _create_linked_user() -> Tuple[User, PartnerUser]:
user = _create_unlinked_user()
partner_user = PartnerUser.create(
partner_id=get_proton_partner().id,
user_id=user.id,
external_user_id=random_token(10),
flush=True,
)
return user, partner_user
from .event_test_utils import (
_create_unlinked_user,
OnMemoryDispatcher,
_create_linked_user,
)
def test_event_dispatcher_stores_events():

View File

@ -0,0 +1,120 @@
import arrow
from app import config, alias_utils
from app.db import Session
from app.events.event_dispatcher import GlobalDispatcher
from app.models import Alias, SyncEvent
from tests.utils import random_token
from .event_test_utils import (
OnMemoryDispatcher,
_create_linked_user,
_get_event_from_string,
)
on_memory_dispatcher = OnMemoryDispatcher()
def setup_module():
GlobalDispatcher.set_dispatcher(on_memory_dispatcher)
config.EVENT_WEBHOOK = "http://test"
def teardown_module():
GlobalDispatcher.set_dispatcher(None)
config.EVENT_WEBHOOK = None
def setup_function(func):
on_memory_dispatcher.clear()
def test_event_taken_updates():
event = SyncEvent.create(content="test".encode("utf-8"), flush=True)
assert event.taken_time is None
assert event.mark_as_taken()
assert event.taken_time is not None
def test_event_mark_as_taken_does_nothing_for_taken_events():
now = arrow.utcnow()
event = SyncEvent.create(content="test".encode("utf-8"), taken_time=now, flush=True)
assert not event.mark_as_taken()
def test_event_mark_as_taken_does_nothing_for_not_before_events():
now = arrow.utcnow()
event = SyncEvent.create(content="test".encode("utf-8"), taken_time=now, flush=True)
older_than = now.shift(minutes=-1)
assert not event.mark_as_taken(allow_taken_older_than=older_than)
def test_event_mark_as_taken_works_for_before_events():
now = arrow.utcnow()
event = SyncEvent.create(content="test".encode("utf-8"), taken_time=now, flush=True)
older_than = now.shift(minutes=+1)
assert event.mark_as_taken(allow_taken_older_than=older_than)
def test_fire_event_on_alias_creation():
(user, pu) = _create_linked_user()
alias = Alias.create_new_random(user)
Session.flush()
assert len(on_memory_dispatcher.memory) == 1
event_data = on_memory_dispatcher.memory[0]
event_content = _get_event_from_string(event_data, user, pu)
assert event_content.alias_created is not None
alias_created = event_content.alias_created
assert alias.id == alias_created.id
assert alias.email == alias_created.email
assert "" == alias_created.note
assert alias.enabled == alias_created.enabled
assert int(alias.created_at.timestamp) == alias_created.created_at
def test_fire_event_on_alias_creation_with_note():
(user, pu) = _create_linked_user()
note = random_token(10)
alias = Alias.create_new_random(user, note=note)
Session.flush()
assert len(on_memory_dispatcher.memory) == 1
event_data = on_memory_dispatcher.memory[0]
event_content = _get_event_from_string(event_data, user, pu)
assert event_content.alias_created is not None
alias_created = event_content.alias_created
assert alias.id == alias_created.id
assert alias.email == alias_created.email
assert note == alias_created.note
assert alias.enabled == alias_created.enabled
def test_fire_event_on_alias_deletion():
(user, pu) = _create_linked_user()
alias = Alias.create_new_random(user)
alias_id = alias.id
Session.flush()
on_memory_dispatcher.clear()
alias_utils.delete_alias(alias, user)
assert len(on_memory_dispatcher.memory) == 1
event_data = on_memory_dispatcher.memory[0]
event_content = _get_event_from_string(event_data, user, pu)
assert event_content.alias_deleted is not None
alias_deleted = event_content.alias_deleted
assert alias_id == alias_deleted.id
assert alias.email == alias_deleted.email
def test_fire_event_on_alias_status_change():
(user, pu) = _create_linked_user()
alias = Alias.create_new_random(user)
Session.flush()
on_memory_dispatcher.clear()
alias_utils.change_alias_status(alias, enabled=True)
assert len(on_memory_dispatcher.memory) == 1
event_data = on_memory_dispatcher.memory[0]
event_content = _get_event_from_string(event_data, user, pu)
assert event_content.alias_status_change is not None
event = event_content.alias_status_change
assert alias.id == event.id
assert alias.email == event.email
assert int(alias.created_at.timestamp) == event.created_at
assert event.enabled

View File

@ -0,0 +1,109 @@
import arrow
from app import config
from app.events.event_dispatcher import GlobalDispatcher
from app.events.generated.event_pb2 import UserPlanChanged
from app.models import (
Subscription,
AppleSubscription,
CoinbaseSubscription,
ManualSubscription,
User,
PartnerUser,
)
from .event_test_utils import (
OnMemoryDispatcher,
_create_linked_user,
_get_event_from_string,
)
from tests.utils import random_token
from app.subscription_webhook import execute_subscription_webhook
on_memory_dispatcher = OnMemoryDispatcher()
def setup_module():
GlobalDispatcher.set_dispatcher(on_memory_dispatcher)
config.EVENT_WEBHOOK = "http://test"
def teardown_module():
GlobalDispatcher.set_dispatcher(None)
config.EVENT_WEBHOOK = None
def setup_function(func):
on_memory_dispatcher.clear()
def check_event(user: User, pu: PartnerUser) -> UserPlanChanged:
assert len(on_memory_dispatcher.memory) == 1
event_data = on_memory_dispatcher.memory[0]
event_content = _get_event_from_string(event_data, user, pu)
assert event_content.user_plan_change is not None
plan_change = event_content.user_plan_change
return plan_change
def test_webhook_with_trial():
(user, pu) = _create_linked_user()
execute_subscription_webhook(user)
assert check_event(user, pu).plan_end_time == 0
def test_webhook_with_subscription():
(user, pu) = _create_linked_user()
end_at = arrow.utcnow().shift(days=1).replace(hour=0, minute=0, second=0)
Subscription.create(
user_id=user.id,
cancel_url="",
update_url="",
subscription_id=random_token(10),
event_time=arrow.now(),
next_bill_date=end_at.date(),
plan="yearly",
flush=True,
)
execute_subscription_webhook(user)
assert check_event(user, pu).plan_end_time == end_at.timestamp
def test_webhook_with_apple_subscription():
(user, pu) = _create_linked_user()
end_at = arrow.utcnow().shift(days=2).replace(hour=0, minute=0, second=0)
AppleSubscription.create(
user_id=user.id,
receipt_data=arrow.now().date().strftime("%Y-%m-%d"),
expires_date=end_at.date().strftime("%Y-%m-%d"),
original_transaction_id=random_token(10),
plan="yearly",
product_id="",
flush=True,
)
execute_subscription_webhook(user)
assert check_event(user, pu).plan_end_time == end_at.timestamp
def test_webhook_with_coinbase_subscription():
(user, pu) = _create_linked_user()
end_at = arrow.utcnow().shift(days=3).replace(hour=0, minute=0, second=0)
CoinbaseSubscription.create(
user_id=user.id, end_at=end_at.date().strftime("%Y-%m-%d"), flush=True
)
execute_subscription_webhook(user)
assert check_event(user, pu).plan_end_time == end_at.timestamp
def test_webhook_with_manual_subscription():
(user, pu) = _create_linked_user()
end_at = arrow.utcnow().shift(days=3).replace(hour=0, minute=0, second=0)
ManualSubscription.create(
user_id=user.id, end_at=end_at.date().strftime("%Y-%m-%d"), flush=True
)
execute_subscription_webhook(user)
assert check_event(user, pu).plan_end_time == end_at.timestamp

View File

@ -37,10 +37,14 @@ def test_send_alias_creation_events():
event_list = decoded_event.content.alias_create_list.events
assert len(event_list) == 2
# 0 is newsletter alias
assert event_list[1].alias_id == aliases[0].id
assert event_list[1].id == aliases[0].id
assert event_list[1].email == aliases[0].email
assert event_list[1].note == ""
assert event_list[1].enabled == aliases[0].enabled
assert event_list[1].created_at == int(aliases[0].created_at.timestamp)
decoded_event = event_pb2.Event.FromString(dispatcher.events[1])
assert decoded_event.user_id == user.id
assert decoded_event.external_user_id == partner_user.external_user_id
event_list = decoded_event.content.alias_create_list.events
assert len(event_list) == 1
assert event_list[0].alias_id == aliases[1].id
assert event_list[0].id == aliases[1].id

View File

@ -1,4 +1,6 @@
from arrow import Arrow
from app import config
from app.account_linking import (
SLPlan,
SLPlanType,
@ -8,7 +10,7 @@ from app.proton.proton_callback_handler import (
ProtonCallbackHandler,
generate_account_not_allowed_to_log_in,
)
from app.models import User, PartnerUser
from app.models import User, PartnerUser, Job, JobState
from app.proton.utils import get_proton_partner
from app.utils import random_string
from typing import Optional
@ -23,6 +25,17 @@ class MockProtonClient(ProtonClient):
return self.user
def check_initial_sync_job(user: User):
for job in Job.yield_per_query(10).filter_by(
name=config.JOB_SEND_ALIAS_CREATION_EVENTS,
state=JobState.ready.value,
):
if job.payload.get("user_id") == user.id:
Job.delete(job.id)
return
assert False
def test_proton_callback_handler_unexistant_sl_user():
email = random_email()
name = random_string()
@ -56,6 +69,7 @@ def test_proton_callback_handler_unexistant_sl_user():
)
assert partner_user is not None
assert partner_user.external_user_id == external_id
check_initial_sync_job(res.user)
def test_proton_callback_handler_existant_sl_user():
@ -84,6 +98,7 @@ def test_proton_callback_handler_existant_sl_user():
sa = PartnerUser.get_by(user_id=sl_user.id, partner_id=get_proton_partner().id)
assert sa is not None
assert sa.partner_email == user.email
check_initial_sync_job(res.user)
def test_proton_callback_handler_none_user_login():

View File

@ -1,3 +1,5 @@
from typing import List
import pytest
from arrow import Arrow
@ -16,8 +18,9 @@ from app.account_linking import (
)
from app.db import Session
from app.errors import AccountAlreadyLinkedToAnotherPartnerException
from app.models import Partner, PartnerUser, User
from app.models import Partner, PartnerUser, User, UserAuditLog
from app.proton.utils import get_proton_partner
from app.user_audit_log_utils import UserAuditLogAction
from app.utils import random_string, canonicalize_email
from tests.utils import random_email
@ -91,6 +94,13 @@ def test_login_case_from_partner():
)
assert res.user.activated is True
audit_logs: List[UserAuditLog] = UserAuditLog.filter_by(
user_id=res.user.id,
action=UserAuditLogAction.LinkAccount.value,
).all()
assert len(audit_logs) == 1
assert audit_logs[0].user_id == res.user.id
def test_login_case_from_partner_with_uppercase_email():
partner = get_proton_partner()
@ -125,6 +135,14 @@ def test_login_case_from_web():
assert 0 == (res.user.flags & User.FLAG_CREATED_FROM_PARTNER)
assert res.user.activated is True
audit_logs: List[UserAuditLog] = UserAuditLog.filter_by(
user_id=res.user.id,
action=UserAuditLogAction.LinkAccount.value,
).all()
assert len(audit_logs) == 1
assert audit_logs[0].user_id == res.user.id
assert audit_logs[0].action == UserAuditLogAction.LinkAccount.value
def test_get_strategy_existing_sl_user():
email = random_email()
@ -205,6 +223,13 @@ def test_link_account_with_proton_account_same_address(flask_client):
)
assert partner_user.partner_id == get_proton_partner().id
assert partner_user.external_user_id == partner_user_id
audit_logs: List[UserAuditLog] = UserAuditLog.filter_by(
user_id=res.user.id,
action=UserAuditLogAction.LinkAccount.value,
).all()
assert len(audit_logs) == 1
assert audit_logs[0].user_id == res.user.id
assert audit_logs[0].action == UserAuditLogAction.LinkAccount.value
def test_link_account_with_proton_account_different_address(flask_client):
@ -229,6 +254,14 @@ def test_link_account_with_proton_account_different_address(flask_client):
assert partner_user.partner_id == get_proton_partner().id
assert partner_user.external_user_id == partner_user_id
audit_logs: List[UserAuditLog] = UserAuditLog.filter_by(
user_id=res.user.id,
action=UserAuditLogAction.LinkAccount.value,
).all()
assert len(audit_logs) == 1
assert audit_logs[0].user_id == res.user.id
assert audit_logs[0].action == UserAuditLogAction.LinkAccount.value
def test_link_account_with_proton_account_same_address_but_linked_to_other_user(
flask_client,
@ -248,22 +281,54 @@ def test_link_account_with_proton_account_same_address_but_linked_to_other_user(
partner_user_id, email=random_email()
) # User already linked with the proton account
# START Ensure sl_user_2 has a partner_user with the right data
partner_user = PartnerUser.get_by(
partner_id=get_proton_partner().id, user_id=sl_user_2.id
)
assert partner_user is not None
assert partner_user.partner_id == get_proton_partner().id
assert partner_user.external_user_id == partner_user_id
assert partner_user.partner_email == sl_user_2.email
assert partner_user.user_id == sl_user_2.id
# END Ensure sl_user_2 has a partner_user with the right data
# Proceed to link sl_user_1
res = process_link_case(link_request, sl_user_1, get_proton_partner())
# Check that the result is linking sl_user_1
assert res.user.id == sl_user_1.id
assert res.user.email == partner_email
assert res.strategy == "Link"
# Ensure partner_user for sl_user_1 exists
partner_user = PartnerUser.get_by(
partner_id=get_proton_partner().id, user_id=sl_user_1.id
)
assert partner_user.partner_id == get_proton_partner().id
assert partner_user.external_user_id == partner_user_id
# Ensure partner_user for sl_user_2 does not exist anymore
partner_user = PartnerUser.get_by(
partner_id=get_proton_partner().id, user_id=sl_user_2.id
)
assert partner_user is None
# Ensure audit logs for sl_user_1 show the link action
sl_user_1_audit_logs: List[UserAuditLog] = UserAuditLog.filter_by(
user_id=sl_user_1.id,
action=UserAuditLogAction.LinkAccount.value,
).all()
assert len(sl_user_1_audit_logs) == 1
assert sl_user_1_audit_logs[0].user_id == sl_user_1.id
# Ensure audit logs for sl_user_2 show the unlink action
sl_user_2_audit_logs: List[UserAuditLog] = UserAuditLog.filter_by(
user_id=sl_user_2.id,
action=UserAuditLogAction.UnlinkAccount.value,
).all()
assert len(sl_user_2_audit_logs) == 1
assert sl_user_2_audit_logs[0].user_id == sl_user_2.id
def test_link_account_with_proton_account_different_address_and_linked_to_other_user(
flask_client,
@ -300,6 +365,22 @@ def test_link_account_with_proton_account_different_address_and_linked_to_other_
)
assert partner_user_2 is None
# Ensure audit logs for sl_user_1 show the link action
sl_user_1_audit_logs: List[UserAuditLog] = UserAuditLog.filter_by(
user_id=sl_user_1.id,
action=UserAuditLogAction.LinkAccount.value,
).all()
assert len(sl_user_1_audit_logs) == 1
assert sl_user_1_audit_logs[0].user_id == sl_user_1.id
# Ensure audit logs for sl_user_2 show the unlink action
sl_user_2_audit_logs: List[UserAuditLog] = UserAuditLog.filter_by(
user_id=sl_user_2.id,
action=UserAuditLogAction.UnlinkAccount.value,
).all()
assert len(sl_user_2_audit_logs) == 1
assert sl_user_2_audit_logs[0].user_id == sl_user_2.id
def test_cannot_create_instance_of_base_strategy():
with pytest.raises(Exception):

View File

@ -0,0 +1,95 @@
import random
from app.alias_audit_log_utils import emit_alias_audit_log, AliasAuditLogAction
from app.alias_utils import delete_alias, transfer_alias
from app.models import Alias, AliasAuditLog, AliasDeleteReason
from app.utils import random_string
from tests.utils import create_new_user, random_email
def test_emit_alias_audit_log_for_random_data():
user = create_new_user()
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
)
random_user_id = random.randint(1000, 2000)
message = random_string()
action = AliasAuditLogAction.ChangeAliasStatus
emit_alias_audit_log(
alias=alias,
user_id=random_user_id,
action=action,
message=message,
commit=True,
)
logs_for_alias = AliasAuditLog.filter_by(alias_id=alias.id).all()
assert len(logs_for_alias) == 2
last_log = logs_for_alias[-1]
assert last_log.alias_id == alias.id
assert last_log.alias_email == alias.email
assert last_log.user_id == random_user_id
assert last_log.action == action.value
assert last_log.message == message
def test_emit_alias_audit_log_on_alias_creation():
user = create_new_user()
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
)
log_for_alias = AliasAuditLog.filter_by(alias_id=alias.id).all()
assert len(log_for_alias) == 1
assert log_for_alias[0].alias_id == alias.id
assert log_for_alias[0].alias_email == alias.email
assert log_for_alias[0].user_id == user.id
assert log_for_alias[0].action == AliasAuditLogAction.CreateAlias.value
def test_alias_audit_log_exists_after_alias_deletion():
user = create_new_user()
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
)
alias_id = alias.id
emit_alias_audit_log(alias, AliasAuditLogAction.UpdateAlias, "")
emit_alias_audit_log(alias, AliasAuditLogAction.UpdateAlias, "")
delete_alias(alias, user, AliasDeleteReason.ManualAction, commit=True)
db_alias = Alias.get_by(id=alias_id)
assert db_alias is None
logs_for_alias = AliasAuditLog.filter_by(alias_id=alias.id).all()
assert len(logs_for_alias) == 4
assert logs_for_alias[0].action == AliasAuditLogAction.CreateAlias.value
assert logs_for_alias[1].action == AliasAuditLogAction.UpdateAlias.value
assert logs_for_alias[2].action == AliasAuditLogAction.UpdateAlias.value
assert logs_for_alias[3].action == AliasAuditLogAction.DeleteAlias.value
def test_alias_audit_log_for_transfer():
original_user = create_new_user()
new_user = create_new_user()
alias = Alias.create(
user_id=original_user.id,
email=random_email(),
mailbox_id=original_user.default_mailbox_id,
)
transfer_alias(alias, new_user, [new_user.default_mailbox])
logs_for_alias = AliasAuditLog.filter_by(alias_id=alias.id).all()
assert len(logs_for_alias) == 3
assert logs_for_alias[0].action == AliasAuditLogAction.CreateAlias.value
assert logs_for_alias[1].action == AliasAuditLogAction.TransferredAlias.value
assert logs_for_alias[1].user_id == original_user.id
assert logs_for_alias[2].action == AliasAuditLogAction.AcceptTransferAlias.value
assert logs_for_alias[2].user_id == new_user.id

View File

@ -0,0 +1,70 @@
from typing import Tuple
from app.alias_audit_log_utils import AliasAuditLogAction
from app.alias_mailbox_utils import (
set_mailboxes_for_alias,
CannotSetMailboxesForAliasCause,
)
from app.models import Alias, Mailbox, User, AliasMailbox, AliasAuditLog
from tests.utils import create_new_user, random_email
def setup() -> Tuple[User, Alias]:
user = create_new_user()
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
commit=True,
)
return user, alias
def test_set_mailboxes_for_alias_empty_list():
user, alias = setup()
err = set_mailboxes_for_alias(user.id, alias, [])
assert err is CannotSetMailboxesForAliasCause.EmptyMailboxes
def test_set_mailboxes_for_alias_mailbox_for_other_user():
user, alias = setup()
another_user = create_new_user()
err = set_mailboxes_for_alias(user.id, alias, [another_user.default_mailbox_id])
assert err is CannotSetMailboxesForAliasCause.Forbidden
def test_set_mailboxes_for_alias_mailbox_not_exists():
user, alias = setup()
err = set_mailboxes_for_alias(user.id, alias, [9999999])
assert err is CannotSetMailboxesForAliasCause.Forbidden
def test_set_mailboxes_for_alias_mailbox_success():
user, alias = setup()
mb1 = Mailbox.create(
user_id=user.id,
email=random_email(),
verified=True,
)
mb2 = Mailbox.create(
user_id=user.id,
email=random_email(),
verified=True,
commit=True,
)
err = set_mailboxes_for_alias(user.id, alias, [mb1.id, mb2.id])
assert err is None
db_alias = Alias.get_by(id=alias.id)
assert db_alias is not None
assert db_alias.mailbox_id == mb1.id
alias_mailboxes = AliasMailbox.filter_by(alias_id=alias.id).all()
assert len(alias_mailboxes) == 1
assert alias_mailboxes[0].mailbox_id == mb2.id
audit_logs = AliasAuditLog.filter_by(alias_id=alias.id).all()
assert len(audit_logs) == 2
assert audit_logs[0].action == AliasAuditLogAction.CreateAlias.value
assert audit_logs[1].action == AliasAuditLogAction.ChangedMailboxes.value
assert audit_logs[1].message == f"{mb1.id} ({mb1.email}),{mb2.id} ({mb2.email})"

View File

@ -4,6 +4,7 @@ from app.alias_utils import (
delete_alias,
check_alias_prefix,
get_user_if_alias_would_auto_create,
get_alias_recipient_name,
try_auto_create,
)
from app.config import ALIAS_DOMAINS
@ -18,7 +19,8 @@ from app.models import (
User,
DomainDeletedAlias,
)
from tests.utils import create_new_user, random_domain, random_token
from app.utils import random_string
from tests.utils import create_new_user, random_domain, random_token, random_email
def test_delete_alias(flask_client):
@ -131,3 +133,91 @@ def test_auto_create_alias(flask_client):
assert result, f"Case {test_id} - Failed address {address}"
else:
assert result is None, f"Case {test_id} - Failed address {address}"
# get_alias_recipient_name
def test_get_alias_recipient_name_no_overrides():
user = create_new_user()
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
commit=True,
)
res = get_alias_recipient_name(alias)
assert res.message is None
assert res.name == alias.email
def test_get_alias_recipient_name_alias_name():
user = create_new_user()
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
name=random_string(),
commit=True,
)
res = get_alias_recipient_name(alias)
assert res.message is not None
assert res.name == f"{alias.name} <{alias.email}>"
def test_get_alias_recipient_alias_with_name_and_custom_domain_name():
user = create_new_user()
custom_domain = CustomDomain.create(
user_id=user.id,
domain=random_domain(),
name=random_string(),
verified=True,
)
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
name=random_string(),
custom_domain_id=custom_domain.id,
commit=True,
)
res = get_alias_recipient_name(alias)
assert res.message is not None
assert res.name == f"{alias.name} <{alias.email}>"
def test_get_alias_recipient_alias_without_name_and_custom_domain_without_name():
user = create_new_user()
custom_domain = CustomDomain.create(
user_id=user.id,
domain=random_domain(),
verified=True,
)
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
custom_domain_id=custom_domain.id,
commit=True,
)
res = get_alias_recipient_name(alias)
assert res.message is None
assert res.name == alias.email
def test_get_alias_recipient_alias_without_name_and_custom_domain_name():
user = create_new_user()
custom_domain = CustomDomain.create(
user_id=user.id,
domain=random_domain(),
name=random_string(),
verified=True,
)
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
custom_domain_id=custom_domain.id,
commit=True,
)
res = get_alias_recipient_name(alias)
assert res.message is not None
assert res.name == f"{custom_domain.name} <{alias.email}>"

View File

@ -0,0 +1,196 @@
from typing import Optional
import pytest
from app import config
from app.contact_utils import create_contact, ContactCreateError
from app.db import Session
from app.models import (
Alias,
Contact,
User,
)
from tests.utils import create_new_user, random_email, random_token
def setup_module(module):
config.DISABLE_CREATE_CONTACTS_FOR_FREE_USERS = True
def teardown_module(module):
config.DISABLE_CREATE_CONTACTS_FOR_FREE_USERS = False
def create_provider():
# name auto_created from_partner
yield ["name", "a@b.c", True, True]
yield [None, None, True, True]
yield [None, None, False, True]
yield [None, None, True, False]
yield [None, None, False, False]
@pytest.mark.parametrize(
"name, mail_from, automatic_created, from_partner", create_provider()
)
def test_create_contact(
name: Optional[str],
mail_from: Optional[str],
automatic_created: bool,
from_partner: bool,
):
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
email = random_email()
contact_result = create_contact(
email,
alias,
name=name,
mail_from=mail_from,
automatic_created=automatic_created,
from_partner=from_partner,
)
assert contact_result.error is None
contact = contact_result.contact
assert contact.user_id == user.id
assert contact.alias_id == alias.id
assert contact.website_email == email
assert contact.name == name
assert contact.mail_from == mail_from
assert contact.automatic_created == automatic_created
assert not contact.invalid_email
expected_flags = Contact.FLAG_PARTNER_CREATED if from_partner else 0
assert contact.flags == expected_flags
def test_create_contact_email_email_not_allowed():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
contact_result = create_contact("", alias)
assert contact_result.contact is None
assert contact_result.error == ContactCreateError.InvalidEmail
def test_create_contact_email_email_allowed():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
contact_result = create_contact("", alias, allow_empty_email=True)
assert contact_result.error is None
assert contact_result.contact is not None
assert contact_result.contact.website_email == ""
assert contact_result.contact.invalid_email
def test_create_contact_name_overrides_email_name():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
email = random_email()
name = random_token()
contact_result = create_contact(f"superseeded <{email}>", alias, name=name)
assert contact_result.error is None
assert contact_result.contact is not None
assert contact_result.contact.website_email == email
assert contact_result.contact.name == name
def test_create_contact_name_taken_from_email():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
email = random_email()
name = random_token()
contact_result = create_contact(f"{name} <{email}>", alias)
assert contact_result.error is None
assert contact_result.contact is not None
assert contact_result.contact.website_email == email
assert contact_result.contact.name == name
def test_create_contact_empty_name_is_none():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
email = random_email()
contact_result = create_contact(email, alias, name="")
assert contact_result.error is None
assert contact_result.contact is not None
assert contact_result.contact.website_email == email
assert contact_result.contact.name is None
def test_create_contact_free_user():
user = create_new_user()
user.trial_end = None
user.flags = 0
alias = Alias.create_new_random(user)
Session.flush()
# Free users without the FREE_DISABLE_CREATE_CONTACTS
result = create_contact(random_email(), alias)
assert result.error is None
assert result.created
assert result.contact is not None
assert not result.contact.automatic_created
# Free users with the flag should be able to still create automatic emails
user.flags = User.FLAG_DISABLE_CREATE_CONTACTS
Session.flush()
result = create_contact(random_email(), alias, automatic_created=True)
assert result.error is None
assert result.created
assert result.contact is not None
assert result.contact.automatic_created
# Free users with the flag cannot create non-automatic emails
result = create_contact(random_email(), alias)
assert result.error == ContactCreateError.NotAllowed
def test_do_not_allow_invalid_email():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
contact_result = create_contact("potato", alias)
assert contact_result.contact is None
assert contact_result.error == ContactCreateError.InvalidEmail
contact_result = create_contact("asdf\x00@gmail.com", alias)
assert contact_result.contact is None
assert contact_result.error == ContactCreateError.InvalidEmail
def test_update_name_for_existing():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
email = random_email()
contact_result = create_contact(email, alias)
assert contact_result.error is None
assert contact_result.created
assert contact_result.contact is not None
assert contact_result.contact.name is None
name = random_token()
contact_result = create_contact(email, alias, name=name)
assert contact_result.error is None
assert not contact_result.created
assert contact_result.contact is not None
assert contact_result.contact.name == name
def test_update_mail_from_for_existing():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
email = random_email()
contact_result = create_contact(email, alias)
assert contact_result.error is None
assert contact_result.created
assert contact_result.contact is not None
assert contact_result.contact is not None
assert contact_result.contact.mail_from is None
mail_from = random_email()
contact_result = create_contact(email, alias, mail_from=mail_from)
assert contact_result.error is None
assert not contact_result.created
assert contact_result.contact is not None
assert contact_result.contact.mail_from == mail_from

View File

@ -0,0 +1,267 @@
from typing import Optional
from app import config
from app.config import ALIAS_DOMAINS
from app.custom_domain_utils import (
can_domain_be_used,
create_custom_domain,
is_valid_domain,
sanitize_domain,
set_custom_domain_mailboxes,
CannotUseDomainReason,
CannotSetCustomDomainMailboxesCause,
)
from app.db import Session
from app.models import User, CustomDomain, Mailbox, DomainMailbox
from tests.utils import get_proton_partner, random_email
from tests.utils import create_new_user, random_string, random_domain
user: Optional[User] = None
def setup_module():
global user
config.SKIP_MX_LOOKUP_ON_CHECK = True
user = create_new_user()
user.trial_end = None
user.lifetime = True
Session.commit()
# is_valid_domain
def test_is_valid_domain():
assert is_valid_domain("example.com") is True
assert is_valid_domain("sub.example.com") is True
assert is_valid_domain("ex-ample.com") is True
assert is_valid_domain("-example.com") is False
assert is_valid_domain("example-.com") is False
assert is_valid_domain("exa_mple.com") is False
assert is_valid_domain("example..com") is False
assert is_valid_domain("") is False
assert is_valid_domain("a" * 64 + ".com") is False
assert is_valid_domain("a" * 63 + ".com") is True
assert is_valid_domain("example.com.") is True
assert is_valid_domain(".example.com") is False
assert is_valid_domain("example..com") is False
assert is_valid_domain("example.com-") is False
# can_domain_be_used
def test_can_domain_be_used():
domain = f"{random_string(10)}.com"
res = can_domain_be_used(user, domain)
assert res is None
def test_can_domain_be_used_existing_domain():
domain = random_domain()
CustomDomain.create(user_id=user.id, domain=domain, commit=True)
res = can_domain_be_used(user, domain)
assert res is CannotUseDomainReason.DomainAlreadyUsed
def test_can_domain_be_used_sl_domain():
domain = ALIAS_DOMAINS[0]
res = can_domain_be_used(user, domain)
assert res is CannotUseDomainReason.BuiltinDomain
def test_can_domain_be_used_domain_of_user_email():
domain = user.email.split("@")[1]
res = can_domain_be_used(user, domain)
assert res is CannotUseDomainReason.DomainPartOfUserEmail
def test_can_domain_be_used_domain_of_existing_mailbox():
domain = random_domain()
Mailbox.create(user_id=user.id, email=f"email@{domain}", verified=True, commit=True)
res = can_domain_be_used(user, domain)
assert res is CannotUseDomainReason.DomainUserInMailbox
def test_can_domain_be_used_invalid_domain():
domain = f"{random_string(10)}@lol.com"
res = can_domain_be_used(user, domain)
assert res is CannotUseDomainReason.InvalidDomain
# sanitize_domain
def test_can_sanitize_domain_empty():
assert sanitize_domain("") == ""
def test_can_sanitize_domain_starting_with_http():
domain = "test.domain"
assert sanitize_domain(f"http://{domain}") == domain
def test_can_sanitize_domain_starting_with_https():
domain = "test.domain"
assert sanitize_domain(f"https://{domain}") == domain
def test_can_sanitize_domain_correct_domain():
domain = "test.domain"
assert sanitize_domain(domain) == domain
# create_custom_domain
def test_can_create_custom_domain():
domain = random_domain()
res = create_custom_domain(user=user, domain=domain)
assert res.success is True
assert res.redirect is None
assert res.message == ""
assert res.message_category == ""
assert res.instance is not None
assert res.instance.domain == domain
assert res.instance.user_id == user.id
def test_can_create_custom_domain_validates_if_parent_is_validated():
root_domain = random_domain()
subdomain = f"{random_string(10)}.{root_domain}"
# Create custom domain with the root domain
CustomDomain.create(
user_id=user.id,
domain=root_domain,
verified=True,
ownership_verified=True,
commit=True,
)
# Create custom domain with subdomain. Should automatically be verified
res = create_custom_domain(user=user, domain=subdomain)
assert res.success is True
assert res.instance.domain == subdomain
assert res.instance.user_id == user.id
assert res.instance.ownership_verified is True
def test_creates_custom_domain_with_partner_id():
domain = random_domain()
proton_partner = get_proton_partner()
res = create_custom_domain(user=user, domain=domain, partner_id=proton_partner.id)
assert res.success is True
assert res.instance.domain == domain
assert res.instance.user_id == user.id
assert res.instance.partner_id == proton_partner.id
# set_custom_domain_mailboxes
def test_set_custom_domain_mailboxes_empty_list():
domain = CustomDomain.create(user_id=user.id, domain=random_domain(), commit=True)
res = set_custom_domain_mailboxes(user.id, domain, [])
assert res.success is False
assert res.reason == CannotSetCustomDomainMailboxesCause.NoMailboxes
def test_set_custom_domain_mailboxes_mailbox_from_another_user():
other_user = create_new_user()
other_mailbox = Mailbox.create(
user_id=other_user.id, email=random_email(), verified=True
)
domain = CustomDomain.create(user_id=user.id, domain=random_domain(), commit=True)
res = set_custom_domain_mailboxes(user.id, domain, [other_mailbox.id])
assert res.success is False
assert res.reason == CannotSetCustomDomainMailboxesCause.InvalidMailbox
def test_set_custom_domain_mailboxes_mailbox_from_current_user_and_another_user():
other_user = create_new_user()
other_mailbox = Mailbox.create(
user_id=other_user.id, email=random_email(), verified=True
)
domain = CustomDomain.create(user_id=user.id, domain=random_domain(), commit=True)
res = set_custom_domain_mailboxes(
user.id, domain, [user.default_mailbox_id, other_mailbox.id]
)
assert res.success is False
assert res.reason == CannotSetCustomDomainMailboxesCause.InvalidMailbox
def test_set_custom_domain_mailboxes_success():
other_mailbox = Mailbox.create(user_id=user.id, email=random_email(), verified=True)
domain = CustomDomain.create(user_id=user.id, domain=random_domain(), commit=True)
res = set_custom_domain_mailboxes(
user.id, domain, [user.default_mailbox_id, other_mailbox.id]
)
assert res.success is True
assert res.reason is None
domain_mailboxes = DomainMailbox.filter_by(domain_id=domain.id).all()
assert len(domain_mailboxes) == 2
assert domain_mailboxes[0].domain_id == domain.id
assert domain_mailboxes[0].mailbox_id == user.default_mailbox_id
assert domain_mailboxes[1].domain_id == domain.id
assert domain_mailboxes[1].mailbox_id == other_mailbox.id
def test_set_custom_domain_mailboxes_set_twice():
other_mailbox = Mailbox.create(user_id=user.id, email=random_email(), verified=True)
domain = CustomDomain.create(user_id=user.id, domain=random_domain(), commit=True)
res = set_custom_domain_mailboxes(
user.id, domain, [user.default_mailbox_id, other_mailbox.id]
)
assert res.success is True
assert res.reason is None
res = set_custom_domain_mailboxes(
user.id, domain, [user.default_mailbox_id, other_mailbox.id]
)
assert res.success is True
assert res.reason is None
domain_mailboxes = DomainMailbox.filter_by(domain_id=domain.id).all()
assert len(domain_mailboxes) == 2
assert domain_mailboxes[0].domain_id == domain.id
assert domain_mailboxes[0].mailbox_id == user.default_mailbox_id
assert domain_mailboxes[1].domain_id == domain.id
assert domain_mailboxes[1].mailbox_id == other_mailbox.id
def test_set_custom_domain_mailboxes_removes_old_association():
domain = CustomDomain.create(user_id=user.id, domain=random_domain(), commit=True)
res = set_custom_domain_mailboxes(user.id, domain, [user.default_mailbox_id])
assert res.success is True
assert res.reason is None
other_mailbox = Mailbox.create(
user_id=user.id, email=random_email(), verified=True, commit=True
)
res = set_custom_domain_mailboxes(user.id, domain, [other_mailbox.id])
assert res.success is True
assert res.reason is None
domain_mailboxes = DomainMailbox.filter_by(domain_id=domain.id).all()
assert len(domain_mailboxes) == 1
assert domain_mailboxes[0].domain_id == domain.id
assert domain_mailboxes[0].mailbox_id == other_mailbox.id
def test_set_custom_domain_mailboxes_with_unverified_mailbox():
domain = CustomDomain.create(user_id=user.id, domain=random_domain())
verified_mailbox = Mailbox.create(
user_id=user.id,
email=random_email(),
verified=True,
)
unverified_mailbox = Mailbox.create(
user_id=user.id,
email=random_email(),
verified=False,
)
res = set_custom_domain_mailboxes(
user.id, domain, [verified_mailbox.id, unverified_mailbox.id]
)
assert res.success is False
assert res.reason is CannotSetCustomDomainMailboxesCause.InvalidMailbox

View File

@ -0,0 +1,545 @@
from typing import Optional
from app import config
from app.constants import DMARC_RECORD
from app.custom_domain_validation import CustomDomainValidation
from app.db import Session
from app.models import CustomDomain, User
from app.dns_utils import InMemoryDNSClient, MxRecord
from app.proton.utils import get_proton_partner
from app.utils import random_string
from tests.utils import create_new_user, random_domain
user: Optional[User] = None
def setup_module():
global user
config.SKIP_MX_LOOKUP_ON_CHECK = True
user = create_new_user()
user.trial_end = None
user.lifetime = True
Session.commit()
def create_custom_domain(domain: str) -> CustomDomain:
return CustomDomain.create(user_id=user.id, domain=domain, commit=True)
def test_custom_domain_validation_get_dkim_records():
domain = random_domain()
custom_domain = create_custom_domain(domain)
validator = CustomDomainValidation(domain)
records = validator.get_dkim_records(custom_domain)
assert len(records) == 3
assert records["dkim02._domainkey"] == f"dkim02._domainkey.{domain}"
assert records["dkim03._domainkey"] == f"dkim03._domainkey.{domain}"
assert records["dkim._domainkey"] == f"dkim._domainkey.{domain}"
def test_custom_domain_validation_get_dkim_records_for_partner():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
custom_domain.partner_id = partner_id
Session.commit()
dkim_domain = random_domain()
validator = CustomDomainValidation(
domain, partner_domains={partner_id: dkim_domain}
)
records = validator.get_dkim_records(custom_domain)
assert len(records) == 3
assert records["dkim02._domainkey"] == f"dkim02._domainkey.{dkim_domain}"
assert records["dkim03._domainkey"] == f"dkim03._domainkey.{dkim_domain}"
assert records["dkim._domainkey"] == f"dkim._domainkey.{dkim_domain}"
# get_expected_mx_records
def test_custom_domain_validation_get_expected_mx_records_regular_domain():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
dkim_domain = random_domain()
validator = CustomDomainValidation(
domain, partner_domains={partner_id: dkim_domain}
)
records = validator.get_expected_mx_records(custom_domain)
# As the domain is not a partner_domain,default records should be used even if
# there is a config for the partner
assert len(records) == len(config.EMAIL_SERVERS_WITH_PRIORITY)
for i in range(len(config.EMAIL_SERVERS_WITH_PRIORITY)):
config_record = config.EMAIL_SERVERS_WITH_PRIORITY[i]
assert records[i].priority == config_record[0]
assert records[i].domain == config_record[1]
def test_custom_domain_validation_get_expected_mx_records_domain_from_partner():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
custom_domain.partner_id = partner_id
Session.commit()
dkim_domain = random_domain()
validator = CustomDomainValidation(dkim_domain)
records = validator.get_expected_mx_records(custom_domain)
# As the domain is a partner_domain but there is no custom config for partner, default records
# should be used
assert len(records) == len(config.EMAIL_SERVERS_WITH_PRIORITY)
for i in range(len(config.EMAIL_SERVERS_WITH_PRIORITY)):
config_record = config.EMAIL_SERVERS_WITH_PRIORITY[i]
assert records[i].priority == config_record[0]
assert records[i].domain == config_record[1]
def test_custom_domain_validation_get_expected_mx_records_domain_from_partner_with_custom_config():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
custom_domain.partner_id = partner_id
Session.commit()
dkim_domain = random_domain()
expected_mx_domain = random_domain()
validator = CustomDomainValidation(
dkim_domain, partner_domains={partner_id: expected_mx_domain}
)
records = validator.get_expected_mx_records(custom_domain)
# As the domain is a partner_domain and there is a custom config for partner, partner records
# should be used
assert len(records) == 2
assert records[0].priority == 10
assert records[0].domain == f"mx1.{expected_mx_domain}."
assert records[1].priority == 20
assert records[1].domain == f"mx2.{expected_mx_domain}."
# get_expected_spf_records
def test_custom_domain_validation_get_expected_spf_record_regular_domain():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
dkim_domain = random_domain()
validator = CustomDomainValidation(
domain, partner_domains={partner_id: dkim_domain}
)
record = validator.get_expected_spf_record(custom_domain)
# As the domain is not a partner_domain, default records should be used even if
# there is a config for the partner
assert record == f"v=spf1 include:{config.EMAIL_DOMAIN} ~all"
def test_custom_domain_validation_get_expected_spf_record_domain_from_partner():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
custom_domain.partner_id = partner_id
Session.commit()
dkim_domain = random_domain()
validator = CustomDomainValidation(dkim_domain)
record = validator.get_expected_spf_record(custom_domain)
# As the domain is a partner_domain but there is no custom config for partner, default records
# should be used
assert record == f"v=spf1 include:{config.EMAIL_DOMAIN} ~all"
def test_custom_domain_validation_get_expected_spf_record_domain_from_partner_with_custom_config():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
custom_domain.partner_id = partner_id
Session.commit()
dkim_domain = random_domain()
expected_mx_domain = random_domain()
validator = CustomDomainValidation(
dkim_domain, partner_domains={partner_id: expected_mx_domain}
)
record = validator.get_expected_spf_record(custom_domain)
# As the domain is a partner_domain and there is a custom config for partner, partner records
# should be used
assert record == f"v=spf1 include:{expected_mx_domain} ~all"
# validate_dkim_records
def test_custom_domain_validation_validate_dkim_records_empty_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
res = validator.validate_dkim_records(domain)
assert len(res) == 3
for record_value in res.values():
assert record_value == "empty"
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dkim_verified is False
def test_custom_domain_validation_validate_dkim_records_wrong_records_failure():
dkim_domain = random_domain()
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(dkim_domain, dns_client)
user_domain = random_domain()
# One domain right, two domains wrong
dns_client.set_cname_record(
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
)
dns_client.set_cname_record(f"dkim02._domainkey.{user_domain}", "wrong")
dns_client.set_cname_record(f"dkim03._domainkey.{user_domain}", "wrong")
domain = create_custom_domain(user_domain)
res = validator.validate_dkim_records(domain)
assert len(res) == 2
for record_value in res.values():
assert record_value == "wrong"
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dkim_verified is False
def test_custom_domain_validation_validate_dkim_records_success_with_old_system():
dkim_domain = random_domain()
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(dkim_domain, dns_client)
user_domain = random_domain()
# One domain right, other domains missing
dns_client.set_cname_record(
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
)
domain = create_custom_domain(user_domain)
# DKIM is verified
domain.dkim_verified = True
Session.commit()
res = validator.validate_dkim_records(domain)
assert len(res) == 2
assert f"dkim02._domainkey.{user_domain}" in res
assert f"dkim03._domainkey.{user_domain}" in res
# Flag is not cleared
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dkim_verified is True
def test_custom_domain_validation_validate_dkim_records_success():
dkim_domain = random_domain()
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(dkim_domain, dns_client)
user_domain = random_domain()
# One domain right, two domains wrong
dns_client.set_cname_record(
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
)
dns_client.set_cname_record(
f"dkim02._domainkey.{user_domain}", f"dkim02._domainkey.{dkim_domain}"
)
dns_client.set_cname_record(
f"dkim03._domainkey.{user_domain}", f"dkim03._domainkey.{dkim_domain}"
)
domain = create_custom_domain(user_domain)
res = validator.validate_dkim_records(domain)
assert len(res) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dkim_verified is True
# validate_ownership
def test_custom_domain_validation_validate_ownership_empty_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
res = validator.validate_domain_ownership(domain)
assert res.success is False
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.ownership_verified is False
def test_custom_domain_validation_validate_ownership_wrong_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
wrong_records = [random_string()]
dns_client.set_txt_record(domain.domain, wrong_records)
res = validator.validate_domain_ownership(domain)
assert res.success is False
assert res.errors == wrong_records
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.ownership_verified is False
def test_custom_domain_validation_validate_ownership_success():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
dns_client.set_txt_record(
domain.domain, [validator.get_ownership_verification_record(domain)]
)
res = validator.validate_domain_ownership(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.ownership_verified is True
def test_custom_domain_validation_validate_ownership_from_partner_success():
dns_client = InMemoryDNSClient()
partner_id = get_proton_partner().id
prefix = random_string()
validator = CustomDomainValidation(
random_domain(),
dns_client,
partner_domains_validation_prefixes={partner_id: prefix},
)
domain = create_custom_domain(random_domain())
domain.partner_id = partner_id
Session.commit()
dns_client.set_txt_record(
domain.domain, [validator.get_ownership_verification_record(domain)]
)
res = validator.validate_domain_ownership(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.ownership_verified is True
# validate_mx_records
def test_custom_domain_validation_validate_mx_records_empty_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
res = validator.validate_mx_records(domain)
assert res.success is False
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.verified is False
def test_custom_domain_validation_validate_mx_records_wrong_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
wrong_record_1 = random_string()
wrong_record_2 = random_string()
wrong_records = [MxRecord(10, wrong_record_1), MxRecord(20, wrong_record_2)]
dns_client.set_mx_records(domain.domain, wrong_records)
res = validator.validate_mx_records(domain)
assert res.success is False
assert res.errors == [f"10 {wrong_record_1}", f"20 {wrong_record_2}"]
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.verified is False
def test_custom_domain_validation_validate_mx_records_success():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
dns_client.set_mx_records(domain.domain, validator.get_expected_mx_records(domain))
res = validator.validate_mx_records(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.verified is True
# validate_spf_records
def test_custom_domain_validation_validate_spf_records_empty_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
res = validator.validate_spf_records(domain)
assert res.success is False
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.spf_verified is False
def test_custom_domain_validation_validate_spf_records_wrong_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
wrong_records = [random_string()]
dns_client.set_txt_record(domain.domain, wrong_records)
res = validator.validate_spf_records(domain)
assert res.success is False
assert res.errors == wrong_records
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.spf_verified is False
def test_custom_domain_validation_validate_spf_records_success():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
dns_client.set_txt_record(domain.domain, [f"v=spf1 include:{config.EMAIL_DOMAIN}"])
res = validator.validate_spf_records(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.spf_verified is True
def test_custom_domain_validation_validate_spf_records_partner_domain_success():
dns_client = InMemoryDNSClient()
proton_partner_id = get_proton_partner().id
expected_domain = random_domain()
validator = CustomDomainValidation(
dkim_domain=random_domain(),
dns_client=dns_client,
partner_domains={proton_partner_id: expected_domain},
)
domain = create_custom_domain(random_domain())
domain.partner_id = proton_partner_id
Session.commit()
dns_client.set_txt_record(domain.domain, [f"v=spf1 include:{expected_domain}"])
res = validator.validate_spf_records(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.spf_verified is True
def test_custom_domain_validation_validate_spf_cleans_verification_record():
dns_client = InMemoryDNSClient()
proton_partner_id = get_proton_partner().id
expected_domain = random_domain()
validator = CustomDomainValidation(
dkim_domain=random_domain(),
dns_client=dns_client,
partner_domains={proton_partner_id: expected_domain},
)
domain = create_custom_domain(random_domain())
domain.partner_id = proton_partner_id
Session.commit()
wrong_record = random_string()
dns_client.set_txt_record(
hostname=domain.domain,
txt_list=[wrong_record, validator.get_ownership_verification_record(domain)],
)
res = validator.validate_spf_records(domain)
assert res.success is False
assert len(res.errors) == 1
assert res.errors[0] == wrong_record
# validate_dmarc_records
def test_custom_domain_validation_validate_dmarc_records_empty_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
res = validator.validate_dmarc_records(domain)
assert res.success is False
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dmarc_verified is False
def test_custom_domain_validation_validate_dmarc_records_wrong_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
wrong_records = [random_string()]
dns_client.set_txt_record(f"_dmarc.{domain.domain}", wrong_records)
res = validator.validate_dmarc_records(domain)
assert res.success is False
assert res.errors == wrong_records
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dmarc_verified is False
def test_custom_domain_validation_validate_dmarc_records_success():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
dns_client.set_txt_record(f"_dmarc.{domain.domain}", [DMARC_RECORD])
res = validator.validate_dmarc_records(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dmarc_verified is True

View File

@ -1,10 +1,13 @@
from app.dns_utils import (
get_mx_domains,
get_spf_domain,
get_txt_record,
get_network_dns_client,
is_mx_equivalent,
InMemoryDNSClient,
MxRecord,
)
from tests.utils import random_domain
# use our own domain for test
_DOMAIN = "simplelogin.io"
@ -15,34 +18,58 @@ def test_get_mx_domains():
assert len(r) > 0
for x in r:
assert x[0] > 0
assert x[1]
assert x.priority > 0
assert x.domain
def test_get_spf_domain():
r = get_spf_domain(_DOMAIN)
r = get_network_dns_client().get_spf_domain(_DOMAIN)
assert r == ["simplelogin.co"]
def test_get_txt_record():
r = get_txt_record(_DOMAIN)
r = get_network_dns_client().get_txt_record(_DOMAIN)
assert len(r) > 0
def test_is_mx_equivalent():
assert is_mx_equivalent([], [])
assert is_mx_equivalent([(1, "domain")], [(1, "domain")])
assert is_mx_equivalent(
[(10, "domain1"), (20, "domain2")], [(10, "domain1"), (20, "domain2")]
mx_domains=[MxRecord(1, "domain")], ref_mx_domains=[MxRecord(1, "domain")]
)
assert is_mx_equivalent(
[(5, "domain1"), (10, "domain2")], [(10, "domain1"), (20, "domain2")]
mx_domains=[MxRecord(10, "domain1"), MxRecord(20, "domain2")],
ref_mx_domains=[MxRecord(10, "domain1"), MxRecord(20, "domain2")],
)
assert is_mx_equivalent(
[(5, "domain1"), (10, "domain2"), (20, "domain3")],
[(10, "domain1"), (20, "domain2")],
mx_domains=[MxRecord(5, "domain1"), MxRecord(10, "domain2")],
ref_mx_domains=[MxRecord(10, "domain1"), MxRecord(20, "domain2")],
)
assert is_mx_equivalent(
mx_domains=[
MxRecord(5, "domain1"),
MxRecord(10, "domain2"),
MxRecord(20, "domain3"),
],
ref_mx_domains=[MxRecord(10, "domain1"), MxRecord(20, "domain2")],
)
assert not is_mx_equivalent(
[(5, "domain1"), (10, "domain2")],
[(10, "domain1"), (20, "domain2"), (20, "domain3")],
mx_domains=[MxRecord(5, "domain1"), MxRecord(10, "domain2")],
ref_mx_domains=[
MxRecord(10, "domain1"),
MxRecord(20, "domain2"),
MxRecord(20, "domain3"),
],
)
def test_get_spf_record():
client = InMemoryDNSClient()
sl_domain = random_domain()
domain = random_domain()
spf_record = f"v=spf1 include:{sl_domain}"
client.set_txt_record(domain, [spf_record, "another record"])
res = client.get_spf_domain(domain)
assert res == [sl_domain]

View File

@ -90,12 +90,19 @@ def test_can_be_used_as_personal_email(flask_client):
assert not email_can_be_used_as_mailbox("ab@sl.local")
assert not email_can_be_used_as_mailbox("hey@d1.test")
# custom domain
# custom domain as SL domain
domain = random_domain()
user = create_new_user()
CustomDomain.create(user_id=user.id, domain=domain, verified=True, commit=True)
domain_obj = CustomDomain.create(
user_id=user.id, domain=domain, verified=True, is_sl_subdomain=True, flush=True
)
assert not email_can_be_used_as_mailbox(f"hey@{domain}")
# custom domain is NOT SL domain
domain_obj.is_sl_subdomain = False
Session.flush()
assert email_can_be_used_as_mailbox(f"hey@{domain}")
# disposable domain
disposable_domain = random_domain()
InvalidMailboxDomain.create(domain=disposable_domain, commit=True)

View File

@ -6,7 +6,9 @@ import pytest
from app import mailbox_utils, config
from app.db import Session
from app.mail_sender import mail_sender
from app.models import Mailbox, MailboxActivation, User, Job
from app.mailbox_utils import MailboxEmailChangeError
from app.models import Mailbox, MailboxActivation, User, Job, UserAuditLog
from app.user_audit_log_utils import UserAuditLogAction
from tests.utils import create_new_user, random_email
@ -218,7 +220,11 @@ def test_delete_with_transfer():
user, random_email(), use_digit_codes=True, send_link=False
).mailbox
transfer_mailbox = mailbox_utils.create_mailbox(
user, random_email(), use_digit_codes=True, send_link=False
user,
random_email(),
use_digit_codes=True,
send_link=False,
verified=True,
).mailbox
mailbox_utils.delete_mailbox(
user, mailbox.id, transfer_mailbox_id=transfer_mailbox.id
@ -236,6 +242,28 @@ def test_delete_with_transfer():
assert job.payload["transfer_mailbox_id"] is None
def test_cannot_delete_with_transfer_to_unverified_mailbox():
mailbox = mailbox_utils.create_mailbox(
user, random_email(), use_digit_codes=True, send_link=False
).mailbox
transfer_mailbox = mailbox_utils.create_mailbox(
user,
random_email(),
use_digit_codes=True,
send_link=False,
verified=False,
).mailbox
with pytest.raises(mailbox_utils.MailboxError):
mailbox_utils.delete_mailbox(
user, mailbox.id, transfer_mailbox_id=transfer_mailbox.id
)
# Verify mailbox still exists
db_mailbox = Mailbox.get_by(id=mailbox.id)
assert db_mailbox is not None
def test_verify_non_existing_mailbox():
with pytest.raises(mailbox_utils.MailboxError):
mailbox_utils.verify_mailbox_code(user, 999999999, "9999999")
@ -302,3 +330,79 @@ def test_verify_ok():
assert activation is None
mailbox = Mailbox.get(id=output.mailbox.id)
assert mailbox.verified
# perform_mailbox_email_change
def test_perform_mailbox_email_change_invalid_id():
res = mailbox_utils.perform_mailbox_email_change(99999)
assert res.error == MailboxEmailChangeError.InvalidId
assert res.message_category == "error"
def test_perform_mailbox_email_change_valid_id_not_new_email():
user = create_new_user()
mb = Mailbox.create(
user_id=user.id,
email=random_email(),
new_email=None,
verified=True,
commit=True,
)
res = mailbox_utils.perform_mailbox_email_change(mb.id)
assert res.error == MailboxEmailChangeError.InvalidId
assert res.message_category == "error"
audit_log_entries = UserAuditLog.filter_by(
user_id=user.id, action=UserAuditLogAction.UpdateMailbox.value
).count()
assert audit_log_entries == 0
def test_perform_mailbox_email_change_valid_id_email_already_used():
user = create_new_user()
new_email = random_email()
# Create mailbox with that email
Mailbox.create(
user_id=user.id,
email=new_email,
verified=True,
)
mb_to_change = Mailbox.create(
user_id=user.id,
email=random_email(),
new_email=new_email,
verified=True,
commit=True,
)
res = mailbox_utils.perform_mailbox_email_change(mb_to_change.id)
assert res.error == MailboxEmailChangeError.EmailAlreadyUsed
assert res.message_category == "error"
audit_log_entries = UserAuditLog.filter_by(
user_id=user.id, action=UserAuditLogAction.UpdateMailbox.value
).count()
assert audit_log_entries == 0
def test_perform_mailbox_email_change_success():
user = create_new_user()
new_email = random_email()
mb = Mailbox.create(
user_id=user.id,
email=random_email(),
new_email=new_email,
verified=True,
commit=True,
)
res = mailbox_utils.perform_mailbox_email_change(mb.id)
assert res.error is None
assert res.message_category == "success"
db_mailbox = Mailbox.get_by(id=mb.id)
assert db_mailbox is not None
assert db_mailbox.verified is True
assert db_mailbox.email == new_email
assert db_mailbox.new_email is None
audit_log_entries = UserAuditLog.filter_by(
user_id=user.id, action=UserAuditLogAction.UpdateMailbox.value
).count()
assert audit_log_entries == 1

View File

@ -2,7 +2,7 @@ import arrow
from app.db import Session
from app.models import CoinbaseSubscription
from server import handle_coinbase_event
from app.payments.coinbase import handle_coinbase_event
from tests.utils import create_new_user

View File

@ -1,113 +0,0 @@
import http.server
import json
import threading
import arrow
from app import config
from app.models import (
Subscription,
AppleSubscription,
CoinbaseSubscription,
ManualSubscription,
)
from tests.utils import create_new_user, random_token
from app.subscription_webhook import execute_subscription_webhook
http_server = None
last_http_request = None
def setup_module():
global http_server
http_server = http.server.ThreadingHTTPServer(("", 0), HTTPTestServer)
print(http_server.server_port)
threading.Thread(target=http_server.serve_forever, daemon=True).start()
config.SUBSCRIPTION_CHANGE_WEBHOOK = f"http://localhost:{http_server.server_port}"
def teardown_module():
global http_server
config.SUBSCRIPTION_CHANGE_WEBHOOK = None
http_server.shutdown()
class HTTPTestServer(http.server.BaseHTTPRequestHandler):
def do_POST(self):
global last_http_request
content_len = int(self.headers.get("Content-Length"))
body_data = self.rfile.read(content_len)
last_http_request = json.loads(body_data)
self.send_response(200)
def test_webhook_with_trial():
user = create_new_user()
execute_subscription_webhook(user)
assert last_http_request["user_id"] == user.id
assert last_http_request["is_premium"]
assert last_http_request["active_subscription_end"] is None
def test_webhook_with_subscription():
user = create_new_user()
end_at = arrow.utcnow().shift(days=1).replace(hour=0, minute=0, second=0)
Subscription.create(
user_id=user.id,
cancel_url="",
update_url="",
subscription_id=random_token(10),
event_time=arrow.now(),
next_bill_date=end_at.date(),
plan="yearly",
flush=True,
)
execute_subscription_webhook(user)
assert last_http_request["user_id"] == user.id
assert last_http_request["is_premium"]
assert last_http_request["active_subscription_end"] == end_at.timestamp
def test_webhook_with_apple_subscription():
user = create_new_user()
end_at = arrow.utcnow().shift(days=2).replace(hour=0, minute=0, second=0)
AppleSubscription.create(
user_id=user.id,
receipt_data=arrow.now().date().strftime("%Y-%m-%d"),
expires_date=end_at.date().strftime("%Y-%m-%d"),
original_transaction_id=random_token(10),
plan="yearly",
product_id="",
flush=True,
)
execute_subscription_webhook(user)
assert last_http_request["user_id"] == user.id
assert last_http_request["is_premium"]
assert last_http_request["active_subscription_end"] == end_at.timestamp
def test_webhook_with_coinbase_subscription():
user = create_new_user()
end_at = arrow.utcnow().shift(days=3).replace(hour=0, minute=0, second=0)
CoinbaseSubscription.create(
user_id=user.id, end_at=end_at.date().strftime("%Y-%m-%d"), flush=True
)
execute_subscription_webhook(user)
assert last_http_request["user_id"] == user.id
assert last_http_request["is_premium"]
assert last_http_request["active_subscription_end"] == end_at.timestamp
def test_webhook_with_manual_subscription():
user = create_new_user()
end_at = arrow.utcnow().shift(days=3).replace(hour=0, minute=0, second=0)
ManualSubscription.create(
user_id=user.id, end_at=end_at.date().strftime("%Y-%m-%d"), flush=True
)
execute_subscription_webhook(user)
assert last_http_request["user_id"] == user.id
assert last_http_request["is_premium"]
assert last_http_request["active_subscription_end"] == end_at.timestamp

View File

@ -0,0 +1,57 @@
from typing import List
from app import config, mailbox_utils
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.models import UserAuditLog
from app.utils import random_string
from tests.utils import create_new_user, random_email
def setup_module():
config.SKIP_MX_LOOKUP_ON_CHECK = True
def teardown_module():
config.SKIP_MX_LOOKUP_ON_CHECK = False
def test_emit_alias_audit_log_for_random_data():
user = create_new_user()
message = random_string()
action = UserAuditLogAction.CreateMailbox
emit_user_audit_log(
user=user,
action=action,
message=message,
commit=True,
)
logs_for_user: List[UserAuditLog] = UserAuditLog.filter_by(
user_id=user.id, action=action.value
).all()
assert len(logs_for_user) == 1
assert logs_for_user[0].user_id == user.id
assert logs_for_user[0].user_email == user.email
assert logs_for_user[0].action == action.value
assert logs_for_user[0].message == message
def test_emit_audit_log_on_mailbox_creation():
user = create_new_user()
output = mailbox_utils.create_mailbox(
user=user, email=random_email(), verified=True
)
logs_for_user: List[UserAuditLog] = UserAuditLog.filter_by(
user_id=user.id,
action=UserAuditLogAction.CreateMailbox.value,
).all()
assert len(logs_for_user) == 1
assert logs_for_user[0].user_id == user.id
assert logs_for_user[0].user_email == user.email
assert logs_for_user[0].action == UserAuditLogAction.CreateMailbox.value
assert (
logs_for_user[0].message
== f"Create mailbox {output.mailbox.id} ({output.mailbox.email}). Verified=True"
)