diff --git a/app/app/alias_utils.py b/app/app/alias_utils.py index 49aade1..8e51c4e 100644 --- a/app/app/alias_utils.py +++ b/app/app/alias_utils.py @@ -370,9 +370,7 @@ def delete_alias( EventDispatcher.send_event( user, - EventContent( - alias_deleted=AliasDeleted(alias_id=alias_id, alias_email=alias_email) - ), + EventContent(alias_deleted=AliasDeleted(id=alias_id, email=alias_email)), ) if commit: Session.commit() @@ -511,7 +509,10 @@ def change_alias_status(alias: Alias, enabled: bool, commit: bool = False): alias.enabled = enabled event = AliasStatusChanged( - alias_id=alias.id, alias_email=alias.email, enabled=enabled + id=alias.id, + email=alias.email, + enabled=enabled, + created_at=int(alias.created_at.timestamp), ) EventDispatcher.send_event(alias.user, EventContent(alias_status_change=event)) diff --git a/app/app/config.py b/app/app/config.py index 38f6e40..e55d1d8 100644 --- a/app/app/config.py +++ b/app/app/config.py @@ -35,6 +35,33 @@ def sl_getenv(env_var: str, default_factory: Callable = None): return literal_eval(value) +def get_env_dict(env_var: str) -> dict[str, str]: + """ + Get an env variable and convert it into a python dictionary with keys and values as strings. + Args: + env_var (str): env var, example: SL_DB + + Syntax is: key1=value1;key2=value2 + Components separated by ; + key and value separated by = + """ + value = os.getenv(env_var) + if not value: + return {} + + components = value.split(";") + result = {} + for component in components: + if component == "": + continue + parts = component.split("=") + if len(parts) != 2: + raise Exception(f"Invalid config for env var {env_var}") + result[parts[0].strip()] = parts[1].strip() + + return result + + config_file = os.environ.get("CONFIG") if config_file: config_file = get_abs_path(config_file) @@ -609,3 +636,24 @@ EVENT_WEBHOOK_ENABLED_USER_IDS: Optional[List[int]] = read_webhook_enabled_user_ # Allow to define a different DB_URI for the event listener, in case we want to skip the connection pool # It defaults to the regular DB_URI in case it's needed EVENT_LISTENER_DB_URI = os.environ.get("EVENT_LISTENER_DB_URI", DB_URI) + + +def read_partner_dict(var: str) -> dict[int, str]: + partner_value = get_env_dict(var) + if len(partner_value) == 0: + return {} + + res: dict[int, str] = {} + for partner_id in partner_value.keys(): + try: + partner_id_int = int(partner_id.strip()) + res[partner_id_int] = partner_value[partner_id] + except ValueError: + pass + return res + + +PARTNER_DOMAINS: dict[int, str] = read_partner_dict("PARTNER_DOMAINS") +PARTNER_DOMAIN_VALIDATION_PREFIXES: dict[int, str] = read_partner_dict( + "PARTNER_DOMAIN_VALIDATION_PREFIXES" +) diff --git a/app/app/constants.py b/app/app/constants.py index b20bc6a..837e7fa 100644 --- a/app/app/constants.py +++ b/app/app/constants.py @@ -1 +1,2 @@ HEADER_ALLOW_API_COOKIES = "X-Sl-Allowcookies" +DMARC_RECORD = "v=DMARC1; p=quarantine; pct=100; adkim=s; aspf=s" diff --git a/app/app/custom_domain_utils.py b/app/app/custom_domain_utils.py new file mode 100644 index 0000000..9275622 --- /dev/null +++ b/app/app/custom_domain_utils.py @@ -0,0 +1,128 @@ +import re + +from dataclasses import dataclass +from enum import Enum +from typing import Optional + +from app.db import Session +from app.email_utils import get_email_domain_part +from app.log import LOG +from app.models import User, CustomDomain, SLDomain, Mailbox + +_ALLOWED_DOMAIN_REGEX = re.compile(r"^(?!-)[A-Za-z0-9-]{1,63}(? str: + if self == CannotUseDomainReason.InvalidDomain: + return "This is not a valid domain" + elif self == CannotUseDomainReason.BuiltinDomain: + return "A custom domain cannot be a built-in domain." + elif self == CannotUseDomainReason.DomainAlreadyUsed: + return f"{domain} already used" + elif self == CannotUseDomainReason.DomainPartOfUserEmail: + return "You cannot add a domain that you are currently using for your personal email. Please change your personal email to your real email" + elif self == CannotUseDomainReason.DomainUserInMailbox: + return f"{domain} already used in a SimpleLogin mailbox" + else: + raise Exception("Invalid CannotUseDomainReason") + + +def is_valid_domain(domain: str) -> bool: + """ + Checks that a domain is valid according to RFC 1035 + """ + if len(domain) > 255: + return False + if domain.endswith("."): + domain = domain[:-1] # Strip the trailing dot + labels = domain.split(".") + if not labels: + return False + for label in labels: + if not _ALLOWED_DOMAIN_REGEX.match(label): + return False + return True + + +def sanitize_domain(domain: str) -> str: + new_domain = domain.lower().strip() + if new_domain.startswith("http://"): + new_domain = new_domain[len("http://") :] + + if new_domain.startswith("https://"): + new_domain = new_domain[len("https://") :] + + return new_domain + + +def can_domain_be_used(user: User, domain: str) -> Optional[CannotUseDomainReason]: + if not is_valid_domain(domain): + return CannotUseDomainReason.InvalidDomain + elif SLDomain.get_by(domain=domain): + return CannotUseDomainReason.BuiltinDomain + elif CustomDomain.get_by(domain=domain): + return CannotUseDomainReason.DomainAlreadyUsed + elif get_email_domain_part(user.email) == domain: + return CannotUseDomainReason.DomainPartOfUserEmail + elif Mailbox.filter( + Mailbox.verified.is_(True), Mailbox.email.endswith(f"@{domain}") + ).first(): + return CannotUseDomainReason.DomainUserInMailbox + else: + return None + + +def create_custom_domain( + user: User, domain: str, partner_id: Optional[int] = None +) -> CreateCustomDomainResult: + if not user.is_premium(): + return CreateCustomDomainResult( + message="Only premium plan can add custom domain", + message_category="warning", + ) + + new_domain = sanitize_domain(domain) + domain_forbidden_cause = can_domain_be_used(user, new_domain) + if domain_forbidden_cause: + return CreateCustomDomainResult( + message=domain_forbidden_cause.message(new_domain), message_category="error" + ) + + new_custom_domain = CustomDomain.create(domain=new_domain, user_id=user.id) + + # new domain has ownership verified if its parent has the ownership verified + for root_cd in user.custom_domains: + if new_domain.endswith("." + root_cd.domain) and root_cd.ownership_verified: + LOG.i( + "%s ownership verified thanks to %s", + new_custom_domain, + root_cd, + ) + new_custom_domain.ownership_verified = True + + # Add the partner_id in case it's passed + if partner_id is not None: + new_custom_domain.partner_id = partner_id + + Session.commit() + + return CreateCustomDomainResult( + success=True, + instance=new_custom_domain, + ) diff --git a/app/app/custom_domain_validation.py b/app/app/custom_domain_validation.py index 3a2145a..4dbf201 100644 --- a/app/app/custom_domain_validation.py +++ b/app/app/custom_domain_validation.py @@ -1,37 +1,157 @@ +from dataclasses import dataclass +from typing import Optional + +from app import config +from app.constants import DMARC_RECORD from app.db import Session -from app.dns_utils import get_cname_record +from app.dns_utils import ( + DNSClient, + is_mx_equivalent, + get_network_dns_client, +) from app.models import CustomDomain +@dataclass +class DomainValidationResult: + success: bool + errors: [str] + + class CustomDomainValidation: - def __init__(self, dkim_domain: str): + def __init__( + self, + dkim_domain: str, + dns_client: DNSClient = get_network_dns_client(), + partner_domains: Optional[dict[int, str]] = None, + partner_domains_validation_prefixes: Optional[dict[int, str]] = None, + ): self.dkim_domain = dkim_domain - self._dkim_records = { - (f"{key}._domainkey", f"{key}._domainkey.{self.dkim_domain}") + self._dns_client = dns_client + self._partner_domains = partner_domains or config.PARTNER_DOMAINS + self._partner_domain_validation_prefixes = ( + partner_domains_validation_prefixes + or config.PARTNER_DOMAIN_VALIDATION_PREFIXES + ) + + def get_ownership_verification_record(self, domain: CustomDomain) -> str: + prefix = "sl" + if ( + domain.partner_id is not None + and domain.partner_id in self._partner_domain_validation_prefixes + ): + prefix = self._partner_domain_validation_prefixes[domain.partner_id] + return f"{prefix}-verification={domain.ownership_txt_token}" + + def get_dkim_records(self, domain: CustomDomain) -> {str: str}: + """ + Get a list of dkim records to set up. Depending on the custom_domain, whether if it's from a partner or not, + it will return the default ones or the partner ones. + """ + + # By default use the default domain + dkim_domain = self.dkim_domain + if domain.partner_id is not None: + # Domain is from a partner. Retrieve the partner config and use that domain if exists + dkim_domain = self._partner_domains.get(domain.partner_id, dkim_domain) + + return { + f"{key}._domainkey": f"{key}._domainkey.{dkim_domain}" for key in ("dkim", "dkim02", "dkim03") } - def get_dkim_records(self) -> {str: str}: - """ - Get a list of dkim records to set up. It will be - - """ - return self._dkim_records - def validate_dkim_records(self, custom_domain: CustomDomain) -> dict[str, str]: """ Check if dkim records are properly set for this custom domain. Returns empty list if all records are ok. Other-wise return the records that aren't properly configured """ + correct_records = {} invalid_records = {} - for prefix, expected_record in self.get_dkim_records(): + expected_records = self.get_dkim_records(custom_domain) + for prefix, expected_record in expected_records.items(): custom_record = f"{prefix}.{custom_domain.domain}" - dkim_record = get_cname_record(custom_record) - if dkim_record != expected_record: + dkim_record = self._dns_client.get_cname_record(custom_record) + if dkim_record == expected_record: + correct_records[prefix] = custom_record + else: invalid_records[custom_record] = dkim_record or "empty" - # HACK: If dkim is enabled, don't disable it to give users time to update their CNAMES + + # HACK + # As initially we only had one dkim record, we want to allow users that had only the original dkim record and + # the domain validated to continue seeing it as validated (although showing them the missing records). + # However, if not even the original dkim record is right, even if the domain was dkim_verified in the past, + # we will remove the dkim_verified flag. + # This is done in order to give users with the old dkim config (only one) to update their CNAMEs if custom_domain.dkim_verified: - return invalid_records + # Check if at least the original dkim is there + if correct_records.get("dkim._domainkey") is not None: + # Original dkim record is there. Return the missing records (if any) and don't clear the flag + return invalid_records + + # Original DKIM record is not there, which means the DKIM config is not finished. Proceed with the + # rest of the code path, returning the invalid records and clearing the flag custom_domain.dkim_verified = len(invalid_records) == 0 Session.commit() return invalid_records + + def validate_domain_ownership( + self, custom_domain: CustomDomain + ) -> DomainValidationResult: + """ + Check if the custom_domain has added the ownership verification records + """ + txt_records = self._dns_client.get_txt_record(custom_domain.domain) + expected_verification_record = self.get_ownership_verification_record( + custom_domain + ) + + if expected_verification_record in txt_records: + custom_domain.ownership_verified = True + Session.commit() + return DomainValidationResult(success=True, errors=[]) + else: + return DomainValidationResult(success=False, errors=txt_records) + + def validate_mx_records( + self, custom_domain: CustomDomain + ) -> DomainValidationResult: + mx_domains = self._dns_client.get_mx_domains(custom_domain.domain) + + if not is_mx_equivalent(mx_domains, config.EMAIL_SERVERS_WITH_PRIORITY): + return DomainValidationResult( + success=False, + errors=[f"{priority} {domain}" for (priority, domain) in mx_domains], + ) + else: + custom_domain.verified = True + Session.commit() + return DomainValidationResult(success=True, errors=[]) + + def validate_spf_records( + self, custom_domain: CustomDomain + ) -> DomainValidationResult: + spf_domains = self._dns_client.get_spf_domain(custom_domain.domain) + if config.EMAIL_DOMAIN in spf_domains: + custom_domain.spf_verified = True + Session.commit() + return DomainValidationResult(success=True, errors=[]) + else: + custom_domain.spf_verified = False + Session.commit() + return DomainValidationResult( + success=False, + errors=self._dns_client.get_txt_record(custom_domain.domain), + ) + + def validate_dmarc_records( + self, custom_domain: CustomDomain + ) -> DomainValidationResult: + txt_records = self._dns_client.get_txt_record("_dmarc." + custom_domain.domain) + if DMARC_RECORD in txt_records: + custom_domain.dmarc_verified = True + Session.commit() + return DomainValidationResult(success=True, errors=[]) + else: + custom_domain.dmarc_verified = False + Session.commit() + return DomainValidationResult(success=False, errors=txt_records) diff --git a/app/app/dashboard/views/custom_domain.py b/app/app/dashboard/views/custom_domain.py index 875b5db..b410b30 100644 --- a/app/app/dashboard/views/custom_domain.py +++ b/app/app/dashboard/views/custom_domain.py @@ -5,11 +5,9 @@ from wtforms import StringField, validators from app import parallel_limiter from app.config import EMAIL_SERVERS_WITH_PRIORITY +from app.custom_domain_utils import create_custom_domain from app.dashboard.base import dashboard_bp -from app.db import Session -from app.email_utils import get_email_domain_part -from app.log import LOG -from app.models import CustomDomain, Mailbox, DomainMailbox, SLDomain +from app.models import CustomDomain class NewCustomDomainForm(FlaskForm): @@ -25,11 +23,8 @@ def custom_domain(): custom_domains = CustomDomain.filter_by( user_id=current_user.id, is_sl_subdomain=False ).all() - mailboxes = current_user.mailboxes() new_custom_domain_form = NewCustomDomainForm() - errors = {} - if request.method == "POST": if request.form.get("form-name") == "create": if not current_user.is_premium(): @@ -37,87 +32,25 @@ def custom_domain(): return redirect(url_for("dashboard.custom_domain")) if new_custom_domain_form.validate(): - new_domain = new_custom_domain_form.domain.data.lower().strip() - - if new_domain.startswith("http://"): - new_domain = new_domain[len("http://") :] - - if new_domain.startswith("https://"): - new_domain = new_domain[len("https://") :] - - if SLDomain.get_by(domain=new_domain): - flash("A custom domain cannot be a built-in domain.", "error") - elif CustomDomain.get_by(domain=new_domain): - flash(f"{new_domain} already used", "error") - elif get_email_domain_part(current_user.email) == new_domain: - flash( - "You cannot add a domain that you are currently using for your personal email. " - "Please change your personal email to your real email", - "error", - ) - elif Mailbox.filter( - Mailbox.verified.is_(True), Mailbox.email.endswith(f"@{new_domain}") - ).first(): - flash( - f"{new_domain} already used in a SimpleLogin mailbox", "error" - ) - else: - new_custom_domain = CustomDomain.create( - domain=new_domain, user_id=current_user.id - ) - # new domain has ownership verified if its parent has the ownership verified - for root_cd in current_user.custom_domains: - if ( - new_domain.endswith("." + root_cd.domain) - and root_cd.ownership_verified - ): - LOG.i( - "%s ownership verified thanks to %s", - new_custom_domain, - root_cd, - ) - new_custom_domain.ownership_verified = True - - Session.commit() - - mailbox_ids = request.form.getlist("mailbox_ids") - if mailbox_ids: - # check if mailbox is not tempered with - mailboxes = [] - for mailbox_id in mailbox_ids: - mailbox = Mailbox.get(mailbox_id) - if ( - not mailbox - or mailbox.user_id != current_user.id - or not mailbox.verified - ): - flash("Something went wrong, please retry", "warning") - return redirect(url_for("dashboard.custom_domain")) - mailboxes.append(mailbox) - - for mailbox in mailboxes: - DomainMailbox.create( - domain_id=new_custom_domain.id, mailbox_id=mailbox.id - ) - - Session.commit() - - flash( - f"New domain {new_custom_domain.domain} is created", "success" - ) - + res = create_custom_domain( + user=current_user, domain=new_custom_domain_form.domain.data + ) + if res.success: + flash(f"New domain {res.instance.domain} is created", "success") return redirect( url_for( "dashboard.domain_detail_dns", - custom_domain_id=new_custom_domain.id, + custom_domain_id=res.instance.id, ) ) + else: + flash(res.message, res.message_category) + if res.redirect: + return redirect(url_for(res.redirect)) return render_template( "dashboard/custom_domain.html", custom_domains=custom_domains, new_custom_domain_form=new_custom_domain_form, EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY, - errors=errors, - mailboxes=mailboxes, ) diff --git a/app/app/dashboard/views/domain_detail.py b/app/app/dashboard/views/domain_detail.py index 29089a3..9c714c9 100644 --- a/app/app/dashboard/views/domain_detail.py +++ b/app/app/dashboard/views/domain_detail.py @@ -6,16 +6,11 @@ from flask_login import login_required, current_user from flask_wtf import FlaskForm from wtforms import StringField, validators, IntegerField +from app.constants import DMARC_RECORD from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN, JOB_DELETE_DOMAIN from app.custom_domain_validation import CustomDomainValidation from app.dashboard.base import dashboard_bp from app.db import Session -from app.dns_utils import ( - get_mx_domains, - get_spf_domain, - get_txt_record, - is_mx_equivalent, -) from app.log import LOG from app.models import ( CustomDomain, @@ -49,8 +44,6 @@ def domain_detail_dns(custom_domain_id): domain_validator = CustomDomainValidation(EMAIL_DOMAIN) csrf_form = CSRFValidationForm() - dmarc_record = "v=DMARC1; p=quarantine; pct=100; adkim=s; aspf=s" - mx_ok = spf_ok = dkim_ok = dmarc_ok = ownership_ok = True mx_errors = spf_errors = dkim_errors = dmarc_errors = ownership_errors = [] @@ -59,15 +52,14 @@ def domain_detail_dns(custom_domain_id): flash("Invalid request", "warning") return redirect(request.url) if request.form.get("form-name") == "check-ownership": - txt_records = get_txt_record(custom_domain.domain) - - if custom_domain.get_ownership_dns_txt_value() in txt_records: + ownership_validation_result = domain_validator.validate_domain_ownership( + custom_domain + ) + if ownership_validation_result.success: flash( "Domain ownership is verified. Please proceed to the other records setup", "success", ) - custom_domain.ownership_verified = True - Session.commit() return redirect( url_for( "dashboard.domain_detail_dns", @@ -78,36 +70,28 @@ def domain_detail_dns(custom_domain_id): else: flash("We can't find the needed TXT record", "error") ownership_ok = False - ownership_errors = txt_records + ownership_errors = ownership_validation_result.errors elif request.form.get("form-name") == "check-mx": - mx_domains = get_mx_domains(custom_domain.domain) - - if not is_mx_equivalent(mx_domains, EMAIL_SERVERS_WITH_PRIORITY): - flash("The MX record is not correctly set", "warning") - - mx_ok = False - # build mx_errors to show to user - mx_errors = [ - f"{priority} {domain}" for (priority, domain) in mx_domains - ] - else: + mx_validation_result = domain_validator.validate_mx_records(custom_domain) + if mx_validation_result.success: flash( "Your domain can start receiving emails. You can now use it to create alias", "success", ) - custom_domain.verified = True - Session.commit() return redirect( url_for( "dashboard.domain_detail_dns", custom_domain_id=custom_domain.id ) ) + else: + flash("The MX record is not correctly set", "warning") + mx_ok = False + mx_errors = mx_validation_result.errors + elif request.form.get("form-name") == "check-spf": - spf_domains = get_spf_domain(custom_domain.domain) - if EMAIL_DOMAIN in spf_domains: - custom_domain.spf_verified = True - Session.commit() + spf_validation_result = domain_validator.validate_spf_records(custom_domain) + if spf_validation_result.success: flash("SPF is setup correctly", "success") return redirect( url_for( @@ -115,14 +99,12 @@ def domain_detail_dns(custom_domain_id): ) ) else: - custom_domain.spf_verified = False - Session.commit() flash( f"SPF: {EMAIL_DOMAIN} is not included in your SPF record.", "warning", ) spf_ok = False - spf_errors = get_txt_record(custom_domain.domain) + spf_errors = spf_validation_result.errors elif request.form.get("form-name") == "check-dkim": dkim_errors = domain_validator.validate_dkim_records(custom_domain) @@ -138,10 +120,10 @@ def domain_detail_dns(custom_domain_id): flash("DKIM: the CNAME record is not correctly set", "warning") elif request.form.get("form-name") == "check-dmarc": - txt_records = get_txt_record("_dmarc." + custom_domain.domain) - if dmarc_record in txt_records: - custom_domain.dmarc_verified = True - Session.commit() + dmarc_validation_result = domain_validator.validate_dmarc_records( + custom_domain + ) + if dmarc_validation_result.success: flash("DMARC is setup correctly", "success") return redirect( url_for( @@ -149,19 +131,21 @@ def domain_detail_dns(custom_domain_id): ) ) else: - custom_domain.dmarc_verified = False - Session.commit() flash( "DMARC: The TXT record is not correctly set", "warning", ) dmarc_ok = False - dmarc_errors = txt_records + dmarc_errors = dmarc_validation_result.errors return render_template( "dashboard/domain_detail/dns.html", EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY, - dkim_records=domain_validator.get_dkim_records(), + ownership_record=domain_validator.get_ownership_verification_record( + custom_domain + ), + dkim_records=domain_validator.get_dkim_records(custom_domain), + dmarc_record=DMARC_RECORD, **locals(), ) diff --git a/app/app/dns_utils.py b/app/app/dns_utils.py index 429d0aa..2ce6993 100644 --- a/app/app/dns_utils.py +++ b/app/app/dns_utils.py @@ -1,100 +1,13 @@ -from app import config -from typing import Optional, List, Tuple +from abc import ABC, abstractmethod +from typing import List, Tuple, Optional import dns.resolver - -def _get_dns_resolver(): - my_resolver = dns.resolver.Resolver() - my_resolver.nameservers = config.NAMESERVERS - - return my_resolver - - -def get_ns(hostname) -> [str]: - try: - answers = _get_dns_resolver().resolve(hostname, "NS", search=True) - except Exception: - return [] - return [a.to_text() for a in answers] - - -def get_cname_record(hostname) -> Optional[str]: - """Return the CNAME record if exists for a domain, WITHOUT the trailing period at the end""" - try: - answers = _get_dns_resolver().resolve(hostname, "CNAME", search=True) - except Exception: - return None - - for a in answers: - ret = a.to_text() - return ret[:-1] - - return None - - -def get_mx_domains(hostname) -> [(int, str)]: - """return list of (priority, domain name) sorted by priority (lowest priority first) - domain name ends with a "." at the end. - """ - try: - answers = _get_dns_resolver().resolve(hostname, "MX", search=True) - except Exception: - return [] - - ret = [] - - for a in answers: - record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.' - parts = record.split(" ") - - ret.append((int(parts[0]), parts[1])) - - return sorted(ret, key=lambda prio_domain: prio_domain[0]) - +from app.config import NAMESERVERS _include_spf = "include:" -def get_spf_domain(hostname) -> [str]: - """return all domains listed in *include:*""" - try: - answers = _get_dns_resolver().resolve(hostname, "TXT", search=True) - except Exception: - return [] - - ret = [] - - for a in answers: # type: dns.rdtypes.ANY.TXT.TXT - for record in a.strings: - record = record.decode() # record is bytes - - if record.startswith("v=spf1"): - parts = record.split(" ") - for part in parts: - if part.startswith(_include_spf): - ret.append(part[part.find(_include_spf) + len(_include_spf) :]) - - return ret - - -def get_txt_record(hostname) -> [str]: - try: - answers = _get_dns_resolver().resolve(hostname, "TXT", search=True) - except Exception: - return [] - - ret = [] - - for a in answers: # type: dns.rdtypes.ANY.TXT.TXT - for record in a.strings: - record = record.decode() # record is bytes - - ret.append(record) - - return ret - - def is_mx_equivalent( mx_domains: List[Tuple[int, str]], ref_mx_domains: List[Tuple[int, str]] ) -> bool: @@ -105,16 +18,127 @@ def is_mx_equivalent( The priority order is taken into account but not the priority number. For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)] """ - mx_domains = sorted(mx_domains, key=lambda priority_domain: priority_domain[0]) - ref_mx_domains = sorted( - ref_mx_domains, key=lambda priority_domain: priority_domain[0] - ) + mx_domains = sorted(mx_domains, key=lambda x: x[0]) + ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x[0]) if len(mx_domains) < len(ref_mx_domains): return False - for i in range(0, len(ref_mx_domains)): + for i in range(len(ref_mx_domains)): if mx_domains[i][1] != ref_mx_domains[i][1]: return False return True + + +class DNSClient(ABC): + @abstractmethod + def get_cname_record(self, hostname: str) -> Optional[str]: + pass + + @abstractmethod + def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]: + pass + + def get_spf_domain(self, hostname: str) -> List[str]: + """ + return all domains listed in *include:* + """ + try: + records = self.get_txt_record(hostname) + ret = [] + for record in records: + if record.startswith("v=spf1"): + parts = record.split(" ") + for part in parts: + if part.startswith(_include_spf): + ret.append( + part[part.find(_include_spf) + len(_include_spf) :] + ) + return ret + except Exception: + return [] + + @abstractmethod + def get_txt_record(self, hostname: str) -> List[str]: + pass + + +class NetworkDNSClient(DNSClient): + def __init__(self, nameservers: List[str]): + self._resolver = dns.resolver.Resolver() + self._resolver.nameservers = nameservers + + def get_cname_record(self, hostname: str) -> Optional[str]: + """ + Return the CNAME record if exists for a domain, WITHOUT the trailing period at the end + """ + try: + answers = self._resolver.resolve(hostname, "CNAME", search=True) + for a in answers: + ret = a.to_text() + return ret[:-1] + except Exception: + return None + + def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]: + """ + return list of (priority, domain name) sorted by priority (lowest priority first) + domain name ends with a "." at the end. + """ + try: + answers = self._resolver.resolve(hostname, "MX", search=True) + ret = [] + for a in answers: + record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.' + parts = record.split(" ") + ret.append((int(parts[0]), parts[1])) + return sorted(ret, key=lambda x: x[0]) + except Exception: + return [] + + def get_txt_record(self, hostname: str) -> List[str]: + try: + answers = self._resolver.resolve(hostname, "TXT", search=True) + ret = [] + for a in answers: # type: dns.rdtypes.ANY.TXT.TXT + for record in a.strings: + ret.append(record.decode()) + return ret + except Exception: + return [] + + +class InMemoryDNSClient(DNSClient): + def __init__(self): + self.cname_records: dict[str, Optional[str]] = {} + self.mx_records: dict[str, List[Tuple[int, str]]] = {} + self.spf_records: dict[str, List[str]] = {} + self.txt_records: dict[str, List[str]] = {} + + def set_cname_record(self, hostname: str, cname: str): + self.cname_records[hostname] = cname + + def set_mx_records(self, hostname: str, mx_list: List[Tuple[int, str]]): + self.mx_records[hostname] = mx_list + + def set_txt_record(self, hostname: str, txt_list: List[str]): + self.txt_records[hostname] = txt_list + + def get_cname_record(self, hostname: str) -> Optional[str]: + return self.cname_records.get(hostname) + + def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]: + mx_list = self.mx_records.get(hostname, []) + return sorted(mx_list, key=lambda x: x[0]) + + def get_txt_record(self, hostname: str) -> List[str]: + return self.txt_records.get(hostname, []) + + +def get_network_dns_client() -> NetworkDNSClient: + return NetworkDNSClient(NAMESERVERS) + + +def get_mx_domains(hostname: str) -> [(int, str)]: + return get_network_dns_client().get_mx_domains(hostname) diff --git a/app/app/events/generated/event_pb2.py b/app/app/events/generated/event_pb2.py index a805c62..8132660 100644 --- a/app/app/events/generated/event_pb2.py +++ b/app/app/events/generated/event_pb2.py @@ -24,7 +24,7 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0b\x65vent.proto\x12\x12simplelogin_events\"(\n\x0fUserPlanChanged\x12\x15\n\rplan_end_time\x18\x01 \x01(\r\"\r\n\x0bUserDeleted\"Z\n\x0c\x41liasCreated\x12\x10\n\x08\x61lias_id\x18\x01 \x01(\r\x12\x13\n\x0b\x61lias_email\x18\x02 \x01(\t\x12\x12\n\nalias_note\x18\x03 \x01(\t\x12\x0f\n\x07\x65nabled\x18\x04 \x01(\x08\"L\n\x12\x41liasStatusChanged\x12\x10\n\x08\x61lias_id\x18\x01 \x01(\r\x12\x13\n\x0b\x61lias_email\x18\x02 \x01(\t\x12\x0f\n\x07\x65nabled\x18\x03 \x01(\x08\"5\n\x0c\x41liasDeleted\x12\x10\n\x08\x61lias_id\x18\x01 \x01(\r\x12\x13\n\x0b\x61lias_email\x18\x02 \x01(\t\"D\n\x10\x41liasCreatedList\x12\x30\n\x06\x65vents\x18\x01 \x03(\x0b\x32 .simplelogin_events.AliasCreated\"\x93\x03\n\x0c\x45ventContent\x12?\n\x10user_plan_change\x18\x01 \x01(\x0b\x32#.simplelogin_events.UserPlanChangedH\x00\x12\x37\n\x0cuser_deleted\x18\x02 \x01(\x0b\x32\x1f.simplelogin_events.UserDeletedH\x00\x12\x39\n\ralias_created\x18\x03 \x01(\x0b\x32 .simplelogin_events.AliasCreatedH\x00\x12\x45\n\x13\x61lias_status_change\x18\x04 \x01(\x0b\x32&.simplelogin_events.AliasStatusChangedH\x00\x12\x39\n\ralias_deleted\x18\x05 \x01(\x0b\x32 .simplelogin_events.AliasDeletedH\x00\x12\x41\n\x11\x61lias_create_list\x18\x06 \x01(\x0b\x32$.simplelogin_events.AliasCreatedListH\x00\x42\t\n\x07\x63ontent\"y\n\x05\x45vent\x12\x0f\n\x07user_id\x18\x01 \x01(\r\x12\x18\n\x10\x65xternal_user_id\x18\x02 \x01(\t\x12\x12\n\npartner_id\x18\x03 \x01(\r\x12\x31\n\x07\x63ontent\x18\x04 \x01(\x0b\x32 .simplelogin_events.EventContentb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0b\x65vent.proto\x12\x12simplelogin_events\"(\n\x0fUserPlanChanged\x12\x15\n\rplan_end_time\x18\x01 \x01(\r\"\r\n\x0bUserDeleted\"\\\n\x0c\x41liasCreated\x12\n\n\x02id\x18\x01 \x01(\r\x12\r\n\x05\x65mail\x18\x02 \x01(\t\x12\x0c\n\x04note\x18\x03 \x01(\t\x12\x0f\n\x07\x65nabled\x18\x04 \x01(\x08\x12\x12\n\ncreated_at\x18\x05 \x01(\r\"T\n\x12\x41liasStatusChanged\x12\n\n\x02id\x18\x01 \x01(\r\x12\r\n\x05\x65mail\x18\x02 \x01(\t\x12\x0f\n\x07\x65nabled\x18\x03 \x01(\x08\x12\x12\n\ncreated_at\x18\x04 \x01(\r\")\n\x0c\x41liasDeleted\x12\n\n\x02id\x18\x01 \x01(\r\x12\r\n\x05\x65mail\x18\x02 \x01(\t\"D\n\x10\x41liasCreatedList\x12\x30\n\x06\x65vents\x18\x01 \x03(\x0b\x32 .simplelogin_events.AliasCreated\"\x93\x03\n\x0c\x45ventContent\x12?\n\x10user_plan_change\x18\x01 \x01(\x0b\x32#.simplelogin_events.UserPlanChangedH\x00\x12\x37\n\x0cuser_deleted\x18\x02 \x01(\x0b\x32\x1f.simplelogin_events.UserDeletedH\x00\x12\x39\n\ralias_created\x18\x03 \x01(\x0b\x32 .simplelogin_events.AliasCreatedH\x00\x12\x45\n\x13\x61lias_status_change\x18\x04 \x01(\x0b\x32&.simplelogin_events.AliasStatusChangedH\x00\x12\x39\n\ralias_deleted\x18\x05 \x01(\x0b\x32 .simplelogin_events.AliasDeletedH\x00\x12\x41\n\x11\x61lias_create_list\x18\x06 \x01(\x0b\x32$.simplelogin_events.AliasCreatedListH\x00\x42\t\n\x07\x63ontent\"y\n\x05\x45vent\x12\x0f\n\x07user_id\x18\x01 \x01(\r\x12\x18\n\x10\x65xternal_user_id\x18\x02 \x01(\t\x12\x12\n\npartner_id\x18\x03 \x01(\r\x12\x31\n\x07\x63ontent\x18\x04 \x01(\x0b\x32 .simplelogin_events.EventContentb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -36,15 +36,15 @@ if not _descriptor._USE_C_DESCRIPTORS: _globals['_USERDELETED']._serialized_start=77 _globals['_USERDELETED']._serialized_end=90 _globals['_ALIASCREATED']._serialized_start=92 - _globals['_ALIASCREATED']._serialized_end=182 - _globals['_ALIASSTATUSCHANGED']._serialized_start=184 - _globals['_ALIASSTATUSCHANGED']._serialized_end=260 - _globals['_ALIASDELETED']._serialized_start=262 - _globals['_ALIASDELETED']._serialized_end=315 - _globals['_ALIASCREATEDLIST']._serialized_start=317 - _globals['_ALIASCREATEDLIST']._serialized_end=385 - _globals['_EVENTCONTENT']._serialized_start=388 - _globals['_EVENTCONTENT']._serialized_end=791 - _globals['_EVENT']._serialized_start=793 - _globals['_EVENT']._serialized_end=914 + _globals['_ALIASCREATED']._serialized_end=184 + _globals['_ALIASSTATUSCHANGED']._serialized_start=186 + _globals['_ALIASSTATUSCHANGED']._serialized_end=270 + _globals['_ALIASDELETED']._serialized_start=272 + _globals['_ALIASDELETED']._serialized_end=313 + _globals['_ALIASCREATEDLIST']._serialized_start=315 + _globals['_ALIASCREATEDLIST']._serialized_end=383 + _globals['_EVENTCONTENT']._serialized_start=386 + _globals['_EVENTCONTENT']._serialized_end=789 + _globals['_EVENT']._serialized_start=791 + _globals['_EVENT']._serialized_end=912 # @@protoc_insertion_point(module_scope) diff --git a/app/app/events/generated/event_pb2.pyi b/app/app/events/generated/event_pb2.pyi index d1793ee..9bc43b3 100644 --- a/app/app/events/generated/event_pb2.pyi +++ b/app/app/events/generated/event_pb2.pyi @@ -16,34 +16,38 @@ class UserDeleted(_message.Message): def __init__(self) -> None: ... class AliasCreated(_message.Message): - __slots__ = ("alias_id", "alias_email", "alias_note", "enabled") - ALIAS_ID_FIELD_NUMBER: _ClassVar[int] - ALIAS_EMAIL_FIELD_NUMBER: _ClassVar[int] - ALIAS_NOTE_FIELD_NUMBER: _ClassVar[int] + __slots__ = ("id", "email", "note", "enabled", "created_at") + ID_FIELD_NUMBER: _ClassVar[int] + EMAIL_FIELD_NUMBER: _ClassVar[int] + NOTE_FIELD_NUMBER: _ClassVar[int] ENABLED_FIELD_NUMBER: _ClassVar[int] - alias_id: int - alias_email: str - alias_note: str + CREATED_AT_FIELD_NUMBER: _ClassVar[int] + id: int + email: str + note: str enabled: bool - def __init__(self, alias_id: _Optional[int] = ..., alias_email: _Optional[str] = ..., alias_note: _Optional[str] = ..., enabled: bool = ...) -> None: ... + created_at: int + def __init__(self, id: _Optional[int] = ..., email: _Optional[str] = ..., note: _Optional[str] = ..., enabled: bool = ..., created_at: _Optional[int] = ...) -> None: ... class AliasStatusChanged(_message.Message): - __slots__ = ("alias_id", "alias_email", "enabled") - ALIAS_ID_FIELD_NUMBER: _ClassVar[int] - ALIAS_EMAIL_FIELD_NUMBER: _ClassVar[int] + __slots__ = ("id", "email", "enabled", "created_at") + ID_FIELD_NUMBER: _ClassVar[int] + EMAIL_FIELD_NUMBER: _ClassVar[int] ENABLED_FIELD_NUMBER: _ClassVar[int] - alias_id: int - alias_email: str + CREATED_AT_FIELD_NUMBER: _ClassVar[int] + id: int + email: str enabled: bool - def __init__(self, alias_id: _Optional[int] = ..., alias_email: _Optional[str] = ..., enabled: bool = ...) -> None: ... + created_at: int + def __init__(self, id: _Optional[int] = ..., email: _Optional[str] = ..., enabled: bool = ..., created_at: _Optional[int] = ...) -> None: ... class AliasDeleted(_message.Message): - __slots__ = ("alias_id", "alias_email") - ALIAS_ID_FIELD_NUMBER: _ClassVar[int] - ALIAS_EMAIL_FIELD_NUMBER: _ClassVar[int] - alias_id: int - alias_email: str - def __init__(self, alias_id: _Optional[int] = ..., alias_email: _Optional[str] = ...) -> None: ... + __slots__ = ("id", "email") + ID_FIELD_NUMBER: _ClassVar[int] + EMAIL_FIELD_NUMBER: _ClassVar[int] + id: int + email: str + def __init__(self, id: _Optional[int] = ..., email: _Optional[str] = ...) -> None: ... class AliasCreatedList(_message.Message): __slots__ = ("events",) diff --git a/app/app/jobs/event_jobs.py b/app/app/jobs/event_jobs.py index cd990b9..ffe4fba 100644 --- a/app/app/jobs/event_jobs.py +++ b/app/app/jobs/event_jobs.py @@ -22,10 +22,11 @@ def send_alias_creation_events_for_user( ): event_list.append( AliasCreated( - alias_id=alias.id, - alias_email=alias.email, - alias_note=alias.note, + id=alias.id, + email=alias.email, + note=alias.note, enabled=alias.enabled, + created_at=int(alias.created_at.timestamp), ) ) if len(event_list) >= chunk_size: diff --git a/app/app/models.py b/app/app/models.py index 9ad0ed1..31ce24d 100644 --- a/app/app/models.py +++ b/app/app/models.py @@ -973,7 +973,7 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle): def has_custom_domain(self): return CustomDomain.filter_by(user_id=self.id, verified=True).count() > 0 - def custom_domains(self): + def custom_domains(self) -> List["CustomDomain"]: return CustomDomain.filter_by(user_id=self.id, verified=True).all() def available_domains_for_random_alias( @@ -1677,10 +1677,11 @@ class Alias(Base, ModelMixin): from app.events.generated.event_pb2 import AliasCreated, EventContent event = AliasCreated( - alias_id=new_alias.id, - alias_email=new_alias.email, - alias_note=new_alias.note, + id=new_alias.id, + email=new_alias.email, + note=new_alias.note, enabled=True, + created_at=int(new_alias.created_at.timestamp), ) EventDispatcher.send_event(user, EventContent(alias_created=event)) @@ -2418,6 +2419,14 @@ class CustomDomain(Base, ModelMixin): sa.Boolean, nullable=False, default=False, server_default="0" ) + partner_id = sa.Column( + sa.Integer, + sa.ForeignKey("partner.id"), + nullable=True, + default=None, + server_default=None, + ) + __table_args__ = ( Index( "ix_unique_domain", # Index name @@ -2442,9 +2451,6 @@ class CustomDomain(Base, ModelMixin): def get_trash_url(self): return config.URL + f"/dashboard/domains/{self.id}/trash" - def get_ownership_dns_txt_value(self): - return f"sl-verification={self.ownership_txt_token}" - @classmethod def create(cls, **kwargs): domain = kwargs.get("domain") diff --git a/app/migrations/versions/2024_091315_2441b7ff5da9_custom_domain_partner_id.py b/app/migrations/versions/2024_091315_2441b7ff5da9_custom_domain_partner_id.py new file mode 100644 index 0000000..ba33bea --- /dev/null +++ b/app/migrations/versions/2024_091315_2441b7ff5da9_custom_domain_partner_id.py @@ -0,0 +1,30 @@ +"""Custom Domain partner id + +Revision ID: 2441b7ff5da9 +Revises: 1c14339aae90 +Create Date: 2024-09-13 15:43:02.425964 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '2441b7ff5da9' +down_revision = '1c14339aae90' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('custom_domain', sa.Column('partner_id', sa.Integer(), nullable=True, default=None, server_default=None)) + op.create_foreign_key(None, 'custom_domain', 'partner', ['partner_id'], ['id']) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, 'custom_domain', type_='foreignkey') + op.drop_column('custom_domain', 'partner_id') + # ### end Alembic commands ### diff --git a/app/proto/event.proto b/app/proto/event.proto index 694849e..2fd0158 100644 --- a/app/proto/event.proto +++ b/app/proto/event.proto @@ -10,21 +10,23 @@ message UserDeleted { } message AliasCreated { - uint32 alias_id = 1; - string alias_email = 2; - string alias_note = 3; + uint32 id = 1; + string email = 2; + string note = 3; bool enabled = 4; + uint32 created_at = 5; } message AliasStatusChanged { - uint32 alias_id = 1; - string alias_email = 2; + uint32 id = 1; + string email = 2; bool enabled = 3; + uint32 created_at = 4; } message AliasDeleted { - uint32 alias_id = 1; - string alias_email = 2; + uint32 id = 1; + string email = 2; } message AliasCreatedList { diff --git a/app/scripts/new-migration.sh b/app/scripts/new-migration.sh index 5456857..da11a75 100755 --- a/app/scripts/new-migration.sh +++ b/app/scripts/new-migration.sh @@ -12,10 +12,10 @@ docker run -p 25432:5432 --name ${container_name} -e POSTGRES_PASSWORD=postgres sleep 3 # upgrade the DB to the latest stage and -env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl rye run alembic upgrade head +env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl poetry run alembic upgrade head # generate the migration script. -env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl rye run alembic revision --autogenerate $@ +env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl poetry run alembic revision --autogenerate $@ # remove the db docker rm -f ${container_name} diff --git a/app/templates/dashboard/custom_domain.html b/app/templates/dashboard/custom_domain.html index 62a12b0..2982867 100644 --- a/app/templates/dashboard/custom_domain.html +++ b/app/templates/dashboard/custom_domain.html @@ -94,4 +94,4 @@ {% endblock %} -{% block script %}{% endblock %} + diff --git a/app/templates/dashboard/domain_detail/dns.html b/app/templates/dashboard/domain_detail/dns.html index 15ef346..4058f5e 100644 --- a/app/templates/dashboard/domain_detail/dns.html +++ b/app/templates/dashboard/domain_detail/dns.html @@ -38,7 +38,7 @@ Value: {{ custom_domain.get_ownership_dns_txt_value() }} + data-clipboard-text="{{ ownership_record }}">{{ ownership_record }}