4.47.2
All checks were successful
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m33s
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m39s
Build-Release-Image / Merge-Images (push) Successful in 12s
Build-Release-Image / Create-Release (push) Successful in 8s
Build-Release-Image / Notify (push) Successful in 5s
All checks were successful
Build-Release-Image / Build-Image (linux/arm64) (push) Successful in 3m33s
Build-Release-Image / Build-Image (linux/amd64) (push) Successful in 3m39s
Build-Release-Image / Merge-Images (push) Successful in 12s
Build-Release-Image / Create-Release (push) Successful in 8s
Build-Release-Image / Notify (push) Successful in 5s
This commit is contained in:
parent
e7f0f81d85
commit
dae6f64482
1
app/.gitignore
vendored
1
app/.gitignore
vendored
@ -11,7 +11,6 @@ db.sqlite-journal
|
|||||||
static/upload
|
static/upload
|
||||||
venv/
|
venv/
|
||||||
.venv
|
.venv
|
||||||
.python-version
|
|
||||||
.coverage
|
.coverage
|
||||||
htmlcov
|
htmlcov
|
||||||
adhoc
|
adhoc
|
||||||
|
@ -20,15 +20,15 @@ SimpleLogin backend consists of 2 main components:
|
|||||||
## Install dependencies
|
## Install dependencies
|
||||||
|
|
||||||
The project requires:
|
The project requires:
|
||||||
- Python 3.7+ and [poetry](https://python-poetry.org/) to manage dependencies
|
- Python 3.10 and [rye](https://github.com/astral-sh/rye) to manage dependencies
|
||||||
- Node v10 for front-end.
|
- Node v10 for front-end.
|
||||||
- Postgres 12+
|
- Postgres 13+
|
||||||
|
|
||||||
First, install all dependencies by running the following command.
|
First, install all dependencies by running the following command.
|
||||||
Feel free to use `virtualenv` or similar tools to isolate development environment.
|
Feel free to use `virtualenv` or similar tools to isolate development environment.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
poetry install
|
rye sync
|
||||||
```
|
```
|
||||||
|
|
||||||
On Mac, sometimes you might need to install some other packages via `brew`:
|
On Mac, sometimes you might need to install some other packages via `brew`:
|
||||||
@ -55,7 +55,7 @@ brew install -s re2 pybind11
|
|||||||
We use pre-commit to run all our linting and static analysis checks. Please run
|
We use pre-commit to run all our linting and static analysis checks. Please run
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
poetry run pre-commit install
|
rye run pre-commit install
|
||||||
```
|
```
|
||||||
|
|
||||||
To install it in your development environment.
|
To install it in your development environment.
|
||||||
@ -160,25 +160,25 @@ Here are the small sum-ups of the directory structures and their roles:
|
|||||||
The code is formatted using [ruff](https://github.com/astral-sh/ruff), to format the code, simply run
|
The code is formatted using [ruff](https://github.com/astral-sh/ruff), to format the code, simply run
|
||||||
|
|
||||||
```
|
```
|
||||||
poetry run ruff format .
|
rye run ruff format .
|
||||||
```
|
```
|
||||||
|
|
||||||
The code is also checked with `flake8`, make sure to run `flake8` before creating the pull request by
|
The code is also checked with `flake8`, make sure to run `flake8` before creating the pull request by
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
poetry run flake8
|
rye run flake8
|
||||||
```
|
```
|
||||||
|
|
||||||
For HTML templates, we use `djlint`. Before creating a pull request, please run
|
For HTML templates, we use `djlint`. Before creating a pull request, please run
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
poetry run djlint --check templates
|
rye run djlint --check templates
|
||||||
```
|
```
|
||||||
|
|
||||||
If some files aren't properly formatted, you can format all files with
|
If some files aren't properly formatted, you can format all files with
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
poetry run djlint --reformat .
|
rye run djlint --reformat .
|
||||||
```
|
```
|
||||||
|
|
||||||
## Test sending email
|
## Test sending email
|
||||||
|
@ -541,7 +541,7 @@ exit
|
|||||||
|
|
||||||
Once you've created all your desired login accounts, add these lines to `/simplelogin.env` to disable further registrations:
|
Once you've created all your desired login accounts, add these lines to `/simplelogin.env` to disable further registrations:
|
||||||
|
|
||||||
```
|
```.env
|
||||||
DISABLE_REGISTRATION=1
|
DISABLE_REGISTRATION=1
|
||||||
DISABLE_ONBOARDING=true
|
DISABLE_ONBOARDING=true
|
||||||
```
|
```
|
||||||
|
@ -2,6 +2,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import arrow
|
import arrow
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
from flask_admin.form import SecureForm
|
||||||
from flask_admin.model.template import EndpointLinkRowAction
|
from flask_admin.model.template import EndpointLinkRowAction
|
||||||
from markupsafe import Markup
|
from markupsafe import Markup
|
||||||
|
|
||||||
@ -100,6 +101,7 @@ def _user_upgrade_channel_formatter(view, context, model, name):
|
|||||||
|
|
||||||
|
|
||||||
class UserAdmin(SLModelView):
|
class UserAdmin(SLModelView):
|
||||||
|
form_base_class = SecureForm
|
||||||
column_searchable_list = ["email", "id"]
|
column_searchable_list = ["email", "id"]
|
||||||
column_exclude_list = [
|
column_exclude_list = [
|
||||||
"salt",
|
"salt",
|
||||||
@ -344,6 +346,7 @@ def manual_upgrade(way: str, ids: [int], is_giveaway: bool):
|
|||||||
|
|
||||||
|
|
||||||
class EmailLogAdmin(SLModelView):
|
class EmailLogAdmin(SLModelView):
|
||||||
|
form_base_class = SecureForm
|
||||||
column_searchable_list = ["id"]
|
column_searchable_list = ["id"]
|
||||||
column_filters = ["id", "user.email", "mailbox.email", "contact.website_email"]
|
column_filters = ["id", "user.email", "mailbox.email", "contact.website_email"]
|
||||||
|
|
||||||
@ -352,6 +355,7 @@ class EmailLogAdmin(SLModelView):
|
|||||||
|
|
||||||
|
|
||||||
class AliasAdmin(SLModelView):
|
class AliasAdmin(SLModelView):
|
||||||
|
form_base_class = SecureForm
|
||||||
column_searchable_list = ["id", "user.email", "email", "mailbox.email"]
|
column_searchable_list = ["id", "user.email", "email", "mailbox.email"]
|
||||||
column_filters = ["id", "user.email", "email", "mailbox.email"]
|
column_filters = ["id", "user.email", "email", "mailbox.email"]
|
||||||
|
|
||||||
@ -377,6 +381,7 @@ class AliasAdmin(SLModelView):
|
|||||||
|
|
||||||
|
|
||||||
class MailboxAdmin(SLModelView):
|
class MailboxAdmin(SLModelView):
|
||||||
|
form_base_class = SecureForm
|
||||||
column_searchable_list = ["id", "user.email", "email"]
|
column_searchable_list = ["id", "user.email", "email"]
|
||||||
column_filters = ["id", "user.email", "email"]
|
column_filters = ["id", "user.email", "email"]
|
||||||
|
|
||||||
@ -387,11 +392,13 @@ class MailboxAdmin(SLModelView):
|
|||||||
|
|
||||||
|
|
||||||
class CouponAdmin(SLModelView):
|
class CouponAdmin(SLModelView):
|
||||||
|
form_base_class = SecureForm
|
||||||
can_edit = False
|
can_edit = False
|
||||||
can_create = True
|
can_create = True
|
||||||
|
|
||||||
|
|
||||||
class ManualSubscriptionAdmin(SLModelView):
|
class ManualSubscriptionAdmin(SLModelView):
|
||||||
|
form_base_class = SecureForm
|
||||||
can_edit = True
|
can_edit = True
|
||||||
column_searchable_list = ["id", "user.email"]
|
column_searchable_list = ["id", "user.email"]
|
||||||
|
|
||||||
@ -433,12 +440,14 @@ class ManualSubscriptionAdmin(SLModelView):
|
|||||||
|
|
||||||
|
|
||||||
class CustomDomainAdmin(SLModelView):
|
class CustomDomainAdmin(SLModelView):
|
||||||
|
form_base_class = SecureForm
|
||||||
column_searchable_list = ["domain", "user.email", "user.id"]
|
column_searchable_list = ["domain", "user.email", "user.id"]
|
||||||
column_exclude_list = ["ownership_txt_token"]
|
column_exclude_list = ["ownership_txt_token"]
|
||||||
can_edit = False
|
can_edit = False
|
||||||
|
|
||||||
|
|
||||||
class ReferralAdmin(SLModelView):
|
class ReferralAdmin(SLModelView):
|
||||||
|
form_base_class = SecureForm
|
||||||
column_searchable_list = ["id", "user.email", "code", "name"]
|
column_searchable_list = ["id", "user.email", "code", "name"]
|
||||||
column_filters = ["id", "user.email", "code", "name"]
|
column_filters = ["id", "user.email", "code", "name"]
|
||||||
|
|
||||||
@ -467,6 +476,7 @@ def _admin_created_at_formatter(view, context, model, name):
|
|||||||
|
|
||||||
|
|
||||||
class AdminAuditLogAdmin(SLModelView):
|
class AdminAuditLogAdmin(SLModelView):
|
||||||
|
form_base_class = SecureForm
|
||||||
column_searchable_list = ["admin.id", "admin.email", "model_id", "created_at"]
|
column_searchable_list = ["admin.id", "admin.email", "model_id", "created_at"]
|
||||||
column_filters = ["admin.id", "admin.email", "model_id", "created_at"]
|
column_filters = ["admin.id", "admin.email", "model_id", "created_at"]
|
||||||
column_exclude_list = ["id"]
|
column_exclude_list = ["id"]
|
||||||
@ -497,6 +507,7 @@ def _transactionalcomplaint_refused_email_id_formatter(view, context, model, nam
|
|||||||
|
|
||||||
|
|
||||||
class ProviderComplaintAdmin(SLModelView):
|
class ProviderComplaintAdmin(SLModelView):
|
||||||
|
form_base_class = SecureForm
|
||||||
column_searchable_list = ["id", "user.id", "created_at"]
|
column_searchable_list = ["id", "user.id", "created_at"]
|
||||||
column_filters = ["user.id", "state"]
|
column_filters = ["user.id", "state"]
|
||||||
column_hide_backrefs = False
|
column_hide_backrefs = False
|
||||||
@ -567,6 +578,7 @@ def _newsletter_html_formatter(view, context, model: Newsletter, name):
|
|||||||
|
|
||||||
|
|
||||||
class NewsletterAdmin(SLModelView):
|
class NewsletterAdmin(SLModelView):
|
||||||
|
form_base_class = SecureForm
|
||||||
list_template = "admin/model/newsletter-list.html"
|
list_template = "admin/model/newsletter-list.html"
|
||||||
edit_template = "admin/model/newsletter-edit.html"
|
edit_template = "admin/model/newsletter-edit.html"
|
||||||
edit_modal = False
|
edit_modal = False
|
||||||
@ -648,6 +660,7 @@ class NewsletterAdmin(SLModelView):
|
|||||||
|
|
||||||
|
|
||||||
class NewsletterUserAdmin(SLModelView):
|
class NewsletterUserAdmin(SLModelView):
|
||||||
|
form_base_class = SecureForm
|
||||||
column_searchable_list = ["id"]
|
column_searchable_list = ["id"]
|
||||||
column_filters = ["id", "user.email", "newsletter.subject"]
|
column_filters = ["id", "user.email", "newsletter.subject"]
|
||||||
column_exclude_list = ["created_at", "updated_at", "id"]
|
column_exclude_list = ["created_at", "updated_at", "id"]
|
||||||
@ -657,17 +670,20 @@ class NewsletterUserAdmin(SLModelView):
|
|||||||
|
|
||||||
|
|
||||||
class DailyMetricAdmin(SLModelView):
|
class DailyMetricAdmin(SLModelView):
|
||||||
|
form_base_class = SecureForm
|
||||||
column_exclude_list = ["created_at", "updated_at", "id"]
|
column_exclude_list = ["created_at", "updated_at", "id"]
|
||||||
|
|
||||||
can_export = True
|
can_export = True
|
||||||
|
|
||||||
|
|
||||||
class MetricAdmin(SLModelView):
|
class MetricAdmin(SLModelView):
|
||||||
|
form_base_class = SecureForm
|
||||||
column_exclude_list = ["created_at", "updated_at", "id"]
|
column_exclude_list = ["created_at", "updated_at", "id"]
|
||||||
|
|
||||||
can_export = True
|
can_export = True
|
||||||
|
|
||||||
|
|
||||||
class InvalidMailboxDomainAdmin(SLModelView):
|
class InvalidMailboxDomainAdmin(SLModelView):
|
||||||
|
form_base_class = SecureForm
|
||||||
can_create = True
|
can_create = True
|
||||||
can_delete = True
|
can_delete = True
|
||||||
|
@ -64,8 +64,12 @@ def verify_prefix_suffix(
|
|||||||
# SimpleLogin domain case:
|
# SimpleLogin domain case:
|
||||||
# 1) alias_suffix must start with "." and
|
# 1) alias_suffix must start with "." and
|
||||||
# 2) alias_domain_prefix must come from the word list
|
# 2) alias_domain_prefix must come from the word list
|
||||||
|
available_sl_domains = [
|
||||||
|
sl_domain.domain
|
||||||
|
for sl_domain in user.get_sl_domains(alias_options=alias_options)
|
||||||
|
]
|
||||||
if (
|
if (
|
||||||
alias_domain in user.available_sl_domains(alias_options=alias_options)
|
alias_domain in available_sl_domains
|
||||||
and alias_domain not in user_custom_domains
|
and alias_domain not in user_custom_domains
|
||||||
# when DISABLE_ALIAS_SUFFIX is true, alias_domain_prefix is empty
|
# when DISABLE_ALIAS_SUFFIX is true, alias_domain_prefix is empty
|
||||||
and not config.DISABLE_ALIAS_SUFFIX
|
and not config.DISABLE_ALIAS_SUFFIX
|
||||||
@ -80,9 +84,7 @@ def verify_prefix_suffix(
|
|||||||
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
|
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if alias_domain not in user.available_sl_domains(
|
if alias_domain not in available_sl_domains:
|
||||||
alias_options=alias_options
|
|
||||||
):
|
|
||||||
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
|
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ from app.api.base import api_bp, require_api_auth
|
|||||||
from app.config import SESSION_COOKIE_NAME
|
from app.config import SESSION_COOKIE_NAME
|
||||||
from app.dashboard.views.index import get_stats
|
from app.dashboard.views.index import get_stats
|
||||||
from app.db import Session
|
from app.db import Session
|
||||||
|
from app.image_validation import detect_image_format, ImageFormat
|
||||||
from app.models import ApiKey, File, PartnerUser, User
|
from app.models import ApiKey, File, PartnerUser, User
|
||||||
from app.proton.utils import get_proton_partner
|
from app.proton.utils import get_proton_partner
|
||||||
from app.session import logout_session
|
from app.session import logout_session
|
||||||
@ -78,7 +79,6 @@ def update_user_info():
|
|||||||
data = request.get_json() or {}
|
data = request.get_json() or {}
|
||||||
|
|
||||||
if "profile_picture" in data:
|
if "profile_picture" in data:
|
||||||
if data["profile_picture"] is None:
|
|
||||||
if user.profile_picture_id:
|
if user.profile_picture_id:
|
||||||
file = user.profile_picture
|
file = user.profile_picture
|
||||||
user.profile_picture_id = None
|
user.profile_picture_id = None
|
||||||
@ -89,6 +89,8 @@ def update_user_info():
|
|||||||
Session.flush()
|
Session.flush()
|
||||||
else:
|
else:
|
||||||
raw_data = base64.decodebytes(data["profile_picture"].encode())
|
raw_data = base64.decodebytes(data["profile_picture"].encode())
|
||||||
|
if detect_image_format(raw_data) == ImageFormat.Unknown:
|
||||||
|
return jsonify(error="Unsupported image format"), 400
|
||||||
file_path = random_string(30)
|
file_path = random_string(30)
|
||||||
file = File.create(user_id=user.id, path=file_path)
|
file = File.create(user_id=user.id, path=file_path)
|
||||||
Session.flush()
|
Session.flush()
|
||||||
|
@ -115,7 +115,8 @@ def register():
|
|||||||
|
|
||||||
|
|
||||||
def send_activation_email(user, next_url):
|
def send_activation_email(user, next_url):
|
||||||
# the activation code is valid for 1h
|
# the activation code is valid for 1h and delete all previous codes
|
||||||
|
Session.query(ActivationCode).filter(ActivationCode.user_id == user.id).delete()
|
||||||
activation = ActivationCode.create(user_id=user.id, code=random_string(30))
|
activation = ActivationCode.create(user_id=user.id, code=random_string(30))
|
||||||
Session.commit()
|
Session.commit()
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ from flask_wtf import FlaskForm
|
|||||||
from flask_wtf.file import FileField
|
from flask_wtf.file import FileField
|
||||||
from wtforms import StringField, validators
|
from wtforms import StringField, validators
|
||||||
|
|
||||||
from app import s3
|
from app import s3, user_settings
|
||||||
from app.config import (
|
from app.config import (
|
||||||
FIRST_ALIAS_DOMAIN,
|
FIRST_ALIAS_DOMAIN,
|
||||||
ALIAS_RANDOM_SUFFIX_LENGTH,
|
ALIAS_RANDOM_SUFFIX_LENGTH,
|
||||||
@ -31,12 +31,10 @@ from app.models import (
|
|||||||
PlanEnum,
|
PlanEnum,
|
||||||
File,
|
File,
|
||||||
EmailChange,
|
EmailChange,
|
||||||
CustomDomain,
|
|
||||||
AliasGeneratorEnum,
|
AliasGeneratorEnum,
|
||||||
AliasSuffixEnum,
|
AliasSuffixEnum,
|
||||||
ManualSubscription,
|
ManualSubscription,
|
||||||
SenderFormatEnum,
|
SenderFormatEnum,
|
||||||
SLDomain,
|
|
||||||
CoinbaseSubscription,
|
CoinbaseSubscription,
|
||||||
AppleSubscription,
|
AppleSubscription,
|
||||||
PartnerUser,
|
PartnerUser,
|
||||||
@ -166,39 +164,12 @@ def setting():
|
|||||||
return redirect(url_for("dashboard.setting"))
|
return redirect(url_for("dashboard.setting"))
|
||||||
elif request.form.get("form-name") == "change-random-alias-default-domain":
|
elif request.form.get("form-name") == "change-random-alias-default-domain":
|
||||||
default_domain = request.form.get("random-alias-default-domain")
|
default_domain = request.form.get("random-alias-default-domain")
|
||||||
|
try:
|
||||||
if default_domain:
|
user_settings.set_default_alias_domain(current_user, default_domain)
|
||||||
sl_domain: SLDomain = SLDomain.get_by(domain=default_domain)
|
except user_settings.CannotSetAlias as e:
|
||||||
if sl_domain:
|
flash(e.msg, "error")
|
||||||
if sl_domain.premium_only and not current_user.is_premium():
|
|
||||||
flash("You cannot use this domain", "error")
|
|
||||||
return redirect(url_for("dashboard.setting"))
|
return redirect(url_for("dashboard.setting"))
|
||||||
|
|
||||||
current_user.default_alias_public_domain_id = sl_domain.id
|
|
||||||
current_user.default_alias_custom_domain_id = None
|
|
||||||
else:
|
|
||||||
custom_domain = CustomDomain.get_by(domain=default_domain)
|
|
||||||
if custom_domain:
|
|
||||||
# sanity check
|
|
||||||
if (
|
|
||||||
custom_domain.user_id != current_user.id
|
|
||||||
or not custom_domain.verified
|
|
||||||
):
|
|
||||||
LOG.w(
|
|
||||||
"%s cannot use domain %s", current_user, custom_domain
|
|
||||||
)
|
|
||||||
flash(f"Domain {default_domain} can't be used", "error")
|
|
||||||
return redirect(request.url)
|
|
||||||
else:
|
|
||||||
current_user.default_alias_custom_domain_id = (
|
|
||||||
custom_domain.id
|
|
||||||
)
|
|
||||||
current_user.default_alias_public_domain_id = None
|
|
||||||
|
|
||||||
else:
|
|
||||||
current_user.default_alias_custom_domain_id = None
|
|
||||||
current_user.default_alias_public_domain_id = None
|
|
||||||
|
|
||||||
Session.commit()
|
Session.commit()
|
||||||
flash("Your preference has been updated", "success")
|
flash("Your preference has been updated", "success")
|
||||||
return redirect(url_for("dashboard.setting"))
|
return redirect(url_for("dashboard.setting"))
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from flask import request, render_template, redirect, url_for, flash
|
from flask import request, render_template, redirect, url_for, flash
|
||||||
from flask_login import current_user, login_required
|
from flask_login import current_user, login_required
|
||||||
@ -11,6 +12,7 @@ from app.config import ADMIN_EMAIL
|
|||||||
from app.db import Session
|
from app.db import Session
|
||||||
from app.developer.base import developer_bp
|
from app.developer.base import developer_bp
|
||||||
from app.email_utils import send_email
|
from app.email_utils import send_email
|
||||||
|
from app.image_validation import detect_image_format, ImageFormat
|
||||||
from app.log import LOG
|
from app.log import LOG
|
||||||
from app.models import Client, RedirectUri, File, Referral
|
from app.models import Client, RedirectUri, File, Referral
|
||||||
from app.utils import random_string
|
from app.utils import random_string
|
||||||
@ -46,16 +48,25 @@ def client_detail(client_id):
|
|||||||
approval_form.description.data = client.description
|
approval_form.description.data = client.description
|
||||||
|
|
||||||
if action == "edit" and form.validate_on_submit():
|
if action == "edit" and form.validate_on_submit():
|
||||||
|
parsed_url = urlparse(form.url.data)
|
||||||
|
if parsed_url.scheme != "https":
|
||||||
|
flash("Only https urls are allowed", "error")
|
||||||
|
return redirect(url_for("developer.index"))
|
||||||
client.name = form.name.data
|
client.name = form.name.data
|
||||||
client.home_url = form.url.data
|
client.home_url = form.url.data
|
||||||
|
|
||||||
if form.icon.data:
|
if form.icon.data:
|
||||||
# todo: remove current icon if any
|
icon_data = form.icon.data.read(10240)
|
||||||
# todo: handle remove icon
|
if detect_image_format(icon_data) == ImageFormat.Unknown:
|
||||||
|
flash("Unknown file format", "warning")
|
||||||
|
return redirect(url_for("developer.index"))
|
||||||
|
if client.icon:
|
||||||
|
s3.delete(client.icon_id)
|
||||||
|
File.delete(client.icon)
|
||||||
file_path = random_string(30)
|
file_path = random_string(30)
|
||||||
file = File.create(path=file_path, user_id=client.user_id)
|
file = File.create(path=file_path, user_id=client.user_id)
|
||||||
|
|
||||||
s3.upload_from_bytesio(file_path, BytesIO(form.icon.data.read()))
|
s3.upload_from_bytesio(file_path, BytesIO(icon_data))
|
||||||
|
|
||||||
Session.flush()
|
Session.flush()
|
||||||
LOG.d("upload file %s to s3", file)
|
LOG.d("upload file %s to s3", file)
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from flask import render_template, redirect, url_for, flash
|
from flask import render_template, redirect, url_for, flash
|
||||||
from flask_login import current_user, login_required
|
from flask_login import current_user, login_required
|
||||||
from flask_wtf import FlaskForm
|
from flask_wtf import FlaskForm
|
||||||
@ -20,6 +22,10 @@ def new_client():
|
|||||||
|
|
||||||
if form.validate_on_submit():
|
if form.validate_on_submit():
|
||||||
client = Client.create_new(form.name.data, current_user.id)
|
client = Client.create_new(form.name.data, current_user.id)
|
||||||
|
parsed_url = urlparse(form.url.data)
|
||||||
|
if parsed_url.scheme != "https":
|
||||||
|
flash("Only https urls are allowed", "error")
|
||||||
|
return redirect(url_for("developer.new_client"))
|
||||||
client.home_url = form.url.data
|
client.home_url = form.url.data
|
||||||
Session.commit()
|
Session.commit()
|
||||||
|
|
||||||
|
@ -985,8 +985,8 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
|
|||||||
- the domain
|
- the domain
|
||||||
"""
|
"""
|
||||||
res = []
|
res = []
|
||||||
for domain in self.available_sl_domains(alias_options=alias_options):
|
for domain in self.get_sl_domains(alias_options=alias_options):
|
||||||
res.append((True, domain))
|
res.append((True, domain.domain))
|
||||||
|
|
||||||
for custom_domain in self.verified_custom_domains():
|
for custom_domain in self.verified_custom_domains():
|
||||||
res.append((False, custom_domain.domain))
|
res.append((False, custom_domain.domain))
|
||||||
@ -1128,7 +1128,10 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
|
|||||||
- Verified custom domains
|
- Verified custom domains
|
||||||
|
|
||||||
"""
|
"""
|
||||||
domains = self.available_sl_domains(alias_options=alias_options)
|
domains = [
|
||||||
|
sl_domain.domain
|
||||||
|
for sl_domain in self.get_sl_domains(alias_options=alias_options)
|
||||||
|
]
|
||||||
|
|
||||||
for custom_domain in self.verified_custom_domains():
|
for custom_domain in self.verified_custom_domains():
|
||||||
domains.append(custom_domain.domain)
|
domains.append(custom_domain.domain)
|
||||||
@ -2483,7 +2486,7 @@ class CustomDomain(Base, ModelMixin):
|
|||||||
return sorted(self._auto_create_rules, key=lambda rule: rule.order)
|
return sorted(self._auto_create_rules, key=lambda rule: rule.order)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<Custom Domain {self.domain}>"
|
return f"<Custom Domain {self.id} {self.domain}>"
|
||||||
|
|
||||||
|
|
||||||
class AutoCreateRule(Base, ModelMixin):
|
class AutoCreateRule(Base, ModelMixin):
|
||||||
@ -3114,7 +3117,7 @@ class SLDomain(Base, ModelMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<SLDomain {self.domain} {'Premium' if self.premium_only else 'Free'}"
|
return f"<SLDomain {self.id} {self.domain} {'Premium' if self.premium_only else 'Free'}>"
|
||||||
|
|
||||||
|
|
||||||
class Monitoring(Base, ModelMixin):
|
class Monitoring(Base, ModelMixin):
|
||||||
@ -3484,6 +3487,7 @@ class AdminAuditLog(Base):
|
|||||||
action=AuditLogActionEnum.stop_trial.value,
|
action=AuditLogActionEnum.stop_trial.value,
|
||||||
model="User",
|
model="User",
|
||||||
model_id=user_id,
|
model_id=user_id,
|
||||||
|
data={},
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -3729,6 +3733,7 @@ class SyncEvent(Base, ModelMixin):
|
|||||||
taken_time = sa.Column(
|
taken_time = sa.Column(
|
||||||
ArrowType, default=None, nullable=True, server_default=None, index=True
|
ArrowType, default=None, nullable=True, server_default=None, index=True
|
||||||
)
|
)
|
||||||
|
retry_count = sa.Column(sa.Integer, default=0, nullable=False, server_default="0")
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
sa.Index("ix_sync_event_created_at", "created_at"),
|
sa.Index("ix_sync_event_created_at", "created_at"),
|
||||||
@ -3750,7 +3755,7 @@ class SyncEvent(Base, ModelMixin):
|
|||||||
return res.rowcount > 0
|
return res.rowcount > 0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_dead_letter(cls, older_than: Arrow) -> [SyncEvent]:
|
def get_dead_letter(cls, older_than: Arrow, max_retries: int) -> [SyncEvent]:
|
||||||
return (
|
return (
|
||||||
SyncEvent.filter(
|
SyncEvent.filter(
|
||||||
(
|
(
|
||||||
@ -3763,6 +3768,7 @@ class SyncEvent(Base, ModelMixin):
|
|||||||
& (SyncEvent.created_at < older_than)
|
& (SyncEvent.created_at < older_than)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
& (SyncEvent.retry_count < max_retries)
|
||||||
)
|
)
|
||||||
.order_by(SyncEvent.id)
|
.order_by(SyncEvent.id)
|
||||||
.limit(100)
|
.limit(100)
|
||||||
|
47
app/app/user_settings.py
Normal file
47
app/app/user_settings.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from app.db import Session
|
||||||
|
from app.log import LOG
|
||||||
|
from app.models import User, SLDomain, CustomDomain
|
||||||
|
|
||||||
|
|
||||||
|
class CannotSetAlias(Exception):
|
||||||
|
def __init__(self, msg: str):
|
||||||
|
self.msg = msg
|
||||||
|
|
||||||
|
|
||||||
|
def set_default_alias_domain(user: User, domain_name: Optional[str]):
|
||||||
|
if domain_name is None:
|
||||||
|
LOG.i(f"User {user} has set no domain as default domain")
|
||||||
|
user.default_alias_public_domain_id = None
|
||||||
|
user.default_alias_custom_domain_id = None
|
||||||
|
Session.flush()
|
||||||
|
return
|
||||||
|
sl_domain: SLDomain = SLDomain.get_by(domain=domain_name)
|
||||||
|
if sl_domain:
|
||||||
|
if sl_domain.hidden:
|
||||||
|
LOG.i(f"User {user} has tried to set up a hidden domain as default domain")
|
||||||
|
raise CannotSetAlias("Domain does not exist")
|
||||||
|
if sl_domain.premium_only and not user.is_premium():
|
||||||
|
LOG.i(f"User {user} has tried to set up a premium domain as default domain")
|
||||||
|
raise CannotSetAlias("You cannot use this domain")
|
||||||
|
LOG.i(f"User {user} has set public {sl_domain} as default domain")
|
||||||
|
user.default_alias_public_domain_id = sl_domain.id
|
||||||
|
user.default_alias_custom_domain_id = None
|
||||||
|
Session.flush()
|
||||||
|
return
|
||||||
|
custom_domain = CustomDomain.get_by(domain=domain_name)
|
||||||
|
if not custom_domain:
|
||||||
|
LOG.i(
|
||||||
|
f"User {user} has tried to set up an non existing domain as default domain"
|
||||||
|
)
|
||||||
|
raise CannotSetAlias("Domain does not exist or it hasn't been verified")
|
||||||
|
if custom_domain.user_id != user.id or not custom_domain.verified:
|
||||||
|
LOG.i(
|
||||||
|
f"User {user} has tried to set domain {custom_domain} as default domain that does not belong to the user or that is not verified"
|
||||||
|
)
|
||||||
|
raise CannotSetAlias("Domain does not exist or it hasn't been verified")
|
||||||
|
LOG.i(f"User {user} has set custom {custom_domain} as default domain")
|
||||||
|
user.default_alias_public_domain_id = None
|
||||||
|
user.default_alias_custom_domain_id = custom_domain.id
|
||||||
|
Session.flush()
|
@ -262,7 +262,8 @@ def get_or_create_contact(from_header: str, mail_from: str, alias: Alias) -> Con
|
|||||||
|
|
||||||
Session.commit()
|
Session.commit()
|
||||||
except IntegrityError:
|
except IntegrityError:
|
||||||
# No need to manually rollback, as IntegrityError already rolls back
|
# If the tx has been rolled back, the connection is borked. Force close to try to get a new one and start fresh
|
||||||
|
Session.close()
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"Contact with email {contact_email} for alias_id {alias_id} already existed, fetching from DB"
|
f"Contact with email {contact_email} for alias_id {alias_id} already existed, fetching from DB"
|
||||||
)
|
)
|
||||||
|
@ -8,6 +8,8 @@ from events.runner import Runner
|
|||||||
from events.event_source import DeadLetterEventSource, PostgresEventSource
|
from events.event_source import DeadLetterEventSource, PostgresEventSource
|
||||||
from events.event_sink import ConsoleEventSink, HttpEventSink
|
from events.event_sink import ConsoleEventSink, HttpEventSink
|
||||||
|
|
||||||
|
_DEFAULT_MAX_RETRIES = 100
|
||||||
|
|
||||||
|
|
||||||
class Mode(Enum):
|
class Mode(Enum):
|
||||||
DEAD_LETTER = "dead_letter"
|
DEAD_LETTER = "dead_letter"
|
||||||
@ -23,10 +25,10 @@ class Mode(Enum):
|
|||||||
raise ValueError(f"Invalid mode: {value}")
|
raise ValueError(f"Invalid mode: {value}")
|
||||||
|
|
||||||
|
|
||||||
def main(mode: Mode, dry_run: bool):
|
def main(mode: Mode, dry_run: bool, max_retries: int):
|
||||||
if mode == Mode.DEAD_LETTER:
|
if mode == Mode.DEAD_LETTER:
|
||||||
LOG.i("Using DeadLetterEventSource")
|
LOG.i("Using DeadLetterEventSource")
|
||||||
source = DeadLetterEventSource()
|
source = DeadLetterEventSource(max_retries)
|
||||||
elif mode == Mode.LISTENER:
|
elif mode == Mode.LISTENER:
|
||||||
LOG.i("Using PostgresEventSource")
|
LOG.i("Using PostgresEventSource")
|
||||||
source = PostgresEventSource(DB_URI)
|
source = PostgresEventSource(DB_URI)
|
||||||
@ -51,6 +53,13 @@ def args():
|
|||||||
help="Mode to run",
|
help="Mode to run",
|
||||||
choices=[Mode.DEAD_LETTER.value, Mode.LISTENER.value],
|
choices=[Mode.DEAD_LETTER.value, Mode.LISTENER.value],
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"max_retries",
|
||||||
|
help="Max retries to consider an event as error and not try to process it again",
|
||||||
|
type=int,
|
||||||
|
nargs="?",
|
||||||
|
default=_DEFAULT_MAX_RETRIES,
|
||||||
|
)
|
||||||
parser.add_argument("--dry-run", help="Dry run mode", action="store_true")
|
parser.add_argument("--dry-run", help="Dry run mode", action="store_true")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@ -61,4 +70,8 @@ if __name__ == "__main__":
|
|||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
args = args()
|
args = args()
|
||||||
main(Mode.from_str(args.mode), args.dry_run)
|
main(
|
||||||
|
mode=Mode.from_str(args.mode),
|
||||||
|
dry_run=args.dry_run,
|
||||||
|
max_retries=args.max_retries,
|
||||||
|
)
|
||||||
|
@ -4,6 +4,8 @@ import psycopg2
|
|||||||
import select
|
import select
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from app.db import Session
|
||||||
from app.log import LOG
|
from app.log import LOG
|
||||||
from app.models import SyncEvent
|
from app.models import SyncEvent
|
||||||
from app.events.event_dispatcher import NOTIFICATION_CHANNEL
|
from app.events.event_dispatcher import NOTIFICATION_CHANNEL
|
||||||
@ -66,6 +68,7 @@ class PostgresEventSource(EventSource):
|
|||||||
LOG.info(f"Could not find event with id={notify.payload}")
|
LOG.info(f"Could not find event with id={notify.payload}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOG.warn(f"Error getting event: {e}")
|
LOG.warn(f"Error getting event: {e}")
|
||||||
|
Session.close() # Ensure we get a new connection and we don't leave a dangling tx
|
||||||
|
|
||||||
def __connect(self):
|
def __connect(self):
|
||||||
self.__connection = psycopg2.connect(self.__connection_string)
|
self.__connection = psycopg2.connect(self.__connection_string)
|
||||||
@ -76,6 +79,9 @@ class PostgresEventSource(EventSource):
|
|||||||
|
|
||||||
|
|
||||||
class DeadLetterEventSource(EventSource):
|
class DeadLetterEventSource(EventSource):
|
||||||
|
def __init__(self, max_retries: int):
|
||||||
|
self.__max_retries = max_retries
|
||||||
|
|
||||||
@newrelic.agent.background_task()
|
@newrelic.agent.background_task()
|
||||||
def run(self, on_event: Callable[[SyncEvent], NoReturn]):
|
def run(self, on_event: Callable[[SyncEvent], NoReturn]):
|
||||||
while True:
|
while True:
|
||||||
@ -83,7 +89,9 @@ class DeadLetterEventSource(EventSource):
|
|||||||
threshold = arrow.utcnow().shift(
|
threshold = arrow.utcnow().shift(
|
||||||
minutes=-_DEAD_LETTER_THRESHOLD_MINUTES
|
minutes=-_DEAD_LETTER_THRESHOLD_MINUTES
|
||||||
)
|
)
|
||||||
events = SyncEvent.get_dead_letter(older_than=threshold)
|
events = SyncEvent.get_dead_letter(
|
||||||
|
older_than=threshold, max_retries=self.__max_retries
|
||||||
|
)
|
||||||
if events:
|
if events:
|
||||||
LOG.info(f"Got {len(events)} dead letter events")
|
LOG.info(f"Got {len(events)} dead letter events")
|
||||||
if events:
|
if events:
|
||||||
@ -92,7 +100,8 @@ class DeadLetterEventSource(EventSource):
|
|||||||
)
|
)
|
||||||
for event in events:
|
for event in events:
|
||||||
on_event(event)
|
on_event(event)
|
||||||
else:
|
Session.close() # Ensure that we have a new connection and we don't have a dangling tx with a lock
|
||||||
|
if not events:
|
||||||
LOG.debug("No dead letter events")
|
LOG.debug("No dead letter events")
|
||||||
sleep(_DEAD_LETTER_INTERVAL_SECONDS)
|
sleep(_DEAD_LETTER_INTERVAL_SECONDS)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -2,6 +2,7 @@ import arrow
|
|||||||
import newrelic.agent
|
import newrelic.agent
|
||||||
|
|
||||||
from app.log import LOG
|
from app.log import LOG
|
||||||
|
from app.db import Session
|
||||||
from app.models import SyncEvent
|
from app.models import SyncEvent
|
||||||
from events.event_sink import EventSink
|
from events.event_sink import EventSink
|
||||||
from events.event_source import EventSource
|
from events.event_source import EventSource
|
||||||
@ -37,6 +38,9 @@ class Runner:
|
|||||||
"Custom/sync_event_elapsed_time",
|
"Custom/sync_event_elapsed_time",
|
||||||
time_between_taken_and_created.total_seconds(),
|
time_between_taken_and_created.total_seconds(),
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
event.retry_count = event.retry_count + 1
|
||||||
|
Session.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOG.warn(f"Exception processing event [id={event.id}]: {e}")
|
LOG.warn(f"Exception processing event [id={event.id}]: {e}")
|
||||||
newrelic.agent.record_custom_metric("Custom/sync_event_failed", 1)
|
newrelic.agent.record_custom_metric("Custom/sync_event_failed", 1)
|
||||||
|
@ -0,0 +1,28 @@
|
|||||||
|
"""add retry count to sync event
|
||||||
|
|
||||||
|
Revision ID: 56d08955fcab
|
||||||
|
Revises: d608b8e48082
|
||||||
|
Create Date: 2024-07-19 08:21:19.979973
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '56d08955fcab'
|
||||||
|
down_revision = 'd608b8e48082'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column('sync_event', sa.Column('retry_count', sa.Integer(), server_default='0', nullable=False, default=0))
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_column('sync_event', 'retry_count')
|
||||||
|
# ### end Alembic commands ###
|
@ -12,10 +12,10 @@ docker run -p 25432:5432 --name ${container_name} -e POSTGRES_PASSWORD=postgres
|
|||||||
sleep 3
|
sleep 3
|
||||||
|
|
||||||
# upgrade the DB to the latest stage and
|
# upgrade the DB to the latest stage and
|
||||||
env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl poetry run alembic upgrade head
|
env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl rye run alembic upgrade head
|
||||||
|
|
||||||
# generate the migration script.
|
# generate the migration script.
|
||||||
env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl poetry run alembic revision --autogenerate $@
|
env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl rye run alembic revision --autogenerate $@
|
||||||
|
|
||||||
# remove the db
|
# remove the db
|
||||||
docker rm -f ${container_name}
|
docker rm -f ${container_name}
|
||||||
|
@ -3,5 +3,5 @@
|
|||||||
export DB_URI=postgresql://myuser:mypassword@localhost:15432/simplelogin
|
export DB_URI=postgresql://myuser:mypassword@localhost:15432/simplelogin
|
||||||
echo 'drop schema public cascade; create schema public;' | psql $DB_URI
|
echo 'drop schema public cascade; create schema public;' | psql $DB_URI
|
||||||
|
|
||||||
poetry run alembic upgrade head
|
rye run alembic upgrade head
|
||||||
poetry run flask dummy-data
|
rye run flask dummy-data
|
||||||
|
@ -3,4 +3,4 @@
|
|||||||
export DB_URI=postgresql://myuser:mypassword@localhost:15432/test
|
export DB_URI=postgresql://myuser:mypassword@localhost:15432/test
|
||||||
echo 'drop schema public cascade; create schema public;' | psql $DB_URI
|
echo 'drop schema public cascade; create schema public;' | psql $DB_URI
|
||||||
|
|
||||||
poetry run alembic upgrade head
|
rye run alembic upgrade head
|
||||||
|
@ -283,6 +283,7 @@ def set_index_page(app):
|
|||||||
and not request.path.startswith("/git")
|
and not request.path.startswith("/git")
|
||||||
and not request.path.startswith("/favicon.ico")
|
and not request.path.startswith("/favicon.ico")
|
||||||
):
|
):
|
||||||
|
start_time = g.start_time or time.time()
|
||||||
LOG.d(
|
LOG.d(
|
||||||
"%s %s %s %s %s, takes %s",
|
"%s %s %s %s %s, takes %s",
|
||||||
request.remote_addr,
|
request.remote_addr,
|
||||||
@ -290,7 +291,7 @@ def set_index_page(app):
|
|||||||
request.path,
|
request.path,
|
||||||
request.args,
|
request.args,
|
||||||
res.status_code,
|
res.status_code,
|
||||||
time.time() - g.start_time,
|
time.time() - start_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
@ -44,6 +44,9 @@ def test_update_settings_alias_generator(flask_client):
|
|||||||
|
|
||||||
def test_update_settings_random_alias_default_domain(flask_client):
|
def test_update_settings_random_alias_default_domain(flask_client):
|
||||||
user = login(flask_client)
|
user = login(flask_client)
|
||||||
|
custom_domain = CustomDomain.create(
|
||||||
|
domain=random_domain(), verified=True, user_id=user.id, flush=True
|
||||||
|
)
|
||||||
assert user.default_random_alias_domain() == "sl.local"
|
assert user.default_random_alias_domain() == "sl.local"
|
||||||
|
|
||||||
r = flask_client.patch(
|
r = flask_client.patch(
|
||||||
@ -57,6 +60,12 @@ def test_update_settings_random_alias_default_domain(flask_client):
|
|||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
assert user.default_random_alias_domain() == "d1.test"
|
assert user.default_random_alias_domain() == "d1.test"
|
||||||
|
|
||||||
|
r = flask_client.patch(
|
||||||
|
"/api/setting", json={"random_alias_default_domain": custom_domain.domain}
|
||||||
|
)
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert user.default_random_alias_domain() == custom_domain.domain
|
||||||
|
|
||||||
|
|
||||||
def test_update_settings_sender_format(flask_client):
|
def test_update_settings_sender_format(flask_client):
|
||||||
user = login(flask_client)
|
user = login(flask_client)
|
||||||
|
@ -23,7 +23,7 @@ _MAX_PER_MINUTE = 3
|
|||||||
_ENDPOINT,
|
_ENDPOINT,
|
||||||
methods=["GET"],
|
methods=["GET"],
|
||||||
)
|
)
|
||||||
@limiter.limit(f"{_MAX_PER_MINUTE}/minute")
|
@limiter.limit(f"{_MAX_PER_MINUTE}/hour")
|
||||||
def rate_limited_endpoint_1():
|
def rate_limited_endpoint_1():
|
||||||
return "Working", HTTPStatus.OK
|
return "Working", HTTPStatus.OK
|
||||||
|
|
||||||
|
@ -365,12 +365,22 @@ def test_sync_event_dead_letter():
|
|||||||
commit=True,
|
commit=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# create event with too many retries
|
||||||
|
max_retries = 5
|
||||||
|
e5 = SyncEvent.create(
|
||||||
|
content=b"content",
|
||||||
|
retry_count=max_retries + 1,
|
||||||
|
created_at=arrow.now(),
|
||||||
|
commit=True,
|
||||||
|
)
|
||||||
|
|
||||||
# get dead letter events
|
# get dead letter events
|
||||||
dead_letter_events = SyncEvent.get_dead_letter(
|
dead_letter_events = SyncEvent.get_dead_letter(
|
||||||
older_than=arrow.now().shift(minutes=-10)
|
older_than=arrow.now().shift(minutes=-10), max_retries=max_retries
|
||||||
)
|
)
|
||||||
assert len(dead_letter_events) == 2
|
assert len(dead_letter_events) == 2
|
||||||
assert e1 in dead_letter_events
|
assert e1 in dead_letter_events
|
||||||
assert e2 in dead_letter_events
|
assert e2 in dead_letter_events
|
||||||
assert e3 not in dead_letter_events
|
assert e3 not in dead_letter_events
|
||||||
assert e4 not in dead_letter_events
|
assert e4 not in dead_letter_events
|
||||||
|
assert e5 not in dead_letter_events
|
||||||
|
0
app/tests/user_settings/__init__.py
Normal file
0
app/tests/user_settings/__init__.py
Normal file
128
app/tests/user_settings/test_set_default_alias_domain.py
Normal file
128
app/tests/user_settings/test_set_default_alias_domain.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from app import user_settings
|
||||||
|
from app.db import Session
|
||||||
|
from app.models import User, CustomDomain, SLDomain
|
||||||
|
from tests.utils import random_token, create_new_user
|
||||||
|
|
||||||
|
user_id: int = 0
|
||||||
|
custom_domain_name: str = ""
|
||||||
|
sl_domain_name: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
def setup_module():
|
||||||
|
global user_id, custom_domain_name, sl_domain_name
|
||||||
|
user = create_new_user()
|
||||||
|
user.trial_end = None
|
||||||
|
user_id = user.id
|
||||||
|
custom_domain_name = CustomDomain.create(
|
||||||
|
user_id=user_id,
|
||||||
|
catch_all=True,
|
||||||
|
domain=random_token() + ".com",
|
||||||
|
verified=True,
|
||||||
|
flush=True,
|
||||||
|
).domain
|
||||||
|
sl_domain_name = SLDomain.create(
|
||||||
|
domain=random_token() + ".com",
|
||||||
|
premium_only=False,
|
||||||
|
flush=True,
|
||||||
|
order=5,
|
||||||
|
hidden=False,
|
||||||
|
).domain
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_default_no_domain():
|
||||||
|
user = User.get(user_id)
|
||||||
|
user.default_alias_public_domain_id = SLDomain.get_by(domain=sl_domain_name).id
|
||||||
|
user.default_alias_private_domain_id = CustomDomain.get_by(
|
||||||
|
domain=custom_domain_name
|
||||||
|
).id
|
||||||
|
Session.flush()
|
||||||
|
user_settings.set_default_alias_domain(user, None)
|
||||||
|
assert user.default_alias_public_domain_id is None
|
||||||
|
assert user.default_alias_custom_domain_id is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_premium_sl_domain_with_non_premium_user():
|
||||||
|
user = User.get(user_id)
|
||||||
|
user.lifetime = False
|
||||||
|
domain = SLDomain.get_by(domain=sl_domain_name)
|
||||||
|
domain.premium_only = True
|
||||||
|
Session.flush()
|
||||||
|
with pytest.raises(user_settings.CannotSetAlias):
|
||||||
|
user_settings.set_default_alias_domain(user, sl_domain_name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_hidden_sl_domain():
|
||||||
|
user = User.get(user_id)
|
||||||
|
domain = SLDomain.get_by(domain=sl_domain_name)
|
||||||
|
domain.hidden = True
|
||||||
|
domain.premium_only = False
|
||||||
|
Session.flush()
|
||||||
|
with pytest.raises(user_settings.CannotSetAlias):
|
||||||
|
user_settings.set_default_alias_domain(user, sl_domain_name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_sl_domain():
|
||||||
|
user = User.get(user_id)
|
||||||
|
user.lifetime = False
|
||||||
|
domain = SLDomain.get_by(domain=sl_domain_name)
|
||||||
|
domain.hidden = False
|
||||||
|
domain.premium_only = False
|
||||||
|
Session.flush()
|
||||||
|
user_settings.set_default_alias_domain(user, sl_domain_name)
|
||||||
|
assert user.default_alias_public_domain_id == domain.id
|
||||||
|
assert user.default_alias_custom_domain_id is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_sl_premium_domain():
|
||||||
|
user = User.get(user_id)
|
||||||
|
user.lifetime = True
|
||||||
|
domain = SLDomain.get_by(domain=sl_domain_name)
|
||||||
|
domain.hidden = False
|
||||||
|
domain.premium_only = True
|
||||||
|
Session.flush()
|
||||||
|
user_settings.set_default_alias_domain(user, sl_domain_name)
|
||||||
|
assert user.default_alias_public_domain_id == domain.id
|
||||||
|
assert user.default_alias_custom_domain_id is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_other_user_custom_domain():
|
||||||
|
user = User.get(user_id)
|
||||||
|
user.lifetime = True
|
||||||
|
other_user_domain_name = CustomDomain.create(
|
||||||
|
user_id=create_new_user().id,
|
||||||
|
catch_all=True,
|
||||||
|
domain=random_token() + ".com",
|
||||||
|
verified=True,
|
||||||
|
).domain
|
||||||
|
Session.flush()
|
||||||
|
with pytest.raises(user_settings.CannotSetAlias):
|
||||||
|
user_settings.set_default_alias_domain(user, other_user_domain_name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_unverified_custom_domain():
|
||||||
|
user = User.get(user_id)
|
||||||
|
user.lifetime = True
|
||||||
|
domain = CustomDomain.get_by(domain=custom_domain_name)
|
||||||
|
domain.verified = False
|
||||||
|
Session.flush()
|
||||||
|
with pytest.raises(user_settings.CannotSetAlias):
|
||||||
|
user_settings.set_default_alias_domain(user, custom_domain_name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_custom_domain():
|
||||||
|
user = User.get(user_id)
|
||||||
|
user.lifetime = True
|
||||||
|
domain = CustomDomain.get_by(domain=custom_domain_name)
|
||||||
|
domain.verified = True
|
||||||
|
Session.flush()
|
||||||
|
user_settings.set_default_alias_domain(user, custom_domain_name)
|
||||||
|
assert user.default_alias_public_domain_id is None
|
||||||
|
assert user.default_alias_custom_domain_id == domain.id
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_invalid_custom_domain():
|
||||||
|
user = User.get(user_id)
|
||||||
|
with pytest.raises(user_settings.CannotSetAlias):
|
||||||
|
user_settings.set_default_alias_domain(user, "invalid_nop" + random_token())
|
Loading…
x
Reference in New Issue
Block a user