Compare commits

...

9 Commits

Author SHA1 Message Date
294232a329 4.52.1
All checks were successful
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m56s
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 4m45s
Build-Release-Image / Merge-Images (push) Successful in 22s
Build-Release-Image / Create-Release (push) Successful in 8s
Build-Release-Image / Notify (push) Successful in 3s
2024-10-02 12:00:06 +01:00
fae9d7bc17 4.52.0
All checks were successful
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 4m44s
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 4m31s
Build-Release-Image / Merge-Images (push) Successful in 23s
Build-Release-Image / Create-Release (push) Successful in 23s
Build-Release-Image / Notify (push) Successful in 17s
2024-10-01 12:00:06 +01:00
d666f5af3f 4.51.2
All checks were successful
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m33s
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m35s
Build-Release-Image / Merge-Images (push) Successful in 25s
Build-Release-Image / Create-Release (push) Successful in 10s
Build-Release-Image / Notify (push) Successful in 3s
2024-09-28 12:00:06 +01:00
556fae02d5 4.51.1
All checks were successful
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m21s
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m40s
Build-Release-Image / Merge-Images (push) Successful in 23s
Build-Release-Image / Create-Release (push) Successful in 9s
Build-Release-Image / Notify (push) Successful in 4s
2024-09-26 12:00:06 +01:00
fd4c67c3d1 4.51.0
All checks were successful
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m30s
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m31s
Build-Release-Image / Merge-Images (push) Successful in 11s
Build-Release-Image / Create-Release (push) Successful in 9s
Build-Release-Image / Notify (push) Successful in 2s
2024-09-25 12:00:07 +01:00
edef254529 4.50.0
All checks were successful
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m43s
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m54s
Build-Release-Image / Merge-Images (push) Successful in 23s
Build-Release-Image / Create-Release (push) Successful in 14s
Build-Release-Image / Notify (push) Successful in 3s
2024-09-19 12:00:06 +01:00
357f0cca57 4.49.10
Some checks failed
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m7s
Build-Release-Image / Build-Image (linux/amd64) (push) Has been cancelled
Build-Release-Image / Merge-Images (push) Has been cancelled
Build-Release-Image / Create-Release (push) Has been cancelled
Build-Release-Image / Notify (push) Has been cancelled
2024-09-13 12:00:28 +01:00
8ce90e27f7 4.49.9
Some checks failed
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m4s
Build-Release-Image / Build-Image (linux/amd64) (push) Has been cancelled
Build-Release-Image / Merge-Images (push) Has been cancelled
Build-Release-Image / Create-Release (push) Has been cancelled
Build-Release-Image / Notify (push) Has been cancelled
2024-09-07 12:00:06 +01:00
3ecc8d36f9 4.49.8
Some checks failed
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m18s
Build-Release-Image / Build-Image (linux/amd64) (push) Has been cancelled
Build-Release-Image / Merge-Images (push) Has been cancelled
Build-Release-Image / Create-Release (push) Has been cancelled
Build-Release-Image / Notify (push) Has been cancelled
2024-09-04 12:00:07 +01:00
48 changed files with 2086 additions and 577 deletions

View File

@ -109,7 +109,7 @@ jobs:
GITHUB_ACTIONS_TEST: true GITHUB_ACTIONS_TEST: true
- name: Archive code coverage results - name: Archive code coverage results
uses: actions/upload-artifact@v2 uses: actions/upload-artifact@v4
with: with:
name: code-coverage-report name: code-coverage-report
path: htmlcov path: htmlcov

View File

@ -736,7 +736,8 @@ class InvalidMailboxDomainAdmin(SLModelView):
class EmailSearchResult: class EmailSearchResult:
no_match: bool = True no_match: bool = True
alias: Optional[Alias] = None alias: Optional[Alias] = None
mailbox: Optional[Mailbox] = None mailbox: list[Mailbox] = []
mailbox_count: int = 0
deleted_alias: Optional[DeletedAlias] = None deleted_alias: Optional[DeletedAlias] = None
deleted_custom_alias: Optional[DomainDeletedAlias] = None deleted_custom_alias: Optional[DomainDeletedAlias] = None
user: Optional[User] = None user: Optional[User] = None
@ -752,9 +753,12 @@ class EmailSearchResult:
if user: if user:
output.user = user output.user = user
output.no_match = False output.no_match = False
mailbox = Mailbox.get_by(email=email) mailboxes = (
if mailbox: Mailbox.filter_by(email=email).order_by(Mailbox.id.desc()).limit(10).all()
output.mailbox = mailbox )
if mailboxes:
output.mailbox = mailboxes
output.mailbox_count = Mailbox.filter_by(email=email).count()
output.no_match = False output.no_match = False
deleted_alias = DeletedAlias.get_by(email=email) deleted_alias = DeletedAlias.get_by(email=email)
if deleted_alias: if deleted_alias:
@ -779,11 +783,13 @@ class EmailSearchHelpers:
@staticmethod @staticmethod
def mailbox_count(user: User) -> int: def mailbox_count(user: User) -> int:
return Mailbox.filter_by(user_id=user.id).order_by(Mailbox.id.asc()).count() return Mailbox.filter_by(user_id=user.id).order_by(Mailbox.id.desc()).count()
@staticmethod @staticmethod
def alias_list(user: User) -> list[Alias]: def alias_list(user: User) -> list[Alias]:
return Alias.filter_by(user_id=user.id).order_by(Alias.id.asc()).limit(10).all() return (
Alias.filter_by(user_id=user.id).order_by(Alias.id.desc()).limit(10).all()
)
@staticmethod @staticmethod
def alias_count(user: User) -> int: def alias_count(user: User) -> int:
@ -806,9 +812,8 @@ class EmailSearchAdmin(BaseView):
@expose("/", methods=["GET", "POST"]) @expose("/", methods=["GET", "POST"])
def index(self): def index(self):
search = EmailSearchResult() search = EmailSearchResult()
email = "" email = request.args.get("email")
if request.form and request.form["email"]: if email is not None and len(email) > 0:
email = request.form["email"]
email = email.strip() email = email.strip()
search = EmailSearchResult.from_email(email) search = EmailSearchResult.from_email(email)

View File

