4.53.1
All checks were successful
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m13s
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m54s
Build-Release-Image / Merge-Images (push) Successful in 16s
Build-Release-Image / Create-Release (push) Successful in 40s
Build-Release-Image / Notify (push) Successful in 5s

This commit is contained in:
MrMeeb 2024-10-09 12:00:06 +01:00
parent 294232a329
commit 3da6c983e1
23 changed files with 741 additions and 299 deletions

View File

@ -163,7 +163,7 @@ jobs:
uses: docker/build-push-action@v3 uses: docker/build-push-action@v3
with: with:
context: . context: .
platforms: linux/amd64,linux/arm64 platforms: linux/amd64
push: true push: true
tags: ${{ steps.meta.outputs.tags }} tags: ${{ steps.meta.outputs.tags }}

View File

@ -34,6 +34,7 @@ from app.models import (
DeletedAlias, DeletedAlias,
DomainDeletedAlias, DomainDeletedAlias,
PartnerUser, PartnerUser,
AliasMailbox,
) )
from app.newsletter_utils import send_newsletter_to_user, send_newsletter_to_address from app.newsletter_utils import send_newsletter_to_user, send_newsletter_to_address
@ -785,6 +786,25 @@ class EmailSearchHelpers:
def mailbox_count(user: User) -> int: def mailbox_count(user: User) -> int:
return Mailbox.filter_by(user_id=user.id).order_by(Mailbox.id.desc()).count() return Mailbox.filter_by(user_id=user.id).order_by(Mailbox.id.desc()).count()
@staticmethod
def alias_mailboxes(alias: Alias) -> list[Mailbox]:
return (
Session.query(Mailbox)
.filter(Mailbox.id == Alias.mailbox_id, Alias.id == alias.id)
.union(
Session.query(Mailbox)
.join(AliasMailbox, Mailbox.id == AliasMailbox.mailbox_id)
.filter(AliasMailbox.alias_id == alias.id)
)
.order_by(Mailbox.id)
.limit(10)
.all()
)
@staticmethod
def alias_mailbox_count(alias: Alias) -> int:
return len(alias.mailboxes)
@staticmethod @staticmethod
def alias_list(user: User) -> list[Alias]: def alias_list(user: User) -> list[Alias]:
return ( return (

View File

@ -1,6 +1,7 @@
import csv import csv
from io import StringIO from io import StringIO
import re import re
from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
from email_validator import validate_email, EmailNotValidError from email_validator import validate_email, EmailNotValidError
@ -23,6 +24,7 @@ from app.email_utils import (
send_cannot_create_domain_alias, send_cannot_create_domain_alias,
send_email, send_email,
render, render,
sl_formataddr,
) )
from app.errors import AliasInTrashError from app.errors import AliasInTrashError
from app.events.event_dispatcher import EventDispatcher from app.events.event_dispatcher import EventDispatcher
@ -30,6 +32,7 @@ from app.events.generated.event_pb2 import (
AliasDeleted, AliasDeleted,
AliasStatusChanged, AliasStatusChanged,
EventContent, EventContent,
AliasCreated,
) )
from app.log import LOG from app.log import LOG
from app.models import ( from app.models import (
@ -501,6 +504,28 @@ def transfer_alias(alias, new_user, new_mailboxes: [Mailbox]):
alias.disable_pgp = False alias.disable_pgp = False
alias.pinned = False alias.pinned = False
EventDispatcher.send_event(
old_user,
EventContent(
alias_deleted=AliasDeleted(
id=alias.id,
email=alias.email,
)
),
)
EventDispatcher.send_event(
new_user,
EventContent(
alias_created=AliasCreated(
id=alias.id,
email=alias.email,
note=alias.note,
enabled=alias.enabled,
created_at=int(alias.created_at.timestamp),
)
),
)
Session.commit() Session.commit()
@ -518,3 +543,30 @@ def change_alias_status(alias: Alias, enabled: bool, commit: bool = False):
if commit: if commit:
Session.commit() Session.commit()
@dataclass
class AliasRecipientName:
name: str
message: Optional[str] = None
def get_alias_recipient_name(alias: Alias) -> AliasRecipientName:
"""
Logic:
1. If alias has name, use it
2. If alias has custom domain, and custom domain has name, use it
3. Otherwise, use the alias email as the recipient
"""
if alias.name:
return AliasRecipientName(
name=sl_formataddr((alias.name, alias.email)),
message=f"Put alias name {alias.name} in from header",
)
elif alias.custom_domain:
if alias.custom_domain.name:
return AliasRecipientName(
name=sl_formataddr((alias.custom_domain.name, alias.email)),
message=f"Put domain default alias name {alias.custom_domain.name} in from header",
)
return AliasRecipientName(name=alias.email)

View File

@ -52,8 +52,12 @@ def auth_login():
password = data.get("password") password = data.get("password")
device = data.get("device") device = data.get("device")
email = sanitize_email(data.get("email")) email = data.get("email")
canonical_email = canonicalize_email(data.get("email")) if not email:
LoginEvent(LoginEvent.ActionType.failed, LoginEvent.Source.api).send()
return jsonify(error="Email or password incorrect"), 400
email = sanitize_email(email)
canonical_email = canonicalize_email(email)
user = User.get_by(email=email) or User.get_by(email=canonical_email) user = User.get_by(email=email) or User.get_by(email=canonical_email)

View File

@ -87,7 +87,7 @@ def update_user_info():
File.delete(file.id) File.delete(file.id)
s3.delete(file.path) s3.delete(file.path)
Session.flush() Session.flush()
else: if data["profile_picture"] is not None:
raw_data = base64.decodebytes(data["profile_picture"].encode()) raw_data = base64.decodebytes(data["profile_picture"].encode())
if detect_image_format(raw_data) == ImageFormat.Unknown: if detect_image_format(raw_data) == ImageFormat.Unknown:
return jsonify(error="Unsupported image format"), 400 return jsonify(error="Unsupported image format"), 400

View File

@ -653,9 +653,11 @@ def read_partner_dict(var: str) -> dict[int, str]:
return res return res
PARTNER_DOMAINS: dict[int, str] = read_partner_dict("PARTNER_DOMAINS") PARTNER_DNS_CUSTOM_DOMAINS: dict[int, str] = read_partner_dict(
PARTNER_DOMAIN_VALIDATION_PREFIXES: dict[int, str] = read_partner_dict( "PARTNER_DNS_CUSTOM_DOMAINS"
"PARTNER_DOMAIN_VALIDATION_PREFIXES" )
PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES: dict[int, str] = read_partner_dict(
"PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES"
) )
MAILBOX_VERIFICATION_OVERRIDE_CODE: Optional[str] = os.environ.get( MAILBOX_VERIFICATION_OVERRIDE_CODE: Optional[str] = os.environ.get(

View File

@ -1,15 +1,17 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import List, Optional
from app import config from app import config
from app.constants import DMARC_RECORD from app.constants import DMARC_RECORD
from app.db import Session from app.db import Session
from app.dns_utils import ( from app.dns_utils import (
MxRecord,
DNSClient, DNSClient,
is_mx_equivalent, is_mx_equivalent,
get_network_dns_client, get_network_dns_client,
) )
from app.models import CustomDomain from app.models import CustomDomain
from app.utils import random_string
@dataclass @dataclass
@ -28,10 +30,10 @@ class CustomDomainValidation:
): ):
self.dkim_domain = dkim_domain self.dkim_domain = dkim_domain
self._dns_client = dns_client self._dns_client = dns_client
self._partner_domains = partner_domains or config.PARTNER_DOMAINS self._partner_domains = partner_domains or config.PARTNER_DNS_CUSTOM_DOMAINS
self._partner_domain_validation_prefixes = ( self._partner_domain_validation_prefixes = (
partner_domains_validation_prefixes partner_domains_validation_prefixes
or config.PARTNER_DOMAIN_VALIDATION_PREFIXES or config.PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES
) )
def get_ownership_verification_record(self, domain: CustomDomain) -> str: def get_ownership_verification_record(self, domain: CustomDomain) -> str:
@ -41,8 +43,36 @@ class CustomDomainValidation:
and domain.partner_id in self._partner_domain_validation_prefixes and domain.partner_id in self._partner_domain_validation_prefixes
): ):
prefix = self._partner_domain_validation_prefixes[domain.partner_id] prefix = self._partner_domain_validation_prefixes[domain.partner_id]
if not domain.ownership_txt_token:
domain.ownership_txt_token = random_string(30)
Session.commit()
return f"{prefix}-verification={domain.ownership_txt_token}" return f"{prefix}-verification={domain.ownership_txt_token}"
def get_expected_mx_records(self, domain: CustomDomain) -> list[MxRecord]:
records = []
if domain.partner_id is not None and domain.partner_id in self._partner_domains:
domain = self._partner_domains[domain.partner_id]
records.append(MxRecord(10, f"mx1.{domain}."))
records.append(MxRecord(20, f"mx2.{domain}."))
else:
# Default ones
for priority, domain in config.EMAIL_SERVERS_WITH_PRIORITY:
records.append(MxRecord(priority, domain))
return records
def get_expected_spf_domain(self, domain: CustomDomain) -> str:
if domain.partner_id is not None and domain.partner_id in self._partner_domains:
return self._partner_domains[domain.partner_id]
else:
return config.EMAIL_DOMAIN
def get_expected_spf_record(self, domain: CustomDomain) -> str:
spf_domain = self.get_expected_spf_domain(domain)
return f"v=spf1 include:{spf_domain} ~all"
def get_dkim_records(self, domain: CustomDomain) -> {str: str}: def get_dkim_records(self, domain: CustomDomain) -> {str: str}:
""" """
Get a list of dkim records to set up. Depending on the custom_domain, whether if it's from a partner or not, Get a list of dkim records to set up. Depending on the custom_domain, whether if it's from a partner or not,
@ -116,11 +146,12 @@ class CustomDomainValidation:
self, custom_domain: CustomDomain self, custom_domain: CustomDomain
) -> DomainValidationResult: ) -> DomainValidationResult:
mx_domains = self._dns_client.get_mx_domains(custom_domain.domain) mx_domains = self._dns_client.get_mx_domains(custom_domain.domain)
expected_mx_records = self.get_expected_mx_records(custom_domain)
if not is_mx_equivalent(mx_domains, config.EMAIL_SERVERS_WITH_PRIORITY): if not is_mx_equivalent(mx_domains, expected_mx_records):
return DomainValidationResult( return DomainValidationResult(
success=False, success=False,
errors=[f"{priority} {domain}" for (priority, domain) in mx_domains], errors=[f"{record.priority} {record.domain}" for record in mx_domains],
) )
else: else:
custom_domain.verified = True custom_domain.verified = True
@ -131,16 +162,19 @@ class CustomDomainValidation:
self, custom_domain: CustomDomain self, custom_domain: CustomDomain
) -> DomainValidationResult: ) -> DomainValidationResult:
spf_domains = self._dns_client.get_spf_domain(custom_domain.domain) spf_domains = self._dns_client.get_spf_domain(custom_domain.domain)
if config.EMAIL_DOMAIN in spf_domains: expected_spf_domain = self.get_expected_spf_domain(custom_domain)
if expected_spf_domain in spf_domains:
custom_domain.spf_verified = True custom_domain.spf_verified = True
Session.commit() Session.commit()
return DomainValidationResult(success=True, errors=[]) return DomainValidationResult(success=True, errors=[])
else: else:
custom_domain.spf_verified = False custom_domain.spf_verified = False
Session.commit() Session.commit()
txt_records = self._dns_client.get_txt_record(custom_domain.domain)
cleaned_records = self.__clean_spf_records(txt_records, custom_domain)
return DomainValidationResult( return DomainValidationResult(
success=False, success=False,
errors=self._dns_client.get_txt_record(custom_domain.domain), errors=cleaned_records,
) )
def validate_dmarc_records( def validate_dmarc_records(
@ -155,3 +189,13 @@ class CustomDomainValidation:
custom_domain.dmarc_verified = False custom_domain.dmarc_verified = False
Session.commit() Session.commit()
return DomainValidationResult(success=False, errors=txt_records) return DomainValidationResult(success=False, errors=txt_records)
def __clean_spf_records(
self, txt_records: List[str], custom_domain: CustomDomain
) -> List[str]:
final_records = []
verification_record = self.get_ownership_verification_record(custom_domain)
for record in txt_records:
if record != verification_record:
final_records.append(record)
return final_records

View File

@ -21,7 +21,9 @@ class NewCustomDomainForm(FlaskForm):
@parallel_limiter.lock(only_when=lambda: request.method == "POST") @parallel_limiter.lock(only_when=lambda: request.method == "POST")
def custom_domain(): def custom_domain():
custom_domains = CustomDomain.filter_by( custom_domains = CustomDomain.filter_by(
user_id=current_user.id, is_sl_subdomain=False user_id=current_user.id,
is_sl_subdomain=False,
pending_deletion=False,
).all() ).all()
new_custom_domain_form = NewCustomDomainForm() new_custom_domain_form = NewCustomDomainForm()

View File

@ -36,8 +36,6 @@ def domain_detail_dns(custom_domain_id):
custom_domain.ownership_txt_token = random_string(30) custom_domain.ownership_txt_token = random_string(30)
Session.commit() Session.commit()
spf_record = f"v=spf1 include:{EMAIL_DOMAIN} ~all"
domain_validator = CustomDomainValidation(EMAIL_DOMAIN) domain_validator = CustomDomainValidation(EMAIL_DOMAIN)
csrf_form = CSRFValidationForm() csrf_form = CSRFValidationForm()
@ -141,7 +139,9 @@ def domain_detail_dns(custom_domain_id):
ownership_record=domain_validator.get_ownership_verification_record( ownership_record=domain_validator.get_ownership_verification_record(
custom_domain custom_domain
), ),
expected_mx_records=domain_validator.get_expected_mx_records(custom_domain),
dkim_records=domain_validator.get_dkim_records(custom_domain), dkim_records=domain_validator.get_dkim_records(custom_domain),
spf_record=domain_validator.get_expected_spf_record(custom_domain),
dmarc_record=DMARC_RECORD, dmarc_record=DMARC_RECORD,
**locals(), **locals(),
) )

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Tuple, Optional from dataclasses import dataclass
from typing import List, Optional
import dns.resolver import dns.resolver
@ -8,8 +9,14 @@ from app.config import NAMESERVERS
_include_spf = "include:" _include_spf = "include:"
@dataclass
class MxRecord:
priority: int
domain: str
def is_mx_equivalent( def is_mx_equivalent(
mx_domains: List[Tuple[int, str]], ref_mx_domains: List[Tuple[int, str]] mx_domains: List[MxRecord], ref_mx_domains: List[MxRecord]
) -> bool: ) -> bool:
""" """
Compare mx_domains with ref_mx_domains to see if they are equivalent. Compare mx_domains with ref_mx_domains to see if they are equivalent.
@ -18,14 +25,14 @@ def is_mx_equivalent(
The priority order is taken into account but not the priority number. The priority order is taken into account but not the priority number.
For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)] For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)]
""" """
mx_domains = sorted(mx_domains, key=lambda x: x[0]) mx_domains = sorted(mx_domains, key=lambda x: x.priority)
ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x[0]) ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x.priority)
if len(mx_domains) < len(ref_mx_domains): if len(mx_domains) < len(ref_mx_domains):
return False return False
for i in range(len(ref_mx_domains)): for actual, expected in zip(mx_domains, ref_mx_domains):
if mx_domains[i][1] != ref_mx_domains[i][1]: if actual.domain != expected.domain:
return False return False
return True return True
@ -37,7 +44,7 @@ class DNSClient(ABC):
pass pass
@abstractmethod @abstractmethod
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]: def get_mx_domains(self, hostname: str) -> List[MxRecord]:
pass pass
def get_spf_domain(self, hostname: str) -> List[str]: def get_spf_domain(self, hostname: str) -> List[str]:
@ -81,7 +88,7 @@ class NetworkDNSClient(DNSClient):
except Exception: except Exception:
return None return None
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]: def get_mx_domains(self, hostname: str) -> List[MxRecord]:
""" """
return list of (priority, domain name) sorted by priority (lowest priority first) return list of (priority, domain name) sorted by priority (lowest priority first)
domain name ends with a "." at the end. domain name ends with a "." at the end.
@ -92,14 +99,14 @@ class NetworkDNSClient(DNSClient):
for a in answers: for a in answers:
record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.' record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
parts = record.split(" ") parts = record.split(" ")
ret.append((int(parts[0]), parts[1])) ret.append(MxRecord(priority=int(parts[0]), domain=parts[1]))
return sorted(ret, key=lambda x: x[0]) return sorted(ret, key=lambda x: x.priority)
except Exception: except Exception:
return [] return []
def get_txt_record(self, hostname: str) -> List[str]: def get_txt_record(self, hostname: str) -> List[str]:
try: try:
answers = self._resolver.resolve(hostname, "TXT", search=True) answers = self._resolver.resolve(hostname, "TXT", search=False)
ret = [] ret = []
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
for record in a.strings: for record in a.strings:
@ -112,14 +119,14 @@ class NetworkDNSClient(DNSClient):
class InMemoryDNSClient(DNSClient): class InMemoryDNSClient(DNSClient):
def __init__(self): def __init__(self):
self.cname_records: dict[str, Optional[str]] = {} self.cname_records: dict[str, Optional[str]] = {}
self.mx_records: dict[str, List[Tuple[int, str]]] = {} self.mx_records: dict[str, List[MxRecord]] = {}
self.spf_records: dict[str, List[str]] = {} self.spf_records: dict[str, List[str]] = {}
self.txt_records: dict[str, List[str]] = {} self.txt_records: dict[str, List[str]] = {}
def set_cname_record(self, hostname: str, cname: str): def set_cname_record(self, hostname: str, cname: str):
self.cname_records[hostname] = cname self.cname_records[hostname] = cname
def set_mx_records(self, hostname: str, mx_list: List[Tuple[int, str]]): def set_mx_records(self, hostname: str, mx_list: List[MxRecord]):
self.mx_records[hostname] = mx_list self.mx_records[hostname] = mx_list
def set_txt_record(self, hostname: str, txt_list: List[str]): def set_txt_record(self, hostname: str, txt_list: List[str]):
@ -128,9 +135,9 @@ class InMemoryDNSClient(DNSClient):
def get_cname_record(self, hostname: str) -> Optional[str]: def get_cname_record(self, hostname: str) -> Optional[str]:
return self.cname_records.get(hostname) return self.cname_records.get(hostname)
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]: def get_mx_domains(self, hostname: str) -> List[MxRecord]:
mx_list = self.mx_records.get(hostname, []) mx_list = self.mx_records.get(hostname, [])
return sorted(mx_list, key=lambda x: x[0]) return sorted(mx_list, key=lambda x: x.priority)
def get_txt_record(self, hostname: str) -> List[str]: def get_txt_record(self, hostname: str) -> List[str]:
return self.txt_records.get(hostname, []) return self.txt_records.get(hostname, [])
@ -140,5 +147,5 @@ def get_network_dns_client() -> NetworkDNSClient:
return NetworkDNSClient(NAMESERVERS) return NetworkDNSClient(NAMESERVERS)
def get_mx_domains(hostname: str) -> [(int, str)]: def get_mx_domains(hostname: str) -> List[MxRecord]:
return get_network_dns_client().get_mx_domains(hostname) return get_network_dns_client().get_mx_domains(hostname)

