Compare commits

...

5 Commits

Author SHA1 Message Date
fd4c67c3d1 4.51.0
All checks were successful
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m30s
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m31s
Build-Release-Image / Merge-Images (push) Successful in 11s
Build-Release-Image / Create-Release (push) Successful in 9s
Build-Release-Image / Notify (push) Successful in 2s
2024-09-25 12:00:07 +01:00
edef254529 4.50.0
All checks were successful
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m43s
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m54s
Build-Release-Image / Merge-Images (push) Successful in 23s
Build-Release-Image / Create-Release (push) Successful in 14s
Build-Release-Image / Notify (push) Successful in 3s
2024-09-19 12:00:06 +01:00
357f0cca57 4.49.10
Some checks failed
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m7s
Build-Release-Image / Build-Image (linux/amd64) (push) Has been cancelled
Build-Release-Image / Merge-Images (push) Has been cancelled
Build-Release-Image / Create-Release (push) Has been cancelled
Build-Release-Image / Notify (push) Has been cancelled
2024-09-13 12:00:28 +01:00
8ce90e27f7 4.49.9
Some checks failed
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m4s
Build-Release-Image / Build-Image (linux/amd64) (push) Has been cancelled
Build-Release-Image / Merge-Images (push) Has been cancelled
Build-Release-Image / Create-Release (push) Has been cancelled
Build-Release-Image / Notify (push) Has been cancelled
2024-09-07 12:00:06 +01:00
3ecc8d36f9 4.49.8
Some checks failed
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m18s
Build-Release-Image / Build-Image (linux/amd64) (push) Has been cancelled
Build-Release-Image / Merge-Images (push) Has been cancelled
Build-Release-Image / Create-Release (push) Has been cancelled
Build-Release-Image / Notify (push) Has been cancelled
2024-09-04 12:00:07 +01:00
41 changed files with 1722 additions and 496 deletions

View File

@ -109,7 +109,7 @@ jobs:
GITHUB_ACTIONS_TEST: true
- name: Archive code coverage results
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v4
with:
name: code-coverage-report
path: htmlcov

View File

@ -736,7 +736,8 @@ class InvalidMailboxDomainAdmin(SLModelView):
class EmailSearchResult:
no_match: bool = True
alias: Optional[Alias] = None
mailbox: Optional[Mailbox] = None
mailbox: list[Mailbox] = []
mailbox_count: int = 0
deleted_alias: Optional[DeletedAlias] = None
deleted_custom_alias: Optional[DomainDeletedAlias] = None
user: Optional[User] = None
@ -752,9 +753,12 @@ class EmailSearchResult:
if user:
output.user = user
output.no_match = False
mailbox = Mailbox.get_by(email=email)
if mailbox:
output.mailbox = mailbox
mailboxes = (
Mailbox.filter_by(email=email).order_by(Mailbox.id.desc()).limit(10).all()
)
if mailboxes:
output.mailbox = mailboxes
output.mailbox_count = Mailbox.filter_by(email=email).count()
output.no_match = False
deleted_alias = DeletedAlias.get_by(email=email)
if deleted_alias:
@ -779,11 +783,13 @@ class EmailSearchHelpers:
@staticmethod
def mailbox_count(user: User) -> int:
return Mailbox.filter_by(user_id=user.id).order_by(Mailbox.id.asc()).count()
return Mailbox.filter_by(user_id=user.id).order_by(Mailbox.id.desc()).count()
@staticmethod
def alias_list(user: User) -> list[Alias]:
return Alias.filter_by(user_id=user.id).order_by(Alias.id.asc()).limit(10).all()
return (
Alias.filter_by(user_id=user.id).order_by(Alias.id.desc()).limit(10).all()
)
@staticmethod
def alias_count(user: User) -> int:
@ -806,9 +812,8 @@ class EmailSearchAdmin(BaseView):
@expose("/", methods=["GET", "POST"])
def index(self):
search = EmailSearchResult()
email = ""
if request.form and request.form["email"]:
email = request.form["email"]
email = request.args.get("email")
if email is not None and len(email) > 0:
email = email.strip()
search = EmailSearchResult.from_email(email)

View File