@ -363,11 +363,14 @@ def delete_alias(
Session.commit() Session.commit()
LOG.i(f"Moving {alias} to global trash {deleted_alias}") LOG.i(f"Moving {alias} to global trash {deleted_alias}")
alias_id = alias.id
alias_email = alias.email
Alias.filter(Alias.id == alias.id).delete() Alias.filter(Alias.id == alias.id).delete()
Session.commit() Session.commit()
EventDispatcher.send_event( EventDispatcher.send_event(
user, EventContent(alias_deleted=AliasDeleted(alias_id=alias.id)) user,
EventContent(alias_deleted=AliasDeleted(id=alias_id, email=alias_email)),
) )
if commit: if commit:
Session.commit() Session.commit()
@ -506,7 +509,10 @@ def change_alias_status(alias: Alias, enabled: bool, commit: bool = False):
alias.enabled = enabled alias.enabled = enabled
event = AliasStatusChanged( event = AliasStatusChanged(
alias_id=alias.id, alias_email=alias.email, enabled=enabled id=alias.id,
email=alias.email,
enabled=enabled,
created_at=int(alias.created_at.timestamp),
) )
EventDispatcher.send_event(alias.user, EventContent(alias_status_change=event)) EventDispatcher.send_event(alias.user, EventContent(alias_status_change=event))

View File

@ -424,7 +424,7 @@ def create_contact_route(alias_id):
contact_address = data.get("contact") contact_address = data.get("contact")
try: try:
contact = create_contact(g.user, alias, contact_address) contact = create_contact(alias, contact_address)
except ErrContactErrorUpgradeNeeded as err: except ErrContactErrorUpgradeNeeded as err:
return jsonify(error=err.error_for_user()), 403 return jsonify(error=err.error_for_user()), 403
except (ErrAddressInvalid, CannotCreateContactForReverseAlias) as err: except (ErrAddressInvalid, CannotCreateContactForReverseAlias) as err:

View File

@ -35,6 +35,33 @@ def sl_getenv(env_var: str, default_factory: Callable = None):
return literal_eval(value) return literal_eval(value)
def get_env_dict(env_var: str) -> dict[str, str]:
"""
Get an env variable and convert it into a python dictionary with keys and values as strings.
Args:
env_var (str): env var, example: SL_DB
Syntax is: key1=value1;key2=value2
Components separated by ;
key and value separated by =
"""
value = os.getenv(env_var)
if not value:
return {}
components = value.split(";")
result = {}
for component in components:
if component == "":
continue
parts = component.split("=")
if len(parts) != 2:
raise Exception(f"Invalid config for env var {env_var}")
result[parts[0].strip()] = parts[1].strip()
return result
config_file = os.environ.get("CONFIG") config_file = os.environ.get("CONFIG")
if config_file: if config_file:
config_file = get_abs_path(config_file) config_file = get_abs_path(config_file)
@ -609,3 +636,28 @@ EVENT_WEBHOOK_ENABLED_USER_IDS: Optional[List[int]] = read_webhook_enabled_user_
# Allow to define a different DB_URI for the event listener, in case we want to skip the connection pool # Allow to define a different DB_URI for the event listener, in case we want to skip the connection pool
# It defaults to the regular DB_URI in case it's needed # It defaults to the regular DB_URI in case it's needed
EVENT_LISTENER_DB_URI = os.environ.get("EVENT_LISTENER_DB_URI", DB_URI) EVENT_LISTENER_DB_URI = os.environ.get("EVENT_LISTENER_DB_URI", DB_URI)
def read_partner_dict(var: str) -> dict[int, str]:
partner_value = get_env_dict(var)
if len(partner_value) == 0:
return {}
res: dict[int, str] = {}
for partner_id in partner_value.keys():
try:
partner_id_int = int(partner_id.strip())
res[partner_id_int] = partner_value[partner_id]
except ValueError:
pass
return res
PARTNER_DOMAINS: dict[int, str] = read_partner_dict("PARTNER_DOMAINS")
PARTNER_DOMAIN_VALIDATION_PREFIXES: dict[int, str] = read_partner_dict(
"PARTNER_DOMAIN_VALIDATION_PREFIXES"
)
MAILBOX_VERIFICATION_OVERRIDE_CODE: Optional[str] = os.environ.get(
"MAILBOX_VERIFICATION_OVERRIDE_CODE", None
)

View File

@ -1 +1,2 @@
HEADER_ALLOW_API_COOKIES = "X-Sl-Allowcookies" HEADER_ALLOW_API_COOKIES = "X-Sl-Allowcookies"
DMARC_RECORD = "v=DMARC1; p=quarantine; pct=100; adkim=s; aspf=s"

113
app/app/contact_utils.py Normal file
View File

@ -0,0 +1,113 @@
from dataclasses import dataclass
from enum import Enum
from typing import Optional
from sqlalchemy.exc import IntegrityError
from app.db import Session
from app.email_utils import generate_reply_email, parse_full_address
from app.email_validation import is_valid_email
from app.log import LOG
from app.models import Contact, Alias
from app.utils import sanitize_email
class ContactCreateError(Enum):
InvalidEmail = "Invalid email"
NotAllowed = "Your plan does not allow to create contacts"
@dataclass
class ContactCreateResult:
contact: Optional[Contact]
created: bool
error: Optional[ContactCreateError]
def __update_contact_if_needed(
contact: Contact, name: Optional[str], mail_from: Optional[str]
) -> ContactCreateResult:
if name and contact.name != name:
LOG.d(f"Setting {contact} name to {name}")
contact.name = name
Session.commit()
if mail_from and contact.mail_from is None:
LOG.d(f"Setting {contact} mail_from to {mail_from}")
contact.mail_from = mail_from
Session.commit()
return ContactCreateResult(contact, created=False, error=None)
def create_contact(
email: str,
alias: Alias,
name: Optional[str] = None,
mail_from: Optional[str] = None,
allow_empty_email: bool = False,
automatic_created: bool = False,
from_partner: bool = False,
) -> ContactCreateResult:
# If user cannot create contacts, they still need to be created when receiving an email for an alias
if not automatic_created and not alias.user.can_create_contacts():
return ContactCreateResult(
None, created=False, error=ContactCreateError.NotAllowed
)
# Parse emails with form 'name <email>'
try:
email_name, email = parse_full_address(email)
except ValueError:
email = ""
email_name = ""
# If no name is explicitly given try to get it from the parsed email
if name is None:
name = email_name[: Contact.MAX_NAME_LENGTH]
else:
name = name[: Contact.MAX_NAME_LENGTH]
# If still no name is there, make sure the name is None instead of empty string
if not name:
name = None
if name is not None and "\x00" in name:
LOG.w("Cannot use contact name because has \\x00")
name = ""
# Sanitize email and if it's not valid only allow to create a contact if it's explicitly allowed. Otherwise fail
email = sanitize_email(email, not_lower=True)
if not is_valid_email(email):
LOG.w(f"invalid contact email {email}")
if not allow_empty_email:
return ContactCreateResult(
None, created=False, error=ContactCreateError.InvalidEmail
)
LOG.d("Create a contact with invalid email for %s", alias)
# either reuse a contact with empty email or create a new contact with empty email
email = ""
# If contact exists, update name and mail_from if needed
contact = Contact.get_by(alias_id=alias.id, website_email=email)
if contact is not None:
return __update_contact_if_needed(contact, name, mail_from)
# Create the contact
reply_email = generate_reply_email(email, alias)
try:
flags = Contact.FLAG_PARTNER_CREATED if from_partner else 0
contact = Contact.create(
user_id=alias.user_id,
alias_id=alias.id,
website_email=email,
name=name,
reply_email=reply_email,
mail_from=mail_from,
automatic_created=automatic_created,
flags=flags,
invalid_email=email == "",
commit=True,
)
LOG.d(
f"Created contact {contact} for alias {alias} with email {email} invalid_email={contact.invalid_email}"
)
except IntegrityError:
Session.rollback()
LOG.info(
f"Contact with email {email} for alias_id {alias.id} already existed, fetching from DB"
)
contact = Contact.get_by(alias_id=alias.id, website_email=email)
return __update_contact_if_needed(contact, name, mail_from)
return ContactCreateResult(contact, created=True, error=None)

View File

@ -0,0 +1,194 @@
import arrow
import re
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional
from app.config import JOB_DELETE_DOMAIN
from app.db import Session
from app.email_utils import get_email_domain_part
from app.log import LOG
from app.models import User, CustomDomain, SLDomain, Mailbox, Job, DomainMailbox
_ALLOWED_DOMAIN_REGEX = re.compile(r"^(?!-)[A-Za-z0-9-]{1,63}(?<!-)$")
_MAX_MAILBOXES_PER_DOMAIN = 20
@dataclass
class CreateCustomDomainResult:
message: str = ""
message_category: str = ""
success: bool = False
instance: Optional[CustomDomain] = None
redirect: Optional[str] = None
class CannotUseDomainReason(Enum):
InvalidDomain = 1
BuiltinDomain = 2
DomainAlreadyUsed = 3
DomainPartOfUserEmail = 4
DomainUserInMailbox = 5
def message(self, domain: str) -> str:
if self == CannotUseDomainReason.InvalidDomain:
return "This is not a valid domain"
elif self == CannotUseDomainReason.BuiltinDomain:
return "A custom domain cannot be a built-in domain."
elif self == CannotUseDomainReason.DomainAlreadyUsed:
return f"{domain} already used"
elif self == CannotUseDomainReason.DomainPartOfUserEmail:
return "You cannot add a domain that you are currently using for your personal email. Please change your personal email to your real email"
elif self == CannotUseDomainReason.DomainUserInMailbox:
return f"{domain} already used in a SimpleLogin mailbox"
else:
raise Exception("Invalid CannotUseDomainReason")
class CannotSetCustomDomainMailboxesCause(Enum):
InvalidMailbox = "Something went wrong, please retry"
NoMailboxes = "You must select at least 1 mailbox"
TooManyMailboxes = (
f"You can only set up to {_MAX_MAILBOXES_PER_DOMAIN} mailboxes per domain"
)
@dataclass
class SetCustomDomainMailboxesResult:
success: bool
reason: Optional[CannotSetCustomDomainMailboxesCause] = None
def is_valid_domain(domain: str) -> bool:
"""
Checks that a domain is valid according to RFC 1035
"""
if len(domain) > 255:
return False
if domain.endswith("."):
domain = domain[:-1] # Strip the trailing dot
labels = domain.split(".")
if not labels:
return False
for label in labels:
if not _ALLOWED_DOMAIN_REGEX.match(label):
return False
return True
def sanitize_domain(domain: str) -> str:
new_domain = domain.lower().strip()
if new_domain.startswith("http://"):
new_domain = new_domain[len("http://") :]
if new_domain.startswith("https://"):
new_domain = new_domain[len("https://") :]
return new_domain
def can_domain_be_used(user: User, domain: str) -> Optional[CannotUseDomainReason]:
if not is_valid_domain(domain):
return CannotUseDomainReason.InvalidDomain
elif SLDomain.get_by(domain=domain):
return CannotUseDomainReason.BuiltinDomain
elif CustomDomain.get_by(domain=domain):
return CannotUseDomainReason.DomainAlreadyUsed
elif get_email_domain_part(user.email) == domain:
return CannotUseDomainReason.DomainPartOfUserEmail
elif Mailbox.filter(
Mailbox.verified.is_(True), Mailbox.email.endswith(f"@{domain}")
).first():
return CannotUseDomainReason.DomainUserInMailbox
else:
return None
def create_custom_domain(
user: User, domain: str, partner_id: Optional[int] = None
) -> CreateCustomDomainResult:
if not user.is_premium():
return CreateCustomDomainResult(
message="Only premium plan can add custom domain",
message_category="warning",
)
new_domain = sanitize_domain(domain)
domain_forbidden_cause = can_domain_be_used(user, new_domain)
if domain_forbidden_cause:
return CreateCustomDomainResult(
message=domain_forbidden_cause.message(new_domain), message_category="error"
)
new_custom_domain = CustomDomain.create(domain=new_domain, user_id=user.id)
# new domain has ownership verified if its parent has the ownership verified
for root_cd in user.custom_domains:
if new_domain.endswith("." + root_cd.domain) and root_cd.ownership_verified:
LOG.i(
"%s ownership verified thanks to %s",
new_custom_domain,
root_cd,
)
new_custom_domain.ownership_verified = True
# Add the partner_id in case it's passed
if partner_id is not None:
new_custom_domain.partner_id = partner_id
Session.commit()
return CreateCustomDomainResult(
success=True,
instance=new_custom_domain,
)
def delete_custom_domain(domain: CustomDomain):
# Schedule delete domain job
LOG.w("schedule delete domain job for %s", domain)
domain.pending_deletion = True
Job.create(
name=JOB_DELETE_DOMAIN,
payload={"custom_domain_id": domain.id},
run_at=arrow.now(),
commit=True,
)
def set_custom_domain_mailboxes(
user_id: int, custom_domain: CustomDomain, mailbox_ids: List[int]
) -> SetCustomDomainMailboxesResult:
if len(mailbox_ids) == 0:
return SetCustomDomainMailboxesResult(
success=False, reason=CannotSetCustomDomainMailboxesCause.NoMailboxes
)
elif len(mailbox_ids) > _MAX_MAILBOXES_PER_DOMAIN:
return SetCustomDomainMailboxesResult(
success=False, reason=CannotSetCustomDomainMailboxesCause.TooManyMailboxes
)
mailboxes = (
Session.query(Mailbox)
.filter(
Mailbox.id.in_(mailbox_ids),
Mailbox.user_id == user_id,
Mailbox.verified == True, # noqa: E712
)
.all()
)
if len(mailboxes) != len(mailbox_ids):
return SetCustomDomainMailboxesResult(
success=False, reason=CannotSetCustomDomainMailboxesCause.InvalidMailbox
)
# first remove all existing domain-mailboxes links
DomainMailbox.filter_by(domain_id=custom_domain.id).delete()
Session.flush()
for mailbox in mailboxes:
DomainMailbox.create(domain_id=custom_domain.id, mailbox_id=mailbox.id)
Session.commit()
return SetCustomDomainMailboxesResult(success=True)

View File

@ -1,37 +1,157 @@
from dataclasses import dataclass
from typing import Optional
from app import config
from app.constants import DMARC_RECORD
from app.db import Session from app.db import Session
from app.dns_utils import get_cname_record from app.dns_utils import (
DNSClient,
is_mx_equivalent,
get_network_dns_client,
)
from app.models import CustomDomain from app.models import CustomDomain
@dataclass
class DomainValidationResult:
success: bool
errors: [str]
class CustomDomainValidation: class CustomDomainValidation:
def __init__(self, dkim_domain: str): def __init__(
self,
dkim_domain: str,
dns_client: DNSClient = get_network_dns_client(),
partner_domains: Optional[dict[int, str]] = None,
partner_domains_validation_prefixes: Optional[dict[int, str]] = None,
):
self.dkim_domain = dkim_domain self.dkim_domain = dkim_domain
self._dkim_records = { self._dns_client = dns_client
(f"{key}._domainkey", f"{key}._domainkey.{self.dkim_domain}") self._partner_domains = partner_domains or config.PARTNER_DOMAINS
self._partner_domain_validation_prefixes = (
partner_domains_validation_prefixes
or config.PARTNER_DOMAIN_VALIDATION_PREFIXES
)
def get_ownership_verification_record(self, domain: CustomDomain) -> str:
prefix = "sl"
if (
domain.partner_id is not None
and domain.partner_id in self._partner_domain_validation_prefixes
):
prefix = self._partner_domain_validation_prefixes[domain.partner_id]
return f"{prefix}-verification={domain.ownership_txt_token}"
def get_dkim_records(self, domain: CustomDomain) -> {str: str}:
"""
Get a list of dkim records to set up. Depending on the custom_domain, whether if it's from a partner or not,
it will return the default ones or the partner ones.
"""
# By default use the default domain
dkim_domain = self.dkim_domain
if domain.partner_id is not None:
# Domain is from a partner. Retrieve the partner config and use that domain if exists
dkim_domain = self._partner_domains.get(domain.partner_id, dkim_domain)
return {
f"{key}._domainkey": f"{key}._domainkey.{dkim_domain}"
for key in ("dkim", "dkim02", "dkim03") for key in ("dkim", "dkim02", "dkim03")
} }
def get_dkim_records(self) -> {str: str}:
"""
Get a list of dkim records to set up. It will be
"""
return self._dkim_records
def validate_dkim_records(self, custom_domain: CustomDomain) -> dict[str, str]: def validate_dkim_records(self, custom_domain: CustomDomain) -> dict[str, str]:
""" """
Check if dkim records are properly set for this custom domain. Check if dkim records are properly set for this custom domain.
Returns empty list if all records are ok. Other-wise return the records that aren't properly configured Returns empty list if all records are ok. Other-wise return the records that aren't properly configured
""" """
correct_records = {}
invalid_records = {} invalid_records = {}
for prefix, expected_record in self.get_dkim_records(): expected_records = self.get_dkim_records(custom_domain)
for prefix, expected_record in expected_records.items():
custom_record = f"{prefix}.{custom_domain.domain}" custom_record = f"{prefix}.{custom_domain.domain}"
dkim_record = get_cname_record(custom_record) dkim_record = self._dns_client.get_cname_record(custom_record)
if dkim_record != expected_record: if dkim_record == expected_record:
correct_records[prefix] = custom_record
else:
invalid_records[custom_record] = dkim_record or "empty" invalid_records[custom_record] = dkim_record or "empty"
# HACK: If dkim is enabled, don't disable it to give users time to update their CNAMES
# HACK
# As initially we only had one dkim record, we want to allow users that had only the original dkim record and
# the domain validated to continue seeing it as validated (although showing them the missing records).
# However, if not even the original dkim record is right, even if the domain was dkim_verified in the past,
# we will remove the dkim_verified flag.
# This is done in order to give users with the old dkim config (only one) to update their CNAMEs
if custom_domain.dkim_verified: if custom_domain.dkim_verified:
return invalid_records # Check if at least the original dkim is there
if correct_records.get("dkim._domainkey") is not None:
# Original dkim record is there. Return the missing records (if any) and don't clear the flag
return invalid_records
# Original DKIM record is not there, which means the DKIM config is not finished. Proceed with the
# rest of the code path, returning the invalid records and clearing the flag
custom_domain.dkim_verified = len(invalid_records) == 0 custom_domain.dkim_verified = len(invalid_records) == 0
Session.commit() Session.commit()
return invalid_records return invalid_records
def validate_domain_ownership(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
"""
Check if the custom_domain has added the ownership verification records
"""
txt_records = self._dns_client.get_txt_record(custom_domain.domain)
expected_verification_record = self.get_ownership_verification_record(
custom_domain
)
if expected_verification_record in txt_records:
custom_domain.ownership_verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
else:
return DomainValidationResult(success=False, errors=txt_records)
def validate_mx_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
mx_domains = self._dns_client.get_mx_domains(custom_domain.domain)
if not is_mx_equivalent(mx_domains, config.EMAIL_SERVERS_WITH_PRIORITY):
return DomainValidationResult(
success=False,
errors=[f"{priority} {domain}" for (priority, domain) in mx_domains],
)
else:
custom_domain.verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
def validate_spf_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
spf_domains = self._dns_client.get_spf_domain(custom_domain.domain)
if config.EMAIL_DOMAIN in spf_domains:
custom_domain.spf_verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
else:
custom_domain.spf_verified = False
Session.commit()
return DomainValidationResult(
success=False,
errors=self._dns_client.get_txt_record(custom_domain.domain),
)
def validate_dmarc_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
txt_records = self._dns_client.get_txt_record("_dmarc." + custom_domain.domain)
if DMARC_RECORD in txt_records:
custom_domain.dmarc_verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
else:
custom_domain.dmarc_verified = False
Session.commit()
return DomainValidationResult(success=False, errors=txt_records)

View File

@ -9,13 +9,10 @@ from sqlalchemy import and_, func, case
from wtforms import StringField, validators, ValidationError from wtforms import StringField, validators, ValidationError
# Need to import directly from config to allow modification from the tests # Need to import directly from config to allow modification from the tests
from app import config, parallel_limiter from app import config, parallel_limiter, contact_utils
from app.contact_utils import ContactCreateError
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
from app.email_utils import (
generate_reply_email,
parse_full_address,
)
from app.email_validation import is_valid_email from app.email_validation import is_valid_email
from app.errors import ( from app.errors import (
CannotCreateContactForReverseAlias, CannotCreateContactForReverseAlias,
@ -24,8 +21,8 @@ from app.errors import (
ErrContactAlreadyExists, ErrContactAlreadyExists,
) )
from app.log import LOG from app.log import LOG
from app.models import Alias, Contact, EmailLog, User from app.models import Alias, Contact, EmailLog
from app.utils import sanitize_email, CSRFValidationForm from app.utils import CSRFValidationForm
def email_validator(): def email_validator():
@ -51,7 +48,7 @@ def email_validator():
return _check return _check
def create_contact(user: User, alias: Alias, contact_address: str) -> Contact: def create_contact(alias: Alias, contact_address: str) -> Contact:
""" """
Create a contact for a user. Can be restricted for new free users by enabling DISABLE_CREATE_CONTACTS_FOR_FREE_USERS. Create a contact for a user. Can be restricted for new free users by enabling DISABLE_CREATE_CONTACTS_FOR_FREE_USERS.
Can throw exceptions: Can throw exceptions:
@ -61,37 +58,23 @@ def create_contact(user: User, alias: Alias, contact_address: str) -> Contact:
""" """
if not contact_address: if not contact_address:
raise ErrAddressInvalid("Empty address") raise ErrAddressInvalid("Empty address")
try: output = contact_utils.create_contact(email=contact_address, alias=alias)
contact_name, contact_email = parse_full_address(contact_address) if output.error == ContactCreateError.InvalidEmail:
except ValueError:
raise ErrAddressInvalid(contact_address) raise ErrAddressInvalid(contact_address)
elif output.error == ContactCreateError.NotAllowed:
contact_email = sanitize_email(contact_email)
if not is_valid_email(contact_email):
raise ErrAddressInvalid(contact_email)
contact = Contact.get_by(alias_id=alias.id, website_email=contact_email)
if contact:
raise ErrContactAlreadyExists(contact)
if not user.can_create_contacts():
raise ErrContactErrorUpgradeNeeded() raise ErrContactErrorUpgradeNeeded()
elif output.error is not None:
raise ErrAddressInvalid("Invalid address")
elif not output.created:
raise ErrContactAlreadyExists(output.contact)
contact = Contact.create( contact = output.contact
user_id=alias.user_id,
alias_id=alias.id,
website_email=contact_email,
name=contact_name,
reply_email=generate_reply_email(contact_email, alias),
)
LOG.d( LOG.d(
"create reverse-alias for %s %s, reverse alias:%s", "create reverse-alias for %s %s, reverse alias:%s",
contact_address, contact_address,
alias, alias,
contact.reply_email, contact.reply_email,
) )
Session.commit()
return contact return contact
@ -261,7 +244,7 @@ def alias_contact_manager(alias_id):
if new_contact_form.validate(): if new_contact_form.validate():
contact_address = new_contact_form.email.data.strip() contact_address = new_contact_form.email.data.strip()
try: try:
contact = create_contact(current_user, alias, contact_address) contact = create_contact(alias, contact_address)
except ( except (
ErrContactErrorUpgradeNeeded, ErrContactErrorUpgradeNeeded,
ErrAddressInvalid, ErrAddressInvalid,

View File

@ -5,11 +5,9 @@ from wtforms import StringField, validators
from app import parallel_limiter from app import parallel_limiter
from app.config import EMAIL_SERVERS_WITH_PRIORITY from app.config import EMAIL_SERVERS_WITH_PRIORITY
from app.custom_domain_utils import create_custom_domain
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session from app.models import CustomDomain
from app.email_utils import get_email_domain_part
from app.log import LOG
from app.models import CustomDomain, Mailbox, DomainMailbox, SLDomain
class NewCustomDomainForm(FlaskForm): class NewCustomDomainForm(FlaskForm):
@ -25,11 +23,8 @@ def custom_domain():
custom_domains = CustomDomain.filter_by( custom_domains = CustomDomain.filter_by(
user_id=current_user.id, is_sl_subdomain=False user_id=current_user.id, is_sl_subdomain=False
).all() ).all()
mailboxes = current_user.mailboxes()
new_custom_domain_form = NewCustomDomainForm() new_custom_domain_form = NewCustomDomainForm()
errors = {}
if request.method == "POST": if request.method == "POST":
if request.form.get("form-name") == "create": if request.form.get("form-name") == "create":
if not current_user.is_premium(): if not current_user.is_premium():
@ -37,87 +32,25 @@ def custom_domain():
return redirect(url_for("dashboard.custom_domain")) return redirect(url_for("dashboard.custom_domain"))
if new_custom_domain_form.validate(): if new_custom_domain_form.validate():
new_domain = new_custom_domain_form.domain.data.lower().strip() res = create_custom_domain(
user=current_user, domain=new_custom_domain_form.domain.data
if new_domain.startswith("http://"): )
new_domain = new_domain[len("http://") :] if res.success:
flash(f"New domain {res.instance.domain} is created", "success")
if new_domain.startswith("https://"):
new_domain = new_domain[len("https://") :]
if SLDomain.get_by(domain=new_domain):
flash("A custom domain cannot be a built-in domain.", "error")
elif CustomDomain.get_by(domain=new_domain):
flash(f"{new_domain} already used", "error")
elif get_email_domain_part(current_user.email) == new_domain:
flash(
"You cannot add a domain that you are currently using for your personal email. "
"Please change your personal email to your real email",
"error",
)
elif Mailbox.filter(
Mailbox.verified.is_(True), Mailbox.email.endswith(f"@{new_domain}")
).first():
flash(
f"{new_domain} already used in a SimpleLogin mailbox", "error"
)
else:
new_custom_domain = CustomDomain.create(
domain=new_domain, user_id=current_user.id
)
# new domain has ownership verified if its parent has the ownership verified
for root_cd in current_user.custom_domains:
if (
new_domain.endswith("." + root_cd.domain)
and root_cd.ownership_verified
):
LOG.i(
"%s ownership verified thanks to %s",
new_custom_domain,
root_cd,
)
new_custom_domain.ownership_verified = True
Session.commit()
mailbox_ids = request.form.getlist("mailbox_ids")
if mailbox_ids:
# check if mailbox is not tempered with
mailboxes = []
for mailbox_id in mailbox_ids:
mailbox = Mailbox.get(mailbox_id)
if (
not mailbox
or mailbox.user_id != current_user.id
or not mailbox.verified
):
flash("Something went wrong, please retry", "warning")
return redirect(url_for("dashboard.custom_domain"))
mailboxes.append(mailbox)
for mailbox in mailboxes:
DomainMailbox.create(
domain_id=new_custom_domain.id, mailbox_id=mailbox.id
)
Session.commit()
flash(
f"New domain {new_custom_domain.domain} is created", "success"
)
return redirect( return redirect(
url_for( url_for(
"dashboard.domain_detail_dns", "dashboard.domain_detail_dns",
custom_domain_id=new_custom_domain.id, custom_domain_id=res.instance.id,
) )
) )
else:
flash(res.message, res.message_category)
if res.redirect:
return redirect(url_for(res.redirect))
return render_template( return render_template(
"dashboard/custom_domain.html", "dashboard/custom_domain.html",
custom_domains=custom_domains, custom_domains=custom_domains,
new_custom_domain_form=new_custom_domain_form, new_custom_domain_form=new_custom_domain_form,
EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY, EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY,
errors=errors,
mailboxes=mailboxes,
) )

View File

@ -1,31 +1,23 @@
import re import re
import arrow
from flask import render_template, request, redirect, url_for, flash from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import StringField, validators, IntegerField from wtforms import StringField, validators, IntegerField
from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN, JOB_DELETE_DOMAIN from app.constants import DMARC_RECORD
from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN
from app.custom_domain_utils import delete_custom_domain, set_custom_domain_mailboxes
from app.custom_domain_validation import CustomDomainValidation from app.custom_domain_validation import CustomDomainValidation
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
from app.dns_utils import (
get_mx_domains,
get_spf_domain,
get_txt_record,
is_mx_equivalent,
)
from app.log import LOG
from app.models import ( from app.models import (
CustomDomain, CustomDomain,
Alias, Alias,
DomainDeletedAlias, DomainDeletedAlias,
Mailbox, Mailbox,
DomainMailbox,
AutoCreateRule, AutoCreateRule,
AutoCreateRuleMailbox, AutoCreateRuleMailbox,
Job,
) )
from app.regex_utils import regex_match from app.regex_utils import regex_match
from app.utils import random_string, CSRFValidationForm from app.utils import random_string, CSRFValidationForm
@ -49,8 +41,6 @@ def domain_detail_dns(custom_domain_id):
domain_validator = CustomDomainValidation(EMAIL_DOMAIN) domain_validator = CustomDomainValidation(EMAIL_DOMAIN)
csrf_form = CSRFValidationForm() csrf_form = CSRFValidationForm()
dmarc_record = "v=DMARC1; p=quarantine; pct=100; adkim=s; aspf=s"
mx_ok = spf_ok = dkim_ok = dmarc_ok = ownership_ok = True mx_ok = spf_ok = dkim_ok = dmarc_ok = ownership_ok = True
mx_errors = spf_errors = dkim_errors = dmarc_errors = ownership_errors = [] mx_errors = spf_errors = dkim_errors = dmarc_errors = ownership_errors = []
@ -59,15 +49,14 @@ def domain_detail_dns(custom_domain_id):
flash("Invalid request", "warning") flash("Invalid request", "warning")
return redirect(request.url) return redirect(request.url)
if request.form.get("form-name") == "check-ownership": if request.form.get("form-name") == "check-ownership":
txt_records = get_txt_record(custom_domain.domain) ownership_validation_result = domain_validator.validate_domain_ownership(
custom_domain
if custom_domain.get_ownership_dns_txt_value() in txt_records: )
if ownership_validation_result.success:
flash( flash(
"Domain ownership is verified. Please proceed to the other records setup", "Domain ownership is verified. Please proceed to the other records setup",
"success", "success",
) )
custom_domain.ownership_verified = True
Session.commit()
return redirect( return redirect(
url_for( url_for(
"dashboard.domain_detail_dns", "dashboard.domain_detail_dns",
@ -78,36 +67,28 @@ def domain_detail_dns(custom_domain_id):
else: else:
flash("We can't find the needed TXT record", "error") flash("We can't find the needed TXT record", "error")
ownership_ok = False ownership_ok = False
ownership_errors = txt_records ownership_errors = ownership_validation_result.errors
elif request.form.get("form-name") == "check-mx": elif request.form.get("form-name") == "check-mx":
mx_domains = get_mx_domains(custom_domain.domain) mx_validation_result = domain_validator.validate_mx_records(custom_domain)
if mx_validation_result.success:
if not is_mx_equivalent(mx_domains, EMAIL_SERVERS_WITH_PRIORITY):
flash("The MX record is not correctly set", "warning")
mx_ok = False
# build mx_errors to show to user
mx_errors = [
f"{priority} {domain}" for (priority, domain) in mx_domains
]
else:
flash( flash(
"Your domain can start receiving emails. You can now use it to create alias", "Your domain can start receiving emails. You can now use it to create alias",
"success", "success",
) )
custom_domain.verified = True
Session.commit()
return redirect( return redirect(
url_for( url_for(
"dashboard.domain_detail_dns", custom_domain_id=custom_domain.id "dashboard.domain_detail_dns", custom_domain_id=custom_domain.id
) )
) )
else:
flash("The MX record is not correctly set", "warning")
mx_ok = False
mx_errors = mx_validation_result.errors
elif request.form.get("form-name") == "check-spf": elif request.form.get("form-name") == "check-spf":
spf_domains = get_spf_domain(custom_domain.domain) spf_validation_result = domain_validator.validate_spf_records(custom_domain)
if EMAIL_DOMAIN in spf_domains: if spf_validation_result.success:
custom_domain.spf_verified = True
Session.commit()
flash("SPF is setup correctly", "success") flash("SPF is setup correctly", "success")
return redirect( return redirect(
url_for( url_for(
@ -115,14 +96,12 @@ def domain_detail_dns(custom_domain_id):
) )
) )
else: else:
custom_domain.spf_verified = False
Session.commit()
flash( flash(
f"SPF: {EMAIL_DOMAIN} is not included in your SPF record.", f"SPF: {EMAIL_DOMAIN} is not included in your SPF record.",
"warning", "warning",
) )
spf_ok = False spf_ok = False
spf_errors = get_txt_record(custom_domain.domain) spf_errors = spf_validation_result.errors
elif request.form.get("form-name") == "check-dkim": elif request.form.get("form-name") == "check-dkim":
dkim_errors = domain_validator.validate_dkim_records(custom_domain) dkim_errors = domain_validator.validate_dkim_records(custom_domain)
@ -138,10 +117,10 @@ def domain_detail_dns(custom_domain_id):
flash("DKIM: the CNAME record is not correctly set", "warning") flash("DKIM: the CNAME record is not correctly set", "warning")
elif request.form.get("form-name") == "check-dmarc": elif request.form.get("form-name") == "check-dmarc":
txt_records = get_txt_record("_dmarc." + custom_domain.domain) dmarc_validation_result = domain_validator.validate_dmarc_records(
if dmarc_record in txt_records: custom_domain
custom_domain.dmarc_verified = True )
Session.commit() if dmarc_validation_result.success:
flash("DMARC is setup correctly", "success") flash("DMARC is setup correctly", "success")
return redirect( return redirect(
url_for( url_for(
@ -149,19 +128,21 @@ def domain_detail_dns(custom_domain_id):
) )
) )
else: else:
custom_domain.dmarc_verified = False
Session.commit()
flash( flash(
"DMARC: The TXT record is not correctly set", "DMARC: The TXT record is not correctly set",
"warning", "warning",
) )
dmarc_ok = False dmarc_ok = False
dmarc_errors = txt_records dmarc_errors = dmarc_validation_result.errors
return render_template( return render_template(
"dashboard/domain_detail/dns.html", "dashboard/domain_detail/dns.html",
EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY, EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY,
dkim_records=domain_validator.get_dkim_records(), ownership_record=domain_validator.get_ownership_verification_record(
custom_domain
),
dkim_records=domain_validator.get_dkim_records(custom_domain),
dmarc_record=DMARC_RECORD,
**locals(), **locals(),
) )
@ -238,40 +219,16 @@ def domain_detail(custom_domain_id):
) )
elif request.form.get("form-name") == "update": elif request.form.get("form-name") == "update":
mailbox_ids = request.form.getlist("mailbox_ids") mailbox_ids = request.form.getlist("mailbox_ids")
# check if mailbox is not tempered with result = set_custom_domain_mailboxes(
mailboxes = [] user_id=current_user.id,
for mailbox_id in mailbox_ids: custom_domain=custom_domain,
mailbox = Mailbox.get(mailbox_id) mailbox_ids=mailbox_ids,
if ( )
not mailbox
or mailbox.user_id != current_user.id
or not mailbox.verified
):
flash("Something went wrong, please retry", "warning")
return redirect(
url_for(
"dashboard.domain_detail", custom_domain_id=custom_domain.id
)
)
mailboxes.append(mailbox)
if not mailboxes: if result.success:
flash("You must select at least 1 mailbox", "warning") flash(f"{custom_domain.domain} mailboxes has been updated", "success")
return redirect( else:
url_for( flash(result.reason.value, "warning")
"dashboard.domain_detail", custom_domain_id=custom_domain.id
)
)
# first remove all existing domain-mailboxes links
DomainMailbox.filter_by(domain_id=custom_domain.id).delete()
Session.flush()
for mailbox in mailboxes:
DomainMailbox.create(domain_id=custom_domain.id, mailbox_id=mailbox.id)
Session.commit()
flash(f"{custom_domain.domain} mailboxes has been updated", "success")
return redirect( return redirect(
url_for("dashboard.domain_detail", custom_domain_id=custom_domain.id) url_for("dashboard.domain_detail", custom_domain_id=custom_domain.id)
@ -279,16 +236,8 @@ def domain_detail(custom_domain_id):
elif request.form.get("form-name") == "delete": elif request.form.get("form-name") == "delete":
name = custom_domain.domain name = custom_domain.domain
LOG.d("Schedule deleting %s", custom_domain)
# Schedule delete domain job delete_custom_domain(custom_domain)
LOG.w("schedule delete domain job for %s", custom_domain)
Job.create(
name=JOB_DELETE_DOMAIN,
payload={"custom_domain_id": custom_domain.id},
run_at=arrow.now(),
commit=True,
)
flash( flash(
f"{name} scheduled for deletion." f"{name} scheduled for deletion."

View File

@ -123,7 +123,12 @@ def mailbox_verify():
if not code: if not code:
# Old way # Old way
return verify_with_signed_secret(mailbox_id) return verify_with_signed_secret(mailbox_id)
mailbox = mailbox_utils.verify_mailbox_code(current_user, mailbox_id, code) try:
mailbox = mailbox_utils.verify_mailbox_code(current_user, mailbox_id, code)
except mailbox_utils.MailboxError as e:
LOG.i(f"Cannot verify mailbox {mailbox_id} because of {e}")
flash(f"Cannot verify mailbox: {e.msg}", "error")
return redirect(url_for("dashboard.mailbox_route"))
LOG.d("Mailbox %s is verified", mailbox) LOG.d("Mailbox %s is verified", mailbox)
return render_template("dashboard/mailbox_validation.html", mailbox=mailbox) return render_template("dashboard/mailbox_validation.html", mailbox=mailbox)

View File

@ -1,100 +1,13 @@
from app import config from abc import ABC, abstractmethod
from typing import Optional, List, Tuple from typing import List, Tuple, Optional
import dns.resolver import dns.resolver
from app.config import NAMESERVERS
def _get_dns_resolver():
my_resolver = dns.resolver.Resolver()
my_resolver.nameservers = config.NAMESERVERS
return my_resolver
def get_ns(hostname) -> [str]:
try:
answers = _get_dns_resolver().resolve(hostname, "NS", search=True)
except Exception:
return []
return [a.to_text() for a in answers]
def get_cname_record(hostname) -> Optional[str]:
"""Return the CNAME record if exists for a domain, WITHOUT the trailing period at the end"""
try:
answers = _get_dns_resolver().resolve(hostname, "CNAME", search=True)
except Exception:
return None
for a in answers:
ret = a.to_text()
return ret[:-1]
return None
def get_mx_domains(hostname) -> [(int, str)]:
"""return list of (priority, domain name) sorted by priority (lowest priority first)
domain name ends with a "." at the end.
"""
try:
answers = _get_dns_resolver().resolve(hostname, "MX", search=True)
except Exception:
return []
ret = []
for a in answers:
record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
parts = record.split(" ")
ret.append((int(parts[0]), parts[1]))
return sorted(ret, key=lambda prio_domain: prio_domain[0])
_include_spf = "include:" _include_spf = "include:"
def get_spf_domain(hostname) -> [str]:
"""return all domains listed in *include:*"""
try:
answers = _get_dns_resolver().resolve(hostname, "TXT", search=True)
except Exception:
return []
ret = []
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
for record in a.strings:
record = record.decode() # record is bytes
if record.startswith("v=spf1"):
parts = record.split(" ")
for part in parts:
if part.startswith(_include_spf):
ret.append(part[part.find(_include_spf) + len(_include_spf) :])
return ret
def get_txt_record(hostname) -> [str]:
try:
answers = _get_dns_resolver().resolve(hostname, "TXT", search=True)
except Exception:
return []
ret = []
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
for record in a.strings:
record = record.decode() # record is bytes
ret.append(record)
return ret
def is_mx_equivalent( def is_mx_equivalent(
mx_domains: List[Tuple[int, str]], ref_mx_domains: List[Tuple[int, str]] mx_domains: List[Tuple[int, str]], ref_mx_domains: List[Tuple[int, str]]
) -> bool: ) -> bool:
@ -105,16 +18,127 @@ def is_mx_equivalent(
The priority order is taken into account but not the priority number. The priority order is taken into account but not the priority number.
For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)] For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)]
""" """
mx_domains = sorted(mx_domains, key=lambda priority_domain: priority_domain[0]) mx_domains = sorted(mx_domains, key=lambda x: x[0])
ref_mx_domains = sorted( ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x[0])
ref_mx_domains, key=lambda priority_domain: priority_domain[0]
)
if len(mx_domains) < len(ref_mx_domains): if len(mx_domains) < len(ref_mx_domains):
return False return False
for i in range(0, len(ref_mx_domains)): for i in range(len(ref_mx_domains)):
if mx_domains[i][1] != ref_mx_domains[i][1]: if mx_domains[i][1] != ref_mx_domains[i][1]:
return False return False
return True return True
class DNSClient(ABC):
@abstractmethod
def get_cname_record(self, hostname: str) -> Optional[str]:
pass
@abstractmethod
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
pass
def get_spf_domain(self, hostname: str) -> List[str]:
"""
return all domains listed in *include:*
"""
try:
records = self.get_txt_record(hostname)
ret = []
for record in records:
if record.startswith("v=spf1"):
parts = record.split(" ")
for part in parts:
if part.startswith(_include_spf):
ret.append(
part[part.find(_include_spf) + len(_include_spf) :]
)
return ret
except Exception:
return []
@abstractmethod
def get_txt_record(self, hostname: str) -> List[str]:
pass
class NetworkDNSClient(DNSClient):
def __init__(self, nameservers: List[str]):
self._resolver = dns.resolver.Resolver()
self._resolver.nameservers = nameservers
def get_cname_record(self, hostname: str) -> Optional[str]:
"""
Return the CNAME record if exists for a domain, WITHOUT the trailing period at the end
"""
try:
answers = self._resolver.resolve(hostname, "CNAME", search=True)
for a in answers:
ret = a.to_text()
return ret[:-1]
except Exception:
return None
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
"""
return list of (priority, domain name) sorted by priority (lowest priority first)
domain name ends with a "." at the end.
"""
try:
answers = self._resolver.resolve(hostname, "MX", search=True)
ret = []
for a in answers:
record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
parts = record.split(" ")
ret.append((int(parts[0]), parts[1]))
return sorted(ret, key=lambda x: x[0])
except Exception:
return []
def get_txt_record(self, hostname: str) -> List[str]:
try:
answers = self._resolver.resolve(hostname, "TXT", search=True)
ret = []
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
for record in a.strings:
ret.append(record.decode())
return ret
except Exception:
return []
class InMemoryDNSClient(DNSClient):
def __init__(self):
self.cname_records: dict[str, Optional[str]] = {}
self.mx_records: dict[str, List[Tuple[int, str]]] = {}
self.spf_records: dict[str, List[str]] = {}
self.txt_records: dict[str, List[str]] = {}
def set_cname_record(self, hostname: str, cname: str):
self.cname_records[hostname] = cname
def set_mx_records(self, hostname: str, mx_list: List[Tuple[int, str]]):
self.mx_records[hostname] = mx_list
def set_txt_record(self, hostname: str, txt_list: List[str]):
self.txt_records[hostname] = txt_list
def get_cname_record(self, hostname: str) -> Optional[str]:
return self.cname_records.get(hostname)
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
mx_list = self.mx_records.get(hostname, [])
return sorted(mx_list, key=lambda x: x[0])
def get_txt_record(self, hostname: str) -> List[str]:
return self.txt_records.get(hostname, [])
def get_network_dns_client() -> NetworkDNSClient:
return NetworkDNSClient(NAMESERVERS)
def get_mx_domains(hostname: str) -> [(int, str)]:
return get_network_dns_client().get_mx_domains(hostname)

View File

@ -30,14 +30,30 @@ class PostgresDispatcher(Dispatcher):
return PostgresDispatcher() return PostgresDispatcher()
class GlobalDispatcher:
__dispatcher: Optional[Dispatcher] = None
@staticmethod
def get_dispatcher() -> Dispatcher:
if not GlobalDispatcher.__dispatcher:
GlobalDispatcher.__dispatcher = PostgresDispatcher.get()
return GlobalDispatcher.__dispatcher
@staticmethod
def set_dispatcher(dispatcher: Optional[Dispatcher]):
GlobalDispatcher.__dispatcher = dispatcher
class EventDispatcher: class EventDispatcher:
@staticmethod @staticmethod
def send_event( def send_event(
user: User, user: User,
content: event_pb2.EventContent, content: event_pb2.EventContent,
dispatcher: Dispatcher = PostgresDispatcher.get(), dispatcher: Optional[Dispatcher] = None,
skip_if_webhook_missing: bool = True, skip_if_webhook_missing: bool = True,
): ):
if dispatcher is None:
dispatcher = GlobalDispatcher.get_dispatcher()
if config.EVENT_WEBHOOK_DISABLE: if config.EVENT_WEBHOOK_DISABLE:
LOG.i("Not sending events because webhook is disabled") LOG.i("Not sending events because webhook is disabled")
return return
@ -48,15 +64,9 @@ class EventDispatcher:
) )
return return
if config.EVENT_WEBHOOK_ENABLED_USER_IDS is not None:
if user.id not in config.EVENT_WEBHOOK_ENABLED_USER_IDS:
return
partner_user = EventDispatcher.__partner_user(user.id) partner_user = EventDispatcher.__partner_user(user.id)
if not partner_user: if not partner_user:
LOG.i( LOG.i(f"Not sending events because there's no partner user for user {user}")
f"Not sending events because there's no partner user for user {user}"
)
return return
event = event_pb2.Event( event = event_pb2.Event(
@ -68,7 +78,9 @@ class EventDispatcher:
serialized = event.SerializeToString() serialized = event.SerializeToString()
dispatcher.send(serialized) dispatcher.send(serialized)
newrelic.agent.record_custom_metric("Custom/events_stored", 1)
event_type = content.WhichOneof("content")
newrelic.agent.record_custom_event("EventStoredToDb", {"type": event_type})
LOG.i("Sent event to the dispatcher") LOG.i("Sent event to the dispatcher")
@staticmethod @staticmethod

View File

@ -24,7 +24,7 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0b\x65vent.proto\x12\x12simplelogin_events\"(\n\x0fUserPlanChanged\x12\x15\n\rplan_end_time\x18\x01 \x01(\r\"\r\n\x0bUserDeleted\"Z\n\x0c\x41liasCreated\x12\x10\n\x08\x61lias_id\x18\x01 \x01(\r\x12\x13\n\x0b\x61lias_email\x18\x02 \x01(\t\x12\x12\n\nalias_note\x18\x03 \x01(\t\x12\x0f\n\x07\x65nabled\x18\x04 \x01(\x08\"L\n\x12\x41liasStatusChanged\x12\x10\n\x08\x61lias_id\x18\x01 \x01(\r\x12\x13\n\x0b\x61lias_email\x18\x02 \x01(\t\x12\x0f\n\x07\x65nabled\x18\x03 \x01(\x08\"5\n\x0c\x41liasDeleted\x12\x10\n\x08\x61lias_id\x18\x01 \x01(\r\x12\x13\n\x0b\x61lias_email\x18\x02 \x01(\t\"D\n\x10\x41liasCreatedList\x12\x30\n\x06\x65vents\x18\x01 \x03(\x0b\x32 .simplelogin_events.AliasCreated\"\x93\x03\n\x0c\x45ventContent\x12?\n\x10user_plan_change\x18\x01 \x01(\x0b\x32#.simplelogin_events.UserPlanChangedH\x00\x12\x37\n\x0cuser_deleted\x18\x02 \x01(\x0b\x32\x1f.simplelogin_events.UserDeletedH\x00\x12\x39\n\ralias_created\x18\x03 \x01(\x0b\x32 .simplelogin_events.AliasCreatedH\x00\x12\x45\n\x13\x61lias_status_change\x18\x04 \x01(\x0b\x32&.simplelogin_events.AliasStatusChangedH\x00\x12\x39\n\ralias_deleted\x18\x05 \x01(\x0b\x32 .simplelogin_events.AliasDeletedH\x00\x12\x41\n\x11\x61lias_create_list\x18\x06 \x01(\x0b\x32$.simplelogin_events.AliasCreatedListH\x00\x42\t\n\x07\x63ontent\"y\n\x05\x45vent\x12\x0f\n\x07user_id\x18\x01 \x01(\r\x12\x18\n\x10\x65xternal_user_id\x18\x02 \x01(\t\x12\x12\n\npartner_id\x18\x03 \x01(\r\x12\x31\n\x07\x63ontent\x18\x04 \x01(\x0b\x32 .simplelogin_events.EventContentb\x06proto3') DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0b\x65vent.proto\x12\x12simplelogin_events\"(\n\x0fUserPlanChanged\x12\x15\n\rplan_end_time\x18\x01 \x01(\r\"\r\n\x0bUserDeleted\"\\\n\x0c\x41liasCreated\x12\n\n\x02id\x18\x01 \x01(\r\x12\r\n\x05\x65mail\x18\x02 \x01(\t\x12\x0c\n\x04note\x18\x03 \x01(\t\x12\x0f\n\x07\x65nabled\x18\x04 \x01(\x08\x12\x12\n\ncreated_at\x18\x05 \x01(\r\"T\n\x12\x41liasStatusChanged\x12\n\n\x02id\x18\x01 \x01(\r\x12\r\n\x05\x65mail\x18\x02 \x01(\t\x12\x0f\n\x07\x65nabled\x18\x03 \x01(\x08\x12\x12\n\ncreated_at\x18\x04 \x01(\r\")\n\x0c\x41liasDeleted\x12\n\n\x02id\x18\x01 \x01(\r\x12\r\n\x05\x65mail\x18\x02 \x01(\t\"D\n\x10\x41liasCreatedList\x12\x30\n\x06\x65vents\x18\x01 \x03(\x0b\x32 .simplelogin_events.AliasCreated\"\x93\x03\n\x0c\x45ventContent\x12?\n\x10user_plan_change\x18\x01 \x01(\x0b\x32#.simplelogin_events.UserPlanChangedH\x00\x12\x37\n\x0cuser_deleted\x18\x02 \x01(\x0b\x32\x1f.simplelogin_events.UserDeletedH\x00\x12\x39\n\ralias_created\x18\x03 \x01(\x0b\x32 .simplelogin_events.AliasCreatedH\x00\x12\x45\n\x13\x61lias_status_change\x18\x04 \x01(\x0b\x32&.simplelogin_events.AliasStatusChangedH\x00\x12\x39\n\ralias_deleted\x18\x05 \x01(\x0b\x32 .simplelogin_events.AliasDeletedH\x00\x12\x41\n\x11\x61lias_create_list\x18\x06 \x01(\x0b\x32$.simplelogin_events.AliasCreatedListH\x00\x42\t\n\x07\x63ontent\"y\n\x05\x45vent\x12\x0f\n\x07user_id\x18\x01 \x01(\r\x12\x18\n\x10\x65xternal_user_id\x18\x02 \x01(\t\x12\x12\n\npartner_id\x18\x03 \x01(\r\x12\x31\n\x07\x63ontent\x18\x04 \x01(\x0b\x32 .simplelogin_events.EventContentb\x06proto3')
_globals = globals() _globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@ -36,15 +36,15 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals['_USERDELETED']._serialized_start=77 _globals['_USERDELETED']._serialized_start=77
_globals['_USERDELETED']._serialized_end=90 _globals['_USERDELETED']._serialized_end=90
_globals['_ALIASCREATED']._serialized_start=92 _globals['_ALIASCREATED']._serialized_start=92
_globals['_ALIASCREATED']._serialized_end=182 _globals['_ALIASCREATED']._serialized_end=184
_globals['_ALIASSTATUSCHANGED']._serialized_start=184 _globals['_ALIASSTATUSCHANGED']._serialized_start=186
_globals['_ALIASSTATUSCHANGED']._serialized_end=260 _globals['_ALIASSTATUSCHANGED']._serialized_end=270
_globals['_ALIASDELETED']._serialized_start=262 _globals['_ALIASDELETED']._serialized_start=272
_globals['_ALIASDELETED']._serialized_end=315 _globals['_ALIASDELETED']._serialized_end=313
_globals['_ALIASCREATEDLIST']._serialized_start=317 _globals['_ALIASCREATEDLIST']._serialized_start=315
_globals['_ALIASCREATEDLIST']._serialized_end=385 _globals['_ALIASCREATEDLIST']._serialized_end=383
_globals['_EVENTCONTENT']._serialized_start=388 _globals['_EVENTCONTENT']._serialized_start=386
_globals['_EVENTCONTENT']._serialized_end=791 _globals['_EVENTCONTENT']._serialized_end=789
_globals['_EVENT']._serialized_start=793 _globals['_EVENT']._serialized_start=791
_globals['_EVENT']._serialized_end=914 _globals['_EVENT']._serialized_end=912
# @@protoc_insertion_point(module_scope) # @@protoc_insertion_point(module_scope)

View File

@ -16,34 +16,38 @@ class UserDeleted(_message.Message):
def __init__(self) -> None: ... def __init__(self) -> None: ...
class AliasCreated(_message.Message): class AliasCreated(_message.Message):
__slots__ = ("alias_id", "alias_email", "alias_note", "enabled") __slots__ = ("id", "email", "note", "enabled", "created_at")
ALIAS_ID_FIELD_NUMBER: _ClassVar[int] ID_FIELD_NUMBER: _ClassVar[int]
ALIAS_EMAIL_FIELD_NUMBER: _ClassVar[int] EMAIL_FIELD_NUMBER: _ClassVar[int]
ALIAS_NOTE_FIELD_NUMBER: _ClassVar[int] NOTE_FIELD_NUMBER: _ClassVar[int]
ENABLED_FIELD_NUMBER: _ClassVar[int] ENABLED_FIELD_NUMBER: _ClassVar[int]
alias_id: int CREATED_AT_FIELD_NUMBER: _ClassVar[int]
alias_email: str id: int
alias_note: str email: str
note: str
enabled: bool enabled: bool
def __init__(self, alias_id: _Optional[int] = ..., alias_email: _Optional[str] = ..., alias_note: _Optional[str] = ..., enabled: bool = ...) -> None: ... created_at: int
def __init__(self, id: _Optional[int] = ..., email: _Optional[str] = ..., note: _Optional[str] = ..., enabled: bool = ..., created_at: _Optional[int] = ...) -> None: ...
class AliasStatusChanged(_message.Message): class AliasStatusChanged(_message.Message):
__slots__ = ("alias_id", "alias_email", "enabled") __slots__ = ("id", "email", "enabled", "created_at")
ALIAS_ID_FIELD_NUMBER: _ClassVar[int] ID_FIELD_NUMBER: _ClassVar[int]
ALIAS_EMAIL_FIELD_NUMBER: _ClassVar[int] EMAIL_FIELD_NUMBER: _ClassVar[int]
ENABLED_FIELD_NUMBER: _ClassVar[int] ENABLED_FIELD_NUMBER: _ClassVar[int]
alias_id: int CREATED_AT_FIELD_NUMBER: _ClassVar[int]
alias_email: str id: int
email: str
enabled: bool enabled: bool
def __init__(self, alias_id: _Optional[int] = ..., alias_email: _Optional[str] = ..., enabled: bool = ...) -> None: ... created_at: int
def __init__(self, id: _Optional[int] = ..., email: _Optional[str] = ..., enabled: bool = ..., created_at: _Optional[int] = ...) -> None: ...
class AliasDeleted(_message.Message): class AliasDeleted(_message.Message):
__slots__ = ("alias_id", "alias_email") __slots__ = ("id", "email")
ALIAS_ID_FIELD_NUMBER: _ClassVar[int] ID_FIELD_NUMBER: _ClassVar[int]
ALIAS_EMAIL_FIELD_NUMBER: _ClassVar[int] EMAIL_FIELD_NUMBER: _ClassVar[int]
alias_id: int id: int
alias_email: str email: str
def __init__(self, alias_id: _Optional[int] = ..., alias_email: _Optional[str] = ...) -> None: ... def __init__(self, id: _Optional[int] = ..., email: _Optional[str] = ...) -> None: ...
class AliasCreatedList(_message.Message): class AliasCreatedList(_message.Message):
__slots__ = ("events",) __slots__ = ("events",)

View File

@ -22,10 +22,11 @@ def send_alias_creation_events_for_user(
): ):
event_list.append( event_list.append(
AliasCreated( AliasCreated(
alias_id=alias.id, id=alias.id,
alias_email=alias.email, email=alias.email,
alias_note=alias.note, note=alias.note,
enabled=alias.enabled, enabled=alias.enabled,
created_at=int(alias.created_at.timestamp),
) )
) )
if len(event_list) >= chunk_size: if len(event_list) >= chunk_size:

View File

@ -213,7 +213,10 @@ def generate_activation_code(
) -> MailboxActivation: ) -> MailboxActivation:
clear_activation_codes_for_mailbox(mailbox) clear_activation_codes_for_mailbox(mailbox)
if use_digit_code: if use_digit_code:
code = "{:06d}".format(random.randint(1, 999999)) if config.MAILBOX_VERIFICATION_OVERRIDE_CODE:
code = config.MAILBOX_VERIFICATION_OVERRIDE_CODE
else:
code = "{:06d}".format(random.randint(1, 999999))
else: else:
code = secrets.token_urlsafe(16) code = secrets.token_urlsafe(16)
return MailboxActivation.create( return MailboxActivation.create(

View File

@ -336,7 +336,7 @@ class Fido(Base, ModelMixin):
class User(Base, ModelMixin, UserMixin, PasswordOracle): class User(Base, ModelMixin, UserMixin, PasswordOracle):
__tablename__ = "users" __tablename__ = "users"
FLAG_FREE_DISABLE_CREATE_ALIAS = 1 << 0 FLAG_DISABLE_CREATE_CONTACTS = 1 << 0
FLAG_CREATED_FROM_PARTNER = 1 << 1 FLAG_CREATED_FROM_PARTNER = 1 << 1
FLAG_FREE_OLD_ALIAS_LIMIT = 1 << 2 FLAG_FREE_OLD_ALIAS_LIMIT = 1 << 2
FLAG_CREATED_ALIAS_FROM_PARTNER = 1 << 3 FLAG_CREATED_ALIAS_FROM_PARTNER = 1 << 3
@ -543,7 +543,7 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
# bitwise flags. Allow for future expansion # bitwise flags. Allow for future expansion
flags = sa.Column( flags = sa.Column(
sa.BigInteger, sa.BigInteger,
default=FLAG_FREE_DISABLE_CREATE_ALIAS, default=FLAG_DISABLE_CREATE_CONTACTS,
server_default="0", server_default="0",
nullable=False, nullable=False,
) )
@ -973,7 +973,7 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
def has_custom_domain(self): def has_custom_domain(self):
return CustomDomain.filter_by(user_id=self.id, verified=True).count() > 0 return CustomDomain.filter_by(user_id=self.id, verified=True).count() > 0
def custom_domains(self): def custom_domains(self) -> List["CustomDomain"]:
return CustomDomain.filter_by(user_id=self.id, verified=True).all() return CustomDomain.filter_by(user_id=self.id, verified=True).all()
def available_domains_for_random_alias( def available_domains_for_random_alias(
@ -1168,7 +1168,7 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
def can_create_contacts(self) -> bool: def can_create_contacts(self) -> bool:
if self.is_premium(): if self.is_premium():
return True return True
if self.flags & User.FLAG_FREE_DISABLE_CREATE_ALIAS == 0: if self.flags & User.FLAG_DISABLE_CREATE_CONTACTS == 0:
return True return True
return not config.DISABLE_CREATE_CONTACTS_FOR_FREE_USERS return not config.DISABLE_CREATE_CONTACTS_FOR_FREE_USERS
@ -1660,18 +1660,6 @@ class Alias(Base, ModelMixin):
Session.add(new_alias) Session.add(new_alias)
DailyMetric.get_or_create_today_metric().nb_alias += 1 DailyMetric.get_or_create_today_metric().nb_alias += 1
# Internal import to avoid global import cycles
from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import AliasCreated, EventContent
event = AliasCreated(
alias_id=new_alias.id,
alias_email=new_alias.email,
alias_note=new_alias.note,
enabled=True,
)
EventDispatcher.send_event(user, EventContent(alias_created=event))
if ( if (
new_alias.flags & cls.FLAG_PARTNER_CREATED > 0 new_alias.flags & cls.FLAG_PARTNER_CREATED > 0
and new_alias.user.flags & User.FLAG_CREATED_ALIAS_FROM_PARTNER == 0 and new_alias.user.flags & User.FLAG_CREATED_ALIAS_FROM_PARTNER == 0
@ -1684,6 +1672,19 @@ class Alias(Base, ModelMixin):
if flush: if flush:
Session.flush() Session.flush()
# Internal import to avoid global import cycles
from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import AliasCreated, EventContent
event = AliasCreated(
id=new_alias.id,
email=new_alias.email,
note=new_alias.note,
enabled=True,
created_at=int(new_alias.created_at.timestamp),
)
EventDispatcher.send_event(user, EventContent(alias_created=event))
return new_alias return new_alias
@classmethod @classmethod
@ -1862,6 +1863,8 @@ class Contact(Base, ModelMixin):
MAX_NAME_LENGTH = 512 MAX_NAME_LENGTH = 512
FLAG_PARTNER_CREATED = 1 << 0
__tablename__ = "contact" __tablename__ = "contact"
__table_args__ = ( __table_args__ = (
@ -1920,6 +1923,9 @@ class Contact(Base, ModelMixin):
# whether contact is created automatically during the forward phase # whether contact is created automatically during the forward phase
automatic_created = sa.Column(sa.Boolean, nullable=True, default=False) automatic_created = sa.Column(sa.Boolean, nullable=True, default=False)
# contact flags
flags = sa.Column(sa.Integer, nullable=False, default=0, server_default="0")
@property @property
def email(self): def email(self):
return self.website_email return self.website_email
@ -2418,6 +2424,18 @@ class CustomDomain(Base, ModelMixin):
sa.Boolean, nullable=False, default=False, server_default="0" sa.Boolean, nullable=False, default=False, server_default="0"
) )
partner_id = sa.Column(
sa.Integer,
sa.ForeignKey("partner.id"),
nullable=True,
default=None,
server_default=None,
)
pending_deletion = sa.Column(
sa.Boolean, nullable=False, default=False, server_default="0"
)
__table_args__ = ( __table_args__ = (
Index( Index(
"ix_unique_domain", # Index name "ix_unique_domain", # Index name
@ -2425,6 +2443,8 @@ class CustomDomain(Base, ModelMixin):
unique=True, unique=True,
postgresql_where=Column("ownership_verified"), postgresql_where=Column("ownership_verified"),
), # The condition ), # The condition
Index("ix_custom_domain_user_id", "user_id"),
Index("ix_custom_domain_pending_deletion", "pending_deletion"),
) )
user = orm.relationship(User, foreign_keys=[user_id], backref="custom_domains") user = orm.relationship(User, foreign_keys=[user_id], backref="custom_domains")
@ -2442,9 +2462,6 @@ class CustomDomain(Base, ModelMixin):
def get_trash_url(self): def get_trash_url(self):
return config.URL + f"/dashboard/domains/{self.id}/trash" return config.URL + f"/dashboard/domains/{self.id}/trash"
def get_ownership_dns_txt_value(self):
return f"sl-verification={self.ownership_txt_token}"
@classmethod @classmethod
def create(cls, **kwargs): def create(cls, **kwargs):
domain = kwargs.get("domain") domain = kwargs.get("domain")

View File

@ -2,9 +2,11 @@ from dataclasses import dataclass
from enum import Enum from enum import Enum
from flask import url_for from flask import url_for
from typing import Optional from typing import Optional
import arrow
from app import config
from app.errors import LinkException from app.errors import LinkException
from app.models import User, Partner from app.models import User, Partner, Job
from app.proton.proton_client import ProtonClient, ProtonUser from app.proton.proton_client import ProtonClient, ProtonUser
from app.account_linking import ( from app.account_linking import (
process_login_case, process_login_case,
@ -41,12 +43,21 @@ class ProtonCallbackHandler:
def __init__(self, proton_client: ProtonClient): def __init__(self, proton_client: ProtonClient):
self.proton_client = proton_client self.proton_client = proton_client
def _initial_alias_sync(self, user: User):
Job.create(
name=config.JOB_SEND_ALIAS_CREATION_EVENTS,
payload={"user_id": user.id},
run_at=arrow.now(),
commit=True,
)
def handle_login(self, partner: Partner) -> ProtonCallbackResult: def handle_login(self, partner: Partner) -> ProtonCallbackResult:
try: try:
user = self.__get_partner_user() user = self.__get_partner_user()
if user is None: if user is None:
return generate_account_not_allowed_to_log_in() return generate_account_not_allowed_to_log_in()
res = process_login_case(user, partner) res = process_login_case(user, partner)
self._initial_alias_sync(res.user)
return ProtonCallbackResult( return ProtonCallbackResult(
redirect_to_login=False, redirect_to_login=False,
flash_message=None, flash_message=None,
@ -75,6 +86,7 @@ class ProtonCallbackHandler:
if user is None: if user is None:
return generate_account_not_allowed_to_log_in() return generate_account_not_allowed_to_log_in()
res = process_link_case(user, current_user, partner) res = process_link_case(user, current_user, partner)
self._initial_alias_sync(res.user)
return ProtonCallbackResult( return ProtonCallbackResult(
redirect_to_login=False, redirect_to_login=False,
flash_message="Account successfully linked", flash_message="Account successfully linked",

View File

@ -52,7 +52,7 @@ from flanker.addresslib import address
from flanker.addresslib.address import EmailAddress from flanker.addresslib.address import EmailAddress
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from app import pgp_utils, s3, config from app import pgp_utils, s3, config, contact_utils
from app.alias_utils import try_auto_create, change_alias_status from app.alias_utils import try_auto_create, change_alias_status
from app.config import ( from app.config import (
EMAIL_DOMAIN, EMAIL_DOMAIN,
@ -195,81 +195,16 @@ def get_or_create_contact(from_header: str, mail_from: str, alias: Alias) -> Con
mail_from, mail_from,
) )
contact_email = mail_from contact_email = mail_from
contact_result = contact_utils.create_contact(
if not is_valid_email(contact_email): email=contact_email,
LOG.w( alias=alias,
"invalid contact email %s. Parse from %s %s", name=contact_name,
contact_email, mail_from=mail_from,
from_header, allow_empty_email=True,
mail_from, automatic_created=True,
) from_partner=False,
# either reuse a contact with empty email or create a new contact with empty email )
contact_email = "" return contact_result.contact
contact_email = sanitize_email(contact_email, not_lower=True)
if contact_name and "\x00" in contact_name:
LOG.w("issue with contact name %s", contact_name)
contact_name = ""
contact = Contact.get_by(alias_id=alias.id, website_email=contact_email)
if contact:
if contact.name != contact_name:
LOG.d(
"Update contact %s name %s to %s",
contact,
contact.name,
contact_name,
)
contact.name = contact_name
Session.commit()
# contact created in the past does not have mail_from and from_header field
if not contact.mail_from and mail_from:
LOG.d(
"Set contact mail_from %s: %s to %s",
contact,
contact.mail_from,
mail_from,
)
contact.mail_from = mail_from
Session.commit()
else:
alias_id = alias.id
try:
contact_email_for_reply = (
contact_email if is_valid_email(contact_email) else ""
)
contact = Contact.create(
user_id=alias.user_id,
alias_id=alias_id,
website_email=contact_email,
name=contact_name,
mail_from=mail_from,
reply_email=generate_reply_email(contact_email_for_reply, alias),
automatic_created=True,
)
if not contact_email:
LOG.d("Create a contact with invalid email for %s", alias)
contact.invalid_email = True
LOG.d(
"create contact %s for %s, reverse alias:%s",
contact_email,
alias,
contact.reply_email,
)
Session.commit()
except IntegrityError:
# If the tx has been rolled back, the connection is borked. Force close to try to get a new one and start fresh
Session.close()
LOG.info(
f"Contact with email {contact_email} for alias_id {alias_id} already existed, fetching from DB"
)
contact = Contact.get_by(alias_id=alias_id, website_email=contact_email)
return contact
def get_or_create_reply_to_contact( def get_or_create_reply_to_contact(
@ -294,33 +229,7 @@ def get_or_create_reply_to_contact(
) )
return None return None
contact = Contact.get_by(alias_id=alias.id, website_email=contact_address) return contact_utils.create_contact(contact_address, alias, contact_name).contact
if contact:
return contact
else:
LOG.d(
"create contact %s for alias %s via reply-to header %s",
contact_address,
alias,
reply_to_header,
)
try:
contact = Contact.create(
user_id=alias.user_id,
alias_id=alias.id,
website_email=contact_address,
name=contact_name,
reply_email=generate_reply_email(contact_address, alias),
automatic_created=True,
)
Session.commit()
except IntegrityError:
LOG.w("Contact %s %s already exist", alias, contact_address)
Session.rollback()
contact = Contact.get_by(alias_id=alias.id, website_email=contact_address)
return contact
def replace_header_when_forward(msg: Message, alias: Alias, header: str): def replace_header_when_forward(msg: Message, alias: Alias, header: str):
@ -818,7 +727,7 @@ def forward_email_to_mailbox(
email_log = EmailLog.create( email_log = EmailLog.create(
contact_id=contact.id, contact_id=contact.id,
user_id=user.id, user_id=contact.user_id,
mailbox_id=mailbox.id, mailbox_id=mailbox.id,
alias_id=contact.alias_id, alias_id=contact.alias_id,
message_id=str(msg[headers.MESSAGE_ID]), message_id=str(msg[headers.MESSAGE_ID]),

View File

@ -9,7 +9,7 @@ from events.runner import Runner
from events.event_source import DeadLetterEventSource, PostgresEventSource from events.event_source import DeadLetterEventSource, PostgresEventSource
from events.event_sink import ConsoleEventSink, HttpEventSink from events.event_sink import ConsoleEventSink, HttpEventSink
_DEFAULT_MAX_RETRIES = 100 _DEFAULT_MAX_RETRIES = 10
class Mode(Enum): class Mode(Enum):

View File

@ -27,7 +27,9 @@ class HttpEventSink(EventSink):
headers={"Content-Type": "application/x-protobuf"}, headers={"Content-Type": "application/x-protobuf"},
verify=not EVENT_WEBHOOK_SKIP_VERIFY_SSL, verify=not EVENT_WEBHOOK_SKIP_VERIFY_SSL,
) )
newrelic.agent.record_custom_event("event_sent", {"http_code": res.status_code}) newrelic.agent.record_custom_event(
"EventSentToPartner", {"http_code": res.status_code}
)
if res.status_code != 200: if res.status_code != 200:
LOG.warning( LOG.warning(
f"Failed to send event to webhook: {res.status_code} {res.text}" f"Failed to send event to webhook: {res.status_code} {res.text}"

View File

@ -72,7 +72,9 @@ class PostgresEventSource(EventSource):
Session.close() # Ensure we get a new connection and we don't leave a dangling tx Session.close() # Ensure we get a new connection and we don't leave a dangling tx
def __connect(self): def __connect(self):
self.__connection = psycopg2.connect(self.__connection_string) self.__connection = psycopg2.connect(
self.__connection_string, application_name="sl-event-listen"
)
from app.db import Session from app.db import Session

View File

@ -3,7 +3,7 @@ Run scheduled jobs.
Not meant for running job at precise time (+- 1h) Not meant for running job at precise time (+- 1h)
""" """
import time import time
from typing import List from typing import List, Optional
import arrow import arrow
from sqlalchemy.sql.expression import or_, and_ from sqlalchemy.sql.expression import or_, and_
@ -240,28 +240,30 @@ def process_job(job: Job):
elif job.name == config.JOB_DELETE_DOMAIN: elif job.name == config.JOB_DELETE_DOMAIN:
custom_domain_id = job.payload.get("custom_domain_id") custom_domain_id = job.payload.get("custom_domain_id")
custom_domain = CustomDomain.get(custom_domain_id) custom_domain: Optional[CustomDomain] = CustomDomain.get(custom_domain_id)
if not custom_domain: if not custom_domain:
return return
domain_name = custom_domain.domain domain_name = custom_domain.domain
user = custom_domain.user user = custom_domain.user
custom_domain_partner_id = custom_domain.partner_id
CustomDomain.delete(custom_domain.id) CustomDomain.delete(custom_domain.id)
Session.commit() Session.commit()
LOG.d("Domain %s deleted", domain_name) LOG.d("Domain %s deleted", domain_name)
send_email( if custom_domain_partner_id is None:
user.email, send_email(
f"Your domain {domain_name} has been deleted", user.email,
f"""Domain {domain_name} along with its aliases are deleted successfully. f"Your domain {domain_name} has been deleted",
f"""Domain {domain_name} along with its aliases are deleted successfully.
Regards, Regards,
SimpleLogin team. SimpleLogin team.
""", """,
retries=3, retries=3,
) )
elif job.name == config.JOB_SEND_USER_REPORT: elif job.name == config.JOB_SEND_USER_REPORT:
export_job = ExportUserDataJob.create_from_job(job) export_job = ExportUserDataJob.create_from_job(job)
if export_job: if export_job:

View File

@ -0,0 +1,30 @@
"""Custom Domain partner id
Revision ID: 2441b7ff5da9
Revises: 1c14339aae90
Create Date: 2024-09-13 15:43:02.425964
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '2441b7ff5da9'
down_revision = '1c14339aae90'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('custom_domain', sa.Column('partner_id', sa.Integer(), nullable=True, default=None, server_default=None))
op.create_foreign_key(None, 'custom_domain', 'partner', ['partner_id'], ['id'])
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, 'custom_domain', type_='foreignkey')
op.drop_column('custom_domain', 'partner_id')
# ### end Alembic commands ###

View File

@ -0,0 +1,31 @@
"""contact.flags and custom_domain.pending_deletion
Revision ID: 88dd7a0abf54
Revises: 2441b7ff5da9
Create Date: 2024-09-19 15:41:20.910374
"""
import sqlalchemy_utils
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '88dd7a0abf54'
down_revision = '2441b7ff5da9'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('contact', sa.Column('flags', sa.Integer(), server_default='0', nullable=False))
op.add_column('custom_domain', sa.Column('pending_deletion', sa.Boolean(), server_default='0', nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('custom_domain', 'pending_deletion')
op.drop_column('contact', 'flags')
# ### end Alembic commands ###

View File

@ -0,0 +1,27 @@
"""custom domain indices
Revision ID: 62afa3a10010
Revises: 88dd7a0abf54
Create Date: 2024-09-30 11:40:04.127791
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = '62afa3a10010'
down_revision = '88dd7a0abf54'
branch_labels = None
depends_on = None
def upgrade():
with op.get_context().autocommit_block():
op.create_index('ix_custom_domain_pending_deletion', 'custom_domain', ['pending_deletion'], unique=False, postgresql_concurrently=True)
op.create_index('ix_custom_domain_user_id', 'custom_domain', ['user_id'], unique=False, postgresql_concurrently=True)
def downgrade():
with op.get_context().autocommit_block():
op.drop_index('ix_custom_domain_user_id', table_name='custom_domain', postgresql_concurrently=True)
op.drop_index('ix_custom_domain_pending_deletion', table_name='custom_domain', postgresql_concurrently=True)

View File

@ -94,6 +94,20 @@ def log_nb_db_connection():
newrelic.agent.record_custom_metric("Custom/nb_db_connections", nb_connection) newrelic.agent.record_custom_metric("Custom/nb_db_connections", nb_connection)
@newrelic.agent.background_task()
def log_nb_db_connection_by_app_name():
# get the number of connections to the DB
rows = Session.execute(
"SELECT application_name, count(datid) FROM pg_stat_activity group by application_name"
)
for row in rows:
if row[0].find("sl-") == 0:
LOG.d("number of db connections for app %s = %s", row[0], row[1])
newrelic.agent.record_custom_metric(
f"Custom/nb_db_app_connection/{row[0]}", row[1]
)
@newrelic.agent.background_task() @newrelic.agent.background_task()
def log_pending_to_process_events(): def log_pending_to_process_events():
r = Session.execute("select count(*) from sync_event WHERE taken_time IS NULL;") r = Session.execute("select count(*) from sync_event WHERE taken_time IS NULL;")
@ -125,6 +139,21 @@ def log_events_pending_dead_letter():
) )
@newrelic.agent.background_task()
def log_failed_events():
r = Session.execute(
"""
SELECT COUNT(*)
FROM sync_event
WHERE retry_count >= 10;
""",
)
failed_events = list(r)[0][0]
LOG.d("number of failed events %s", failed_events)
newrelic.agent.record_custom_metric("Custom/sync_events_failed", failed_events)
if __name__ == "__main__": if __name__ == "__main__":
exporter = MetricExporter(get_newrelic_license()) exporter = MetricExporter(get_newrelic_license())
while True: while True:
@ -132,6 +161,8 @@ if __name__ == "__main__":
log_nb_db_connection() log_nb_db_connection()
log_pending_to_process_events() log_pending_to_process_events()
log_events_pending_dead_letter() log_events_pending_dead_letter()
log_failed_events()
log_nb_db_connection_by_app_name()
Session.close() Session.close()
exporter.run() exporter.run()

View File

@ -0,0 +1,49 @@
#!/usr/bin/env python3
import argparse
import time
from sqlalchemy import func
from app.models import Alias
from app.db import Session
parser = argparse.ArgumentParser(
prog="Backfill alias", description="Update alias notes and backfill flag"
)
parser.add_argument(
"-s", "--start_alias_id", default=0, type=int, help="Initial alias_id"
)
parser.add_argument("-e", "--end_alias_id", default=0, type=int, help="Last alias_id")
args = parser.parse_args()
alias_id_start = args.start_alias_id
max_alias_id = args.end_alias_id
if max_alias_id == 0:
max_alias_id = Session.query(func.max(Alias.id)).scalar()
print(f"Checking alias {alias_id_start} to {max_alias_id}")
step = 1000
noteSql = "(note = 'Created through Proton' or note = 'Created through partner Proton')"
alias_query = f"UPDATE alias set note = NULL, flags = flags | :flag where id>=:start AND id<:end and {noteSql}"
updated = 0
start_time = time.time()
for batch_start in range(alias_id_start, max_alias_id, step):
rows_done = Session.execute(
alias_query,
{
"start": batch_start,
"end": batch_start + step,
"flag": Alias.FLAG_PARTNER_CREATED,
},
)
updated += rows_done.rowcount
Session.commit()
elapsed = time.time() - start_time
time_per_alias = elapsed / (updated + 1)
last_batch_id = batch_start + step
remaining = max_alias_id - last_batch_id
time_remaining = (max_alias_id - last_batch_id) * time_per_alias
hours_remaining = time_remaining / 3600.0
print(
f"\rAlias {batch_start}/{max_alias_id} {updated} {hours_remaining:.2f}hrs remaining"
)
print("")

View File

@ -10,21 +10,23 @@ message UserDeleted {
} }
message AliasCreated { message AliasCreated {
uint32 alias_id = 1; uint32 id = 1;
string alias_email = 2; string email = 2;
string alias_note = 3; string note = 3;
bool enabled = 4; bool enabled = 4;
uint32 created_at = 5;
} }
message AliasStatusChanged { message AliasStatusChanged {
uint32 alias_id = 1; uint32 id = 1;
string alias_email = 2; string email = 2;
bool enabled = 3; bool enabled = 3;
uint32 created_at = 4;
} }
message AliasDeleted { message AliasDeleted {
uint32 alias_id = 1; uint32 id = 1;
string alias_email = 2; string email = 2;
} }
message AliasCreatedList { message AliasCreatedList {

View File

@ -12,10 +12,10 @@ docker run -p 25432:5432 --name ${container_name} -e POSTGRES_PASSWORD=postgres
sleep 3 sleep 3
# upgrade the DB to the latest stage and # upgrade the DB to the latest stage and
env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl rye run alembic upgrade head env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl poetry run alembic upgrade head
# generate the migration script. # generate the migration script.
env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl rye run alembic revision --autogenerate $@ env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl poetry run alembic revision --autogenerate $@
# remove the db # remove the db
docker rm -f ${container_name} docker rm -f ${container_name}

View File

@ -19,7 +19,7 @@
<tbody> <tbody>
<tr> <tr>
<td>{{ user.id }}</td> <td>{{ user.id }}</td>
<td>{{ user.email }}</td> <td><a href="?email={{ user.email }}">{{ user.email }}</a></td>
{% if user.disabled %} {% if user.disabled %}
<td class="text-danger">Disabled</td> <td class="text-danger">Disabled</td>
@ -32,7 +32,7 @@
<td>{{ user.updated_at }}</td> <td>{{ user.updated_at }}</td>
{% if pu %} {% if pu %}
<td>{{ pu.partner_email }}</td> <td><a href="?email={{ pu.partner_email }}">{{ pu.partner_email }}</a></td>
{% else %} {% else %}
<td>No</td> <td>No</td>
{% endif %} {% endif %}
@ -43,7 +43,7 @@
{% macro list_mailboxes(mbox_count, mboxes) %} {% macro list_mailboxes(mbox_count, mboxes) %}
<h4> <h4>
{{ mbox_count }} Mailboxes found. {{ mbox_count }} Mailboxes found.
{% if mbox_count>10 %}Showing only the first 10.{% endif %} {% if mbox_count>10 %}Showing only the last 10.{% endif %}
</h4> </h4>
<table class="table"> <table class="table">
<thead> <thead>
@ -59,7 +59,7 @@
<tr> <tr>
<td>{{ mailbox.id }}</td> <td>{{ mailbox.id }}</td>
<td>{{ mailbox.email }}</td> <td><a href="?email={{mailbox.email}}">{{mailbox.email}}</a></td>
<td>{{ "Yes" if mailbox.verified else "No" }}</td> <td>{{ "Yes" if mailbox.verified else "No" }}</td>
<td> <td>
{{ mailbox.created_at }} {{ mailbox.created_at }}
@ -72,7 +72,7 @@
{% macro list_alias(alias_count, aliases) %} {% macro list_alias(alias_count, aliases) %}
<h4> <h4>
{{ alias_count }} Aliases found. {{ alias_count }} Aliases found.
{% if alias_count>10 %}Showing only the first 10.{% endif %} {% if alias_count>10 %}Showing only the last 10.{% endif %}
</h4> </h4>
<table class="table"> <table class="table">
<thead> <thead>
@ -95,7 +95,7 @@
{% for alias in aliases %} {% for alias in aliases %}
<tr> <tr>
<td>{{ alias.id }}</td> <td>{{ alias.id }}</td>
<td>{{ alias.email }}</td> <td><a href="?email={{alias.email}}">{{alias.email}}</a></td>
<td>{{ "Yes" if alias.verified else "No" }}</td> <td>{{ "Yes" if alias.verified else "No" }}</td>
<td>{{ alias.created_at }}</td> <td>{{ alias.created_at }}</td>
</tr> </tr>
@ -156,7 +156,7 @@
{% block body %} {% block body %}
<div class="border border-dark border-2 mt-1 mb-2 p-3"> <div class="border border-dark border-2 mt-1 mb-2 p-3">
<form method="post"> <form method="get">
<div class="form-group"> <div class="form-group">
<label for="email">Email to search:</label> <label for="email">Email to search:</label>
<input type="text" <input type="text"
@ -167,7 +167,7 @@
<button type="submit" class="btn btn-primary">Submit</button> <button type="submit" class="btn btn-primary">Submit</button>
</form> </form>
</div> </div>
{% if no_match %} {% if data.no_match and email %}
<div class="border border-dark border-2 mt-1 mb-2 p-3 alert alert-warning" <div class="border border-dark border-2 mt-1 mb-2 p-3 alert alert-warning"
role="alert">No user, alias or mailbox found for {{ email }}</div> role="alert">No user, alias or mailbox found for {{ email }}</div>
@ -190,14 +190,19 @@
{{ list_alias(helper.alias_count(data.user) ,helper.alias_list(data.user)) }} {{ list_alias(helper.alias_count(data.user) ,helper.alias_list(data.user)) }}
</div> </div>
{% endif %} {% endif %}
{% if data.mailbox %} {% if data.mailbox_count > 10 %}
<h3>Found more than 10 mailboxes for {{ email }}. Showing the last 10</h3>
{% elif data.mailbox_count > 0 %}
<h3>Found {{ data.mailbox_count }} mailbox(es) for {{ email }}</h3>
{% endif %}
{% for mailbox in data.mailbox %}
<div class="border border-dark mt-1 mb-2 p-3"> <div class="border border-dark mt-1 mb-2 p-3">
<h3 class="mb-3">Found Mailbox {{ data.mailbox.email }}</h3> <h3 class="mb-3">Found Mailbox {{ mailbox.email }}</h3>
{{ list_mailboxes(1, [data.mailbox]) }} {{ list_mailboxes(1, [mailbox]) }}
{{ show_user(data.mailbox.user) }} {{ show_user(mailbox.user) }}
</div> </div>
{% endif %} {% endfor %}
{% if data.deleted_alias %} {% if data.deleted_alias %}
<div class="border border-dark mt-1 mb-2 p-3"> <div class="border border-dark mt-1 mb-2 p-3">

View File

@ -94,4 +94,4 @@
</div> </div>
</div> </div>
{% endblock %} {% endblock %}
{% block script %}<script>$('.mailbox-select').multipleSelect();</script>{% endblock %}

View File

@ -38,7 +38,7 @@
Value: <em data-toggle="tooltip" Value: <em data-toggle="tooltip"
title="Click to copy" title="Click to copy"
class="clipboard" class="clipboard"
data-clipboard-text="{{ custom_domain.get_ownership_dns_txt_value() }}">{{ custom_domain.get_ownership_dns_txt_value() }}</em> data-clipboard-text="{{ ownership_record }}">{{ ownership_record }}</em>
</div> </div>
<form method="post" action="#ownership-form"> <form method="post" action="#ownership-form">
{{ csrf_form.csrf_token }} {{ csrf_form.csrf_token }}
@ -237,7 +237,7 @@
folder. folder.
</div> </div>
<div class="mb-2">Add the following CNAME DNS records to your domain.</div> <div class="mb-2">Add the following CNAME DNS records to your domain.</div>
{% for dkim_prefix, dkim_cname_value in dkim_records %} {% for dkim_prefix, dkim_cname_value in dkim_records.items() %}
<div class="mb-2 p-3 dns-record"> <div class="mb-2 p-3 dns-record">
Record: CNAME Record: CNAME

View File

@ -21,8 +21,8 @@
<div class="my-3"> <div class="my-3">
<p>Alternatively you can use your Proton credentials to ensure it's you.</p> <p>Alternatively you can use your Proton credentials to ensure it's you.</p>
</div> </div>
<a class="btn btn-primary btn-block mt-2 proton-button w-25" <a class="btn btn-primary btn-block mt-2 proton-button"
href="{{ url_for('auth.proton_login', next=next) }}"> href="{{ url_for('auth.proton_login', next=next) }}" style="max-width: 400px">
<img class="mr-2" src="/static/images/proton.svg" /> <img class="mr-2" src="/static/images/proton.svg" />
Authenticate with Proton Authenticate with Proton
</a> </a>
@ -38,4 +38,4 @@
{% endif %} {% endif %}
</div> </div>
</div> </div>
{% endblock %} {% endblock %}

View File

@ -536,7 +536,7 @@ def test_create_contact_route_free_users(flask_client):
assert r.status_code == 201 assert r.status_code == 201
# End trial and disallow for new free users. Config should allow it # End trial and disallow for new free users. Config should allow it
user.flags = User.FLAG_FREE_DISABLE_CREATE_ALIAS user.flags = User.FLAG_DISABLE_CREATE_CONTACTS
Session.commit() Session.commit()
r = flask_client.post( r = flask_client.post(
url_for("api.create_contact_route", alias_id=alias.id), url_for("api.create_contact_route", alias_id=alias.id),

View File

@ -4,7 +4,7 @@ from app.models import (
Alias, Alias,
Contact, Contact,
) )
from tests.utils import login from tests.utils import login, random_email
def test_add_contact_success(flask_client): def test_add_contact_success(flask_client):
@ -13,26 +13,28 @@ def test_add_contact_success(flask_client):
assert Contact.filter_by(user_id=user.id).count() == 0 assert Contact.filter_by(user_id=user.id).count() == 0
email = random_email()
# <<< Create a new contact >>> # <<< Create a new contact >>>
flask_client.post( flask_client.post(
url_for("dashboard.alias_contact_manager", alias_id=alias.id), url_for("dashboard.alias_contact_manager", alias_id=alias.id),
data={ data={
"form-name": "create", "form-name": "create",
"email": "abcd@gmail.com", "email": email,
}, },
follow_redirects=True, follow_redirects=True,
) )
# a new contact is added # a new contact is added
assert Contact.filter_by(user_id=user.id).count() == 1 assert Contact.filter_by(user_id=user.id).count() == 1
contact = Contact.filter_by(user_id=user.id).first() contact = Contact.filter_by(user_id=user.id).first()
assert contact.website_email == "abcd@gmail.com" assert contact.website_email == email
# <<< Create a new contact using a full email format >>> # <<< Create a new contact using a full email format >>>
email = random_email()
flask_client.post( flask_client.post(
url_for("dashboard.alias_contact_manager", alias_id=alias.id), url_for("dashboard.alias_contact_manager", alias_id=alias.id),
data={ data={
"form-name": "create", "form-name": "create",
"email": "First Last <another@gmail.com>", "email": f"First Last <{email}>",
}, },
follow_redirects=True, follow_redirects=True,
) )
@ -41,7 +43,7 @@ def test_add_contact_success(flask_client):
contact = ( contact = (
Contact.filter_by(user_id=user.id).filter(Contact.id != contact.id).first() Contact.filter_by(user_id=user.id).filter(Contact.id != contact.id).first()
) )
assert contact.website_email == "another@gmail.com" assert contact.website_email == email
assert contact.name == "First Last" assert contact.name == "First Last"
# <<< Create a new contact with invalid email address >>> # <<< Create a new contact with invalid email address >>>

View File

@ -0,0 +1,32 @@
from app.events.event_dispatcher import Dispatcher
from app.models import PartnerUser, User
from app.proton.utils import get_proton_partner
from tests.utils import create_new_user, random_token
from typing import Tuple
class OnMemoryDispatcher(Dispatcher):
def __init__(self):
self.memory = []
def send(self, event: bytes):
self.memory.append(event)
def clear(self):
self.memory = []
def _create_unlinked_user() -> User:
return create_new_user()
def _create_linked_user() -> Tuple[User, PartnerUser]:
user = _create_unlinked_user()
partner_user = PartnerUser.create(
partner_id=get_proton_partner().id,
user_id=user.id,
external_user_id=random_token(10),
flush=True,
)
return user, partner_user

View File

@ -1,33 +1,10 @@
from app.events.event_dispatcher import EventDispatcher, Dispatcher from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import EventContent, UserDeleted from app.events.generated.event_pb2 import EventContent, UserDeleted
from app.models import PartnerUser, User from .event_test_utils import (
from app.proton.utils import get_proton_partner _create_unlinked_user,
from tests.utils import create_new_user, random_token OnMemoryDispatcher,
from typing import Tuple _create_linked_user,
)
class OnMemoryDispatcher(Dispatcher):
def __init__(self):
self.memory = []
def send(self, event: bytes):
self.memory.append(event)
def _create_unlinked_user() -> User:
return create_new_user()
def _create_linked_user() -> Tuple[User, PartnerUser]:
user = _create_unlinked_user()
partner_user = PartnerUser.create(
partner_id=get_proton_partner().id,
user_id=user.id,
external_user_id=random_token(10),
flush=True,
)
return user, partner_user
def test_event_dispatcher_stores_events(): def test_event_dispatcher_stores_events():

View File

@ -0,0 +1,102 @@
from app import config, alias_utils
from app.db import Session
from app.events.event_dispatcher import GlobalDispatcher
from app.events.generated import event_pb2
from app.models import Alias, User, PartnerUser
from tests.utils import random_token
from .event_test_utils import (
OnMemoryDispatcher,
_create_linked_user,
)
on_memory_dispatcher = OnMemoryDispatcher()
def setup_module():
GlobalDispatcher.set_dispatcher(on_memory_dispatcher)
config.EVENT_WEBHOOK = "http://test"
def teardown_module():
GlobalDispatcher.set_dispatcher(None)
config.EVENT_WEBHOOK = None
def setup_function(func):
on_memory_dispatcher.clear()
def _get_event_from_string(
data: str, user: User, pu: PartnerUser
) -> event_pb2.EventContent:
event = event_pb2.Event()
event.ParseFromString(data)
assert user.id == event.user_id
assert pu.external_user_id == event.external_user_id
assert pu.partner_id == event.partner_id
return event.content
def test_fire_event_on_alias_creation():
(user, pu) = _create_linked_user()
alias = Alias.create_new_random(user)
Session.flush()
assert len(on_memory_dispatcher.memory) == 1
event_data = on_memory_dispatcher.memory[0]
event_content = _get_event_from_string(event_data, user, pu)
assert event_content.alias_created is not None
alias_created = event_content.alias_created
assert alias.id == alias_created.id
assert alias.email == alias_created.email
assert "" == alias_created.note
assert alias.enabled == alias_created.enabled
assert int(alias.created_at.timestamp) == alias_created.created_at
def test_fire_event_on_alias_creation_with_note():
(user, pu) = _create_linked_user()
note = random_token(10)
alias = Alias.create_new_random(user, note=note)
Session.flush()
assert len(on_memory_dispatcher.memory) == 1
event_data = on_memory_dispatcher.memory[0]
event_content = _get_event_from_string(event_data, user, pu)
assert event_content.alias_created is not None
alias_created = event_content.alias_created
assert alias.id == alias_created.id
assert alias.email == alias_created.email
assert note == alias_created.note
assert alias.enabled == alias_created.enabled
def test_fire_event_on_alias_deletion():
(user, pu) = _create_linked_user()
alias = Alias.create_new_random(user)
alias_id = alias.id
Session.flush()
on_memory_dispatcher.clear()
alias_utils.delete_alias(alias, user)
assert len(on_memory_dispatcher.memory) == 1
event_data = on_memory_dispatcher.memory[0]
event_content = _get_event_from_string(event_data, user, pu)
assert event_content.alias_deleted is not None
alias_deleted = event_content.alias_deleted
assert alias_id == alias_deleted.id
assert alias.email == alias_deleted.email
def test_fire_event_on_alias_status_change():
(user, pu) = _create_linked_user()
alias = Alias.create_new_random(user)
Session.flush()
on_memory_dispatcher.clear()
alias_utils.change_alias_status(alias, True)
assert len(on_memory_dispatcher.memory) == 1
event_data = on_memory_dispatcher.memory[0]
event_content = _get_event_from_string(event_data, user, pu)
assert event_content.alias_status_change is not None
event = event_content.alias_status_change
assert alias.id == event.id
assert alias.email == event.email
assert int(alias.created_at.timestamp) == event.created_at
assert event.enabled

View File

@ -37,10 +37,14 @@ def test_send_alias_creation_events():
event_list = decoded_event.content.alias_create_list.events event_list = decoded_event.content.alias_create_list.events
assert len(event_list) == 2 assert len(event_list) == 2
# 0 is newsletter alias # 0 is newsletter alias
assert event_list[1].alias_id == aliases[0].id assert event_list[1].id == aliases[0].id
assert event_list[1].email == aliases[0].email
assert event_list[1].note == ""
assert event_list[1].enabled == aliases[0].enabled
assert event_list[1].created_at == int(aliases[0].created_at.timestamp)
decoded_event = event_pb2.Event.FromString(dispatcher.events[1]) decoded_event = event_pb2.Event.FromString(dispatcher.events[1])
assert decoded_event.user_id == user.id assert decoded_event.user_id == user.id
assert decoded_event.external_user_id == partner_user.external_user_id assert decoded_event.external_user_id == partner_user.external_user_id
event_list = decoded_event.content.alias_create_list.events event_list = decoded_event.content.alias_create_list.events
assert len(event_list) == 1 assert len(event_list) == 1
assert event_list[0].alias_id == aliases[1].id assert event_list[0].id == aliases[1].id

View File

@ -1,4 +1,6 @@
from arrow import Arrow from arrow import Arrow
from app import config
from app.account_linking import ( from app.account_linking import (
SLPlan, SLPlan,
SLPlanType, SLPlanType,
@ -8,7 +10,7 @@ from app.proton.proton_callback_handler import (
ProtonCallbackHandler, ProtonCallbackHandler,
generate_account_not_allowed_to_log_in, generate_account_not_allowed_to_log_in,
) )
from app.models import User, PartnerUser from app.models import User, PartnerUser, Job, JobState
from app.proton.utils import get_proton_partner from app.proton.utils import get_proton_partner
from app.utils import random_string from app.utils import random_string
from typing import Optional from typing import Optional
@ -23,6 +25,17 @@ class MockProtonClient(ProtonClient):
return self.user return self.user
def check_initial_sync_job(user: User):
for job in Job.yield_per_query(10).filter_by(
name=config.JOB_SEND_ALIAS_CREATION_EVENTS,
state=JobState.ready.value,
):
if job.payload.get("user_id") == user.id:
Job.delete(job.id)
return
assert False
def test_proton_callback_handler_unexistant_sl_user(): def test_proton_callback_handler_unexistant_sl_user():
email = random_email() email = random_email()
name = random_string() name = random_string()
@ -56,6 +69,7 @@ def test_proton_callback_handler_unexistant_sl_user():
) )
assert partner_user is not None assert partner_user is not None
assert partner_user.external_user_id == external_id assert partner_user.external_user_id == external_id
check_initial_sync_job(res.user)
def test_proton_callback_handler_existant_sl_user(): def test_proton_callback_handler_existant_sl_user():
@ -84,6 +98,7 @@ def test_proton_callback_handler_existant_sl_user():
sa = PartnerUser.get_by(user_id=sl_user.id, partner_id=get_proton_partner().id) sa = PartnerUser.get_by(user_id=sl_user.id, partner_id=get_proton_partner().id)
assert sa is not None assert sa is not None
assert sa.partner_email == user.email assert sa.partner_email == user.email
check_initial_sync_job(res.user)
def test_proton_callback_handler_none_user_login(): def test_proton_callback_handler_none_user_login():

View File

@ -0,0 +1,196 @@
from typing import Optional
import pytest
from app import config
from app.contact_utils import create_contact, ContactCreateError
from app.db import Session
from app.models import (
Alias,
Contact,
User,
)
from tests.utils import create_new_user, random_email, random_token
def setup_module(module):
config.DISABLE_CREATE_CONTACTS_FOR_FREE_USERS = True
def teardown_module(module):
config.DISABLE_CREATE_CONTACTS_FOR_FREE_USERS = False
def create_provider():
# name auto_created from_partner
yield ["name", "a@b.c", True, True]
yield [None, None, True, True]
yield [None, None, False, True]
yield [None, None, True, False]
yield [None, None, False, False]
@pytest.mark.parametrize(
"name, mail_from, automatic_created, from_partner", create_provider()
)
def test_create_contact(
name: Optional[str],
mail_from: Optional[str],
automatic_created: bool,
from_partner: bool,
):
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
email = random_email()
contact_result = create_contact(
email,
alias,
name=name,
mail_from=mail_from,
automatic_created=automatic_created,
from_partner=from_partner,
)
assert contact_result.error is None
contact = contact_result.contact
assert contact.user_id == user.id
assert contact.alias_id == alias.id
assert contact.website_email == email
assert contact.name == name
assert contact.mail_from == mail_from
assert contact.automatic_created == automatic_created
assert not contact.invalid_email
expected_flags = Contact.FLAG_PARTNER_CREATED if from_partner else 0
assert contact.flags == expected_flags
def test_create_contact_email_email_not_allowed():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
contact_result = create_contact("", alias)
assert contact_result.contact is None
assert contact_result.error == ContactCreateError.InvalidEmail
def test_create_contact_email_email_allowed():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
contact_result = create_contact("", alias, allow_empty_email=True)
assert contact_result.error is None
assert contact_result.contact is not None
assert contact_result.contact.website_email == ""
assert contact_result.contact.invalid_email
def test_create_contact_name_overrides_email_name():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
email = random_email()
name = random_token()
contact_result = create_contact(f"superseeded <{email}>", alias, name=name)
assert contact_result.error is None
assert contact_result.contact is not None
assert contact_result.contact.website_email == email
assert contact_result.contact.name == name
def test_create_contact_name_taken_from_email():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
email = random_email()
name = random_token()
contact_result = create_contact(f"{name} <{email}>", alias)
assert contact_result.error is None
assert contact_result.contact is not None
assert contact_result.contact.website_email == email
assert contact_result.contact.name == name
def test_create_contact_empty_name_is_none():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
email = random_email()
contact_result = create_contact(email, alias, name="")
assert contact_result.error is None
assert contact_result.contact is not None
assert contact_result.contact.website_email == email
assert contact_result.contact.name is None
def test_create_contact_free_user():
user = create_new_user()
user.trial_end = None
user.flags = 0
alias = Alias.create_new_random(user)
Session.flush()
# Free users without the FREE_DISABLE_CREATE_CONTACTS
result = create_contact(random_email(), alias)
assert result.error is None
assert result.created
assert result.contact is not None
assert not result.contact.automatic_created
# Free users with the flag should be able to still create automatic emails
user.flags = User.FLAG_DISABLE_CREATE_CONTACTS
Session.flush()
result = create_contact(random_email(), alias, automatic_created=True)
assert result.error is None
assert result.created
assert result.contact is not None
assert result.contact.automatic_created
# Free users with the flag cannot create non-automatic emails
result = create_contact(random_email(), alias)
assert result.error == ContactCreateError.NotAllowed
def test_do_not_allow_invalid_email():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
contact_result = create_contact("potato", alias)
assert contact_result.contact is None
assert contact_result.error == ContactCreateError.InvalidEmail
contact_result = create_contact("asdf\x00@gmail.com", alias)
assert contact_result.contact is None
assert contact_result.error == ContactCreateError.InvalidEmail
def test_update_name_for_existing():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
email = random_email()
contact_result = create_contact(email, alias)
assert contact_result.error is None
assert contact_result.created
assert contact_result.contact is not None
assert contact_result.contact.name is None
name = random_token()
contact_result = create_contact(email, alias, name=name)
assert contact_result.error is None
assert not contact_result.created
assert contact_result.contact is not None
assert contact_result.contact.name == name
def test_update_mail_from_for_existing():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
email = random_email()
contact_result = create_contact(email, alias)
assert contact_result.error is None
assert contact_result.created
assert contact_result.contact is not None
assert contact_result.contact is not None
assert contact_result.contact.mail_from is None
mail_from = random_email()
contact_result = create_contact(email, alias, mail_from=mail_from)
assert contact_result.error is None
assert not contact_result.created
assert contact_result.contact is not None
assert contact_result.contact.mail_from == mail_from

View File

@ -0,0 +1,267 @@
from typing import Optional
from app import config
from app.config import ALIAS_DOMAINS
from app.custom_domain_utils import (
can_domain_be_used,
create_custom_domain,
is_valid_domain,
sanitize_domain,
set_custom_domain_mailboxes,
CannotUseDomainReason,
CannotSetCustomDomainMailboxesCause,
)
from app.db import Session
from app.models import User, CustomDomain, Mailbox, DomainMailbox
from tests.utils import get_proton_partner, random_email
from tests.utils import create_new_user, random_string, random_domain
user: Optional[User] = None
def setup_module():
global user
config.SKIP_MX_LOOKUP_ON_CHECK = True
user = create_new_user()
user.trial_end = None
user.lifetime = True
Session.commit()
# is_valid_domain
def test_is_valid_domain():
assert is_valid_domain("example.com") is True
assert is_valid_domain("sub.example.com") is True
assert is_valid_domain("ex-ample.com") is True
assert is_valid_domain("-example.com") is False
assert is_valid_domain("example-.com") is False
assert is_valid_domain("exa_mple.com") is False
assert is_valid_domain("example..com") is False
assert is_valid_domain("") is False
assert is_valid_domain("a" * 64 + ".com") is False
assert is_valid_domain("a" * 63 + ".com") is True
assert is_valid_domain("example.com.") is True
assert is_valid_domain(".example.com") is False
assert is_valid_domain("example..com") is False
assert is_valid_domain("example.com-") is False
# can_domain_be_used
def test_can_domain_be_used():
domain = f"{random_string(10)}.com"
res = can_domain_be_used(user, domain)
assert res is None
def test_can_domain_be_used_existing_domain():
domain = random_domain()
CustomDomain.create(user_id=user.id, domain=domain, commit=True)
res = can_domain_be_used(user, domain)
assert res is CannotUseDomainReason.DomainAlreadyUsed
def test_can_domain_be_used_sl_domain():
domain = ALIAS_DOMAINS[0]
res = can_domain_be_used(user, domain)
assert res is CannotUseDomainReason.BuiltinDomain
def test_can_domain_be_used_domain_of_user_email():
domain = user.email.split("@")[1]
res = can_domain_be_used(user, domain)
assert res is CannotUseDomainReason.DomainPartOfUserEmail
def test_can_domain_be_used_domain_of_existing_mailbox():
domain = random_domain()
Mailbox.create(user_id=user.id, email=f"email@{domain}", verified=True, commit=True)
res = can_domain_be_used(user, domain)
assert res is CannotUseDomainReason.DomainUserInMailbox
def test_can_domain_be_used_invalid_domain():
domain = f"{random_string(10)}@lol.com"
res = can_domain_be_used(user, domain)
assert res is CannotUseDomainReason.InvalidDomain
# sanitize_domain
def test_can_sanitize_domain_empty():
assert sanitize_domain("") == ""
def test_can_sanitize_domain_starting_with_http():
domain = "test.domain"
assert sanitize_domain(f"http://{domain}") == domain
def test_can_sanitize_domain_starting_with_https():
domain = "test.domain"
assert sanitize_domain(f"https://{domain}") == domain
def test_can_sanitize_domain_correct_domain():
domain = "test.domain"
assert sanitize_domain(domain) == domain
# create_custom_domain
def test_can_create_custom_domain():
domain = random_domain()
res = create_custom_domain(user=user, domain=domain)
assert res.success is True
assert res.redirect is None
assert res.message == ""
assert res.message_category == ""
assert res.instance is not None
assert res.instance.domain == domain
assert res.instance.user_id == user.id
def test_can_create_custom_domain_validates_if_parent_is_validated():
root_domain = random_domain()
subdomain = f"{random_string(10)}.{root_domain}"
# Create custom domain with the root domain
CustomDomain.create(
user_id=user.id,
domain=root_domain,
verified=True,
ownership_verified=True,
commit=True,
)
# Create custom domain with subdomain. Should automatically be verified
res = create_custom_domain(user=user, domain=subdomain)
assert res.success is True
assert res.instance.domain == subdomain
assert res.instance.user_id == user.id
assert res.instance.ownership_verified is True
def test_creates_custom_domain_with_partner_id():
domain = random_domain()
proton_partner = get_proton_partner()
res = create_custom_domain(user=user, domain=domain, partner_id=proton_partner.id)
assert res.success is True
assert res.instance.domain == domain
assert res.instance.user_id == user.id
assert res.instance.partner_id == proton_partner.id
# set_custom_domain_mailboxes
def test_set_custom_domain_mailboxes_empty_list():
domain = CustomDomain.create(user_id=user.id, domain=random_domain(), commit=True)
res = set_custom_domain_mailboxes(user.id, domain, [])
assert res.success is False
assert res.reason == CannotSetCustomDomainMailboxesCause.NoMailboxes
def test_set_custom_domain_mailboxes_mailbox_from_another_user():
other_user = create_new_user()
other_mailbox = Mailbox.create(
user_id=other_user.id, email=random_email(), verified=True
)
domain = CustomDomain.create(user_id=user.id, domain=random_domain(), commit=True)
res = set_custom_domain_mailboxes(user.id, domain, [other_mailbox.id])
assert res.success is False
assert res.reason == CannotSetCustomDomainMailboxesCause.InvalidMailbox
def test_set_custom_domain_mailboxes_mailbox_from_current_user_and_another_user():
other_user = create_new_user()
other_mailbox = Mailbox.create(
user_id=other_user.id, email=random_email(), verified=True
)
domain = CustomDomain.create(user_id=user.id, domain=random_domain(), commit=True)
res = set_custom_domain_mailboxes(
user.id, domain, [user.default_mailbox_id, other_mailbox.id]
)
assert res.success is False
assert res.reason == CannotSetCustomDomainMailboxesCause.InvalidMailbox
def test_set_custom_domain_mailboxes_success():
other_mailbox = Mailbox.create(user_id=user.id, email=random_email(), verified=True)
domain = CustomDomain.create(user_id=user.id, domain=random_domain(), commit=True)
res = set_custom_domain_mailboxes(
user.id, domain, [user.default_mailbox_id, other_mailbox.id]
)
assert res.success is True
assert res.reason is None
domain_mailboxes = DomainMailbox.filter_by(domain_id=domain.id).all()
assert len(domain_mailboxes) == 2
assert domain_mailboxes[0].domain_id == domain.id
assert domain_mailboxes[0].mailbox_id == user.default_mailbox_id
assert domain_mailboxes[1].domain_id == domain.id
assert domain_mailboxes[1].mailbox_id == other_mailbox.id
def test_set_custom_domain_mailboxes_set_twice():
other_mailbox = Mailbox.create(user_id=user.id, email=random_email(), verified=True)
domain = CustomDomain.create(user_id=user.id, domain=random_domain(), commit=True)
res = set_custom_domain_mailboxes(
user.id, domain, [user.default_mailbox_id, other_mailbox.id]
)
assert res.success is True
assert res.reason is None
res = set_custom_domain_mailboxes(
user.id, domain, [user.default_mailbox_id, other_mailbox.id]
)
assert res.success is True
assert res.reason is None
domain_mailboxes = DomainMailbox.filter_by(domain_id=domain.id).all()
assert len(domain_mailboxes) == 2
assert domain_mailboxes[0].domain_id == domain.id
assert domain_mailboxes[0].mailbox_id == user.default_mailbox_id
assert domain_mailboxes[1].domain_id == domain.id
assert domain_mailboxes[1].mailbox_id == other_mailbox.id
def test_set_custom_domain_mailboxes_removes_old_association():
domain = CustomDomain.create(user_id=user.id, domain=random_domain(), commit=True)
res = set_custom_domain_mailboxes(user.id, domain, [user.default_mailbox_id])
assert res.success is True
assert res.reason is None
other_mailbox = Mailbox.create(
user_id=user.id, email=random_email(), verified=True, commit=True
)
res = set_custom_domain_mailboxes(user.id, domain, [other_mailbox.id])
assert res.success is True
assert res.reason is None
domain_mailboxes = DomainMailbox.filter_by(domain_id=domain.id).all()
assert len(domain_mailboxes) == 1
assert domain_mailboxes[0].domain_id == domain.id
assert domain_mailboxes[0].mailbox_id == other_mailbox.id
def test_set_custom_domain_mailboxes_with_unverified_mailbox():
domain = CustomDomain.create(user_id=user.id, domain=random_domain())
verified_mailbox = Mailbox.create(
user_id=user.id,
email=random_email(),
verified=True,
)
unverified_mailbox = Mailbox.create(
user_id=user.id,
email=random_email(),
verified=False,
)
res = set_custom_domain_mailboxes(
user.id, domain, [verified_mailbox.id, unverified_mailbox.id]
)
assert res.success is False
assert res.reason is CannotSetCustomDomainMailboxesCause.InvalidMailbox

View File

@ -0,0 +1,376 @@
from typing import Optional
from app import config
from app.constants import DMARC_RECORD
from app.custom_domain_validation import CustomDomainValidation
from app.db import Session
from app.models import CustomDomain, User
from app.dns_utils import InMemoryDNSClient
from app.proton.utils import get_proton_partner
from app.utils import random_string
from tests.utils import create_new_user, random_domain
user: Optional[User] = None
def setup_module():
global user
config.SKIP_MX_LOOKUP_ON_CHECK = True
user = create_new_user()
user.trial_end = None
user.lifetime = True
Session.commit()
def create_custom_domain(domain: str) -> CustomDomain:
return CustomDomain.create(user_id=user.id, domain=domain, commit=True)
def test_custom_domain_validation_get_dkim_records():
domain = random_domain()
custom_domain = create_custom_domain(domain)
validator = CustomDomainValidation(domain)
records = validator.get_dkim_records(custom_domain)
assert len(records) == 3
assert records["dkim02._domainkey"] == f"dkim02._domainkey.{domain}"
assert records["dkim03._domainkey"] == f"dkim03._domainkey.{domain}"
assert records["dkim._domainkey"] == f"dkim._domainkey.{domain}"
def test_custom_domain_validation_get_dkim_records_for_partner():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
custom_domain.partner_id = partner_id
Session.commit()
dkim_domain = random_domain()
validator = CustomDomainValidation(
domain, partner_domains={partner_id: dkim_domain}
)
records = validator.get_dkim_records(custom_domain)
assert len(records) == 3
assert records["dkim02._domainkey"] == f"dkim02._domainkey.{dkim_domain}"
assert records["dkim03._domainkey"] == f"dkim03._domainkey.{dkim_domain}"
assert records["dkim._domainkey"] == f"dkim._domainkey.{dkim_domain}"
# validate_dkim_records
def test_custom_domain_validation_validate_dkim_records_empty_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
res = validator.validate_dkim_records(domain)
assert len(res) == 3
for record_value in res.values():
assert record_value == "empty"
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dkim_verified is False
def test_custom_domain_validation_validate_dkim_records_wrong_records_failure():
dkim_domain = random_domain()
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(dkim_domain, dns_client)
user_domain = random_domain()
# One domain right, two domains wrong
dns_client.set_cname_record(
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
)
dns_client.set_cname_record(f"dkim02._domainkey.{user_domain}", "wrong")
dns_client.set_cname_record(f"dkim03._domainkey.{user_domain}", "wrong")
domain = create_custom_domain(user_domain)
res = validator.validate_dkim_records(domain)
assert len(res) == 2
for record_value in res.values():
assert record_value == "wrong"
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dkim_verified is False
def test_custom_domain_validation_validate_dkim_records_success_with_old_system():
dkim_domain = random_domain()
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(dkim_domain, dns_client)
user_domain = random_domain()
# One domain right, other domains missing
dns_client.set_cname_record(
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
)
domain = create_custom_domain(user_domain)
# DKIM is verified
domain.dkim_verified = True
Session.commit()
res = validator.validate_dkim_records(domain)
assert len(res) == 2
assert f"dkim02._domainkey.{user_domain}" in res
assert f"dkim03._domainkey.{user_domain}" in res
# Flag is not cleared
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dkim_verified is True
def test_custom_domain_validation_validate_dkim_records_success():
dkim_domain = random_domain()
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(dkim_domain, dns_client)
user_domain = random_domain()
# One domain right, two domains wrong
dns_client.set_cname_record(
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
)
dns_client.set_cname_record(
f"dkim02._domainkey.{user_domain}", f"dkim02._domainkey.{dkim_domain}"
)
dns_client.set_cname_record(
f"dkim03._domainkey.{user_domain}", f"dkim03._domainkey.{dkim_domain}"
)
domain = create_custom_domain(user_domain)
res = validator.validate_dkim_records(domain)
assert len(res) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dkim_verified is True
# validate_ownership
def test_custom_domain_validation_validate_ownership_empty_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
res = validator.validate_domain_ownership(domain)
assert res.success is False
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.ownership_verified is False
def test_custom_domain_validation_validate_ownership_wrong_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
wrong_records = [random_string()]
dns_client.set_txt_record(domain.domain, wrong_records)
res = validator.validate_domain_ownership(domain)
assert res.success is False
assert res.errors == wrong_records
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.ownership_verified is False
def test_custom_domain_validation_validate_ownership_success():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
dns_client.set_txt_record(
domain.domain, [validator.get_ownership_verification_record(domain)]
)
res = validator.validate_domain_ownership(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.ownership_verified is True
def test_custom_domain_validation_validate_ownership_from_partner_success():
dns_client = InMemoryDNSClient()
partner_id = get_proton_partner().id
prefix = random_string()
validator = CustomDomainValidation(
random_domain(),
dns_client,
partner_domains_validation_prefixes={partner_id: prefix},
)
domain = create_custom_domain(random_domain())
domain.partner_id = partner_id
Session.commit()
dns_client.set_txt_record(
domain.domain, [validator.get_ownership_verification_record(domain)]
)
res = validator.validate_domain_ownership(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.ownership_verified is True
# validate_mx_records
def test_custom_domain_validation_validate_mx_records_empty_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
res = validator.validate_mx_records(domain)
assert res.success is False
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.verified is False
def test_custom_domain_validation_validate_mx_records_wrong_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
wrong_record_1 = random_string()
wrong_record_2 = random_string()
wrong_records = [(10, wrong_record_1), (20, wrong_record_2)]
dns_client.set_mx_records(domain.domain, wrong_records)
res = validator.validate_mx_records(domain)
assert res.success is False
assert res.errors == [f"10 {wrong_record_1}", f"20 {wrong_record_2}"]
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.verified is False
def test_custom_domain_validation_validate_mx_records_success():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
dns_client.set_mx_records(domain.domain, config.EMAIL_SERVERS_WITH_PRIORITY)
res = validator.validate_mx_records(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.verified is True
# validate_spf_records
def test_custom_domain_validation_validate_spf_records_empty_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
res = validator.validate_spf_records(domain)
assert res.success is False
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.spf_verified is False
def test_custom_domain_validation_validate_spf_records_wrong_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
wrong_records = [random_string()]
dns_client.set_txt_record(domain.domain, wrong_records)
res = validator.validate_spf_records(domain)
assert res.success is False
assert res.errors == wrong_records
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.spf_verified is False
def test_custom_domain_validation_validate_spf_records_success():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
dns_client.set_txt_record(domain.domain, [f"v=spf1 include:{config.EMAIL_DOMAIN}"])
res = validator.validate_spf_records(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.spf_verified is True
# validate_dmarc_records
def test_custom_domain_validation_validate_dmarc_records_empty_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
res = validator.validate_dmarc_records(domain)
assert res.success is False
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dmarc_verified is False
def test_custom_domain_validation_validate_dmarc_records_wrong_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
wrong_records = [random_string()]
dns_client.set_txt_record(f"_dmarc.{domain.domain}", wrong_records)
res = validator.validate_dmarc_records(domain)
assert res.success is False
assert res.errors == wrong_records
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dmarc_verified is False
def test_custom_domain_validation_validate_dmarc_records_success():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
dns_client.set_txt_record(f"_dmarc.{domain.domain}", [DMARC_RECORD])
res = validator.validate_dmarc_records(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dmarc_verified is True

View File

@ -1,10 +1,12 @@
from app.dns_utils import ( from app.dns_utils import (
get_mx_domains, get_mx_domains,
get_spf_domain, get_network_dns_client,
get_txt_record,
is_mx_equivalent, is_mx_equivalent,
InMemoryDNSClient,
) )
from tests.utils import random_domain
# use our own domain for test # use our own domain for test
_DOMAIN = "simplelogin.io" _DOMAIN = "simplelogin.io"
@ -20,12 +22,12 @@ def test_get_mx_domains():
def test_get_spf_domain(): def test_get_spf_domain():
r = get_spf_domain(_DOMAIN) r = get_network_dns_client().get_spf_domain(_DOMAIN)
assert r == ["simplelogin.co"] assert r == ["simplelogin.co"]
def test_get_txt_record(): def test_get_txt_record():
r = get_txt_record(_DOMAIN) r = get_network_dns_client().get_txt_record(_DOMAIN)
assert len(r) > 0 assert len(r) > 0
@ -46,3 +48,15 @@ def test_is_mx_equivalent():
[(5, "domain1"), (10, "domain2")], [(5, "domain1"), (10, "domain2")],
[(10, "domain1"), (20, "domain2"), (20, "domain3")], [(10, "domain1"), (20, "domain2"), (20, "domain3")],
) )
def test_get_spf_record():
client = InMemoryDNSClient()
sl_domain = random_domain()
domain = random_domain()
spf_record = f"v=spf1 include:{sl_domain}"
client.set_txt_record(domain, [spf_record, "another record"])
res = client.get_spf_domain(domain)
assert res == [sl_domain]