Compare commits

..

No commits in common. "main" and "4.64.4" have entirely different histories.
main ... 4.64.4

111 changed files with 1000 additions and 1984 deletions

View File

@ -27,6 +27,11 @@ jobs:
sudo apt update sudo apt update
sudo apt install -y libre2-dev libpq-dev sudo apt install -y libre2-dev libpq-dev
- name: "Set up Python"
uses: actions/setup-python@v5
with:
python-version-file: "pyproject.toml"
- name: Install dependencies - name: Install dependencies
if: steps.setup-uv.outputs.cache-hit != 'true' if: steps.setup-uv.outputs.cache-hit != 'true'
run: uv sync --locked --all-extras run: uv sync --locked --all-extras
@ -81,6 +86,11 @@ jobs:
sudo apt update sudo apt update
sudo apt install -y libre2-dev libpq-dev sudo apt install -y libre2-dev libpq-dev
- name: "Set up Python"
uses: actions/setup-python@v5
with:
python-version-file: "pyproject.toml"
- name: Install dependencies - name: Install dependencies
if: steps.setup-uv.outputs.cache-hit != 'true' if: steps.setup-uv.outputs.cache-hit != 'true'
run: uv sync --locked --all-extras run: uv sync --locked --all-extras
@ -97,7 +107,7 @@ jobs:
- name: Prepare version file - name: Prepare version file
run: | run: |
scripts/generate-build-info.sh ${{ github.sha }} ${{ github.ref_name }} scripts/generate-build-info.sh ${{ github.sha }}
cat app/build_info.py cat app/build_info.py
- name: Test with pytest - name: Test with pytest
@ -154,7 +164,7 @@ jobs:
- name: Prepare version file - name: Prepare version file
run: | run: |
scripts/generate-build-info.sh ${{ github.sha }} ${{ github.ref_name }} scripts/generate-build-info.sh ${{ github.sha }}
cat app/build_info.py cat app/build_info.py
- name: Build image and publish to Docker Registry - name: Build image and publish to Docker Registry

View File

@ -1 +1 @@
3.12.8 3.10.16

View File

@ -215,7 +215,7 @@ python email_handler.py
4) Send a test email 4) Send a test email
```bash ```bash
swaks --to e1@sl.lan --from hey@google.com --server 127.0.0.1:20381 swaks --to e1@sl.local --from hey@google.com --server 127.0.0.1:20381
``` ```
Now open http://localhost:1080/ (or http://localhost:1080/ for MailHog), you should see the forwarded email. Now open http://localhost:1080/ (or http://localhost:1080/ for MailHog), you should see the forwarded email.

View File

@ -1,25 +1,21 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional, List from typing import Optional, List
import arrow import arrow
import sqlalchemy import sqlalchemy
from flask import redirect, url_for, request, flash, Response
from flask_admin import BaseView from flask_admin import BaseView
from flask_admin import expose, AdminIndexView
from flask_admin.actions import action
from flask_admin.contrib import sqla
from flask_admin.form import SecureForm from flask_admin.form import SecureForm
from flask_admin.model.template import EndpointLinkRowAction from flask_admin.model.template import EndpointLinkRowAction
from flask_login import current_user
from markupsafe import Markup from markupsafe import Markup
from app import models, s3, config from app import models, s3, config
from app.custom_domain_validation import ( from flask import redirect, url_for, request, flash, Response
CustomDomainValidation, from flask_admin import expose, AdminIndexView
DomainValidationResult, from flask_admin.actions import action
ExpectedValidationRecords, from flask_admin.contrib import sqla
) from flask_login import current_user
from app.custom_domain_validation import CustomDomainValidation, DomainValidationResult
from app.db import Session from app.db import Session
from app.dns_utils import get_network_dns_client from app.dns_utils import get_network_dns_client
from app.events.event_dispatcher import EventDispatcher from app.events.event_dispatcher import EventDispatcher
@ -48,7 +44,6 @@ from app.models import (
CustomDomain, CustomDomain,
) )
from app.newsletter_utils import send_newsletter_to_user, send_newsletter_to_address from app.newsletter_utils import send_newsletter_to_user, send_newsletter_to_address
from app.proton.proton_unlink import perform_proton_account_unlink
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
@ -126,7 +121,7 @@ class SLAdminIndexView(AdminIndexView):
if not current_user.is_authenticated or not current_user.is_admin: if not current_user.is_authenticated or not current_user.is_admin:
return redirect(url_for("auth.login", next=request.url)) return redirect(url_for("auth.login", next=request.url))
return redirect(url_for("admin.email_search.index")) return redirect("/admin/email_search")
class UserAdmin(SLModelView): class UserAdmin(SLModelView):
@ -918,7 +913,7 @@ class EmailSearchAdmin(BaseView):
@expose("/", methods=["GET", "POST"]) @expose("/", methods=["GET", "POST"])
def index(self): def index(self):
search = EmailSearchResult() search = EmailSearchResult()
email = request.args.get("query") email = request.args.get("email")
if email is not None and len(email) > 0: if email is not None and len(email) > 0:
email = email.strip() email = email.strip()
search = EmailSearchResult.from_request_email(email) search = EmailSearchResult.from_request_email(email)
@ -930,48 +925,17 @@ class EmailSearchAdmin(BaseView):
helper=EmailSearchHelpers, helper=EmailSearchHelpers,
) )
@expose("/partner_unlink", methods=["POST"])
def delete_partner_link(self):
user_id = request.form.get("user_id")
if not user_id:
flash("Missing user_id", "error")
return redirect(url_for("admin.email_search.index"))
try:
user_id = int(user_id)
except ValueError:
flash("Missing user_id", "error")
return redirect(url_for("admin.email_search.index", query=user_id))
user = User.get(user_id)
if user is None:
flash("User not found", "error")
return redirect(url_for("admin.email_search.index", query=user_id))
external_user_id = perform_proton_account_unlink(user, skip_check=True)
if not external_user_id:
flash("User unlinked", "success")
return redirect(url_for("admin.email_search.index", query=user_id))
AdminAuditLog.create(
admin_user_id=user.id,
model=User.__class__.__name__,
model_id=user.id,
action=AuditLogActionEnum.unlink_user.value,
data={"external_user_id": external_user_id},
)
Session.commit()
return redirect(url_for("admin.email_search.index", query=user_id))
class CustomDomainWithValidationData: class CustomDomainWithValidationData:
def __init__(self, domain: CustomDomain): def __init__(self, domain: CustomDomain):
self.domain: CustomDomain = domain self.domain: CustomDomain = domain
self.ownership_expected: Optional[ExpectedValidationRecords] = None self.ownership_expected: Optional[str] = None
self.ownership_validation: Optional[DomainValidationResult] = None self.ownership_validation: Optional[DomainValidationResult] = None
self.mx_expected: Optional[dict[int, ExpectedValidationRecords]] = None self.mx_expected: Optional[str] = None
self.mx_validation: Optional[DomainValidationResult] = None self.mx_validation: Optional[DomainValidationResult] = None
self.spf_expected: Optional[ExpectedValidationRecords] = None self.spf_expected: Optional[str] = None
self.spf_validation: Optional[DomainValidationResult] = None self.spf_validation: Optional[DomainValidationResult] = None
self.dkim_expected: {str: ExpectedValidationRecords} = {} self.dkim_expected: {str: str} = {}
self.dkim_validation: {str: str} = {} self.dkim_validation: {str: str} = {}
@ -1026,6 +990,7 @@ class CustomDomainSearchResult:
custom_domain custom_domain
) )
out.domains.append(validation_data) out.domains.append(validation_data)
print(validation_data.dkim_expected, validation_data.dkim_validation)
return out return out
@ -1055,6 +1020,7 @@ class CustomDomainSearchAdmin(BaseView):
if cd is not None: if cd is not None:
user = cd.user user = cd.user
search = CustomDomainSearchResult.from_user(user) search = CustomDomainSearchResult.from_user(user)
print("NEW", search.domains)
return self.render( return self.render(
"admin/custom_domain_search.html", "admin/custom_domain_search.html",

View File

@ -36,7 +36,6 @@ def set_mailboxes_for_alias(
Mailbox.user_id == user_id, Mailbox.user_id == user_id,
Mailbox.verified == True, # noqa: E712 Mailbox.verified == True, # noqa: E712
) )
.order_by(Mailbox.id.asc())
.all() .all()
) )
if len(mailboxes) != len(mailbox_ids): if len(mailboxes) != len(mailbox_ids):

View File

@ -191,8 +191,15 @@ def get_alias_infos_with_pagination_v3(
q = q.order_by(Alias.email.desc()) q = q.order_by(Alias.email.desc())
else: else:
# default sorting # default sorting
latest_activity = case(
[
(Alias.created_at > EmailLog.created_at, Alias.created_at),
(Alias.created_at < EmailLog.created_at, EmailLog.created_at),
],
else_=Alias.created_at,
)
q = q.order_by(Alias.pinned.desc()) q = q.order_by(Alias.pinned.desc())
q = q.order_by(func.greatest(Alias.created_at, EmailLog.created_at).desc()) q = q.order_by(latest_activity.desc())
q = q.limit(page_limit).offset(page_id * page_size) q = q.limit(page_limit).offset(page_id * page_size)

View File

@ -1,6 +1,7 @@
import secrets import secrets
import string import string
import facebook
import google.oauth2.credentials import google.oauth2.credentials
import googleapiclient.discovery import googleapiclient.discovery
from flask import jsonify, request from flask import jsonify, request
@ -260,8 +261,6 @@ def auth_facebook():
} }
""" """
import facebook
data = request.get_json() data = request.get_json()
if not data: if not data:
return jsonify(error="request body cannot be empty"), 400 return jsonify(error="request body cannot be empty"), 400

View File

@ -62,17 +62,8 @@ def new_custom_alias_v2():
if not data: if not data:
return jsonify(error="request body cannot be empty"), 400 return jsonify(error="request body cannot be empty"), 400
alias_prefix = data.get("alias_prefix", "") alias_prefix = data.get("alias_prefix", "").strip().lower().replace(" ", "")
if not isinstance(alias_prefix, str) or not alias_prefix: signed_suffix = data.get("signed_suffix", "").strip()
return jsonify(error="invalid value for alias_prefix"), 400
alias_prefix = alias_prefix.strip().lower().replace(" ", "")
signed_suffix = data.get("signed_suffix", "")
if not isinstance(signed_suffix, str) or not signed_suffix:
return jsonify(error="invalid value for signed_suffix"), 400
signed_suffix = signed_suffix.strip()
note = data.get("note") note = data.get("note")
alias_prefix = convert_to_id(alias_prefix) alias_prefix = convert_to_id(alias_prefix)

View File

@ -12,7 +12,7 @@ from app.models import (
SenderFormatEnum, SenderFormatEnum,
AliasSuffixEnum, AliasSuffixEnum,
) )
from app.proton.proton_unlink import perform_proton_account_unlink from app.proton.utils import perform_proton_account_unlink
def setting_to_dict(user: User): def setting_to_dict(user: User):

View File