View File

@ -592,7 +592,7 @@ def email_can_be_used_as_mailbox(email_address: str) -> bool:
from app.models import CustomDomain from app.models import CustomDomain
if CustomDomain.get_by(domain=domain, verified=True): if CustomDomain.get_by(domain=domain, is_sl_subdomain=True, verified=True):
LOG.d("domain %s is a SimpleLogin custom domain", domain) LOG.d("domain %s is a SimpleLogin custom domain", domain)
return False return False
@ -657,7 +657,7 @@ def get_mx_domain_list(domain) -> [str]:
""" """
priority_domains = get_mx_domains(domain) priority_domains = get_mx_domains(domain)
return [d[:-1] for _, d in priority_domains] return [d.domain[:-1] for d in priority_domains]
def personal_email_already_used(email_address: str) -> bool: def personal_email_already_used(email_address: str) -> bool:

View File

@ -2766,9 +2766,9 @@ class Mailbox(Base, ModelMixin):
from app.email_utils import get_email_local_part from app.email_utils import get_email_local_part
mx_domains: [(int, str)] = get_mx_domains(get_email_local_part(self.email)) mx_domains = get_mx_domains(get_email_local_part(self.email))
# Proton is the first domain # Proton is the first domain
if mx_domains and mx_domains[0][1] in ( if mx_domains and mx_domains[0].domain in (
"mail.protonmail.ch.", "mail.protonmail.ch.",
"mailsec.protonmail.ch.", "mailsec.protonmail.ch.",
): ):

