Compare commits

...

2 Commits

Author SHA1 Message Date
a5801551d0 4.65.1
Some checks failed
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 20m30s
Build-Release-Image / Build-Image (linux/amd64) (push) Has been cancelled
Build-Release-Image / Merge-Images (push) Has been cancelled
Build-Release-Image / Create-Release (push) Has been cancelled
Build-Release-Image / Notify (push) Has been cancelled
2025-02-04 12:00:06 +00:00
9c2a35193c 4.64.4
All checks were successful
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 2m50s
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 20m44s
Build-Release-Image / Merge-Images (push) Successful in 26s
Build-Release-Image / Create-Release (push) Successful in 12s
Build-Release-Image / Notify (push) Successful in 17s
2025-01-28 12:00:06 +00:00
26 changed files with 730 additions and 305 deletions

View File

@ -1,21 +1,25 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional, List from typing import Optional, List
import arrow import arrow
import sqlalchemy import sqlalchemy
from flask_admin import BaseView
from flask_admin.form import SecureForm
from flask_admin.model.template import EndpointLinkRowAction
from markupsafe import Markup
from app import models, s3, config
from flask import redirect, url_for, request, flash, Response from flask import redirect, url_for, request, flash, Response
from flask_admin import BaseView
from flask_admin import expose, AdminIndexView from flask_admin import expose, AdminIndexView
from flask_admin.actions import action from flask_admin.actions import action
from flask_admin.contrib import sqla from flask_admin.contrib import sqla
from flask_admin.form import SecureForm
from flask_admin.model.template import EndpointLinkRowAction
from flask_login import current_user from flask_login import current_user
from markupsafe import Markup
from app.custom_domain_validation import CustomDomainValidation, DomainValidationResult from app import models, s3, config
from app.custom_domain_validation import (
CustomDomainValidation,
DomainValidationResult,
ExpectedValidationRecords,
)
from app.db import Session from app.db import Session
from app.dns_utils import get_network_dns_client from app.dns_utils import get_network_dns_client
from app.events.event_dispatcher import EventDispatcher from app.events.event_dispatcher import EventDispatcher
@ -929,13 +933,13 @@ class EmailSearchAdmin(BaseView):
class CustomDomainWithValidationData: class CustomDomainWithValidationData:
def __init__(self, domain: CustomDomain): def __init__(self, domain: CustomDomain):
self.domain: CustomDomain = domain self.domain: CustomDomain = domain
self.ownership_expected: Optional[str] = None self.ownership_expected: Optional[ExpectedValidationRecords] = None
self.ownership_validation: Optional[DomainValidationResult] = None self.ownership_validation: Optional[DomainValidationResult] = None
self.mx_expected: Optional[str] = None self.mx_expected: Optional[dict[int, ExpectedValidationRecords]] = None
self.mx_validation: Optional[DomainValidationResult] = None self.mx_validation: Optional[DomainValidationResult] = None
self.spf_expected: Optional[str] = None self.spf_expected: Optional[ExpectedValidationRecords] = None
self.spf_validation: Optional[DomainValidationResult] = None self.spf_validation: Optional[DomainValidationResult] = None
self.dkim_expected: {str: str} = {} self.dkim_expected: {str: ExpectedValidationRecords} = {}
self.dkim_validation: {str: str} = {} self.dkim_validation: {str: str} = {}
@ -990,7 +994,6 @@ class CustomDomainSearchResult:
custom_domain custom_domain
) )
out.domains.append(validation_data) out.domains.append(validation_data)
print(validation_data.dkim_expected, validation_data.dkim_validation)
return out return out
@ -1020,7 +1023,6 @@ class CustomDomainSearchAdmin(BaseView):
if cd is not None: if cd is not None:
user = cd.user user = cd.user
search = CustomDomainSearchResult.from_user(user) search = CustomDomainSearchResult.from_user(user)
print("NEW", search.domains)
return self.render( return self.render(
"admin/custom_domain_search.html", "admin/custom_domain_search.html",

View File

@ -36,6 +36,7 @@ def set_mailboxes_for_alias(
Mailbox.user_id == user_id, Mailbox.user_id == user_id,
Mailbox.verified == True, # noqa: E712 Mailbox.verified == True, # noqa: E712
) )
.order_by(Mailbox.id.asc())
.all() .all()
) )
if len(mailboxes) != len(mailbox_ids): if len(mailboxes) != len(mailbox_ids):

View File

