From dae6f644821aaee2accf855cb85f9aac0af0e9a9 Mon Sep 17 00:00:00 2001 From: MrMeeb Date: Tue, 30 Jul 2024 12:00:06 +0100 Subject: [PATCH] 4.47.2 --- app/.dockerignore | 2 +- app/.gitignore | 3 +- app/CONTRIBUTING.md | 18 +-- app/README.md | 2 +- app/app/admin_model.py | 16 +++ app/app/alias_suffix.py | 10 +- app/app/api/views/user_info.py | 18 +-- app/app/auth/views/register.py | 3 +- app/app/dashboard/views/setting.py | 41 +----- app/app/developer/views/client_detail.py | 17 ++- app/app/developer/views/new_client.py | 6 + app/app/models.py | 18 ++- app/app/user_settings.py | 47 +++++++ app/email_handler.py | 3 +- app/event_listener.py | 19 ++- app/events/event_source.py | 13 +- app/events/runner.py | 4 + ...08955fcab_add_retry_count_to_sync_event.py | 28 ++++ app/scripts/new-migration.sh | 4 +- app/scripts/reset_local_db.sh | 4 +- app/scripts/reset_test_db.sh | 2 +- app/server.py | 3 +- app/tests/api/test_setting.py | 9 ++ app/tests/test_extensions.py | 2 +- app/tests/test_models.py | 12 +- app/tests/user_settings/__init__.py | 0 .../test_set_default_alias_domain.py | 128 ++++++++++++++++++ 27 files changed, 348 insertions(+), 84 deletions(-) create mode 100644 app/app/user_settings.py create mode 100644 app/migrations/versions/2024_071908_56d08955fcab_add_retry_count_to_sync_event.py create mode 100644 app/tests/user_settings/__init__.py create mode 100644 app/tests/user_settings/test_set_default_alias_domain.py diff --git a/app/.dockerignore b/app/.dockerignore index f749263..9ae272b 100644 --- a/app/.dockerignore +++ b/app/.dockerignore @@ -14,4 +14,4 @@ venv/ .venv .coverage htmlcov -.git/ \ No newline at end of file +.git/ diff --git a/app/.gitignore b/app/.gitignore index e7cea4d..fafbd5e 100644 --- a/app/.gitignore +++ b/app/.gitignore @@ -11,8 +11,7 @@ db.sqlite-journal static/upload venv/ .venv -.python-version .coverage htmlcov adhoc -.env.* \ No newline at end of file +.env.* diff --git a/app/CONTRIBUTING.md b/app/CONTRIBUTING.md index 6dc29ea..89a283d 100644 --- a/app/CONTRIBUTING.md +++ b/app/CONTRIBUTING.md @@ -20,15 +20,15 @@ SimpleLogin backend consists of 2 main components: ## Install dependencies The project requires: -- Python 3.7+ and [poetry](https://python-poetry.org/) to manage dependencies +- Python 3.10 and [rye](https://github.com/astral-sh/rye) to manage dependencies - Node v10 for front-end. -- Postgres 12+ +- Postgres 13+ First, install all dependencies by running the following command. Feel free to use `virtualenv` or similar tools to isolate development environment. ```bash -poetry install +rye sync ``` On Mac, sometimes you might need to install some other packages via `brew`: @@ -55,7 +55,7 @@ brew install -s re2 pybind11 We use pre-commit to run all our linting and static analysis checks. Please run ```bash -poetry run pre-commit install +rye run pre-commit install ``` To install it in your development environment. @@ -160,25 +160,25 @@ Here are the small sum-ups of the directory structures and their roles: The code is formatted using [ruff](https://github.com/astral-sh/ruff), to format the code, simply run ``` -poetry run ruff format . +rye run ruff format . ``` The code is also checked with `flake8`, make sure to run `flake8` before creating the pull request by ```bash -poetry run flake8 +rye run flake8 ``` For HTML templates, we use `djlint`. Before creating a pull request, please run ```bash -poetry run djlint --check templates +rye run djlint --check templates ``` If some files aren't properly formatted, you can format all files with ```bash -poetry run djlint --reformat . +rye run djlint --reformat . ``` ## Test sending email @@ -225,4 +225,4 @@ Now open http://localhost:1080/ (or http://localhost:1080/ for MailHog), you sho Some features require a job handler (such as GDPR data export). To test such feature you need to run the job_runner ```bash python job_runner.py -``` \ No newline at end of file +``` diff --git a/app/README.md b/app/README.md index 5ff9bcd..b98f317 100644 --- a/app/README.md +++ b/app/README.md @@ -541,7 +541,7 @@ exit Once you've created all your desired login accounts, add these lines to `/simplelogin.env` to disable further registrations: -``` +```.env DISABLE_REGISTRATION=1 DISABLE_ONBOARDING=true ``` diff --git a/app/app/admin_model.py b/app/app/admin_model.py index 8cf0a54..0aa7edd 100644 --- a/app/app/admin_model.py +++ b/app/app/admin_model.py @@ -2,6 +2,7 @@ from typing import Optional import arrow import sqlalchemy +from flask_admin.form import SecureForm from flask_admin.model.template import EndpointLinkRowAction from markupsafe import Markup @@ -100,6 +101,7 @@ def _user_upgrade_channel_formatter(view, context, model, name): class UserAdmin(SLModelView): + form_base_class = SecureForm column_searchable_list = ["email", "id"] column_exclude_list = [ "salt", @@ -344,6 +346,7 @@ def manual_upgrade(way: str, ids: [int], is_giveaway: bool): class EmailLogAdmin(SLModelView): + form_base_class = SecureForm column_searchable_list = ["id"] column_filters = ["id", "user.email", "mailbox.email", "contact.website_email"] @@ -352,6 +355,7 @@ class EmailLogAdmin(SLModelView): class AliasAdmin(SLModelView): + form_base_class = SecureForm column_searchable_list = ["id", "user.email", "email", "mailbox.email"] column_filters = ["id", "user.email", "email", "mailbox.email"] @@ -377,6 +381,7 @@ class AliasAdmin(SLModelView): class MailboxAdmin(SLModelView): + form_base_class = SecureForm column_searchable_list = ["id", "user.email", "email"] column_filters = ["id", "user.email", "email"] @@ -387,11 +392,13 @@ class MailboxAdmin(SLModelView): class CouponAdmin(SLModelView): + form_base_class = SecureForm can_edit = False can_create = True class ManualSubscriptionAdmin(SLModelView): + form_base_class = SecureForm can_edit = True column_searchable_list = ["id", "user.email"] @@ -433,12 +440,14 @@ class ManualSubscriptionAdmin(SLModelView): class CustomDomainAdmin(SLModelView): + form_base_class = SecureForm column_searchable_list = ["domain", "user.email", "user.id"] column_exclude_list = ["ownership_txt_token"] can_edit = False class ReferralAdmin(SLModelView): + form_base_class = SecureForm column_searchable_list = ["id", "user.email", "code", "name"] column_filters = ["id", "user.email", "code", "name"] @@ -467,6 +476,7 @@ def _admin_created_at_formatter(view, context, model, name): class AdminAuditLogAdmin(SLModelView): + form_base_class = SecureForm column_searchable_list = ["admin.id", "admin.email", "model_id", "created_at"] column_filters = ["admin.id", "admin.email", "model_id", "created_at"] column_exclude_list = ["id"] @@ -497,6 +507,7 @@ def _transactionalcomplaint_refused_email_id_formatter(view, context, model, nam class ProviderComplaintAdmin(SLModelView): + form_base_class = SecureForm column_searchable_list = ["id", "user.id", "created_at"] column_filters = ["user.id", "state"] column_hide_backrefs = False @@ -567,6 +578,7 @@ def _newsletter_html_formatter(view, context, model: Newsletter, name): class NewsletterAdmin(SLModelView): + form_base_class = SecureForm list_template = "admin/model/newsletter-list.html" edit_template = "admin/model/newsletter-edit.html" edit_modal = False @@ -648,6 +660,7 @@ class NewsletterAdmin(SLModelView): class NewsletterUserAdmin(SLModelView): + form_base_class = SecureForm column_searchable_list = ["id"] column_filters = ["id", "user.email", "newsletter.subject"] column_exclude_list = ["created_at", "updated_at", "id"] @@ -657,17 +670,20 @@ class NewsletterUserAdmin(SLModelView): class DailyMetricAdmin(SLModelView): + form_base_class = SecureForm column_exclude_list = ["created_at", "updated_at", "id"] can_export = True class MetricAdmin(SLModelView): + form_base_class = SecureForm column_exclude_list = ["created_at", "updated_at", "id"] can_export = True class InvalidMailboxDomainAdmin(SLModelView): + form_base_class = SecureForm can_create = True can_delete = True diff --git a/app/app/alias_suffix.py b/app/app/alias_suffix.py index 8cfc98a..fbcbff2 100644 --- a/app/app/alias_suffix.py +++ b/app/app/alias_suffix.py @@ -64,8 +64,12 @@ def verify_prefix_suffix( # SimpleLogin domain case: # 1) alias_suffix must start with "." and # 2) alias_domain_prefix must come from the word list + available_sl_domains = [ + sl_domain.domain + for sl_domain in user.get_sl_domains(alias_options=alias_options) + ] if ( - alias_domain in user.available_sl_domains(alias_options=alias_options) + alias_domain in available_sl_domains and alias_domain not in user_custom_domains # when DISABLE_ALIAS_SUFFIX is true, alias_domain_prefix is empty and not config.DISABLE_ALIAS_SUFFIX @@ -80,9 +84,7 @@ def verify_prefix_suffix( LOG.e("wrong alias suffix %s, user %s", alias_suffix, user) return False - if alias_domain not in user.available_sl_domains( - alias_options=alias_options - ): + if alias_domain not in available_sl_domains: LOG.e("wrong alias suffix %s, user %s", alias_suffix, user) return False diff --git a/app/app/api/views/user_info.py b/app/app/api/views/user_info.py index fee3124..44c530d 100644 --- a/app/app/api/views/user_info.py +++ b/app/app/api/views/user_info.py @@ -10,6 +10,7 @@ from app.api.base import api_bp, require_api_auth from app.config import SESSION_COOKIE_NAME from app.dashboard.views.index import get_stats from app.db import Session +from app.image_validation import detect_image_format, ImageFormat from app.models import ApiKey, File, PartnerUser, User from app.proton.utils import get_proton_partner from app.session import logout_session @@ -78,17 +79,18 @@ def update_user_info(): data = request.get_json() or {} if "profile_picture" in data: - if data["profile_picture"] is None: - if user.profile_picture_id: - file = user.profile_picture - user.profile_picture_id = None + if user.profile_picture_id: + file = user.profile_picture + user.profile_picture_id = None + Session.flush() + if file: + File.delete(file.id) + s3.delete(file.path) Session.flush() - if file: - File.delete(file.id) - s3.delete(file.path) - Session.flush() else: raw_data = base64.decodebytes(data["profile_picture"].encode()) + if detect_image_format(raw_data) == ImageFormat.Unknown: + return jsonify(error="Unsupported image format"), 400 file_path = random_string(30) file = File.create(user_id=user.id, path=file_path) Session.flush() diff --git a/app/app/auth/views/register.py b/app/app/auth/views/register.py index 6740e57..1799a76 100644 --- a/app/app/auth/views/register.py +++ b/app/app/auth/views/register.py @@ -115,7 +115,8 @@ def register(): def send_activation_email(user, next_url): - # the activation code is valid for 1h + # the activation code is valid for 1h and delete all previous codes + Session.query(ActivationCode).filter(ActivationCode.user_id == user.id).delete() activation = ActivationCode.create(user_id=user.id, code=random_string(30)) Session.commit() diff --git a/app/app/dashboard/views/setting.py b/app/app/dashboard/views/setting.py index f07ebb1..3b140b8 100644 --- a/app/app/dashboard/views/setting.py +++ b/app/app/dashboard/views/setting.py @@ -14,7 +14,7 @@ from flask_wtf import FlaskForm from flask_wtf.file import FileField from wtforms import StringField, validators -from app import s3 +from app import s3, user_settings from app.config import ( FIRST_ALIAS_DOMAIN, ALIAS_RANDOM_SUFFIX_LENGTH, @@ -31,12 +31,10 @@ from app.models import ( PlanEnum, File, EmailChange, - CustomDomain, AliasGeneratorEnum, AliasSuffixEnum, ManualSubscription, SenderFormatEnum, - SLDomain, CoinbaseSubscription, AppleSubscription, PartnerUser, @@ -166,38 +164,11 @@ def setting(): return redirect(url_for("dashboard.setting")) elif request.form.get("form-name") == "change-random-alias-default-domain": default_domain = request.form.get("random-alias-default-domain") - - if default_domain: - sl_domain: SLDomain = SLDomain.get_by(domain=default_domain) - if sl_domain: - if sl_domain.premium_only and not current_user.is_premium(): - flash("You cannot use this domain", "error") - return redirect(url_for("dashboard.setting")) - - current_user.default_alias_public_domain_id = sl_domain.id - current_user.default_alias_custom_domain_id = None - else: - custom_domain = CustomDomain.get_by(domain=default_domain) - if custom_domain: - # sanity check - if ( - custom_domain.user_id != current_user.id - or not custom_domain.verified - ): - LOG.w( - "%s cannot use domain %s", current_user, custom_domain - ) - flash(f"Domain {default_domain} can't be used", "error") - return redirect(request.url) - else: - current_user.default_alias_custom_domain_id = ( - custom_domain.id - ) - current_user.default_alias_public_domain_id = None - - else: - current_user.default_alias_custom_domain_id = None - current_user.default_alias_public_domain_id = None + try: + user_settings.set_default_alias_domain(current_user, default_domain) + except user_settings.CannotSetAlias as e: + flash(e.msg, "error") + return redirect(url_for("dashboard.setting")) Session.commit() flash("Your preference has been updated", "success") diff --git a/app/app/developer/views/client_detail.py b/app/app/developer/views/client_detail.py index 26d08f8..1af95b3 100644 --- a/app/app/developer/views/client_detail.py +++ b/app/app/developer/views/client_detail.py @@ -1,4 +1,5 @@ from io import BytesIO +from urllib.parse import urlparse from flask import request, render_template, redirect, url_for, flash from flask_login import current_user, login_required @@ -11,6 +12,7 @@ from app.config import ADMIN_EMAIL from app.db import Session from app.developer.base import developer_bp from app.email_utils import send_email +from app.image_validation import detect_image_format, ImageFormat from app.log import LOG from app.models import Client, RedirectUri, File, Referral from app.utils import random_string @@ -46,16 +48,25 @@ def client_detail(client_id): approval_form.description.data = client.description if action == "edit" and form.validate_on_submit(): + parsed_url = urlparse(form.url.data) + if parsed_url.scheme != "https": + flash("Only https urls are allowed", "error") + return redirect(url_for("developer.index")) client.name = form.name.data client.home_url = form.url.data if form.icon.data: - # todo: remove current icon if any - # todo: handle remove icon + icon_data = form.icon.data.read(10240) + if detect_image_format(icon_data) == ImageFormat.Unknown: + flash("Unknown file format", "warning") + return redirect(url_for("developer.index")) + if client.icon: + s3.delete(client.icon_id) + File.delete(client.icon) file_path = random_string(30) file = File.create(path=file_path, user_id=client.user_id) - s3.upload_from_bytesio(file_path, BytesIO(form.icon.data.read())) + s3.upload_from_bytesio(file_path, BytesIO(icon_data)) Session.flush() LOG.d("upload file %s to s3", file) diff --git a/app/app/developer/views/new_client.py b/app/app/developer/views/new_client.py index 6942241..1db7219 100644 --- a/app/app/developer/views/new_client.py +++ b/app/app/developer/views/new_client.py @@ -1,3 +1,5 @@ +from urllib.parse import urlparse + from flask import render_template, redirect, url_for, flash from flask_login import current_user, login_required from flask_wtf import FlaskForm @@ -20,6 +22,10 @@ def new_client(): if form.validate_on_submit(): client = Client.create_new(form.name.data, current_user.id) + parsed_url = urlparse(form.url.data) + if parsed_url.scheme != "https": + flash("Only https urls are allowed", "error") + return redirect(url_for("developer.new_client")) client.home_url = form.url.data Session.commit() diff --git a/app/app/models.py b/app/app/models.py index bfbc837..e5297aa 100644 --- a/app/app/models.py +++ b/app/app/models.py @@ -985,8 +985,8 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle): - the domain """ res = [] - for domain in self.available_sl_domains(alias_options=alias_options): - res.append((True, domain)) + for domain in self.get_sl_domains(alias_options=alias_options): + res.append((True, domain.domain)) for custom_domain in self.verified_custom_domains(): res.append((False, custom_domain.domain)) @@ -1128,7 +1128,10 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle): - Verified custom domains """ - domains = self.available_sl_domains(alias_options=alias_options) + domains = [ + sl_domain.domain + for sl_domain in self.get_sl_domains(alias_options=alias_options) + ] for custom_domain in self.verified_custom_domains(): domains.append(custom_domain.domain) @@ -2483,7 +2486,7 @@ class CustomDomain(Base, ModelMixin): return sorted(self._auto_create_rules, key=lambda rule: rule.order) def __repr__(self): - return f"" + return f"" class AutoCreateRule(Base, ModelMixin): @@ -3114,7 +3117,7 @@ class SLDomain(Base, ModelMixin): ) def __repr__(self): - return f"" class Monitoring(Base, ModelMixin): @@ -3484,6 +3487,7 @@ class AdminAuditLog(Base): action=AuditLogActionEnum.stop_trial.value, model="User", model_id=user_id, + data={}, ) @classmethod @@ -3729,6 +3733,7 @@ class SyncEvent(Base, ModelMixin): taken_time = sa.Column( ArrowType, default=None, nullable=True, server_default=None, index=True ) + retry_count = sa.Column(sa.Integer, default=0, nullable=False, server_default="0") __table_args__ = ( sa.Index("ix_sync_event_created_at", "created_at"), @@ -3750,7 +3755,7 @@ class SyncEvent(Base, ModelMixin): return res.rowcount > 0 @classmethod - def get_dead_letter(cls, older_than: Arrow) -> [SyncEvent]: + def get_dead_letter(cls, older_than: Arrow, max_retries: int) -> [SyncEvent]: return ( SyncEvent.filter( ( @@ -3763,6 +3768,7 @@ class SyncEvent(Base, ModelMixin): & (SyncEvent.created_at < older_than) ) ) + & (SyncEvent.retry_count < max_retries) ) .order_by(SyncEvent.id) .limit(100) diff --git a/app/app/user_settings.py b/app/app/user_settings.py new file mode 100644 index 0000000..f9bbb96 --- /dev/null +++ b/app/app/user_settings.py @@ -0,0 +1,47 @@ +from typing import Optional + +from app.db import Session +from app.log import LOG +from app.models import User, SLDomain, CustomDomain + + +class CannotSetAlias(Exception): + def __init__(self, msg: str): + self.msg = msg + + +def set_default_alias_domain(user: User, domain_name: Optional[str]): + if domain_name is None: + LOG.i(f"User {user} has set no domain as default domain") + user.default_alias_public_domain_id = None + user.default_alias_custom_domain_id = None + Session.flush() + return + sl_domain: SLDomain = SLDomain.get_by(domain=domain_name) + if sl_domain: + if sl_domain.hidden: + LOG.i(f"User {user} has tried to set up a hidden domain as default domain") + raise CannotSetAlias("Domain does not exist") + if sl_domain.premium_only and not user.is_premium(): + LOG.i(f"User {user} has tried to set up a premium domain as default domain") + raise CannotSetAlias("You cannot use this domain") + LOG.i(f"User {user} has set public {sl_domain} as default domain") + user.default_alias_public_domain_id = sl_domain.id + user.default_alias_custom_domain_id = None + Session.flush() + return + custom_domain = CustomDomain.get_by(domain=domain_name) + if not custom_domain: + LOG.i( + f"User {user} has tried to set up an non existing domain as default domain" + ) + raise CannotSetAlias("Domain does not exist or it hasn't been verified") + if custom_domain.user_id != user.id or not custom_domain.verified: + LOG.i( + f"User {user} has tried to set domain {custom_domain} as default domain that does not belong to the user or that is not verified" + ) + raise CannotSetAlias("Domain does not exist or it hasn't been verified") + LOG.i(f"User {user} has set custom {custom_domain} as default domain") + user.default_alias_public_domain_id = None + user.default_alias_custom_domain_id = custom_domain.id + Session.flush() diff --git a/app/email_handler.py b/app/email_handler.py index 7257800..84301ac 100644 --- a/app/email_handler.py +++ b/app/email_handler.py @@ -262,7 +262,8 @@ def get_or_create_contact(from_header: str, mail_from: str, alias: Alias) -> Con Session.commit() except IntegrityError: - # No need to manually rollback, as IntegrityError already rolls back + # If the tx has been rolled back, the connection is borked. Force close to try to get a new one and start fresh + Session.close() LOG.info( f"Contact with email {contact_email} for alias_id {alias_id} already existed, fetching from DB" ) diff --git a/app/event_listener.py b/app/event_listener.py index 93e23ed..544220f 100644 --- a/app/event_listener.py +++ b/app/event_listener.py @@ -8,6 +8,8 @@ from events.runner import Runner from events.event_source import DeadLetterEventSource, PostgresEventSource from events.event_sink import ConsoleEventSink, HttpEventSink +_DEFAULT_MAX_RETRIES = 100 + class Mode(Enum): DEAD_LETTER = "dead_letter" @@ -23,10 +25,10 @@ class Mode(Enum): raise ValueError(f"Invalid mode: {value}") -def main(mode: Mode, dry_run: bool): +def main(mode: Mode, dry_run: bool, max_retries: int): if mode == Mode.DEAD_LETTER: LOG.i("Using DeadLetterEventSource") - source = DeadLetterEventSource() + source = DeadLetterEventSource(max_retries) elif mode == Mode.LISTENER: LOG.i("Using PostgresEventSource") source = PostgresEventSource(DB_URI) @@ -51,6 +53,13 @@ def args(): help="Mode to run", choices=[Mode.DEAD_LETTER.value, Mode.LISTENER.value], ) + parser.add_argument( + "max_retries", + help="Max retries to consider an event as error and not try to process it again", + type=int, + nargs="?", + default=_DEFAULT_MAX_RETRIES, + ) parser.add_argument("--dry-run", help="Dry run mode", action="store_true") return parser.parse_args() @@ -61,4 +70,8 @@ if __name__ == "__main__": exit(1) args = args() - main(Mode.from_str(args.mode), args.dry_run) + main( + mode=Mode.from_str(args.mode), + dry_run=args.dry_run, + max_retries=args.max_retries, + ) diff --git a/app/events/event_source.py b/app/events/event_source.py index f4f8937..89a8a4a 100644 --- a/app/events/event_source.py +++ b/app/events/event_source.py @@ -4,6 +4,8 @@ import psycopg2 import select from abc import ABC, abstractmethod + +from app.db import Session from app.log import LOG from app.models import SyncEvent from app.events.event_dispatcher import NOTIFICATION_CHANNEL @@ -66,6 +68,7 @@ class PostgresEventSource(EventSource): LOG.info(f"Could not find event with id={notify.payload}") except Exception as e: LOG.warn(f"Error getting event: {e}") + Session.close() # Ensure we get a new connection and we don't leave a dangling tx def __connect(self): self.__connection = psycopg2.connect(self.__connection_string) @@ -76,6 +79,9 @@ class PostgresEventSource(EventSource): class DeadLetterEventSource(EventSource): + def __init__(self, max_retries: int): + self.__max_retries = max_retries + @newrelic.agent.background_task() def run(self, on_event: Callable[[SyncEvent], NoReturn]): while True: @@ -83,7 +89,9 @@ class DeadLetterEventSource(EventSource): threshold = arrow.utcnow().shift( minutes=-_DEAD_LETTER_THRESHOLD_MINUTES ) - events = SyncEvent.get_dead_letter(older_than=threshold) + events = SyncEvent.get_dead_letter( + older_than=threshold, max_retries=self.__max_retries + ) if events: LOG.info(f"Got {len(events)} dead letter events") if events: @@ -92,7 +100,8 @@ class DeadLetterEventSource(EventSource): ) for event in events: on_event(event) - else: + Session.close() # Ensure that we have a new connection and we don't have a dangling tx with a lock + if not events: LOG.debug("No dead letter events") sleep(_DEAD_LETTER_INTERVAL_SECONDS) except Exception as e: diff --git a/app/events/runner.py b/app/events/runner.py index d6f9c2e..0fe7bff 100644 --- a/app/events/runner.py +++ b/app/events/runner.py @@ -2,6 +2,7 @@ import arrow import newrelic.agent from app.log import LOG +from app.db import Session from app.models import SyncEvent from events.event_sink import EventSink from events.event_source import EventSource @@ -37,6 +38,9 @@ class Runner: "Custom/sync_event_elapsed_time", time_between_taken_and_created.total_seconds(), ) + else: + event.retry_count = event.retry_count + 1 + Session.commit() except Exception as e: LOG.warn(f"Exception processing event [id={event.id}]: {e}") newrelic.agent.record_custom_metric("Custom/sync_event_failed", 1) diff --git a/app/migrations/versions/2024_071908_56d08955fcab_add_retry_count_to_sync_event.py b/app/migrations/versions/2024_071908_56d08955fcab_add_retry_count_to_sync_event.py new file mode 100644 index 0000000..a283b87 --- /dev/null +++ b/app/migrations/versions/2024_071908_56d08955fcab_add_retry_count_to_sync_event.py @@ -0,0 +1,28 @@ +"""add retry count to sync event + +Revision ID: 56d08955fcab +Revises: d608b8e48082 +Create Date: 2024-07-19 08:21:19.979973 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '56d08955fcab' +down_revision = 'd608b8e48082' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('sync_event', sa.Column('retry_count', sa.Integer(), server_default='0', nullable=False, default=0)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('sync_event', 'retry_count') + # ### end Alembic commands ### diff --git a/app/scripts/new-migration.sh b/app/scripts/new-migration.sh index da11a75..5456857 100755 --- a/app/scripts/new-migration.sh +++ b/app/scripts/new-migration.sh @@ -12,10 +12,10 @@ docker run -p 25432:5432 --name ${container_name} -e POSTGRES_PASSWORD=postgres sleep 3 # upgrade the DB to the latest stage and -env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl poetry run alembic upgrade head +env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl rye run alembic upgrade head # generate the migration script. -env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl poetry run alembic revision --autogenerate $@ +env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl rye run alembic revision --autogenerate $@ # remove the db docker rm -f ${container_name} diff --git a/app/scripts/reset_local_db.sh b/app/scripts/reset_local_db.sh index 422c2a8..42dae95 100755 --- a/app/scripts/reset_local_db.sh +++ b/app/scripts/reset_local_db.sh @@ -3,5 +3,5 @@ export DB_URI=postgresql://myuser:mypassword@localhost:15432/simplelogin echo 'drop schema public cascade; create schema public;' | psql $DB_URI -poetry run alembic upgrade head -poetry run flask dummy-data +rye run alembic upgrade head +rye run flask dummy-data diff --git a/app/scripts/reset_test_db.sh b/app/scripts/reset_test_db.sh index 2546601..234cef3 100755 --- a/app/scripts/reset_test_db.sh +++ b/app/scripts/reset_test_db.sh @@ -3,4 +3,4 @@ export DB_URI=postgresql://myuser:mypassword@localhost:15432/test echo 'drop schema public cascade; create schema public;' | psql $DB_URI -poetry run alembic upgrade head +rye run alembic upgrade head diff --git a/app/server.py b/app/server.py index 04278dc..ea74d1e 100644 --- a/app/server.py +++ b/app/server.py @@ -283,6 +283,7 @@ def set_index_page(app): and not request.path.startswith("/git") and not request.path.startswith("/favicon.ico") ): + start_time = g.start_time or time.time() LOG.d( "%s %s %s %s %s, takes %s", request.remote_addr, @@ -290,7 +291,7 @@ def set_index_page(app): request.path, request.args, res.status_code, - time.time() - g.start_time, + time.time() - start_time, ) return res diff --git a/app/tests/api/test_setting.py b/app/tests/api/test_setting.py index 199a840..698596e 100644 --- a/app/tests/api/test_setting.py +++ b/app/tests/api/test_setting.py @@ -44,6 +44,9 @@ def test_update_settings_alias_generator(flask_client): def test_update_settings_random_alias_default_domain(flask_client): user = login(flask_client) + custom_domain = CustomDomain.create( + domain=random_domain(), verified=True, user_id=user.id, flush=True + ) assert user.default_random_alias_domain() == "sl.local" r = flask_client.patch( @@ -57,6 +60,12 @@ def test_update_settings_random_alias_default_domain(flask_client): assert r.status_code == 200 assert user.default_random_alias_domain() == "d1.test" + r = flask_client.patch( + "/api/setting", json={"random_alias_default_domain": custom_domain.domain} + ) + assert r.status_code == 200 + assert user.default_random_alias_domain() == custom_domain.domain + def test_update_settings_sender_format(flask_client): user = login(flask_client) diff --git a/app/tests/test_extensions.py b/app/tests/test_extensions.py index a8f659f..e494301 100644 --- a/app/tests/test_extensions.py +++ b/app/tests/test_extensions.py @@ -23,7 +23,7 @@ _MAX_PER_MINUTE = 3 _ENDPOINT, methods=["GET"], ) -@limiter.limit(f"{_MAX_PER_MINUTE}/minute") +@limiter.limit(f"{_MAX_PER_MINUTE}/hour") def rate_limited_endpoint_1(): return "Working", HTTPStatus.OK diff --git a/app/tests/test_models.py b/app/tests/test_models.py index eaaacb1..b84488d 100644 --- a/app/tests/test_models.py +++ b/app/tests/test_models.py @@ -365,12 +365,22 @@ def test_sync_event_dead_letter(): commit=True, ) + # create event with too many retries + max_retries = 5 + e5 = SyncEvent.create( + content=b"content", + retry_count=max_retries + 1, + created_at=arrow.now(), + commit=True, + ) + # get dead letter events dead_letter_events = SyncEvent.get_dead_letter( - older_than=arrow.now().shift(minutes=-10) + older_than=arrow.now().shift(minutes=-10), max_retries=max_retries ) assert len(dead_letter_events) == 2 assert e1 in dead_letter_events assert e2 in dead_letter_events assert e3 not in dead_letter_events assert e4 not in dead_letter_events + assert e5 not in dead_letter_events diff --git a/app/tests/user_settings/__init__.py b/app/tests/user_settings/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/tests/user_settings/test_set_default_alias_domain.py b/app/tests/user_settings/test_set_default_alias_domain.py new file mode 100644 index 0000000..668204a --- /dev/null +++ b/app/tests/user_settings/test_set_default_alias_domain.py @@ -0,0 +1,128 @@ +import pytest + +from app import user_settings +from app.db import Session +from app.models import User, CustomDomain, SLDomain +from tests.utils import random_token, create_new_user + +user_id: int = 0 +custom_domain_name: str = "" +sl_domain_name: str = "" + + +def setup_module(): + global user_id, custom_domain_name, sl_domain_name + user = create_new_user() + user.trial_end = None + user_id = user.id + custom_domain_name = CustomDomain.create( + user_id=user_id, + catch_all=True, + domain=random_token() + ".com", + verified=True, + flush=True, + ).domain + sl_domain_name = SLDomain.create( + domain=random_token() + ".com", + premium_only=False, + flush=True, + order=5, + hidden=False, + ).domain + + +def test_set_default_no_domain(): + user = User.get(user_id) + user.default_alias_public_domain_id = SLDomain.get_by(domain=sl_domain_name).id + user.default_alias_private_domain_id = CustomDomain.get_by( + domain=custom_domain_name + ).id + Session.flush() + user_settings.set_default_alias_domain(user, None) + assert user.default_alias_public_domain_id is None + assert user.default_alias_custom_domain_id is None + + +def test_set_premium_sl_domain_with_non_premium_user(): + user = User.get(user_id) + user.lifetime = False + domain = SLDomain.get_by(domain=sl_domain_name) + domain.premium_only = True + Session.flush() + with pytest.raises(user_settings.CannotSetAlias): + user_settings.set_default_alias_domain(user, sl_domain_name) + + +def test_set_hidden_sl_domain(): + user = User.get(user_id) + domain = SLDomain.get_by(domain=sl_domain_name) + domain.hidden = True + domain.premium_only = False + Session.flush() + with pytest.raises(user_settings.CannotSetAlias): + user_settings.set_default_alias_domain(user, sl_domain_name) + + +def test_set_sl_domain(): + user = User.get(user_id) + user.lifetime = False + domain = SLDomain.get_by(domain=sl_domain_name) + domain.hidden = False + domain.premium_only = False + Session.flush() + user_settings.set_default_alias_domain(user, sl_domain_name) + assert user.default_alias_public_domain_id == domain.id + assert user.default_alias_custom_domain_id is None + + +def test_set_sl_premium_domain(): + user = User.get(user_id) + user.lifetime = True + domain = SLDomain.get_by(domain=sl_domain_name) + domain.hidden = False + domain.premium_only = True + Session.flush() + user_settings.set_default_alias_domain(user, sl_domain_name) + assert user.default_alias_public_domain_id == domain.id + assert user.default_alias_custom_domain_id is None + + +def test_set_other_user_custom_domain(): + user = User.get(user_id) + user.lifetime = True + other_user_domain_name = CustomDomain.create( + user_id=create_new_user().id, + catch_all=True, + domain=random_token() + ".com", + verified=True, + ).domain + Session.flush() + with pytest.raises(user_settings.CannotSetAlias): + user_settings.set_default_alias_domain(user, other_user_domain_name) + + +def test_set_unverified_custom_domain(): + user = User.get(user_id) + user.lifetime = True + domain = CustomDomain.get_by(domain=custom_domain_name) + domain.verified = False + Session.flush() + with pytest.raises(user_settings.CannotSetAlias): + user_settings.set_default_alias_domain(user, custom_domain_name) + + +def test_set_custom_domain(): + user = User.get(user_id) + user.lifetime = True + domain = CustomDomain.get_by(domain=custom_domain_name) + domain.verified = True + Session.flush() + user_settings.set_default_alias_domain(user, custom_domain_name) + assert user.default_alias_public_domain_id is None + assert user.default_alias_custom_domain_id == domain.id + + +def test_set_invalid_custom_domain(): + user = User.get(user_id) + with pytest.raises(user_settings.CannotSetAlias): + user_settings.set_default_alias_domain(user, "invalid_nop" + random_token())