View File

@ -14,6 +14,7 @@ from sqlalchemy.sql import Insert, text
from app import s3, config from app import s3, config
from app.alias_utils import nb_email_log_for_mailbox from app.alias_utils import nb_email_log_for_mailbox
from app.api.views.apple import verify_receipt from app.api.views.apple import verify_receipt
from app.custom_domain_validation import CustomDomainValidation
from app.db import Session from app.db import Session
from app.dns_utils import get_mx_domains, is_mx_equivalent from app.dns_utils import get_mx_domains, is_mx_equivalent
from app.email_utils import ( from app.email_utils import (
@ -905,9 +906,11 @@ def check_custom_domain():
LOG.i("custom domain has been deleted") LOG.i("custom domain has been deleted")
def check_single_custom_domain(custom_domain): def check_single_custom_domain(custom_domain: CustomDomain):
mx_domains = get_mx_domains(custom_domain.domain) mx_domains = get_mx_domains(custom_domain.domain)
if not is_mx_equivalent(mx_domains, config.EMAIL_SERVERS_WITH_PRIORITY): validator = CustomDomainValidation(dkim_domain=config.EMAIL_DOMAIN)
expected_custom_domains = validator.get_expected_mx_records(custom_domain)
if not is_mx_equivalent(mx_domains, expected_custom_domains):
user = custom_domain.user user = custom_domain.user
LOG.w( LOG.w(
"The MX record is not correctly set for %s %s %s", "The MX record is not correctly set for %s %s %s",

View File

@ -53,7 +53,11 @@ from flanker.addresslib.address import EmailAddress
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from app import pgp_utils, s3, config, contact_utils from app import pgp_utils, s3, config, contact_utils
from app.alias_utils import try_auto_create, change_alias_status from app.alias_utils import (
try_auto_create,
change_alias_status,
get_alias_recipient_name,
)
from app.config import ( from app.config import (
EMAIL_DOMAIN, EMAIL_DOMAIN,
URL, URL,
@ -1161,23 +1165,11 @@ def handle_reply(envelope, msg: Message, rcpt_to: str) -> (bool, str):
Session.commit() Session.commit()
# make the email comes from alias recipient_name = get_alias_recipient_name(alias)
from_header = alias.email if recipient_name.message:
# add alias name from alias LOG.d(recipient_name.message)
if alias.name: LOG.d("From header is %s", recipient_name.name)
LOG.d("Put alias name %s in from header", alias.name) add_or_replace_header(msg, headers.FROM, recipient_name.name)
from_header = sl_formataddr((alias.name, alias.email))
elif alias.custom_domain:
# add alias name from domain
if alias.custom_domain.name:
LOG.d(
"Put domain default alias name %s in from header",
alias.custom_domain.name,
)
from_header = sl_formataddr((alias.custom_domain.name, alias.email))
LOG.d("From header is %s", from_header)
add_or_replace_header(msg, headers.FROM, from_header)
try: try:
if str(msg[headers.TO]).lower() == "undisclosed-recipients:;": if str(msg[headers.TO]).lower() == "undisclosed-recipients:;":

View File

@ -1,220 +1,220 @@
{% extends 'admin/master.html' %} {% extends 'admin/master.html' %}
{% macro show_user(user) -%} {% macro show_user(user) -%}
<h4>User {{ user.email }} with ID {{ user.id }}.</h4> <h4>User {{ user.email }} with ID {{ user.id }}.</h4>
{% set pu = helper.partner_user(user) %} {% set pu = helper.partner_user(user) %}
<table class="table"> <table class="table">
<thead> <thead>
<tr>
<th scope="col">User ID</th>
<th scope="col">Email</th>
<th scope="col">Status</th>
<th scope="col">Paid</th>
<th>Subscription</th>
<th>Created At</th>
<th>Updated At</th>
<th>Connected with Proton account</th>
</tr>
</thead>
<tbody>
<tr>
<td>{{ user.id }}</td>
<td><a href="?email={{ user.email }}">{{ user.email }}</a></td>
{% if user.disabled %}
<td class="text-danger">Disabled</td>
{% else %}
<td class="text-success">Enabled</td>
{% endif %}
<td>{{ "yes" if user.is_paid() else "No" }}</td>
<td>{{ user.get_active_subscription() }}</td>
<td>{{ user.created_at }}</td>
<td>{{ user.updated_at }}</td>
{% if pu %}
<td><a href="?email={{ pu.partner_email }}">{{ pu.partner_email }}</a></td>
{% else %}
<td>No</td>
{% endif %}
</tr>
</tbody>
</table>
{%- endmacro %}
{% macro list_mailboxes(mbox_count, mboxes) %}
<h4>
{{ mbox_count }} Mailboxes found.
{% if mbox_count>10 %}Showing only the last 10.{% endif %}
</h4>
<table class="table">
<thead>
<tr>
<th>Mailbox ID</th>
<th>Email</th>
<th>Verified</th>
<th>Created At</th>
</tr>
</thead>
<tbody>
{% for mailbox in mboxes %}
<tr> <tr>
<td>{{ mailbox.id }}</td> <th scope="col">User ID</th>
<td><a href="?email={{mailbox.email}}">{{mailbox.email}}</a></td> <th scope="col">Email</th>
<td>{{ "Yes" if mailbox.verified else "No" }}</td> <th scope="col">Status</th>
<td> <th scope="col">Paid</th>
{{ mailbox.created_at }} <th>Subscription</th>
</td> <th>Created At</th>
<th>Updated At</th>
<th>Connected with Proton account</th>
</tr> </tr>
{% endfor %} </thead>
</tbody> <tbody>
</table> <tr>
<td>{{ user.id }}</td>
<td><a href="?email={{ user.email }}">{{ user.email }}</a></td>
{% if user.disabled %}
<td class="text-danger">Disabled</td>
{% else %}
<td class="text-success">Enabled</td>
{% endif %}
<td>{{ "yes" if user.is_paid() else "No" }}</td>
<td>{{ user.get_active_subscription() }}</td>
<td>{{ user.created_at }}</td>
<td>{{ user.updated_at }}</td>
{% if pu %}
<td><a href="?email={{ pu.partner_email }}">{{ pu.partner_email }}</a></td>
{% else %}
<td>No</td>
{% endif %}
</tr>
</tbody>
</table>
{%- endmacro %}
{% macro list_mailboxes(message, mbox_count, mboxes) %}
<h4>
{{ mbox_count }} {{ message }}.
{% if mbox_count>10 %}Showing only the last 10.{% endif %}
</h4>
<table class="table">
<thead>
<tr>
<th>Mailbox ID</th>
<th>Email</th>
<th>Verified</th>
<th>Created At</th>
</tr>
</thead>
<tbody>
{% for mailbox in mboxes %}
<tr>
<td>{{ mailbox.id }}</td>
<td><a href="?email={{ mailbox.email }}">{{ mailbox.email }}</a></td>
<td>{{ "Yes" if mailbox.verified else "No" }}</td>
<td>
{{ mailbox.created_at }}
</td>
</tr>
{% endfor %}
</tbody>
</table>
{% endmacro %} {% endmacro %}
{% macro list_alias(alias_count, aliases) %} {% macro list_alias(alias_count, aliases) %}
<h4> <h4>
{{ alias_count }} Aliases found. {{ alias_count }} Aliases found.
{% if alias_count>10 %}Showing only the last 10.{% endif %} {% if alias_count>10 %}Showing only the last 10.{% endif %}
</h4> </h4>
<table class="table"> <table class="table">
<thead> <thead>
<tr>
<th>
Alias ID
</th>
<th>
Email
</th>
<th>
Verified
</th>
<th>
Created At
</th>
</tr>
</thead>
<tbody>
{% for alias in aliases %}
<tr> <tr>
<td>{{ alias.id }}</td> <th>
<td><a href="?email={{alias.email}}">{{alias.email}}</a></td> Alias ID
<td>{{ "Yes" if alias.verified else "No" }}</td> </th>
<td>{{ alias.created_at }}</td> <th>
Email
</th>
<th>
Verified
</th>
<th>
Created At
</th>
</tr> </tr>
{% endfor %} </thead>
</tbody> <tbody>
</table> {% for alias in aliases %}
<tr>
<td>{{ alias.id }}</td>
<td><a href="?email={{ alias.email }}">{{ alias.email }}</a></td>
<td>{{ "Yes" if alias.verified else "No" }}</td>
<td>{{ alias.created_at }}</td>
</tr>
{% endfor %}
</tbody>
</table>
{% endmacro %} {% endmacro %}
{% macro show_deleted_alias(deleted_alias) -%} {% macro show_deleted_alias(deleted_alias) -%}
<h4>Deleted Alias {{ deleted_alias.email }} with ID {{ deleted_alias.id }}.</h4> <h4>Deleted Alias {{ deleted_alias.email }} with ID {{ deleted_alias.id }}.</h4>
<table class="table"> <table class="table">
<thead> <thead>
<tr> <tr>
<th scope="col">Deleted Alias ID</th> <th scope="col">Deleted Alias ID</th>
<th scope="col">Email</th> <th scope="col">Email</th>
<th scope="col">Deleted At</th> <th scope="col">Deleted At</th>
<th scope="col">Reason</th> <th scope="col">Reason</th>
</tr> </tr>
</thead> </thead>
<tbody> <tbody>
<tr> <tr>
<td>{{ deleted_alias.id }}</td> <td>{{ deleted_alias.id }}</td>
<td>{{ deleted_alias.email }}</td> <td>{{ deleted_alias.email }}</td>
<td>{{ deleted_alias.created_at }}</td> <td>{{ deleted_alias.created_at }}</td>
<td>{{ deleted_alias.reason }}</td> <td>{{ deleted_alias.reason }}</td>
</tr> </tr>
</tbody> </tbody>
</table> </table>
{%- endmacro %} {%- endmacro %}
{% macro show_domain_deleted_alias(dom_deleted_alias) -%} {% macro show_domain_deleted_alias(dom_deleted_alias) -%}
<h4> <h4>
Domain Deleted Alias {{ dom_deleted_alias.email }} with ID {{ dom_deleted_alias.id }} for domain {{ dom_deleted_alias.domain.domain }} Domain Deleted Alias {{ dom_deleted_alias.email }} with ID {{ dom_deleted_alias.id }} for
</h4> domain {{ dom_deleted_alias.domain.domain }}
<table class="table"> </h4>
<thead> <table class="table">
<tr> <thead>
<th scope="col">Deleted Alias ID</th> <tr>
<th scope="col">Email</th> <th scope="col">Deleted Alias ID</th>
<th scope="col">Domain</th> <th scope="col">Email</th>
<th scope="col">Domain ID</th> <th scope="col">Domain</th>
<th scope="col">Domain owner user ID</th> <th scope="col">Domain ID</th>
<th scope="col">Domain owner user email</th> <th scope="col">Domain owner user ID</th>
<th scope="col">Deleted At</th> <th scope="col">Domain owner user email</th>
</tr> <th scope="col">Deleted At</th>
</thead> </tr>
<tbody> </thead>
<tr> <tbody>
<td>{{ dom_deleted_alias.id }}</td> <tr>
<td>{{ dom_deleted_alias.email }}</td> <td>{{ dom_deleted_alias.id }}</td>
<td>{{ dom_deleted_alias.domain.domain }}</td> <td>{{ dom_deleted_alias.email }}</td>
<td>{{ dom_deleted_alias.domain.id }}</td> <td>{{ dom_deleted_alias.domain.domain }}</td>
<td>{{ dom_deleted_alias.domain.user_id }}</td> <td>{{ dom_deleted_alias.domain.id }}</td>
<td>{{ dom_deleted_alias.created_at }}</td> <td>{{ dom_deleted_alias.domain.user_id }}</td>
</tr> <td>{{ dom_deleted_alias.created_at }}</td>
</tbody> </tr>
</table> </tbody>
{{ show_user(data.domain_deleted_alias.domain.user) }} </table>
{{ show_user(data.domain_deleted_alias.domain.user) }}
{%- endmacro %} {%- endmacro %}
{% block body %} {% block body %}
<div class="border border-dark border-2 mt-1 mb-2 p-3">
<form method="get">
<div class="form-group">
<label for="email">Email to search:</label>
<input type="text"
class="form-control"
name="email"
value="{{ email or '' }}" />
</div>
<button type="submit" class="btn btn-primary">Submit</button>
</form>
</div>
{% if data.no_match and email %}
<div class="border border-dark border-2 mt-1 mb-2 p-3 alert alert-warning"
role="alert">No user, alias or mailbox found for {{ email }}</div>
{% endif %}
{% if data.alias %}
<div class="border border-dark border-2 mt-1 mb-2 p-3"> <div class="border border-dark border-2 mt-1 mb-2 p-3">
<h3 class="mb-3">Found Alias {{ data.alias.email }}</h3> <form method="get">
{{ list_alias(1,[data.alias]) }} <div class="form-group">
{{ show_user(data.alias.user) }} <label for="email">Email to search:</label>
{{ list_mailboxes(helper.mailbox_count(data.alias.user) , helper.mailbox_list(data.alias.user) ) }} <input type="text"
class="form-control"
name="email"
value="{{ email or '' }}"/>
</div>
<button type="submit" class="btn btn-primary">Submit</button>
</form>
</div> </div>
{% endif %} {% if data.no_match and email %}
{% if data.user %} <div class="border border-dark border-2 mt-1 mb-2 p-3 alert alert-warning"
role="alert">No user, alias or mailbox found for {{ email }}</div>
<div class="border border-dark border-2 mt-1 mb-2 p-3">
<h3 class="mb-3">Found User {{ data.user.email }}</h3>
{{ show_user(data.user) }}
{{ list_mailboxes(helper.mailbox_count(data.user) , helper.mailbox_list(data.user) ) }}
{{ list_alias(helper.alias_count(data.user) ,helper.alias_list(data.user)) }}
</div>
{% endif %}
{% if data.mailbox_count > 10 %}
<h3>Found more than 10 mailboxes for {{ email }}. Showing the last 10</h3>
{% elif data.mailbox_count > 0 %}
<h3>Found {{ data.mailbox_count }} mailbox(es) for {{ email }}</h3>
{% endif %} {% endif %}
{% for mailbox in data.mailbox %}
<div class="border border-dark mt-1 mb-2 p-3"> {% if data.alias %}
<h3 class="mb-3">Found Mailbox {{ mailbox.email }}</h3> <div class="border border-dark border-2 mt-1 mb-2 p-3">
{{ list_mailboxes(1, [mailbox]) }} <h3 class="mb-3">Found Alias {{ data.alias.email }}</h3>
{{ show_user(mailbox.user) }} {{ list_alias(1,[data.alias]) }}
</div> {{ list_mailboxes("Mailboxes for alias", helper.alias_mailbox_count(data.alias), helper.alias_mailboxes(data.alias)) }}
{% endfor %} {{ show_user(data.alias.user) }}
{% if data.deleted_alias %} </div>
{% endif %}
<div class="border border-dark mt-1 mb-2 p-3"> {% if data.user %}
<h3 class="mb-3">Found DeletedAlias {{ data.deleted_alias.email }}</h3> <div class="border border-dark border-2 mt-1 mb-2 p-3">
{{ show_deleted_alias(data.deleted_alias) }} <h3 class="mb-3">Found User {{ data.user.email }}</h3>
</div> {{ show_user(data.user) }}
{% endif %} {{ list_mailboxes("Mailboxes for user", helper.mailbox_count(data.user) , helper.mailbox_list(data.user) ) }}
{% if data.domain_deleted_alias %} {{ list_alias(helper.alias_count(data.user) ,helper.alias_list(data.user)) }}
</div>
{% endif %}
{% if data.mailbox_count > 10 %}
<h3>Found more than 10 mailboxes for {{ email }}. Showing the last 10</h3>
{% elif data.mailbox_count > 0 %}
<h3>Found {{ data.mailbox_count }} mailbox(es) for {{ email }}</h3>
{% endif %}
{% for mailbox in data.mailbox %}
<div class="border border-dark mt-1 mb-2 p-3"> <div class="border border-dark mt-1 mb-2 p-3">
<h3 class="mb-3">Found DomainDeletedAlias {{ data.domain_deleted_alias.email }}</h3> <h3 class="mb-3">Found Mailbox {{ mailbox.email }}</h3>
{{ show_domain_deleted_alias(data.domain_deleted_alias) }} {{ list_mailboxes("Mailbox found", 1, [mailbox]) }}
</div> {{ show_user(mailbox.user) }}
{% endif %} </div>
{% endfor %}
{% if data.deleted_alias %}
<div class="border border-dark mt-1 mb-2 p-3">
<h3 class="mb-3">Found DeletedAlias {{ data.deleted_alias.email }}</h3>
{{ show_deleted_alias(data.deleted_alias) }}
</div>
{% endif %}
{% if data.domain_deleted_alias %}
<div class="border border-dark mt-1 mb-2 p-3">
<h3 class="mb-3">Found DomainDeletedAlias {{ data.domain_deleted_alias.email }}</h3>
{{ show_domain_deleted_alias(data.domain_deleted_alias) }}
</div>
{% endif %}
{% endblock %} {% endblock %}

View File

@ -91,7 +91,8 @@
<br /> <br />
Some domain registrars (Namecheap, CloudFlare, etc) might also use <em>@</em> for the root domain. Some domain registrars (Namecheap, CloudFlare, etc) might also use <em>@</em> for the root domain.
</div> </div>
{% for priority, email_server in EMAIL_SERVERS_WITH_PRIORITY %}
{% for record in expected_mx_records %}
<div class="mb-3 p-3 dns-record"> <div class="mb-3 p-3 dns-record">
Record: MX Record: MX
@ -99,14 +100,15 @@
Domain: {{ custom_domain.domain }} or Domain: {{ custom_domain.domain }} or
<b>@</b> <b>@</b>
<br /> <br />
Priority: {{ priority }} Priority: {{ record.priority }}
<br /> <br />
Target: <em data-toggle="tooltip" Target: <em data-toggle="tooltip"
title="Click to copy" title="Click to copy"
class="clipboard" class="clipboard"
data-clipboard-text="{{ email_server }}">{{ email_server }}</em> data-clipboard-text="{{ record.domain }}">{{ record.domain }}</em>
</div> </div>
{% endfor %} {% endfor %}
<form method="post" action="#mx-form"> <form method="post" action="#mx-form">
{{ csrf_form.csrf_token }} {{ csrf_form.csrf_token }}
<input type="hidden" name="form-name" value="check-mx"> <input type="hidden" name="form-name" value="check-mx">

View File

@ -1,38 +1,72 @@
import app.alias_utils import app.alias_utils
from app import config
from app.db import Session from app.db import Session
from app.events.event_dispatcher import GlobalDispatcher
from app.models import ( from app.models import (
Alias, Alias,
Mailbox, Mailbox,
User,
AliasMailbox, AliasMailbox,
) )
from tests.events.event_test_utils import (
OnMemoryDispatcher,
_get_event_from_string,
_create_linked_user,
)
from tests.utils import login from tests.utils import login
on_memory_dispatcher = OnMemoryDispatcher()
def setup_module():
GlobalDispatcher.set_dispatcher(on_memory_dispatcher)
config.EVENT_WEBHOOK = "http://test"
def teardown_module():
GlobalDispatcher.set_dispatcher(None)
config.EVENT_WEBHOOK = None
def test_alias_transfer(flask_client): def test_alias_transfer(flask_client):
user = login(flask_client) (source_user, source_user_pu) = _create_linked_user()
mb = Mailbox.create(user_id=user.id, email="mb@gmail.com", commit=True) source_user = login(flask_client, source_user)
mb = Mailbox.create(user_id=source_user.id, email="mb@gmail.com", commit=True)
alias = Alias.create_new_random(user) alias = Alias.create_new_random(source_user)
Session.commit() Session.commit()
AliasMailbox.create(alias_id=alias.id, mailbox_id=mb.id, commit=True) AliasMailbox.create(alias_id=alias.id, mailbox_id=mb.id, commit=True)
new_user = User.create( (target_user, target_user_pu) = _create_linked_user()
email="hey@example.com",
password="password",
activated=True,
commit=True,
)
Mailbox.create( Mailbox.create(
user_id=new_user.id, email="hey2@example.com", verified=True, commit=True user_id=target_user.id, email="hey2@example.com", verified=True, commit=True
) )
app.alias_utils.transfer_alias(alias, new_user, new_user.mailboxes()) on_memory_dispatcher.clear()
app.alias_utils.transfer_alias(alias, target_user, target_user.mailboxes())
# refresh from db # refresh from db
alias = Alias.get(alias.id) alias = Alias.get(alias.id)
assert alias.user == new_user assert alias.user == target_user
assert set(alias.mailboxes) == set(new_user.mailboxes()) assert set(alias.mailboxes) == set(target_user.mailboxes())
assert len(alias.mailboxes) == 2 assert len(alias.mailboxes) == 2
# Check events
assert len(on_memory_dispatcher.memory) == 2
# 1st delete event
event_data = on_memory_dispatcher.memory[0]
event_content = _get_event_from_string(event_data, source_user, source_user_pu)
assert event_content.alias_deleted is not None
alias_deleted = event_content.alias_deleted
assert alias_deleted.id == alias.id
assert alias_deleted.email == alias.email
# 2nd create event
event_data = on_memory_dispatcher.memory[1]
event_content = _get_event_from_string(event_data, target_user, target_user_pu)
assert event_content.alias_created is not None
alias_created = event_content.alias_created
assert alias.id == alias_created.id
assert alias.email == alias_created.email
assert alias.note or "" == alias_created.note
assert alias.enabled == alias_created.enabled

View File

@ -1,4 +1,5 @@
from app.events.event_dispatcher import Dispatcher from app.events.event_dispatcher import Dispatcher
from app.events.generated import event_pb2
from app.models import PartnerUser, User from app.models import PartnerUser, User
from app.proton.utils import get_proton_partner from app.proton.utils import get_proton_partner
from tests.utils import create_new_user, random_token from tests.utils import create_new_user, random_token
@ -30,3 +31,14 @@ def _create_linked_user() -> Tuple[User, PartnerUser]:
) )
return user, partner_user return user, partner_user
def _get_event_from_string(
data: str, user: User, pu: PartnerUser
) -> event_pb2.EventContent:
event = event_pb2.Event()
event.ParseFromString(data)
assert user.id == event.user_id
assert pu.external_user_id == event.external_user_id
assert pu.partner_id == event.partner_id
return event.content

View File

@ -1,12 +1,12 @@
from app import config, alias_utils from app import config, alias_utils
from app.db import Session from app.db import Session
from app.events.event_dispatcher import GlobalDispatcher from app.events.event_dispatcher import GlobalDispatcher
from app.events.generated import event_pb2 from app.models import Alias
from app.models import Alias, User, PartnerUser
from tests.utils import random_token from tests.utils import random_token
from .event_test_utils import ( from .event_test_utils import (
OnMemoryDispatcher, OnMemoryDispatcher,
_create_linked_user, _create_linked_user,
_get_event_from_string,
) )
on_memory_dispatcher = OnMemoryDispatcher() on_memory_dispatcher = OnMemoryDispatcher()
@ -26,17 +26,6 @@ def setup_function(func):
on_memory_dispatcher.clear() on_memory_dispatcher.clear()
def _get_event_from_string(
data: str, user: User, pu: PartnerUser
) -> event_pb2.EventContent:
event = event_pb2.Event()
event.ParseFromString(data)
assert user.id == event.user_id
assert pu.external_user_id == event.external_user_id
assert pu.partner_id == event.partner_id
return event.content
def test_fire_event_on_alias_creation(): def test_fire_event_on_alias_creation():
(user, pu) = _create_linked_user() (user, pu) = _create_linked_user()
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)

View File

@ -4,6 +4,7 @@ from app.alias_utils import (
delete_alias, delete_alias,
check_alias_prefix, check_alias_prefix,
get_user_if_alias_would_auto_create, get_user_if_alias_would_auto_create,
get_alias_recipient_name,
try_auto_create, try_auto_create,
) )
from app.config import ALIAS_DOMAINS from app.config import ALIAS_DOMAINS
@ -18,7 +19,8 @@ from app.models import (
User, User,
DomainDeletedAlias, DomainDeletedAlias,
) )
from tests.utils import create_new_user, random_domain, random_token from app.utils import random_string
from tests.utils import create_new_user, random_domain, random_token, random_email
def test_delete_alias(flask_client): def test_delete_alias(flask_client):
@ -131,3 +133,91 @@ def test_auto_create_alias(flask_client):
assert result, f"Case {test_id} - Failed address {address}" assert result, f"Case {test_id} - Failed address {address}"
else: else:
assert result is None, f"Case {test_id} - Failed address {address}" assert result is None, f"Case {test_id} - Failed address {address}"
# get_alias_recipient_name
def test_get_alias_recipient_name_no_overrides():
user = create_new_user()
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
commit=True,
)
res = get_alias_recipient_name(alias)
assert res.message is None
assert res.name == alias.email
def test_get_alias_recipient_name_alias_name():
user = create_new_user()
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
name=random_string(),
commit=True,
)
res = get_alias_recipient_name(alias)
assert res.message is not None
assert res.name == f"{alias.name} <{alias.email}>"
def test_get_alias_recipient_alias_with_name_and_custom_domain_name():
user = create_new_user()
custom_domain = CustomDomain.create(
user_id=user.id,
domain=random_domain(),
name=random_string(),
verified=True,
)
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
name=random_string(),
custom_domain_id=custom_domain.id,
commit=True,
)
res = get_alias_recipient_name(alias)
assert res.message is not None
assert res.name == f"{alias.name} <{alias.email}>"
def test_get_alias_recipient_alias_without_name_and_custom_domain_without_name():
user = create_new_user()
custom_domain = CustomDomain.create(
user_id=user.id,
domain=random_domain(),
verified=True,
)
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
custom_domain_id=custom_domain.id,
commit=True,
)
res = get_alias_recipient_name(alias)
assert res.message is None
assert res.name == alias.email
def test_get_alias_recipient_alias_without_name_and_custom_domain_name():
user = create_new_user()
custom_domain = CustomDomain.create(
user_id=user.id,
domain=random_domain(),
name=random_string(),
verified=True,
)
alias = Alias.create(
user_id=user.id,
email=random_email(),
mailbox_id=user.default_mailbox_id,
custom_domain_id=custom_domain.id,
commit=True,
)
res = get_alias_recipient_name(alias)
assert res.message is not None
assert res.name == f"{custom_domain.name} <{alias.email}>"