@ -2,7 +2,7 @@ from flask import jsonify, g
from sqlalchemy_utils.types.arrow import arrow from sqlalchemy_utils.types.arrow import arrow
from app.api.base import api_bp, require_api_sudo, require_api_auth from app.api.base import api_bp, require_api_sudo, require_api_auth
from app.constants import JobType from app import config
from app.extensions import limiter from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import Job, ApiToCookieToken from app.models import Job, ApiToCookieToken
@ -24,7 +24,7 @@ def delete_user():
) )
LOG.w("schedule delete account job for %s", g.user) LOG.w("schedule delete account job for %s", g.user)
Job.create( Job.create(
name=JobType.DELETE_ACCOUNT.value, name=config.JOB_DELETE_ACCOUNT,
payload={"user_id": g.user.id}, payload={"user_id": g.user.id},
run_at=arrow.now(), run_at=arrow.now(),
commit=True, commit=True,
@ -44,8 +44,6 @@ def get_api_session_token():
token: "asdli3ldq39h9hd3", token: "asdli3ldq39h9hd3",
} }
""" """
if not g.api_key:
return jsonify(ok=False), 401
token = ApiToCookieToken.create( token = ApiToCookieToken.create(
user=g.user, user=g.user,
api_key_id=g.api_key.id, api_key_id=g.api_key.id,

View File

@ -12,7 +12,7 @@ from app.dashboard.views.index import get_stats
from app.db import Session from app.db import Session
from app.image_validation import detect_image_format, ImageFormat from app.image_validation import detect_image_format, ImageFormat
from app.models import ApiKey, File, PartnerUser, User from app.models import ApiKey, File, PartnerUser, User
from app.proton.proton_partner import get_proton_partner from app.proton.utils import get_proton_partner
from app.session import logout_session from app.session import logout_session
from app.utils import random_string from app.utils import random_string

View File

@ -23,7 +23,7 @@ from app.proton.proton_callback_handler import (
ProtonCallbackHandler, ProtonCallbackHandler,
Action, Action,
) )
from app.proton.proton_partner import get_proton_partner from app.proton.utils import get_proton_partner
from app.utils import sanitize_next_url, sanitize_scheme from app.utils import sanitize_next_url, sanitize_scheme
_authorization_base_url = PROTON_BASE_URL + "/oauth/authorize" _authorization_base_url = PROTON_BASE_URL + "/oauth/authorize"

View File

@ -1,3 +1,2 @@
SHA1 = "dev" SHA1 = "dev"
BUILD_TIME = "1652365083" BUILD_TIME = "1652365083"
VERSION = SHA1

View File

@ -62,17 +62,6 @@ def get_env_dict(env_var: str) -> dict[str, str]:
return result return result
def get_env_csv(env_var: str, default: Optional[str]) -> list[str]:
"""
Get an env variable and convert it into a list of strings separated by,
Syntax is: val1,val2
"""
value = os.getenv(env_var, default)
if not value:
return []
return [field.strip() for field in value.split(",") if field.strip()]
config_file = os.environ.get("CONFIG") config_file = os.environ.get("CONFIG")
if config_file: if config_file:
config_file = get_abs_path(config_file) config_file = get_abs_path(config_file)
@ -182,14 +171,6 @@ FIRST_ALIAS_DOMAIN = os.environ.get("FIRST_ALIAS_DOMAIN") or EMAIL_DOMAIN
# e.g. [(10, "mx1.hostname."), (10, "mx2.hostname.")] # e.g. [(10, "mx1.hostname."), (10, "mx2.hostname.")]
EMAIL_SERVERS_WITH_PRIORITY = sl_getenv("EMAIL_SERVERS_WITH_PRIORITY") EMAIL_SERVERS_WITH_PRIORITY = sl_getenv("EMAIL_SERVERS_WITH_PRIORITY")
PROTON_MX_SERVERS = get_env_csv(
"PROTON_MX_SERVERS", "mail.protonmail.ch., mailsec.protonmail.ch."
)
PROTON_EMAIL_DOMAINS = get_env_csv(
"PROTON_EMAIL_DOMAINS", "proton.me, protonmail.com, protonmail.ch, proton.ch, pm.me"
)
# disable the alias suffix, i.e. the ".random_word" part # disable the alias suffix, i.e. the ".random_word" part
DISABLE_ALIAS_SUFFIX = "DISABLE_ALIAS_SUFFIX" in os.environ DISABLE_ALIAS_SUFFIX = "DISABLE_ALIAS_SUFFIX" in os.environ
@ -316,6 +297,20 @@ MFA_USER_ID = "mfa_user_id"
FLASK_PROFILER_PATH = os.environ.get("FLASK_PROFILER_PATH") FLASK_PROFILER_PATH = os.environ.get("FLASK_PROFILER_PATH")
FLASK_PROFILER_PASSWORD = os.environ.get("FLASK_PROFILER_PASSWORD") FLASK_PROFILER_PASSWORD = os.environ.get("FLASK_PROFILER_PASSWORD")
# Job names
JOB_ONBOARDING_1 = "onboarding-1"
JOB_ONBOARDING_2 = "onboarding-2"
JOB_ONBOARDING_3 = "onboarding-3"
JOB_ONBOARDING_4 = "onboarding-4"
JOB_BATCH_IMPORT = "batch-import"
JOB_DELETE_ACCOUNT = "delete-account"
JOB_DELETE_MAILBOX = "delete-mailbox"
JOB_DELETE_DOMAIN = "delete-domain"
JOB_SEND_USER_REPORT = "send-user-report"
JOB_SEND_PROTON_WELCOME_1 = "proton-welcome-1"
JOB_SEND_ALIAS_CREATION_EVENTS = "send-alias-creation-events"
JOB_SEND_EVENT_TO_WEBHOOK = "send-event-to-webhook"
# for pagination # for pagination
PAGE_LIMIT = 20 PAGE_LIMIT = 20

View File

@ -1,18 +1,2 @@
import enum
HEADER_ALLOW_API_COOKIES = "X-Sl-Allowcookies" HEADER_ALLOW_API_COOKIES = "X-Sl-Allowcookies"
DMARC_RECORD = "v=DMARC1; p=quarantine; pct=100; adkim=s; aspf=s" DMARC_RECORD = "v=DMARC1; p=quarantine; pct=100; adkim=s; aspf=s"
class JobType(enum.Enum):
ONBOARDING_1 = "onboarding-1"
ONBOARDING_2 = "onboarding-2"
ONBOARDING_4 = "onboarding-4"
BATCH_IMPORT = "batch-import"
DELETE_ACCOUNT = "delete-account"
DELETE_MAILBOX = "delete-mailbox"
DELETE_DOMAIN = "delete-domain"
SEND_USER_REPORT = "send-user-report"
SEND_PROTON_WELCOME_1 = "proton-welcome-1"
SEND_ALIAS_CREATION_EVENTS = "send-alias-creation-events"
SEND_EVENT_TO_WEBHOOK = "send-event-to-webhook"

View File

@ -9,14 +9,7 @@ from app.email_utils import send_email
from app.events.event_dispatcher import EventDispatcher from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import EventContent, UserPlanChanged from app.events.generated.event_pb2 import EventContent, UserPlanChanged
from app.log import LOG from app.log import LOG
from app.models import ( from app.models import User, ManualSubscription, Coupon, LifetimeCoupon
User,
ManualSubscription,
Coupon,
LifetimeCoupon,
PartnerSubscription,
PartnerUser,
)
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
@ -71,18 +64,13 @@ def redeem_coupon(coupon_code: str, user: User) -> Optional[Coupon]:
else: else:
sub.end_at = arrow.now().shift(years=coupon.nb_year, days=1) sub.end_at = arrow.now().shift(years=coupon.nb_year, days=1)
else: else:
# There may be an expired manual subscription sub = ManualSubscription.create(
sub = ManualSubscription.get_by(user_id=user.id) user_id=user.id,
end_at = arrow.now().shift(years=coupon.nb_year, days=1) end_at=arrow.now().shift(years=coupon.nb_year, days=1),
if sub: comment="using coupon code",
sub.end_at = end_at is_giveaway=coupon.is_giveaway,
else: commit=True,
sub = ManualSubscription.create( )
user_id=user.id,
end_at=end_at,
comment="using coupon code",
is_giveaway=coupon.is_giveaway,
)
emit_user_audit_log( emit_user_audit_log(
user=user, user=user,
action=UserAuditLogAction.Upgrade, action=UserAuditLogAction.Upgrade,
@ -99,16 +87,6 @@ def redeem_coupon(coupon_code: str, user: User) -> Optional[Coupon]:
def redeem_lifetime_coupon(coupon_code: str, user: User) -> Optional[Coupon]: def redeem_lifetime_coupon(coupon_code: str, user: User) -> Optional[Coupon]:
if user.lifetime:
return None
partner_sub = (
Session.query(PartnerSubscription)
.join(PartnerUser, PartnerUser.id == PartnerSubscription.partner_user_id)
.filter(PartnerUser.user_id == user.id, PartnerSubscription.lifetime == True) # noqa: E712
.first()
)
if partner_sub is not None:
return None
coupon: LifetimeCoupon = LifetimeCoupon.get_by(code=coupon_code) coupon: LifetimeCoupon = LifetimeCoupon.get_by(code=coupon_code)
if not coupon: if not coupon:
return None return None

View File

@ -5,7 +5,7 @@ from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import List, Optional from typing import List, Optional
from app.constants import JobType from app.config import JOB_DELETE_DOMAIN
from app.db import Session from app.db import Session
from app.email_utils import get_email_domain_part from app.email_utils import get_email_domain_part
from app.log import LOG from app.log import LOG
@ -156,7 +156,7 @@ def delete_custom_domain(domain: CustomDomain):
LOG.w("schedule delete domain job for %s", domain) LOG.w("schedule delete domain job for %s", domain)
domain.pending_deletion = True domain.pending_deletion = True
Job.create( Job.create(
name=JobType.DELETE_DOMAIN.value, name=JOB_DELETE_DOMAIN,
payload={"custom_domain_id": domain.id}, payload={"custom_domain_id": domain.id},
run_at=arrow.now(), run_at=arrow.now(),
commit=True, commit=True,

View File

@ -5,7 +5,9 @@ from app import config
from app.constants import DMARC_RECORD from app.constants import DMARC_RECORD
from app.db import Session from app.db import Session
from app.dns_utils import ( from app.dns_utils import (
MxRecord,
DNSClient, DNSClient,
is_mx_equivalent,
get_network_dns_client, get_network_dns_client,
) )
from app.models import CustomDomain from app.models import CustomDomain
@ -19,39 +21,6 @@ class DomainValidationResult:
errors: [str] errors: [str]
@dataclass
class ExpectedValidationRecords:
recommended: str
allowed: list[str]
def is_mx_equivalent(
mx_domains: dict[int, list[str]],
expected_mx_domains: dict[int, ExpectedValidationRecords],
) -> bool:
"""
Compare mx_domains with ref_mx_domains to see if they are equivalent.
mx_domains and ref_mx_domains are list of (priority, domain)
The priority order is taken into account but not the priority number.
For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)]
"""
expected_prios = []
for prio in expected_mx_domains:
expected_prios.append(prio)
if len(expected_prios) != len(mx_domains):
return False
for prio_position, prio_value in enumerate(sorted(mx_domains.keys())):
for domain in mx_domains[prio_value]:
if domain not in expected_mx_domains[expected_prios[prio_position]].allowed:
return False
return True
class CustomDomainValidation: class CustomDomainValidation:
def __init__( def __init__(
self, self,
@ -68,88 +37,59 @@ class CustomDomainValidation:
or config.PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES or config.PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES
) )
def get_ownership_verification_record( def get_ownership_verification_record(self, domain: CustomDomain) -> str:
self, domain: CustomDomain prefix = "sl"
) -> ExpectedValidationRecords:
prefixes = ["sl"]
if ( if (
domain.partner_id is not None domain.partner_id is not None
and domain.partner_id in self._partner_domain_validation_prefixes and domain.partner_id in self._partner_domain_validation_prefixes
): ):
prefixes.insert( prefix = self._partner_domain_validation_prefixes[domain.partner_id]
0, self._partner_domain_validation_prefixes[domain.partner_id]
)
if not domain.ownership_txt_token: if not domain.ownership_txt_token:
domain.ownership_txt_token = random_string(30) domain.ownership_txt_token = random_string(30)
Session.commit() Session.commit()
valid = [ return f"{prefix}-verification={domain.ownership_txt_token}"
f"{prefix}-verification={domain.ownership_txt_token}" for prefix in prefixes
]
return ExpectedValidationRecords(recommended=valid[0], allowed=valid)
def get_expected_mx_records( def get_expected_mx_records(self, domain: CustomDomain) -> list[MxRecord]:
self, domain: CustomDomain
) -> dict[int, ExpectedValidationRecords]:
records = {}
if domain.partner_id is not None and domain.partner_id in self._partner_domains:
domain = self._partner_domains[domain.partner_id]
records[10] = [f"mx1.{domain}."]
records[20] = [f"mx2.{domain}."]
# Default ones
for priority, domain in config.EMAIL_SERVERS_WITH_PRIORITY:
if priority not in records:
records[priority] = []
records[priority].append(domain)
return {
priority: ExpectedValidationRecords(
recommended=records[priority][0], allowed=records[priority]
)
for priority in records
}
def get_expected_spf_domain(
self, domain: CustomDomain
) -> ExpectedValidationRecords:
records = [] records = []
if domain.partner_id is not None and domain.partner_id in self._partner_domains: if domain.partner_id is not None and domain.partner_id in self._partner_domains:
records.append(self._partner_domains[domain.partner_id]) domain = self._partner_domains[domain.partner_id]
records.append(MxRecord(10, f"mx1.{domain}."))
records.append(MxRecord(20, f"mx2.{domain}."))
else: else:
records.append(config.EMAIL_DOMAIN) # Default ones
return ExpectedValidationRecords(recommended=records[0], allowed=records) for priority, domain in config.EMAIL_SERVERS_WITH_PRIORITY:
records.append(MxRecord(priority, domain))
return records
def get_expected_spf_domain(self, domain: CustomDomain) -> str:
if domain.partner_id is not None and domain.partner_id in self._partner_domains:
return self._partner_domains[domain.partner_id]
else:
return config.EMAIL_DOMAIN
def get_expected_spf_record(self, domain: CustomDomain) -> str: def get_expected_spf_record(self, domain: CustomDomain) -> str:
spf_domain = self.get_expected_spf_domain(domain) spf_domain = self.get_expected_spf_domain(domain)
return f"v=spf1 include:{spf_domain.recommended} ~all" return f"v=spf1 include:{spf_domain} ~all"
def get_dkim_records( def get_dkim_records(self, domain: CustomDomain) -> {str: str}:
self, domain: CustomDomain
) -> {str: ExpectedValidationRecords}:
""" """
Get a list of dkim records to set up. Depending on the custom_domain, whether if it's from a partner or not, Get a list of dkim records to set up. Depending on the custom_domain, whether if it's from a partner or not,
it will return the default ones or the partner ones. it will return the default ones or the partner ones.
""" """
# By default use the default domain # By default use the default domain
dkim_domains = [self.dkim_domain] dkim_domain = self.dkim_domain
if domain.partner_id is not None: if domain.partner_id is not None:
# Domain is from a partner. Retrieve the partner config and use that domain as preferred if it exists # Domain is from a partner. Retrieve the partner config and use that domain if exists
partner_domain = self._partner_domains.get(domain.partner_id, None) dkim_domain = self._partner_domains.get(domain.partner_id, dkim_domain)
if partner_domain is not None:
dkim_domains.insert(0, partner_domain)
output = {} return {
for key in ("dkim", "dkim02", "dkim03"): f"{key}._domainkey": f"{key}._domainkey.{dkim_domain}"
records = [ for key in ("dkim", "dkim02", "dkim03")
f"{key}._domainkey.{dkim_domain}" for dkim_domain in dkim_domains }
]
output[f"{key}._domainkey"] = ExpectedValidationRecords(
recommended=records[0], allowed=records
)
return output
def validate_dkim_records(self, custom_domain: CustomDomain) -> dict[str, str]: def validate_dkim_records(self, custom_domain: CustomDomain) -> dict[str, str]:
""" """
@ -162,7 +102,7 @@ class CustomDomainValidation:
for prefix, expected_record in expected_records.items(): for prefix, expected_record in expected_records.items():
custom_record = f"{prefix}.{custom_domain.domain}" custom_record = f"{prefix}.{custom_domain.domain}"
dkim_record = self._dns_client.get_cname_record(custom_record) dkim_record = self._dns_client.get_cname_record(custom_record)
if dkim_record in expected_record.allowed: if dkim_record == expected_record:
correct_records[prefix] = custom_record correct_records[prefix] = custom_record
else: else:
invalid_records[custom_record] = dkim_record or "empty" invalid_records[custom_record] = dkim_record or "empty"
@ -198,15 +138,11 @@ class CustomDomainValidation:
Check if the custom_domain has added the ownership verification records Check if the custom_domain has added the ownership verification records
""" """
txt_records = self._dns_client.get_txt_record(custom_domain.domain) txt_records = self._dns_client.get_txt_record(custom_domain.domain)
expected_verification_records = self.get_ownership_verification_record( expected_verification_record = self.get_ownership_verification_record(
custom_domain custom_domain
) )
found = False
for verification_record in expected_verification_records.allowed: if expected_verification_record in txt_records:
if verification_record in txt_records:
found = True
break
if found:
custom_domain.ownership_verified = True custom_domain.ownership_verified = True
emit_user_audit_log( emit_user_audit_log(
user=custom_domain.user, user=custom_domain.user,
@ -225,11 +161,10 @@ class CustomDomainValidation:
expected_mx_records = self.get_expected_mx_records(custom_domain) expected_mx_records = self.get_expected_mx_records(custom_domain)
if not is_mx_equivalent(mx_domains, expected_mx_records): if not is_mx_equivalent(mx_domains, expected_mx_records):
errors = [] return DomainValidationResult(
for prio in mx_domains: success=False,
for mx_domain in mx_domains[prio]: errors=[f"{record.priority} {record.domain}" for record in mx_domains],
errors.append(f"{prio} {mx_domain}") )
return DomainValidationResult(success=False, errors=errors)
else: else:
custom_domain.verified = True custom_domain.verified = True
emit_user_audit_log( emit_user_audit_log(
@ -245,7 +180,7 @@ class CustomDomainValidation:
) -> DomainValidationResult: ) -> DomainValidationResult:
spf_domains = self._dns_client.get_spf_domain(custom_domain.domain) spf_domains = self._dns_client.get_spf_domain(custom_domain.domain)
expected_spf_domain = self.get_expected_spf_domain(custom_domain) expected_spf_domain = self.get_expected_spf_domain(custom_domain)
if len(set(expected_spf_domain.allowed).intersection(set(spf_domains))) > 0: if expected_spf_domain in spf_domains:
custom_domain.spf_verified = True custom_domain.spf_verified = True
emit_user_audit_log( emit_user_audit_log(
user=custom_domain.user, user=custom_domain.user,
@ -286,8 +221,8 @@ class CustomDomainValidation:
self, txt_records: List[str], custom_domain: CustomDomain self, txt_records: List[str], custom_domain: CustomDomain
) -> List[str]: ) -> List[str]:
final_records = [] final_records = []
verification_records = self.get_ownership_verification_record(custom_domain) verification_record = self.get_ownership_verification_record(custom_domain)
for record in txt_records: for record in txt_records:
if record not in verification_records.allowed: if record != verification_record:
final_records.append(record) final_records.append(record)
return final_records return final_records

View File

@ -39,7 +39,7 @@ from app.models import (
SenderFormatEnum, SenderFormatEnum,
UnsubscribeBehaviourEnum, UnsubscribeBehaviourEnum,
) )
from app.proton.proton_unlink import perform_proton_account_unlink from app.proton.utils import perform_proton_account_unlink
from app.utils import ( from app.utils import (
random_string, random_string,
CSRFValidationForm, CSRFValidationForm,

View File

@ -3,7 +3,7 @@ from flask import render_template, flash, request, redirect, url_for
from flask_login import login_required, current_user from flask_login import login_required, current_user
from app import s3 from app import s3
from app.constants import JobType from app.config import JOB_BATCH_IMPORT
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required from app.dashboard.views.enter_sudo import sudo_required
from app.db import Session from app.db import Session
@ -64,7 +64,7 @@ def batch_import_route():
# Schedule batch import job # Schedule batch import job
Job.create( Job.create(
name=JobType.BATCH_IMPORT.value, name=JOB_BATCH_IMPORT,
payload={"batch_import_id": bi.id}, payload={"batch_import_id": bi.id},
run_at=arrow.now(), run_at=arrow.now(),
) )

View File

@ -3,7 +3,7 @@ from flask import flash, redirect, url_for, request, render_template
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from app.constants import JobType from app.config import JOB_DELETE_ACCOUNT
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required from app.dashboard.views.enter_sudo import sudo_required
from app.log import LOG from app.log import LOG
@ -40,7 +40,7 @@ def delete_account():
message=f"User {current_user.id} ({current_user.email}) marked for deletion via webapp", message=f"User {current_user.id} ({current_user.email}) marked for deletion via webapp",
) )
Job.create( Job.create(
name=JobType.DELETE_ACCOUNT.value, name=JOB_DELETE_ACCOUNT,
payload={"user_id": current_user.id}, payload={"user_id": current_user.id},
run_at=arrow.now(), run_at=arrow.now(),
commit=True, commit=True,

View File

@ -5,8 +5,8 @@ from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import StringField, validators, IntegerField from wtforms import StringField, validators, IntegerField
from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN
from app.constants import DMARC_RECORD from app.constants import DMARC_RECORD
from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN
from app.custom_domain_utils import delete_custom_domain, set_custom_domain_mailboxes from app.custom_domain_utils import delete_custom_domain, set_custom_domain_mailboxes
from app.custom_domain_validation import CustomDomainValidation from app.custom_domain_validation import CustomDomainValidation
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
@ -137,7 +137,7 @@ def domain_detail_dns(custom_domain_id):
return render_template( return render_template(
"dashboard/domain_detail/dns.html", "dashboard/domain_detail/dns.html",
EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY, EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY,
ownership_records=domain_validator.get_ownership_verification_record( ownership_record=domain_validator.get_ownership_verification_record(
custom_domain custom_domain
), ),
expected_mx_records=domain_validator.get_expected_mx_records(custom_domain), expected_mx_records=domain_validator.get_expected_mx_records(custom_domain),

View File

@ -11,7 +11,7 @@ from app.dashboard.base import dashboard_bp
from app.extensions import limiter from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import PartnerUser, SocialAuth from app.models import PartnerUser, SocialAuth
from app.proton.proton_partner import get_proton_partner from app.proton.utils import get_proton_partner
from app.utils import sanitize_next_url from app.utils import sanitize_next_url
_SUDO_GAP = 120 _SUDO_GAP = 120

View File

@ -267,13 +267,12 @@ def cancel_mailbox_change_route(mailbox_id):
@dashboard_bp.route("/mailbox/confirm_change") @dashboard_bp.route("/mailbox/confirm_change")
@login_required
@limiter.limit("3/minute")
def mailbox_confirm_email_change_route(): def mailbox_confirm_email_change_route():
mailbox_id = request.args.get("mailbox_id") mailbox_id = request.args.get("mailbox_id")
code = request.args.get("code") code = request.args.get("code")
if code: if code:
print("HAS OCO", code)
try: try:
mailbox = mailbox_utils.verify_mailbox_code(current_user, mailbox_id, code) mailbox = mailbox_utils.verify_mailbox_code(current_user, mailbox_id, code)
flash("Successfully changed mailbox email", "success") flash("Successfully changed mailbox email", "success")
@ -281,6 +280,7 @@ def mailbox_confirm_email_change_route():
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox.id) url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox.id)
) )
except mailbox_utils.MailboxError as e: except mailbox_utils.MailboxError as e:
print(e)
flash(f"Cannot verify mailbox: {e.msg}", "error") flash(f"Cannot verify mailbox: {e.msg}", "error")
return redirect(url_for("dashboard.mailbox_route")) return redirect(url_for("dashboard.mailbox_route"))
else: else:

View File

@ -22,7 +22,7 @@ from app.models import (
PartnerUser, PartnerUser,
PartnerSubscription, PartnerSubscription,
) )
from app.proton.proton_partner import get_proton_partner from app.proton.utils import get_proton_partner
@dashboard_bp.route("/pricing", methods=["GET", "POST"]) @dashboard_bp.route("/pricing", methods=["GET", "POST"])

View File

@ -41,8 +41,7 @@ from app.models import (
PartnerSubscription, PartnerSubscription,
UnsubscribeBehaviourEnum, UnsubscribeBehaviourEnum,
) )
from app.proton.proton_partner import get_proton_partner from app.proton.utils import get_proton_partner, can_unlink_proton_account
from app.proton.proton_unlink import can_unlink_proton_account
from app.utils import ( from app.utils import (
random_string, random_string,
CSRFValidationForm, CSRFValidationForm,

View File

@ -1,5 +1,4 @@
"""List of clients""" """List of clients"""
from flask import render_template from flask import render_template
from flask_login import current_user, login_required from flask_login import current_user, login_required

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional from typing import List, Optional
import dns.resolver import dns.resolver
@ -8,13 +9,42 @@ from app.config import NAMESERVERS
_include_spf = "include:" _include_spf = "include:"
@dataclass
class MxRecord:
priority: int
domain: str
def is_mx_equivalent(
mx_domains: List[MxRecord], ref_mx_domains: List[MxRecord]
) -> bool:
"""
Compare mx_domains with ref_mx_domains to see if they are equivalent.
mx_domains and ref_mx_domains are list of (priority, domain)
The priority order is taken into account but not the priority number.
For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)]
"""
mx_domains = sorted(mx_domains, key=lambda x: x.priority)
ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x.priority)
if len(mx_domains) < len(ref_mx_domains):
return False
for actual, expected in zip(mx_domains, ref_mx_domains):
if actual.domain != expected.domain:
return False
return True
class DNSClient(ABC): class DNSClient(ABC):
@abstractmethod @abstractmethod
def get_cname_record(self, hostname: str) -> Optional[str]: def get_cname_record(self, hostname: str) -> Optional[str]:
pass pass
@abstractmethod @abstractmethod
def get_mx_domains(self, hostname: str) -> dict[int, list[str]]: def get_mx_domains(self, hostname: str) -> List[MxRecord]:
pass pass
def get_spf_domain(self, hostname: str) -> List[str]: def get_spf_domain(self, hostname: str) -> List[str]:
@ -58,24 +88,21 @@ class NetworkDNSClient(DNSClient):
except Exception: except Exception:
return None return None
def get_mx_domains(self, hostname: str) -> dict[int, list[str]]: def get_mx_domains(self, hostname: str) -> List[MxRecord]:
""" """
return list of (priority, domain name) sorted by priority (lowest priority first) return list of (priority, domain name) sorted by priority (lowest priority first)
domain name ends with a "." at the end. domain name ends with a "." at the end.
""" """
ret = {}
try: try:
answers = self._resolver.resolve(hostname, "MX", search=True) answers = self._resolver.resolve(hostname, "MX", search=True)
ret = []
for a in answers: for a in answers:
record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.' record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
parts = record.split(" ") parts = record.split(" ")
prio = int(parts[0]) ret.append(MxRecord(priority=int(parts[0]), domain=parts[1]))
if prio not in ret: return sorted(ret, key=lambda x: x.priority)
ret[prio] = []
ret[prio].append(parts[1])
except Exception: except Exception:
pass return []
return ret
def get_txt_record(self, hostname: str) -> List[str]: def get_txt_record(self, hostname: str) -> List[str]:
try: try:
@ -92,14 +119,14 @@ class NetworkDNSClient(DNSClient):
class InMemoryDNSClient(DNSClient): class InMemoryDNSClient(DNSClient):
def __init__(self): def __init__(self):
self.cname_records: dict[str, Optional[str]] = {} self.cname_records: dict[str, Optional[str]] = {}
self.mx_records: dict[int, dict[int, list[str]]] = {} self.mx_records: dict[str, List[MxRecord]] = {}
self.spf_records: dict[str, List[str]] = {} self.spf_records: dict[str, List[str]] = {}
self.txt_records: dict[str, List[str]] = {} self.txt_records: dict[str, List[str]] = {}
def set_cname_record(self, hostname: str, cname: str): def set_cname_record(self, hostname: str, cname: str):
self.cname_records[hostname] = cname self.cname_records[hostname] = cname
def set_mx_records(self, hostname: str, mx_list: dict[int, list[str]]): def set_mx_records(self, hostname: str, mx_list: List[MxRecord]):
self.mx_records[hostname] = mx_list self.mx_records[hostname] = mx_list
def set_txt_record(self, hostname: str, txt_list: List[str]): def set_txt_record(self, hostname: str, txt_list: List[str]):
@ -108,27 +135,17 @@ class InMemoryDNSClient(DNSClient):
def get_cname_record(self, hostname: str) -> Optional[str]: def get_cname_record(self, hostname: str) -> Optional[str]:
return self.cname_records.get(hostname) return self.cname_records.get(hostname)
def get_mx_domains(self, hostname: str) -> dict[int, list[str]]: def get_mx_domains(self, hostname: str) -> List[MxRecord]:
return self.mx_records.get(hostname, {}) mx_list = self.mx_records.get(hostname, [])
return sorted(mx_list, key=lambda x: x.priority)
def get_txt_record(self, hostname: str) -> List[str]: def get_txt_record(self, hostname: str) -> List[str]:
return self.txt_records.get(hostname, []) return self.txt_records.get(hostname, [])
global_dns_client: Optional[DNSClient] = None def get_network_dns_client() -> NetworkDNSClient:
def get_network_dns_client() -> DNSClient:
global global_dns_client
if global_dns_client is not None:
return global_dns_client
return NetworkDNSClient(NAMESERVERS) return NetworkDNSClient(NAMESERVERS)
def set_global_dns_client(dns_client: Optional[DNSClient]): def get_mx_domains(hostname: str) -> List[MxRecord]:
global global_dns_client
global_dns_client = dns_client
def get_mx_domains(hostname: str) -> dict[int, list[str]]:
return get_network_dns_client().get_mx_domains(hostname) return get_network_dns_client().get_mx_domains(hostname)

View File

@ -1,5 +1,4 @@
"""Email headers""" """Email headers"""
MESSAGE_ID = "Message-ID" MESSAGE_ID = "Message-ID"
IN_REPLY_TO = "In-Reply-To" IN_REPLY_TO = "In-Reply-To"
REFERENCES = "References" REFERENCES = "References"

View File

@ -657,11 +657,7 @@ def get_mx_domain_list(domain) -> [str]:
""" """
priority_domains = get_mx_domains(domain) priority_domains = get_mx_domains(domain)
mx_domains = [] return [d.domain[:-1] for d in priority_domains]
for prio in priority_domains:
for domain in priority_domains[prio]:
mx_domains.append(domain[:-1])
return mx_domains
def personal_email_already_used(email_address: str) -> bool: def personal_email_already_used(email_address: str) -> bool:
@ -1355,9 +1351,7 @@ def get_queue_id(msg: Message) -> Optional[str]:
search_result = re.search(r"with E?SMTP[AS]? id ([0-9a-zA-Z]{1,})", received_header) search_result = re.search(r"with E?SMTP[AS]? id ([0-9a-zA-Z]{1,})", received_header)
if search_result: if search_result:
return search_result.group(1) return search_result.group(1)
search_result = re.search( search_result = re.search("\(Postfix\)\r\n\tid ([a-zA-Z0-9]{1,});", received_header)
r"\(Postfix\)\r\n\tid ([a-zA-Z0-9]{1,});", received_header
)
if search_result: if search_result:
return search_result.group(1) return search_result.group(1)
return None return None

View File

@ -8,7 +8,7 @@ from app.errors import ProtonPartnerNotSetUp
from app.events.generated import event_pb2 from app.events.generated import event_pb2
from app.log import LOG from app.log import LOG
from app.models import User, PartnerUser, SyncEvent from app.models import User, PartnerUser, SyncEvent
from app.proton.proton_partner import get_proton_partner from app.proton.utils import get_proton_partner
from typing import Optional from typing import Optional
NOTIFICATION_CHANNEL = "simplelogin_sync_events" NOTIFICATION_CHANNEL = "simplelogin_sync_events"

View File

@ -24,7 +24,7 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0b\x65vent.proto\x12\x12simplelogin_events\":\n\x0fUserPlanChanged\x12\x15\n\rplan_end_time\x18\x01 \x01(\r\x12\x10\n\x08lifetime\x18\x02 \x01(\x08\"\r\n\x0bUserDeleted\"\\\n\x0c\x41liasCreated\x12\n\n\x02id\x18\x01 \x01(\r\x12\r\n\x05\x65mail\x18\x02 \x01(\t\x12\x0c\n\x04note\x18\x03 \x01(\t\x12\x0f\n\x07\x65nabled\x18\x04 \x01(\x08\x12\x12\n\ncreated_at\x18\x05 \x01(\r\"T\n\x12\x41liasStatusChanged\x12\n\n\x02id\x18\x01 \x01(\r\x12\r\n\x05\x65mail\x18\x02 \x01(\t\x12\x0f\n\x07\x65nabled\x18\x03 \x01(\x08\x12\x12\n\ncreated_at\x18\x04 \x01(\r\")\n\x0c\x41liasDeleted\x12\n\n\x02id\x18\x01 \x01(\r\x12\r\n\x05\x65mail\x18\x02 \x01(\t\"D\n\x10\x41liasCreatedList\x12\x30\n\x06\x65vents\x18\x01 \x03(\x0b\x32 .simplelogin_events.AliasCreated\"\x0e\n\x0cUserUnlinked\"\xce\x03\n\x0c\x45ventContent\x12?\n\x10user_plan_change\x18\x01 \x01(\x0b\x32#.simplelogin_events.UserPlanChangedH\x00\x12\x37\n\x0cuser_deleted\x18\x02 \x01(\x0b\x32\x1f.simplelogin_events.UserDeletedH\x00\x12\x39\n\ralias_created\x18\x03 \x01(\x0b\x32 .simplelogin_events.AliasCreatedH\x00\x12\x45\n\x13\x61lias_status_change\x18\x04 \x01(\x0b\x32&.simplelogin_events.AliasStatusChangedH\x00\x12\x39\n\ralias_deleted\x18\x05 \x01(\x0b\x32 .simplelogin_events.AliasDeletedH\x00\x12\x41\n\x11\x61lias_create_list\x18\x06 \x01(\x0b\x32$.simplelogin_events.AliasCreatedListH\x00\x12\x39\n\ruser_unlinked\x18\x07 \x01(\x0b\x32 .simplelogin_events.UserUnlinkedH\x00\x42\t\n\x07\x63ontent\"y\n\x05\x45vent\x12\x0f\n\x07user_id\x18\x01 \x01(\r\x12\x18\n\x10\x65xternal_user_id\x18\x02 \x01(\t\x12\x12\n\npartner_id\x18\x03 \x01(\r\x12\x31\n\x07\x63ontent\x18\x04 \x01(\x0b\x32 .simplelogin_events.EventContentb\x06proto3') DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0b\x65vent.proto\x12\x12simplelogin_events\":\n\x0fUserPlanChanged\x12\x15\n\rplan_end_time\x18\x01 \x01(\r\x12\x10\n\x08lifetime\x18\x02 \x01(\x08\"\r\n\x0bUserDeleted\"\\\n\x0c\x41liasCreated\x12\n\n\x02id\x18\x01 \x01(\r\x12\r\n\x05\x65mail\x18\x02 \x01(\t\x12\x0c\n\x04note\x18\x03 \x01(\t\x12\x0f\n\x07\x65nabled\x18\x04 \x01(\x08\x12\x12\n\ncreated_at\x18\x05 \x01(\r\"T\n\x12\x41liasStatusChanged\x12\n\n\x02id\x18\x01 \x01(\r\x12\r\n\x05\x65mail\x18\x02 \x01(\t\x12\x0f\n\x07\x65nabled\x18\x03 \x01(\x08\x12\x12\n\ncreated_at\x18\x04 \x01(\r\")\n\x0c\x41liasDeleted\x12\n\n\x02id\x18\x01 \x01(\r\x12\r\n\x05\x65mail\x18\x02 \x01(\t\"D\n\x10\x41liasCreatedList\x12\x30\n\x06\x65vents\x18\x01 \x03(\x0b\x32 .simplelogin_events.AliasCreated\"\x93\x03\n\x0c\x45ventContent\x12?\n\x10user_plan_change\x18\x01 \x01(\x0b\x32#.simplelogin_events.UserPlanChangedH\x00\x12\x37\n\x0cuser_deleted\x18\x02 \x01(\x0b\x32\x1f.simplelogin_events.UserDeletedH\x00\x12\x39\n\ralias_created\x18\x03 \x01(\x0b\x32 .simplelogin_events.AliasCreatedH\x00\x12\x45\n\x13\x61lias_status_change\x18\x04 \x01(\x0b\x32&.simplelogin_events.AliasStatusChangedH\x00\x12\x39\n\ralias_deleted\x18\x05 \x01(\x0b\x32 .simplelogin_events.AliasDeletedH\x00\x12\x41\n\x11\x61lias_create_list\x18\x06 \x01(\x0b\x32$.simplelogin_events.AliasCreatedListH\x00\x42\t\n\x07\x63ontent\"y\n\x05\x45vent\x12\x0f\n\x07user_id\x18\x01 \x01(\r\x12\x18\n\x10\x65xternal_user_id\x18\x02 \x01(\t\x12\x12\n\npartner_id\x18\x03 \x01(\r\x12\x31\n\x07\x63ontent\x18\x04 \x01(\x0b\x32 .simplelogin_events.EventContentb\x06proto3')
_globals = globals() _globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@ -43,10 +43,8 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals['_ALIASDELETED']._serialized_end=331 _globals['_ALIASDELETED']._serialized_end=331
_globals['_ALIASCREATEDLIST']._serialized_start=333 _globals['_ALIASCREATEDLIST']._serialized_start=333
_globals['_ALIASCREATEDLIST']._serialized_end=401 _globals['_ALIASCREATEDLIST']._serialized_end=401
_globals['_USERUNLINKED']._serialized_start=403 _globals['_EVENTCONTENT']._serialized_start=404
_globals['_USERUNLINKED']._serialized_end=417 _globals['_EVENTCONTENT']._serialized_end=807
_globals['_EVENTCONTENT']._serialized_start=420 _globals['_EVENT']._serialized_start=809
_globals['_EVENTCONTENT']._serialized_end=882 _globals['_EVENT']._serialized_end=930
_globals['_EVENT']._serialized_start=884
_globals['_EVENT']._serialized_end=1005
# @@protoc_insertion_point(module_scope) # @@protoc_insertion_point(module_scope)

View File

@ -57,27 +57,21 @@ class AliasCreatedList(_message.Message):
events: _containers.RepeatedCompositeFieldContainer[AliasCreated] events: _containers.RepeatedCompositeFieldContainer[AliasCreated]
def __init__(self, events: _Optional[_Iterable[_Union[AliasCreated, _Mapping]]] = ...) -> None: ... def __init__(self, events: _Optional[_Iterable[_Union[AliasCreated, _Mapping]]] = ...) -> None: ...
class UserUnlinked(_message.Message):
__slots__ = ()
def __init__(self) -> None: ...
class EventContent(_message.Message): class EventContent(_message.Message):
__slots__ = ("user_plan_change", "user_deleted", "alias_created", "alias_status_change", "alias_deleted", "alias_create_list", "user_unlinked") __slots__ = ("user_plan_change", "user_deleted", "alias_created", "alias_status_change", "alias_deleted", "alias_create_list")
USER_PLAN_CHANGE_FIELD_NUMBER: _ClassVar[int] USER_PLAN_CHANGE_FIELD_NUMBER: _ClassVar[int]
USER_DELETED_FIELD_NUMBER: _ClassVar[int] USER_DELETED_FIELD_NUMBER: _ClassVar[int]
ALIAS_CREATED_FIELD_NUMBER: _ClassVar[int] ALIAS_CREATED_FIELD_NUMBER: _ClassVar[int]
ALIAS_STATUS_CHANGE_FIELD_NUMBER: _ClassVar[int] ALIAS_STATUS_CHANGE_FIELD_NUMBER: _ClassVar[int]
ALIAS_DELETED_FIELD_NUMBER: _ClassVar[int] ALIAS_DELETED_FIELD_NUMBER: _ClassVar[int]
ALIAS_CREATE_LIST_FIELD_NUMBER: _ClassVar[int] ALIAS_CREATE_LIST_FIELD_NUMBER: _ClassVar[int]
USER_UNLINKED_FIELD_NUMBER: _ClassVar[int]
user_plan_change: UserPlanChanged user_plan_change: UserPlanChanged
user_deleted: UserDeleted user_deleted: UserDeleted
alias_created: AliasCreated alias_created: AliasCreated
alias_status_change: AliasStatusChanged alias_status_change: AliasStatusChanged
alias_deleted: AliasDeleted alias_deleted: AliasDeleted
alias_create_list: AliasCreatedList alias_create_list: AliasCreatedList
user_unlinked: UserUnlinked def __init__(self, user_plan_change: _Optional[_Union[UserPlanChanged, _Mapping]] = ..., user_deleted: _Optional[_Union[UserDeleted, _Mapping]] = ..., alias_created: _Optional[_Union[AliasCreated, _Mapping]] = ..., alias_status_change: _Optional[_Union[AliasStatusChanged, _Mapping]] = ..., alias_deleted: _Optional[_Union[AliasDeleted, _Mapping]] = ..., alias_create_list: _Optional[_Union[AliasCreatedList, _Mapping]] = ...) -> None: ...
def __init__(self, user_plan_change: _Optional[_Union[UserPlanChanged, _Mapping]] = ..., user_deleted: _Optional[_Union[UserDeleted, _Mapping]] = ..., alias_created: _Optional[_Union[AliasCreated, _Mapping]] = ..., alias_status_change: _Optional[_Union[AliasStatusChanged, _Mapping]] = ..., alias_deleted: _Optional[_Union[AliasDeleted, _Mapping]] = ..., alias_create_list: _Optional[_Union[AliasCreatedList, _Mapping]] = ..., user_unlinked: _Optional[_Union[UserUnlinked, _Mapping]] = ...) -> None: ...
class Event(_message.Message): class Event(_message.Message):
__slots__ = ("user_id", "external_user_id", "partner_id", "content") __slots__ = ("user_id", "external_user_id", "partner_id", "content")

View File

@ -37,7 +37,7 @@ from app.models import (
PartnerSubscription, PartnerSubscription,
) )
from app.pgp_utils import load_public_key from app.pgp_utils import load_public_key
from app.proton.proton_partner import get_proton_partner from app.proton.utils import get_proton_partner
def fake_data(): def fake_data():
@ -90,7 +90,7 @@ def fake_data():
user_id=user.id, user_id=user.id,
alias_id=alias.id, alias_id=alias.id,
website_email="hey@google.com", website_email="hey@google.com",
reply_email="rep@sl.lan", reply_email="rep@sl.local",
commit=True, commit=True,
) )
EmailLog.create( EmailLog.create(
@ -166,7 +166,7 @@ def fake_data():
# user_id=user.id, # user_id=user.id,
# alias_id=a.id, # alias_id=a.id,
# website_email=f"contact{i}@example.com", # website_email=f"contact{i}@example.com",
# reply_email=f"rep{i}@sl.lan", # reply_email=f"rep{i}@sl.local",
# ) # )
# Session.commit() # Session.commit()
# for _ in range(3): # for _ in range(3):

View File

@ -12,7 +12,6 @@ import arrow
import sqlalchemy import sqlalchemy
from app import config from app import config
from app.constants import JobType
from app.db import Session from app.db import Session
from app.email import headers from app.email import headers
from app.email_utils import ( from app.email_utils import (
@ -175,7 +174,7 @@ class ExportUserDataJob:
jobs_in_db = ( jobs_in_db = (
Session.query(Job) Session.query(Job)
.filter( .filter(
Job.name == JobType.SEND_USER_REPORT.value, Job.name == config.JOB_SEND_USER_REPORT,
Job.payload.op("->")("user_id").cast(sqlalchemy.TEXT) Job.payload.op("->")("user_id").cast(sqlalchemy.TEXT)
== str(self._user.id), == str(self._user.id),
Job.taken.is_(False), Job.taken.is_(False),
@ -185,7 +184,7 @@ class ExportUserDataJob:
if jobs_in_db > 0: if jobs_in_db > 0:
return None return None
return Job.create( return Job.create(
name=JobType.SEND_USER_REPORT.value, name=config.JOB_SEND_USER_REPORT,
payload={"user_id": self._user.id}, payload={"user_id": self._user.id},
run_at=arrow.now(), run_at=arrow.now(),
commit=True, commit=True,

View File

@ -5,7 +5,7 @@ from typing import Optional
import arrow import arrow
from app.constants import JobType from app import config
from app.errors import ProtonPartnerNotSetUp from app.errors import ProtonPartnerNotSetUp
from app.events.generated import event_pb2 from app.events.generated import event_pb2
from app.events.generated.event_pb2 import EventContent from app.events.generated.event_pb2 import EventContent
@ -14,7 +14,7 @@ from app.models import (
Job, Job,
PartnerUser, PartnerUser,
) )
from app.proton.proton_partner import get_proton_partner from app.proton.utils import get_proton_partner
from events.event_sink import EventSink from events.event_sink import EventSink
@ -57,16 +57,14 @@ class SendEventToWebhookJob:
return SendEventToWebhookJob(user=user, event=event) return SendEventToWebhookJob(user=user, event=event)
def store_job_in_db( def store_job_in_db(self, run_at: Optional[arrow.Arrow]) -> Job:
self, run_at: Optional[arrow.Arrow], commit: bool = True
) -> Job:
stub = self._event.SerializeToString() stub = self._event.SerializeToString()
return Job.create( return Job.create(
name=JobType.SEND_EVENT_TO_WEBHOOK.value, name=config.JOB_SEND_EVENT_TO_WEBHOOK,
payload={ payload={
"user_id": self._user.id, "user_id": self._user.id,
"event": base64.b64encode(stub).decode("utf-8"), "event": base64.b64encode(stub).decode("utf-8"),
}, },
run_at=run_at if run_at is not None else arrow.now(), run_at=run_at if run_at is not None else arrow.now(),
commit=commit, commit=True,
) )

View File

@ -10,7 +10,7 @@ from app.config import (
# this format allows clickable link to code source in PyCharm # this format allows clickable link to code source in PyCharm
_log_format = ( _log_format = (
"%(asctime)s - %(name)s - %(levelname)s - %(process)d - %(request_id)s" "%(asctime)s - %(name)s - %(levelname)s - %(process)d - "
'"%(pathname)s:%(lineno)d" - %(funcName)s() - %(message_id)s - %(message)s' '"%(pathname)s:%(lineno)d" - %(funcName)s() - %(message_id)s - %(message)s'
) )
_log_formatter = logging.Formatter(_log_format) _log_formatter = logging.Formatter(_log_format)
@ -37,21 +37,6 @@ class EmailHandlerFilter(logging.Filter):
return _MESSAGE_ID return _MESSAGE_ID
class RequestIdFilter(logging.Filter):
"""automatically add request-id to keep track of a request"""
def filter(self, record):
from flask import g, has_request_context
request_id = ""
if has_request_context() and hasattr(g, "request_id"):
ctx_request_id = getattr(g, "request_id")
if ctx_request_id:
request_id = f"{ctx_request_id} - "
record.request_id = request_id
return True
def _get_console_handler(): def _get_console_handler():
console_handler = logging.StreamHandler(sys.stdout) console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(_log_formatter) console_handler.setFormatter(_log_formatter)
@ -69,7 +54,6 @@ def _get_logger(name) -> logging.Logger:
logger.addHandler(_get_console_handler()) logger.addHandler(_get_console_handler())
logger.addFilter(EmailHandlerFilter()) logger.addFilter(EmailHandlerFilter())
logger.addFilter(RequestIdFilter())
# no propagation to avoid propagating to root logger # no propagation to avoid propagating to root logger
logger.propagate = False logger.propagate = False

View File

@ -2,12 +2,10 @@ import dataclasses
import secrets import secrets
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional
import arrow import arrow
from sqlalchemy.exc import IntegrityError
from app import config from app import config
from app.constants import JobType from app.config import JOB_DELETE_MAILBOX
from app.db import Session from app.db import Session
from app.email_utils import ( from app.email_utils import (
mailbox_already_used, mailbox_already_used,
@ -156,7 +154,7 @@ def delete_mailbox(
f"User {user} has scheduled delete mailbox job for {mailbox.id} with transfer to mailbox {transfer_mailbox_id}" f"User {user} has scheduled delete mailbox job for {mailbox.id} with transfer to mailbox {transfer_mailbox_id}"
) )
Job.create( Job.create(
name=JobType.DELETE_MAILBOX.value, name=JOB_DELETE_MAILBOX,
payload={ payload={
"mailbox_id": mailbox.id, "mailbox_id": mailbox.id,
"transfer_mailbox_id": transfer_mailbox_id "transfer_mailbox_id": transfer_mailbox_id
@ -247,7 +245,7 @@ def verify_mailbox_code(user: User, mailbox_id: int, code: str) -> Mailbox:
message=f"Verify mailbox {mailbox_id} ({mailbox.email})", message=f"Verify mailbox {mailbox_id} ({mailbox.email})",
) )
if Mailbox.get_by(email=mailbox.new_email, user_id=user.id): if Mailbox.get_by(email=mailbox.new_email, user_id=user.id):
raise MailboxError("That address is already in use") raise MailboxError("That addres is already in use")
else: else:
LOG.i( LOG.i(
@ -353,8 +351,6 @@ def request_mailbox_email_change(
check_email_for_mailbox(new_email, user) check_email_for_mailbox(new_email, user)
if email_ownership_verified: if email_ownership_verified:
mailbox.email = new_email mailbox.email = new_email
mailbox.new_email = None
mailbox.verified = True
else: else:
mailbox.new_email = new_email mailbox.new_email = new_email
emit_user_audit_log( emit_user_audit_log(
@ -362,12 +358,7 @@ def request_mailbox_email_change(
action=UserAuditLogAction.UpdateMailbox, action=UserAuditLogAction.UpdateMailbox,
message=f"Updated mailbox {mailbox.id} email ({new_email}) pre-verified({email_ownership_verified}", message=f"Updated mailbox {mailbox.id} email ({new_email}) pre-verified({email_ownership_verified}",
) )
try: Session.commit()
Session.commit()
except IntegrityError:
LOG.i(f"This email {new_email} is already pending for some mailbox")
Session.rollback()
raise MailboxError("Email already in use")
if email_ownership_verified: if email_ownership_verified:
LOG.i(f"User {user} as created a pre-verified mailbox with {new_email}") LOG.i(f"User {user} as created a pre-verified mailbox with {new_email}")

View File

@ -30,9 +30,9 @@ from sqlalchemy_utils import ArrowType
from app import config, rate_limiter from app import config, rate_limiter
from app import s3 from app import s3
from app.constants import JobType
from app.db import Session from app.db import Session
from app.dns_utils import get_mx_domains from app.dns_utils import get_mx_domains
from app.errors import ( from app.errors import (
AliasInTrashError, AliasInTrashError,
DirectoryInTrashError, DirectoryInTrashError,
@ -239,7 +239,6 @@ class AuditLogActionEnum(EnumE):
disable_user = 9 disable_user = 9
enable_user = 10 enable_user = 10
stop_trial = 11 stop_trial = 11
unlink_user = 12
class Phase(EnumE): class Phase(EnumE):
@ -276,12 +275,6 @@ class AliasDeleteReason(EnumE):
CustomDomainDeleted = 5 CustomDomainDeleted = 5
class JobPriority(EnumE):
Low = 1
Default = 50
High = 100
class IntEnumType(sa.types.TypeDecorator): class IntEnumType(sa.types.TypeDecorator):
impl = sa.Integer impl = sa.Integer
@ -369,7 +362,7 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
sa.Boolean, default=True, nullable=False, server_default="1" sa.Boolean, default=True, nullable=False, server_default="1"
) )
activated = sa.Column(sa.Boolean, default=False, nullable=False) activated = sa.Column(sa.Boolean, default=False, nullable=False, index=True)
# an account can be disabled if having harmful behavior # an account can be disabled if having harmful behavior
disabled = sa.Column(sa.Boolean, default=False, nullable=False, server_default="0") disabled = sa.Column(sa.Boolean, default=False, nullable=False, server_default="0")
@ -583,12 +576,6 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
"ix_users_default_alias_custom_domain_id", default_alias_custom_domain_id "ix_users_default_alias_custom_domain_id", default_alias_custom_domain_id
), ),
sa.Index("ix_users_profile_picture_id", profile_picture_id), sa.Index("ix_users_profile_picture_id", profile_picture_id),
sa.Index(
"idx_users_email_trgm",
"email",
postgresql_ops={"email": "gin_trgm_ops"},
postgresql_using="gin",
),
) )
@property @property
@ -657,7 +644,7 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
user.notification = False user.notification = False
user.trial_end = None user.trial_end = None
Job.create( Job.create(
name=JobType.SEND_PROTON_WELCOME_1.value, name=config.JOB_SEND_PROTON_WELCOME_1,
payload={"user_id": user.id}, payload={"user_id": user.id},
run_at=arrow.now(), run_at=arrow.now(),
) )
@ -683,17 +670,17 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
# Schedule onboarding emails # Schedule onboarding emails
Job.create( Job.create(
name=JobType.ONBOARDING_1.value, name=config.JOB_ONBOARDING_1,
payload={"user_id": user.id}, payload={"user_id": user.id},
run_at=arrow.now().shift(days=1), run_at=arrow.now().shift(days=1),
) )
Job.create( Job.create(
name=JobType.ONBOARDING_2.value, name=config.JOB_ONBOARDING_2,
payload={"user_id": user.id}, payload={"user_id": user.id},
run_at=arrow.now().shift(days=2), run_at=arrow.now().shift(days=2),
) )
Job.create( Job.create(
name=JobType.ONBOARDING_4.value, name=config.JOB_ONBOARDING_4,
payload={"user_id": user.id}, payload={"user_id": user.id},
run_at=arrow.now().shift(days=3), run_at=arrow.now().shift(days=3),
) )
@ -1937,16 +1924,13 @@ class Contact(Base, ModelMixin):
__table_args__ = ( __table_args__ = (
sa.UniqueConstraint("alias_id", "website_email", name="uq_contact"), sa.UniqueConstraint("alias_id", "website_email", name="uq_contact"),
sa.Index("ix_contact_user_id_id", "user_id", "id"),
) )
user_id = sa.Column( user_id = sa.Column(
sa.ForeignKey(User.id, ondelete="cascade"), sa.ForeignKey(User.id, ondelete="cascade"), nullable=False, index=True
nullable=False,
) )
alias_id = sa.Column( alias_id = sa.Column(
sa.ForeignKey(Alias.id, ondelete="cascade"), sa.ForeignKey(Alias.id, ondelete="cascade"), nullable=False, index=True
nullable=False,
) )
name = sa.Column( name = sa.Column(
@ -2131,10 +2115,11 @@ class EmailLog(Base, ModelMixin):
Index("ix_email_log_mailbox_id", "mailbox_id"), Index("ix_email_log_mailbox_id", "mailbox_id"),
Index("ix_email_log_bounced_mailbox_id", "bounced_mailbox_id"), Index("ix_email_log_bounced_mailbox_id", "bounced_mailbox_id"),
Index("ix_email_log_refused_email_id", "refused_email_id"), Index("ix_email_log_refused_email_id", "refused_email_id"),
Index("ix_email_log_user_id_email_log_id", "user_id", "id"),
) )
user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) user_id = sa.Column(
sa.ForeignKey(User.id, ondelete="cascade"), nullable=False, index=True
)
contact_id = sa.Column( contact_id = sa.Column(
sa.ForeignKey(Contact.id, ondelete="cascade"), nullable=False, index=True sa.ForeignKey(Contact.id, ondelete="cascade"), nullable=False, index=True
) )
@ -2410,8 +2395,7 @@ class AliasUsedOn(Base, ModelMixin):
) )
alias_id = sa.Column( alias_id = sa.Column(
sa.ForeignKey(Alias.id, ondelete="cascade"), sa.ForeignKey(Alias.id, ondelete="cascade"), nullable=False, index=True
nullable=False,
) )
user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False)
@ -2434,7 +2418,10 @@ class ApiKey(Base, ModelMixin):
user = orm.relationship(User) user = orm.relationship(User)
__table_args__ = (sa.Index("ix_api_key_user_id", "user_id"),) __table_args__ = (
sa.Index("ix_api_key_code", "code"),
sa.Index("ix_api_key_user_id", "user_id"),
)
@classmethod @classmethod
def create(cls, user_id, name=None, **kwargs): def create(cls, user_id, name=None, **kwargs):
@ -2594,6 +2581,7 @@ class AutoCreateRule(Base, ModelMixin):
sa.UniqueConstraint( sa.UniqueConstraint(
"custom_domain_id", "order", name="uq_auto_create_rule_order" "custom_domain_id", "order", name="uq_auto_create_rule_order"
), ),
sa.Index("ix_auto_create_rule_custom_domain_id", "custom_domain_id"),
) )
custom_domain_id = sa.Column( custom_domain_id = sa.Column(
@ -2776,19 +2764,12 @@ class Job(Base, ModelMixin):
nullable=False, nullable=False,
server_default=str(JobState.ready.value), server_default=str(JobState.ready.value),
default=JobState.ready.value, default=JobState.ready.value,
index=True,
) )
attempts = sa.Column(sa.Integer, nullable=False, server_default="0", default=0) attempts = sa.Column(sa.Integer, nullable=False, server_default="0", default=0)
taken_at = sa.Column(ArrowType, nullable=True) taken_at = sa.Column(ArrowType, nullable=True)
priority = sa.Column(
IntEnumType(JobPriority),
default=JobPriority.Default,
server_default=str(JobPriority.Default.value),
nullable=False,
)
__table_args__ = ( __table_args__ = (Index("ix_state_run_at_taken_at", state, run_at, taken_at),)
Index("ix_state_run_at_taken_at_priority", state, run_at, taken_at, priority),
)
def __repr__(self): def __repr__(self):
return f"<Job {self.id} {self.name} {self.payload}>" return f"<Job {self.id} {self.name} {self.payload}>"
@ -2796,7 +2777,9 @@ class Job(Base, ModelMixin):
class Mailbox(Base, ModelMixin): class Mailbox(Base, ModelMixin):
__tablename__ = "mailbox" __tablename__ = "mailbox"
user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) user_id = sa.Column(
sa.ForeignKey(User.id, ondelete="cascade"), nullable=False, index=True
)
email = sa.Column(sa.String(256), nullable=False, index=True) email = sa.Column(sa.String(256), nullable=False, index=True)
verified = sa.Column(sa.Boolean, default=False, nullable=False) verified = sa.Column(sa.Boolean, default=False, nullable=False)
force_spf = sa.Column(sa.Boolean, default=True, server_default="1", nullable=False) force_spf = sa.Column(sa.Boolean, default=True, server_default="1", nullable=False)
@ -2825,13 +2808,6 @@ class Mailbox(Base, ModelMixin):
__table_args__ = ( __table_args__ = (
sa.UniqueConstraint("user_id", "email", name="uq_mailbox_user"), sa.UniqueConstraint("user_id", "email", name="uq_mailbox_user"),
sa.Index("ix_mailbox_pgp_finger_print", "pgp_finger_print"), sa.Index("ix_mailbox_pgp_finger_print", "pgp_finger_print"),
# index on email column using pg_trgm
Index(
"ix_mailbox_email_trgm_idx",
"email",
postgresql_ops={"email": "gin_trgm_ops"},
postgresql_using="gin",
),
) )
user = orm.relationship(User, foreign_keys=[user_id]) user = orm.relationship(User, foreign_keys=[user_id])
@ -2854,20 +2830,24 @@ class Mailbox(Base, ModelMixin):
return len(alias_ids) return len(alias_ids)
def is_proton(self) -> bool: def is_proton(self) -> bool:
for proton_email_domain in config.PROTON_EMAIL_DOMAINS: if (
if self.email.endswith(f"@{proton_email_domain}"): self.email.endswith("@proton.me")
return True or self.email.endswith("@protonmail.com")
or self.email.endswith("@protonmail.ch")
or self.email.endswith("@proton.ch")
or self.email.endswith("@pm.me")
):
return True
from app.email_utils import get_email_local_part from app.email_utils import get_email_local_part
mx_domains = get_mx_domains(get_email_local_part(self.email)) mx_domains = get_mx_domains(get_email_local_part(self.email))
proton_mx_domains = config.PROTON_MX_SERVERS
# Proton is the first domain # Proton is the first domain
for prio in mx_domains: if mx_domains and mx_domains[0].domain in (
for mx_domain in mx_domains[prio]: "mail.protonmail.ch.",
if mx_domain in proton_mx_domains: "mailsec.protonmail.ch.",
return True ):
return True
return False return False
@ -3030,11 +3010,7 @@ class SentAlert(Base, ModelMixin):
to_email = sa.Column(sa.String(256), nullable=False) to_email = sa.Column(sa.String(256), nullable=False)
alert_type = sa.Column(sa.String(256), nullable=False) alert_type = sa.Column(sa.String(256), nullable=False)
__table_args__ = ( __table_args__ = (sa.Index("ix_sent_alert_user_id", "user_id"),)
sa.Index("ix_sent_alert_user_id", "user_id"),
sa.Index("ix_sent_alert_to_email", "to_email"),
sa.Index("ix_sent_alert_alert_type", "alert_type"),
)
class AliasMailbox(Base, ModelMixin): class AliasMailbox(Base, ModelMixin):
@ -3044,8 +3020,7 @@ class AliasMailbox(Base, ModelMixin):
) )
alias_id = sa.Column( alias_id = sa.Column(
sa.ForeignKey(Alias.id, ondelete="cascade"), sa.ForeignKey(Alias.id, ondelete="cascade"), nullable=False, index=True
nullable=False,
) )
mailbox_id = sa.Column( mailbox_id = sa.Column(
sa.ForeignKey(Mailbox.id, ondelete="cascade"), nullable=False, index=True sa.ForeignKey(Mailbox.id, ondelete="cascade"), nullable=False, index=True
@ -3060,8 +3035,7 @@ class AliasHibp(Base, ModelMixin):
__table_args__ = (sa.UniqueConstraint("alias_id", "hibp_id", name="uq_alias_hibp"),) __table_args__ = (sa.UniqueConstraint("alias_id", "hibp_id", name="uq_alias_hibp"),)
alias_id = sa.Column( alias_id = sa.Column(
sa.Integer(), sa.Integer(), sa.ForeignKey("alias.id", ondelete="cascade"), index=True
sa.ForeignKey("alias.id", ondelete="cascade"),
) )
hibp_id = sa.Column( hibp_id = sa.Column(
sa.Integer(), sa.ForeignKey("hibp.id", ondelete="cascade"), index=True sa.Integer(), sa.ForeignKey("hibp.id", ondelete="cascade"), index=True
@ -3777,8 +3751,7 @@ class PartnerUser(Base, ModelMixin):
index=True, index=True,
) )
partner_id = sa.Column( partner_id = sa.Column(
sa.ForeignKey("partner.id", ondelete="cascade"), sa.ForeignKey("partner.id", ondelete="cascade"), nullable=False, index=True
nullable=False,
) )
external_user_id = sa.Column(sa.String(128), unique=False, nullable=False) external_user_id = sa.Column(sa.String(128), unique=False, nullable=False)
partner_email = sa.Column(sa.String(255), unique=False, nullable=True) partner_email = sa.Column(sa.String(255), unique=False, nullable=True)

View File

@ -1,4 +1,4 @@
from app.build_info import SHA1, VERSION from app.build_info import SHA1
from app.monitor.base import monitor_bp from app.monitor.base import monitor_bp
@ -7,11 +7,6 @@ def git_sha1():
return SHA1 return SHA1
@monitor_bp.route("/version")
def version():
return VERSION
@monitor_bp.route("/live") @monitor_bp.route("/live")
def live(): def live():
return "live" return "live"

View File

@ -1,8 +0,0 @@
from app.build_info import VERSION
import newrelic.agent
def send_version_event(service: str):
newrelic.agent.record_custom_event(
"ServiceVersion", {"service": service, "version": VERSION}
)

View File

@ -3,7 +3,7 @@ from typing import Optional
import arrow import arrow
from arrow import Arrow from arrow import Arrow
from app.constants import JobType from app import config
from app.models import PartnerUser, PartnerSubscription, User, Job from app.models import PartnerUser, PartnerSubscription, User, Job
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
@ -18,7 +18,7 @@ def create_partner_user(
external_user_id=external_user_id, external_user_id=external_user_id,
) )
Job.create( Job.create(
name=JobType.SEND_ALIAS_CREATION_EVENTS.value, name=config.JOB_SEND_ALIAS_CREATION_EVENTS,
payload={"user_id": user.id}, payload={"user_id": user.id},
run_at=arrow.now(), run_at=arrow.now(),
) )

View File

@ -1,23 +0,0 @@
from typing import Optional
from app.db import Session
from app.errors import ProtonPartnerNotSetUp
from app.models import Partner
PROTON_PARTNER_NAME = "Proton"
_PROTON_PARTNER: Optional[Partner] = None
def get_proton_partner() -> Partner:
global _PROTON_PARTNER
if _PROTON_PARTNER is None:
partner = Partner.get_by(name=PROTON_PARTNER_NAME)
if partner is None:
raise ProtonPartnerNotSetUp
Session.expunge(partner)
_PROTON_PARTNER = partner
return _PROTON_PARTNER
def is_proton_partner(partner: Partner) -> bool:
return partner.name == PROTON_PARTNER_NAME

View File

@ -1,23 +1,39 @@
from typing import Optional
from newrelic import agent from newrelic import agent
from app.db import Session from app.db import Session
from app.events.event_dispatcher import EventDispatcher from app.errors import ProtonPartnerNotSetUp
from app.events.generated.event_pb2 import EventContent, UserUnlinked
from app.log import LOG from app.log import LOG
from app.models import User, PartnerUser from app.models import Partner, PartnerUser, User
from app.proton.proton_partner import get_proton_partner
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
PROTON_PARTNER_NAME = "Proton"
_PROTON_PARTNER: Optional[Partner] = None
def get_proton_partner() -> Partner:
global _PROTON_PARTNER
if _PROTON_PARTNER is None:
partner = Partner.get_by(name=PROTON_PARTNER_NAME)
if partner is None:
raise ProtonPartnerNotSetUp
Session.expunge(partner)
_PROTON_PARTNER = partner
return _PROTON_PARTNER
def is_proton_partner(partner: Partner) -> bool:
return partner.name == PROTON_PARTNER_NAME
def can_unlink_proton_account(user: User) -> bool: def can_unlink_proton_account(user: User) -> bool:
return (user.flags & User.FLAG_CREATED_FROM_PARTNER) == 0 return (user.flags & User.FLAG_CREATED_FROM_PARTNER) == 0
def perform_proton_account_unlink( def perform_proton_account_unlink(current_user: User) -> bool:
current_user: User, skip_check: bool = False if not can_unlink_proton_account(current_user):
) -> None | str: return False
if not skip_check and not can_unlink_proton_account(current_user):
return None
proton_partner = get_proton_partner() proton_partner = get_proton_partner()
partner_user = PartnerUser.get_by( partner_user = PartnerUser.get_by(
user_id=current_user.id, partner_id=proton_partner.id user_id=current_user.id, partner_id=proton_partner.id
@ -29,11 +45,7 @@ def perform_proton_account_unlink(
action=UserAuditLogAction.UnlinkAccount, action=UserAuditLogAction.UnlinkAccount,
message=f"User has unlinked the account (email={partner_user.partner_email} | external_user_id={partner_user.external_user_id})", message=f"User has unlinked the account (email={partner_user.partner_email} | external_user_id={partner_user.external_user_id})",
) )
EventDispatcher.send_event(
partner_user.user, EventContent(user_unlinked=UserUnlinked())
)
PartnerUser.delete(partner_user.id) PartnerUser.delete(partner_user.id)
external_user_id = partner_user.external_user_id
Session.commit() Session.commit()
agent.record_custom_event("AccountUnlinked", {"partner": proton_partner.name}) agent.record_custom_event("AccountUnlinked", {"partner": proton_partner.name})
return external_user_id return True

View File

@ -1,6 +0,0 @@
from random import randbytes
from base64 import b64encode
def generate_request_id() -> str:
return b64encode(randbytes(6)).decode()

View File

@ -1,7 +1,6 @@
"""Inspired from """Inspired from
https://github.com/petermat/spamassassin_client https://github.com/petermat/spamassassin_client
""" """
import logging import logging
import socket import socket
from io import BytesIO from io import BytesIO

View File

@ -14,9 +14,9 @@ from sqlalchemy.sql import Insert, text
from app import s3, config from app import s3, config
from app.alias_utils import nb_email_log_for_mailbox from app.alias_utils import nb_email_log_for_mailbox
from app.api.views.apple import verify_receipt from app.api.views.apple import verify_receipt
from app.custom_domain_validation import CustomDomainValidation, is_mx_equivalent from app.custom_domain_validation import CustomDomainValidation
from app.db import Session from app.db import Session
from app.dns_utils import get_mx_domains from app.dns_utils import get_mx_domains, is_mx_equivalent
from app.email_utils import ( from app.email_utils import (
send_email, send_email,
send_trial_end_soon_email, send_trial_end_soon_email,
@ -59,7 +59,7 @@ from app.models import (
ApiToCookieToken, ApiToCookieToken,
) )
from app.pgp_utils import load_public_key_and_check, PGPException from app.pgp_utils import load_public_key_and_check, PGPException
from app.proton.proton_partner import get_proton_partner from app.proton.utils import get_proton_partner
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from app.utils import sanitize_email from app.utils import sanitize_email
from server import create_light_app from server import create_light_app

View File

@ -369,8 +369,8 @@ For ex:
"is_premium": false "is_premium": false
}, },
{ {
"signed_suffix": ".yeah@sl.lan.X6_7OQ.i8XL4xsMsn7dxDEWU8eF-Zap0qo", "signed_suffix": ".yeah@sl.local.X6_7OQ.i8XL4xsMsn7dxDEWU8eF-Zap0qo",
"suffix": ".yeah@sl.lan", "suffix": ".yeah@sl.local",
"is_custom": true, "is_custom": true,
"is_premium": false "is_premium": false
} }
@ -465,7 +465,7 @@ Here's an example:
{ {
"creation_date": "2020-04-06 17:57:14+00:00", "creation_date": "2020-04-06 17:57:14+00:00",
"creation_timestamp": 1586195834, "creation_timestamp": 1586195834,
"email": "prefix1.cat@sl.lan", "email": "prefix1.cat@sl.local",
"name": "A Name", "name": "A Name",
"enabled": true, "enabled": true,
"id": 3, "id": 3,
@ -518,7 +518,7 @@ Alias info, use the same format as in /api/v2/aliases. For example:
{ {
"creation_date": "2020-04-06 17:57:14+00:00", "creation_date": "2020-04-06 17:57:14+00:00",
"creation_timestamp": 1586195834, "creation_timestamp": 1586195834,
"email": "prefix1.cat@sl.lan", "email": "prefix1.cat@sl.local",
"name": "A Name", "name": "A Name",
"enabled": true, "enabled": true,
"id": 3, "id": 3,
@ -608,7 +608,7 @@ If success, 200 with the list of activities, for example:
"activities": [ "activities": [
{ {
"action": "reply", "action": "reply",
"from": "yes_meo_chat@sl.lan", "from": "yes_meo_chat@sl.local",
"timestamp": 1580903760, "timestamp": 1580903760,
"to": "marketing@example.com", "to": "marketing@example.com",
"reverse_alias": "\"marketing at example.com\" <reply@a.b>", "reverse_alias": "\"marketing at example.com\" <reply@a.b>",
@ -703,7 +703,7 @@ Return 200 and `existed=true` if contact is already added.
"creation_timestamp": 1584186761, "creation_timestamp": 1584186761,
"last_email_sent_date": null, "last_email_sent_date": null,
"last_email_sent_timestamp": null, "last_email_sent_timestamp": null,
"reverse_alias": "First Last first@example.com <ra+qytyzjhrumrreuszrbjxqjlkh@sl.lan>", "reverse_alias": "First Last first@example.com <ra+qytyzjhrumrreuszrbjxqjlkh@sl.local>",
"reverse_alias_address": "reply+bzvpazcdedcgcpztehxzgjgzmxskqa@sl.co", "reverse_alias_address": "reply+bzvpazcdedcgcpztehxzgjgzmxskqa@sl.co",
"existed": false "existed": false
} }
@ -992,7 +992,7 @@ Return user setting.
{ {
"alias_generator": "word", "alias_generator": "word",
"notification": true, "notification": true,
"random_alias_default_domain": "sl.lan", "random_alias_default_domain": "sl.local",
"sender_format": "AT", "sender_format": "AT",
"random_alias_suffix": "random_string" "random_alias_suffix": "random_string"
} }
@ -1029,7 +1029,7 @@ Return domains that user can use to create random alias
"is_custom": false "is_custom": false
}, },
{ {
"domain": "sl.lan", "domain": "sl.local",
"is_custom": false "is_custom": false
}, },
{ {

View File

@ -30,7 +30,6 @@ It should contain the following info:
""" """
import argparse import argparse
import email import email
import time import time
@ -168,7 +167,6 @@ from app.models import (
VerpType, VerpType,
SLDomain, SLDomain,
) )
from app.monitor_utils import send_version_event
from app.pgp_utils import ( from app.pgp_utils import (
PGPException, PGPException,
sign_data_with_pgpy, sign_data_with_pgpy,
@ -1669,7 +1667,7 @@ def handle_bounce_reply_phase(envelope, msg: Message, email_log: EmailLog):
) )
Notification.create( Notification.create(
user_id=user.id, user_id=user.id,
title=f"Email cannot be sent to {contact.email} from your alias {alias.email}", title=f"Email cannot be sent to { contact.email } from your alias { alias.email }",
message=Notification.render( message=Notification.render(
"notification/bounce-reply-phase.html", "notification/bounce-reply-phase.html",
alias=alias, alias=alias,
@ -1682,7 +1680,7 @@ def handle_bounce_reply_phase(envelope, msg: Message, email_log: EmailLog):
user, user,
ALERT_BOUNCE_EMAIL_REPLY_PHASE, ALERT_BOUNCE_EMAIL_REPLY_PHASE,
mailbox.email, mailbox.email,
f"Email cannot be sent to {contact.email} from your alias {alias.email}", f"Email cannot be sent to { contact.email } from your alias { alias.email }",
render( render(
"transactional/bounce/bounce-email-reply-phase.txt", "transactional/bounce/bounce-email-reply-phase.txt",
user=user, user=user,
@ -2362,7 +2360,6 @@ class MailHandler:
"Custom/nb_rcpt_tos", len(envelope.rcpt_tos) "Custom/nb_rcpt_tos", len(envelope.rcpt_tos)
) )
send_version_event("email_handler")
with create_light_app().app_context(): with create_light_app().app_context():
return_status = handle(envelope, msg) return_status = handle(envelope, msg)
elapsed = time.time() - start elapsed = time.time() - start
@ -2398,7 +2395,6 @@ def main(port: int):
controller.start() controller.start()
LOG.d("Start mail controller %s %s", controller.hostname, controller.port) LOG.d("Start mail controller %s %s", controller.hostname, controller.port)
send_version_event("email_handler")
if LOAD_PGP_EMAIL_HANDLER: if LOAD_PGP_EMAIL_HANDLER:
LOG.w("LOAD PGP keys") LOG.w("LOAD PGP keys")

View File

@ -4,7 +4,6 @@ from sys import argv, exit
from app.config import EVENT_LISTENER_DB_URI from app.config import EVENT_LISTENER_DB_URI
from app.log import LOG from app.log import LOG
from app.monitor_utils import send_version_event
from events import event_debugger from events import event_debugger
from events.runner import Runner from events.runner import Runner
from events.event_source import DeadLetterEventSource, PostgresEventSource from events.event_source import DeadLetterEventSource, PostgresEventSource
@ -31,11 +30,9 @@ def main(mode: Mode, dry_run: bool, max_retries: int):
if mode == Mode.DEAD_LETTER: if mode == Mode.DEAD_LETTER:
LOG.i("Using DeadLetterEventSource") LOG.i("Using DeadLetterEventSource")
source = DeadLetterEventSource(max_retries) source = DeadLetterEventSource(max_retries)
service_name = "event_listener_dead_letter"
elif mode == Mode.LISTENER: elif mode == Mode.LISTENER:
LOG.i("Using PostgresEventSource") LOG.i("Using PostgresEventSource")
source = PostgresEventSource(EVENT_LISTENER_DB_URI) source = PostgresEventSource(EVENT_LISTENER_DB_URI)
service_name = "event_listener"
else: else:
raise ValueError(f"Invalid mode: {mode}") raise ValueError(f"Invalid mode: {mode}")
@ -46,8 +43,7 @@ def main(mode: Mode, dry_run: bool, max_retries: int):
LOG.i("Starting with HttpEventSink") LOG.i("Starting with HttpEventSink")
sink = HttpEventSink() sink = HttpEventSink()
send_version_event(service_name) runner = Runner(source=source, sink=sink)
runner = Runner(source=source, sink=sink, service_name=service_name)
runner.run() runner.run()

View File

@ -4,24 +4,20 @@ import newrelic.agent
from app.log import LOG from app.log import LOG
from app.db import Session from app.db import Session
from app.models import SyncEvent from app.models import SyncEvent
from app.monitor_utils import send_version_event
from events.event_sink import EventSink from events.event_sink import EventSink
from events.event_source import EventSource from events.event_source import EventSource
class Runner: class Runner:
def __init__(self, source: EventSource, sink: EventSink, service_name: str = ""): def __init__(self, source: EventSource, sink: EventSink):
self.__source = source self.__source = source
self.__sink = sink self.__sink = sink
self.__service_name = service_name
def run(self): def run(self):
self.__source.run(self.__on_event) self.__source.run(self.__on_event)
@newrelic.agent.background_task() @newrelic.agent.background_task()
def __on_event(self, event: SyncEvent): def __on_event(self, event: SyncEvent):
if self.__service_name:
send_version_event(self.__service_name)
try: try:
event_created_at = event.created_at event_created_at = event.created_at
start_time = arrow.now() start_time = arrow.now()

View File

@ -19,7 +19,7 @@ URL=http://localhost:7777
NOT_SEND_EMAIL=true NOT_SEND_EMAIL=true
# domain used to create alias # domain used to create alias
EMAIL_DOMAIN=sl.lan EMAIL_DOMAIN=sl.local
# Allow SimpleLogin to enforce SPF by using the extra headers from postfix # Allow SimpleLogin to enforce SPF by using the extra headers from postfix
# ENFORCE_SPF=true # ENFORCE_SPF=true
@ -37,18 +37,18 @@ EMAIL_DOMAIN=sl.lan
# FIRST_ALIAS_DOMAIN = another-domain.com # FIRST_ALIAS_DOMAIN = another-domain.com
# transactional email is sent from this email address # transactional email is sent from this email address
SUPPORT_EMAIL=support@sl.lan SUPPORT_EMAIL=support@sl.local
SUPPORT_NAME=Son from SimpleLogin SUPPORT_NAME=Son from SimpleLogin
# To use VERP # To use VERP
# prefix must end with + and suffix must start with + # prefix must end with + and suffix must start with +
# BOUNCE_PREFIX = "bounces+" # BOUNCE_PREFIX = "bounces+"
# BOUNCE_SUFFIX = "+@sl.lan" # BOUNCE_SUFFIX = "+@sl.local"
# same as BOUNCE_PREFIX but used for reply phase. Note it doesn't have the plus sign (+) at the end. # same as BOUNCE_PREFIX but used for reply phase. Note it doesn't have the plus sign (+) at the end.
# BOUNCE_PREFIX_FOR_REPLY_PHASE = "bounce_reply" # BOUNCE_PREFIX_FOR_REPLY_PHASE = "bounce_reply"
# to receive general stats. # to receive general stats.
# ADMIN_EMAIL=admin@sl.lan # ADMIN_EMAIL=admin@sl.local
# Max number emails user can generate for free plan # Max number emails user can generate for free plan
# Set to 5 by default # Set to 5 by default

View File

@ -6,7 +6,7 @@ from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import Mailbox, Contact, SLDomain, Partner from app.models import Mailbox, Contact, SLDomain, Partner
from app.pgp_utils import load_public_key from app.pgp_utils import load_public_key
from app.proton.proton_partner import PROTON_PARTNER_NAME from app.proton.utils import PROTON_PARTNER_NAME
from server import create_light_app from server import create_light_app

View File

@ -2,18 +2,13 @@
Run scheduled jobs. Run scheduled jobs.
Not meant for running job at precise time (+- 1h) Not meant for running job at precise time (+- 1h)
""" """
import time import time
from typing import List, Optional from typing import List, Optional
import arrow import arrow
import newrelic.agent
from sqlalchemy.orm import Query
from sqlalchemy.orm.exc import ObjectDeletedError
from sqlalchemy.sql.expression import or_, and_ from sqlalchemy.sql.expression import or_, and_
from app import config from app import config
from app.constants import JobType
from app.db import Session from app.db import Session
from app.email_utils import ( from app.email_utils import (
send_email, send_email,
@ -26,13 +21,9 @@ from app.jobs.export_user_data_job import ExportUserDataJob
from app.jobs.send_event_job import SendEventToWebhookJob from app.jobs.send_event_job import SendEventToWebhookJob
from app.log import LOG from app.log import LOG
from app.models import User, Job, BatchImport, Mailbox, CustomDomain, JobState from app.models import User, Job, BatchImport, Mailbox, CustomDomain, JobState
from app.monitor_utils import send_version_event
from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction from app.user_audit_log_utils import emit_user_audit_log, UserAuditLogAction
from events.event_sink import HttpEventSink
from server import create_light_app from server import create_light_app
_MAX_JOBS_PER_BATCH = 50
def onboarding_send_from_alias(user): def onboarding_send_from_alias(user):
comm_email, unsubscribe_link, via_email = user.get_communication_email() comm_email, unsubscribe_link, via_email = user.get_communication_email()
@ -198,8 +189,7 @@ SimpleLogin team.
def process_job(job: Job): def process_job(job: Job):
send_version_event("job_runner") if job.name == config.JOB_ONBOARDING_1:
if job.name == JobType.ONBOARDING_1.value:
user_id = job.payload.get("user_id") user_id = job.payload.get("user_id")
user = User.get(user_id) user = User.get(user_id)
@ -208,7 +198,7 @@ def process_job(job: Job):
if user and user.notification and user.activated: if user and user.notification and user.activated:
LOG.d("send onboarding send-from-alias email to user %s", user) LOG.d("send onboarding send-from-alias email to user %s", user)
onboarding_send_from_alias(user) onboarding_send_from_alias(user)
elif job.name == JobType.ONBOARDING_2.value: elif job.name == config.JOB_ONBOARDING_2:
user_id = job.payload.get("user_id") user_id = job.payload.get("user_id")
user = User.get(user_id) user = User.get(user_id)
@ -217,7 +207,7 @@ def process_job(job: Job):
if user and user.notification and user.activated: if user and user.notification and user.activated:
LOG.d("send onboarding mailbox email to user %s", user) LOG.d("send onboarding mailbox email to user %s", user)
onboarding_mailbox(user) onboarding_mailbox(user)
elif job.name == JobType.ONBOARDING_4.value: elif job.name == config.JOB_ONBOARDING_4:
user_id = job.payload.get("user_id") user_id = job.payload.get("user_id")
user: User = User.get(user_id) user: User = User.get(user_id)
@ -232,11 +222,11 @@ def process_job(job: Job):
LOG.d("send onboarding pgp email to user %s", user) LOG.d("send onboarding pgp email to user %s", user)
onboarding_pgp(user) onboarding_pgp(user)
elif job.name == JobType.BATCH_IMPORT.value: elif job.name == config.JOB_BATCH_IMPORT:
batch_import_id = job.payload.get("batch_import_id") batch_import_id = job.payload.get("batch_import_id")
batch_import = BatchImport.get(batch_import_id) batch_import = BatchImport.get(batch_import_id)
handle_batch_import(batch_import) handle_batch_import(batch_import)
elif job.name == JobType.DELETE_ACCOUNT.value: elif job.name == config.JOB_DELETE_ACCOUNT:
user_id = job.payload.get("user_id") user_id = job.payload.get("user_id")
user = User.get(user_id) user = User.get(user_id)
@ -255,10 +245,10 @@ def process_job(job: Job):
) )
User.delete(user.id) User.delete(user.id)
Session.commit() Session.commit()
elif job.name == JobType.DELETE_MAILBOX.value: elif job.name == config.JOB_DELETE_MAILBOX:
delete_mailbox_job(job) delete_mailbox_job(job)
elif job.name == JobType.DELETE_DOMAIN.value: elif job.name == config.JOB_DELETE_DOMAIN:
custom_domain_id = job.payload.get("custom_domain_id") custom_domain_id = job.payload.get("custom_domain_id")
custom_domain: Optional[CustomDomain] = CustomDomain.get(custom_domain_id) custom_domain: Optional[CustomDomain] = CustomDomain.get(custom_domain_id)
if not custom_domain: if not custom_domain:
@ -295,17 +285,17 @@ def process_job(job: Job):
""", """,
retries=3, retries=3,
) )
elif job.name == JobType.SEND_USER_REPORT.value: elif job.name == config.JOB_SEND_USER_REPORT:
export_job = ExportUserDataJob.create_from_job(job) export_job = ExportUserDataJob.create_from_job(job)
if export_job: if export_job:
export_job.run() export_job.run()
elif job.name == JobType.SEND_PROTON_WELCOME_1.value: elif job.name == config.JOB_SEND_PROTON_WELCOME_1:
user_id = job.payload.get("user_id") user_id = job.payload.get("user_id")
user = User.get(user_id) user = User.get(user_id)
if user and user.activated: if user and user.activated:
LOG.d("Send proton welcome email to user %s", user) LOG.d("Send proton welcome email to user %s", user)
welcome_proton(user) welcome_proton(user)
elif job.name == JobType.SEND_ALIAS_CREATION_EVENTS.value: elif job.name == config.JOB_SEND_ALIAS_CREATION_EVENTS:
user_id = job.payload.get("user_id") user_id = job.payload.get("user_id")
user = User.get(user_id) user = User.get(user_id)
if user and user.activated: if user and user.activated:
@ -313,111 +303,52 @@ def process_job(job: Job):
send_alias_creation_events_for_user( send_alias_creation_events_for_user(
user, dispatcher=PostgresDispatcher.get() user, dispatcher=PostgresDispatcher.get()
) )
elif job.name == JobType.SEND_EVENT_TO_WEBHOOK.value: elif job.name == config.JOB_SEND_EVENT_TO_WEBHOOK:
send_job = SendEventToWebhookJob.create_from_job(job) send_job = SendEventToWebhookJob.create_from_job(job)
if send_job: if send_job:
send_job.run(HttpEventSink()) send_job.run()
else: else:
LOG.e("Unknown job name %s", job.name) LOG.e("Unknown job name %s", job.name)
def get_jobs_to_run_query(taken_before_time: arrow.Arrow) -> Query: def get_jobs_to_run() -> List[Job]:
# Get jobs that match all conditions: # Get jobs that match all conditions:
# - Job.state == ready OR (Job.state == taken AND Job.taken_at < now - 30 mins AND Job.attempts < 5) # - Job.state == ready OR (Job.state == taken AND Job.taken_at < now - 30 mins AND Job.attempts < 5)
# - Job.run_at is Null OR Job.run_at < now + 10 mins # - Job.run_at is Null OR Job.run_at < now + 10 mins
taken_at_earliest = arrow.now().shift(minutes=-config.JOB_TAKEN_RETRY_WAIT_MINS)
run_at_earliest = arrow.now().shift(minutes=+10) run_at_earliest = arrow.now().shift(minutes=+10)
return Job.filter( query = Job.filter(
and_( and_(
or_( or_(
Job.state == JobState.ready.value, Job.state == JobState.ready.value,
and_( and_(
Job.state == JobState.taken.value, Job.state == JobState.taken.value,
Job.taken_at < taken_before_time, Job.taken_at < taken_at_earliest,
Job.attempts < config.JOB_MAX_ATTEMPTS, Job.attempts < config.JOB_MAX_ATTEMPTS,
), ),
), ),
or_(Job.run_at.is_(None), and_(Job.run_at <= run_at_earliest)), or_(Job.run_at.is_(None), and_(Job.run_at <= run_at_earliest)),
) )
) )
return query.all()
def get_jobs_to_run(taken_before_time: arrow.Arrow) -> List[Job]:
query = get_jobs_to_run_query(taken_before_time)
return (
query.order_by(Job.priority.desc())
.order_by(Job.run_at.asc())
.limit(_MAX_JOBS_PER_BATCH)
.all()
)
def take_job(job: Job, taken_before_time: arrow.Arrow) -> bool:
sql = """
UPDATE job
SET
taken_at = :taken_time,
attempts = attempts + 1,
state = :taken_state
WHERE id = :job_id
AND (state = :ready_state OR (state=:taken_state AND taken_at < :taken_before_time))
"""
args = {
"taken_time": arrow.now().datetime,
"job_id": job.id,
"ready_state": JobState.ready.value,
"taken_state": JobState.taken.value,
"taken_before_time": taken_before_time.datetime,
}
try:
res = Session.execute(sql, args)
Session.commit()
except ObjectDeletedError:
return False
return res.rowcount > 0
if __name__ == "__main__": if __name__ == "__main__":
send_version_event("job_runner")
while True: while True:
# wrap in an app context to benefit from app setup like database cleanup, sentry integration, etc # wrap in an app context to benefit from app setup like database cleanup, sentry integration, etc
with create_light_app().app_context(): with create_light_app().app_context():
taken_before_time = arrow.now().shift( for job in get_jobs_to_run():
minutes=-config.JOB_TAKEN_RETRY_WAIT_MINS
)
jobs_done = 0
for job in get_jobs_to_run(taken_before_time):
if not take_job(job, taken_before_time):
continue
LOG.d("Take job %s", job) LOG.d("Take job %s", job)
try: # mark the job as taken, whether it will be executed successfully or not
newrelic.agent.record_custom_event("ProcessJob", {"job": job.name}) job.taken = True
process_job(job) job.taken_at = arrow.now()
job_result = "success" job.state = JobState.taken.value
job.attempts += 1
Session.commit()
process_job(job)
job.state = JobState.done.value job.state = JobState.done.value
jobs_done += 1
except Exception as e:
LOG.warn(f"Error processing job (id={job.id} name={job.name}): {e}")
# Increment manually, as the attempts increment is done by the take_job but not
# updated in our instance
job_attempts = job.attempts + 1
if job_attempts >= config.JOB_MAX_ATTEMPTS:
LOG.warn(
f"Marking job (id={job.id} name={job.name} attempts={job_attempts}) as ERROR"
)
job.state = JobState.error.value
job_result = "error"
else:
job_result = "retry"
newrelic.agent.record_custom_event(
"JobProcessed", {"job": job.name, "result": job_result}
)
Session.commit() Session.commit()
if jobs_done == 0: time.sleep(10)
time.sleep(10)

View File

@ -1,91 +0,0 @@
"""index cleanup
Revision ID: d3ff8848c930
Revises: 085f77996ce3
Create Date: 2025-01-30 15:00:02.995813
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "d3ff8848c930"
down_revision = "085f77996ce3"
branch_labels = None
depends_on = None
def upgrade():
with op.get_context().autocommit_block():
op.drop_index("ix_alias_hibp_alias_id", table_name="alias_hibp")
op.drop_index("ix_alias_mailbox_alias_id", table_name="alias_mailbox")
op.drop_index("ix_alias_used_on_alias_id", table_name="alias_used_on")
op.drop_index("ix_api_key_code", table_name="api_key")
op.drop_index(
"ix_auto_create_rule_custom_domain_id", table_name="auto_create_rule"
)
op.drop_index("ix_contact_alias_id", table_name="contact")
op.create_index(
"ix_email_log_user_id_email_log_id",
"email_log",
["user_id", "id"],
unique=False,
)
op.drop_index("ix_job_state", table_name="job")
op.create_index(
"ix_mailbox_email_trgm_idx",
"mailbox",
["email"],
unique=False,
postgresql_ops={"email": "gin_trgm_ops"},
postgresql_using="gin",
)
op.drop_index("ix_partner_user_partner_id", table_name="partner_user")
op.create_index(
"ix_sent_alert_alert_type", "sent_alert", ["alert_type"], unique=False
)
op.create_index(
"ix_sent_alert_to_email", "sent_alert", ["to_email"], unique=False
)
op.create_index(
"idx_users_email_trgm",
"users",
["email"],
unique=False,
postgresql_ops={"email": "gin_trgm_ops"},
postgresql_using="gin",
)
op.drop_index("ix_users_activated", table_name="users")
op.drop_index("ix_mailbox_user_id", table_name="users")
def downgrade():
with op.get_context().autocommit_block():
op.create_index("ix_users_activated", "users", ["activated"], unique=False)
op.drop_index("idx_users_email_trgm", table_name="users")
op.drop_index("ix_sent_alert_to_email", table_name="sent_alert")
op.drop_index("ix_sent_alert_alert_type", table_name="sent_alert")
op.create_index(
"ix_partner_user_partner_id", "partner_user", ["partner_id"], unique=False
)
op.drop_index("ix_mailbox_email_trgm_idx", table_name="mailbox")
op.create_index("ix_job_state", "job", ["state"], unique=False)
op.drop_index("ix_email_log_user_id_email_log_id", table_name="email_log")
op.create_index("ix_contact_alias_id", "contact", ["alias_id"], unique=False)
op.create_index(
"ix_auto_create_rule_custom_domain_id",
"auto_create_rule",
["custom_domain_id"],
unique=False,
)
op.create_index("ix_api_key_code", "api_key", ["code"], unique=False)
op.create_index(
"ix_alias_used_on_alias_id", "alias_used_on", ["alias_id"], unique=False
)
op.create_index(
"ix_alias_mailbox_alias_id", "alias_mailbox", ["alias_id"], unique=False
)
op.create_index(
"ix_alias_hibp_alias_id", "alias_hibp", ["alias_id"], unique=False
)
op.create_index("ix_mailbox_user_id", "users", ["user_id"], unique=False)

View File

@ -1,23 +0,0 @@
"""index cleanup
Revision ID: 97edba8794f8
Revises: d3ff8848c930
Create Date: 2025-01-31 14:42:22.590597
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = '97edba8794f8'
down_revision = 'd3ff8848c930'
branch_labels = None
depends_on = None
def upgrade():
op.drop_index('ix_email_log_user_id', table_name='email_log')
def downgrade():
op.create_index('ix_email_log_user_id', 'email_log', ['user_id'], unique=False)

View File

@ -1,27 +0,0 @@
"""contact index
Revision ID: 20e7d3ca289a
Revises: 97edba8794f8
Create Date: 2025-02-03 16:52:06.775032
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = '20e7d3ca289a'
down_revision = '97edba8794f8'
branch_labels = None
depends_on = None
def upgrade():
with op.get_context().autocommit_block():
op.create_index('ix_contact_user_id_id', 'contact', ['user_id', 'id'], unique=False)
op.drop_index('ix_contact_user_id', table_name='contact')
def downgrade():
with op.get_context().autocommit_block():
op.create_index('ix_contact_user_id', 'contact', ['user_id'], unique=False)
op.drop_index('ix_contact_user_id_id', table_name='contact')

View File

@ -1,31 +0,0 @@
"""job priorities
Revision ID: fd79503179dd
Revises: 20e7d3ca289a
Create Date: 2025-02-25 15:39:24.833973
"""
import sqlalchemy_utils
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'fd79503179dd'
down_revision = '20e7d3ca289a'
branch_labels = None
depends_on = None
def upgrade():
with op.get_context().autocommit_block():
op.add_column('job', sa.Column('priority', sa.Integer(), server_default='50', nullable=False))
op.create_index('ix_state_run_at_taken_at_priority', 'job', ['state', 'run_at', 'taken_at', 'priority'], unique=False, postgresql_concurrently=True)
op.drop_index('ix_state_run_at_taken_at', table_name='job', postgresql_concurrently=True)
def downgrade():
with op.get_context().autocommit_block():
op.drop_index('ix_state_run_at_taken_at_priority', table_name='job', postgresql_concurrently=True)
op.create_index('ix_state_run_at_taken_at', 'job', ['state', 'run_at', 'taken_at'], unique=False, postgresql_concurrently=True)
op.drop_column('job', 'priority')

View File

@ -7,11 +7,8 @@ from typing import List, Dict
import arrow import arrow
import newrelic.agent import newrelic.agent
from app.models import JobState
from app.config import JOB_MAX_ATTEMPTS, JOB_TAKEN_RETRY_WAIT_MINS
from app.db import Session from app.db import Session
from app.log import LOG from app.log import LOG
from job_runner import get_jobs_to_run_query
from monitor.metric_exporter import MetricExporter from monitor.metric_exporter import MetricExporter
# the number of consecutive fails # the number of consecutive fails
@ -157,38 +154,6 @@ def log_failed_events():
newrelic.agent.record_custom_metric("Custom/sync_events_failed", failed_events) newrelic.agent.record_custom_metric("Custom/sync_events_failed", failed_events)
@newrelic.agent.background_task()
def log_jobs_to_run():
taken_before_time = arrow.now().shift(minutes=-JOB_TAKEN_RETRY_WAIT_MINS)
query = get_jobs_to_run_query(taken_before_time)
count = query.count()
LOG.d(f"Pending jobs to run: {count}")
newrelic.agent.record_custom_metric("Custom/jobs_to_run", count)
@newrelic.agent.background_task()
def log_failed_jobs():
r = Session.execute(
"""
SELECT COUNT(*)
FROM job
WHERE (
state = :error_state
OR (state = :taken_state AND attempts >= :max_attempts)
)
""",
{
"error_state": JobState.error.value,
"taken_state": JobState.taken.value,
"max_attempts": JOB_MAX_ATTEMPTS,
},
)
failed_jobs = list(r)[0][0]
LOG.d(f"Failed jobs: {failed_jobs}")
newrelic.agent.record_custom_metric("Custom/failed_jobs", failed_jobs)
if __name__ == "__main__": if __name__ == "__main__":
exporter = MetricExporter(get_newrelic_license()) exporter = MetricExporter(get_newrelic_license())
while True: while True:
@ -198,8 +163,6 @@ if __name__ == "__main__":
log_events_pending_dead_letter() log_events_pending_dead_letter()
log_failed_events() log_failed_events()
log_nb_db_connection_by_app_name() log_nb_db_connection_by_app_name()
log_jobs_to_run()
log_failed_jobs()
Session.close() Session.close()
exporter.run() exporter.run()

View File

@ -5,7 +5,6 @@ The step-to-step guide can be found on https://simplelogin.io/docs/siwsl/app/
This example is based on This example is based on
https://requests-oauthlib.readthedocs.io/en/latest/examples/real_world_example.html https://requests-oauthlib.readthedocs.io/en/latest/examples/real_world_example.html
""" """
import os import os
from flask import Flask, request, redirect, session, url_for from flask import Flask, request, redirect, session, url_for

View File

@ -34,4 +34,4 @@ for i in range(tests):
end = time.time() end = time.time()
time_taken = end - start time_taken = end - start
print(f"Took {time_taken} -> {time_taken / tests} per test") print(f"Took {time_taken} -> {time_taken/tests} per test")

View File

@ -1,123 +0,0 @@
#!/usr/bin/env python3
import argparse
import sys
import time
from sqlalchemy import func
from typing import Optional
from app.jobs.send_event_job import SendEventToWebhookJob
from app.db import Session
from app.events.generated.event_pb2 import UserPlanChanged, EventContent
from app.models import PartnerUser, User
def process(start_pu_id: int, end_pu_id: int, step: int, only_lifetime: bool):
print(
f"Checking partner user {start_pu_id} to {end_pu_id} (step={step}) (only_lifetime={only_lifetime})"
)
start_time = time.time()
with_lifetime = 0
with_plan = 0
with_free = 0
for batch_start in range(start_pu_id, end_pu_id, step):
query = (
Session.query(User)
.join(PartnerUser, PartnerUser.user_id == User.id)
.filter(PartnerUser.id >= batch_start, PartnerUser.id < batch_start + step)
)
if only_lifetime:
query = query.filter(
User.lifetime == True, # noqa :E712
)
users = query.all()
for user in users:
# Just in case the == True cond is wonky
if user.lifetime:
event = UserPlanChanged(lifetime=True)
with_lifetime += 1
else:
plan_end = user.get_active_subscription_end(
include_partner_subscription=False
)
if plan_end:
event = UserPlanChanged(plan_end_time=plan_end.timestamp)
with_plan += 1
else:
event = UserPlanChanged()
with_free += 1
job = SendEventToWebhookJob(
user=user, event=EventContent(user_plan_change=event)
)
job.store_job_in_db(run_at=None, commit=False)
Session.flush()
Session.commit()
elapsed = time.time() - start_time
last_batch_id = batch_start + step
time_per_user = elapsed / last_batch_id
remaining = end_pu_id - last_batch_id
time_remaining = remaining / time_per_user
hours_remaining = time_remaining / 60.0
print(
f"PartnerUser {batch_start}/{end_pu_id} lifetime {with_lifetime} paid {with_plan} free {with_free} {hours_remaining:.2f} mins remaining"
)
print(f"Sent lifetime {with_lifetime} paid {with_plan} free {with_free}")
def main():
parser = argparse.ArgumentParser(
prog="Schedule Sync User Jobs", description="Create jobs to sync users"
)
parser.add_argument(
"-s", "--start_pu_id", default=0, type=int, help="Initial partner_user_id"
)
parser.add_argument(
"-e", "--end_pu_id", default=0, type=int, help="Last partner_user_id"
)
parser.add_argument("-t", "--step", default=10000, type=int, help="Step to use")
parser.add_argument("-u", "--user", default="", type=str, help="User to sync")
parser.add_argument(
"-l", "--lifetime", action="store_true", help="Only sync lifetime users"
)
args = parser.parse_args()
start_pu_id = args.start_pu_id
end_pu_id = args.end_pu_id
user_id = args.user
only_lifetime = args.lifetime
step = args.step
if not end_pu_id:
end_pu_id = Session.query(func.max(PartnerUser.id)).scalar()
if user_id:
try:
user_id = int(user_id)
except ValueError:
user = User.get_by(email=user_id)
if not user:
print(f"User {user_id} not found")
sys.exit(1)
user_id = user.id
print(f"Limiting to user {user_id}")
partner_user: Optional[PartnerUser] = PartnerUser.get_by(user_id=user_id)
if not partner_user:
print(f"Could not find PartnerUser for user_id={user_id}")
sys.exit(1)
# So we only have one loop
step = 1
start_pu_id = partner_user.id
end_pu_id = partner_user.id
process(
start_pu_id=start_pu_id,
end_pu_id=end_pu_id,
step=step,
only_lifetime=only_lifetime,
)
if __name__ == "__main__":
main()

View File

@ -1,14 +1,14 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import sys
import time import time
import arrow
from sqlalchemy import func from sqlalchemy import func
from app.db import Session
from app.events.event_dispatcher import EventDispatcher from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import UserPlanChanged, EventContent from app.events.generated.event_pb2 import UserPlanChanged, EventContent
from app.models import PartnerUser, User from app.models import PartnerUser, User
from app.db import Session
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog="Backfill alias", description="Send lifetime users to proton" prog="Backfill alias", description="Send lifetime users to proton"
@ -19,69 +19,34 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
"-e", "--end_pu_id", default=0, type=int, help="Last partner_user_id" "-e", "--end_pu_id", default=0, type=int, help="Last partner_user_id"
) )
parser.add_argument("-t", "--step", default=10000, type=int, help="Step to use")
parser.add_argument("-u", "--user", default="", type=str, help="User to sync")
parser.add_argument(
"-l", "--lifetime", action="store_true", help="Only sync lifetime users"
)
args = parser.parse_args() args = parser.parse_args()
pu_id_start = args.start_pu_id pu_id_start = args.start_pu_id
max_pu_id = args.end_pu_id max_pu_id = args.end_pu_id
user_id = args.user
only_lifetime = args.lifetime
step = args.step
if max_pu_id == 0: if max_pu_id == 0:
max_pu_id = Session.query(func.max(PartnerUser.id)).scalar() max_pu_id = Session.query(func.max(PartnerUser.id)).scalar()
if user_id:
try:
user_id = int(user_id)
except ValueError:
user = User.get_by(email=user_id)
if not user:
print(f"User {user_id} not found")
sys.exit(1)
print(f"Limiting to user {user_id}")
user_id = user.id
# So we only have one loop
step = max_pu_id
print(f"Checking partner user {pu_id_start} to {max_pu_id}") print(f"Checking partner user {pu_id_start} to {max_pu_id}")
step = 1000
done = 0 done = 0
start_time = time.time() start_time = time.time()
with_lifetime = 0 with_lifetime = 0
with_plan = 0
with_free = 0
for batch_start in range(pu_id_start, max_pu_id, step): for batch_start in range(pu_id_start, max_pu_id, step):
query = Session.query(User).join(PartnerUser, PartnerUser.user_id == User.id) users = (
if user_id: Session.query(User)
query = query.filter(User.id == user_id) .join(PartnerUser, PartnerUser.user_id == User.id)
else: .filter(
query = query.filter( PartnerUser.id >= batch_start,
PartnerUser.id >= batch_start, PartnerUser.id < batch_start + step PartnerUser.id < batch_start + step,
)
if only_lifetime:
query = query.filter(
User.lifetime == True, # noqa :E712 User.lifetime == True, # noqa :E712
) )
users = query.all() ).all()
for user in users: for user in users:
# Just in case the == True cond is wonky # Just in case the == True cond is wonky
if user.lifetime: if not user.lifetime:
event = UserPlanChanged(lifetime=True) continue
with_lifetime += 1 with_lifetime += 1
else: event = UserPlanChanged(plan_end_time=arrow.get("2038-01-01").timestamp)
plan_end = user.get_active_subscription_end(
include_partner_subscription=False
)
if plan_end:
event = UserPlanChanged(plan_end_time=plan_end.timestamp)
with_plan += 1
else:
event = UserPlanChanged()
with_free += 1
EventDispatcher.send_event(user, EventContent(user_plan_change=event)) EventDispatcher.send_event(user, EventContent(user_plan_change=event))
Session.flush() Session.flush()
Session.commit() Session.commit()
@ -92,6 +57,6 @@ for batch_start in range(pu_id_start, max_pu_id, step):
time_remaining = remaining / time_per_alias time_remaining = remaining / time_per_alias
hours_remaining = time_remaining / 60.0 hours_remaining = time_remaining / 60.0
print( print(
f"artnerUser {batch_start}/{max_pu_id} lifetime {with_lifetime} paid {with_plan} free {with_free} {hours_remaining:.2f} mins remaining" f"\PartnerUser {batch_start}/{max_pu_id} {with_lifetime} {hours_remaining:.2f} mins remaining"
) )
print(f"Sent lifetime {with_lifetime} paid {with_plan} free {with_free}") print(f"With SL lifetime {with_lifetime}")

View File

@ -34,9 +34,6 @@ message AliasCreatedList {
repeated AliasCreated events = 1; repeated AliasCreated events = 1;
} }
message UserUnlinked {
}
message EventContent { message EventContent {
oneof content { oneof content {
UserPlanChanged user_plan_change = 1; UserPlanChanged user_plan_change = 1;
@ -45,7 +42,6 @@ message EventContent {
AliasStatusChanged alias_status_change = 4; AliasStatusChanged alias_status_change = 4;
AliasDeleted alias_deleted = 5; AliasDeleted alias_deleted = 5;
AliasCreatedList alias_create_list = 6; AliasCreatedList alias_create_list = 6;
UserUnlinked user_unlinked = 7;
} }
} }

View File

@ -12,7 +12,7 @@ packages = [
] ]
include = ["templates/*", "templates/**/*", "local_data/*.txt"] include = ["templates/*", "templates/**/*", "local_data/*.txt"]
requires-python = "~=3.12" requires-python = "~=3.10"
dependencies = [ dependencies = [
"flask ~= 1.1.2", "flask ~= 1.1.2",
@ -24,9 +24,9 @@ dependencies = [
"python-dotenv ~= 0.14.0", "python-dotenv ~= 0.14.0",
"ipython ~= 7.31.1", "ipython ~= 7.31.1",
"sqlalchemy_utils ~= 0.36.8", "sqlalchemy_utils ~= 0.36.8",
"psycopg2-binary ~= 2.9.10", "psycopg2-binary ~= 2.9.3",
"sentry_sdk ~= 2.20.0", "sentry_sdk ~= 2.20.0",
"blinker ~= 1.9.0", "blinker ~= 1.4",
"arrow ~= 0.16.0", "arrow ~= 0.16.0",
"Flask-WTF ~= 0.14.3", "Flask-WTF ~= 0.14.3",
"boto3 ~= 1.35.37", "boto3 ~= 1.35.37",
@ -36,16 +36,16 @@ dependencies = [
"watchtower ~= 0.8.0", "watchtower ~= 0.8.0",
"sqlalchemy-utils == 0.36.8", "sqlalchemy-utils == 0.36.8",
"jwcrypto ~= 0.8", "jwcrypto ~= 0.8",
"yacron~=0.19.0", "yacron~=0.11.2",
"flask-debugtoolbar ~= 0.11.0", "flask-debugtoolbar ~= 0.11.0",
"requests_oauthlib ~= 1.3.0", "requests_oauthlib ~= 1.3.0",
"pyopenssl ~= 19.1.0", "pyopenssl ~= 19.1.0",
"aiosmtpd ~= 1.2", "aiosmtpd ~= 1.2",
"dnspython ~= 2.7.0", "dnspython==2.0.0",
"coloredlogs ~= 14.0", "coloredlogs ~= 14.0",
"pycryptodome ~= 3.9.8", "pycryptodome ~= 3.9.8",
"phpserialize ~= 1.3", "phpserialize ~= 1.3",
"dkimpy == 1.0.5", "dkimpy ~= 1.0.5",
"pyotp ~= 2.4.0", "pyotp ~= 2.4.0",
"flask_profiler ~= 1.8.1", "flask_profiler ~= 1.8.1",
"facebook-sdk ~= 3.1.0", "facebook-sdk ~= 3.1.0",
@ -53,15 +53,11 @@ dependencies = [
"google-auth-httplib2 ~= 0.0.4", "google-auth-httplib2 ~= 0.0.4",
"python-gnupg ~= 0.4.6", "python-gnupg ~= 0.4.6",
"webauthn ~= 0.4.7", "webauthn ~= 0.4.7",
# Git dependency until pyspf creates a new release
#"pyspf @ git+https://github.com/sdgathman/pyspf.git@665a6df079485a9824be0829e7d71088453db7f6",
"pyspf ~= 2.0.14", "pyspf ~= 2.0.14",
"Flask-Limiter == 1.4",
"Flask-Limiter == 1.5",
"memory_profiler ~= 0.57.0", "memory_profiler ~= 0.57.0",
"gevent ~= 24.11.1", "gevent ~= 24.11.1",
"email-validator ~= 2.2.0", "email-validator ~= 1.1.3",
"PGPy == 0.5.4", "PGPy == 0.5.4",
"coinbase-commerce ~= 1.0.1", "coinbase-commerce ~= 1.0.1",
"requests ~= 2.25.1", "requests ~= 2.25.1",
@ -75,18 +71,16 @@ dependencies = [
"MarkupSafe~=1.1.1", "MarkupSafe~=1.1.1",
"cryptography ~= 37.0.1", "cryptography ~= 37.0.1",
"SQLAlchemy ~= 1.3.24", "SQLAlchemy ~= 1.3.24",
"redis==5.2.1", "redis==4.6.0",
"newrelic-telemetry-sdk ~= 0.5.0", "newrelic-telemetry-sdk ~= 0.5.0",
"aiospamc == 0.10", "aiospamc == 0.10",
"itsdangerous ~= 1.1.0", "itsdangerous ~= 1.1.0",
"werkzeug ~= 1.0.1", "werkzeug ~= 1.0.1",
"alembic ~= 1.4.3", "alembic ~= 1.4.3",
"limits ~= 4.0.1",
"strictyaml ~= 1.7.3",
] ]
[tool.black] [tool.black]
target-version = ['py312'] target-version = ['py310']
exclude = ''' exclude = '''
( (
/( /(
@ -107,9 +101,8 @@ exclude = '''
''' '''
[tool.ruff] [tool.ruff]
exclude = [".venv", "migrations", "app/events/generated"]
[tool.ruff.lint]
ignore-init-module-imports = true ignore-init-module-imports = true
exclude = [".venv", "migrations", "app/events/generated"]
[tool.djlint] [tool.djlint]
indent = 2 indent = 2
@ -134,7 +127,7 @@ ignore = "H006,H013,H016,H017,H019,H021,H025,H030,H031,T003,J004,J018,T001"
dev-dependencies = [ dev-dependencies = [
"pytest ~= 7.0.0", "pytest ~= 7.0.0",
"pytest-cov ~= 3.0.0", "pytest-cov ~= 3.0.0",
"pre-commit ~= 4.1.0", "pre-commit ~= 2.17.0",
"black ~= 22.1.0", "black ~= 22.1.0",
"djlint==1.34.1", "djlint==1.34.1",
"pylint ~= 2.14.4", "pylint ~= 2.14.4",

View File

@ -4,14 +4,12 @@ SCRIPT_DIR="$(cd "$(dirname "$0")" || exit 1; pwd -P)"
REPO_ROOT=$(echo "${SCRIPT_DIR}" | sed 's:scripts::g') REPO_ROOT=$(echo "${SCRIPT_DIR}" | sed 's:scripts::g')
BUILD_INFO_FILE="${REPO_ROOT}/app/build_info.py" BUILD_INFO_FILE="${REPO_ROOT}/app/build_info.py"
if [[ -z "$2" ]]; then if [[ -z "$1" ]]; then
echo "Invalid usage. Usage: $0 SHA VERSION" echo "This script needs to be invoked with the version as an argument"
exit 1 exit 1
fi fi
SHA="$1" VERSION="$1"
echo "SHA1 = \"${SHA}\"" > $BUILD_INFO_FILE echo "SHA1 = \"${VERSION}\"" > $BUILD_INFO_FILE
BUILD_TIME=$(date +%s) BUILD_TIME=$(date +%s)
echo "BUILD_TIME = \"${BUILD_TIME}\"" >> $BUILD_INFO_FILE echo "BUILD_TIME = \"${BUILD_TIME}\"" >> $BUILD_INFO_FILE
VERSION="$2"
echo "VERSION = \"${VERSION}\"" >> $BUILD_INFO_FILE

View File

@ -99,7 +99,6 @@ from app.models import (
InvalidMailboxDomain, InvalidMailboxDomain,
) )
from app.monitor.base import monitor_bp from app.monitor.base import monitor_bp
from app.monitor_utils import send_version_event
from app.newsletter_utils import send_newsletter_to_user from app.newsletter_utils import send_newsletter_to_user
from app.oauth.base import oauth_bp from app.oauth.base import oauth_bp
from app.onboarding.base import onboarding_bp from app.onboarding.base import onboarding_bp
@ -107,7 +106,6 @@ from app.payments.coinbase import setup_coinbase_commerce
from app.payments.paddle import setup_paddle_callback from app.payments.paddle import setup_paddle_callback
from app.phone.base import phone_bp from app.phone.base import phone_bp
from app.redis_services import initialize_redis_services from app.redis_services import initialize_redis_services
from app.request_utils import generate_request_id
from app.sentry_utils import sentry_before_send from app.sentry_utils import sentry_before_send
if SENTRY_DSN: if SENTRY_DSN:
@ -265,7 +263,6 @@ def set_index_page(app):
and not request.path.startswith("/_debug_toolbar") and not request.path.startswith("/_debug_toolbar")
): ):
g.start_time = time.time() g.start_time = time.time()
g.request_id = generate_request_id()
# to handle the referral url that has ?slref=code part # to handle the referral url that has ?slref=code part
ref_code = request.args.get("slref") ref_code = request.args.get("slref")
@ -296,7 +293,6 @@ def set_index_page(app):
newrelic.agent.record_custom_event( newrelic.agent.record_custom_event(
"HttpResponseStatus", {"code": res.status_code} "HttpResponseStatus", {"code": res.status_code}
) )
send_version_event("app")
return res return res
@ -446,10 +442,10 @@ def init_admin(app):
admin = Admin(name="SimpleLogin", template_mode="bootstrap4") admin = Admin(name="SimpleLogin", template_mode="bootstrap4")
admin.init_app(app, index_view=SLAdminIndexView()) admin.init_app(app, index_view=SLAdminIndexView())
admin.add_view(EmailSearchAdmin(name="Email Search", endpoint="admin.email_search")) admin.add_view(EmailSearchAdmin(name="Email Search", endpoint="email_search"))
admin.add_view( admin.add_view(
CustomDomainSearchAdmin( CustomDomainSearchAdmin(
name="Custom domain search", endpoint="admin.custom_domain_search" name="Custom domain search", endpoint="custom_domain_search"
) )
) )
admin.add_view(UserAdmin(User, Session)) admin.add_view(UserAdmin(User, Session))
@ -585,8 +581,7 @@ def local_main():
# enable flask toolbar # enable flask toolbar
from flask_debugtoolbar import DebugToolbarExtension from flask_debugtoolbar import DebugToolbarExtension
# Disabled in python 3.12 as it collides with the default CPython profiler app.config["DEBUG_TB_PROFILER_ENABLED"] = True
app.config["DEBUG_TB_PROFILER_ENABLED"] = False
app.config["DEBUG_TB_INTERCEPT_REDIRECTS"] = False app.config["DEBUG_TB_INTERCEPT_REDIRECTS"] = False
app.debug = True app.debug = True
DebugToolbarExtension(app) DebugToolbarExtension(app)

View File

@ -1,181 +1,118 @@
{% extends 'admin/master.html' %} {% extends 'admin/master.html' %}
{% block head_css %}
{{ super() }}
<style>
.card-shadow {
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15);
border-radius: 8px;
}
.domain-title {
background-color: #007bff;
color: white;
padding: 10px;
border-radius: 8px 8px 0 0;
}
.status-icon {
font-size: 1.2em;
}
</style>
{% endblock %}
{% macro show_user(user) -%} {% macro show_user(user) -%}
<h4> <h4>User <a href="/admin/email_search?email={{ user.email }}">{{ user.email }}</a> with ID {{ user.id }}.</h4>
User <a href="/admin/email_search?email={{ user.email }}">{{ user.email }}</a> with ID {{ user.id }}. <table class="table">
</h4> <thead>
<table class="table"> <tr>
<thead> <th scope="col">User ID</th>
<tr> <th scope="col">Email</th>
<th scope="col">User ID</th> <th scope="col">Verified</th>
<th scope="col">Email</th> <th scope="col">Status</th>
<th scope="col">Verified</th> <th scope="col">Paid</th>
<th scope="col">Status</th> <th scope="col">Premium</th>
<th scope="col">Paid</th> </tr>
<th scope="col">Premium</th> </thead>
</tr> <tbody>
</thead> <tr>
<tbody> <td>{{ user.id }}</td>
<tr> <td>
<td>{{ user.id }}</td> <a href="/admin/email_search?email={{ user.email }}">{{ user.email }}</a>
<td> </td>
<a href="/admin/email_search?email={{ user.email }}">{{ user.email }}</a> {% if user.activated %}
</td>
{% if user.activated %}
<td class="text-success">Activated</td> <td class="text-success">Activated</td>
{% else %} {% else %}
<td class="text-warning">Pending</td> <td class="text-warning">Pending</td>
{% endif %} {% endif %}
{% if user.disabled %} {% if user.disabled %}
<td class="text-danger">Disabled</td> <td class="text-danger">Disabled</td>
{% else %} {% else %}
<td class="text-success">Enabled</td> <td class="text-success">Enabled</td>
{% endif %} {% endif %}
<td>{{ "yes" if user.is_paid() else "No" }}</td> <td>{{ "yes" if user.is_paid() else "No" }}</td>
<td>{{ "yes" if user.is_premium() else "No" }}</td> <td>{{ "yes" if user.is_premium() else "No" }}</td>
</tr> </tr>
</tbody> </tbody>
</table> </table>
{%- endmacro %} {%- endmacro %}
{% macro show_verification(title, expected, errors) -%} {% macro show_verification(title, expected, errors) -%}
{% if not expected %} {% if not expected %}
<h4 class="mb-3">{{ title }} <span class="text-success">Verified</span></h4>
<li class="list-group-item d-flex justify-content-between align-items-center"> {% else %}
<h5>{{ title }}</h5> <h4 class="mb-3">{{ title }}</h4>
<span class="text-success status-icon"><i class="fa fa-check-circle"></i></span> <p>Expected</p>
</li> <p>{{expected}}</p>
{% else %} <p>Current response</p>
<li class="list-group-item">
<h5>{{ title }}</h5>
<p>
<strong>Expected:</strong> {{ expected.recommended }}
</p>
<p>
<strong>Allowed:</strong>
<ul>
{% for expected_record in expected.allowed %}<li>{{ expected_record }}</li>{% endfor %}
</ul>
</p>
<p>
<strong>Current response:</strong>
</p>
{% for error in errors %}
<ul class="list-group"> <ul class="list-group">
<li class="list-group-item">{{ error }}</li> {% for error in errors %}
</ul> <li class="list-group-item">{{ error }}</li>
{% endfor %}
</li>
{% endif %}
{%- endmacro %}
{% macro show_mx_verification(title, expected, errors) -%}
{% if not expected %}
<li class="list-group-item d-flex justify-content-between align-items-center">
<h5>{{ title }}</h5>
<span class="text-success status-icon"><i class="fa fa-check-circle"></i></span>
</li>
{% else %}
<li class="list-group-item">
<h5>{{ title }}</h5>
<ul>
<li class="list-group-item">
{% for prio in expected %}
<p>
<strong>Priority {{ prio }}:</strong> {{ expected[prio].recommended }}
</p>
<p>
<strong>Allowed:</strong>
<ul>
{% for expected_record in expected[prio].allowed %}<li>{{ expected_record }}</li>{% endfor %}
</ul>
</p>
<p>
<strong>Current response:</strong>
</p>
{% for error in errors %}
<ul class="list-group">
<li class="list-group-item">{{ error }}</li>
</ul>
{% endfor %}
</li>
{% endfor %} {% endfor %}
</ul>
</li>
{% endif %}
{%- endmacro %}
{% macro show_domain(domain_with_data) -%}
<div class="col-md-3 mb-4">
<div class="card card-shadow">
<div class="domain-title text-center">
<h4>Domain {{ domain_with_data.domain.domain }}</h4>
</div>
<div class="card-body">
{% set domain = domain_with_data.domain %}
<ul class="list-group">
{{ show_verification("Ownership", domain_with_data.ownership_expected, domain_with_data.ownership_validation.errors) }}
{{ show_mx_verification("MX", domain_with_data.mx_expected, domain_with_data.mx_validation.errors) }}
{{ show_verification("SPF", domain_with_data.spf_expected, domain_with_data.spf_validation.errors) }}
{% for dkim_domain in domain_with_data.dkim_expected %}
{{ show_verification("DKIM {}.{}".format(dkim_domain, domain.domain) , domain_with_data.dkim_expected[dkim_domain], [domain_with_data.dkim_validation.get(dkim_domain+"."+domain.domain,'')]) }}
{% endfor %}
</ul> </ul>
</div> {% endif %}
</div>
</div>
{%- endmacro %} {%- endmacro %}
{% macro show_domain(domain_with_data) -%}
<h3>Domain {{ domain_with_data.domain.domain }}</h3>
{% set domain = domain_with_data.domain %}
<ul class="list-group">
<li class="list-group-item">
{{ show_verification("Ownership", domain_with_data.ownership_expected, domain_with_data.ownership_validation.errors) }}
</li>
<li class="list-group-item">
{{ show_verification("MX", domain_with_data.mx_expected, domain_with_data.mx_validation.errors) }}
</li>
<li class="list-group-item">
{{ show_verification("SPF", domain_with_data.spf_expected, domain_with_data.spf_validation.errors) }}
</li>
{% for dkim_domain in domain_with_data.dkim_expected %}
<li class="list-group-item">
{{ show_verification("DKIM {}.{}".format(dkim_domain, domain.domain), domain_with_data.dkim_expected[dkim_domain], [domain_with_data.dkim_validation.get(dkim_domain+"."+domain.domain,'')]) }}
</li>
{% endfor %}
</ul>
{%- endmacro %}
{% block body %} {% block body %}
<div class="border border-dark border-2 mt-1 mb-2 p-3">
<form method="get">
<div class="form-group">
<label for="email">User or domain to search:</label>
<input type="text"
class="form-control"
name="user"
value="{{ query or '' }}" />
</div>
<button type="submit" class="btn btn-primary">Submit</button>
</form>
</div>
{% if data.no_match and query %}
<div class="border border-dark border-2 mt-1 mb-2 p-3 alert alert-warning"
role="alert">No user, alias or mailbox found for {{ query }}</div>
{% endif %}
{% if data.user %}
<div class="border border-dark border-2 mt-1 mb-2 p-3"> <div class="border border-dark border-2 mt-1 mb-2 p-3">
<h3 class="mb-3">Found User {{ data.user.email }}</h3> <form method="get">
{{ show_user(data.user) }} <div class="form-group">
<label for="email">User or domain to search:</label>
<input type="text"
class="form-control"
name="user"
value="{{ query or '' }}" />
</div>
<button type="submit" class="btn btn-primary">Submit</button>
</form>
</div> </div>
{% endif %} {% if data.no_match and query %}
<div class="row mt-4"> <div class="border border-dark border-2 mt-1 mb-2 p-3 alert alert-warning"
{% for domain_with_data in data.domains %}{{ show_domain(domain_with_data) }}{% endfor %} role="alert">No user, alias or mailbox found for {{ query }}</div>
</div> {% endif %}
</div> {% if data.user %}
<div class="border border-dark border-2 mt-1 mb-2 p-3">
<h3 class="mb-3">Found User {{ data.user.email }}</h3>
{{ show_user(data.user) }}
</div>
{% endif %}
<div class="d-flex">
{% for domain_with_data in data.domains %}
<div class="card m-2 border-dark" style="width: 30rem;">
<div class="card-body">
{{ show_domain(domain_with_data) }}
</div>
</div>
{% endfor %}
</div>
{% endblock %} {% endblock %}

View File

@ -22,7 +22,7 @@
<tr> <tr>
<td>{{ user.id }}</td> <td>{{ user.id }}</td>
<td> <td>
<a href="?query={{ user.email }}">{{ user.email }}</a> <a href="?email={{ user.email }}">{{ user.email }}</a>
</td> </td>
{% if user.activated %} {% if user.activated %}
@ -43,16 +43,8 @@
<td>{{ user.updated_at }}</td> <td>{{ user.updated_at }}</td>
{% if pu %} {% if pu %}
<td class="flex"> <td>
<a href="?query={{ pu.partner_email }}">{{ pu.partner_email }}</a> <a href="?email={{ pu.partner_email }}">{{ pu.partner_email }}</a>
<form class="d-inline"
action="{{ url_for("admin.email_search.delete_partner_link") }}"
method="POST">
<input type="hidden" name="user_id" value="{{ user.id }}">
<button type="submit"
onclick="return confirm('Are you sure you would like to unlink the user?');"
class="btn btn-danger d-inline">Unlink</button>
</form>
</td> </td>
{% else %} {% else %}
<td>No</td> <td>No</td>
@ -80,7 +72,7 @@
<tr> <tr>
<td>{{ mailbox.id }}</td> <td>{{ mailbox.id }}</td>
<td> <td>
<a href="?query={{ mailbox.email }}">{{ mailbox.email }}</a> <a href="?email={{ mailbox.email }}">{{ mailbox.email }}</a>
</td> </td>
<td>{{ "Yes" if mailbox.verified else "No" }}</td> <td>{{ "Yes" if mailbox.verified else "No" }}</td>
<td>{{ mailbox.created_at }}</td> <td>{{ mailbox.created_at }}</td>
@ -109,7 +101,7 @@
<tr> <tr>
<td>{{ alias.id }}</td> <td>{{ alias.id }}</td>
<td> <td>
<a href="?query={{ alias.email }}">{{ alias.email }}</a> <a href="?email={{ alias.email }}">{{ alias.email }}</a>
</td> </td>
<td>{{ "Yes" if alias.enabled else "No" }}</td> <td>{{ "Yes" if alias.enabled else "No" }}</td>
<td>{{ alias.created_at }}</td> <td>{{ alias.created_at }}</td>
@ -189,7 +181,7 @@
<td>{{ entry.user_id }}</td> <td>{{ entry.user_id }}</td>
<td>{{ entry.alias_id }}</td> <td>{{ entry.alias_id }}</td>
<td> <td>
<a href="?query={{ entry.alias_email }}">{{ entry.alias_email }}</a> <a href="?email={{ entry.alias_email }}">{{ entry.alias_email }}</a>
</td> </td>
<td>{{ entry.action }}</td> <td>{{ entry.action }}</td>
<td>{{ entry.message }}</td> <td>{{ entry.message }}</td>
@ -215,7 +207,7 @@
<tr> <tr>
<td> <td>
<a href="?query={{ entry.user_email }}">{{ entry.user_email }}</a> <a href="?email={{ entry.user_email }}">{{ entry.user_email }}</a>
</td> </td>
<td>{{ entry.action }}</td> <td>{{ entry.action }}</td>
<td>{{ entry.message }}</td> <td>{{ entry.message }}</td>
@ -230,10 +222,10 @@
<div class="border border-dark border-2 mt-1 mb-2 p-3"> <div class="border border-dark border-2 mt-1 mb-2 p-3">
<form method="get"> <form method="get">
<div class="form-group"> <div class="form-group">
<label for="email">UserID or Email to search:</label> <label for="email">Email to search:</label>
<input type="text" <input type="text"
class="form-control" class="form-control"
name="query" name="email"
value="{{ email or '' }}" /> value="{{ email or '' }}" />
</div> </div>
<button type="submit" class="btn btn-primary">Submit</button> <button type="submit" class="btn btn-primary">Submit</button>

View File

@ -38,7 +38,7 @@
Value: <em data-toggle="tooltip" Value: <em data-toggle="tooltip"
title="Click to copy" title="Click to copy"
class="clipboard" class="clipboard"
data-clipboard-text="{{ ownership_records.recommended }}">{{ ownership_records.recommended }}</em> data-clipboard-text="{{ ownership_record }}">{{ ownership_record }}</em>
</div> </div>
<form method="post" action="#ownership-form"> <form method="post" action="#ownership-form">
{{ csrf_form.csrf_token }} {{ csrf_form.csrf_token }}
@ -91,7 +91,7 @@
<br /> <br />
Some domain registrars (Namecheap, CloudFlare, etc) might also use <em>@</em> for the root domain. Some domain registrars (Namecheap, CloudFlare, etc) might also use <em>@</em> for the root domain.
</div> </div>
{% for prio in expected_mx_records %} {% for record in expected_mx_records %}
<div class="mb-3 p-3 dns-record"> <div class="mb-3 p-3 dns-record">
Record: MX Record: MX
@ -99,12 +99,12 @@
Domain: {{ custom_domain.domain }} or Domain: {{ custom_domain.domain }} or
<b>@</b> <b>@</b>
<br /> <br />
Priority: {{ prio }} Priority: {{ record.priority }}
<br /> <br />
Target: <em data-toggle="tooltip" Target: <em data-toggle="tooltip"
title="Click to copy" title="Click to copy"
class="clipboard" class="clipboard"
data-clipboard-text="{{ expected_mx_records[prio].recommended }}">{{ expected_mx_records[prio].recommended }}</em> data-clipboard-text="{{ record.domain }}">{{ record.domain }}</em>
</div> </div>
{% endfor %} {% endfor %}
<form method="post" action="#mx-form"> <form method="post" action="#mx-form">
@ -251,8 +251,8 @@
<em data-toggle="tooltip" <em data-toggle="tooltip"
title="Click to copy" title="Click to copy"
class="clipboard" class="clipboard"
data-clipboard-text="{{ dkim_cname_value.recommended }}." data-clipboard-text="{{ dkim_cname_value }}."
style="overflow-wrap: break-word">{{ dkim_cname_value.recommended }}.</em> style="overflow-wrap: break-word">{{ dkim_cname_value }}.</em>
</div> </div>
{% endfor %} {% endfor %}
<div class="alert alert-info"> <div class="alert alert-info">

View File

@ -647,8 +647,8 @@ def test_get_alias(flask_client):
def test_is_reverse_alias(flask_client): def test_is_reverse_alias(flask_client):
assert is_reverse_alias("ra+abcd@sl.lan") assert is_reverse_alias("ra+abcd@sl.local")
assert is_reverse_alias("reply+abcd@sl.lan") assert is_reverse_alias("reply+abcd@sl.local")
assert not is_reverse_alias("ra+abcd@test.org") assert not is_reverse_alias("ra+abcd@test.org")
assert not is_reverse_alias("reply+abcd@test.org") assert not is_reverse_alias("reply+abcd@test.org")

View File

@ -1,3 +1,5 @@
from flask import g
from app import config from app import config
from app.alias_suffix import signer from app.alias_suffix import signer
from app.alias_utils import delete_alias from app.alias_utils import delete_alias
@ -5,7 +7,7 @@ from app.config import EMAIL_DOMAIN, MAX_NB_EMAIL_FREE_PLAN
from app.db import Session from app.db import Session
from app.models import Alias, CustomDomain, Mailbox, AliasUsedOn from app.models import Alias, CustomDomain, Mailbox, AliasUsedOn
from app.utils import random_word from app.utils import random_word
from tests.utils import fix_rate_limit_after_request, login, random_domain, random_token from tests.utils import login, random_domain, random_token
def test_v2(flask_client): def test_v2(flask_client):
@ -274,7 +276,7 @@ def test_too_many_requests(flask_client):
# to make flask-limiter work with unit test # to make flask-limiter work with unit test
# https://github.com/alisaifee/flask-limiter/issues/147#issuecomment-642683820 # https://github.com/alisaifee/flask-limiter/issues/147#issuecomment-642683820
fix_rate_limit_after_request() g._rate_limiting_complete = False
else: else:
# last request # last request
assert r.status_code == 429 assert r.status_code == 429

View File

@ -1,12 +1,12 @@
import uuid import uuid
from flask import url_for from flask import url_for, g
from app import config from app import config
from app.config import EMAIL_DOMAIN, MAX_NB_EMAIL_FREE_PLAN from app.config import EMAIL_DOMAIN, MAX_NB_EMAIL_FREE_PLAN
from app.db import Session from app.db import Session
from app.models import Alias, CustomDomain, AliasUsedOn from app.models import Alias, CustomDomain, AliasUsedOn
from tests.utils import fix_rate_limit_after_request, login, random_domain from tests.utils import login, random_domain
def test_with_hostname(flask_client): def test_with_hostname(flask_client):
@ -17,7 +17,7 @@ def test_with_hostname(flask_client):
) )
assert r.status_code == 201 assert r.status_code == 201
assert r.json["alias"].endswith("d1.lan") assert r.json["alias"].endswith("d1.test")
# make sure alias starts with the suggested prefix # make sure alias starts with the suggested prefix
assert r.json["alias"].startswith("test") assert r.json["alias"].startswith("test")
@ -133,7 +133,7 @@ def test_too_many_requests(flask_client):
) )
# to make flask-limiter work with unit test # to make flask-limiter work with unit test
# https://github.com/alisaifee/flask-limiter/issues/147#issuecomment-642683820 # https://github.com/alisaifee/flask-limiter/issues/147#issuecomment-642683820
fix_rate_limit_after_request() g._rate_limiting_complete = False
else: else:
# last request # last request
assert r.status_code == 429 assert r.status_code == 429

View File

@ -112,14 +112,14 @@ def test_get_alias_infos_with_pagination_v3_no_duplicate_when_empty_contact(
user_id=user.id, user_id=user.id,
alias_id=alias.id, alias_id=alias.id,
website_email="contact@example.com", website_email="contact@example.com",
reply_email="rep@sl.lan", reply_email="rep@sl.local",
) )
Contact.create( Contact.create(
user_id=user.id, user_id=user.id,
alias_id=alias.id, alias_id=alias.id,
website_email="contact2@example.com", website_email="contact2@example.com",
reply_email="rep2@sl.lan", reply_email="rep2@sl.local",
) )
alias_infos = get_alias_infos_with_pagination_v3(user) alias_infos = get_alias_infos_with_pagination_v3(user)

View File

@ -15,7 +15,7 @@ def test_get_setting(flask_client):
assert r.json == { assert r.json == {
"alias_generator": "word", "alias_generator": "word",
"notification": True, "notification": True,
"random_alias_default_domain": "sl.lan", "random_alias_default_domain": "sl.local",
"sender_format": "AT", "sender_format": "AT",
"random_alias_suffix": "word", "random_alias_suffix": "word",
} }
@ -47,7 +47,7 @@ def test_update_settings_random_alias_default_domain(flask_client):
custom_domain = CustomDomain.create( custom_domain = CustomDomain.create(
domain=random_domain(), verified=True, user_id=user.id, flush=True domain=random_domain(), verified=True, user_id=user.id, flush=True
) )
assert user.default_random_alias_domain() == "sl.lan" assert user.default_random_alias_domain() == "sl.local"
r = flask_client.patch( r = flask_client.patch(
"/api/setting", json={"random_alias_default_domain": "invalid"} "/api/setting", json={"random_alias_default_domain": "invalid"}
@ -55,10 +55,10 @@ def test_update_settings_random_alias_default_domain(flask_client):
assert r.status_code == 400 assert r.status_code == 400
r = flask_client.patch( r = flask_client.patch(
"/api/setting", json={"random_alias_default_domain": "d1.lan"} "/api/setting", json={"random_alias_default_domain": "d1.test"}
) )
assert r.status_code == 200 assert r.status_code == 200
assert user.default_random_alias_domain() == "d1.lan" assert user.default_random_alias_domain() == "d1.test"
r = flask_client.patch( r = flask_client.patch(
"/api/setting", json={"random_alias_default_domain": custom_domain.domain} "/api/setting", json={"random_alias_default_domain": custom_domain.domain}

View File

@ -2,7 +2,7 @@ from random import random
from flask import url_for from flask import url_for
from app.constants import JobType from app import config
from app.db import Session from app.db import Session
from app.models import Job, ApiToCookieToken from app.models import Job, ApiToCookieToken
from tests.api.utils import get_new_user_and_api_key from tests.api.utils import get_new_user_and_api_key
@ -48,7 +48,7 @@ def test_delete_with_sudo(flask_client):
jobs = Job.all() jobs = Job.all()
assert len(jobs) == 1 assert len(jobs) == 1
job = jobs[0] job = jobs[0]
assert job.name == JobType.DELETE_ACCOUNT.value assert job.name == config.JOB_DELETE_ACCOUNT
assert job.payload == {"user_id": user.id} assert job.payload == {"user_id": user.id}

View File

@ -3,7 +3,7 @@ from flask import url_for
from app import config from app import config
from app.db import Session from app.db import Session
from app.models import User, PartnerUser from app.models import User, PartnerUser
from app.proton.proton_partner import get_proton_partner from app.proton.utils import get_proton_partner
from tests.api.utils import get_new_user_and_api_key from tests.api.utils import get_new_user_and_api_key
from tests.utils import login, random_token, random_email from tests.utils import login, random_token, random_email

View File

@ -23,7 +23,7 @@ from init_app import add_sl_domains, add_proton_partner
app = create_app() app = create_app()
app.config["TESTING"] = True app.config["TESTING"] = True
app.config["WTF_CSRF_ENABLED"] = False app.config["WTF_CSRF_ENABLED"] = False
app.config["SERVER_NAME"] = "sl.lan" app.config["SERVER_NAME"] = "sl.test"
# enable pg_trgm extension # enable pg_trgm extension
with engine.connect() as conn: with engine.connect() as conn:

View File

@ -14,7 +14,7 @@ from app.models import (
PartnerSubscription, PartnerSubscription,
User, User,
) )
from app.proton.proton_partner import get_proton_partner from app.proton.utils import get_proton_partner
from tests.utils import create_new_user, random_token from tests.utils import create_new_user, random_token

View File

@ -1,6 +1,6 @@
from random import random from random import random
from flask import url_for from flask import url_for, g
from app import config from app import config
from app.alias_suffix import ( from app.alias_suffix import (
@ -22,12 +22,7 @@ from app.models import (
DailyMetric, DailyMetric,
) )
from app.utils import random_word from app.utils import random_word
from tests.utils import ( from tests.utils import login, random_domain, create_new_user
fix_rate_limit_after_request,
login,
random_domain,
create_new_user,
)
def test_add_alias_success(flask_client): def test_add_alias_success(flask_client):
@ -393,7 +388,7 @@ def test_too_many_requests(flask_client):
# to make flask-limiter work with unit test # to make flask-limiter work with unit test
# https://github.com/alisaifee/flask-limiter/issues/147#issuecomment-642683820 # https://github.com/alisaifee/flask-limiter/issues/147#issuecomment-642683820
fix_rate_limit_after_request() g._rate_limiting_complete = False
else: else:
# last request # last request
assert r.status_code == 429 assert r.status_code == 429

View File

@ -1,10 +1,10 @@
from flask import url_for from flask import url_for, g
from app import config from app import config
from app.models import ( from app.models import (
Alias, Alias,
) )
from tests.utils import fix_rate_limit_after_request, login from tests.utils import login
def test_create_random_alias_success(flask_client): def test_create_random_alias_success(flask_client):
@ -34,7 +34,7 @@ def test_too_many_requests(flask_client):
# to make flask-limiter work with unit test # to make flask-limiter work with unit test
# https://github.com/alisaifee/flask-limiter/issues/147#issuecomment-642683820 # https://github.com/alisaifee/flask-limiter/issues/147#issuecomment-642683820
fix_rate_limit_after_request() g._rate_limiting_complete = False
else: else:
# last request # last request
assert r.status_code == 429 assert r.status_code == 429

View File

@ -28,7 +28,7 @@ def test_rate_limited_forward_phase_for_alias(flask_client):
user_id=user.id, user_id=user.id,
alias_id=alias.id, alias_id=alias.id,
website_email="contact@example.com", website_email="contact@example.com",
reply_email="rep@sl.lan", reply_email="rep@sl.local",
) )
Session.commit() Session.commit()
for _ in range(MAX_ACTIVITY_DURING_MINUTE_PER_ALIAS + 1): for _ in range(MAX_ACTIVITY_DURING_MINUTE_PER_ALIAS + 1):
@ -52,7 +52,7 @@ def test_rate_limited_forward_phase_for_mailbox(flask_client):
user_id=user.id, user_id=user.id,
alias_id=alias.id, alias_id=alias.id,
website_email="contact@example.com", website_email="contact@example.com",
reply_email="rep@sl.lan", reply_email="rep@sl.local",
) )
Session.commit() Session.commit()
for _ in range(MAX_ACTIVITY_DURING_MINUTE_PER_MAILBOX + 1): for _ in range(MAX_ACTIVITY_DURING_MINUTE_PER_MAILBOX + 1):
@ -90,7 +90,7 @@ def test_rate_limited_reply_phase(flask_client):
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
Session.commit() Session.commit()
reply_email = f"reply-{random.random()}@sl.lan" reply_email = f"reply-{random.random()}@sl.local"
contact = Contact.create( contact = Contact.create(
user_id=user.id, user_id=user.id,
alias_id=alias.id, alias_id=alias.id,

View File

@ -1,7 +1,7 @@
from app.events.event_dispatcher import Dispatcher from app.events.event_dispatcher import Dispatcher
from app.events.generated import event_pb2 from app.events.generated import event_pb2
from app.models import PartnerUser, User from app.models import PartnerUser, User
from app.proton.proton_partner import get_proton_partner from app.proton.utils import get_proton_partner
from tests.utils import create_new_user, random_token from tests.utils import create_new_user, random_token
from typing import Tuple from typing import Tuple

View File

@ -37,7 +37,7 @@ def prepare_complaint(
contact = Contact.create( contact = Contact.create(
user_id=alias.user.id, user_id=alias.user.id,
alias_id=alias.id, alias_id=alias.id,
website_email=f"contact{random.random()}@mailbox.lan", website_email=f"contact{random.random()}@mailbox.test",
reply_email="d@e.f", reply_email="d@e.f",
commit=True, commit=True,
) )

View File

@ -27,7 +27,7 @@ def generate_unsub_block_contact_data() -> Iterable:
user_id=user.id, user_id=user.id,
alias_id=alias.id, alias_id=alias.id,
website_email="contact@example.com", website_email="contact@example.com",
reply_email="rep@sl.lan", reply_email="rep@sl.local",
commit=True, commit=True,
) )
@ -86,7 +86,7 @@ def generate_unsub_disable_alias_data() -> Iterable:
user_id=user.id, user_id=user.id,
alias_id=alias.id, alias_id=alias.id,
website_email="contact@example.com", website_email="contact@example.com",
reply_email="rep@sl.lan", reply_email="rep@sl.local",
commit=True, commit=True,
) )
@ -145,7 +145,7 @@ def generate_unsub_preserve_original_data() -> Iterable:
user_id=user.id, user_id=user.id,
alias_id=alias.id, alias_id=alias.id,
website_email="contact@example.com", website_email="contact@example.com",
reply_email="rep@sl.lan", reply_email="rep@sl.local",
commit=True, commit=True,
) )
@ -215,7 +215,7 @@ def test_unsub_preserves_sl_unsubscriber():
user_id=user.id, user_id=user.id,
alias_id=alias.id, alias_id=alias.id,
website_email="contact@example.com", website_email="contact@example.com",
reply_email="rep@sl.lan", reply_email="rep@sl.local",
commit=True, commit=True,
) )
message = Message() message = Message()

View File

@ -49,7 +49,7 @@ def test_old_subject_block_contact():
user_id=user.id, user_id=user.id,
alias_id=alias.id, alias_id=alias.id,
website_email="contact@example.com", website_email="contact@example.com",
reply_email=f"{random()}@sl.lan", reply_email=f"{random()}@sl.local",
block_forward=False, block_forward=False,
commit=True, commit=True,
) )
@ -92,7 +92,7 @@ def test_new_subject_block_contact():
user_id=user.id, user_id=user.id,
alias_id=alias.id, alias_id=alias.id,
website_email="contact@example.com", website_email="contact@example.com",
reply_email=f"{random()}@sl.lan", reply_email=f"{random()}@sl.local",
block_forward=False, block_forward=False,
commit=True, commit=True,
) )
@ -172,7 +172,7 @@ def test_request_disable_contact(flask_client):
user_id=user.id, user_id=user.id,
alias_id=alias.id, alias_id=alias.id,
website_email="contact@example.com", website_email="contact@example.com",
reply_email=f"{random()}@sl.lan", reply_email=f"{random()}@sl.local",
block_forward=False, block_forward=False,
commit=True, commit=True,
) )

View File

@ -1,6 +1,6 @@
from sqlalchemy_utils.types.arrow import arrow from sqlalchemy_utils.types.arrow import arrow
from app.constants import JobType from app.config import JOB_DELETE_MAILBOX
from app.db import Session from app.db import Session
from app.mail_sender import mail_sender from app.mail_sender import mail_sender
from app.models import Alias, Mailbox, Job, AliasMailbox from app.models import Alias, Mailbox, Job, AliasMailbox
@ -21,7 +21,7 @@ def test_delete_mailbox_transfer_mailbox_primary(flask_client):
alias_id = Alias.create_new(user, "prefix", mailbox_id=m1.id).id alias_id = Alias.create_new(user, "prefix", mailbox_id=m1.id).id
AliasMailbox.create(alias_id=alias_id, mailbox_id=m2.id) AliasMailbox.create(alias_id=alias_id, mailbox_id=m2.id)
job = Job.create( job = Job.create(
name=JobType.DELETE_MAILBOX.value, name=JOB_DELETE_MAILBOX,
payload={"mailbox_id": m1.id, "transfer_mailbox_id": m2.id}, payload={"mailbox_id": m1.id, "transfer_mailbox_id": m2.id},
run_at=arrow.now(), run_at=arrow.now(),
commit=True, commit=True,
@ -43,7 +43,7 @@ def test_delete_mailbox_no_email(flask_client):
user_id=user.id, email=random_email(), verified=True, flush=True user_id=user.id, email=random_email(), verified=True, flush=True
) )
job = Job.create( job = Job.create(
name=JobType.DELETE_MAILBOX.value, name=JOB_DELETE_MAILBOX,
payload={"mailbox_id": m1.id, "transfer_mailbox_id": None, "send_mail": False}, payload={"mailbox_id": m1.id, "transfer_mailbox_id": None, "send_mail": False},
run_at=arrow.now(), run_at=arrow.now(),
commit=True, commit=True,
@ -70,7 +70,7 @@ def test_delete_mailbox_transfer_mailbox_in_list(flask_client):
alias_id = Alias.create_new(user, "prefix", mailbox_id=m1.id).id alias_id = Alias.create_new(user, "prefix", mailbox_id=m1.id).id
AliasMailbox.create(alias_id=alias_id, mailbox_id=m2.id) AliasMailbox.create(alias_id=alias_id, mailbox_id=m2.id)
job = Job.create( job = Job.create(
name=JobType.DELETE_MAILBOX.value, name=JOB_DELETE_MAILBOX,
payload={"mailbox_id": m2.id, "transfer_mailbox_id": m3.id}, payload={"mailbox_id": m2.id, "transfer_mailbox_id": m3.id},
run_at=arrow.now(), run_at=arrow.now(),
commit=True, commit=True,
@ -95,7 +95,7 @@ def test_delete_mailbox_no_transfer(flask_client):
alias_id = Alias.create_new(user, "prefix", mailbox_id=m1.id).id alias_id = Alias.create_new(user, "prefix", mailbox_id=m1.id).id
job = Job.create( job = Job.create(
name=JobType.DELETE_MAILBOX.value, name=JOB_DELETE_MAILBOX,
payload={"mailbox_id": m1.id}, payload={"mailbox_id": m1.id},
run_at=arrow.now(), run_at=arrow.now(),
commit=True, commit=True,

View File

@ -1,7 +1,7 @@
from app import config from app import config
from app.db import Session from app.db import Session
from job_runner import get_jobs_to_run from job_runner import get_jobs_to_run
from app.models import Job, JobPriority, JobState from app.models import Job, JobState
import arrow import arrow
@ -56,65 +56,18 @@ def test_get_jobs_to_run(flask_client):
run_at=now.shift(hours=3), run_at=now.shift(hours=3),
) )
# Job out of attempts # Job out of attempts
Job.create( (
name="", Job.create(
payload="", name="",
state=JobState.taken.value, payload="",
taken_at=now.shift(minutes=-(config.JOB_TAKEN_RETRY_WAIT_MINS + 10)), state=JobState.taken.value,
attempts=config.JOB_MAX_ATTEMPTS + 1, taken_at=now.shift(minutes=-(config.JOB_TAKEN_RETRY_WAIT_MINS + 10)),
attempts=config.JOB_MAX_ATTEMPTS + 1,
),
) )
# Job marked as error
Job.create(
name="",
payload="",
state=JobState.error.value,
taken_at=now.shift(minutes=-(config.JOB_TAKEN_RETRY_WAIT_MINS + 10)),
attempts=config.JOB_MAX_ATTEMPTS + 1,
)
Session.commit() Session.commit()
taken_before_time = arrow.now().shift(minutes=-config.JOB_TAKEN_RETRY_WAIT_MINS) jobs = get_jobs_to_run()
jobs = get_jobs_to_run(taken_before_time)
assert len(jobs) == len(expected_jobs_to_run) assert len(jobs) == len(expected_jobs_to_run)
job_ids = [job.id for job in jobs] job_ids = [job.id for job in jobs]
for job in expected_jobs_to_run: for job in expected_jobs_to_run:
assert job.id in job_ids assert job.id in job_ids
def test_get_jobs_to_run_respects_priority(flask_client):
now = arrow.now()
for job in Job.all():
Job.delete(job.id)
j1 = Job.create(
name="", payload="", run_at=now.shift(minutes=-1), priority=JobPriority.High
)
j2 = Job.create(
name="", payload="", run_at=now.shift(minutes=-2), priority=JobPriority.Default
)
j3 = Job.create(
name="", payload="", run_at=now.shift(minutes=-3), priority=JobPriority.Default
)
j4 = Job.create(
name="", payload="", run_at=now.shift(minutes=-4), priority=JobPriority.Low
)
j5 = Job.create(
name="", payload="", run_at=now.shift(minutes=-2), priority=JobPriority.High
)
Session.commit()
taken_before_time = arrow.now().shift(minutes=-config.JOB_TAKEN_RETRY_WAIT_MINS)
jobs = get_jobs_to_run(taken_before_time)
assert len(jobs) == 5
job_ids = [job.id for job in jobs]
# The expected outcome is:
# 1. j5 -> 2 mins ago and High
# 2. j1 -> 1 min ago and High
# --- The 2 above are high, so they should be the first ones. j5 is first as it's been pending for a longer time
# 3. j3 -> 3 mins ago and Default
# 4. j2 -> 2 mins ago and Default
# --- The 2 above are both default, and again, are sorted by run_at ascendingly
# 5. j4 -> 3 mins ago and Low. Even if it is the one that has been waiting the most, as it's Low, it's the last one
assert job_ids == [j5.id, j1.id, j3.id, j2.id, j4.id]

View File

@ -1,10 +1,10 @@
import arrow import arrow
from app.constants import JobType from app import config
from app.events.generated.event_pb2 import EventContent, AliasDeleted from app.events.generated.event_pb2 import EventContent, AliasDeleted
from app.jobs.send_event_job import SendEventToWebhookJob from app.jobs.send_event_job import SendEventToWebhookJob
from app.models import PartnerUser from app.models import PartnerUser
from app.proton.proton_partner import get_proton_partner from app.proton.utils import get_proton_partner
from events.event_sink import ConsoleEventSink from events.event_sink import ConsoleEventSink
from tests.utils import create_new_user, random_token from tests.utils import create_new_user, random_token
@ -17,7 +17,7 @@ def test_serialize_and_deserialize_job():
run_at = arrow.now().shift(hours=10) run_at = arrow.now().shift(hours=10)
db_job = SendEventToWebhookJob(user, event).store_job_in_db(run_at=run_at) db_job = SendEventToWebhookJob(user, event).store_job_in_db(run_at=run_at)
assert db_job.run_at == run_at assert db_job.run_at == run_at
assert db_job.name == JobType.SEND_EVENT_TO_WEBHOOK.value assert db_job.name == config.JOB_SEND_EVENT_TO_WEBHOOK
job = SendEventToWebhookJob.create_from_job(db_job) job = SendEventToWebhookJob.create_from_job(db_job)
assert job._user.id == user.id assert job._user.id == user.id
assert job._event.alias_deleted.id == alias_id assert job._event.alias_deleted.id == alias_id

View File

@ -1,37 +0,0 @@
from app import config
from app.dns_utils import set_global_dns_client, InMemoryDNSClient
from app.email_utils import get_email_local_part
from app.models import Mailbox
from tests.utils import create_new_user, random_email
dns_client = InMemoryDNSClient()
def setup_module():
set_global_dns_client(dns_client)
def teardown_module():
set_global_dns_client(None)
def test_is_proton_with_email_domain():
user = create_new_user()
mailbox = Mailbox.create(
user_id=user.id, email=f"test@{config.PROTON_EMAIL_DOMAINS[0]}"
)
assert mailbox.is_proton()
mailbox = Mailbox.create(user_id=user.id, email="a@b.c")
assert not mailbox.is_proton()
def test_is_proton_with_mx_domain():
email = random_email()
dns_client.set_mx_records(
get_email_local_part(email), {10: config.PROTON_MX_SERVERS}
)
user = create_new_user()
mailbox = Mailbox.create(user_id=user.id, email=email)
assert mailbox.is_proton()
dns_client.set_mx_records(get_email_local_part(email), {10: ["nowhere.net"]})
assert not mailbox.is_proton()

View File

@ -1,9 +1,8 @@
import arrow import arrow
from app import config
from app.constants import JobType
from app.db import Session from app.db import Session
from app.models import User, Job, PartnerSubscription, PartnerUser, ManualSubscription from app.models import User, Job, PartnerSubscription, PartnerUser, ManualSubscription
from app.proton.proton_partner import get_proton_partner from app.proton.utils import get_proton_partner
from tests.utils import random_email, random_token from tests.utils import random_email, random_token
@ -17,7 +16,7 @@ def test_create_from_partner(flask_client):
assert user.newsletter_alias_id is None assert user.newsletter_alias_id is None
job = Session.query(Job).order_by(Job.id.desc()).first() job = Session.query(Job).order_by(Job.id.desc()).first()
assert job is not None assert job is not None
assert job.name == JobType.SEND_PROTON_WELCOME_1.value assert job.name == config.JOB_SEND_PROTON_WELCOME_1
assert job.payload.get("user_id") == user.id assert job.payload.get("user_id") == user.id

View File

@ -140,13 +140,13 @@ def test_get_metrics():
records=[ records=[
UpcloudRecord( UpcloudRecord(
db_role="master", db_role="master",
label="test-1 (master)", label="test-1 " "(master)",
time="2022-01-21T13:12:00Z", time="2022-01-21T13:12:00Z",
value=3.275132296130991, value=3.275132296130991,
), ),
UpcloudRecord( UpcloudRecord(
db_role="standby", db_role="standby",
label="test-2 (standby)", label="test-2 " "(standby)",
time="2022-01-21T13:12:00Z", time="2022-01-21T13:12:00Z",
value=4.196249043309251, value=4.196249043309251,
), ),
@ -157,13 +157,13 @@ def test_get_metrics():
records=[ records=[
UpcloudRecord( UpcloudRecord(
db_role="master", db_role="master",
label="test-1 (master)", label="test-1 " "(master)",
time="2022-01-21T13:11:30Z", time="2022-01-21T13:11:30Z",
value=5.654416415900109, value=5.654416415900109,
), ),
UpcloudRecord( UpcloudRecord(
db_role="standby", db_role="standby",
label="test-2 (standby)", label="test-2 " "(standby)",
time="2022-01-21T13:11:30Z", time="2022-01-21T13:11:30Z",
value=5.58959125727556, value=5.58959125727556,
), ),
@ -174,13 +174,13 @@ def test_get_metrics():
records=[ records=[
UpcloudRecord( UpcloudRecord(
db_role="master", db_role="master",
label="test-1 (master)", label="test-1 " "(master)",
time="2022-01-21T13:11:30Z", time="2022-01-21T13:11:30Z",
value=0, value=0,
), ),
UpcloudRecord( UpcloudRecord(
db_role="standby", db_role="standby",
label="test-2 (standby)", label="test-2 " "(standby)",
time="2022-01-21T13:11:30Z", time="2022-01-21T13:11:30Z",
value=0, value=0,
), ),
@ -191,13 +191,13 @@ def test_get_metrics():
records=[ records=[
UpcloudRecord( UpcloudRecord(
db_role="master", db_role="master",
label="test-1 (master)", label="test-1 " "(master)",
time="2022-01-21T13:11:30Z", time="2022-01-21T13:11:30Z",
value=4, value=4,
), ),
UpcloudRecord( UpcloudRecord(
db_role="standby", db_role="standby",
label="test-2 (standby)", label="test-2 " "(standby)",
time="2022-01-21T13:11:30Z", time="2022-01-21T13:11:30Z",
value=3, value=3,
), ),
@ -208,13 +208,13 @@ def test_get_metrics():
records=[ records=[
UpcloudRecord( UpcloudRecord(
db_role="master", db_role="master",
label="test-1 (master)", label="test-1 " "(master)",
time="2022-01-21T13:11:30Z", time="2022-01-21T13:11:30Z",
value=0.14, value=0.14,
), ),
UpcloudRecord( UpcloudRecord(
db_role="standby", db_role="standby",
label="test-2 (standby)", label="test-2 " "(standby)",
time="2022-01-21T13:11:30Z", time="2022-01-21T13:11:30Z",
value=0.09, value=0.09,
), ),
@ -225,13 +225,13 @@ def test_get_metrics():
records=[ records=[
UpcloudRecord( UpcloudRecord(
db_role="master", db_role="master",
label="test-1 (master)", label="test-1 " "(master)",
time="2022-01-21T13:11:30Z", time="2022-01-21T13:11:30Z",
value=11.488581675749048, value=11.488581675749048,
), ),
UpcloudRecord( UpcloudRecord(
db_role="standby", db_role="standby",
label="test-2 (standby)", label="test-2 " "(standby)",
time="2022-01-21T13:11:30Z", time="2022-01-21T13:11:30Z",
value=12.272260458006759, value=12.272260458006759,
), ),
@ -242,13 +242,13 @@ def test_get_metrics():
records=[ records=[
UpcloudRecord( UpcloudRecord(
db_role="master", db_role="master",
label="test-1 (master)", label="test-1 " "(master)",
time="2022-01-21T13:11:30Z", time="2022-01-21T13:11:30Z",
value=466, value=466,
), ),
UpcloudRecord( UpcloudRecord(
db_role="standby", db_role="standby",
label="test-2 (standby)", label="test-2 " "(standby)",
time="2022-01-21T13:11:30Z", time="2022-01-21T13:11:30Z",
value=458, value=458,
), ),
@ -259,13 +259,13 @@ def test_get_metrics():
records=[ records=[
UpcloudRecord( UpcloudRecord(
db_role="master", db_role="master",
label="test-1 (master)", label="test-1 " "(master)",
time="2022-01-21T13:11:30Z", time="2022-01-21T13:11:30Z",
value=694, value=694,
), ),
UpcloudRecord( UpcloudRecord(
db_role="standby", db_role="standby",
label="test-2 (standby)", label="test-2 " "(standby)",
time="2022-01-21T13:11:30Z", time="2022-01-21T13:11:30Z",
value=573, value=573,
), ),

View File

@ -7,7 +7,7 @@ from app.account_linking import (
) )
from app.db import Session from app.db import Session
from app.models import User, PartnerUser, PartnerSubscription from app.models import User, PartnerUser, PartnerSubscription
from app.proton.proton_partner import get_proton_partner from app.proton.utils import get_proton_partner
from app.utils import random_string from app.utils import random_string
from tests.utils import random_email from tests.utils import random_email

View File

@ -1,17 +1,17 @@
from arrow import Arrow from arrow import Arrow
from app import config
from app.account_linking import ( from app.account_linking import (
SLPlan, SLPlan,
SLPlanType, SLPlanType,
) )
from app.constants import JobType
from app.proton.proton_client import ProtonClient, UserInformation from app.proton.proton_client import ProtonClient, UserInformation
from app.proton.proton_callback_handler import ( from app.proton.proton_callback_handler import (
ProtonCallbackHandler, ProtonCallbackHandler,
generate_account_not_allowed_to_log_in, generate_account_not_allowed_to_log_in,
) )
from app.models import User, PartnerUser, Job, JobState from app.models import User, PartnerUser, Job, JobState
from app.proton.proton_partner import get_proton_partner from app.proton.utils import get_proton_partner
from app.utils import random_string from app.utils import random_string
from typing import Optional from typing import Optional
from tests.utils import random_email from tests.utils import random_email
@ -28,7 +28,7 @@ class MockProtonClient(ProtonClient):
def check_initial_sync_job(user: User, expected: bool): def check_initial_sync_job(user: User, expected: bool):
found = False found = False
for job in Job.yield_per_query(10).filter_by( for job in Job.yield_per_query(10).filter_by(
name=JobType.SEND_ALIAS_CREATION_EVENTS.value, name=config.JOB_SEND_ALIAS_CREATION_EVENTS,
state=JobState.ready.value, state=JobState.ready.value,
): ):
if job.payload.get("user_id") == user.id: if job.payload.get("user_id") == user.id:

View File

@ -5,9 +5,9 @@ LOCAL_FILE_UPLOAD=1
# Email related settings # Email related settings
# Only print email content, not sending it # Only print email content, not sending it
NOT_SEND_EMAIL=true NOT_SEND_EMAIL=true
EMAIL_DOMAIN=sl.lan EMAIL_DOMAIN=sl.local
OTHER_ALIAS_DOMAINS=["d1.lan", "d2.lan", "sl.lan"] OTHER_ALIAS_DOMAINS=["d1.test", "d2.test", "sl.local"]
SUPPORT_EMAIL=support@sl.lan SUPPORT_EMAIL=support@sl.local
ADMIN_EMAIL=to_fill ADMIN_EMAIL=to_fill
# Max number emails user can generate for free plan # Max number emails user can generate for free plan
MAX_NB_EMAIL_FREE_PLAN=3 MAX_NB_EMAIL_FREE_PLAN=3

View File

@ -19,7 +19,7 @@ from app.account_linking import (
from app.db import Session from app.db import Session
from app.errors import AccountAlreadyLinkedToAnotherPartnerException from app.errors import AccountAlreadyLinkedToAnotherPartnerException
from app.models import Partner, PartnerUser, User, UserAuditLog from app.models import Partner, PartnerUser, User, UserAuditLog
from app.proton.proton_partner import get_proton_partner from app.proton.utils import get_proton_partner
from app.user_audit_log_utils import UserAuditLogAction from app.user_audit_log_utils import UserAuditLogAction
from app.utils import random_string, canonicalize_email from app.utils import random_string, canonicalize_email
from tests.utils import random_email from tests.utils import random_email

View File

@ -59,17 +59,11 @@ def test_set_mailboxes_for_alias_mailbox_success():
assert db_alias is not None assert db_alias is not None
assert db_alias.mailbox_id == mb1.id assert db_alias.mailbox_id == mb1.id
alias_mailboxes = ( alias_mailboxes = AliasMailbox.filter_by(alias_id=alias.id).all()
AliasMailbox.filter_by(alias_id=alias.id).order_by(AliasMailbox.id.asc()).all()
)
assert len(alias_mailboxes) == 1 assert len(alias_mailboxes) == 1
assert alias_mailboxes[0].mailbox_id == mb2.id assert alias_mailboxes[0].mailbox_id == mb2.id
audit_logs = ( audit_logs = AliasAuditLog.filter_by(alias_id=alias.id).all()
AliasAuditLog.filter_by(alias_id=alias.id)
.order_by(AliasAuditLog.id.asc())
.all()
)
assert len(audit_logs) == 2 assert len(audit_logs) == 2
assert audit_logs[0].action == AliasAuditLogAction.CreateAlias.value assert audit_logs[0].action == AliasAuditLogAction.CreateAlias.value
assert audit_logs[1].action == AliasAuditLogAction.ChangedMailboxes.value assert audit_logs[1].action == AliasAuditLogAction.ChangedMailboxes.value

View File

@ -3,7 +3,7 @@ import re
from app.alias_suffix import get_alias_suffixes from app.alias_suffix import get_alias_suffixes
from app.db import Session from app.db import Session
from app.models import SLDomain, PartnerUser, AliasOptions, CustomDomain from app.models import SLDomain, PartnerUser, AliasOptions, CustomDomain
from app.proton.proton_partner import get_proton_partner from app.proton.utils import get_proton_partner
from init_app import add_sl_domains from init_app import add_sl_domains
from tests.utils import create_new_user, random_token from tests.utils import create_new_user, random_token

View File

@ -6,7 +6,6 @@ from app.coupon_utils import (
CouponUserCannotRedeemError, CouponUserCannotRedeemError,
redeem_lifetime_coupon, redeem_lifetime_coupon,
) )
from app.db import Session
from app.models import ( from app.models import (
Coupon, Coupon,
Subscription, Subscription,
@ -15,11 +14,8 @@ from app.models import (
CoinbaseSubscription, CoinbaseSubscription,
LifetimeCoupon, LifetimeCoupon,
User, User,
PartnerSubscription,
PartnerUser,
) )
from app.proton.proton_partner import get_proton_partner from tests.utils import create_new_user, random_string
from tests.utils import create_new_user, random_string, random_email
def test_use_coupon(): def test_use_coupon():
@ -66,31 +62,6 @@ def test_use_coupon_extend_manual_sub():
assert left.days > 364 assert left.days > 364
def test_use_coupon_extend_expired_manual_sub():
user = create_new_user()
initial_end = arrow.now().shift(days=-15)
ManualSubscription.create(
user_id=user.id,
end_at=initial_end,
flush=True,
)
code = random_string(10)
Coupon.create(code=code, nb_year=1, commit=True)
coupon = redeem_coupon(code, user)
assert coupon
coupon = Coupon.get_by(code=code)
assert coupon
assert coupon.used
assert coupon.used_by_user_id == user.id
sub = user.get_active_subscription()
assert isinstance(sub, ManualSubscription)
left = sub.end_at - initial_end
assert left.days > 364
def test_coupon_with_subscription(): def test_coupon_with_subscription():
user = create_new_user() user = create_new_user()
end_at = arrow.utcnow().shift(days=1).replace(hour=0, minute=0, second=0) end_at = arrow.utcnow().shift(days=1).replace(hour=0, minute=0, second=0)
@ -186,55 +157,3 @@ def test_used_lifetime_coupon():
user = User.get(user.id) user = User.get(user.id)
assert not user.lifetime assert not user.lifetime
assert not user.paid_lifetime assert not user.paid_lifetime
def test_used_lifetime_coupon_with_lifetime_user():
user = create_new_user()
user.lifetime = True
code = random_string(10)
LifetimeCoupon.create(code=code, nb_used=10, paid=True)
coupon = redeem_lifetime_coupon(code, user)
assert coupon is None
def test_used_lifetime_coupon_with_lifetime_partner():
email = random_email()
user = User.create(email=email)
pu = PartnerUser.create(
user_id=user.id,
partner_id=get_proton_partner().id,
partner_email=email,
external_user_id=random_string(10),
flush=True,
)
PartnerSubscription.create(
partner_user_id=pu.id, end_at=arrow.utcnow().shift(years=10), lifetime=True
)
Session.flush()
code = random_string(10)
LifetimeCoupon.create(code=code, nb_used=10, paid=True)
coupon = redeem_lifetime_coupon(code, user)
assert coupon is None
def test_used_lifetime_coupon_with_partner_sub():
email = random_email()
user = User.create(email=email)
pu = PartnerUser.create(
user_id=user.id,
partner_id=get_proton_partner().id,
partner_email=email,
external_user_id=random_string(10),
flush=True,
)
PartnerSubscription.create(
partner_user_id=pu.id, end_at=arrow.utcnow().shift(years=10)
)
Session.flush()
code = random_string(10)
LifetimeCoupon.create(code=code, nb_used=10, paid=True)
coupon = redeem_lifetime_coupon(code, user)
assert coupon
user = User.get(user.id)
assert user.lifetime
assert user.paid_lifetime

View File

@ -13,8 +13,8 @@ from app.custom_domain_utils import (
) )
from app.db import Session from app.db import Session
from app.models import User, CustomDomain, Mailbox, DomainMailbox from app.models import User, CustomDomain, Mailbox, DomainMailbox
from tests.utils import create_new_user, random_string, random_domain
from tests.utils import get_proton_partner, random_email from tests.utils import get_proton_partner, random_email
from tests.utils import create_new_user, random_string, random_domain
user: Optional[User] = None user: Optional[User] = None
@ -195,11 +195,7 @@ def test_set_custom_domain_mailboxes_success():
assert res.success is True assert res.success is True
assert res.reason is None assert res.reason is None
domain_mailboxes = ( domain_mailboxes = DomainMailbox.filter_by(domain_id=domain.id).all()
DomainMailbox.filter_by(domain_id=domain.id)
.order_by(DomainMailbox.mailbox_id.asc())
.all()
)
assert len(domain_mailboxes) == 2 assert len(domain_mailboxes) == 2
assert domain_mailboxes[0].domain_id == domain.id assert domain_mailboxes[0].domain_id == domain.id
assert domain_mailboxes[0].mailbox_id == user.default_mailbox_id assert domain_mailboxes[0].mailbox_id == user.default_mailbox_id
@ -223,11 +219,7 @@ def test_set_custom_domain_mailboxes_set_twice():
assert res.success is True assert res.success is True
assert res.reason is None assert res.reason is None
domain_mailboxes = ( domain_mailboxes = DomainMailbox.filter_by(domain_id=domain.id).all()
DomainMailbox.filter_by(domain_id=domain.id)
.order_by(DomainMailbox.mailbox_id.asc())
.all()
)
assert len(domain_mailboxes) == 2 assert len(domain_mailboxes) == 2
assert domain_mailboxes[0].domain_id == domain.id assert domain_mailboxes[0].domain_id == domain.id
assert domain_mailboxes[0].mailbox_id == user.default_mailbox_id assert domain_mailboxes[0].mailbox_id == user.default_mailbox_id

Some files were not shown because too many files have changed in this diff Show More