@ -363,11 +363,14 @@ def delete_alias(
Session.commit()
LOG.i(f"Moving {alias} to global trash {deleted_alias}")
alias_id = alias.id
alias_email = alias.email
Alias.filter(Alias.id == alias.id).delete()
Session.commit()
EventDispatcher.send_event(
user, EventContent(alias_deleted=AliasDeleted(alias_id=alias.id))
user,
EventContent(alias_deleted=AliasDeleted(id=alias_id, email=alias_email)),
)
if commit:
Session.commit()
@ -506,7 +509,10 @@ def change_alias_status(alias: Alias, enabled: bool, commit: bool = False):
alias.enabled = enabled
event = AliasStatusChanged(
alias_id=alias.id, alias_email=alias.email, enabled=enabled
id=alias.id,
email=alias.email,
enabled=enabled,
created_at=int(alias.created_at.timestamp),
)
EventDispatcher.send_event(alias.user, EventContent(alias_status_change=event))

View File

@ -35,6 +35,33 @@ def sl_getenv(env_var: str, default_factory: Callable = None):
return literal_eval(value)
def get_env_dict(env_var: str) -> dict[str, str]:
"""
Get an env variable and convert it into a python dictionary with keys and values as strings.
Args:
env_var (str): env var, example: SL_DB
Syntax is: key1=value1;key2=value2
Components separated by ;
key and value separated by =
"""
value = os.getenv(env_var)
if not value:
return {}
components = value.split(";")
result = {}
for component in components:
if component == "":
continue
parts = component.split("=")
if len(parts) != 2:
raise Exception(f"Invalid config for env var {env_var}")
result[parts[0].strip()] = parts[1].strip()
return result
config_file = os.environ.get("CONFIG")
if config_file:
config_file = get_abs_path(config_file)
@ -609,3 +636,24 @@ EVENT_WEBHOOK_ENABLED_USER_IDS: Optional[List[int]] = read_webhook_enabled_user_
# Allow to define a different DB_URI for the event listener, in case we want to skip the connection pool
# It defaults to the regular DB_URI in case it's needed
EVENT_LISTENER_DB_URI = os.environ.get("EVENT_LISTENER_DB_URI", DB_URI)
def read_partner_dict(var: str) -> dict[int, str]:
partner_value = get_env_dict(var)
if len(partner_value) == 0:
return {}
res: dict[int, str] = {}
for partner_id in partner_value.keys():
try:
partner_id_int = int(partner_id.strip())
res[partner_id_int] = partner_value[partner_id]
except ValueError:
pass
return res
PARTNER_DOMAINS: dict[int, str] = read_partner_dict("PARTNER_DOMAINS")
PARTNER_DOMAIN_VALIDATION_PREFIXES: dict[int, str] = read_partner_dict(
"PARTNER_DOMAIN_VALIDATION_PREFIXES"
)

View File

@ -1 +1,2 @@
HEADER_ALLOW_API_COOKIES = "X-Sl-Allowcookies"
DMARC_RECORD = "v=DMARC1; p=quarantine; pct=100; adkim=s; aspf=s"

89
app/app/contact_utils.py Normal file
View File

@ -0,0 +1,89 @@
from dataclasses import dataclass
from enum import Enum
from typing import Optional
from sqlalchemy.exc import IntegrityError
from app.db import Session
from app.email_utils import generate_reply_email
from app.email_validation import is_valid_email
from app.log import LOG
from app.models import Contact, Alias
from app.utils import sanitize_email
class ContactCreateError(Enum):
InvalidEmail = "Invalid email"
@dataclass
class ContactCreateResult:
contact: Optional[Contact]
error: Optional[ContactCreateError]
def __update_contact_if_needed(
contact: Contact, name: Optional[str], mail_from: Optional[str]
) -> ContactCreateResult:
if name and contact.name != name:
LOG.d(f"Setting {contact} name to {name}")
contact.name = name
Session.commit()
if mail_from and contact.mail_from is None:
LOG.d(f"Setting {contact} mail_from to {mail_from}")
contact.mail_from = mail_from
Session.commit()
return ContactCreateResult(contact, None)
def create_contact(
email: str,
name: Optional[str],
alias: Alias,
mail_from: Optional[str] = None,
allow_empty_email: bool = False,
automatic_created: bool = False,
from_partner: bool = False,
) -> ContactCreateResult:
if name is not None:
name = name[: Contact.MAX_NAME_LENGTH]
if name is not None and "\x00" in name:
LOG.w("Cannot use contact name because has \\x00")
name = ""
if not is_valid_email(email):
LOG.w(f"invalid contact email {email}")
if not allow_empty_email:
return ContactCreateResult(None, ContactCreateError.InvalidEmail)
LOG.d("Create a contact with invalid email for %s", alias)
# either reuse a contact with empty email or create a new contact with empty email
email = ""
email = sanitize_email(email, not_lower=True)
contact = Contact.get_by(alias_id=alias.id, website_email=email)
if contact is not None:
return __update_contact_if_needed(contact, name, mail_from)
reply_email = generate_reply_email(email, alias)
try:
flags = Contact.FLAG_PARTNER_CREATED if from_partner else 0
contact = Contact.create(
user_id=alias.user_id,
alias_id=alias.id,
website_email=email,
name=name,
reply_email=reply_email,
mail_from=mail_from,
automatic_created=automatic_created,
flags=flags,
invalid_email=email == "",
commit=True,
)
LOG.d(
f"Created contact {contact} for alias {alias} with email {email} invalid_email={contact.invalid_email}"
)
except IntegrityError:
Session.rollback()
LOG.info(
f"Contact with email {email} for alias_id {alias.id} already existed, fetching from DB"
)
contact = Contact.get_by(alias_id=alias.id, website_email=email)
return __update_contact_if_needed(contact, name, mail_from)
return ContactCreateResult(contact, None)

View File

@ -0,0 +1,142 @@
import arrow
import re
from dataclasses import dataclass
from enum import Enum
from typing import Optional
from app.config import JOB_DELETE_DOMAIN
from app.db import Session
from app.email_utils import get_email_domain_part
from app.log import LOG
from app.models import User, CustomDomain, SLDomain, Mailbox, Job
_ALLOWED_DOMAIN_REGEX = re.compile(r"^(?!-)[A-Za-z0-9-]{1,63}(?<!-)$")
@dataclass
class CreateCustomDomainResult:
message: str = ""
message_category: str = ""
success: bool = False
instance: Optional[CustomDomain] = None
redirect: Optional[str] = None
class CannotUseDomainReason(Enum):
InvalidDomain = 1
BuiltinDomain = 2
DomainAlreadyUsed = 3
DomainPartOfUserEmail = 4
DomainUserInMailbox = 5
def message(self, domain: str) -> str:
if self == CannotUseDomainReason.InvalidDomain:
return "This is not a valid domain"
elif self == CannotUseDomainReason.BuiltinDomain:
return "A custom domain cannot be a built-in domain."
elif self == CannotUseDomainReason.DomainAlreadyUsed:
return f"{domain} already used"
elif self == CannotUseDomainReason.DomainPartOfUserEmail:
return "You cannot add a domain that you are currently using for your personal email. Please change your personal email to your real email"
elif self == CannotUseDomainReason.DomainUserInMailbox:
return f"{domain} already used in a SimpleLogin mailbox"
else:
raise Exception("Invalid CannotUseDomainReason")
def is_valid_domain(domain: str) -> bool:
"""
Checks that a domain is valid according to RFC 1035
"""
if len(domain) > 255:
return False
if domain.endswith("."):
domain = domain[:-1] # Strip the trailing dot
labels = domain.split(".")
if not labels:
return False
for label in labels:
if not _ALLOWED_DOMAIN_REGEX.match(label):
return False
return True
def sanitize_domain(domain: str) -> str:
new_domain = domain.lower().strip()
if new_domain.startswith("http://"):
new_domain = new_domain[len("http://") :]
if new_domain.startswith("https://"):
new_domain = new_domain[len("https://") :]
return new_domain
def can_domain_be_used(user: User, domain: str) -> Optional[CannotUseDomainReason]:
if not is_valid_domain(domain):
return CannotUseDomainReason.InvalidDomain
elif SLDomain.get_by(domain=domain):
return CannotUseDomainReason.BuiltinDomain
elif CustomDomain.get_by(domain=domain):
return CannotUseDomainReason.DomainAlreadyUsed
elif get_email_domain_part(user.email) == domain:
return CannotUseDomainReason.DomainPartOfUserEmail
elif Mailbox.filter(
Mailbox.verified.is_(True), Mailbox.email.endswith(f"@{domain}")
).first():
return CannotUseDomainReason.DomainUserInMailbox
else:
return None
def create_custom_domain(
user: User, domain: str, partner_id: Optional[int] = None
) -> CreateCustomDomainResult:
if not user.is_premium():
return CreateCustomDomainResult(
message="Only premium plan can add custom domain",
message_category="warning",
)
new_domain = sanitize_domain(domain)
domain_forbidden_cause = can_domain_be_used(user, new_domain)
if domain_forbidden_cause:
return CreateCustomDomainResult(
message=domain_forbidden_cause.message(new_domain), message_category="error"
)
new_custom_domain = CustomDomain.create(domain=new_domain, user_id=user.id)
# new domain has ownership verified if its parent has the ownership verified
for root_cd in user.custom_domains:
if new_domain.endswith("." + root_cd.domain) and root_cd.ownership_verified:
LOG.i(
"%s ownership verified thanks to %s",
new_custom_domain,
root_cd,
)
new_custom_domain.ownership_verified = True
# Add the partner_id in case it's passed
if partner_id is not None:
new_custom_domain.partner_id = partner_id
Session.commit()
return CreateCustomDomainResult(
success=True,
instance=new_custom_domain,
)
def delete_custom_domain(domain: CustomDomain):
# Schedule delete domain job
LOG.w("schedule delete domain job for %s", domain)
domain.pending_deletion = True
Job.create(
name=JOB_DELETE_DOMAIN,
payload={"custom_domain_id": domain.id},
run_at=arrow.now(),
commit=True,
)

View File

@ -1,37 +1,157 @@
from dataclasses import dataclass
from typing import Optional
from app import config
from app.constants import DMARC_RECORD
from app.db import Session
from app.dns_utils import get_cname_record
from app.dns_utils import (
DNSClient,
is_mx_equivalent,
get_network_dns_client,
)
from app.models import CustomDomain
@dataclass
class DomainValidationResult:
success: bool
errors: [str]
class CustomDomainValidation:
def __init__(self, dkim_domain: str):
def __init__(
self,
dkim_domain: str,
dns_client: DNSClient = get_network_dns_client(),
partner_domains: Optional[dict[int, str]] = None,
partner_domains_validation_prefixes: Optional[dict[int, str]] = None,
):
self.dkim_domain = dkim_domain
self._dkim_records = {
(f"{key}._domainkey", f"{key}._domainkey.{self.dkim_domain}")
self._dns_client = dns_client
self._partner_domains = partner_domains or config.PARTNER_DOMAINS
self._partner_domain_validation_prefixes = (
partner_domains_validation_prefixes
or config.PARTNER_DOMAIN_VALIDATION_PREFIXES
)
def get_ownership_verification_record(self, domain: CustomDomain) -> str:
prefix = "sl"
if (
domain.partner_id is not None
and domain.partner_id in self._partner_domain_validation_prefixes
):
prefix = self._partner_domain_validation_prefixes[domain.partner_id]
return f"{prefix}-verification={domain.ownership_txt_token}"
def get_dkim_records(self, domain: CustomDomain) -> {str: str}:
"""
Get a list of dkim records to set up. Depending on the custom_domain, whether if it's from a partner or not,
it will return the default ones or the partner ones.
"""
# By default use the default domain
dkim_domain = self.dkim_domain
if domain.partner_id is not None:
# Domain is from a partner. Retrieve the partner config and use that domain if exists
dkim_domain = self._partner_domains.get(domain.partner_id, dkim_domain)
return {
f"{key}._domainkey": f"{key}._domainkey.{dkim_domain}"
for key in ("dkim", "dkim02", "dkim03")
}
def get_dkim_records(self) -> {str: str}:
"""
Get a list of dkim records to set up. It will be
"""
return self._dkim_records
def validate_dkim_records(self, custom_domain: CustomDomain) -> dict[str, str]:
"""
Check if dkim records are properly set for this custom domain.
Returns empty list if all records are ok. Other-wise return the records that aren't properly configured
"""
correct_records = {}
invalid_records = {}
for prefix, expected_record in self.get_dkim_records():
expected_records = self.get_dkim_records(custom_domain)
for prefix, expected_record in expected_records.items():
custom_record = f"{prefix}.{custom_domain.domain}"
dkim_record = get_cname_record(custom_record)
if dkim_record != expected_record:
dkim_record = self._dns_client.get_cname_record(custom_record)
if dkim_record == expected_record:
correct_records[prefix] = custom_record
else:
invalid_records[custom_record] = dkim_record or "empty"
# HACK: If dkim is enabled, don't disable it to give users time to update their CNAMES
# HACK
# As initially we only had one dkim record, we want to allow users that had only the original dkim record and
# the domain validated to continue seeing it as validated (although showing them the missing records).
# However, if not even the original dkim record is right, even if the domain was dkim_verified in the past,
# we will remove the dkim_verified flag.
# This is done in order to give users with the old dkim config (only one) to update their CNAMEs
if custom_domain.dkim_verified:
return invalid_records
# Check if at least the original dkim is there
if correct_records.get("dkim._domainkey") is not None:
# Original dkim record is there. Return the missing records (if any) and don't clear the flag
return invalid_records
# Original DKIM record is not there, which means the DKIM config is not finished. Proceed with the
# rest of the code path, returning the invalid records and clearing the flag
custom_domain.dkim_verified = len(invalid_records) == 0
Session.commit()
return invalid_records
def validate_domain_ownership(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
"""
Check if the custom_domain has added the ownership verification records
"""
txt_records = self._dns_client.get_txt_record(custom_domain.domain)
expected_verification_record = self.get_ownership_verification_record(
custom_domain
)
if expected_verification_record in txt_records:
custom_domain.ownership_verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
else:
return DomainValidationResult(success=False, errors=txt_records)
def validate_mx_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
mx_domains = self._dns_client.get_mx_domains(custom_domain.domain)
if not is_mx_equivalent(mx_domains, config.EMAIL_SERVERS_WITH_PRIORITY):
return DomainValidationResult(
success=False,
errors=[f"{priority} {domain}" for (priority, domain) in mx_domains],
)
else:
custom_domain.verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
def validate_spf_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
spf_domains = self._dns_client.get_spf_domain(custom_domain.domain)
if config.EMAIL_DOMAIN in spf_domains:
custom_domain.spf_verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
else:
custom_domain.spf_verified = False
Session.commit()
return DomainValidationResult(
success=False,
errors=self._dns_client.get_txt_record(custom_domain.domain),
)
def validate_dmarc_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
txt_records = self._dns_client.get_txt_record("_dmarc." + custom_domain.domain)
if DMARC_RECORD in txt_records:
custom_domain.dmarc_verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
else:
custom_domain.dmarc_verified = False
Session.commit()
return DomainValidationResult(success=False, errors=txt_records)

View File

@ -5,11 +5,9 @@ from wtforms import StringField, validators
from app import parallel_limiter
from app.config import EMAIL_SERVERS_WITH_PRIORITY
from app.custom_domain_utils import create_custom_domain
from app.dashboard.base import dashboard_bp
from app.db import Session
from app.email_utils import get_email_domain_part
from app.log import LOG
from app.models import CustomDomain, Mailbox, DomainMailbox, SLDomain
from app.models import CustomDomain
class NewCustomDomainForm(FlaskForm):
@ -25,11 +23,8 @@ def custom_domain():
custom_domains = CustomDomain.filter_by(
user_id=current_user.id, is_sl_subdomain=False
).all()
mailboxes = current_user.mailboxes()
new_custom_domain_form = NewCustomDomainForm()
errors = {}
if request.method == "POST":
if request.form.get("form-name") == "create":
if not current_user.is_premium():
@ -37,87 +32,25 @@ def custom_domain():
return redirect(url_for("dashboard.custom_domain"))
if new_custom_domain_form.validate():
new_domain = new_custom_domain_form.domain.data.lower().strip()
if new_domain.startswith("http://"):
new_domain = new_domain[len("http://") :]
if new_domain.startswith("https://"):
new_domain = new_domain[len("https://") :]
if SLDomain.get_by(domain=new_domain):
flash("A custom domain cannot be a built-in domain.", "error")
elif CustomDomain.get_by(domain=new_domain):
flash(f"{new_domain} already used", "error")
elif get_email_domain_part(current_user.email) == new_domain:
flash(
"You cannot add a domain that you are currently using for your personal email. "
"Please change your personal email to your real email",
"error",
)
elif Mailbox.filter(
Mailbox.verified.is_(True), Mailbox.email.endswith(f"@{new_domain}")
).first():
flash(
f"{new_domain} already used in a SimpleLogin mailbox", "error"
)
else:
new_custom_domain = CustomDomain.create(
domain=new_domain, user_id=current_user.id
)
# new domain has ownership verified if its parent has the ownership verified
for root_cd in current_user.custom_domains:
if (
new_domain.endswith("." + root_cd.domain)
and root_cd.ownership_verified
):
LOG.i(
"%s ownership verified thanks to %s",
new_custom_domain,
root_cd,
)
new_custom_domain.ownership_verified = True
Session.commit()
mailbox_ids = request.form.getlist("mailbox_ids")
if mailbox_ids:
# check if mailbox is not tempered with
mailboxes = []
for mailbox_id in mailbox_ids:
mailbox = Mailbox.get(mailbox_id)
if (
not mailbox
or mailbox.user_id != current_user.id
or not mailbox.verified
):
flash("Something went wrong, please retry", "warning")
return redirect(url_for("dashboard.custom_domain"))
mailboxes.append(mailbox)
for mailbox in mailboxes:
DomainMailbox.create(
domain_id=new_custom_domain.id, mailbox_id=mailbox.id
)
Session.commit()
flash(
f"New domain {new_custom_domain.domain} is created", "success"
)
res = create_custom_domain(
user=current_user, domain=new_custom_domain_form.domain.data
)
if res.success:
flash(f"New domain {res.instance.domain} is created", "success")
return redirect(
url_for(
"dashboard.domain_detail_dns",
custom_domain_id=new_custom_domain.id,
custom_domain_id=res.instance.id,
)
)
else:
flash(res.message, res.message_category)
if res.redirect:
return redirect(url_for(res.redirect))
return render_template(
"dashboard/custom_domain.html",
custom_domains=custom_domains,
new_custom_domain_form=new_custom_domain_form,
EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY,
errors=errors,
mailboxes=mailboxes,
)

View File

@ -1,22 +1,16 @@
import re
import arrow
from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user
from flask_wtf import FlaskForm
from wtforms import StringField, validators, IntegerField
from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN, JOB_DELETE_DOMAIN
from app.constants import DMARC_RECORD
from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN
from app.custom_domain_utils import delete_custom_domain
from app.custom_domain_validation import CustomDomainValidation
from app.dashboard.base import dashboard_bp
from app.db import Session
from app.dns_utils import (
get_mx_domains,
get_spf_domain,
get_txt_record,
is_mx_equivalent,
)
from app.log import LOG
from app.models import (
CustomDomain,
Alias,
@ -25,7 +19,6 @@ from app.models import (
DomainMailbox,
AutoCreateRule,
AutoCreateRuleMailbox,
Job,
)
from app.regex_utils import regex_match
from app.utils import random_string, CSRFValidationForm
@ -49,8 +42,6 @@ def domain_detail_dns(custom_domain_id):
domain_validator = CustomDomainValidation(EMAIL_DOMAIN)
csrf_form = CSRFValidationForm()
dmarc_record = "v=DMARC1; p=quarantine; pct=100; adkim=s; aspf=s"
mx_ok = spf_ok = dkim_ok = dmarc_ok = ownership_ok = True
mx_errors = spf_errors = dkim_errors = dmarc_errors = ownership_errors = []
@ -59,15 +50,14 @@ def domain_detail_dns(custom_domain_id):
flash("Invalid request", "warning")
return redirect(request.url)
if request.form.get("form-name") == "check-ownership":
txt_records = get_txt_record(custom_domain.domain)
if custom_domain.get_ownership_dns_txt_value() in txt_records:
ownership_validation_result = domain_validator.validate_domain_ownership(
custom_domain
)
if ownership_validation_result.success:
flash(
"Domain ownership is verified. Please proceed to the other records setup",
"success",
)
custom_domain.ownership_verified = True
Session.commit()
return redirect(
url_for(
"dashboard.domain_detail_dns",
@ -78,36 +68,28 @@ def domain_detail_dns(custom_domain_id):
else:
flash("We can't find the needed TXT record", "error")
ownership_ok = False
ownership_errors = txt_records
ownership_errors = ownership_validation_result.errors
elif request.form.get("form-name") == "check-mx":
mx_domains = get_mx_domains(custom_domain.domain)
if not is_mx_equivalent(mx_domains, EMAIL_SERVERS_WITH_PRIORITY):
flash("The MX record is not correctly set", "warning")
mx_ok = False
# build mx_errors to show to user
mx_errors = [
f"{priority} {domain}" for (priority, domain) in mx_domains
]
else:
mx_validation_result = domain_validator.validate_mx_records(custom_domain)
if mx_validation_result.success:
flash(
"Your domain can start receiving emails. You can now use it to create alias",
"success",
)
custom_domain.verified = True
Session.commit()
return redirect(
url_for(
"dashboard.domain_detail_dns", custom_domain_id=custom_domain.id
)
)
else:
flash("The MX record is not correctly set", "warning")
mx_ok = False
mx_errors = mx_validation_result.errors
elif request.form.get("form-name") == "check-spf":
spf_domains = get_spf_domain(custom_domain.domain)
if EMAIL_DOMAIN in spf_domains:
custom_domain.spf_verified = True
Session.commit()
spf_validation_result = domain_validator.validate_spf_records(custom_domain)
if spf_validation_result.success:
flash("SPF is setup correctly", "success")
return redirect(
url_for(
@ -115,14 +97,12 @@ def domain_detail_dns(custom_domain_id):
)
)
else:
custom_domain.spf_verified = False
Session.commit()
flash(
f"SPF: {EMAIL_DOMAIN} is not included in your SPF record.",
"warning",
)
spf_ok = False
spf_errors = get_txt_record(custom_domain.domain)
spf_errors = spf_validation_result.errors
elif request.form.get("form-name") == "check-dkim":
dkim_errors = domain_validator.validate_dkim_records(custom_domain)
@ -138,10 +118,10 @@ def domain_detail_dns(custom_domain_id):
flash("DKIM: the CNAME record is not correctly set", "warning")
elif request.form.get("form-name") == "check-dmarc":
txt_records = get_txt_record("_dmarc." + custom_domain.domain)
if dmarc_record in txt_records:
custom_domain.dmarc_verified = True
Session.commit()
dmarc_validation_result = domain_validator.validate_dmarc_records(
custom_domain
)
if dmarc_validation_result.success:
flash("DMARC is setup correctly", "success")
return redirect(
url_for(
@ -149,19 +129,21 @@ def domain_detail_dns(custom_domain_id):
)
)
else:
custom_domain.dmarc_verified = False
Session.commit()
flash(
"DMARC: The TXT record is not correctly set",
"warning",
)
dmarc_ok = False
dmarc_errors = txt_records
dmarc_errors = dmarc_validation_result.errors
return render_template(
"dashboard/domain_detail/dns.html",
EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY,
dkim_records=domain_validator.get_dkim_records(),
ownership_record=domain_validator.get_ownership_verification_record(
custom_domain
),
dkim_records=domain_validator.get_dkim_records(custom_domain),
dmarc_record=DMARC_RECORD,
**locals(),
)
@ -279,16 +261,8 @@ def domain_detail(custom_domain_id):
elif request.form.get("form-name") == "delete":
name = custom_domain.domain
LOG.d("Schedule deleting %s", custom_domain)
# Schedule delete domain job
LOG.w("schedule delete domain job for %s", custom_domain)
Job.create(
name=JOB_DELETE_DOMAIN,
payload={"custom_domain_id": custom_domain.id},
run_at=arrow.now(),
commit=True,
)
delete_custom_domain(custom_domain)
flash(
f"{name} scheduled for deletion."

View File

@ -123,7 +123,12 @@ def mailbox_verify():
if not code:
# Old way
return verify_with_signed_secret(mailbox_id)
mailbox = mailbox_utils.verify_mailbox_code(current_user, mailbox_id, code)
try:
mailbox = mailbox_utils.verify_mailbox_code(current_user, mailbox_id, code)
except mailbox_utils.MailboxError as e:
LOG.i(f"Cannot verify mailbox {mailbox_id} because of {e}")
flash(f"Cannot verify mailbox: {e.msg}", "error")
return redirect(url_for("dashboard.mailbox_route"))
LOG.d("Mailbox %s is verified", mailbox)
return render_template("dashboard/mailbox_validation.html", mailbox=mailbox)

View File

@ -1,100 +1,13 @@
from app import config
from typing import Optional, List, Tuple
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional
import dns.resolver
def _get_dns_resolver():
my_resolver = dns.resolver.Resolver()
my_resolver.nameservers = config.NAMESERVERS
return my_resolver
def get_ns(hostname) -> [str]:
try:
answers = _get_dns_resolver().resolve(hostname, "NS", search=True)
except Exception:
return []
return [a.to_text() for a in answers]
def get_cname_record(hostname) -> Optional[str]:
"""Return the CNAME record if exists for a domain, WITHOUT the trailing period at the end"""
try:
answers = _get_dns_resolver().resolve(hostname, "CNAME", search=True)
except Exception:
return None
for a in answers:
ret = a.to_text()
return ret[:-1]
return None
def get_mx_domains(hostname) -> [(int, str)]:
"""return list of (priority, domain name) sorted by priority (lowest priority first)
domain name ends with a "." at the end.
"""
try:
answers = _get_dns_resolver().resolve(hostname, "MX", search=True)
except Exception:
return []
ret = []
for a in answers:
record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
parts = record.split(" ")
ret.append((int(parts[0]), parts[1]))
return sorted(ret, key=lambda prio_domain: prio_domain[0])
from app.config import NAMESERVERS
_include_spf = "include:"
def get_spf_domain(hostname) -> [str]:
"""return all domains listed in *include:*"""
try:
answers = _get_dns_resolver().resolve(hostname, "TXT", search=True)
except Exception:
return []
ret = []
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
for record in a.strings:
record = record.decode() # record is bytes
if record.startswith("v=spf1"):
parts = record.split(" ")
for part in parts:
if part.startswith(_include_spf):
ret.append(part[part.find(_include_spf) + len(_include_spf) :])
return ret
def get_txt_record(hostname) -> [str]:
try:
answers = _get_dns_resolver().resolve(hostname, "TXT", search=True)
except Exception:
return []
ret = []
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
for record in a.strings:
record = record.decode() # record is bytes
ret.append(record)
return ret
def is_mx_equivalent(
mx_domains: List[Tuple[int, str]], ref_mx_domains: List[Tuple[int, str]]
) -> bool:
@ -105,16 +18,127 @@ def is_mx_equivalent(
The priority order is taken into account but not the priority number.
For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)]
"""
mx_domains = sorted(mx_domains, key=lambda priority_domain: priority_domain[0])
ref_mx_domains = sorted(
ref_mx_domains, key=lambda priority_domain: priority_domain[0]
)
mx_domains = sorted(mx_domains, key=lambda x: x[0])
ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x[0])
if len(mx_domains) < len(ref_mx_domains):
return False
for i in range(0, len(ref_mx_domains)):
for i in range(len(ref_mx_domains)):
if mx_domains[i][1] != ref_mx_domains[i][1]:
return False
return True
class DNSClient(ABC):
@abstractmethod
def get_cname_record(self, hostname: str) -> Optional[str]:
pass
@abstractmethod
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
pass
def get_spf_domain(self, hostname: str) -> List[str]:
"""
return all domains listed in *include:*
"""
try:
records = self.get_txt_record(hostname)
ret = []
for record in records:
if record.startswith("v=spf1"):
parts = record.split(" ")
for part in parts:
if part.startswith(_include_spf):
ret.append(
part[part.find(_include_spf) + len(_include_spf) :]
)
return ret
except Exception:
return []
@abstractmethod
def get_txt_record(self, hostname: str) -> List[str]:
pass
class NetworkDNSClient(DNSClient):
def __init__(self, nameservers: List[str]):
self._resolver = dns.resolver.Resolver()
self._resolver.nameservers = nameservers
def get_cname_record(self, hostname: str) -> Optional[str]:
"""
Return the CNAME record if exists for a domain, WITHOUT the trailing period at the end
"""
try:
answers = self._resolver.resolve(hostname, "CNAME", search=True)
for a in answers:
ret = a.to_text()
return ret[:-1]
except Exception:
return None
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
"""
return list of (priority, domain name) sorted by priority (lowest priority first)
domain name ends with a "." at the end.
"""
try:
answers = self._resolver.resolve(hostname, "MX", search=True)
ret = []
for a in answers:
record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
parts = record.split(" ")
ret.append((int(parts[0]), parts[1]))
return sorted(ret, key=lambda x: x[0])
except Exception:
return []
def get_txt_record(self, hostname: str) -> List[str]:
try:
answers = self._resolver.resolve(hostname, "TXT", search=True)
ret = []
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
for record in a.strings:
ret.append(record.decode())
return ret
except Exception:
return []
class InMemoryDNSClient(DNSClient):
def __init__(self):
self.cname_records: dict[str, Optional[str]] = {}
self.mx_records: dict[str, List[Tuple[int, str]]] = {}
self.spf_records: dict[str, List[str]] = {}
self.txt_records: dict[str, List[str]] = {}
def set_cname_record(self, hostname: str, cname: str):
self.cname_records[hostname] = cname
def set_mx_records(self, hostname: str, mx_list: List[Tuple[int, str]]):
self.mx_records[hostname] = mx_list
def set_txt_record(self, hostname: str, txt_list: List[str]):
self.txt_records[hostname] = txt_list
def get_cname_record(self, hostname: str) -> Optional[str]:
return self.cname_records.get(hostname)
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
mx_list = self.mx_records.get(hostname, [])
return sorted(mx_list, key=lambda x: x[0])
def get_txt_record(self, hostname: str) -> List[str]:
return self.txt_records.get(hostname, [])
def get_network_dns_client() -> NetworkDNSClient:
return NetworkDNSClient(NAMESERVERS)
def get_mx_domains(hostname: str) -> [(int, str)]:
return get_network_dns_client().get_mx_domains(hostname)

View File

@ -30,14 +30,30 @@ class PostgresDispatcher(Dispatcher):
return PostgresDispatcher()
class GlobalDispatcher:
__dispatcher: Optional[Dispatcher] = None
@staticmethod
def get_dispatcher() -> Dispatcher:
if not GlobalDispatcher.__dispatcher:
GlobalDispatcher.__dispatcher = PostgresDispatcher.get()
return GlobalDispatcher.__dispatcher
@staticmethod
def set_dispatcher(dispatcher: Optional[Dispatcher]):
GlobalDispatcher.__dispatcher = dispatcher
class EventDispatcher:
@staticmethod
def send_event(
user: User,
content: event_pb2.EventContent,
dispatcher: Dispatcher = PostgresDispatcher.get(),
dispatcher: Optional[Dispatcher] = None,
skip_if_webhook_missing: bool = True,
):
if dispatcher is None:
dispatcher = GlobalDispatcher.get_dispatcher()
if config.EVENT_WEBHOOK_DISABLE:
LOG.i("Not sending events because webhook is disabled")
return
@ -54,9 +70,7 @@ class EventDispatcher:
partner_user = EventDispatcher.__partner_user(user.id)
if not partner_user:
LOG.i(
f"Not sending events because there's no partner user for user {user}"
)
LOG.i(f"Not sending events because there's no partner user for user {user}")
return
event = event_pb2.Event(
@ -68,7 +82,9 @@ class EventDispatcher:
serialized = event.SerializeToString()
dispatcher.send(serialized)
newrelic.agent.record_custom_metric("Custom/events_stored", 1)
event_type = content.WhichOneof("content")
newrelic.agent.record_custom_event("EventStoredToDb", {"type": event_type})
LOG.i("Sent event to the dispatcher")
@staticmethod

View File

@ -24,7 +24,7 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0b\x65vent.proto\x12\x12simplelogin_events\"(\n\x0fUserPlanChanged\x12\x15\n\rplan_end_time\x18\x01 \x01(\r\"\r\n\x0bUserDeleted\"Z\n\x0c\x41liasCreated\x12\x10\n\x08\x61lias_id\x18\x01 \x01(\r\x12\x13\n\x0b\x61lias_email\x18\x02 \x01(\t\x12\x12\n\nalias_note\x18\x03 \x01(\t\x12\x0f\n\x07\x65nabled\x18\x04 \x01(\x08\"L\n\x12\x41liasStatusChanged\x12\x10\n\x08\x61lias_id\x18\x01 \x01(\r\x12\x13\n\x0b\x61lias_email\x18\x02 \x01(\t\x12\x0f\n\x07\x65nabled\x18\x03 \x01(\x08\"5\n\x0c\x41liasDeleted\x12\x10\n\x08\x61lias_id\x18\x01 \x01(\r\x12\x13\n\x0b\x61lias_email\x18\x02 \x01(\t\"D\n\x10\x41liasCreatedList\x12\x30\n\x06\x65vents\x18\x01 \x03(\x0b\x32 .simplelogin_events.AliasCreated\"\x93\x03\n\x0c\x45ventContent\x12?\n\x10user_plan_change\x18\x01 \x01(\x0b\x32#.simplelogin_events.UserPlanChangedH\x00\x12\x37\n\x0cuser_deleted\x18\x02 \x01(\x0b\x32\x1f.simplelogin_events.UserDeletedH\x00\x12\x39\n\ralias_created\x18\x03 \x01(\x0b\x32 .simplelogin_events.AliasCreatedH\x00\x12\x45\n\x13\x61lias_status_change\x18\x04 \x01(\x0b\x32&.simplelogin_events.AliasStatusChangedH\x00\x12\x39\n\ralias_deleted\x18\x05 \x01(\x0b\x32 .simplelogin_events.AliasDeletedH\x00\x12\x41\n\x11\x61lias_create_list\x18\x06 \x01(\x0b\x32$.simplelogin_events.AliasCreatedListH\x00\x42\t\n\x07\x63ontent\"y\n\x05\x45vent\x12\x0f\n\x07user_id\x18\x01 \x01(\r\x12\x18\n\x10\x65xternal_user_id\x18\x02 \x01(\t\x12\x12\n\npartner_id\x18\x03 \x01(\r\x12\x31\n\x07\x63ontent\x18\x04 \x01(\x0b\x32 .simplelogin_events.EventContentb\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0b\x65vent.proto\x12\x12simplelogin_events\"(\n\x0fUserPlanChanged\x12\x15\n\rplan_end_time\x18\x01 \x01(\r\"\r\n\x0bUserDeleted\"\\\n\x0c\x41liasCreated\x12\n\n\x02id\x18\x01 \x01(\r\x12\r\n\x05\x65mail\x18\x02 \x01(\t\x12\x0c\n\x04note\x18\x03 \x01(\t\x12\x0f\n\x07\x65nabled\x18\x04 \x01(\x08\x12\x12\n\ncreated_at\x18\x05 \x01(\r\"T\n\x12\x41liasStatusChanged\x12\n\n\x02id\x18\x01 \x01(\r\x12\r\n\x05\x65mail\x18\x02 \x01(\t\x12\x0f\n\x07\x65nabled\x18\x03 \x01(\x08\x12\x12\n\ncreated_at\x18\x04 \x01(\r\")\n\x0c\x41liasDeleted\x12\n\n\x02id\x18\x01 \x01(\r\x12\r\n\x05\x65mail\x18\x02 \x01(\t\"D\n\x10\x41liasCreatedList\x12\x30\n\x06\x65vents\x18\x01 \x03(\x0b\x32 .simplelogin_events.AliasCreated\"\x93\x03\n\x0c\x45ventContent\x12?\n\x10user_plan_change\x18\x01 \x01(\x0b\x32#.simplelogin_events.UserPlanChangedH\x00\x12\x37\n\x0cuser_deleted\x18\x02 \x01(\x0b\x32\x1f.simplelogin_events.UserDeletedH\x00\x12\x39\n\ralias_created\x18\x03 \x01(\x0b\x32 .simplelogin_events.AliasCreatedH\x00\x12\x45\n\x13\x61lias_status_change\x18\x04 \x01(\x0b\x32&.simplelogin_events.AliasStatusChangedH\x00\x12\x39\n\ralias_deleted\x18\x05 \x01(\x0b\x32 .simplelogin_events.AliasDeletedH\x00\x12\x41\n\x11\x61lias_create_list\x18\x06 \x01(\x0b\x32$.simplelogin_events.AliasCreatedListH\x00\x42\t\n\x07\x63ontent\"y\n\x05\x45vent\x12\x0f\n\x07user_id\x18\x01 \x01(\r\x12\x18\n\x10\x65xternal_user_id\x18\x02 \x01(\t\x12\x12\n\npartner_id\x18\x03 \x01(\r\x12\x31\n\x07\x63ontent\x18\x04 \x01(\x0b\x32 .simplelogin_events.EventContentb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@ -36,15 +36,15 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals['_USERDELETED']._serialized_start=77
_globals['_USERDELETED']._serialized_end=90
_globals['_ALIASCREATED']._serialized_start=92
_globals['_ALIASCREATED']._serialized_end=182
_globals['_ALIASSTATUSCHANGED']._serialized_start=184
_globals['_ALIASSTATUSCHANGED']._serialized_end=260
_globals['_ALIASDELETED']._serialized_start=262
_globals['_ALIASDELETED']._serialized_end=315
_globals['_ALIASCREATEDLIST']._serialized_start=317
_globals['_ALIASCREATEDLIST']._serialized_end=385
_globals['_EVENTCONTENT']._serialized_start=388
_globals['_EVENTCONTENT']._serialized_end=791
_globals['_EVENT']._serialized_start=793
_globals['_EVENT']._serialized_end=914
_globals['_ALIASCREATED']._serialized_end=184
_globals['_ALIASSTATUSCHANGED']._serialized_start=186
_globals['_ALIASSTATUSCHANGED']._serialized_end=270
_globals['_ALIASDELETED']._serialized_start=272
_globals['_ALIASDELETED']._serialized_end=313
_globals['_ALIASCREATEDLIST']._serialized_start=315
_globals['_ALIASCREATEDLIST']._serialized_end=383
_globals['_EVENTCONTENT']._serialized_start=386
_globals['_EVENTCONTENT']._serialized_end=789
_globals['_EVENT']._serialized_start=791
_globals['_EVENT']._serialized_end=912
# @@protoc_insertion_point(module_scope)

View File

@ -16,34 +16,38 @@ class UserDeleted(_message.Message):
def __init__(self) -> None: ...
class AliasCreated(_message.Message):
__slots__ = ("alias_id", "alias_email", "alias_note", "enabled")
ALIAS_ID_FIELD_NUMBER: _ClassVar[int]
ALIAS_EMAIL_FIELD_NUMBER: _ClassVar[int]
ALIAS_NOTE_FIELD_NUMBER: _ClassVar[int]
__slots__ = ("id", "email", "note", "enabled", "created_at")
ID_FIELD_NUMBER: _ClassVar[int]
EMAIL_FIELD_NUMBER: _ClassVar[int]
NOTE_FIELD_NUMBER: _ClassVar[int]
ENABLED_FIELD_NUMBER: _ClassVar[int]
alias_id: int
alias_email: str
alias_note: str
CREATED_AT_FIELD_NUMBER: _ClassVar[int]
id: int
email: str
note: str
enabled: bool
def __init__(self, alias_id: _Optional[int] = ..., alias_email: _Optional[str] = ..., alias_note: _Optional[str] = ..., enabled: bool = ...) -> None: ...
created_at: int
def __init__(self, id: _Optional[int] = ..., email: _Optional[str] = ..., note: _Optional[str] = ..., enabled: bool = ..., created_at: _Optional[int] = ...) -> None: ...
class AliasStatusChanged(_message.Message):
__slots__ = ("alias_id", "alias_email", "enabled")
ALIAS_ID_FIELD_NUMBER: _ClassVar[int]
ALIAS_EMAIL_FIELD_NUMBER: _ClassVar[int]
__slots__ = ("id", "email", "enabled", "created_at")
ID_FIELD_NUMBER: _ClassVar[int]
EMAIL_FIELD_NUMBER: _ClassVar[int]
ENABLED_FIELD_NUMBER: _ClassVar[int]
alias_id: int
alias_email: str
CREATED_AT_FIELD_NUMBER: _ClassVar[int]
id: int
email: str
enabled: bool
def __init__(self, alias_id: _Optional[int] = ..., alias_email: _Optional[str] = ..., enabled: bool = ...) -> None: ...
created_at: int
def __init__(self, id: _Optional[int] = ..., email: _Optional[str] = ..., enabled: bool = ..., created_at: _Optional[int] = ...) -> None: ...
class AliasDeleted(_message.Message):
__slots__ = ("alias_id", "alias_email")
ALIAS_ID_FIELD_NUMBER: _ClassVar[int]
ALIAS_EMAIL_FIELD_NUMBER: _ClassVar[int]
alias_id: int
alias_email: str
def __init__(self, alias_id: _Optional[int] = ..., alias_email: _Optional[str] = ...) -> None: ...
__slots__ = ("id", "email")
ID_FIELD_NUMBER: _ClassVar[int]
EMAIL_FIELD_NUMBER: _ClassVar[int]
id: int
email: str
def __init__(self, id: _Optional[int] = ..., email: _Optional[str] = ...) -> None: ...
class AliasCreatedList(_message.Message):
__slots__ = ("events",)

View File

@ -22,10 +22,11 @@ def send_alias_creation_events_for_user(
):
event_list.append(
AliasCreated(
alias_id=alias.id,
alias_email=alias.email,
alias_note=alias.note,
id=alias.id,
email=alias.email,
note=alias.note,
enabled=alias.enabled,
created_at=int(alias.created_at.timestamp),
)
)
if len(event_list) >= chunk_size:

View File

@ -973,7 +973,7 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
def has_custom_domain(self):
return CustomDomain.filter_by(user_id=self.id, verified=True).count() > 0
def custom_domains(self):
def custom_domains(self) -> List["CustomDomain"]:
return CustomDomain.filter_by(user_id=self.id, verified=True).all()
def available_domains_for_random_alias(
@ -1660,18 +1660,6 @@ class Alias(Base, ModelMixin):
Session.add(new_alias)
DailyMetric.get_or_create_today_metric().nb_alias += 1
# Internal import to avoid global import cycles
from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import AliasCreated, EventContent
event = AliasCreated(
alias_id=new_alias.id,
alias_email=new_alias.email,
alias_note=new_alias.note,
enabled=True,
)
EventDispatcher.send_event(user, EventContent(alias_created=event))
if (
new_alias.flags & cls.FLAG_PARTNER_CREATED > 0
and new_alias.user.flags & User.FLAG_CREATED_ALIAS_FROM_PARTNER == 0
@ -1684,6 +1672,19 @@ class Alias(Base, ModelMixin):
if flush:
Session.flush()
# Internal import to avoid global import cycles
from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import AliasCreated, EventContent
event = AliasCreated(
id=new_alias.id,
email=new_alias.email,
note=new_alias.note,
enabled=True,
created_at=int(new_alias.created_at.timestamp),
)
EventDispatcher.send_event(user, EventContent(alias_created=event))
return new_alias
@classmethod
@ -1862,6 +1863,8 @@ class Contact(Base, ModelMixin):
MAX_NAME_LENGTH = 512
FLAG_PARTNER_CREATED = 1 << 0
__tablename__ = "contact"
__table_args__ = (
@ -1920,6 +1923,9 @@ class Contact(Base, ModelMixin):
# whether contact is created automatically during the forward phase
automatic_created = sa.Column(sa.Boolean, nullable=True, default=False)
# contact flags
flags = sa.Column(sa.Integer, nullable=False, default=0, server_default="0")
@property
def email(self):
return self.website_email
@ -2418,6 +2424,18 @@ class CustomDomain(Base, ModelMixin):
sa.Boolean, nullable=False, default=False, server_default="0"
)
partner_id = sa.Column(
sa.Integer,
sa.ForeignKey("partner.id"),
nullable=True,
default=None,
server_default=None,
)
pending_deletion = sa.Column(
sa.Boolean, nullable=False, default=False, server_default="0"
)
__table_args__ = (
Index(
"ix_unique_domain", # Index name
@ -2442,9 +2460,6 @@ class CustomDomain(Base, ModelMixin):
def get_trash_url(self):
return config.URL + f"/dashboard/domains/{self.id}/trash"
def get_ownership_dns_txt_value(self):
return f"sl-verification={self.ownership_txt_token}"
@classmethod
def create(cls, **kwargs):
domain = kwargs.get("domain")

View File

@ -2,9 +2,11 @@ from dataclasses import dataclass
from enum import Enum
from flask import url_for
from typing import Optional
import arrow
from app import config
from app.errors import LinkException
from app.models import User, Partner
from app.models import User, Partner, Job
from app.proton.proton_client import ProtonClient, ProtonUser
from app.account_linking import (
process_login_case,
@ -41,12 +43,21 @@ class ProtonCallbackHandler:
def __init__(self, proton_client: ProtonClient):
self.proton_client = proton_client
def _initial_alias_sync(self, user: User):
Job.create(
name=config.JOB_SEND_ALIAS_CREATION_EVENTS,
payload={"user_id": user.id},
run_at=arrow.now(),
commit=True,
)
def handle_login(self, partner: Partner) -> ProtonCallbackResult:
try:
user = self.__get_partner_user()
if user is None:
return generate_account_not_allowed_to_log_in()
res = process_login_case(user, partner)
self._initial_alias_sync(res.user)
return ProtonCallbackResult(
redirect_to_login=False,
flash_message=None,
@ -75,6 +86,7 @@ class ProtonCallbackHandler:
if user is None:
return generate_account_not_allowed_to_log_in()
res = process_link_case(user, current_user, partner)
self._initial_alias_sync(res.user)
return ProtonCallbackResult(
redirect_to_login=False,
flash_message="Account successfully linked",

View File

@ -52,7 +52,7 @@ from flanker.addresslib import address
from flanker.addresslib.address import EmailAddress
from sqlalchemy.exc import IntegrityError
from app import pgp_utils, s3, config
from app import pgp_utils, s3, config, contact_utils
from app.alias_utils import try_auto_create, change_alias_status
from app.config import (
EMAIL_DOMAIN,
@ -195,81 +195,16 @@ def get_or_create_contact(from_header: str, mail_from: str, alias: Alias) -> Con
mail_from,
)
contact_email = mail_from
if not is_valid_email(contact_email):
LOG.w(
"invalid contact email %s. Parse from %s %s",
contact_email,
from_header,
mail_from,
)
# either reuse a contact with empty email or create a new contact with empty email
contact_email = ""
contact_email = sanitize_email(contact_email, not_lower=True)
if contact_name and "\x00" in contact_name:
LOG.w("issue with contact name %s", contact_name)
contact_name = ""
contact = Contact.get_by(alias_id=alias.id, website_email=contact_email)
if contact:
if contact.name != contact_name:
LOG.d(
"Update contact %s name %s to %s",
contact,
contact.name,
contact_name,
)
contact.name = contact_name
Session.commit()
# contact created in the past does not have mail_from and from_header field
if not contact.mail_from and mail_from:
LOG.d(
"Set contact mail_from %s: %s to %s",
contact,
contact.mail_from,
mail_from,
)
contact.mail_from = mail_from
Session.commit()
else:
alias_id = alias.id
try:
contact_email_for_reply = (
contact_email if is_valid_email(contact_email) else ""
)
contact = Contact.create(
user_id=alias.user_id,
alias_id=alias_id,
website_email=contact_email,
name=contact_name,
mail_from=mail_from,
reply_email=generate_reply_email(contact_email_for_reply, alias),
automatic_created=True,
)
if not contact_email:
LOG.d("Create a contact with invalid email for %s", alias)
contact.invalid_email = True
LOG.d(
"create contact %s for %s, reverse alias:%s",
contact_email,
alias,
contact.reply_email,
)
Session.commit()
except IntegrityError:
# If the tx has been rolled back, the connection is borked. Force close to try to get a new one and start fresh
Session.close()
LOG.info(
f"Contact with email {contact_email} for alias_id {alias_id} already existed, fetching from DB"
)
contact = Contact.get_by(alias_id=alias_id, website_email=contact_email)
return contact
contact_result = contact_utils.create_contact(
email=contact_email,
name=contact_name,
alias=alias,
mail_from=mail_from,
allow_empty_email=True,
automatic_created=True,
from_partner=False,
)
return contact_result.contact
def get_or_create_reply_to_contact(
@ -294,33 +229,7 @@ def get_or_create_reply_to_contact(
)
return None
contact = Contact.get_by(alias_id=alias.id, website_email=contact_address)
if contact:
return contact
else:
LOG.d(
"create contact %s for alias %s via reply-to header %s",
contact_address,
alias,
reply_to_header,
)
try:
contact = Contact.create(
user_id=alias.user_id,
alias_id=alias.id,
website_email=contact_address,
name=contact_name,
reply_email=generate_reply_email(contact_address, alias),
automatic_created=True,
)
Session.commit()
except IntegrityError:
LOG.w("Contact %s %s already exist", alias, contact_address)
Session.rollback()
contact = Contact.get_by(alias_id=alias.id, website_email=contact_address)
return contact
return contact_utils.create_contact(contact_address, contact_name, alias).contact
def replace_header_when_forward(msg: Message, alias: Alias, header: str):
@ -818,7 +727,7 @@ def forward_email_to_mailbox(
email_log = EmailLog.create(
contact_id=contact.id,
user_id=user.id,
user_id=contact.user_id,
mailbox_id=mailbox.id,
alias_id=contact.alias_id,
message_id=str(msg[headers.MESSAGE_ID]),

View File

@ -9,7 +9,7 @@ from events.runner import Runner
from events.event_source import DeadLetterEventSource, PostgresEventSource
from events.event_sink import ConsoleEventSink, HttpEventSink
_DEFAULT_MAX_RETRIES = 100
_DEFAULT_MAX_RETRIES = 10
class Mode(Enum):

View File

@ -27,7 +27,9 @@ class HttpEventSink(EventSink):
headers={"Content-Type": "application/x-protobuf"},
verify=not EVENT_WEBHOOK_SKIP_VERIFY_SSL,
)
newrelic.agent.record_custom_event("event_sent", {"http_code": res.status_code})
newrelic.agent.record_custom_event(
"EventSentToPartner", {"http_code": res.status_code}
)
if res.status_code != 200:
LOG.warning(
f"Failed to send event to webhook: {res.status_code} {res.text}"

View File

@ -3,7 +3,7 @@ Run scheduled jobs.
Not meant for running job at precise time (+- 1h)
"""
import time
from typing import List
from typing import List, Optional
import arrow
from sqlalchemy.sql.expression import or_, and_
@ -240,7 +240,7 @@ def process_job(job: Job):
elif job.name == config.JOB_DELETE_DOMAIN:
custom_domain_id = job.payload.get("custom_domain_id")
custom_domain = CustomDomain.get(custom_domain_id)
custom_domain: Optional[CustomDomain] = CustomDomain.get(custom_domain_id)
if not custom_domain:
return
@ -252,16 +252,17 @@ def process_job(job: Job):
LOG.d("Domain %s deleted", domain_name)
send_email(
user.email,
f"Your domain {domain_name} has been deleted",
f"""Domain {domain_name} along with its aliases are deleted successfully.
if custom_domain.partner_id is None:
send_email(
user.email,
f"Your domain {domain_name} has been deleted",
f"""Domain {domain_name} along with its aliases are deleted successfully.
Regards,
SimpleLogin team.
""",
retries=3,
)
Regards,
SimpleLogin team.
""",
retries=3,
)
elif job.name == config.JOB_SEND_USER_REPORT:
export_job = ExportUserDataJob.create_from_job(job)
if export_job:

View File

@ -0,0 +1,30 @@
"""Custom Domain partner id
Revision ID: 2441b7ff5da9
Revises: 1c14339aae90
Create Date: 2024-09-13 15:43:02.425964
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '2441b7ff5da9'
down_revision = '1c14339aae90'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('custom_domain', sa.Column('partner_id', sa.Integer(), nullable=True, default=None, server_default=None))
op.create_foreign_key(None, 'custom_domain', 'partner', ['partner_id'], ['id'])
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, 'custom_domain', type_='foreignkey')
op.drop_column('custom_domain', 'partner_id')
# ### end Alembic commands ###

View File

@ -0,0 +1,31 @@
"""contact.flags and custom_domain.pending_deletion
Revision ID: 88dd7a0abf54
Revises: 2441b7ff5da9
Create Date: 2024-09-19 15:41:20.910374
"""
import sqlalchemy_utils
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '88dd7a0abf54'
down_revision = '2441b7ff5da9'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('contact', sa.Column('flags', sa.Integer(), server_default='0', nullable=False))
op.add_column('custom_domain', sa.Column('pending_deletion', sa.Boolean(), server_default='0', nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('custom_domain', 'pending_deletion')
op.drop_column('contact', 'flags')
# ### end Alembic commands ###

View File

@ -125,6 +125,21 @@ def log_events_pending_dead_letter():
)
@newrelic.agent.background_task()
def log_failed_events():
r = Session.execute(
"""
SELECT COUNT(*)
FROM sync_event
WHERE retries >= 10;
""",
)
failed_events = list(r)[0][0]
LOG.d("number of failed events %s", failed_events)
newrelic.agent.record_custom_metric("Custom/sync_events_failed", failed_events)
if __name__ == "__main__":
exporter = MetricExporter(get_newrelic_license())
while True:
@ -132,6 +147,7 @@ if __name__ == "__main__":
log_nb_db_connection()
log_pending_to_process_events()
log_events_pending_dead_letter()
log_failed_events()
Session.close()
exporter.run()

View File

@ -0,0 +1,49 @@
#!/usr/bin/env python3
import argparse
import time
from sqlalchemy import func
from app.models import Alias
from app.db import Session
parser = argparse.ArgumentParser(
prog="Backfill alias", description="Update alias notes and backfill flag"
)
parser.add_argument(
"-s", "--start_alias_id", default=0, type=int, help="Initial alias_id"
)
parser.add_argument("-e", "--end_alias_id", default=0, type=int, help="Last alias_id")
args = parser.parse_args()
alias_id_start = args.start_alias_id
max_alias_id = args.end_alias_id
if max_alias_id == 0:
max_alias_id = Session.query(func.max(Alias.id)).scalar()
print(f"Checking alias {alias_id_start} to {max_alias_id}")
step = 1000
noteSql = "(note = 'Created through Proton' or note = 'Created through partner Proton')"
alias_query = f"UPDATE alias set note = NULL, flags = flags | :flag where id>=:start AND id<:end and {noteSql}"
updated = 0
start_time = time.time()
for batch_start in range(alias_id_start, max_alias_id, step):
rows_done = Session.execute(
alias_query,
{
"start": batch_start,
"end": batch_start + step,
"flag": Alias.FLAG_PARTNER_CREATED,
},
)
updated += rows_done.rowcount
Session.commit()
elapsed = time.time() - start_time
time_per_alias = elapsed / (updated + 1)
last_batch_id = batch_start + step
remaining = max_alias_id - last_batch_id
time_remaining = (max_alias_id - last_batch_id) * time_per_alias
hours_remaining = time_remaining / 3600.0
print(
f"\rAlias {batch_start}/{max_alias_id} {updated} {hours_remaining:.2f}hrs remaining"
)
print("")

View File

@ -10,21 +10,23 @@ message UserDeleted {
}
message AliasCreated {
uint32 alias_id = 1;
string alias_email = 2;
string alias_note = 3;
uint32 id = 1;
string email = 2;
string note = 3;
bool enabled = 4;
uint32 created_at = 5;
}
message AliasStatusChanged {
uint32 alias_id = 1;
string alias_email = 2;
uint32 id = 1;
string email = 2;
bool enabled = 3;
uint32 created_at = 4;
}
message AliasDeleted {
uint32 alias_id = 1;
string alias_email = 2;
uint32 id = 1;
string email = 2;
}
message AliasCreatedList {

View File

@ -12,10 +12,10 @@ docker run -p 25432:5432 --name ${container_name} -e POSTGRES_PASSWORD=postgres
sleep 3
# upgrade the DB to the latest stage and
env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl rye run alembic upgrade head
env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl poetry run alembic upgrade head
# generate the migration script.
env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl rye run alembic revision --autogenerate $@
env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl poetry run alembic revision --autogenerate $@
# remove the db
docker rm -f ${container_name}

View File

@ -19,7 +19,7 @@
<tbody>
<tr>
<td>{{ user.id }}</td>
<td>{{ user.email }}</td>
<td><a href="?email={{ user.email }}">{{ user.email }}</a></td>
{% if user.disabled %}
<td class="text-danger">Disabled</td>
@ -32,7 +32,7 @@
<td>{{ user.updated_at }}</td>
{% if pu %}
<td>{{ pu.partner_email }}</td>
<td><a href="?email={{ pu.partner_email }}">{{ pu.partner_email }}</a></td>
{% else %}
<td>No</td>
{% endif %}
@ -43,7 +43,7 @@
{% macro list_mailboxes(mbox_count, mboxes) %}
<h4>
{{ mbox_count }} Mailboxes found.
{% if mbox_count>10 %}Showing only the first 10.{% endif %}
{% if mbox_count>10 %}Showing only the last 10.{% endif %}
</h4>
<table class="table">
<thead>
@ -59,7 +59,7 @@
<tr>
<td>{{ mailbox.id }}</td>
<td>{{ mailbox.email }}</td>
<td><a href="?email={{mailbox.email}}">{{mailbox.email}}</a></td>
<td>{{ "Yes" if mailbox.verified else "No" }}</td>
<td>
{{ mailbox.created_at }}
@ -72,7 +72,7 @@
{% macro list_alias(alias_count, aliases) %}
<h4>
{{ alias_count }} Aliases found.
{% if alias_count>10 %}Showing only the first 10.{% endif %}
{% if alias_count>10 %}Showing only the last 10.{% endif %}
</h4>
<table class="table">
<thead>
@ -95,7 +95,7 @@
{% for alias in aliases %}
<tr>
<td>{{ alias.id }}</td>
<td>{{ alias.email }}</td>
<td><a href="?email={{alias.email}}">{{alias.email}}</a></td>
<td>{{ "Yes" if alias.verified else "No" }}</td>
<td>{{ alias.created_at }}</td>
</tr>
@ -156,7 +156,7 @@
{% block body %}
<div class="border border-dark border-2 mt-1 mb-2 p-3">
<form method="post">
<form method="get">
<div class="form-group">
<label for="email">Email to search:</label>
<input type="text"
@ -167,7 +167,7 @@
<button type="submit" class="btn btn-primary">Submit</button>
</form>
</div>
{% if no_match %}
{% if data.no_match and email %}
<div class="border border-dark border-2 mt-1 mb-2 p-3 alert alert-warning"
role="alert">No user, alias or mailbox found for {{ email }}</div>
@ -190,14 +190,19 @@
{{ list_alias(helper.alias_count(data.user) ,helper.alias_list(data.user)) }}
</div>
{% endif %}
{% if data.mailbox %}
{% if data.mailbox_count > 10 %}
<h3>Found more than 10 mailboxes for {{ email }}. Showing the last 10</h3>
{% elif data.mailbox_count > 0 %}
<h3>Found {{ data.mailbox_count }} mailbox(es) for {{ email }}</h3>
{% endif %}
{% for mailbox in data.mailbox %}
<div class="border border-dark mt-1 mb-2 p-3">
<h3 class="mb-3">Found Mailbox {{ data.mailbox.email }}</h3>
{{ list_mailboxes(1, [data.mailbox]) }}
{{ show_user(data.mailbox.user) }}
<h3 class="mb-3">Found Mailbox {{ mailbox.email }}</h3>
{{ list_mailboxes(1, [mailbox]) }}
{{ show_user(mailbox.user) }}
</div>
{% endif %}
{% endfor %}
{% if data.deleted_alias %}
<div class="border border-dark mt-1 mb-2 p-3">

View File

@ -94,4 +94,4 @@
</div>
</div>
{% endblock %}
{% block script %}<script>$('.mailbox-select').multipleSelect();</script>{% endblock %}

View File

@ -38,7 +38,7 @@
Value: <em data-toggle="tooltip"
title="Click to copy"
class="clipboard"
data-clipboard-text="{{ custom_domain.get_ownership_dns_txt_value() }}">{{ custom_domain.get_ownership_dns_txt_value() }}</em>
data-clipboard-text="{{ ownership_record }}">{{ ownership_record }}</em>
</div>
<form method="post" action="#ownership-form">
{{ csrf_form.csrf_token }}
@ -237,7 +237,7 @@
folder.
</div>
<div class="mb-2">Add the following CNAME DNS records to your domain.</div>
{% for dkim_prefix, dkim_cname_value in dkim_records %}
{% for dkim_prefix, dkim_cname_value in dkim_records.items() %}
<div class="mb-2 p-3 dns-record">
Record: CNAME

View File

@ -21,8 +21,8 @@
<div class="my-3">
<p>Alternatively you can use your Proton credentials to ensure it's you.</p>
</div>
<a class="btn btn-primary btn-block mt-2 proton-button w-25"
href="{{ url_for('auth.proton_login', next=next) }}">
<a class="btn btn-primary btn-block mt-2 proton-button"
href="{{ url_for('auth.proton_login', next=next) }}" style="max-width: 400px">
<img class="mr-2" src="/static/images/proton.svg" />
Authenticate with Proton
</a>
@ -38,4 +38,4 @@
{% endif %}
</div>
</div>
{% endblock %}
{% endblock %}

View File

@ -0,0 +1,32 @@
from app.events.event_dispatcher import Dispatcher
from app.models import PartnerUser, User
from app.proton.utils import get_proton_partner
from tests.utils import create_new_user, random_token
from typing import Tuple
class OnMemoryDispatcher(Dispatcher):
def __init__(self):
self.memory = []
def send(self, event: bytes):
self.memory.append(event)
def clear(self):
self.memory = []
def _create_unlinked_user() -> User:
return create_new_user()
def _create_linked_user() -> Tuple[User, PartnerUser]:
user = _create_unlinked_user()
partner_user = PartnerUser.create(
partner_id=get_proton_partner().id,
user_id=user.id,
external_user_id=random_token(10),
flush=True,
)
return user, partner_user

View File

@ -1,33 +1,10 @@
from app.events.event_dispatcher import EventDispatcher, Dispatcher
from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import EventContent, UserDeleted
from app.models import PartnerUser, User
from app.proton.utils import get_proton_partner
from tests.utils import create_new_user, random_token
from typing import Tuple
class OnMemoryDispatcher(Dispatcher):
def __init__(self):
self.memory = []
def send(self, event: bytes):
self.memory.append(event)
def _create_unlinked_user() -> User:
return create_new_user()
def _create_linked_user() -> Tuple[User, PartnerUser]:
user = _create_unlinked_user()
partner_user = PartnerUser.create(
partner_id=get_proton_partner().id,
user_id=user.id,
external_user_id=random_token(10),
flush=True,
)
return user, partner_user
from .event_test_utils import (
_create_unlinked_user,
OnMemoryDispatcher,
_create_linked_user,
)
def test_event_dispatcher_stores_events():

View File

@ -0,0 +1,102 @@
from app import config, alias_utils
from app.db import Session
from app.events.event_dispatcher import GlobalDispatcher
from app.events.generated import event_pb2
from app.models import Alias, User, PartnerUser
from tests.utils import random_token
from .event_test_utils import (
OnMemoryDispatcher,
_create_linked_user,
)
on_memory_dispatcher = OnMemoryDispatcher()
def setup_module():
GlobalDispatcher.set_dispatcher(on_memory_dispatcher)
config.EVENT_WEBHOOK = "http://test"
def teardown_module():
GlobalDispatcher.set_dispatcher(None)
config.EVENT_WEBHOOK = None
def setup_function(func):
on_memory_dispatcher.clear()
def _get_event_from_string(
data: str, user: User, pu: PartnerUser
) -> event_pb2.EventContent:
event = event_pb2.Event()
event.ParseFromString(data)
assert user.id == event.user_id
assert pu.external_user_id == event.external_user_id
assert pu.partner_id == event.partner_id
return event.content
def test_fire_event_on_alias_creation():
(user, pu) = _create_linked_user()
alias = Alias.create_new_random(user)
Session.flush()
assert len(on_memory_dispatcher.memory) == 1
event_data = on_memory_dispatcher.memory[0]
event_content = _get_event_from_string(event_data, user, pu)
assert event_content.alias_created is not None
alias_created = event_content.alias_created
assert alias.id == alias_created.id
assert alias.email == alias_created.email
assert "" == alias_created.note
assert alias.enabled == alias_created.enabled
assert int(alias.created_at.timestamp) == alias_created.created_at
def test_fire_event_on_alias_creation_with_note():
(user, pu) = _create_linked_user()
note = random_token(10)
alias = Alias.create_new_random(user, note=note)
Session.flush()
assert len(on_memory_dispatcher.memory) == 1
event_data = on_memory_dispatcher.memory[0]
event_content = _get_event_from_string(event_data, user, pu)
assert event_content.alias_created is not None
alias_created = event_content.alias_created
assert alias.id == alias_created.id
assert alias.email == alias_created.email
assert note == alias_created.note
assert alias.enabled == alias_created.enabled
def test_fire_event_on_alias_deletion():
(user, pu) = _create_linked_user()
alias = Alias.create_new_random(user)
alias_id = alias.id
Session.flush()
on_memory_dispatcher.clear()
alias_utils.delete_alias(alias, user)
assert len(on_memory_dispatcher.memory) == 1
event_data = on_memory_dispatcher.memory[0]
event_content = _get_event_from_string(event_data, user, pu)
assert event_content.alias_deleted is not None
alias_deleted = event_content.alias_deleted
assert alias_id == alias_deleted.id
assert alias.email == alias_deleted.email
def test_fire_event_on_alias_status_change():
(user, pu) = _create_linked_user()
alias = Alias.create_new_random(user)
Session.flush()
on_memory_dispatcher.clear()
alias_utils.change_alias_status(alias, True)
assert len(on_memory_dispatcher.memory) == 1
event_data = on_memory_dispatcher.memory[0]
event_content = _get_event_from_string(event_data, user, pu)
assert event_content.alias_status_change is not None
event = event_content.alias_status_change
assert alias.id == event.id
assert alias.email == event.email
assert int(alias.created_at.timestamp) == event.created_at
assert event.enabled

View File

@ -37,10 +37,14 @@ def test_send_alias_creation_events():
event_list = decoded_event.content.alias_create_list.events
assert len(event_list) == 2
# 0 is newsletter alias
assert event_list[1].alias_id == aliases[0].id
assert event_list[1].id == aliases[0].id
assert event_list[1].email == aliases[0].email
assert event_list[1].note == ""
assert event_list[1].enabled == aliases[0].enabled
assert event_list[1].created_at == int(aliases[0].created_at.timestamp)
decoded_event = event_pb2.Event.FromString(dispatcher.events[1])
assert decoded_event.user_id == user.id
assert decoded_event.external_user_id == partner_user.external_user_id
event_list = decoded_event.content.alias_create_list.events
assert len(event_list) == 1
assert event_list[0].alias_id == aliases[1].id
assert event_list[0].id == aliases[1].id

View File

@ -1,4 +1,6 @@
from arrow import Arrow
from app import config
from app.account_linking import (
SLPlan,
SLPlanType,
@ -8,7 +10,7 @@ from app.proton.proton_callback_handler import (
ProtonCallbackHandler,
generate_account_not_allowed_to_log_in,
)
from app.models import User, PartnerUser
from app.models import User, PartnerUser, Job, JobState
from app.proton.utils import get_proton_partner
from app.utils import random_string
from typing import Optional
@ -23,6 +25,17 @@ class MockProtonClient(ProtonClient):
return self.user
def check_initial_sync_job(user: User):
for job in Job.yield_per_query(10).filter_by(
name=config.JOB_SEND_ALIAS_CREATION_EVENTS,
state=JobState.ready.value,
):
if job.payload.get("user_id") == user.id:
Job.delete(job.id)
return
assert False
def test_proton_callback_handler_unexistant_sl_user():
email = random_email()
name = random_string()
@ -56,6 +69,7 @@ def test_proton_callback_handler_unexistant_sl_user():
)
assert partner_user is not None
assert partner_user.external_user_id == external_id
check_initial_sync_job(res.user)
def test_proton_callback_handler_existant_sl_user():
@ -84,6 +98,7 @@ def test_proton_callback_handler_existant_sl_user():
sa = PartnerUser.get_by(user_id=sl_user.id, partner_id=get_proton_partner().id)
assert sa is not None
assert sa.partner_email == user.email
check_initial_sync_job(res.user)
def test_proton_callback_handler_none_user_login():

View File

@ -0,0 +1,117 @@
from typing import Optional
import pytest
from app.contact_utils import create_contact, ContactCreateError
from app.db import Session
from app.models import (
Alias,
Contact,
)
from tests.utils import create_new_user, random_email, random_token
def create_provider():
# name auto_created from_partner
yield ["name", "a@b.c", True, True]
yield [None, None, True, True]
yield [None, None, False, True]
yield [None, None, True, False]
yield [None, None, False, False]
@pytest.mark.parametrize(
"name, mail_from, automatic_created, from_partner", create_provider()
)
def test_create_contact(
name: Optional[str],
mail_from: Optional[str],
automatic_created: bool,
from_partner: bool,
):
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
email = random_email()
contact_result = create_contact(
email,
name,
alias,
mail_from=mail_from,
automatic_created=automatic_created,
from_partner=from_partner,
)
assert contact_result.error is None
contact = contact_result.contact
assert contact.user_id == user.id
assert contact.alias_id == alias.id
assert contact.website_email == email
assert contact.name == name
assert contact.mail_from == mail_from
assert contact.automatic_created == automatic_created
assert not contact.invalid_email
expected_flags = Contact.FLAG_PARTNER_CREATED if from_partner else 0
assert contact.flags == expected_flags
def test_create_contact_email_email_not_allowed():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
contact_result = create_contact("", "", alias)
assert contact_result.contact is None
assert contact_result.error == ContactCreateError.InvalidEmail
def test_create_contact_email_email_allowed():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
contact_result = create_contact("", "", alias, allow_empty_email=True)
assert contact_result.error is None
assert contact_result.contact is not None
assert contact_result.contact.website_email == ""
assert contact_result.contact.invalid_email
def test_do_not_allow_invalid_email():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
contact_result = create_contact("potato", "", alias)
assert contact_result.contact is None
assert contact_result.error == ContactCreateError.InvalidEmail
contact_result = create_contact("asdf\x00@gmail.com", "", alias)
assert contact_result.contact is None
assert contact_result.error == ContactCreateError.InvalidEmail
def test_update_name_for_existing():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
email = random_email()
contact_result = create_contact(email, "", alias)
assert contact_result.error is None
assert contact_result.contact is not None
assert contact_result.contact.name == ""
name = random_token()
contact_result = create_contact(email, name, alias)
assert contact_result.error is None
assert contact_result.contact is not None
assert contact_result.contact.name == name
def test_update_mail_from_for_existing():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
email = random_email()
contact_result = create_contact(email, "", alias)
assert contact_result.error is None
assert contact_result.contact is not None
assert contact_result.contact.mail_from is None
mail_from = random_email()
contact_result = create_contact(email, "", alias, mail_from=mail_from)
assert contact_result.error is None
assert contact_result.contact is not None
assert contact_result.contact.mail_from == mail_from

View File

@ -0,0 +1,149 @@
from typing import Optional
from app import config
from app.config import ALIAS_DOMAINS
from app.custom_domain_utils import (
can_domain_be_used,
create_custom_domain,
is_valid_domain,
sanitize_domain,
CannotUseDomainReason,
)
from app.db import Session
from app.models import User, CustomDomain, Mailbox
from tests.utils import get_proton_partner
from tests.utils import create_new_user, random_string, random_domain
user: Optional[User] = None
def setup_module():
global user
config.SKIP_MX_LOOKUP_ON_CHECK = True
user = create_new_user()
user.trial_end = None
user.lifetime = True
Session.commit()
# is_valid_domain
def test_is_valid_domain():
assert is_valid_domain("example.com") is True
assert is_valid_domain("sub.example.com") is True
assert is_valid_domain("ex-ample.com") is True
assert is_valid_domain("-example.com") is False
assert is_valid_domain("example-.com") is False
assert is_valid_domain("exa_mple.com") is False
assert is_valid_domain("example..com") is False
assert is_valid_domain("") is False
assert is_valid_domain("a" * 64 + ".com") is False
assert is_valid_domain("a" * 63 + ".com") is True
assert is_valid_domain("example.com.") is True
assert is_valid_domain(".example.com") is False
assert is_valid_domain("example..com") is False
assert is_valid_domain("example.com-") is False
# can_domain_be_used
def test_can_domain_be_used():
domain = f"{random_string(10)}.com"
res = can_domain_be_used(user, domain)
assert res is None
def test_can_domain_be_used_existing_domain():
domain = random_domain()
CustomDomain.create(user_id=user.id, domain=domain, commit=True)
res = can_domain_be_used(user, domain)
assert res is CannotUseDomainReason.DomainAlreadyUsed
def test_can_domain_be_used_sl_domain():
domain = ALIAS_DOMAINS[0]
res = can_domain_be_used(user, domain)
assert res is CannotUseDomainReason.BuiltinDomain
def test_can_domain_be_used_domain_of_user_email():
domain = user.email.split("@")[1]
res = can_domain_be_used(user, domain)
assert res is CannotUseDomainReason.DomainPartOfUserEmail
def test_can_domain_be_used_domain_of_existing_mailbox():
domain = random_domain()
Mailbox.create(user_id=user.id, email=f"email@{domain}", verified=True, commit=True)
res = can_domain_be_used(user, domain)
assert res is CannotUseDomainReason.DomainUserInMailbox
def test_can_domain_be_used_invalid_domain():
domain = f"{random_string(10)}@lol.com"
res = can_domain_be_used(user, domain)
assert res is CannotUseDomainReason.InvalidDomain
# sanitize_domain
def test_can_sanitize_domain_empty():
assert sanitize_domain("") == ""
def test_can_sanitize_domain_starting_with_http():
domain = "test.domain"
assert sanitize_domain(f"http://{domain}") == domain
def test_can_sanitize_domain_starting_with_https():
domain = "test.domain"
assert sanitize_domain(f"https://{domain}") == domain
def test_can_sanitize_domain_correct_domain():
domain = "test.domain"
assert sanitize_domain(domain) == domain
# create_custom_domain
def test_can_create_custom_domain():
domain = random_domain()
res = create_custom_domain(user=user, domain=domain)
assert res.success is True
assert res.redirect is None
assert res.message == ""
assert res.message_category == ""
assert res.instance is not None
assert res.instance.domain == domain
assert res.instance.user_id == user.id
def test_can_create_custom_domain_validates_if_parent_is_validated():
root_domain = random_domain()
subdomain = f"{random_string(10)}.{root_domain}"
# Create custom domain with the root domain
CustomDomain.create(
user_id=user.id,
domain=root_domain,
verified=True,
ownership_verified=True,
commit=True,
)
# Create custom domain with subdomain. Should automatically be verified
res = create_custom_domain(user=user, domain=subdomain)
assert res.success is True
assert res.instance.domain == subdomain
assert res.instance.user_id == user.id
assert res.instance.ownership_verified is True
def test_creates_custom_domain_with_partner_id():
domain = random_domain()
proton_partner = get_proton_partner()
res = create_custom_domain(user=user, domain=domain, partner_id=proton_partner.id)
assert res.success is True
assert res.instance.domain == domain
assert res.instance.user_id == user.id
assert res.instance.partner_id == proton_partner.id

View File

@ -0,0 +1,376 @@
from typing import Optional
from app import config
from app.constants import DMARC_RECORD
from app.custom_domain_validation import CustomDomainValidation
from app.db import Session
from app.models import CustomDomain, User
from app.dns_utils import InMemoryDNSClient
from app.proton.utils import get_proton_partner
from app.utils import random_string
from tests.utils import create_new_user, random_domain
user: Optional[User] = None
def setup_module():
global user
config.SKIP_MX_LOOKUP_ON_CHECK = True
user = create_new_user()
user.trial_end = None
user.lifetime = True
Session.commit()
def create_custom_domain(domain: str) -> CustomDomain:
return CustomDomain.create(user_id=user.id, domain=domain, commit=True)
def test_custom_domain_validation_get_dkim_records():
domain = random_domain()
custom_domain = create_custom_domain(domain)
validator = CustomDomainValidation(domain)
records = validator.get_dkim_records(custom_domain)
assert len(records) == 3
assert records["dkim02._domainkey"] == f"dkim02._domainkey.{domain}"
assert records["dkim03._domainkey"] == f"dkim03._domainkey.{domain}"
assert records["dkim._domainkey"] == f"dkim._domainkey.{domain}"
def test_custom_domain_validation_get_dkim_records_for_partner():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
custom_domain.partner_id = partner_id
Session.commit()
dkim_domain = random_domain()
validator = CustomDomainValidation(
domain, partner_domains={partner_id: dkim_domain}
)
records = validator.get_dkim_records(custom_domain)
assert len(records) == 3
assert records["dkim02._domainkey"] == f"dkim02._domainkey.{dkim_domain}"
assert records["dkim03._domainkey"] == f"dkim03._domainkey.{dkim_domain}"
assert records["dkim._domainkey"] == f"dkim._domainkey.{dkim_domain}"
# validate_dkim_records
def test_custom_domain_validation_validate_dkim_records_empty_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
res = validator.validate_dkim_records(domain)
assert len(res) == 3
for record_value in res.values():
assert record_value == "empty"
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dkim_verified is False
def test_custom_domain_validation_validate_dkim_records_wrong_records_failure():
dkim_domain = random_domain()
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(dkim_domain, dns_client)
user_domain = random_domain()
# One domain right, two domains wrong
dns_client.set_cname_record(
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
)
dns_client.set_cname_record(f"dkim02._domainkey.{user_domain}", "wrong")
dns_client.set_cname_record(f"dkim03._domainkey.{user_domain}", "wrong")
domain = create_custom_domain(user_domain)
res = validator.validate_dkim_records(domain)
assert len(res) == 2
for record_value in res.values():
assert record_value == "wrong"
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dkim_verified is False
def test_custom_domain_validation_validate_dkim_records_success_with_old_system():
dkim_domain = random_domain()
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(dkim_domain, dns_client)
user_domain = random_domain()
# One domain right, other domains missing
dns_client.set_cname_record(
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
)
domain = create_custom_domain(user_domain)
# DKIM is verified
domain.dkim_verified = True
Session.commit()
res = validator.validate_dkim_records(domain)
assert len(res) == 2
assert f"dkim02._domainkey.{user_domain}" in res
assert f"dkim03._domainkey.{user_domain}" in res
# Flag is not cleared
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dkim_verified is True
def test_custom_domain_validation_validate_dkim_records_success():
dkim_domain = random_domain()
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(dkim_domain, dns_client)
user_domain = random_domain()
# One domain right, two domains wrong
dns_client.set_cname_record(
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
)
dns_client.set_cname_record(
f"dkim02._domainkey.{user_domain}", f"dkim02._domainkey.{dkim_domain}"
)
dns_client.set_cname_record(
f"dkim03._domainkey.{user_domain}", f"dkim03._domainkey.{dkim_domain}"
)
domain = create_custom_domain(user_domain)
res = validator.validate_dkim_records(domain)
assert len(res) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dkim_verified is True
# validate_ownership
def test_custom_domain_validation_validate_ownership_empty_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
res = validator.validate_domain_ownership(domain)
assert res.success is False
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.ownership_verified is False
def test_custom_domain_validation_validate_ownership_wrong_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
wrong_records = [random_string()]
dns_client.set_txt_record(domain.domain, wrong_records)
res = validator.validate_domain_ownership(domain)
assert res.success is False
assert res.errors == wrong_records
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.ownership_verified is False
def test_custom_domain_validation_validate_ownership_success():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
dns_client.set_txt_record(
domain.domain, [validator.get_ownership_verification_record(domain)]
)
res = validator.validate_domain_ownership(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.ownership_verified is True
def test_custom_domain_validation_validate_ownership_from_partner_success():
dns_client = InMemoryDNSClient()
partner_id = get_proton_partner().id
prefix = random_string()
validator = CustomDomainValidation(
random_domain(),
dns_client,
partner_domains_validation_prefixes={partner_id: prefix},
)
domain = create_custom_domain(random_domain())
domain.partner_id = partner_id
Session.commit()
dns_client.set_txt_record(
domain.domain, [validator.get_ownership_verification_record(domain)]
)
res = validator.validate_domain_ownership(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.ownership_verified is True
# validate_mx_records
def test_custom_domain_validation_validate_mx_records_empty_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
res = validator.validate_mx_records(domain)
assert res.success is False
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.verified is False
def test_custom_domain_validation_validate_mx_records_wrong_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
wrong_record_1 = random_string()
wrong_record_2 = random_string()
wrong_records = [(10, wrong_record_1), (20, wrong_record_2)]
dns_client.set_mx_records(domain.domain, wrong_records)
res = validator.validate_mx_records(domain)
assert res.success is False
assert res.errors == [f"10 {wrong_record_1}", f"20 {wrong_record_2}"]
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.verified is False
def test_custom_domain_validation_validate_mx_records_success():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
dns_client.set_mx_records(domain.domain, config.EMAIL_SERVERS_WITH_PRIORITY)
res = validator.validate_mx_records(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.verified is True
# validate_spf_records
def test_custom_domain_validation_validate_spf_records_empty_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
res = validator.validate_spf_records(domain)
assert res.success is False
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.spf_verified is False
def test_custom_domain_validation_validate_spf_records_wrong_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
wrong_records = [random_string()]
dns_client.set_txt_record(domain.domain, wrong_records)
res = validator.validate_spf_records(domain)
assert res.success is False
assert res.errors == wrong_records
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.spf_verified is False
def test_custom_domain_validation_validate_spf_records_success():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
dns_client.set_txt_record(domain.domain, [f"v=spf1 include:{config.EMAIL_DOMAIN}"])
res = validator.validate_spf_records(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.spf_verified is True
# validate_dmarc_records
def test_custom_domain_validation_validate_dmarc_records_empty_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
res = validator.validate_dmarc_records(domain)
assert res.success is False
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dmarc_verified is False
def test_custom_domain_validation_validate_dmarc_records_wrong_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
wrong_records = [random_string()]
dns_client.set_txt_record(f"_dmarc.{domain.domain}", wrong_records)
res = validator.validate_dmarc_records(domain)
assert res.success is False
assert res.errors == wrong_records
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dmarc_verified is False
def test_custom_domain_validation_validate_dmarc_records_success():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
dns_client.set_txt_record(f"_dmarc.{domain.domain}", [DMARC_RECORD])
res = validator.validate_dmarc_records(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dmarc_verified is True

View File

@ -1,10 +1,12 @@
from app.dns_utils import (
get_mx_domains,
get_spf_domain,
get_txt_record,
get_network_dns_client,
is_mx_equivalent,
InMemoryDNSClient,
)
from tests.utils import random_domain
# use our own domain for test
_DOMAIN = "simplelogin.io"
@ -20,12 +22,12 @@ def test_get_mx_domains():
def test_get_spf_domain():
r = get_spf_domain(_DOMAIN)
r = get_network_dns_client().get_spf_domain(_DOMAIN)
assert r == ["simplelogin.co"]
def test_get_txt_record():
r = get_txt_record(_DOMAIN)
r = get_network_dns_client().get_txt_record(_DOMAIN)
assert len(r) > 0
@ -46,3 +48,15 @@ def test_is_mx_equivalent():
[(5, "domain1"), (10, "domain2")],
[(10, "domain1"), (20, "domain2"), (20, "domain3")],
)
def test_get_spf_record():
client = InMemoryDNSClient()
sl_domain = random_domain()
domain = random_domain()
spf_record = f"v=spf1 include:{sl_domain}"
client.set_txt_record(domain, [spf_record, "another record"])
res = client.get_spf_domain(domain)
assert res == [sl_domain]