View File

@ -5,7 +5,7 @@ from app.constants import DMARC_RECORD
from app.custom_domain_validation import CustomDomainValidation from app.custom_domain_validation import CustomDomainValidation
from app.db import Session from app.db import Session
from app.models import CustomDomain, User from app.models import CustomDomain, User
from app.dns_utils import InMemoryDNSClient from app.dns_utils import InMemoryDNSClient, MxRecord
from app.proton.utils import get_proton_partner from app.proton.utils import get_proton_partner
from app.utils import random_string from app.utils import random_string
from tests.utils import create_new_user, random_domain from tests.utils import create_new_user, random_domain
@ -58,6 +58,123 @@ def test_custom_domain_validation_get_dkim_records_for_partner():
assert records["dkim._domainkey"] == f"dkim._domainkey.{dkim_domain}" assert records["dkim._domainkey"] == f"dkim._domainkey.{dkim_domain}"
# get_expected_mx_records
def test_custom_domain_validation_get_expected_mx_records_regular_domain():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
dkim_domain = random_domain()
validator = CustomDomainValidation(
domain, partner_domains={partner_id: dkim_domain}
)
records = validator.get_expected_mx_records(custom_domain)
# As the domain is not a partner_domain,default records should be used even if
# there is a config for the partner
assert len(records) == len(config.EMAIL_SERVERS_WITH_PRIORITY)
for i in range(len(config.EMAIL_SERVERS_WITH_PRIORITY)):
config_record = config.EMAIL_SERVERS_WITH_PRIORITY[i]
assert records[i].priority == config_record[0]
assert records[i].domain == config_record[1]
def test_custom_domain_validation_get_expected_mx_records_domain_from_partner():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
custom_domain.partner_id = partner_id
Session.commit()
dkim_domain = random_domain()
validator = CustomDomainValidation(dkim_domain)
records = validator.get_expected_mx_records(custom_domain)
# As the domain is a partner_domain but there is no custom config for partner, default records
# should be used
assert len(records) == len(config.EMAIL_SERVERS_WITH_PRIORITY)
for i in range(len(config.EMAIL_SERVERS_WITH_PRIORITY)):
config_record = config.EMAIL_SERVERS_WITH_PRIORITY[i]
assert records[i].priority == config_record[0]
assert records[i].domain == config_record[1]
def test_custom_domain_validation_get_expected_mx_records_domain_from_partner_with_custom_config():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
custom_domain.partner_id = partner_id
Session.commit()
dkim_domain = random_domain()
expected_mx_domain = random_domain()
validator = CustomDomainValidation(
dkim_domain, partner_domains={partner_id: expected_mx_domain}
)
records = validator.get_expected_mx_records(custom_domain)
# As the domain is a partner_domain and there is a custom config for partner, partner records
# should be used
assert len(records) == 2
assert records[0].priority == 10
assert records[0].domain == f"mx1.{expected_mx_domain}."
assert records[1].priority == 20
assert records[1].domain == f"mx2.{expected_mx_domain}."
# get_expected_spf_records
def test_custom_domain_validation_get_expected_spf_record_regular_domain():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
dkim_domain = random_domain()
validator = CustomDomainValidation(
domain, partner_domains={partner_id: dkim_domain}
)
record = validator.get_expected_spf_record(custom_domain)
# As the domain is not a partner_domain, default records should be used even if
# there is a config for the partner
assert record == f"v=spf1 include:{config.EMAIL_DOMAIN} ~all"
def test_custom_domain_validation_get_expected_spf_record_domain_from_partner():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
custom_domain.partner_id = partner_id
Session.commit()
dkim_domain = random_domain()
validator = CustomDomainValidation(dkim_domain)
record = validator.get_expected_spf_record(custom_domain)
# As the domain is a partner_domain but there is no custom config for partner, default records
# should be used
assert record == f"v=spf1 include:{config.EMAIL_DOMAIN} ~all"
def test_custom_domain_validation_get_expected_spf_record_domain_from_partner_with_custom_config():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
custom_domain.partner_id = partner_id
Session.commit()
dkim_domain = random_domain()
expected_mx_domain = random_domain()
validator = CustomDomainValidation(
dkim_domain, partner_domains={partner_id: expected_mx_domain}
)
record = validator.get_expected_spf_record(custom_domain)
# As the domain is a partner_domain and there is a custom config for partner, partner records
# should be used
assert record == f"v=spf1 include:{expected_mx_domain} ~all"
# validate_dkim_records # validate_dkim_records
def test_custom_domain_validation_validate_dkim_records_empty_records_failure(): def test_custom_domain_validation_validate_dkim_records_empty_records_failure():
dns_client = InMemoryDNSClient() dns_client = InMemoryDNSClient()
@ -253,7 +370,7 @@ def test_custom_domain_validation_validate_mx_records_wrong_records_failure():
wrong_record_1 = random_string() wrong_record_1 = random_string()
wrong_record_2 = random_string() wrong_record_2 = random_string()
wrong_records = [(10, wrong_record_1), (20, wrong_record_2)] wrong_records = [MxRecord(10, wrong_record_1), MxRecord(20, wrong_record_2)]
dns_client.set_mx_records(domain.domain, wrong_records) dns_client.set_mx_records(domain.domain, wrong_records)
res = validator.validate_mx_records(domain) res = validator.validate_mx_records(domain)
@ -270,7 +387,7 @@ def test_custom_domain_validation_validate_mx_records_success():
domain = create_custom_domain(random_domain()) domain = create_custom_domain(random_domain())
dns_client.set_mx_records(domain.domain, config.EMAIL_SERVERS_WITH_PRIORITY) dns_client.set_mx_records(domain.domain, validator.get_expected_mx_records(domain))
res = validator.validate_mx_records(domain) res = validator.validate_mx_records(domain)
assert res.success is True assert res.success is True
@ -328,6 +445,58 @@ def test_custom_domain_validation_validate_spf_records_success():
assert db_domain.spf_verified is True assert db_domain.spf_verified is True
def test_custom_domain_validation_validate_spf_records_partner_domain_success():
dns_client = InMemoryDNSClient()
proton_partner_id = get_proton_partner().id
expected_domain = random_domain()
validator = CustomDomainValidation(
dkim_domain=random_domain(),
dns_client=dns_client,
partner_domains={proton_partner_id: expected_domain},
)
domain = create_custom_domain(random_domain())
domain.partner_id = proton_partner_id
Session.commit()
dns_client.set_txt_record(domain.domain, [f"v=spf1 include:{expected_domain}"])
res = validator.validate_spf_records(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.spf_verified is True
def test_custom_domain_validation_validate_spf_cleans_verification_record():
dns_client = InMemoryDNSClient()
proton_partner_id = get_proton_partner().id
expected_domain = random_domain()
validator = CustomDomainValidation(
dkim_domain=random_domain(),
dns_client=dns_client,
partner_domains={proton_partner_id: expected_domain},
)
domain = create_custom_domain(random_domain())
domain.partner_id = proton_partner_id
Session.commit()
wrong_record = random_string()
dns_client.set_txt_record(
hostname=domain.domain,
txt_list=[wrong_record, validator.get_ownership_verification_record(domain)],
)
res = validator.validate_spf_records(domain)
assert res.success is False
assert len(res.errors) == 1
assert res.errors[0] == wrong_record
# validate_dmarc_records # validate_dmarc_records
def test_custom_domain_validation_validate_dmarc_records_empty_failure(): def test_custom_domain_validation_validate_dmarc_records_empty_failure():
dns_client = InMemoryDNSClient() dns_client = InMemoryDNSClient()

View File

@ -3,6 +3,7 @@ from app.dns_utils import (
get_network_dns_client, get_network_dns_client,
is_mx_equivalent, is_mx_equivalent,
InMemoryDNSClient, InMemoryDNSClient,
MxRecord,
) )
from tests.utils import random_domain from tests.utils import random_domain
@ -17,8 +18,8 @@ def test_get_mx_domains():
assert len(r) > 0 assert len(r) > 0
for x in r: for x in r:
assert x[0] > 0 assert x.priority > 0
assert x[1] assert x.domain
def test_get_spf_domain(): def test_get_spf_domain():
@ -33,20 +34,32 @@ def test_get_txt_record():
def test_is_mx_equivalent(): def test_is_mx_equivalent():
assert is_mx_equivalent([], []) assert is_mx_equivalent([], [])
assert is_mx_equivalent([(1, "domain")], [(1, "domain")])
assert is_mx_equivalent( assert is_mx_equivalent(
[(10, "domain1"), (20, "domain2")], [(10, "domain1"), (20, "domain2")] mx_domains=[MxRecord(1, "domain")], ref_mx_domains=[MxRecord(1, "domain")]
) )
assert is_mx_equivalent( assert is_mx_equivalent(
[(5, "domain1"), (10, "domain2")], [(10, "domain1"), (20, "domain2")] mx_domains=[MxRecord(10, "domain1"), MxRecord(20, "domain2")],
ref_mx_domains=[MxRecord(10, "domain1"), MxRecord(20, "domain2")],
) )
assert is_mx_equivalent( assert is_mx_equivalent(
[(5, "domain1"), (10, "domain2"), (20, "domain3")], mx_domains=[MxRecord(5, "domain1"), MxRecord(10, "domain2")],
[(10, "domain1"), (20, "domain2")], ref_mx_domains=[MxRecord(10, "domain1"), MxRecord(20, "domain2")],
)
assert is_mx_equivalent(
mx_domains=[
MxRecord(5, "domain1"),
MxRecord(10, "domain2"),
MxRecord(20, "domain3"),
],
ref_mx_domains=[MxRecord(10, "domain1"), MxRecord(20, "domain2")],
) )
assert not is_mx_equivalent( assert not is_mx_equivalent(
[(5, "domain1"), (10, "domain2")], mx_domains=[MxRecord(5, "domain1"), MxRecord(10, "domain2")],
[(10, "domain1"), (20, "domain2"), (20, "domain3")], ref_mx_domains=[
MxRecord(10, "domain1"),
MxRecord(20, "domain2"),
MxRecord(20, "domain3"),
],
) )

View File

@ -90,12 +90,19 @@ def test_can_be_used_as_personal_email(flask_client):
assert not email_can_be_used_as_mailbox("ab@sl.local") assert not email_can_be_used_as_mailbox("ab@sl.local")
assert not email_can_be_used_as_mailbox("hey@d1.test") assert not email_can_be_used_as_mailbox("hey@d1.test")
# custom domain # custom domain as SL domain
domain = random_domain() domain = random_domain()
user = create_new_user() user = create_new_user()
CustomDomain.create(user_id=user.id, domain=domain, verified=True, commit=True) domain_obj = CustomDomain.create(
user_id=user.id, domain=domain, verified=True, is_sl_subdomain=True, flush=True
)
assert not email_can_be_used_as_mailbox(f"hey@{domain}") assert not email_can_be_used_as_mailbox(f"hey@{domain}")
# custom domain is NOT SL domain
domain_obj.is_sl_subdomain = False
Session.flush()
assert email_can_be_used_as_mailbox(f"hey@{domain}")
# disposable domain # disposable domain
disposable_domain = random_domain() disposable_domain = random_domain()
InvalidMailboxDomain.create(domain=disposable_domain, commit=True) InvalidMailboxDomain.create(domain=disposable_domain, commit=True)