@ -191,15 +191,8 @@ def get_alias_infos_with_pagination_v3(
q = q.order_by(Alias.email.desc()) q = q.order_by(Alias.email.desc())
else: else:
# default sorting # default sorting
latest_activity = case(
[
(Alias.created_at > EmailLog.created_at, Alias.created_at),
(Alias.created_at < EmailLog.created_at, EmailLog.created_at),
],
else_=Alias.created_at,
)
q = q.order_by(Alias.pinned.desc()) q = q.order_by(Alias.pinned.desc())
q = q.order_by(latest_activity.desc()) q = q.order_by(func.greatest(Alias.created_at, EmailLog.created_at).desc())
q = q.limit(page_limit).offset(page_id * page_size) q = q.limit(page_limit).offset(page_id * page_size)

View File

@ -9,7 +9,14 @@ from app.email_utils import send_email
from app.events.event_dispatcher import EventDispatcher from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import EventContent, UserPlanChanged from app.events.generated.event_pb2 import EventContent, UserPlanChanged
from app.log import LOG from app.log import LOG
from app.models import User, ManualSubscription, Coupon, LifetimeCoupon from app.models import (
User,
ManualSubscription,
Coupon,
LifetimeCoupon,
PartnerSubscription,
PartnerUser,
)
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
@ -87,6 +94,16 @@ def redeem_coupon(coupon_code: str, user: User) -> Optional[Coupon]:
def redeem_lifetime_coupon(coupon_code: str, user: User) -> Optional[Coupon]: def redeem_lifetime_coupon(coupon_code: str, user: User) -> Optional[Coupon]:
if user.lifetime:
return None
partner_sub = (
Session.query(PartnerSubscription)
.join(PartnerUser, PartnerUser.id == PartnerSubscription.partner_user_id)
.filter(PartnerUser.user_id == user.id, PartnerSubscription.lifetime == True) # noqa: E712
.first()
)
if partner_sub is not None:
return None
coupon: LifetimeCoupon = LifetimeCoupon.get_by(code=coupon_code) coupon: LifetimeCoupon = LifetimeCoupon.get_by(code=coupon_code)
if not coupon: if not coupon:
return None return None

View File

@ -5,9 +5,7 @@ from app import config
from app.constants import DMARC_RECORD from app.constants import DMARC_RECORD
from app.db import Session from app.db import Session
from app.dns_utils import ( from app.dns_utils import (
MxRecord,
DNSClient, DNSClient,
is_mx_equivalent,
get_network_dns_client, get_network_dns_client,
) )
from app.models import CustomDomain from app.models import CustomDomain
@ -21,6 +19,39 @@ class DomainValidationResult:
errors: [str] errors: [str]
@dataclass
class ExpectedValidationRecords:
recommended: str
allowed: list[str]
def is_mx_equivalent(
mx_domains: dict[int, list[str]],
expected_mx_domains: dict[int, ExpectedValidationRecords],
) -> bool:
"""
Compare mx_domains with ref_mx_domains to see if they are equivalent.
mx_domains and ref_mx_domains are list of (priority, domain)
The priority order is taken into account but not the priority number.
For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)]
"""
expected_prios = []
for prio in expected_mx_domains:
expected_prios.append(prio)
if len(expected_prios) != len(mx_domains):
return False
for prio_position, prio_value in enumerate(sorted(mx_domains.keys())):
for domain in mx_domains[prio_value]:
if domain not in expected_mx_domains[expected_prios[prio_position]].allowed:
return False
return True
class CustomDomainValidation: class CustomDomainValidation:
def __init__( def __init__(
self, self,
@ -37,59 +68,88 @@ class CustomDomainValidation:
or config.PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES or config.PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES
) )
def get_ownership_verification_record(self, domain: CustomDomain) -> str: def get_ownership_verification_record(
prefix = "sl" self, domain: CustomDomain
) -> ExpectedValidationRecords:
prefixes = ["sl"]
if ( if (
domain.partner_id is not None domain.partner_id is not None
and domain.partner_id in self._partner_domain_validation_prefixes and domain.partner_id in self._partner_domain_validation_prefixes
): ):
prefix = self._partner_domain_validation_prefixes[domain.partner_id] prefixes.insert(
0, self._partner_domain_validation_prefixes[domain.partner_id]
)
if not domain.ownership_txt_token: if not domain.ownership_txt_token:
domain.ownership_txt_token = random_string(30) domain.ownership_txt_token = random_string(30)
Session.commit() Session.commit()
return f"{prefix}-verification={domain.ownership_txt_token}" valid = [
f"{prefix}-verification={domain.ownership_txt_token}" for prefix in prefixes
]
return ExpectedValidationRecords(recommended=valid[0], allowed=valid)
def get_expected_mx_records(self, domain: CustomDomain) -> list[MxRecord]: def get_expected_mx_records(
records = [] self, domain: CustomDomain
) -> dict[int, ExpectedValidationRecords]:
records = {}
if domain.partner_id is not None and domain.partner_id in self._partner_domains: if domain.partner_id is not None and domain.partner_id in self._partner_domains:
domain = self._partner_domains[domain.partner_id] domain = self._partner_domains[domain.partner_id]
records.append(MxRecord(10, f"mx1.{domain}.")) records[10] = [f"mx1.{domain}."]
records.append(MxRecord(20, f"mx2.{domain}.")) records[20] = [f"mx2.{domain}."]
else: # Default ones
# Default ones for priority, domain in config.EMAIL_SERVERS_WITH_PRIORITY:
for priority, domain in config.EMAIL_SERVERS_WITH_PRIORITY: if priority not in records:
records.append(MxRecord(priority, domain)) records[priority] = []
records[priority].append(domain)
return records return {
priority: ExpectedValidationRecords(
recommended=records[priority][0], allowed=records[priority]
)
for priority in records
}
def get_expected_spf_domain(self, domain: CustomDomain) -> str: def get_expected_spf_domain(
self, domain: CustomDomain
) -> ExpectedValidationRecords:
records = []
if domain.partner_id is not None and domain.partner_id in self._partner_domains: if domain.partner_id is not None and domain.partner_id in self._partner_domains:
return self._partner_domains[domain.partner_id] records.append(self._partner_domains[domain.partner_id])
else: else:
return config.EMAIL_DOMAIN records.append(config.EMAIL_DOMAIN)
return ExpectedValidationRecords(recommended=records[0], allowed=records)
def get_expected_spf_record(self, domain: CustomDomain) -> str: def get_expected_spf_record(self, domain: CustomDomain) -> str:
spf_domain = self.get_expected_spf_domain(domain) spf_domain = self.get_expected_spf_domain(domain)
return f"v=spf1 include:{spf_domain} ~all" return f"v=spf1 include:{spf_domain.recommended} ~all"
def get_dkim_records(self, domain: CustomDomain) -> {str: str}: def get_dkim_records(
self, domain: CustomDomain
) -> {str: ExpectedValidationRecords}:
""" """
Get a list of dkim records to set up. Depending on the custom_domain, whether if it's from a partner or not, Get a list of dkim records to set up. Depending on the custom_domain, whether if it's from a partner or not,
it will return the default ones or the partner ones. it will return the default ones or the partner ones.
""" """
# By default use the default domain # By default use the default domain
dkim_domain = self.dkim_domain dkim_domains = [self.dkim_domain]
if domain.partner_id is not None: if domain.partner_id is not None:
# Domain is from a partner. Retrieve the partner config and use that domain if exists # Domain is from a partner. Retrieve the partner config and use that domain as preferred if it exists
dkim_domain = self._partner_domains.get(domain.partner_id, dkim_domain) partner_domain = self._partner_domains.get(domain.partner_id, None)
if partner_domain is not None:
dkim_domains.insert(0, partner_domain)
return { output = {}
f"{key}._domainkey": f"{key}._domainkey.{dkim_domain}" for key in ("dkim", "dkim02", "dkim03"):
for key in ("dkim", "dkim02", "dkim03") records = [
} f"{key}._domainkey.{dkim_domain}" for dkim_domain in dkim_domains
]
output[f"{key}._domainkey"] = ExpectedValidationRecords(
recommended=records[0], allowed=records
)
return output
def validate_dkim_records(self, custom_domain: CustomDomain) -> dict[str, str]: def validate_dkim_records(self, custom_domain: CustomDomain) -> dict[str, str]:
""" """
@ -102,7 +162,7 @@ class CustomDomainValidation:
for prefix, expected_record in expected_records.items(): for prefix, expected_record in expected_records.items():
custom_record = f"{prefix}.{custom_domain.domain}" custom_record = f"{prefix}.{custom_domain.domain}"
dkim_record = self._dns_client.get_cname_record(custom_record) dkim_record = self._dns_client.get_cname_record(custom_record)
if dkim_record == expected_record: if dkim_record in expected_record.allowed:
correct_records[prefix] = custom_record correct_records[prefix] = custom_record
else: else:
invalid_records[custom_record] = dkim_record or "empty" invalid_records[custom_record] = dkim_record or "empty"
@ -138,11 +198,15 @@ class CustomDomainValidation:
Check if the custom_domain has added the ownership verification records Check if the custom_domain has added the ownership verification records
""" """
txt_records = self._dns_client.get_txt_record(custom_domain.domain) txt_records = self._dns_client.get_txt_record(custom_domain.domain)
expected_verification_record = self.get_ownership_verification_record( expected_verification_records = self.get_ownership_verification_record(
custom_domain custom_domain
) )
found = False
if expected_verification_record in txt_records: for verification_record in expected_verification_records.allowed:
if verification_record in txt_records:
found = True
break
if found:
custom_domain.ownership_verified = True custom_domain.ownership_verified = True
emit_user_audit_log( emit_user_audit_log(
user=custom_domain.user, user=custom_domain.user,
@ -161,10 +225,11 @@ class CustomDomainValidation:
expected_mx_records = self.get_expected_mx_records(custom_domain) expected_mx_records = self.get_expected_mx_records(custom_domain)
if not is_mx_equivalent(mx_domains, expected_mx_records): if not is_mx_equivalent(mx_domains, expected_mx_records):
return DomainValidationResult( errors = []
success=False, for prio in mx_domains:
errors=[f"{record.priority} {record.domain}" for record in mx_domains], for mx_domain in mx_domains[prio]:
) errors.append(f"{prio} {mx_domain}")
return DomainValidationResult(success=False, errors=errors)
else: else:
custom_domain.verified = True custom_domain.verified = True
emit_user_audit_log( emit_user_audit_log(
@ -180,7 +245,7 @@ class CustomDomainValidation:
) -> DomainValidationResult: ) -> DomainValidationResult:
spf_domains = self._dns_client.get_spf_domain(custom_domain.domain) spf_domains = self._dns_client.get_spf_domain(custom_domain.domain)
expected_spf_domain = self.get_expected_spf_domain(custom_domain) expected_spf_domain = self.get_expected_spf_domain(custom_domain)
if expected_spf_domain in spf_domains: if len(set(expected_spf_domain.allowed).intersection(set(spf_domains))) > 0:
custom_domain.spf_verified = True custom_domain.spf_verified = True
emit_user_audit_log( emit_user_audit_log(
user=custom_domain.user, user=custom_domain.user,
@ -221,8 +286,8 @@ class CustomDomainValidation:
self, txt_records: List[str], custom_domain: CustomDomain self, txt_records: List[str], custom_domain: CustomDomain
) -> List[str]: ) -> List[str]:
final_records = [] final_records = []
verification_record = self.get_ownership_verification_record(custom_domain) verification_records = self.get_ownership_verification_record(custom_domain)
for record in txt_records: for record in txt_records:
if record != verification_record: if record not in verification_records.allowed:
final_records.append(record) final_records.append(record)
return final_records return final_records

View File

@ -5,8 +5,8 @@ from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import StringField, validators, IntegerField from wtforms import StringField, validators, IntegerField
from app.constants import DMARC_RECORD
from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN
from app.constants import DMARC_RECORD
from app.custom_domain_utils import delete_custom_domain, set_custom_domain_mailboxes from app.custom_domain_utils import delete_custom_domain, set_custom_domain_mailboxes
from app.custom_domain_validation import CustomDomainValidation from app.custom_domain_validation import CustomDomainValidation
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
@ -137,7 +137,7 @@ def domain_detail_dns(custom_domain_id):
return render_template( return render_template(
"dashboard/domain_detail/dns.html", "dashboard/domain_detail/dns.html",
EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY, EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY,
ownership_record=domain_validator.get_ownership_verification_record( ownership_records=domain_validator.get_ownership_verification_record(
custom_domain custom_domain
), ),
expected_mx_records=domain_validator.get_expected_mx_records(custom_domain), expected_mx_records=domain_validator.get_expected_mx_records(custom_domain),

View File

@ -267,12 +267,13 @@ def cancel_mailbox_change_route(mailbox_id):
@dashboard_bp.route("/mailbox/confirm_change") @dashboard_bp.route("/mailbox/confirm_change")
@login_required
@limiter.limit("3/minute")
def mailbox_confirm_email_change_route(): def mailbox_confirm_email_change_route():
mailbox_id = request.args.get("mailbox_id") mailbox_id = request.args.get("mailbox_id")
code = request.args.get("code") code = request.args.get("code")
if code: if code:
print("HAS OCO", code)
try: try:
mailbox = mailbox_utils.verify_mailbox_code(current_user, mailbox_id, code) mailbox = mailbox_utils.verify_mailbox_code(current_user, mailbox_id, code)
flash("Successfully changed mailbox email", "success") flash("Successfully changed mailbox email", "success")
@ -280,7 +281,6 @@ def mailbox_confirm_email_change_route():
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox.id) url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox.id)
) )
except mailbox_utils.MailboxError as e: except mailbox_utils.MailboxError as e:
print(e)
flash(f"Cannot verify mailbox: {e.msg}", "error") flash(f"Cannot verify mailbox: {e.msg}", "error")
return redirect(url_for("dashboard.mailbox_route")) return redirect(url_for("dashboard.mailbox_route"))
else: else:

View File

@ -1,5 +1,4 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional from typing import List, Optional
import dns.resolver import dns.resolver
@ -9,42 +8,13 @@ from app.config import NAMESERVERS
_include_spf = "include:" _include_spf = "include:"
@dataclass
class MxRecord:
priority: int
domain: str
def is_mx_equivalent(
mx_domains: List[MxRecord], ref_mx_domains: List[MxRecord]
) -> bool:
"""
Compare mx_domains with ref_mx_domains to see if they are equivalent.
mx_domains and ref_mx_domains are list of (priority, domain)
The priority order is taken into account but not the priority number.
For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)]
"""
mx_domains = sorted(mx_domains, key=lambda x: x.priority)
ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x.priority)
if len(mx_domains) < len(ref_mx_domains):
return False
for actual, expected in zip(mx_domains, ref_mx_domains):
if actual.domain != expected.domain:
return False
return True
class DNSClient(ABC): class DNSClient(ABC):
@abstractmethod @abstractmethod
def get_cname_record(self, hostname: str) -> Optional[str]: def get_cname_record(self, hostname: str) -> Optional[str]:
pass pass
@abstractmethod @abstractmethod
def get_mx_domains(self, hostname: str) -> List[MxRecord]: def get_mx_domains(self, hostname: str) -> dict[int, list[str]]:
pass pass
def get_spf_domain(self, hostname: str) -> List[str]: def get_spf_domain(self, hostname: str) -> List[str]:
@ -88,21 +58,24 @@ class NetworkDNSClient(DNSClient):
except Exception: except Exception:
return None return None
def get_mx_domains(self, hostname: str) -> List[MxRecord]: def get_mx_domains(self, hostname: str) -> dict[int, list[str]]:
""" """
return list of (priority, domain name) sorted by priority (lowest priority first) return list of (priority, domain name) sorted by priority (lowest priority first)
domain name ends with a "." at the end. domain name ends with a "." at the end.
""" """
ret = {}
try: try:
answers = self._resolver.resolve(hostname, "MX", search=True) answers = self._resolver.resolve(hostname, "MX", search=True)
ret = []
for a in answers: for a in answers:
record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.' record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
parts = record.split(" ") parts = record.split(" ")
ret.append(MxRecord(priority=int(parts[0]), domain=parts[1])) prio = int(parts[0])
return sorted(ret, key=lambda x: x.priority) if prio not in ret:
ret[prio] = []
ret[prio].append(parts[1])
except Exception: except Exception:
return [] pass
return ret
def get_txt_record(self, hostname: str) -> List[str]: def get_txt_record(self, hostname: str) -> List[str]:
try: try:
@ -119,14 +92,14 @@ class NetworkDNSClient(DNSClient):
class InMemoryDNSClient(DNSClient): class InMemoryDNSClient(DNSClient):
def __init__(self): def __init__(self):
self.cname_records: dict[str, Optional[str]] = {} self.cname_records: dict[str, Optional[str]] = {}
self.mx_records: dict[str, List[MxRecord]] = {} self.mx_records: dict[int, dict[int, list[str]]] = {}
self.spf_records: dict[str, List[str]] = {} self.spf_records: dict[str, List[str]] = {}
self.txt_records: dict[str, List[str]] = {} self.txt_records: dict[str, List[str]] = {}
def set_cname_record(self, hostname: str, cname: str): def set_cname_record(self, hostname: str, cname: str):
self.cname_records[hostname] = cname self.cname_records[hostname] = cname
def set_mx_records(self, hostname: str, mx_list: List[MxRecord]): def set_mx_records(self, hostname: str, mx_list: dict[int, list[str]]):
self.mx_records[hostname] = mx_list self.mx_records[hostname] = mx_list
def set_txt_record(self, hostname: str, txt_list: List[str]): def set_txt_record(self, hostname: str, txt_list: List[str]):
@ -135,9 +108,8 @@ class InMemoryDNSClient(DNSClient):
def get_cname_record(self, hostname: str) -> Optional[str]: def get_cname_record(self, hostname: str) -> Optional[str]:
return self.cname_records.get(hostname) return self.cname_records.get(hostname)
def get_mx_domains(self, hostname: str) -> List[MxRecord]: def get_mx_domains(self, hostname: str) -> dict[int, list[str]]:
mx_list = self.mx_records.get(hostname, []) return self.mx_records.get(hostname, {})
return sorted(mx_list, key=lambda x: x.priority)
def get_txt_record(self, hostname: str) -> List[str]: def get_txt_record(self, hostname: str) -> List[str]:
return self.txt_records.get(hostname, []) return self.txt_records.get(hostname, [])
@ -147,5 +119,5 @@ def get_network_dns_client() -> NetworkDNSClient:
return NetworkDNSClient(NAMESERVERS) return NetworkDNSClient(NAMESERVERS)
def get_mx_domains(hostname: str) -> List[MxRecord]: def get_mx_domains(hostname: str) -> dict[int, list[str]]:
return get_network_dns_client().get_mx_domains(hostname) return get_network_dns_client().get_mx_domains(hostname)

