4.47.2
All checks were successful
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m33s
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m39s
Build-Release-Image / Merge-Images (push) Successful in 12s
Build-Release-Image / Create-Release (push) Successful in 8s
Build-Release-Image / Notify (push) Successful in 5s

This commit is contained in:
MrMeeb 2024-07-30 12:00:06 +01:00
parent e7f0f81d85
commit dae6f64482
27 changed files with 348 additions and 84 deletions

1
app/.gitignore vendored
View File

@ -11,7 +11,6 @@ db.sqlite-journal
static/upload static/upload
venv/ venv/
.venv .venv
.python-version
.coverage .coverage
htmlcov htmlcov
adhoc adhoc

View File

@ -20,15 +20,15 @@ SimpleLogin backend consists of 2 main components:
## Install dependencies ## Install dependencies
The project requires: The project requires:
- Python 3.7+ and [poetry](https://python-poetry.org/) to manage dependencies - Python 3.10 and [rye](https://github.com/astral-sh/rye) to manage dependencies
- Node v10 for front-end. - Node v10 for front-end.
- Postgres 12+ - Postgres 13+
First, install all dependencies by running the following command. First, install all dependencies by running the following command.
Feel free to use `virtualenv` or similar tools to isolate development environment. Feel free to use `virtualenv` or similar tools to isolate development environment.
```bash ```bash
poetry install rye sync
``` ```
On Mac, sometimes you might need to install some other packages via `brew`: On Mac, sometimes you might need to install some other packages via `brew`:
@ -55,7 +55,7 @@ brew install -s re2 pybind11
We use pre-commit to run all our linting and static analysis checks. Please run We use pre-commit to run all our linting and static analysis checks. Please run
```bash ```bash
poetry run pre-commit install rye run pre-commit install
``` ```
To install it in your development environment. To install it in your development environment.
@ -160,25 +160,25 @@ Here are the small sum-ups of the directory structures and their roles:
The code is formatted using [ruff](https://github.com/astral-sh/ruff), to format the code, simply run The code is formatted using [ruff](https://github.com/astral-sh/ruff), to format the code, simply run
``` ```
poetry run ruff format . rye run ruff format .
``` ```
The code is also checked with `flake8`, make sure to run `flake8` before creating the pull request by The code is also checked with `flake8`, make sure to run `flake8` before creating the pull request by
```bash ```bash
poetry run flake8 rye run flake8
``` ```
For HTML templates, we use `djlint`. Before creating a pull request, please run For HTML templates, we use `djlint`. Before creating a pull request, please run
```bash ```bash
poetry run djlint --check templates rye run djlint --check templates
``` ```
If some files aren't properly formatted, you can format all files with If some files aren't properly formatted, you can format all files with
```bash ```bash
poetry run djlint --reformat . rye run djlint --reformat .
``` ```
## Test sending email ## Test sending email

View File

@ -541,7 +541,7 @@ exit
Once you've created all your desired login accounts, add these lines to `/simplelogin.env` to disable further registrations: Once you've created all your desired login accounts, add these lines to `/simplelogin.env` to disable further registrations:
``` ```.env
DISABLE_REGISTRATION=1 DISABLE_REGISTRATION=1
DISABLE_ONBOARDING=true DISABLE_ONBOARDING=true
``` ```

View File

@ -2,6 +2,7 @@ from typing import Optional
import arrow import arrow
import sqlalchemy import sqlalchemy
from flask_admin.form import SecureForm
from flask_admin.model.template import EndpointLinkRowAction from flask_admin.model.template import EndpointLinkRowAction
from markupsafe import Markup from markupsafe import Markup
@ -100,6 +101,7 @@ def _user_upgrade_channel_formatter(view, context, model, name):
class UserAdmin(SLModelView): class UserAdmin(SLModelView):
form_base_class = SecureForm
column_searchable_list = ["email", "id"] column_searchable_list = ["email", "id"]
column_exclude_list = [ column_exclude_list = [
"salt", "salt",
@ -344,6 +346,7 @@ def manual_upgrade(way: str, ids: [int], is_giveaway: bool):
class EmailLogAdmin(SLModelView): class EmailLogAdmin(SLModelView):
form_base_class = SecureForm
column_searchable_list = ["id"] column_searchable_list = ["id"]
column_filters = ["id", "user.email", "mailbox.email", "contact.website_email"] column_filters = ["id", "user.email", "mailbox.email", "contact.website_email"]
@ -352,6 +355,7 @@ class EmailLogAdmin(SLModelView):
class AliasAdmin(SLModelView): class AliasAdmin(SLModelView):
form_base_class = SecureForm
column_searchable_list = ["id", "user.email", "email", "mailbox.email"] column_searchable_list = ["id", "user.email", "email", "mailbox.email"]
column_filters = ["id", "user.email", "email", "mailbox.email"] column_filters = ["id", "user.email", "email", "mailbox.email"]
@ -377,6 +381,7 @@ class AliasAdmin(SLModelView):
class MailboxAdmin(SLModelView): class MailboxAdmin(SLModelView):
form_base_class = SecureForm
column_searchable_list = ["id", "user.email", "email"] column_searchable_list = ["id", "user.email", "email"]
column_filters = ["id", "user.email", "email"] column_filters = ["id", "user.email", "email"]
@ -387,11 +392,13 @@ class MailboxAdmin(SLModelView):
class CouponAdmin(SLModelView): class CouponAdmin(SLModelView):
form_base_class = SecureForm
can_edit = False can_edit = False
can_create = True can_create = True
class ManualSubscriptionAdmin(SLModelView): class ManualSubscriptionAdmin(SLModelView):
form_base_class = SecureForm
can_edit = True can_edit = True
column_searchable_list = ["id", "user.email"] column_searchable_list = ["id", "user.email"]
@ -433,12 +440,14 @@ class ManualSubscriptionAdmin(SLModelView):
class CustomDomainAdmin(SLModelView): class CustomDomainAdmin(SLModelView):
form_base_class = SecureForm
column_searchable_list = ["domain", "user.email", "user.id"] column_searchable_list = ["domain", "user.email", "user.id"]
column_exclude_list = ["ownership_txt_token"] column_exclude_list = ["ownership_txt_token"]
can_edit = False can_edit = False
class ReferralAdmin(SLModelView): class ReferralAdmin(SLModelView):
form_base_class = SecureForm
column_searchable_list = ["id", "user.email", "code", "name"] column_searchable_list = ["id", "user.email", "code", "name"]
column_filters = ["id", "user.email", "code", "name"] column_filters = ["id", "user.email", "code", "name"]
@ -467,6 +476,7 @@ def _admin_created_at_formatter(view, context, model, name):
class AdminAuditLogAdmin(SLModelView): class AdminAuditLogAdmin(SLModelView):
form_base_class = SecureForm
column_searchable_list = ["admin.id", "admin.email", "model_id", "created_at"] column_searchable_list = ["admin.id", "admin.email", "model_id", "created_at"]
column_filters = ["admin.id", "admin.email", "model_id", "created_at"] column_filters = ["admin.id", "admin.email", "model_id", "created_at"]
column_exclude_list = ["id"] column_exclude_list = ["id"]
@ -497,6 +507,7 @@ def _transactionalcomplaint_refused_email_id_formatter(view, context, model, nam
class ProviderComplaintAdmin(SLModelView): class ProviderComplaintAdmin(SLModelView):
form_base_class = SecureForm
column_searchable_list = ["id", "user.id", "created_at"] column_searchable_list = ["id", "user.id", "created_at"]
column_filters = ["user.id", "state"] column_filters = ["user.id", "state"]
column_hide_backrefs = False column_hide_backrefs = False
@ -567,6 +578,7 @@ def _newsletter_html_formatter(view, context, model: Newsletter, name):
class NewsletterAdmin(SLModelView): class NewsletterAdmin(SLModelView):
form_base_class = SecureForm
list_template = "admin/model/newsletter-list.html" list_template = "admin/model/newsletter-list.html"
edit_template = "admin/model/newsletter-edit.html" edit_template = "admin/model/newsletter-edit.html"
edit_modal = False edit_modal = False
@ -648,6 +660,7 @@ class NewsletterAdmin(SLModelView):
class NewsletterUserAdmin(SLModelView): class NewsletterUserAdmin(SLModelView):
form_base_class = SecureForm
column_searchable_list = ["id"] column_searchable_list = ["id"]
column_filters = ["id", "user.email", "newsletter.subject"] column_filters = ["id", "user.email", "newsletter.subject"]
column_exclude_list = ["created_at", "updated_at", "id"] column_exclude_list = ["created_at", "updated_at", "id"]
@ -657,17 +670,20 @@ class NewsletterUserAdmin(SLModelView):
class DailyMetricAdmin(SLModelView): class DailyMetricAdmin(SLModelView):
form_base_class = SecureForm
column_exclude_list = ["created_at", "updated_at", "id"] column_exclude_list = ["created_at", "updated_at", "id"]
can_export = True can_export = True
class MetricAdmin(SLModelView): class MetricAdmin(SLModelView):
form_base_class = SecureForm
column_exclude_list = ["created_at", "updated_at", "id"] column_exclude_list = ["created_at", "updated_at", "id"]
can_export = True can_export = True
class InvalidMailboxDomainAdmin(SLModelView): class InvalidMailboxDomainAdmin(SLModelView):
form_base_class = SecureForm
can_create = True can_create = True
can_delete = True can_delete = True

View File

@ -64,8 +64,12 @@ def verify_prefix_suffix(
# SimpleLogin domain case: # SimpleLogin domain case:
# 1) alias_suffix must start with "." and # 1) alias_suffix must start with "." and
# 2) alias_domain_prefix must come from the word list # 2) alias_domain_prefix must come from the word list
available_sl_domains = [
sl_domain.domain
for sl_domain in user.get_sl_domains(alias_options=alias_options)
]
if ( if (
alias_domain in user.available_sl_domains(alias_options=alias_options) alias_domain in available_sl_domains
and alias_domain not in user_custom_domains and alias_domain not in user_custom_domains
# when DISABLE_ALIAS_SUFFIX is true, alias_domain_prefix is empty # when DISABLE_ALIAS_SUFFIX is true, alias_domain_prefix is empty
and not config.DISABLE_ALIAS_SUFFIX and not config.DISABLE_ALIAS_SUFFIX
@ -80,9 +84,7 @@ def verify_prefix_suffix(
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user) LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
return False return False
if alias_domain not in user.available_sl_domains( if alias_domain not in available_sl_domains:
alias_options=alias_options
):
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user) LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
return False return False

View File

@ -10,6 +10,7 @@ from app.api.base import api_bp, require_api_auth
from app.config import SESSION_COOKIE_NAME from app.config import SESSION_COOKIE_NAME
from app.dashboard.views.index import get_stats from app.dashboard.views.index import get_stats
from app.db import Session from app.db import Session
from app.image_validation import detect_image_format, ImageFormat
from app.models import ApiKey, File, PartnerUser, User from app.models import ApiKey, File, PartnerUser, User
from app.proton.utils import get_proton_partner from app.proton.utils import get_proton_partner
from app.session import logout_session from app.session import logout_session
@ -78,17 +79,18 @@ def update_user_info():
data = request.get_json() or {} data = request.get_json() or {}
if "profile_picture" in data: if "profile_picture" in data:
if data["profile_picture"] is None: if user.profile_picture_id:
if user.profile_picture_id: file = user.profile_picture
file = user.profile_picture user.profile_picture_id = None
user.profile_picture_id = None Session.flush()
if file:
File.delete(file.id)
s3.delete(file.path)
Session.flush() Session.flush()
if file:
File.delete(file.id)
s3.delete(file.path)
Session.flush()
else: else:
raw_data = base64.decodebytes(data["profile_picture"].encode()) raw_data = base64.decodebytes(data["profile_picture"].encode())
if detect_image_format(raw_data) == ImageFormat.Unknown:
return jsonify(error="Unsupported image format"), 400
file_path = random_string(30) file_path = random_string(30)
file = File.create(user_id=user.id, path=file_path) file = File.create(user_id=user.id, path=file_path)
Session.flush() Session.flush()

View File

@ -115,7 +115,8 @@ def register():
def send_activation_email(user, next_url): def send_activation_email(user, next_url):
# the activation code is valid for 1h # the activation code is valid for 1h and delete all previous codes
Session.query(ActivationCode).filter(ActivationCode.user_id == user.id).delete()
activation = ActivationCode.create(user_id=user.id, code=random_string(30)) activation = ActivationCode.create(user_id=user.id, code=random_string(30))
Session.commit() Session.commit()

View File

@ -14,7 +14,7 @@ from flask_wtf import FlaskForm
from flask_wtf.file import FileField from flask_wtf.file import FileField
from wtforms import StringField, validators from wtforms import StringField, validators
from app import s3 from app import s3, user_settings
from app.config import ( from app.config import (
FIRST_ALIAS_DOMAIN, FIRST_ALIAS_DOMAIN,
ALIAS_RANDOM_SUFFIX_LENGTH, ALIAS_RANDOM_SUFFIX_LENGTH,
@ -31,12 +31,10 @@ from app.models import (
PlanEnum, PlanEnum,
File, File,
EmailChange, EmailChange,
CustomDomain,
AliasGeneratorEnum, AliasGeneratorEnum,
AliasSuffixEnum, AliasSuffixEnum,
ManualSubscription, ManualSubscription,
SenderFormatEnum, SenderFormatEnum,
SLDomain,
CoinbaseSubscription, CoinbaseSubscription,
AppleSubscription, AppleSubscription,
PartnerUser, PartnerUser,
@ -166,38 +164,11 @@ def setting():
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
elif request.form.get("form-name") == "change-random-alias-default-domain": elif request.form.get("form-name") == "change-random-alias-default-domain":
default_domain = request.form.get("random-alias-default-domain") default_domain = request.form.get("random-alias-default-domain")
try:
if default_domain: user_settings.set_default_alias_domain(current_user, default_domain)
sl_domain: SLDomain = SLDomain.get_by(domain=default_domain) except user_settings.CannotSetAlias as e:
if sl_domain: flash(e.msg, "error")
if sl_domain.premium_only and not current_user.is_premium(): return redirect(url_for("dashboard.setting"))
flash("You cannot use this domain", "error")
return redirect(url_for("dashboard.setting"))
current_user.default_alias_public_domain_id = sl_domain.id
current_user.default_alias_custom_domain_id = None
else:
custom_domain = CustomDomain.get_by(domain=default_domain)
if custom_domain:
# sanity check
if (
custom_domain.user_id != current_user.id
or not custom_domain.verified
):
LOG.w(
"%s cannot use domain %s", current_user, custom_domain
)
flash(f"Domain {default_domain} can't be used", "error")
return redirect(request.url)
else:
current_user.default_alias_custom_domain_id = (
custom_domain.id
)
current_user.default_alias_public_domain_id = None
else:
current_user.default_alias_custom_domain_id = None
current_user.default_alias_public_domain_id = None
Session.commit() Session.commit()
flash("Your preference has been updated", "success") flash("Your preference has been updated", "success")

View File

@ -1,4 +1,5 @@
from io import BytesIO from io import BytesIO
from urllib.parse import urlparse
from flask import request, render_template, redirect, url_for, flash from flask import request, render_template, redirect, url_for, flash
from flask_login import current_user, login_required from flask_login import current_user, login_required
@ -11,6 +12,7 @@ from app.config import ADMIN_EMAIL
from app.db import Session from app.db import Session
from app.developer.base import developer_bp from app.developer.base import developer_bp
from app.email_utils import send_email from app.email_utils import send_email
from app.image_validation import detect_image_format, ImageFormat
from app.log import LOG from app.log import LOG
from app.models import Client, RedirectUri, File, Referral from app.models import Client, RedirectUri, File, Referral
from app.utils import random_string from app.utils import random_string
@ -46,16 +48,25 @@ def client_detail(client_id):
approval_form.description.data = client.description approval_form.description.data = client.description
if action == "edit" and form.validate_on_submit(): if action == "edit" and form.validate_on_submit():
parsed_url = urlparse(form.url.data)
if parsed_url.scheme != "https":
flash("Only https urls are allowed", "error")
return redirect(url_for("developer.index"))
client.name = form.name.data client.name = form.name.data
client.home_url = form.url.data client.home_url = form.url.data
if form.icon.data: if form.icon.data:
# todo: remove current icon if any icon_data = form.icon.data.read(10240)
# todo: handle remove icon if detect_image_format(icon_data) == ImageFormat.Unknown:
flash("Unknown file format", "warning")
return redirect(url_for("developer.index"))
if client.icon:
s3.delete(client.icon_id)
File.delete(client.icon)
file_path = random_string(30) file_path = random_string(30)
file = File.create(path=file_path, user_id=client.user_id) file = File.create(path=file_path, user_id=client.user_id)
s3.upload_from_bytesio(file_path, BytesIO(form.icon.data.read())) s3.upload_from_bytesio(file_path, BytesIO(icon_data))
Session.flush() Session.flush()
LOG.d("upload file %s to s3", file) LOG.d("upload file %s to s3", file)

View File

@ -1,3 +1,5 @@
from urllib.parse import urlparse
from flask import render_template, redirect, url_for, flash from flask import render_template, redirect, url_for, flash
from flask_login import current_user, login_required from flask_login import current_user, login_required
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
@ -20,6 +22,10 @@ def new_client():
if form.validate_on_submit(): if form.validate_on_submit():
client = Client.create_new(form.name.data, current_user.id) client = Client.create_new(form.name.data, current_user.id)
parsed_url = urlparse(form.url.data)
if parsed_url.scheme != "https":
flash("Only https urls are allowed", "error")
return redirect(url_for("developer.new_client"))
client.home_url = form.url.data client.home_url = form.url.data
Session.commit() Session.commit()

View File

@ -985,8 +985,8 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
- the domain - the domain
""" """
res = [] res = []
for domain in self.available_sl_domains(alias_options=alias_options): for domain in self.get_sl_domains(alias_options=alias_options):
res.append((True, domain)) res.append((True, domain.domain))
for custom_domain in self.verified_custom_domains(): for custom_domain in self.verified_custom_domains():
res.append((False, custom_domain.domain)) res.append((False, custom_domain.domain))
@ -1128,7 +1128,10 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
- Verified custom domains - Verified custom domains
""" """
domains = self.available_sl_domains(alias_options=alias_options) domains = [
sl_domain.domain
for sl_domain in self.get_sl_domains(alias_options=alias_options)
]
for custom_domain in self.verified_custom_domains(): for custom_domain in self.verified_custom_domains():
domains.append(custom_domain.domain) domains.append(custom_domain.domain)
@ -2483,7 +2486,7 @@ class CustomDomain(Base, ModelMixin):
return sorted(self._auto_create_rules, key=lambda rule: rule.order) return sorted(self._auto_create_rules, key=lambda rule: rule.order)
def __repr__(self): def __repr__(self):
return f"<Custom Domain {self.domain}>" return f"<Custom Domain {self.id} {self.domain}>"
class AutoCreateRule(Base, ModelMixin): class AutoCreateRule(Base, ModelMixin):
@ -3114,7 +3117,7 @@ class SLDomain(Base, ModelMixin):
) )
def __repr__(self): def __repr__(self):
return f"<SLDomain {self.domain} {'Premium' if self.premium_only else 'Free'}" return f"<SLDomain {self.id} {self.domain} {'Premium' if self.premium_only else 'Free'}>"
class Monitoring(Base, ModelMixin): class Monitoring(Base, ModelMixin):
@ -3484,6 +3487,7 @@ class AdminAuditLog(Base):
action=AuditLogActionEnum.stop_trial.value, action=AuditLogActionEnum.stop_trial.value,
model="User", model="User",
model_id=user_id, model_id=user_id,
data={},
) )
@classmethod @classmethod
@ -3729,6 +3733,7 @@ class SyncEvent(Base, ModelMixin):
taken_time = sa.Column( taken_time = sa.Column(
ArrowType, default=None, nullable=True, server_default=None, index=True ArrowType, default=None, nullable=True, server_default=None, index=True
) )
retry_count = sa.Column(sa.Integer, default=0, nullable=False, server_default="0")
__table_args__ = ( __table_args__ = (
sa.Index("ix_sync_event_created_at", "created_at"), sa.Index("ix_sync_event_created_at", "created_at"),
@ -3750,7 +3755,7 @@ class SyncEvent(Base, ModelMixin):
return res.rowcount > 0 return res.rowcount > 0
@classmethod @classmethod
def get_dead_letter(cls, older_than: Arrow) -> [SyncEvent]: def get_dead_letter(cls, older_than: Arrow, max_retries: int) -> [SyncEvent]:
return ( return (
SyncEvent.filter( SyncEvent.filter(
( (
@ -3763,6 +3768,7 @@ class SyncEvent(Base, ModelMixin):
& (SyncEvent.created_at < older_than) & (SyncEvent.created_at < older_than)
) )
) )
& (SyncEvent.retry_count < max_retries)
) )
.order_by(SyncEvent.id) .order_by(SyncEvent.id)
.limit(100) .limit(100)

47
app/app/user_settings.py Normal file
View File

@ -0,0 +1,47 @@
from typing import Optional
from app.db import Session
from app.log import LOG
from app.models import User, SLDomain, CustomDomain
class CannotSetAlias(Exception):
def __init__(self, msg: str):
self.msg = msg
def set_default_alias_domain(user: User, domain_name: Optional[str]):
if domain_name is None:
LOG.i(f"User {user} has set no domain as default domain")
user.default_alias_public_domain_id = None
user.default_alias_custom_domain_id = None
Session.flush()
return
sl_domain: SLDomain = SLDomain.get_by(domain=domain_name)
if sl_domain:
if sl_domain.hidden:
LOG.i(f"User {user} has tried to set up a hidden domain as default domain")
raise CannotSetAlias("Domain does not exist")
if sl_domain.premium_only and not user.is_premium():
LOG.i(f"User {user} has tried to set up a premium domain as default domain")
raise CannotSetAlias("You cannot use this domain")
LOG.i(f"User {user} has set public {sl_domain} as default domain")
user.default_alias_public_domain_id = sl_domain.id
user.default_alias_custom_domain_id = None
Session.flush()
return
custom_domain = CustomDomain.get_by(domain=domain_name)
if not custom_domain:
LOG.i(
f"User {user} has tried to set up an non existing domain as default domain"
)
raise CannotSetAlias("Domain does not exist or it hasn't been verified")
if custom_domain.user_id != user.id or not custom_domain.verified:
LOG.i(
f"User {user} has tried to set domain {custom_domain} as default domain that does not belong to the user or that is not verified"
)
raise CannotSetAlias("Domain does not exist or it hasn't been verified")
LOG.i(f"User {user} has set custom {custom_domain} as default domain")
user.default_alias_public_domain_id = None
user.default_alias_custom_domain_id = custom_domain.id
Session.flush()

View File

@ -262,7 +262,8 @@ def get_or_create_contact(from_header: str, mail_from: str, alias: Alias) -> Con
Session.commit() Session.commit()
except IntegrityError: except IntegrityError:
# No need to manually rollback, as IntegrityError already rolls back # If the tx has been rolled back, the connection is borked. Force close to try to get a new one and start fresh
Session.close()
LOG.info( LOG.info(
f"Contact with email {contact_email} for alias_id {alias_id} already existed, fetching from DB" f"Contact with email {contact_email} for alias_id {alias_id} already existed, fetching from DB"
) )

View File

@ -8,6 +8,8 @@ from events.runner import Runner
from events.event_source import DeadLetterEventSource, PostgresEventSource from events.event_source import DeadLetterEventSource, PostgresEventSource
from events.event_sink import ConsoleEventSink, HttpEventSink from events.event_sink import ConsoleEventSink, HttpEventSink
_DEFAULT_MAX_RETRIES = 100
class Mode(Enum): class Mode(Enum):
DEAD_LETTER = "dead_letter" DEAD_LETTER = "dead_letter"
@ -23,10 +25,10 @@ class Mode(Enum):
raise ValueError(f"Invalid mode: {value}") raise ValueError(f"Invalid mode: {value}")
def main(mode: Mode, dry_run: bool): def main(mode: Mode, dry_run: bool, max_retries: int):
if mode == Mode.DEAD_LETTER: if mode == Mode.DEAD_LETTER:
LOG.i("Using DeadLetterEventSource") LOG.i("Using DeadLetterEventSource")
source = DeadLetterEventSource() source = DeadLetterEventSource(max_retries)
elif mode == Mode.LISTENER: elif mode == Mode.LISTENER:
LOG.i("Using PostgresEventSource") LOG.i("Using PostgresEventSource")
source = PostgresEventSource(DB_URI) source = PostgresEventSource(DB_URI)
@ -51,6 +53,13 @@ def args():
help="Mode to run", help="Mode to run",
choices=[Mode.DEAD_LETTER.value, Mode.LISTENER.value], choices=[Mode.DEAD_LETTER.value, Mode.LISTENER.value],
) )
parser.add_argument(
"max_retries",
help="Max retries to consider an event as error and not try to process it again",
type=int,
nargs="?",
default=_DEFAULT_MAX_RETRIES,
)
parser.add_argument("--dry-run", help="Dry run mode", action="store_true") parser.add_argument("--dry-run", help="Dry run mode", action="store_true")
return parser.parse_args() return parser.parse_args()
@ -61,4 +70,8 @@ if __name__ == "__main__":
exit(1) exit(1)
args = args() args = args()
main(Mode.from_str(args.mode), args.dry_run) main(
mode=Mode.from_str(args.mode),
dry_run=args.dry_run,
max_retries=args.max_retries,
)

View File

@ -4,6 +4,8 @@ import psycopg2
import select import select
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import SyncEvent from app.models import SyncEvent
from app.events.event_dispatcher import NOTIFICATION_CHANNEL from app.events.event_dispatcher import NOTIFICATION_CHANNEL
@ -66,6 +68,7 @@ class PostgresEventSource(EventSource):
LOG.info(f"Could not find event with id={notify.payload}") LOG.info(f"Could not find event with id={notify.payload}")
except Exception as e: except Exception as e:
LOG.warn(f"Error getting event: {e}") LOG.warn(f"Error getting event: {e}")
Session.close() # Ensure we get a new connection and we don't leave a dangling tx
def __connect(self): def __connect(self):
self.__connection = psycopg2.connect(self.__connection_string) self.__connection = psycopg2.connect(self.__connection_string)
@ -76,6 +79,9 @@ class PostgresEventSource(EventSource):
class DeadLetterEventSource(EventSource): class DeadLetterEventSource(EventSource):
def __init__(self, max_retries: int):
self.__max_retries = max_retries
@newrelic.agent.background_task() @newrelic.agent.background_task()
def run(self, on_event: Callable[[SyncEvent], NoReturn]): def run(self, on_event: Callable[[SyncEvent], NoReturn]):
while True: while True:
@ -83,7 +89,9 @@ class DeadLetterEventSource(EventSource):
threshold = arrow.utcnow().shift( threshold = arrow.utcnow().shift(
minutes=-_DEAD_LETTER_THRESHOLD_MINUTES minutes=-_DEAD_LETTER_THRESHOLD_MINUTES
) )
events = SyncEvent.get_dead_letter(older_than=threshold) events = SyncEvent.get_dead_letter(
older_than=threshold, max_retries=self.__max_retries
)
if events: if events:
LOG.info(f"Got {len(events)} dead letter events") LOG.info(f"Got {len(events)} dead letter events")
if events: if events:
@ -92,7 +100,8 @@ class DeadLetterEventSource(EventSource):
) )
for event in events: for event in events:
on_event(event) on_event(event)
else: Session.close() # Ensure that we have a new connection and we don't have a dangling tx with a lock
if not events:
LOG.debug("No dead letter events") LOG.debug("No dead letter events")
sleep(_DEAD_LETTER_INTERVAL_SECONDS) sleep(_DEAD_LETTER_INTERVAL_SECONDS)
except Exception as e: except Exception as e:

View File

@ -2,6 +2,7 @@ import arrow
import newrelic.agent import newrelic.agent
from app.log import LOG from app.log import LOG
from app.db import Session
from app.models import SyncEvent from app.models import SyncEvent
from events.event_sink import EventSink from events.event_sink import EventSink
from events.event_source import EventSource from events.event_source import EventSource
@ -37,6 +38,9 @@ class Runner:
"Custom/sync_event_elapsed_time", "Custom/sync_event_elapsed_time",
time_between_taken_and_created.total_seconds(), time_between_taken_and_created.total_seconds(),
) )
else:
event.retry_count = event.retry_count + 1
Session.commit()
except Exception as e: except Exception as e:
LOG.warn(f"Exception processing event [id={event.id}]: {e}") LOG.warn(f"Exception processing event [id={event.id}]: {e}")
newrelic.agent.record_custom_metric("Custom/sync_event_failed", 1) newrelic.agent.record_custom_metric("Custom/sync_event_failed", 1)

View File

@ -0,0 +1,28 @@
"""add retry count to sync event
Revision ID: 56d08955fcab
Revises: d608b8e48082
Create Date: 2024-07-19 08:21:19.979973
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '56d08955fcab'
down_revision = 'd608b8e48082'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('sync_event', sa.Column('retry_count', sa.Integer(), server_default='0', nullable=False, default=0))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('sync_event', 'retry_count')
# ### end Alembic commands ###

View File

@ -12,10 +12,10 @@ docker run -p 25432:5432 --name ${container_name} -e POSTGRES_PASSWORD=postgres
sleep 3 sleep 3
# upgrade the DB to the latest stage and # upgrade the DB to the latest stage and
env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl poetry run alembic upgrade head env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl rye run alembic upgrade head
# generate the migration script. # generate the migration script.
env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl poetry run alembic revision --autogenerate $@ env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl rye run alembic revision --autogenerate $@
# remove the db # remove the db
docker rm -f ${container_name} docker rm -f ${container_name}

View File

@ -3,5 +3,5 @@
export DB_URI=postgresql://myuser:mypassword@localhost:15432/simplelogin export DB_URI=postgresql://myuser:mypassword@localhost:15432/simplelogin
echo 'drop schema public cascade; create schema public;' | psql $DB_URI echo 'drop schema public cascade; create schema public;' | psql $DB_URI
poetry run alembic upgrade head rye run alembic upgrade head
poetry run flask dummy-data rye run flask dummy-data

View File

@ -3,4 +3,4 @@
export DB_URI=postgresql://myuser:mypassword@localhost:15432/test export DB_URI=postgresql://myuser:mypassword@localhost:15432/test
echo 'drop schema public cascade; create schema public;' | psql $DB_URI echo 'drop schema public cascade; create schema public;' | psql $DB_URI
poetry run alembic upgrade head rye run alembic upgrade head

View File

@ -283,6 +283,7 @@ def set_index_page(app):
and not request.path.startswith("/git") and not request.path.startswith("/git")
and not request.path.startswith("/favicon.ico") and not request.path.startswith("/favicon.ico")
): ):
start_time = g.start_time or time.time()
LOG.d( LOG.d(
"%s %s %s %s %s, takes %s", "%s %s %s %s %s, takes %s",
request.remote_addr, request.remote_addr,
@ -290,7 +291,7 @@ def set_index_page(app):
request.path, request.path,
request.args, request.args,
res.status_code, res.status_code,
time.time() - g.start_time, time.time() - start_time,
) )
return res return res

View File

@ -44,6 +44,9 @@ def test_update_settings_alias_generator(flask_client):
def test_update_settings_random_alias_default_domain(flask_client): def test_update_settings_random_alias_default_domain(flask_client):
user = login(flask_client) user = login(flask_client)
custom_domain = CustomDomain.create(
domain=random_domain(), verified=True, user_id=user.id, flush=True
)
assert user.default_random_alias_domain() == "sl.local" assert user.default_random_alias_domain() == "sl.local"
r = flask_client.patch( r = flask_client.patch(
@ -57,6 +60,12 @@ def test_update_settings_random_alias_default_domain(flask_client):
assert r.status_code == 200 assert r.status_code == 200
assert user.default_random_alias_domain() == "d1.test" assert user.default_random_alias_domain() == "d1.test"
r = flask_client.patch(
"/api/setting", json={"random_alias_default_domain": custom_domain.domain}
)
assert r.status_code == 200
assert user.default_random_alias_domain() == custom_domain.domain
def test_update_settings_sender_format(flask_client): def test_update_settings_sender_format(flask_client):
user = login(flask_client) user = login(flask_client)

View File

@ -23,7 +23,7 @@ _MAX_PER_MINUTE = 3
_ENDPOINT, _ENDPOINT,
methods=["GET"], methods=["GET"],
) )
@limiter.limit(f"{_MAX_PER_MINUTE}/minute") @limiter.limit(f"{_MAX_PER_MINUTE}/hour")
def rate_limited_endpoint_1(): def rate_limited_endpoint_1():
return "Working", HTTPStatus.OK return "Working", HTTPStatus.OK

View File

@ -365,12 +365,22 @@ def test_sync_event_dead_letter():
commit=True, commit=True,
) )
# create event with too many retries
max_retries = 5
e5 = SyncEvent.create(
content=b"content",
retry_count=max_retries + 1,
created_at=arrow.now(),
commit=True,
)
# get dead letter events # get dead letter events
dead_letter_events = SyncEvent.get_dead_letter( dead_letter_events = SyncEvent.get_dead_letter(
older_than=arrow.now().shift(minutes=-10) older_than=arrow.now().shift(minutes=-10), max_retries=max_retries
) )
assert len(dead_letter_events) == 2 assert len(dead_letter_events) == 2
assert e1 in dead_letter_events assert e1 in dead_letter_events
assert e2 in dead_letter_events assert e2 in dead_letter_events
assert e3 not in dead_letter_events assert e3 not in dead_letter_events
assert e4 not in dead_letter_events assert e4 not in dead_letter_events
assert e5 not in dead_letter_events

View File

View File

@ -0,0 +1,128 @@
import pytest
from app import user_settings
from app.db import Session
from app.models import User, CustomDomain, SLDomain
from tests.utils import random_token, create_new_user
user_id: int = 0
custom_domain_name: str = ""
sl_domain_name: str = ""
def setup_module():
global user_id, custom_domain_name, sl_domain_name
user = create_new_user()
user.trial_end = None
user_id = user.id
custom_domain_name = CustomDomain.create(
user_id=user_id,
catch_all=True,
domain=random_token() + ".com",
verified=True,
flush=True,
).domain
sl_domain_name = SLDomain.create(
domain=random_token() + ".com",
premium_only=False,
flush=True,
order=5,
hidden=False,
).domain
def test_set_default_no_domain():
user = User.get(user_id)
user.default_alias_public_domain_id = SLDomain.get_by(domain=sl_domain_name).id
user.default_alias_private_domain_id = CustomDomain.get_by(
domain=custom_domain_name
).id
Session.flush()
user_settings.set_default_alias_domain(user, None)
assert user.default_alias_public_domain_id is None
assert user.default_alias_custom_domain_id is None
def test_set_premium_sl_domain_with_non_premium_user():
user = User.get(user_id)
user.lifetime = False
domain = SLDomain.get_by(domain=sl_domain_name)
domain.premium_only = True
Session.flush()
with pytest.raises(user_settings.CannotSetAlias):
user_settings.set_default_alias_domain(user, sl_domain_name)
def test_set_hidden_sl_domain():
user = User.get(user_id)
domain = SLDomain.get_by(domain=sl_domain_name)
domain.hidden = True
domain.premium_only = False
Session.flush()
with pytest.raises(user_settings.CannotSetAlias):
user_settings.set_default_alias_domain(user, sl_domain_name)
def test_set_sl_domain():
user = User.get(user_id)
user.lifetime = False
domain = SLDomain.get_by(domain=sl_domain_name)
domain.hidden = False
domain.premium_only = False
Session.flush()
user_settings.set_default_alias_domain(user, sl_domain_name)
assert user.default_alias_public_domain_id == domain.id
assert user.default_alias_custom_domain_id is None
def test_set_sl_premium_domain():
user = User.get(user_id)
user.lifetime = True
domain = SLDomain.get_by(domain=sl_domain_name)
domain.hidden = False
domain.premium_only = True
Session.flush()
user_settings.set_default_alias_domain(user, sl_domain_name)
assert user.default_alias_public_domain_id == domain.id
assert user.default_alias_custom_domain_id is None
def test_set_other_user_custom_domain():
user = User.get(user_id)
user.lifetime = True
other_user_domain_name = CustomDomain.create(
user_id=create_new_user().id,
catch_all=True,
domain=random_token() + ".com",
verified=True,
).domain
Session.flush()
with pytest.raises(user_settings.CannotSetAlias):
user_settings.set_default_alias_domain(user, other_user_domain_name)
def test_set_unverified_custom_domain():
user = User.get(user_id)
user.lifetime = True
domain = CustomDomain.get_by(domain=custom_domain_name)
domain.verified = False
Session.flush()
with pytest.raises(user_settings.CannotSetAlias):
user_settings.set_default_alias_domain(user, custom_domain_name)
def test_set_custom_domain():
user = User.get(user_id)
user.lifetime = True
domain = CustomDomain.get_by(domain=custom_domain_name)
domain.verified = True
Session.flush()
user_settings.set_default_alias_domain(user, custom_domain_name)
assert user.default_alias_public_domain_id is None
assert user.default_alias_custom_domain_id == domain.id
def test_set_invalid_custom_domain():
user = User.get(user_id)
with pytest.raises(user_settings.CannotSetAlias):
user_settings.set_default_alias_domain(user, "invalid_nop" + random_token())