View File

@ -657,7 +657,11 @@ def get_mx_domain_list(domain) -> [str]:
""" """
priority_domains = get_mx_domains(domain) priority_domains = get_mx_domains(domain)
return [d.domain[:-1] for d in priority_domains] mx_domains = []
for prio in priority_domains:
for domain in priority_domains[prio]:
mx_domains.append(domain[:-1])
return mx_domains
def personal_email_already_used(email_address: str) -> bool: def personal_email_already_used(email_address: str) -> bool:

View File

@ -2,8 +2,8 @@ import urllib
from email.header import Header from email.header import Header
from email.message import Message from email.message import Message
from app.email import headers
from app import config from app import config
from app.email import headers
from app.email_utils import add_or_replace_header, delete_header from app.email_utils import add_or_replace_header, delete_header
from app.handler.unsubscribe_encoder import ( from app.handler.unsubscribe_encoder import (
UnsubscribeEncoder, UnsubscribeEncoder,
@ -46,7 +46,11 @@ class UnsubscribeGenerator:
if start == -1 or end == -1 or start >= end: if start == -1 or end == -1 or start >= end:
continue continue
method = raw_method[start + 1 : end] method = raw_method[start + 1 : end]
url_data = urllib.parse.urlparse(method) try:
url_data = urllib.parse.urlparse(method)
except ValueError:
LOG.debug(f"Unsub has invalid method {method}. Ignoring.")
continue
if url_data.scheme == "mailto": if url_data.scheme == "mailto":
if url_data.path == config.UNSUBSCRIBER: if url_data.path == config.UNSUBSCRIBER:
LOG.debug( LOG.debug(

View File

@ -10,7 +10,7 @@ from app.config import (
# this format allows clickable link to code source in PyCharm # this format allows clickable link to code source in PyCharm
_log_format = ( _log_format = (
"%(asctime)s - %(name)s - %(levelname)s - %(process)d - " "%(asctime)s - %(name)s - %(levelname)s - %(process)d - %(request_id)s"
'"%(pathname)s:%(lineno)d" - %(funcName)s() - %(message_id)s - %(message)s' '"%(pathname)s:%(lineno)d" - %(funcName)s() - %(message_id)s - %(message)s'
) )
_log_formatter = logging.Formatter(_log_format) _log_formatter = logging.Formatter(_log_format)
@ -37,6 +37,21 @@ class EmailHandlerFilter(logging.Filter):
return _MESSAGE_ID return _MESSAGE_ID
class RequestIdFilter(logging.Filter):
"""automatically add request-id to keep track of a request"""
def filter(self, record):
from flask import g, has_request_context
request_id = ""
if has_request_context():
ctx_request_id = getattr(g, "request_id")
if ctx_request_id:
request_id = f"{ctx_request_id} - "
record.request_id = request_id
return True
def _get_console_handler(): def _get_console_handler():
console_handler = logging.StreamHandler(sys.stdout) console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(_log_formatter) console_handler.setFormatter(_log_formatter)
@ -54,6 +69,7 @@ def _get_logger(name) -> logging.Logger:
logger.addHandler(_get_console_handler()) logger.addHandler(_get_console_handler())
logger.addFilter(EmailHandlerFilter()) logger.addFilter(EmailHandlerFilter())
logger.addFilter(RequestIdFilter())
# no propagation to avoid propagating to root logger # no propagation to avoid propagating to root logger
logger.propagate = False logger.propagate = False

View File

@ -32,7 +32,6 @@ from app import config, rate_limiter
from app import s3 from app import s3
from app.db import Session from app.db import Session
from app.dns_utils import get_mx_domains from app.dns_utils import get_mx_domains
from app.errors import ( from app.errors import (
AliasInTrashError, AliasInTrashError,
DirectoryInTrashError, DirectoryInTrashError,
@ -362,7 +361,7 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
sa.Boolean, default=True, nullable=False, server_default="1" sa.Boolean, default=True, nullable=False, server_default="1"
) )
activated = sa.Column(sa.Boolean, default=False, nullable=False, index=True) activated = sa.Column(sa.Boolean, default=False, nullable=False)
# an account can be disabled if having harmful behavior # an account can be disabled if having harmful behavior
disabled = sa.Column(sa.Boolean, default=False, nullable=False, server_default="0") disabled = sa.Column(sa.Boolean, default=False, nullable=False, server_default="0")
@ -576,6 +575,12 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
"ix_users_default_alias_custom_domain_id", default_alias_custom_domain_id "ix_users_default_alias_custom_domain_id", default_alias_custom_domain_id
), ),
sa.Index("ix_users_profile_picture_id", profile_picture_id), sa.Index("ix_users_profile_picture_id", profile_picture_id),
sa.Index(
"idx_users_email_trgm",
"email",
postgresql_ops={"email": "gin_trgm_ops"},
postgresql_using="gin",
),
) )
@property @property
@ -1924,13 +1929,16 @@ class Contact(Base, ModelMixin):
__table_args__ = ( __table_args__ = (
sa.UniqueConstraint("alias_id", "website_email", name="uq_contact"), sa.UniqueConstraint("alias_id", "website_email", name="uq_contact"),
sa.Index("ix_contact_user_id_id", "user_id", "id"),
) )
user_id = sa.Column( user_id = sa.Column(
sa.ForeignKey(User.id, ondelete="cascade"), nullable=False, index=True sa.ForeignKey(User.id, ondelete="cascade"),
nullable=False,
) )
alias_id = sa.Column( alias_id = sa.Column(
sa.ForeignKey(Alias.id, ondelete="cascade"), nullable=False, index=True sa.ForeignKey(Alias.id, ondelete="cascade"),
nullable=False,
) )
name = sa.Column( name = sa.Column(
@ -2115,11 +2123,10 @@ class EmailLog(Base, ModelMixin):
Index("ix_email_log_mailbox_id", "mailbox_id"), Index("ix_email_log_mailbox_id", "mailbox_id"),
Index("ix_email_log_bounced_mailbox_id", "bounced_mailbox_id"), Index("ix_email_log_bounced_mailbox_id", "bounced_mailbox_id"),
Index("ix_email_log_refused_email_id", "refused_email_id"), Index("ix_email_log_refused_email_id", "refused_email_id"),
Index("ix_email_log_user_id_email_log_id", "user_id", "id"),
) )
user_id = sa.Column( user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False)
sa.ForeignKey(User.id, ondelete="cascade"), nullable=False, index=True
)
contact_id = sa.Column( contact_id = sa.Column(
sa.ForeignKey(Contact.id, ondelete="cascade"), nullable=False, index=True sa.ForeignKey(Contact.id, ondelete="cascade"), nullable=False, index=True
) )
@ -2395,7 +2402,8 @@ class AliasUsedOn(Base, ModelMixin):
) )
alias_id = sa.Column( alias_id = sa.Column(
sa.ForeignKey(Alias.id, ondelete="cascade"), nullable=False, index=True sa.ForeignKey(Alias.id, ondelete="cascade"),
nullable=False,
) )
user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False)
@ -2418,10 +2426,7 @@ class ApiKey(Base, ModelMixin):
user = orm.relationship(User) user = orm.relationship(User)
__table_args__ = ( __table_args__ = (sa.Index("ix_api_key_user_id", "user_id"),)
sa.Index("ix_api_key_code", "code"),
sa.Index("ix_api_key_user_id", "user_id"),
)
@classmethod @classmethod
def create(cls, user_id, name=None, **kwargs): def create(cls, user_id, name=None, **kwargs):
@ -2581,7 +2586,6 @@ class AutoCreateRule(Base, ModelMixin):
sa.UniqueConstraint( sa.UniqueConstraint(
"custom_domain_id", "order", name="uq_auto_create_rule_order" "custom_domain_id", "order", name="uq_auto_create_rule_order"
), ),
sa.Index("ix_auto_create_rule_custom_domain_id", "custom_domain_id"),
) )
custom_domain_id = sa.Column( custom_domain_id = sa.Column(
@ -2764,7 +2768,6 @@ class Job(Base, ModelMixin):
nullable=False, nullable=False,
server_default=str(JobState.ready.value), server_default=str(JobState.ready.value),
default=JobState.ready.value, default=JobState.ready.value,
index=True,
) )
attempts = sa.Column(sa.Integer, nullable=False, server_default="0", default=0) attempts = sa.Column(sa.Integer, nullable=False, server_default="0", default=0)
taken_at = sa.Column(ArrowType, nullable=True) taken_at = sa.Column(ArrowType, nullable=True)
@ -2777,9 +2780,7 @@ class Job(Base, ModelMixin):
class Mailbox(Base, ModelMixin): class Mailbox(Base, ModelMixin):
__tablename__ = "mailbox" __tablename__ = "mailbox"
user_id = sa.Column( user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False)
sa.ForeignKey(User.id, ondelete="cascade"), nullable=False, index=True
)
email = sa.Column(sa.String(256), nullable=False, index=True) email = sa.Column(sa.String(256), nullable=False, index=True)
verified = sa.Column(sa.Boolean, default=False, nullable=False) verified = sa.Column(sa.Boolean, default=False, nullable=False)
force_spf = sa.Column(sa.Boolean, default=True, server_default="1", nullable=False) force_spf = sa.Column(sa.Boolean, default=True, server_default="1", nullable=False)
@ -2808,6 +2809,13 @@ class Mailbox(Base, ModelMixin):
__table_args__ = ( __table_args__ = (
sa.UniqueConstraint("user_id", "email", name="uq_mailbox_user"), sa.UniqueConstraint("user_id", "email", name="uq_mailbox_user"),
sa.Index("ix_mailbox_pgp_finger_print", "pgp_finger_print"), sa.Index("ix_mailbox_pgp_finger_print", "pgp_finger_print"),
# index on email column using pg_trgm
Index(
"ix_mailbox_email_trgm_idx",
"email",
postgresql_ops={"email": "gin_trgm_ops"},
postgresql_using="gin",
),
) )
user = orm.relationship(User, foreign_keys=[user_id]) user = orm.relationship(User, foreign_keys=[user_id])
@ -3010,7 +3018,11 @@ class SentAlert(Base, ModelMixin):
to_email = sa.Column(sa.String(256), nullable=False) to_email = sa.Column(sa.String(256), nullable=False)
alert_type = sa.Column(sa.String(256), nullable=False) alert_type = sa.Column(sa.String(256), nullable=False)
__table_args__ = (sa.Index("ix_sent_alert_user_id", "user_id"),) __table_args__ = (
sa.Index("ix_sent_alert_user_id", "user_id"),
sa.Index("ix_sent_alert_to_email", "to_email"),
sa.Index("ix_sent_alert_alert_type", "alert_type"),
)
class AliasMailbox(Base, ModelMixin): class AliasMailbox(Base, ModelMixin):
@ -3020,7 +3032,8 @@ class AliasMailbox(Base, ModelMixin):
) )
alias_id = sa.Column( alias_id = sa.Column(
sa.ForeignKey(Alias.id, ondelete="cascade"), nullable=False, index=True sa.ForeignKey(Alias.id, ondelete="cascade"),
nullable=False,
) )
mailbox_id = sa.Column( mailbox_id = sa.Column(
sa.ForeignKey(Mailbox.id, ondelete="cascade"), nullable=False, index=True sa.ForeignKey(Mailbox.id, ondelete="cascade"), nullable=False, index=True
@ -3035,7 +3048,8 @@ class AliasHibp(Base, ModelMixin):
__table_args__ = (sa.UniqueConstraint("alias_id", "hibp_id", name="uq_alias_hibp"),) __table_args__ = (sa.UniqueConstraint("alias_id", "hibp_id", name="uq_alias_hibp"),)
alias_id = sa.Column( alias_id = sa.Column(
sa.Integer(), sa.ForeignKey("alias.id", ondelete="cascade"), index=True sa.Integer(),
sa.ForeignKey("alias.id", ondelete="cascade"),
) )
hibp_id = sa.Column( hibp_id = sa.Column(
sa.Integer(), sa.ForeignKey("hibp.id", ondelete="cascade"), index=True sa.Integer(), sa.ForeignKey("hibp.id", ondelete="cascade"), index=True
@ -3751,7 +3765,8 @@ class PartnerUser(Base, ModelMixin):
index=True, index=True,
) )
partner_id = sa.Column( partner_id = sa.Column(
sa.ForeignKey("partner.id", ondelete="cascade"), nullable=False, index=True sa.ForeignKey("partner.id", ondelete="cascade"),
nullable=False,
) )
external_user_id = sa.Column(sa.String(128), unique=False, nullable=False) external_user_id = sa.Column(sa.String(128), unique=False, nullable=False)
partner_email = sa.Column(sa.String(255), unique=False, nullable=True) partner_email = sa.Column(sa.String(255), unique=False, nullable=True)

6
app/app/request_utils.py Normal file
View File

@ -0,0 +1,6 @@
from random import randbytes
from base64 import b64encode
def generate_request_id() -> str:
return b64encode(randbytes(6)).decode()

View File

@ -14,9 +14,9 @@ from sqlalchemy.sql import Insert, text
from app import s3, config from app import s3, config
from app.alias_utils import nb_email_log_for_mailbox from app.alias_utils import nb_email_log_for_mailbox
from app.api.views.apple import verify_receipt from app.api.views.apple import verify_receipt
from app.custom_domain_validation import CustomDomainValidation from app.custom_domain_validation import CustomDomainValidation, is_mx_equivalent
from app.db import Session from app.db import Session
from app.dns_utils import get_mx_domains, is_mx_equivalent from app.dns_utils import get_mx_domains
from app.email_utils import ( from app.email_utils import (
send_email, send_email,
send_trial_end_soon_email, send_trial_end_soon_email,

View File

@ -47,6 +47,11 @@ from typing import List, Tuple, Optional
import newrelic.agent import newrelic.agent
from aiosmtpd.controller import Controller from aiosmtpd.controller import Controller
from aiosmtpd.smtp import Envelope from aiosmtpd.smtp import Envelope
from email_validator import validate_email, EmailNotValidError
from flanker.addresslib import address
from flanker.addresslib.address import EmailAddress
from sqlalchemy.exc import IntegrityError
from app import pgp_utils, s3, config, contact_utils from app import pgp_utils, s3, config, contact_utils
from app.alias_utils import ( from app.alias_utils import (
try_auto_create, try_auto_create,
@ -169,12 +174,8 @@ from app.pgp_utils import (
load_public_key_and_check, load_public_key_and_check,
) )
from app.utils import sanitize_email from app.utils import sanitize_email
from email_validator import validate_email, EmailNotValidError
from flanker.addresslib import address
from flanker.addresslib.address import EmailAddress
from init_app import load_pgp_public_keys from init_app import load_pgp_public_keys
from server import create_light_app from server import create_light_app
from sqlalchemy.exc import IntegrityError
def get_or_create_contact( def get_or_create_contact(
@ -601,7 +602,11 @@ def handle_forward(envelope, msg: Message, rcpt_to: str) -> List[Tuple[bool, str
for reply_to in reply_to_header_contents.split(",") for reply_to in reply_to_header_contents.split(",")
if reply_to.strip() if reply_to.strip()
]: ]:
reply_to_name, reply_to_email = parse_full_address(reply_to) try:
reply_to_name, reply_to_email = parse_full_address(reply_to)
except ValueError:
LOG.d(f"Could not parse reply-to address {reply_to}")
continue
if reply_to_email == alias.email: if reply_to_email == alias.email:
LOG.i("Reply-to same as alias %s", alias) LOG.i("Reply-to same as alias %s", alias)
else: else:

View File

@ -0,0 +1,91 @@
"""index cleanup
Revision ID: d3ff8848c930
Revises: 085f77996ce3
Create Date: 2025-01-30 15:00:02.995813
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "d3ff8848c930"
down_revision = "085f77996ce3"
branch_labels = None
depends_on = None
def upgrade():
with op.get_context().autocommit_block():
op.drop_index("ix_alias_hibp_alias_id", table_name="alias_hibp")
op.drop_index("ix_alias_mailbox_alias_id", table_name="alias_mailbox")
op.drop_index("ix_alias_used_on_alias_id", table_name="alias_used_on")
op.drop_index("ix_api_key_code", table_name="api_key")
op.drop_index(
"ix_auto_create_rule_custom_domain_id", table_name="auto_create_rule"
)
op.drop_index("ix_contact_alias_id", table_name="contact")
op.create_index(
"ix_email_log_user_id_email_log_id",
"email_log",
["user_id", "id"],
unique=False,
)
op.drop_index("ix_job_state", table_name="job")
op.create_index(
"ix_mailbox_email_trgm_idx",
"mailbox",
["email"],
unique=False,
postgresql_ops={"email": "gin_trgm_ops"},
postgresql_using="gin",
)
op.drop_index("ix_partner_user_partner_id", table_name="partner_user")
op.create_index(
"ix_sent_alert_alert_type", "sent_alert", ["alert_type"], unique=False
)
op.create_index(
"ix_sent_alert_to_email", "sent_alert", ["to_email"], unique=False
)
op.create_index(
"idx_users_email_trgm",
"users",
["email"],
unique=False,
postgresql_ops={"email": "gin_trgm_ops"},
postgresql_using="gin",
)
op.drop_index("ix_users_activated", table_name="users")
op.drop_index("ix_mailbox_user_id", table_name="users")
def downgrade():
with op.get_context().autocommit_block():
op.create_index("ix_users_activated", "users", ["activated"], unique=False)
op.drop_index("idx_users_email_trgm", table_name="users")
op.drop_index("ix_sent_alert_to_email", table_name="sent_alert")
op.drop_index("ix_sent_alert_alert_type", table_name="sent_alert")
op.create_index(
"ix_partner_user_partner_id", "partner_user", ["partner_id"], unique=False
)
op.drop_index("ix_mailbox_email_trgm_idx", table_name="mailbox")
op.create_index("ix_job_state", "job", ["state"], unique=False)
op.drop_index("ix_email_log_user_id_email_log_id", table_name="email_log")
op.create_index("ix_contact_alias_id", "contact", ["alias_id"], unique=False)
op.create_index(
"ix_auto_create_rule_custom_domain_id",
"auto_create_rule",
["custom_domain_id"],
unique=False,
)
op.create_index("ix_api_key_code", "api_key", ["code"], unique=False)
op.create_index(
"ix_alias_used_on_alias_id", "alias_used_on", ["alias_id"], unique=False
)
op.create_index(
"ix_alias_mailbox_alias_id", "alias_mailbox", ["alias_id"], unique=False
)
op.create_index(
"ix_alias_hibp_alias_id", "alias_hibp", ["alias_id"], unique=False
)
op.create_index("ix_mailbox_user_id", "users", ["user_id"], unique=False)

View File

@ -0,0 +1,23 @@
"""index cleanup
Revision ID: 97edba8794f8
Revises: d3ff8848c930
Create Date: 2025-01-31 14:42:22.590597
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = '97edba8794f8'
down_revision = 'd3ff8848c930'
branch_labels = None
depends_on = None
def upgrade():
op.drop_index('ix_email_log_user_id', table_name='email_log')
def downgrade():
op.create_index('ix_email_log_user_id', 'email_log', ['user_id'], unique=False)

View File

@ -0,0 +1,27 @@
"""contact index
Revision ID: 20e7d3ca289a
Revises: 97edba8794f8
Create Date: 2025-02-03 16:52:06.775032
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = '20e7d3ca289a'
down_revision = '97edba8794f8'
branch_labels = None
depends_on = None
def upgrade():
with op.get_context().autocommit_block():
op.create_index('ix_contact_user_id_id', 'contact', ['user_id', 'id'], unique=False)
op.drop_index('ix_contact_user_id', table_name='contact')
def downgrade():
with op.get_context().autocommit_block():
op.create_index('ix_contact_user_id', 'contact', ['user_id'], unique=False)
op.drop_index('ix_contact_user_id_id', table_name='contact')

View File

@ -106,6 +106,7 @@ from app.payments.coinbase import setup_coinbase_commerce
from app.payments.paddle import setup_paddle_callback from app.payments.paddle import setup_paddle_callback
from app.phone.base import phone_bp from app.phone.base import phone_bp
from app.redis_services import initialize_redis_services from app.redis_services import initialize_redis_services
from app.request_utils import generate_request_id
from app.sentry_utils import sentry_before_send from app.sentry_utils import sentry_before_send
if SENTRY_DSN: if SENTRY_DSN:
@ -263,6 +264,7 @@ def set_index_page(app):
and not request.path.startswith("/_debug_toolbar") and not request.path.startswith("/_debug_toolbar")
): ):
g.start_time = time.time() g.start_time = time.time()
g.request_id = generate_request_id()
# to handle the referral url that has ?slref=code part # to handle the referral url that has ?slref=code part
ref_code = request.args.get("slref") ref_code = request.args.get("slref")

View File

@ -1,118 +1,181 @@
{% extends 'admin/master.html' %} {% extends 'admin/master.html' %}
{% block head_css %}
{{ super() }}
<style>
.card-shadow {
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15);
border-radius: 8px;
}
.domain-title {
background-color: #007bff;
color: white;
padding: 10px;
border-radius: 8px 8px 0 0;
}
.status-icon {
font-size: 1.2em;
}
</style>
{% endblock %}
{% macro show_user(user) -%} {% macro show_user(user) -%}
<h4>User <a href="/admin/email_search?email={{ user.email }}">{{ user.email }}</a> with ID {{ user.id }}.</h4> <h4>
<table class="table"> User <a href="/admin/email_search?email={{ user.email }}">{{ user.email }}</a> with ID {{ user.id }}.
<thead> </h4>
<tr> <table class="table">
<th scope="col">User ID</th> <thead>
<th scope="col">Email</th> <tr>
<th scope="col">Verified</th> <th scope="col">User ID</th>
<th scope="col">Status</th> <th scope="col">Email</th>
<th scope="col">Paid</th> <th scope="col">Verified</th>
<th scope="col">Premium</th> <th scope="col">Status</th>
</tr> <th scope="col">Paid</th>
</thead> <th scope="col">Premium</th>
<tbody> </tr>
<tr> </thead>
<td>{{ user.id }}</td> <tbody>
<td> <tr>
<a href="/admin/email_search?email={{ user.email }}">{{ user.email }}</a> <td>{{ user.id }}</td>
</td> <td>
{% if user.activated %} <a href="/admin/email_search?email={{ user.email }}">{{ user.email }}</a>
</td>
{% if user.activated %}
<td class="text-success">Activated</td> <td class="text-success">Activated</td>
{% else %} {% else %}
<td class="text-warning">Pending</td> <td class="text-warning">Pending</td>
{% endif %} {% endif %}
{% if user.disabled %} {% if user.disabled %}
<td class="text-danger">Disabled</td> <td class="text-danger">Disabled</td>
{% else %} {% else %}
<td class="text-success">Enabled</td> <td class="text-success">Enabled</td>
{% endif %} {% endif %}
<td>{{ "yes" if user.is_paid() else "No" }}</td> <td>{{ "yes" if user.is_paid() else "No" }}</td>
<td>{{ "yes" if user.is_premium() else "No" }}</td> <td>{{ "yes" if user.is_premium() else "No" }}</td>
</tr> </tr>
</tbody> </tbody>
</table> </table>
{%- endmacro %} {%- endmacro %}
{% macro show_verification(title, expected, errors) -%} {% macro show_verification(title, expected, errors) -%}
{% if not expected %} {% if not expected %}
<h4 class="mb-3">{{ title }} <span class="text-success">Verified</span></h4>
{% else %} <li class="list-group-item d-flex justify-content-between align-items-center">
<h4 class="mb-3">{{ title }}</h4> <h5>{{ title }}</h5>
<p>Expected</p> <span class="text-success status-icon"><i class="fa fa-check-circle"></i></span>
<p>{{expected}}</p> </li>
<p>Current response</p> {% else %}
<ul class="list-group"> <li class="list-group-item">
{% for error in errors %} <h5>{{ title }}</h5>
<li class="list-group-item">{{ error }}</li> <p>
{% endfor %} <strong>Expected:</strong> {{ expected.recommended }}
</p>
<p>
<strong>Allowed:</strong>
<ul>
{% for expected_record in expected.allowed %}<li>{{ expected_record }}</li>{% endfor %}
</ul> </ul>
{% endif %} </p>
<p>
<strong>Current response:</strong>
</p>
{% for error in errors %}
<ul class="list-group">
<li class="list-group-item">{{ error }}</li>
</ul>
{% endfor %}
</li>
{% endif %}
{%- endmacro %} {%- endmacro %}
{% macro show_mx_verification(title, expected, errors) -%}
{% if not expected %}
<li class="list-group-item d-flex justify-content-between align-items-center">
<h5>{{ title }}</h5>
<span class="text-success status-icon"><i class="fa fa-check-circle"></i></span>
</li>
{% else %}
<li class="list-group-item">
<h5>{{ title }}</h5>
<ul>
<li class="list-group-item">
{% for prio in expected %}
{% macro show_domain(domain_with_data) -%} <p>
<h3>Domain {{ domain_with_data.domain.domain }}</h3> <strong>Priority {{ prio }}:</strong> {{ expected[prio].recommended }}
{% set domain = domain_with_data.domain %} </p>
<ul class="list-group"> <p>
<li class="list-group-item"> <strong>Allowed:</strong>
{{ show_verification("Ownership", domain_with_data.ownership_expected, domain_with_data.ownership_validation.errors) }} <ul>
</li> {% for expected_record in expected[prio].allowed %}<li>{{ expected_record }}</li>{% endfor %}
<li class="list-group-item"> </ul>
{{ show_verification("MX", domain_with_data.mx_expected, domain_with_data.mx_validation.errors) }} </p>
</li> <p>
<li class="list-group-item"> <strong>Current response:</strong>
{{ show_verification("SPF", domain_with_data.spf_expected, domain_with_data.spf_validation.errors) }} </p>
</li> {% for error in errors %}
{% for dkim_domain in domain_with_data.dkim_expected %}
<li class="list-group-item"> <ul class="list-group">
{{ show_verification("DKIM {}.{}".format(dkim_domain, domain.domain), domain_with_data.dkim_expected[dkim_domain], [domain_with_data.dkim_validation.get(dkim_domain+"."+domain.domain,'')]) }} <li class="list-group-item">{{ error }}</li>
</li> </ul>
{% endfor %}
</li>
{% endfor %} {% endfor %}
</ul> </ul>
</li>
{% endif %}
{%- endmacro %} {%- endmacro %}
{% macro show_domain(domain_with_data) -%}
<div class="col-md-3 mb-4">
<div class="card card-shadow">
<div class="domain-title text-center">
<h4>Domain {{ domain_with_data.domain.domain }}</h4>
</div>
<div class="card-body">
{% set domain = domain_with_data.domain %}
<ul class="list-group">
{{ show_verification("Ownership", domain_with_data.ownership_expected, domain_with_data.ownership_validation.errors) }}
{{ show_mx_verification("MX", domain_with_data.mx_expected, domain_with_data.mx_validation.errors) }}
{{ show_verification("SPF", domain_with_data.spf_expected, domain_with_data.spf_validation.errors) }}
{% for dkim_domain in domain_with_data.dkim_expected %}
{{ show_verification("DKIM {}.{}".format(dkim_domain, domain.domain) , domain_with_data.dkim_expected[dkim_domain], [domain_with_data.dkim_validation.get(dkim_domain+"."+domain.domain,'')]) }}
{% endfor %}
</ul>
</div>
</div>
</div>
{%- endmacro %}
{% block body %} {% block body %}
<div class="border border-dark border-2 mt-1 mb-2 p-3">
<form method="get">
<div class="form-group">
<label for="email">User or domain to search:</label>
<input type="text"
class="form-control"
name="user"
value="{{ query or '' }}" />
</div>
<button type="submit" class="btn btn-primary">Submit</button>
</form>
</div>
{% if data.no_match and query %}
<div class="border border-dark border-2 mt-1 mb-2 p-3 alert alert-warning"
role="alert">No user, alias or mailbox found for {{ query }}</div>
{% endif %}
{% if data.user %}
<div class="border border-dark border-2 mt-1 mb-2 p-3"> <div class="border border-dark border-2 mt-1 mb-2 p-3">
<form method="get"> <h3 class="mb-3">Found User {{ data.user.email }}</h3>
<div class="form-group"> {{ show_user(data.user) }}
<label for="email">User or domain to search:</label>
<input type="text"
class="form-control"
name="user"
value="{{ query or '' }}" />
</div>
<button type="submit" class="btn btn-primary">Submit</button>
</form>
</div> </div>
{% if data.no_match and query %} {% endif %}
<div class="border border-dark border-2 mt-1 mb-2 p-3 alert alert-warning" <div class="row mt-4">
role="alert">No user, alias or mailbox found for {{ query }}</div> {% for domain_with_data in data.domains %}{{ show_domain(domain_with_data) }}{% endfor %}
{% endif %} </div>
{% if data.user %} </div>
<div class="border border-dark border-2 mt-1 mb-2 p-3">
<h3 class="mb-3">Found User {{ data.user.email }}</h3>
{{ show_user(data.user) }}
</div>
{% endif %}
<div class="d-flex">
{% for domain_with_data in data.domains %}
<div class="card m-2 border-dark" style="width: 30rem;">
<div class="card-body">
{{ show_domain(domain_with_data) }}
</div>
</div>
{% endfor %}
</div>
{% endblock %} {% endblock %}

View File

@ -38,7 +38,7 @@
Value: <em data-toggle="tooltip" Value: <em data-toggle="tooltip"
title="Click to copy" title="Click to copy"
class="clipboard" class="clipboard"
data-clipboard-text="{{ ownership_record }}">{{ ownership_record }}</em> data-clipboard-text="{{ ownership_records.recommended }}">{{ ownership_records.recommended }}</em>
</div> </div>
<form method="post" action="#ownership-form"> <form method="post" action="#ownership-form">
{{ csrf_form.csrf_token }} {{ csrf_form.csrf_token }}
@ -91,7 +91,7 @@
<br /> <br />
Some domain registrars (Namecheap, CloudFlare, etc) might also use <em>@</em> for the root domain. Some domain registrars (Namecheap, CloudFlare, etc) might also use <em>@</em> for the root domain.
</div> </div>
{% for record in expected_mx_records %} {% for prio in expected_mx_records %}
<div class="mb-3 p-3 dns-record"> <div class="mb-3 p-3 dns-record">
Record: MX Record: MX
@ -99,12 +99,12 @@
Domain: {{ custom_domain.domain }} or Domain: {{ custom_domain.domain }} or
<b>@</b> <b>@</b>
<br /> <br />
Priority: {{ record.priority }} Priority: {{ prio }}
<br /> <br />
Target: <em data-toggle="tooltip" Target: <em data-toggle="tooltip"
title="Click to copy" title="Click to copy"
class="clipboard" class="clipboard"
data-clipboard-text="{{ record.domain }}">{{ record.domain }}</em> data-clipboard-text="{{ expected_mx_records[prio].recommended }}">{{ expected_mx_records[prio].recommended }}</em>
</div> </div>
{% endfor %} {% endfor %}
<form method="post" action="#mx-form"> <form method="post" action="#mx-form">
@ -251,8 +251,8 @@
<em data-toggle="tooltip" <em data-toggle="tooltip"
title="Click to copy" title="Click to copy"
class="clipboard" class="clipboard"
data-clipboard-text="{{ dkim_cname_value }}." data-clipboard-text="{{ dkim_cname_value.recommended }}."
style="overflow-wrap: break-word">{{ dkim_cname_value }}.</em> style="overflow-wrap: break-word">{{ dkim_cname_value.recommended }}.</em>
</div> </div>
{% endfor %} {% endfor %}
<div class="alert alert-info"> <div class="alert alert-info">

View File

@ -59,11 +59,17 @@ def test_set_mailboxes_for_alias_mailbox_success():
assert db_alias is not None assert db_alias is not None
assert db_alias.mailbox_id == mb1.id assert db_alias.mailbox_id == mb1.id
alias_mailboxes = AliasMailbox.filter_by(alias_id=alias.id).all() alias_mailboxes = (
AliasMailbox.filter_by(alias_id=alias.id).order_by(AliasMailbox.id.asc()).all()
)
assert len(alias_mailboxes) == 1 assert len(alias_mailboxes) == 1
assert alias_mailboxes[0].mailbox_id == mb2.id assert alias_mailboxes[0].mailbox_id == mb2.id
audit_logs = AliasAuditLog.filter_by(alias_id=alias.id).all() audit_logs = (
AliasAuditLog.filter_by(alias_id=alias.id)
.order_by(AliasAuditLog.id.asc())
.all()
)
assert len(audit_logs) == 2 assert len(audit_logs) == 2
assert audit_logs[0].action == AliasAuditLogAction.CreateAlias.value assert audit_logs[0].action == AliasAuditLogAction.CreateAlias.value
assert audit_logs[1].action == AliasAuditLogAction.ChangedMailboxes.value assert audit_logs[1].action == AliasAuditLogAction.ChangedMailboxes.value

View File

@ -6,6 +6,7 @@ from app.coupon_utils import (
CouponUserCannotRedeemError, CouponUserCannotRedeemError,
redeem_lifetime_coupon, redeem_lifetime_coupon,
) )
from app.db import Session
from app.models import ( from app.models import (
Coupon, Coupon,
Subscription, Subscription,
@ -14,8 +15,11 @@ from app.models import (
CoinbaseSubscription, CoinbaseSubscription,
LifetimeCoupon, LifetimeCoupon,
User, User,
PartnerSubscription,
PartnerUser,
) )
from tests.utils import create_new_user, random_string from app.proton.utils import get_proton_partner
from tests.utils import create_new_user, random_string, random_email
def test_use_coupon(): def test_use_coupon():
@ -157,3 +161,55 @@ def test_used_lifetime_coupon():
user = User.get(user.id) user = User.get(user.id)
assert not user.lifetime assert not user.lifetime
assert not user.paid_lifetime assert not user.paid_lifetime
def test_used_lifetime_coupon_with_lifetime_user():
user = create_new_user()
user.lifetime = True
code = random_string(10)
LifetimeCoupon.create(code=code, nb_used=10, paid=True)
coupon = redeem_lifetime_coupon(code, user)
assert coupon is None
def test_used_lifetime_coupon_with_lifetime_partner():
email = random_email()
user = User.create(email=email)
pu = PartnerUser.create(
user_id=user.id,
partner_id=get_proton_partner().id,
partner_email=email,
external_user_id=random_string(10),
flush=True,
)
PartnerSubscription.create(
partner_user_id=pu.id, end_at=arrow.utcnow().shift(years=10), lifetime=True
)
Session.flush()
code = random_string(10)
LifetimeCoupon.create(code=code, nb_used=10, paid=True)
coupon = redeem_lifetime_coupon(code, user)
assert coupon is None
def test_used_lifetime_coupon_with_partner_sub():
email = random_email()
user = User.create(email=email)
pu = PartnerUser.create(
user_id=user.id,
partner_id=get_proton_partner().id,
partner_email=email,
external_user_id=random_string(10),
flush=True,
)
PartnerSubscription.create(
partner_user_id=pu.id, end_at=arrow.utcnow().shift(years=10)
)
Session.flush()
code = random_string(10)
LifetimeCoupon.create(code=code, nb_used=10, paid=True)
coupon = redeem_lifetime_coupon(code, user)
assert coupon
user = User.get(user.id)
assert user.lifetime
assert user.paid_lifetime

View File

@ -13,8 +13,8 @@ from app.custom_domain_utils import (
) )
from app.db import Session from app.db import Session
from app.models import User, CustomDomain, Mailbox, DomainMailbox from app.models import User, CustomDomain, Mailbox, DomainMailbox
from tests.utils import get_proton_partner, random_email
from tests.utils import create_new_user, random_string, random_domain from tests.utils import create_new_user, random_string, random_domain
from tests.utils import get_proton_partner, random_email
user: Optional[User] = None user: Optional[User] = None
@ -195,7 +195,11 @@ def test_set_custom_domain_mailboxes_success():
assert res.success is True assert res.success is True
assert res.reason is None assert res.reason is None
domain_mailboxes = DomainMailbox.filter_by(domain_id=domain.id).all() domain_mailboxes = (
DomainMailbox.filter_by(domain_id=domain.id)
.order_by(DomainMailbox.mailbox_id.asc())
.all()
)
assert len(domain_mailboxes) == 2 assert len(domain_mailboxes) == 2
assert domain_mailboxes[0].domain_id == domain.id assert domain_mailboxes[0].domain_id == domain.id
assert domain_mailboxes[0].mailbox_id == user.default_mailbox_id assert domain_mailboxes[0].mailbox_id == user.default_mailbox_id
@ -219,7 +223,11 @@ def test_set_custom_domain_mailboxes_set_twice():
assert res.success is True assert res.success is True
assert res.reason is None assert res.reason is None
domain_mailboxes = DomainMailbox.filter_by(domain_id=domain.id).all() domain_mailboxes = (
DomainMailbox.filter_by(domain_id=domain.id)
.order_by(DomainMailbox.mailbox_id.asc())
.all()
)
assert len(domain_mailboxes) == 2 assert len(domain_mailboxes) == 2
assert domain_mailboxes[0].domain_id == domain.id assert domain_mailboxes[0].domain_id == domain.id
assert domain_mailboxes[0].mailbox_id == user.default_mailbox_id assert domain_mailboxes[0].mailbox_id == user.default_mailbox_id

View File

@ -4,8 +4,8 @@ from app import config
from app.constants import DMARC_RECORD from app.constants import DMARC_RECORD
from app.custom_domain_validation import CustomDomainValidation from app.custom_domain_validation import CustomDomainValidation
from app.db import Session from app.db import Session
from app.dns_utils import InMemoryDNSClient
from app.models import CustomDomain, User from app.models import CustomDomain, User
from app.dns_utils import InMemoryDNSClient, MxRecord
from app.proton.utils import get_proton_partner from app.proton.utils import get_proton_partner
from app.utils import random_string from app.utils import random_string
from tests.utils import create_new_user, random_domain from tests.utils import create_new_user, random_domain
@ -33,9 +33,12 @@ def test_custom_domain_validation_get_dkim_records():
records = validator.get_dkim_records(custom_domain) records = validator.get_dkim_records(custom_domain)
assert len(records) == 3 assert len(records) == 3
assert records["dkim02._domainkey"] == f"dkim02._domainkey.{domain}" assert records["dkim02._domainkey"].recommended == f"dkim02._domainkey.{domain}"
assert records["dkim03._domainkey"] == f"dkim03._domainkey.{domain}" assert records["dkim02._domainkey"].allowed == [f"dkim02._domainkey.{domain}"]
assert records["dkim._domainkey"] == f"dkim._domainkey.{domain}" assert records["dkim03._domainkey"].recommended == f"dkim03._domainkey.{domain}"
assert records["dkim03._domainkey"].allowed == [f"dkim03._domainkey.{domain}"]
assert records["dkim._domainkey"].recommended == f"dkim._domainkey.{domain}"
assert records["dkim._domainkey"].allowed == [f"dkim._domainkey.{domain}"]
def test_custom_domain_validation_get_dkim_records_for_partner(): def test_custom_domain_validation_get_dkim_records_for_partner():
@ -53,9 +56,25 @@ def test_custom_domain_validation_get_dkim_records_for_partner():
records = validator.get_dkim_records(custom_domain) records = validator.get_dkim_records(custom_domain)
assert len(records) == 3 assert len(records) == 3
assert records["dkim02._domainkey"] == f"dkim02._domainkey.{dkim_domain}" assert (
assert records["dkim03._domainkey"] == f"dkim03._domainkey.{dkim_domain}" records["dkim02._domainkey"].recommended == f"dkim02._domainkey.{dkim_domain}"
assert records["dkim._domainkey"] == f"dkim._domainkey.{dkim_domain}" )
assert records["dkim02._domainkey"].allowed == [
f"dkim02._domainkey.{dkim_domain}",
f"dkim02._domainkey.{domain}",
]
assert (
records["dkim03._domainkey"].recommended == f"dkim03._domainkey.{dkim_domain}"
)
assert records["dkim03._domainkey"].allowed == [
f"dkim03._domainkey.{dkim_domain}",
f"dkim03._domainkey.{domain}",
]
assert records["dkim._domainkey"].recommended == f"dkim._domainkey.{dkim_domain}"
assert records["dkim._domainkey"].allowed == [
f"dkim._domainkey.{dkim_domain}",
f"dkim._domainkey.{domain}",
]
# get_expected_mx_records # get_expected_mx_records
@ -75,8 +94,8 @@ def test_custom_domain_validation_get_expected_mx_records_regular_domain():
assert len(records) == len(config.EMAIL_SERVERS_WITH_PRIORITY) assert len(records) == len(config.EMAIL_SERVERS_WITH_PRIORITY)
for i in range(len(config.EMAIL_SERVERS_WITH_PRIORITY)): for i in range(len(config.EMAIL_SERVERS_WITH_PRIORITY)):
config_record = config.EMAIL_SERVERS_WITH_PRIORITY[i] config_record = config.EMAIL_SERVERS_WITH_PRIORITY[i]
assert records[i].priority == config_record[0] assert records[config_record[0]].recommended == config_record[1]
assert records[i].domain == config_record[1] assert records[config_record[0]].allowed == [config_record[1]]
def test_custom_domain_validation_get_expected_mx_records_domain_from_partner(): def test_custom_domain_validation_get_expected_mx_records_domain_from_partner():
@ -89,14 +108,15 @@ def test_custom_domain_validation_get_expected_mx_records_domain_from_partner():
dkim_domain = random_domain() dkim_domain = random_domain()
validator = CustomDomainValidation(dkim_domain) validator = CustomDomainValidation(dkim_domain)
records = validator.get_expected_mx_records(custom_domain) expected_records = validator.get_expected_mx_records(custom_domain)
# As the domain is a partner_domain but there is no custom config for partner, default records # As the domain is a partner_domain but there is no custom config for partner, default records
# should be used # should be used
assert len(records) == len(config.EMAIL_SERVERS_WITH_PRIORITY) assert len(expected_records) == len(config.EMAIL_SERVERS_WITH_PRIORITY)
for i in range(len(config.EMAIL_SERVERS_WITH_PRIORITY)): for i in range(len(config.EMAIL_SERVERS_WITH_PRIORITY)):
config_record = config.EMAIL_SERVERS_WITH_PRIORITY[i] config_record = config.EMAIL_SERVERS_WITH_PRIORITY[i]
assert records[i].priority == config_record[0] expected = expected_records[config_record[0]]
assert records[i].domain == config_record[1] assert expected.recommended == config_record[1]
assert expected.allowed == [config_record[1]]
def test_custom_domain_validation_get_expected_mx_records_domain_from_partner_with_custom_config(): def test_custom_domain_validation_get_expected_mx_records_domain_from_partner_with_custom_config():
@ -112,15 +132,21 @@ def test_custom_domain_validation_get_expected_mx_records_domain_from_partner_wi
validator = CustomDomainValidation( validator = CustomDomainValidation(
dkim_domain, partner_domains={partner_id: expected_mx_domain} dkim_domain, partner_domains={partner_id: expected_mx_domain}
) )
records = validator.get_expected_mx_records(custom_domain) expected_records = validator.get_expected_mx_records(custom_domain)
# As the domain is a partner_domain and there is a custom config for partner, partner records # As the domain is a partner_domain and there is a custom config for partner, partner records
# should be used # should be used
assert len(records) == 2 assert len(expected_records) == 2
sl_domains = config.EMAIL_SERVERS_WITH_PRIORITY
assert records[0].priority == 10 assert expected_records[10].recommended == f"mx1.{expected_mx_domain}."
assert records[0].domain == f"mx1.{expected_mx_domain}." expected = [f"mx1.{expected_mx_domain}."]
assert records[1].priority == 20 expected.extend([sl_dom[1] for sl_dom in sl_domains if sl_dom[0] == 10])
assert records[1].domain == f"mx2.{expected_mx_domain}." assert expected_records[10].allowed == expected
assert expected_records[20].recommended == f"mx2.{expected_mx_domain}."
expected = [f"mx2.{expected_mx_domain}."]
expected.extend([sl_dom[1] for sl_dom in sl_domains if sl_dom[0] == 20])
assert expected_records[20].allowed == expected
# get_expected_spf_records # get_expected_spf_records
@ -309,7 +335,7 @@ def test_custom_domain_validation_validate_ownership_success():
domain = create_custom_domain(random_domain()) domain = create_custom_domain(random_domain())
dns_client.set_txt_record( dns_client.set_txt_record(
domain.domain, [validator.get_ownership_verification_record(domain)] domain.domain, validator.get_ownership_verification_record(domain).allowed
) )
res = validator.validate_domain_ownership(domain) res = validator.validate_domain_ownership(domain)
@ -336,7 +362,7 @@ def test_custom_domain_validation_validate_ownership_from_partner_success():
Session.commit() Session.commit()
dns_client.set_txt_record( dns_client.set_txt_record(
domain.domain, [validator.get_ownership_verification_record(domain)] domain.domain, validator.get_ownership_verification_record(domain).allowed
) )
res = validator.validate_domain_ownership(domain) res = validator.validate_domain_ownership(domain)
@ -370,7 +396,7 @@ def test_custom_domain_validation_validate_mx_records_wrong_records_failure():
wrong_record_1 = random_string() wrong_record_1 = random_string()
wrong_record_2 = random_string() wrong_record_2 = random_string()
wrong_records = [MxRecord(10, wrong_record_1), MxRecord(20, wrong_record_2)] wrong_records = {10: [wrong_record_1], 20: [wrong_record_2]}
dns_client.set_mx_records(domain.domain, wrong_records) dns_client.set_mx_records(domain.domain, wrong_records)
res = validator.validate_mx_records(domain) res = validator.validate_mx_records(domain)
@ -387,7 +413,12 @@ def test_custom_domain_validation_validate_mx_records_success():
domain = create_custom_domain(random_domain()) domain = create_custom_domain(random_domain())
dns_client.set_mx_records(domain.domain, validator.get_expected_mx_records(domain)) mx_records_by_prio = validator.get_expected_mx_records(domain)
dns_records = {
priority: mx_records_by_prio[priority].allowed
for priority in mx_records_by_prio
}
dns_client.set_mx_records(domain.domain, dns_records)
res = validator.validate_mx_records(domain) res = validator.validate_mx_records(domain)
assert res.success is True assert res.success is True
@ -485,16 +516,19 @@ def test_custom_domain_validation_validate_spf_cleans_verification_record():
domain.partner_id = proton_partner_id domain.partner_id = proton_partner_id
Session.commit() Session.commit()
wrong_record = random_string() ownership_records = validator.get_ownership_verification_record(domain)
dns_client.set_txt_record(
hostname=domain.domain,
txt_list=[wrong_record, validator.get_ownership_verification_record(domain)],
)
res = validator.validate_spf_records(domain)
assert res.success is False for ownership_record in ownership_records.allowed:
assert len(res.errors) == 1 wrong_record = random_string()
assert res.errors[0] == wrong_record dns_client.set_txt_record(
hostname=domain.domain,
txt_list=[wrong_record, ownership_record],
)
res = validator.validate_spf_records(domain)
assert res.success is False
assert len(res.errors) == 1
assert res.errors[0] == wrong_record
# validate_dmarc_records # validate_dmarc_records

View File

@ -1,9 +1,8 @@
from app.custom_domain_validation import is_mx_equivalent, ExpectedValidationRecords
from app.dns_utils import ( from app.dns_utils import (
get_mx_domains, get_mx_domains,
get_network_dns_client, get_network_dns_client,
is_mx_equivalent,
InMemoryDNSClient, InMemoryDNSClient,
MxRecord,
) )
from tests.utils import random_domain from tests.utils import random_domain
@ -17,9 +16,9 @@ def test_get_mx_domains():
assert len(r) > 0 assert len(r) > 0
for x in r: for prio in r:
assert x.priority > 0 assert prio > 0
assert x.domain assert len(r[prio]) > 0
def test_get_spf_domain(): def test_get_spf_domain():
@ -33,33 +32,49 @@ def test_get_txt_record():
def test_is_mx_equivalent(): def test_is_mx_equivalent():
assert is_mx_equivalent([], []) assert is_mx_equivalent({}, {})
assert is_mx_equivalent( assert is_mx_equivalent(
mx_domains=[MxRecord(1, "domain")], ref_mx_domains=[MxRecord(1, "domain")] mx_domains={1: ["domain"]},
expected_mx_domains={
1: ExpectedValidationRecords(recommended="nop", allowed=["domain"])
},
) )
assert is_mx_equivalent( assert is_mx_equivalent(
mx_domains=[MxRecord(10, "domain1"), MxRecord(20, "domain2")], mx_domains={10: ["domain10"], 20: ["domain20"]},
ref_mx_domains=[MxRecord(10, "domain1"), MxRecord(20, "domain2")], expected_mx_domains={
10: ExpectedValidationRecords(recommended="nop", allowed=["domain10"]),
20: ExpectedValidationRecords(recommended="nop", allowed=["domain20"]),
},
) )
assert is_mx_equivalent( assert is_mx_equivalent(
mx_domains=[MxRecord(5, "domain1"), MxRecord(10, "domain2")], mx_domains={5: ["domain1"], 10: ["domain2"]},
ref_mx_domains=[MxRecord(10, "domain1"), MxRecord(20, "domain2")], expected_mx_domains={
10: ExpectedValidationRecords(recommended="nop", allowed=["domain1"]),
20: ExpectedValidationRecords(recommended="nop", allowed=["domain2"]),
},
) )
assert is_mx_equivalent(
mx_domains=[ assert not is_mx_equivalent(
MxRecord(5, "domain1"), mx_domains={10: ["domain10", "domain11"], 20: ["domain20"]},
MxRecord(10, "domain2"), expected_mx_domains={
MxRecord(20, "domain3"), 10: ExpectedValidationRecords(recommended="nop", allowed=["domain10"]),
], 20: ExpectedValidationRecords(recommended="nop", allowed=["domain20"]),
ref_mx_domains=[MxRecord(10, "domain1"), MxRecord(20, "domain2")], },
) )
assert not is_mx_equivalent( assert not is_mx_equivalent(
mx_domains=[MxRecord(5, "domain1"), MxRecord(10, "domain2")], mx_domains={5: ["domain1"], 10: ["domain2"], 20: ["domain3"]},
ref_mx_domains=[ expected_mx_domains={
MxRecord(10, "domain1"), 10: ExpectedValidationRecords(recommended="nop", allowed=["domain1"]),
MxRecord(20, "domain2"), 20: ExpectedValidationRecords(recommended="nop", allowed=["domain2"]),
MxRecord(20, "domain3"), },
], )
assert not is_mx_equivalent(
mx_domains={10: ["domain1"]},
expected_mx_domains={
10: ExpectedValidationRecords(recommended="nop", allowed=["domain1"]),
20: ExpectedValidationRecords(recommended="nop", allowed=["domain2"]),
},
) )