Compare commits

...

7 Commits

Author SHA1 Message Date
24ba25ab6a 4.42.2 2024-04-10 17:23:11 +01:00
78184eeae4 4.42.1 2024-03-26 12:00:08 +00:00
c111fbe8e1 4.42.0 2024-03-19 12:00:09 +00:00
d5981588e4 4.41.2 2024-03-15 12:00:08 +00:00
6af1c2ccf4 Merge pull request 'Correct docker package name' (#2) from fix-package-name-in-gitea-actions into main
Reviewed-on: #2
2024-03-14 15:47:01 +00:00
76664f6e4c Correct docker package name 2024-03-14 15:46:44 +00:00
f7125618c4 4.41.0 2024-03-14 12:00:08 +00:00
52 changed files with 1238 additions and 123 deletions

View File

@ -5,7 +5,7 @@ on:
- '*' - '*'
env: env:
CONTAINER_NAME: git.mrmeeb.stream/mrmeeb/simple-login-dev CONTAINER_NAME: git.mrmeeb.stream/mrmeeb/simple-login
TEA_VERSION: 0.9.2 TEA_VERSION: 0.9.2
jobs: jobs:

View File

@ -1,7 +1,6 @@
name: Test and lint name: Test and lint
on: on: [push, pull_request]
push:
jobs: jobs:
lint: lint:
@ -139,6 +138,12 @@ jobs:
with: with:
fetch-depth: 0 fetch-depth: 0
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- name: Create Sentry release - name: Create Sentry release
uses: getsentry/action-release@v1 uses: getsentry/action-release@v1
env: env:
@ -158,6 +163,7 @@ jobs:
uses: docker/build-push-action@v3 uses: docker/build-push-action@v3
with: with:
context: . context: .
platforms: linux/amd64,linux/arm64
push: true push: true
tags: ${{ steps.meta.outputs.tags }} tags: ${{ steps.meta.outputs.tags }}

View File

@ -151,10 +151,10 @@ Here are the small sum-ups of the directory structures and their roles:
## Pull request ## Pull request
The code is formatted using https://github.com/psf/black, to format the code, simply run The code is formatted using [ruff](https://github.com/astral-sh/ruff), to format the code, simply run
``` ```
poetry run black . poetry run ruff format .
``` ```
The code is also checked with `flake8`, make sure to run `flake8` before creating the pull request by The code is also checked with `flake8`, make sure to run `flake8` before creating the pull request by

View File

@ -308,28 +308,29 @@ def delete_alias(alias: Alias, user: User):
Delete an alias and add it to either global or domain trash Delete an alias and add it to either global or domain trash
Should be used instead of Alias.delete, DomainDeletedAlias.create, DeletedAlias.create Should be used instead of Alias.delete, DomainDeletedAlias.create, DeletedAlias.create
""" """
# save deleted alias to either global or domain trash LOG.i(f"User {user} has deleted alias {alias}")
# save deleted alias to either global or domain tra
if alias.custom_domain_id: if alias.custom_domain_id:
if not DomainDeletedAlias.get_by( if not DomainDeletedAlias.get_by(
email=alias.email, domain_id=alias.custom_domain_id email=alias.email, domain_id=alias.custom_domain_id
): ):
LOG.d("add %s to domain %s trash", alias, alias.custom_domain_id) domain_deleted_alias = DomainDeletedAlias(
Session.add( user_id=user.id,
DomainDeletedAlias( email=alias.email,
user_id=user.id, domain_id=alias.custom_domain_id,
email=alias.email,
domain_id=alias.custom_domain_id,
)
) )
Session.add(domain_deleted_alias)
Session.commit() Session.commit()
LOG.i(
f"Moving {alias} to domain {alias.custom_domain_id} trash {domain_deleted_alias}"
)
else: else:
if not DeletedAlias.get_by(email=alias.email): if not DeletedAlias.get_by(email=alias.email):
LOG.d("add %s to global trash", alias) deleted_alias = DeletedAlias(email=alias.email)
Session.add(DeletedAlias(email=alias.email)) Session.add(deleted_alias)
Session.commit() Session.commit()
LOG.i(f"Moving {alias} to global trash {deleted_alias}")
LOG.i("delete alias %s", alias)
Alias.filter(Alias.id == alias.id).delete() Alias.filter(Alias.id == alias.id).delete()
Session.commit() Session.commit()

View File

@ -16,6 +16,7 @@ from .views import (
social, social,
recovery, recovery,
api_to_cookie, api_to_cookie,
oidc,
) )
__all__ = [ __all__ = [
@ -36,4 +37,5 @@ __all__ = [
"social", "social",
"recovery", "recovery",
"api_to_cookie", "api_to_cookie",
"oidc",
] ]

View File

@ -5,7 +5,7 @@ from wtforms import StringField, validators
from app.auth.base import auth_bp from app.auth.base import auth_bp
from app.auth.views.login_utils import after_login from app.auth.views.login_utils import after_login
from app.config import CONNECT_WITH_PROTON from app.config import CONNECT_WITH_PROTON, CONNECT_WITH_OIDC_ICON, OIDC_CLIENT_ID
from app.events.auth_event import LoginEvent from app.events.auth_event import LoginEvent
from app.extensions import limiter from app.extensions import limiter
from app.log import LOG from app.log import LOG
@ -77,4 +77,6 @@ def login():
next_url=next_url, next_url=next_url,
show_resend_activation=show_resend_activation, show_resend_activation=show_resend_activation,
connect_with_proton=CONNECT_WITH_PROTON, connect_with_proton=CONNECT_WITH_PROTON,
connect_with_oidc=OIDC_CLIENT_ID is not None,
connect_with_oidc_icon=CONNECT_WITH_OIDC_ICON,
) )

131
app/app/auth/views/oidc.py Normal file
View File

@ -0,0 +1,131 @@
from flask import request, session, redirect, flash, url_for
from requests_oauthlib import OAuth2Session
from app import config
from app.auth.base import auth_bp
from app.auth.views.login_utils import after_login
from app.config import (
URL,
OIDC_AUTHORIZATION_URL,
OIDC_USER_INFO_URL,
OIDC_TOKEN_URL,
OIDC_SCOPES,
OIDC_NAME_FIELD,
)
from app.db import Session
from app.email_utils import send_welcome_email
from app.log import LOG
from app.models import User, SocialAuth
from app.utils import encode_url, sanitize_email, sanitize_next_url
# need to set explicitly redirect_uri instead of leaving the lib to pre-fill redirect_uri
# when served behind nginx, the redirect_uri is localhost... and not the real url
_redirect_uri = URL + "/auth/oidc/callback"
SESSION_STATE_KEY = "oauth_state"
@auth_bp.route("/oidc/login")
def oidc_login():
if config.OIDC_CLIENT_ID is None or config.OIDC_CLIENT_SECRET is None:
return redirect(url_for("auth.login"))
next_url = sanitize_next_url(request.args.get("next"))
if next_url:
redirect_uri = _redirect_uri + "?next=" + encode_url(next_url)
else:
redirect_uri = _redirect_uri
oidc = OAuth2Session(
config.OIDC_CLIENT_ID, scope=[OIDC_SCOPES], redirect_uri=redirect_uri
)
authorization_url, state = oidc.authorization_url(OIDC_AUTHORIZATION_URL)
# State is used to prevent CSRF, keep this for later.
session[SESSION_STATE_KEY] = state
return redirect(authorization_url)
@auth_bp.route("/oidc/callback")
def oidc_callback():
if SESSION_STATE_KEY not in session:
flash("Invalid state, please retry", "error")
return redirect(url_for("auth.login"))
if config.OIDC_CLIENT_ID is None or config.OIDC_CLIENT_SECRET is None:
return redirect(url_for("auth.login"))
# user clicks on cancel
if "error" in request.args:
flash("Please use another sign in method then", "warning")
return redirect("/")
oidc = OAuth2Session(
config.OIDC_CLIENT_ID,
state=session[SESSION_STATE_KEY],
scope=[OIDC_SCOPES],
redirect_uri=_redirect_uri,
)
oidc.fetch_token(
OIDC_TOKEN_URL,
client_secret=config.OIDC_CLIENT_SECRET,
authorization_response=request.url,
)
oidc_user_data = oidc.get(OIDC_USER_INFO_URL)
if oidc_user_data.status_code != 200:
LOG.e(
f"cannot get oidc user data {oidc_user_data.status_code} {oidc_user_data.text}"
)
flash(
"Cannot get user data from OIDC, please use another way to login/sign up",
"error",
)
return redirect(url_for("auth.login"))
oidc_user_data = oidc_user_data.json()
email = oidc_user_data.get("email")
if not email:
LOG.e(f"cannot get email for OIDC user {oidc_user_data} {email}")
flash(
"Cannot get a valid email from OIDC, please another way to login/sign up",
"error",
)
return redirect(url_for("auth.login"))
email = sanitize_email(email)
user = User.get_by(email=email)
if not user and config.DISABLE_REGISTRATION:
flash(
"Sorry you cannot sign up via the OIDC provider. Please sign-up first with your email.",
"error",
)
return redirect(url_for("auth.register"))
elif not user:
user = create_user(email, oidc_user_data)
if not SocialAuth.get_by(user_id=user.id, social="oidc"):
SocialAuth.create(user_id=user.id, social="oidc")
Session.commit()
# The activation link contains the original page, for ex authorize page
next_url = sanitize_next_url(request.args.get("next")) if request.args else None
return after_login(user, next_url)
def create_user(email, oidc_user_data):
new_user = User.create(
email=email,
name=oidc_user_data.get(OIDC_NAME_FIELD),
password="",
activated=True,
)
LOG.i(f"Created new user for login request from OIDC. New user {new_user.id}")
Session.commit()
send_welcome_email(new_user)
return new_user

View File

@ -6,7 +6,7 @@ from wtforms import StringField, validators
from app import email_utils, config from app import email_utils, config
from app.auth.base import auth_bp from app.auth.base import auth_bp
from app.config import CONNECT_WITH_PROTON from app.config import CONNECT_WITH_PROTON, CONNECT_WITH_OIDC_ICON
from app.auth.views.login_utils import get_referral from app.auth.views.login_utils import get_referral
from app.config import URL, HCAPTCHA_SECRET, HCAPTCHA_SITEKEY from app.config import URL, HCAPTCHA_SECRET, HCAPTCHA_SITEKEY
from app.db import Session from app.db import Session
@ -109,6 +109,8 @@ def register():
next_url=next_url, next_url=next_url,
HCAPTCHA_SITEKEY=HCAPTCHA_SITEKEY, HCAPTCHA_SITEKEY=HCAPTCHA_SITEKEY,
connect_with_proton=CONNECT_WITH_PROTON, connect_with_proton=CONNECT_WITH_PROTON,
connect_with_oidc=config.OIDC_CLIENT_ID is not None,
connect_with_oidc_icon=CONNECT_WITH_OIDC_ICON,
) )

View File

@ -234,7 +234,7 @@ else:
print("WARNING: Use a temp directory for GNUPGHOME", GNUPGHOME) print("WARNING: Use a temp directory for GNUPGHOME", GNUPGHOME)
# Github, Google, Facebook client id and secrets # Github, Google, Facebook, OIDC client id and secrets
GITHUB_CLIENT_ID = os.environ.get("GITHUB_CLIENT_ID") GITHUB_CLIENT_ID = os.environ.get("GITHUB_CLIENT_ID")
GITHUB_CLIENT_SECRET = os.environ.get("GITHUB_CLIENT_SECRET") GITHUB_CLIENT_SECRET = os.environ.get("GITHUB_CLIENT_SECRET")
@ -244,6 +244,15 @@ GOOGLE_CLIENT_SECRET = os.environ.get("GOOGLE_CLIENT_SECRET")
FACEBOOK_CLIENT_ID = os.environ.get("FACEBOOK_CLIENT_ID") FACEBOOK_CLIENT_ID = os.environ.get("FACEBOOK_CLIENT_ID")
FACEBOOK_CLIENT_SECRET = os.environ.get("FACEBOOK_CLIENT_SECRET") FACEBOOK_CLIENT_SECRET = os.environ.get("FACEBOOK_CLIENT_SECRET")
CONNECT_WITH_OIDC_ICON = os.environ.get("CONNECT_WITH_OIDC_ICON")
OIDC_AUTHORIZATION_URL = os.environ.get("OIDC_AUTHORIZATION_URL")
OIDC_USER_INFO_URL = os.environ.get("OIDC_USER_INFO_URL")
OIDC_TOKEN_URL = os.environ.get("OIDC_TOKEN_URL")
OIDC_CLIENT_ID = os.environ.get("OIDC_CLIENT_ID")
OIDC_CLIENT_SECRET = os.environ.get("OIDC_CLIENT_SECRET")
OIDC_SCOPES = os.environ.get("OIDC_SCOPES")
OIDC_NAME_FIELD = os.environ.get("OIDC_NAME_FIELD", "name")
PROTON_CLIENT_ID = os.environ.get("PROTON_CLIENT_ID") PROTON_CLIENT_ID = os.environ.get("PROTON_CLIENT_ID")
PROTON_CLIENT_SECRET = os.environ.get("PROTON_CLIENT_SECRET") PROTON_CLIENT_SECRET = os.environ.get("PROTON_CLIENT_SECRET")
PROTON_BASE_URL = os.environ.get( PROTON_BASE_URL = os.environ.get(
@ -423,6 +432,9 @@ except Exception:
HIBP_API_KEYS = sl_getenv("HIBP_API_KEYS", list) or [] HIBP_API_KEYS = sl_getenv("HIBP_API_KEYS", list) or []
HIBP_MAX_ALIAS_CHECK = 10_000 HIBP_MAX_ALIAS_CHECK = 10_000
HIBP_RPM = 100 HIBP_RPM = 100
HIBP_SKIP_PARTNER_ALIAS = os.environ.get("HIBP_SKIP_PARTNER_ALIAS")
KEEP_OLD_DATA_DAYS = 30
POSTMASTER = os.environ.get("POSTMASTER") POSTMASTER = os.environ.get("POSTMASTER")

View File

@ -2,10 +2,12 @@ from app.dashboard.base import dashboard_bp
from flask_login import login_required, current_user from flask_login import login_required, current_user
from app.alias_utils import alias_export_csv from app.alias_utils import alias_export_csv
from app.dashboard.views.enter_sudo import sudo_required from app.dashboard.views.enter_sudo import sudo_required
from app.extensions import limiter
@dashboard_bp.route("/alias_export", methods=["GET"]) @dashboard_bp.route("/alias_export", methods=["GET"])
@login_required @login_required
@sudo_required @sudo_required
@limiter.limit("2/minute")
def alias_export_route(): def alias_export_route():
return alias_export_csv(current_user) return alias_export_csv(current_user)

View File

@ -7,6 +7,7 @@ from app.config import JOB_BATCH_IMPORT
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required from app.dashboard.views.enter_sudo import sudo_required
from app.db import Session from app.db import Session
from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import File, BatchImport, Job from app.models import File, BatchImport, Job
from app.utils import random_string, CSRFValidationForm from app.utils import random_string, CSRFValidationForm
@ -15,6 +16,7 @@ from app.utils import random_string, CSRFValidationForm
@dashboard_bp.route("/batch_import", methods=["GET", "POST"]) @dashboard_bp.route("/batch_import", methods=["GET", "POST"])
@login_required @login_required
@sudo_required @sudo_required
@limiter.limit("10/minute", methods=["POST"])
def batch_import_route(): def batch_import_route():
# only for users who have custom domains # only for users who have custom domains
if not current_user.verified_custom_domains(): if not current_user.verified_custom_domains():
@ -39,7 +41,7 @@ def batch_import_route():
return redirect(request.url) return redirect(request.url)
if len(batch_imports) > 10: if len(batch_imports) > 10:
flash( flash(
"You have too many imports already. Wait until some get cleaned up", "You have too many imports already. Please wait until some get cleaned up",
"error", "error",
) )
return render_template( return render_template(

View File

@ -6,15 +6,15 @@ from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import PasswordField, validators from wtforms import PasswordField, validators
from app.config import CONNECT_WITH_PROTON from app.config import CONNECT_WITH_PROTON, OIDC_CLIENT_ID, CONNECT_WITH_OIDC_ICON
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.extensions import limiter from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import PartnerUser from app.models import PartnerUser, SocialAuth
from app.proton.utils import get_proton_partner from app.proton.utils import get_proton_partner
from app.utils import sanitize_next_url from app.utils import sanitize_next_url
_SUDO_GAP = 900 _SUDO_GAP = 120
class LoginForm(FlaskForm): class LoginForm(FlaskForm):
@ -51,11 +51,19 @@ def enter_sudo():
if not partner_user or partner_user.partner_id != get_proton_partner().id: if not partner_user or partner_user.partner_id != get_proton_partner().id:
proton_enabled = False proton_enabled = False
oidc_enabled = OIDC_CLIENT_ID is not None
if oidc_enabled:
oidc_enabled = (
SocialAuth.get_by(user_id=current_user.id, social="oidc") is not None
)
return render_template( return render_template(
"dashboard/enter_sudo.html", "dashboard/enter_sudo.html",
password_check_form=password_check_form, password_check_form=password_check_form,
next=request.args.get("next"), next=request.args.get("next"),
connect_with_proton=proton_enabled, connect_with_proton=proton_enabled,
connect_with_oidc=oidc_enabled,
connect_with_oidc_icon=CONNECT_WITH_OIDC_ICON,
) )

View File

@ -141,7 +141,7 @@ def index():
) )
if request.form.get("form-name") == "delete-alias": if request.form.get("form-name") == "delete-alias":
LOG.d("delete alias %s", alias) LOG.i(f"User {current_user} requested deletion of alias {alias}")
email = alias.email email = alias.email
alias_utils.delete_alias(alias, current_user) alias_utils.delete_alias(alias, current_user)
flash(f"Alias {email} has been deleted", "success") flash(f"Alias {email} has been deleted", "success")

View File

@ -21,6 +21,7 @@ LIST_UNSUBSCRIBE = "List-Unsubscribe"
LIST_UNSUBSCRIBE_POST = "List-Unsubscribe-Post" LIST_UNSUBSCRIBE_POST = "List-Unsubscribe-Post"
RETURN_PATH = "Return-Path" RETURN_PATH = "Return-Path"
AUTHENTICATION_RESULTS = "Authentication-Results" AUTHENTICATION_RESULTS = "Authentication-Results"
SL_QUEUE_ID = "X-SL-Queue-Id"
# headers used to DKIM sign in order of preference # headers used to DKIM sign in order of preference
DKIM_HEADERS = [ DKIM_HEADERS = [

View File

@ -494,9 +494,10 @@ def delete_header(msg: Message, header: str):
def sanitize_header(msg: Message, header: str): def sanitize_header(msg: Message, header: str):
"""remove trailing space and remove linebreak from a header""" """remove trailing space and remove linebreak from a header"""
header_lowercase = header.lower()
for i in reversed(range(len(msg._headers))): for i in reversed(range(len(msg._headers))):
header_name = msg._headers[i][0].lower() header_name = msg._headers[i][0].lower()
if header_name == header.lower(): if header_name == header_lowercase:
# msg._headers[i] is a tuple like ('From', 'hey@google.com') # msg._headers[i] is a tuple like ('From', 'hey@google.com')
if msg._headers[i][1]: if msg._headers[i][1]:
msg._headers[i] = ( msg._headers[i] = (

View File

@ -30,7 +30,9 @@ def apply_dmarc_policy_for_forward_phase(
) -> Tuple[Message, Optional[str]]: ) -> Tuple[Message, Optional[str]]:
spam_result = SpamdResult.extract_from_headers(msg, Phase.forward) spam_result = SpamdResult.extract_from_headers(msg, Phase.forward)
if not DMARC_CHECK_ENABLED or not spam_result: if not DMARC_CHECK_ENABLED or not spam_result:
LOG.i("DMARC check disabled")
return msg, None return msg, None
LOG.i(f"Spam check result in {spam_result}")
from_header = get_header_unicode(msg[headers.FROM]) from_header = get_header_unicode(msg[headers.FROM])
@ -150,8 +152,10 @@ def apply_dmarc_policy_for_reply_phase(
) -> Optional[str]: ) -> Optional[str]:
spam_result = SpamdResult.extract_from_headers(msg, Phase.reply) spam_result = SpamdResult.extract_from_headers(msg, Phase.reply)
if not DMARC_CHECK_ENABLED or not spam_result: if not DMARC_CHECK_ENABLED or not spam_result:
LOG.i("DMARC check disabled")
return None return None
LOG.i(f"Spam check result is {spam_result}")
if spam_result.dmarc not in ( if spam_result.dmarc not in (
DmarcCheckResult.quarantine, DmarcCheckResult.quarantine,
DmarcCheckResult.reject, DmarcCheckResult.reject,

View File

@ -30,7 +30,10 @@ def handle_batch_import(batch_import: BatchImport):
LOG.d("Download file %s from %s", batch_import.file, file_url) LOG.d("Download file %s from %s", batch_import.file, file_url)
r = requests.get(file_url) r = requests.get(file_url)
lines = [line.decode("utf-8") for line in r.iter_lines()] # Replace invisible character
lines = [
line.decode("utf-8").replace("\ufeff", "").strip() for line in r.iter_lines()
]
import_from_csv(batch_import, user, lines) import_from_csv(batch_import, user, lines)

View File

@ -1422,6 +1422,9 @@ def generate_random_alias_email(
class Alias(Base, ModelMixin): class Alias(Base, ModelMixin):
__tablename__ = "alias" __tablename__ = "alias"
FLAG_PARTNER_CREATED = 1 << 0
user_id = sa.Column( user_id = sa.Column(
sa.ForeignKey(User.id, ondelete="cascade"), nullable=False, index=True sa.ForeignKey(User.id, ondelete="cascade"), nullable=False, index=True
) )
@ -1431,6 +1434,9 @@ class Alias(Base, ModelMixin):
name = sa.Column(sa.String(128), nullable=True, default=None) name = sa.Column(sa.String(128), nullable=True, default=None)
enabled = sa.Column(sa.Boolean(), default=True, nullable=False) enabled = sa.Column(sa.Boolean(), default=True, nullable=False)
flags = sa.Column(
sa.BigInteger(), default=0, server_default="0", nullable=False, index=True
)
custom_domain_id = sa.Column( custom_domain_id = sa.Column(
sa.ForeignKey("custom_domain.id", ondelete="cascade"), nullable=True, index=True sa.ForeignKey("custom_domain.id", ondelete="cascade"), nullable=True, index=True
@ -2586,10 +2592,13 @@ class Job(Base, ModelMixin):
nullable=False, nullable=False,
server_default=str(JobState.ready.value), server_default=str(JobState.ready.value),
default=JobState.ready.value, default=JobState.ready.value,
index=True,
) )
attempts = sa.Column(sa.Integer, nullable=False, server_default="0", default=0) attempts = sa.Column(sa.Integer, nullable=False, server_default="0", default=0)
taken_at = sa.Column(ArrowType, nullable=True) taken_at = sa.Column(ArrowType, nullable=True)
__table_args__ = (Index("ix_state_run_at_taken_at", state, run_at, taken_at),)
def __repr__(self): def __repr__(self):
return f"<Job {self.id} {self.name} {self.payload}>" return f"<Job {self.id} {self.name} {self.payload}>"
@ -2937,7 +2946,9 @@ class RecoveryCode(Base, ModelMixin):
class Notification(Base, ModelMixin): class Notification(Base, ModelMixin):
__tablename__ = "notification" __tablename__ = "notification"
user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) user_id = sa.Column(
sa.ForeignKey(User.id, ondelete="cascade"), nullable=False, index=True
)
message = sa.Column(sa.Text, nullable=False) message = sa.Column(sa.Text, nullable=False)
title = sa.Column(sa.String(512)) title = sa.Column(sa.String(512))

View File

@ -30,7 +30,9 @@ def check_bucket_limit(
try: try:
value = lock_redis.incr(bucket_lock_name, bucket_seconds) value = lock_redis.incr(bucket_lock_name, bucket_seconds)
if value > max_hits: if value > max_hits:
LOG.i(f"Rate limit hit for {bucket_lock_name} -> {value}/{max_hits}") LOG.i(
f"Rate limit hit for {lock_name} (bucket id {bucket_id}) -> {value}/{max_hits}"
)
newrelic.agent.record_custom_event( newrelic.agent.record_custom_event(
"BucketRateLimit", "BucketRateLimit",
{"lock_name": lock_name, "bucket_seconds": bucket_seconds}, {"lock_name": lock_name, "bucket_seconds": bucket_seconds},

View File

@ -5,19 +5,9 @@ from typing import Optional
import boto3 import boto3
import requests import requests
from app.config import ( from app import config
AWS_REGION,
BUCKET,
AWS_ACCESS_KEY_ID,
AWS_SECRET_ACCESS_KEY,
LOCAL_FILE_UPLOAD,
UPLOAD_DIR,
URL,
AWS_ENDPOINT_URL,
)
from app.log import LOG from app.log import LOG
_s3_client = None _s3_client = None
@ -25,12 +15,12 @@ def _get_s3client():
global _s3_client global _s3_client
if _s3_client is None: if _s3_client is None:
args = { args = {
"aws_access_key_id": AWS_ACCESS_KEY_ID, "aws_access_key_id": config.AWS_ACCESS_KEY_ID,
"aws_secret_access_key": AWS_SECRET_ACCESS_KEY, "aws_secret_access_key": config.AWS_SECRET_ACCESS_KEY,
"region_name": AWS_REGION, "region_name": config.AWS_REGION,
} }
if AWS_ENDPOINT_URL: if config.AWS_ENDPOINT_URL:
args["endpoint_url"] = AWS_ENDPOINT_URL args["endpoint_url"] = config.AWS_ENDPOINT_URL
_s3_client = boto3.client("s3", **args) _s3_client = boto3.client("s3", **args)
return _s3_client return _s3_client
@ -38,8 +28,8 @@ def _get_s3client():
def upload_from_bytesio(key: str, bs: BytesIO, content_type="application/octet-stream"): def upload_from_bytesio(key: str, bs: BytesIO, content_type="application/octet-stream"):
bs.seek(0) bs.seek(0)
if LOCAL_FILE_UPLOAD: if config.LOCAL_FILE_UPLOAD:
file_path = os.path.join(UPLOAD_DIR, key) file_path = os.path.join(config.UPLOAD_DIR, key)
file_dir = os.path.dirname(file_path) file_dir = os.path.dirname(file_path)
os.makedirs(file_dir, exist_ok=True) os.makedirs(file_dir, exist_ok=True)
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
@ -47,7 +37,7 @@ def upload_from_bytesio(key: str, bs: BytesIO, content_type="application/octet-s
else: else:
_get_s3client().put_object( _get_s3client().put_object(
Bucket=BUCKET, Bucket=config.BUCKET,
Key=key, Key=key,
Body=bs, Body=bs,
ContentType=content_type, ContentType=content_type,
@ -57,8 +47,8 @@ def upload_from_bytesio(key: str, bs: BytesIO, content_type="application/octet-s
def upload_email_from_bytesio(path: str, bs: BytesIO, filename): def upload_email_from_bytesio(path: str, bs: BytesIO, filename):
bs.seek(0) bs.seek(0)
if LOCAL_FILE_UPLOAD: if config.LOCAL_FILE_UPLOAD:
file_path = os.path.join(UPLOAD_DIR, path) file_path = os.path.join(config.UPLOAD_DIR, path)
file_dir = os.path.dirname(file_path) file_dir = os.path.dirname(file_path)
os.makedirs(file_dir, exist_ok=True) os.makedirs(file_dir, exist_ok=True)
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
@ -66,7 +56,7 @@ def upload_email_from_bytesio(path: str, bs: BytesIO, filename):
else: else:
_get_s3client().put_object( _get_s3client().put_object(
Bucket=BUCKET, Bucket=config.BUCKET,
Key=path, Key=path,
Body=bs, Body=bs,
# Support saving a remote file using Http header # Support saving a remote file using Http header
@ -77,12 +67,12 @@ def upload_email_from_bytesio(path: str, bs: BytesIO, filename):
def download_email(path: str) -> Optional[str]: def download_email(path: str) -> Optional[str]:
if LOCAL_FILE_UPLOAD: if config.LOCAL_FILE_UPLOAD:
file_path = os.path.join(UPLOAD_DIR, path) file_path = os.path.join(config.UPLOAD_DIR, path)
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
return f.read() return f.read()
resp = _get_s3client().get_object( resp = _get_s3client().get_object(
Bucket=BUCKET, Bucket=config.BUCKET,
Key=path, Key=path,
) )
if not resp or "Body" not in resp: if not resp or "Body" not in resp:
@ -96,29 +86,30 @@ def upload_from_url(url: str, upload_path):
def get_url(key: str, expires_in=3600) -> str: def get_url(key: str, expires_in=3600) -> str:
if LOCAL_FILE_UPLOAD: if config.LOCAL_FILE_UPLOAD:
return URL + "/static/upload/" + key return config.URL + "/static/upload/" + key
else: else:
return _get_s3client().generate_presigned_url( return _get_s3client().generate_presigned_url(
ExpiresIn=expires_in, ExpiresIn=expires_in,
ClientMethod="get_object", ClientMethod="get_object",
Params={"Bucket": BUCKET, "Key": key}, Params={"Bucket": config.BUCKET, "Key": key},
) )
def delete(path: str): def delete(path: str):
if LOCAL_FILE_UPLOAD: if config.LOCAL_FILE_UPLOAD:
os.remove(os.path.join(UPLOAD_DIR, path)) file_path = os.path.join(config.UPLOAD_DIR, path)
os.remove(file_path)
else: else:
_get_s3client().delete_object(Bucket=BUCKET, Key=path) _get_s3client().delete_object(Bucket=config.BUCKET, Key=path)
def create_bucket_if_not_exists(): def create_bucket_if_not_exists():
s3client = _get_s3client() s3client = _get_s3client()
buckets = s3client.list_buckets() buckets = s3client.list_buckets()
for bucket in buckets["Buckets"]: for bucket in buckets["Buckets"]:
if bucket["Name"] == BUCKET: if bucket["Name"] == config.BUCKET:
LOG.i("Bucket already exists") LOG.i("Bucket already exists")
return return
s3client.create_bucket(Bucket=BUCKET) s3client.create_bucket(Bucket=config.BUCKET)
LOG.i(f"Bucket {BUCKET} created") LOG.i(f"Bucket {config.BUCKET} created")

View File

@ -5,7 +5,7 @@ from typing import List, Tuple
import arrow import arrow
import requests import requests
from sqlalchemy import func, desc, or_, and_, nullsfirst from sqlalchemy import func, desc, or_, and_
from sqlalchemy.ext.compiler import compiles from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from sqlalchemy.orm.exc import ObjectDeletedError from sqlalchemy.orm.exc import ObjectDeletedError
@ -61,6 +61,9 @@ from app.pgp_utils import load_public_key_and_check, PGPException
from app.proton.utils import get_proton_partner from app.proton.utils import get_proton_partner
from app.utils import sanitize_email from app.utils import sanitize_email
from server import create_light_app from server import create_light_app
from tasks.cleanup_old_imports import cleanup_old_imports
from tasks.cleanup_old_jobs import cleanup_old_jobs
from tasks.cleanup_old_notifications import cleanup_old_notifications
DELETE_GRACE_DAYS = 30 DELETE_GRACE_DAYS = 30
@ -976,6 +979,9 @@ async def _hibp_check(api_key, queue):
continue continue
user = alias.user user = alias.user
if user.disabled or not user.is_paid(): if user.disabled or not user.is_paid():
# Mark it as hibp done to skip it as if it had been checked
alias.hibp_last_check = arrow.utcnow()
Session.commit()
continue continue
LOG.d("Checking HIBP for %s", alias) LOG.d("Checking HIBP for %s", alias)
@ -1030,6 +1036,60 @@ async def _hibp_check(api_key, queue):
await asyncio.sleep(rate_sleep) await asyncio.sleep(rate_sleep)
def get_alias_to_check_hibp(
oldest_hibp_allowed: arrow.Arrow,
user_ids_to_skip: list[int],
min_alias_id: int,
max_alias_id: int,
):
now = arrow.now()
alias_query = (
Session.query(Alias)
.join(User, User.id == Alias.user_id)
.join(Subscription, User.id == Subscription.user_id, isouter=True)
.join(ManualSubscription, User.id == ManualSubscription.user_id, isouter=True)
.join(AppleSubscription, User.id == AppleSubscription.user_id, isouter=True)
.join(
CoinbaseSubscription,
User.id == CoinbaseSubscription.user_id,
isouter=True,
)
.join(PartnerUser, User.id == PartnerUser.user_id, isouter=True)
.join(
PartnerSubscription,
PartnerSubscription.partner_user_id == PartnerUser.id,
isouter=True,
)
.filter(
or_(
Alias.hibp_last_check.is_(None),
Alias.hibp_last_check < oldest_hibp_allowed,
),
Alias.user_id.notin_(user_ids_to_skip),
Alias.enabled,
Alias.id >= min_alias_id,
Alias.id < max_alias_id,
User.disabled == False, # noqa: E712
or_(
User.lifetime,
ManualSubscription.end_at > now,
Subscription.next_bill_date > now.date(),
AppleSubscription.expires_date > now,
CoinbaseSubscription.end_at > now,
PartnerSubscription.end_at > now,
),
)
)
if config.HIBP_SKIP_PARTNER_ALIAS:
alias_query = alias_query.filter(
Alias.flags.op("&")(Alias.FLAG_PARTNER_CREATED) == 0
)
for alias in (
alias_query.order_by(Alias.id.asc()).enable_eagerloads(False).yield_per(500)
):
yield alias
async def check_hibp(): async def check_hibp():
""" """
Check all aliases on the HIBP (Have I Been Pwned) API Check all aliases on the HIBP (Have I Been Pwned) API
@ -1056,41 +1116,43 @@ async def check_hibp():
user_ids = [row[0] for row in rows] user_ids = [row[0] for row in rows]
LOG.d("Got %d users to skip" % len(user_ids)) LOG.d("Got %d users to skip" % len(user_ids))
LOG.d("Preparing list of aliases to check") LOG.d("Checking aliases")
queue = asyncio.Queue() queue = asyncio.Queue()
max_date = arrow.now().shift(days=-config.HIBP_SCAN_INTERVAL_DAYS) min_alias_id = 0
for alias in ( max_alias_id = Session.query(func.max(Alias.id)).scalar()
Alias.filter( step = 10000
or_(Alias.hibp_last_check.is_(None), Alias.hibp_last_check < max_date), now = arrow.now()
Alias.user_id.notin_(user_ids), oldest_hibp_allowed = now.shift(days=-config.HIBP_SCAN_INTERVAL_DAYS)
alias_checked = 0
for alias_batch_id in range(min_alias_id, max_alias_id, step):
for alias in get_alias_to_check_hibp(
oldest_hibp_allowed, user_ids, alias_batch_id, alias_batch_id + step
):
await queue.put(alias.id)
alias_checked += queue.qsize()
LOG.d(
f"Need to check about {queue.qsize()} aliases in this loop {alias_batch_id}/{max_alias_id}"
) )
.filter(Alias.enabled)
.order_by(nullsfirst(Alias.hibp_last_check.asc()), Alias.id.asc())
.yield_per(500)
.enable_eagerloads(False)
):
await queue.put(alias.id)
LOG.d("Need to check about %s aliases", queue.qsize()) # Start one checking process per API key
# Each checking process will take one alias from the queue, get the info
# Start one checking process per API key # and then sleep for 1.5 seconds (due to HIBP API request limits)
# Each checking process will take one alias from the queue, get the info checkers = []
# and then sleep for 1.5 seconds (due to HIBP API request limits) for i in range(len(config.HIBP_API_KEYS)):
checkers = [] checker = asyncio.create_task(
for i in range(len(config.HIBP_API_KEYS)): _hibp_check(
checker = asyncio.create_task( config.HIBP_API_KEYS[i],
_hibp_check( queue,
config.HIBP_API_KEYS[i], )
queue,
) )
) checkers.append(checker)
checkers.append(checker)
# Wait until all checking processes are done # Wait until all checking processes are done
for checker in checkers: for checker in checkers:
await checker await checker
LOG.d("Done checking HIBP API for aliases in breaches") LOG.d(f"Done checking {alias_checked} HIBP API for aliases in breaches")
def notify_hibp(): def notify_hibp():
@ -1162,6 +1224,13 @@ def clear_users_scheduled_to_be_deleted(dry_run=False):
Session.commit() Session.commit()
def delete_old_data():
oldest_valid = arrow.now().shift(days=-config.KEEP_OLD_DATA_DAYS)
cleanup_old_imports(oldest_valid)
cleanup_old_jobs(oldest_valid)
cleanup_old_notifications(oldest_valid)
if __name__ == "__main__": if __name__ == "__main__":
LOG.d("Start running cronjob") LOG.d("Start running cronjob")
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -1176,6 +1245,7 @@ if __name__ == "__main__":
"notify_manual_subscription_end", "notify_manual_subscription_end",
"notify_premium_end", "notify_premium_end",
"delete_logs", "delete_logs",
"delete_old_data",
"poll_apple_subscription", "poll_apple_subscription",
"sanity_check", "sanity_check",
"delete_old_monitoring", "delete_old_monitoring",
@ -1204,6 +1274,9 @@ if __name__ == "__main__":
elif args.job == "delete_logs": elif args.job == "delete_logs":
LOG.d("Deleted Logs") LOG.d("Deleted Logs")
delete_logs() delete_logs()
elif args.job == "delete_old_data":
LOG.d("Delete old data")
delete_old_data()
elif args.job == "poll_apple_subscription": elif args.job == "poll_apple_subscription":
LOG.d("Poll Apple Subscriptions") LOG.d("Poll Apple Subscriptions")
poll_apple_subscription() poll_apple_subscription()

View File

@ -37,6 +37,12 @@ jobs:
schedule: "15 5 * * *" schedule: "15 5 * * *"
captureStderr: true captureStderr: true
- name: SimpleLogin Delete Old data
command: python /code/cron.py -j delete_old_data
shell: /bin/bash
schedule: "30 5 * * *"
captureStderr: true
- name: SimpleLogin Poll Apple Subscriptions - name: SimpleLogin Poll Apple Subscriptions
command: python /code/cron.py -j poll_apple_subscription command: python /code/cron.py -j poll_apple_subscription
shell: /bin/bash shell: /bin/bash

View File

@ -875,6 +875,7 @@ def forward_email_to_mailbox(
# References and In-Reply-To are used for keeping the email thread # References and In-Reply-To are used for keeping the email thread
headers.REFERENCES, headers.REFERENCES,
headers.IN_REPLY_TO, headers.IN_REPLY_TO,
headers.SL_QUEUE_ID,
headers.LIST_UNSUBSCRIBE, headers.LIST_UNSUBSCRIBE,
headers.LIST_UNSUBSCRIBE_POST, headers.LIST_UNSUBSCRIBE_POST,
] + headers.MIME_HEADERS ] + headers.MIME_HEADERS
@ -2040,10 +2041,11 @@ def handle(envelope: Envelope, msg: Message) -> str:
return status.E204 return status.E204
# sanitize email headers # sanitize email headers
sanitize_header(msg, "from") sanitize_header(msg, headers.FROM)
sanitize_header(msg, "to") sanitize_header(msg, headers.TO)
sanitize_header(msg, "cc") sanitize_header(msg, headers.CC)
sanitize_header(msg, "reply-to") sanitize_header(msg, headers.REPLY_TO)
sanitize_header(msg, headers.MESSAGE_ID)
LOG.d( LOG.d(
"==>> Handle mail_from:%s, rcpt_tos:%s, header_from:%s, header_to:%s, " "==>> Handle mail_from:%s, rcpt_tos:%s, header_from:%s, header_to:%s, "

View File

@ -116,6 +116,16 @@ WORDS_FILE_PATH=local_data/test_words.txt
# CONNECT_WITH_PROTON=true # CONNECT_WITH_PROTON=true
# CONNECT_WITH_PROTON_COOKIE_NAME=to_fill # CONNECT_WITH_PROTON_COOKIE_NAME=to_fill
# Login with OIDC
# CONNECT_WITH_OIDC_ICON=fa-github
# OIDC_AUTHORIZATION_URL=to_fill
# OIDC_USER_INFO_URL=to_fill
# OIDC_TOKEN_URL=to_fill
# OIDC_SCOPES=openid email profile
# OIDC_NAME_FIELD=name
# OIDC_CLIENT_ID=to_fill
# OIDC_CLIENT_SECRET=to_fill
# Flask profiler # Flask profiler
# FLASK_PROFILER_PATH=/tmp/flask-profiler.sql # FLASK_PROFILER_PATH=/tmp/flask-profiler.sql
# FLASK_PROFILER_PASSWORD=password # FLASK_PROFILER_PASSWORD=password

View File

@ -0,0 +1,48 @@
"""empty message
Revision ID: 52510a633d6f
Revises: 818b0a956205
Create Date: 2024-03-12 12:46:24.161644
"""
import sqlalchemy_utils
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "52510a633d6f"
down_revision = "818b0a956205"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"alias", sa.Column("flags", sa.BigInteger(), server_default="0", nullable=False)
)
with op.get_context().autocommit_block():
op.create_index(op.f("ix_alias_flags"), "alias", ["flags"], unique=False)
op.create_index(op.f("ix_job_state"), "job", ["state"], unique=False)
op.create_index(
"ix_state_run_at_taken_at",
"job",
["state", "run_at", "taken_at"],
unique=False,
)
op.create_index(
op.f("ix_notification_user_id"), "notification", ["user_id"], unique=False
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.get_context().autocommit_block():
op.drop_index(op.f("ix_notification_user_id"), table_name="notification")
op.drop_index("ix_state_run_at_taken_at", table_name="job")
op.drop_index(op.f("ix_job_state"), table_name="job")
op.drop_index(op.f("ix_alias_flags"), table_name="alias")
op.drop_column("alias", "flags")
# ### end Alembic commands ###

View File

@ -0,0 +1,37 @@
#!/usr/bin/env python3
import argparse
import random
import time
from sqlalchemy import func
from app import config
from app.models import Alias, Contact
from app.db import Session
parser = argparse.ArgumentParser(
prog=f"Replace {config.NOREPLY}",
description=f"Replace {config.NOREPLY} from contacts reply email",
)
args = parser.parse_args()
max_alias_id: int = Session.query(func.max(Alias.id)).scalar()
start = time.time()
tests = 1000
for i in range(tests):
alias = (
Alias.filter(Alias.id > int(random.random() * max_alias_id))
.order_by(Alias.id.asc())
.limit(1)
.first()
)
contact = Contact.filter_by(alias_id=alias.id).order_by(Contact.id.asc()).first()
mailboxes = alias.mailboxes
user = alias.user
if i % 10:
print("{i} -> {alias.id}")
end = time.time()
time_taken = end - start
print(f"Took {time_taken} -> {time_taken/tests} per test")

View File

@ -0,0 +1,56 @@
#!/usr/bin/env python3
import argparse
import time
from sqlalchemy import func
from app.models import Alias, SLDomain
from app.db import Session
parser = argparse.ArgumentParser(
prog="Mark partner created aliases with the PARTNER_CREATED flag",
)
parser.add_argument(
"-s", "--start_alias_id", default=0, type=int, help="Initial alias_id"
)
parser.add_argument("-e", "--end_alias_id", default=0, type=int, help="Last alias_id")
args = parser.parse_args()
alias_id_start = args.start_alias_id
max_alias_id = args.end_alias_id
if max_alias_id == 0:
max_alias_id = Session.query(func.max(Alias.id)).scalar()
print(f"Updating aliases from {alias_id_start} to {max_alias_id}")
domains = SLDomain.filter(SLDomain.partner_id.isnot(None)).all()
cond = [f"email like '%{domain.domain}'" for domain in domains]
sql_or_cond = " OR ".join(cond)
sql = f"UPDATE alias set flags = (flags | :flag) WHERE id >= :start and id<:end and flags & :flag = 0 and ({sql_or_cond})"
print(sql)
step = 1000
updated = 0
start_time = time.time()
for batch_start in range(alias_id_start, max_alias_id, step):
updated += Session.execute(
sql,
{
"start": batch_start,
"end": batch_start + step,
"flag": Alias.FLAG_PARTNER_CREATED,
},
).rowcount
elapsed = time.time() - start_time
time_per_alias = elapsed / (batch_start - alias_id_start + step)
last_batch_id = batch_start + step
remaining = max_alias_id - last_batch_id
time_remaining = (max_alias_id - last_batch_id) * time_per_alias
hours_remaining = time_remaining / 3600.0
percent = int(
((batch_start - alias_id_start) * 100) / (max_alias_id - alias_id_start)
)
print(
f"\rAlias {batch_start}/{max_alias_id} {percent}% {updated} updated {hours_remaining:.2f}hrs remaining"
)
print(f"Updated aliases up to {max_alias_id}")

0
app/tasks/__init__.py Normal file
View File

View File

@ -0,0 +1,19 @@
import arrow
from app import s3
from app.log import LOG
from app.models import BatchImport
def cleanup_old_imports(oldest_allowed: arrow.Arrow):
LOG.i(f"Deleting imports older than {oldest_allowed}")
for batch_import in (
BatchImport.filter(BatchImport.created_at < oldest_allowed).yield_per(500).all()
):
LOG.i(
f"Deleting batch import {batch_import} with file {batch_import.file.path}"
)
file = batch_import.file
if file is not None:
s3.delete(file.path)
BatchImport.delete(batch_import.id, commit=True)

View File

@ -0,0 +1,24 @@
import arrow
from sqlalchemy import or_, and_
from app import config
from app.db import Session
from app.log import LOG
from app.models import Job, JobState
def cleanup_old_jobs(oldest_allowed: arrow.Arrow):
LOG.i(f"Deleting jobs older than {oldest_allowed}")
count = Job.filter(
or_(
Job.state == JobState.done.value,
Job.state == JobState.error.value,
and_(
Job.state == JobState.taken.value,
Job.attempts >= config.JOB_MAX_ATTEMPTS,
),
),
Job.updated_at < oldest_allowed,
).delete()
Session.commit()
LOG.i(f"Deleted {count} jobs")

View File

@ -0,0 +1,12 @@
import arrow
from app.db import Session
from app.log import LOG
from app.models import Notification
def cleanup_old_notifications(oldest_allowed: arrow.Arrow):
LOG.i(f"Deleting notifications older than {oldest_allowed}")
count = Notification.filter(Notification.created_at < oldest_allowed).delete()
Session.commit()
LOG.i(f"Deleted {count} notifications")

View File

@ -38,11 +38,21 @@
<span>or</span> <span>or</span>
</div> </div>
<a class="btn btn-primary btn-block mt-2 proton-button" <a class="btn btn-primary btn-block mt-2 proton-button"
href="{{ url_for("auth.proton_login", next=next_url) }}"> href="{{ url_for('auth.proton_login', next=next_url) }}">
<img class="mr-2" src="/static/images/proton.svg" /> <img class="mr-2" src="/static/images/proton.svg" />
Log in with Proton Log in with Proton
</a> </a>
{% endif %} {% endif %}
{% if connect_with_oidc %}
<div class="text-center my-2 text-gray">
<span>or</span>
</div>
<a class="btn btn-primary btn-block mt-2 btn-social"
href="{{ url_for('auth.oidc_login', next=next_url) }}">
<i class="fa {{ connect_with_oidc_icon }}"></i> Log in with SSO
</a>
{% endif %}
</div> </div>
</div> </div>
<div class="text-center text-muted mt-2"> <div class="text-center text-muted mt-2">

View File

@ -50,11 +50,21 @@
<span>or</span> <span>or</span>
</div> </div>
<a class="btn btn-primary btn-block mt-2 proton-button" <a class="btn btn-primary btn-block mt-2 proton-button"
href="{{ url_for("auth.proton_login", next=next_url) }}"> href="{{ url_for('auth.proton_login', next=next_url) }}">
<img class="mr-2" src="/static/images/proton.svg" /> <img class="mr-2" src="/static/images/proton.svg" />
Sign up with Proton Sign up with Proton
</a> </a>
{% endif %} {% endif %}
{% if connect_with_oidc %}
<div class="text-center my-2 text-gray">
<span>or</span>
</div>
<a class="btn btn-primary btn-block mt-2 btn-social"
href="{{ url_for('auth.oidc_login', next=next_url) }}">
<i class="fa {{ connect_with_oidc_icon }}"></i> Sign up with SSO
</a>
{% endif %}
</div> </div>
</form> </form>
<div class="text-center text-muted mb-6"> <div class="text-center text-muted mb-6">

View File

@ -120,21 +120,6 @@
</div> </div>
</div> </div>
<!-- END WebAuthn --> <!-- END WebAuthn -->
<!-- Alias import/export -->
<div class="card">
<div class="card-body">
<div class="card-title">Alias import/export</div>
<div class="mb-3">
You can import your aliases created on other platforms into SimpleLogin.
You can also export your aliases to a readable csv format for a future batch import.
</div>
<a href="{{ url_for('dashboard.batch_import_route') }}"
class="btn btn-outline-primary">Batch Import</a>
<a href="{{ url_for('dashboard.alias_export_route') }}"
class="btn btn-outline-secondary">Export Aliases</a>
</div>
</div>
<!-- END Alias import/export -->
<!-- data export --> <!-- data export -->
<div class="card"> <div class="card">
<div class="card-body"> <div class="card-body">

View File

@ -22,11 +22,20 @@
<p>Alternatively you can use your Proton credentials to ensure it's you.</p> <p>Alternatively you can use your Proton credentials to ensure it's you.</p>
</div> </div>
<a class="btn btn-primary btn-block mt-2 proton-button w-25" <a class="btn btn-primary btn-block mt-2 proton-button w-25"
href="{{ url_for("auth.proton_login", next=next) }}"> href="{{ url_for('auth.proton_login', next=next) }}">
<img class="mr-2" src="/static/images/proton.svg" /> <img class="mr-2" src="/static/images/proton.svg" />
Authenticate with Proton Authenticate with Proton
</a> </a>
{% endif %} {% endif %}
{% if connect_with_oidc %}
<div class="my-3">
<p>Alternatively you can use your SSO credentials to ensure it's you.</p>
<a class="btn btn-primary btn-block mt-2 btn-social w-25"
href="{{ url_for('auth.oidc_login', next=next) }}">
<i class="fa {{ connect_with_oidc_icon }}"></i> Authenticate with SSO
</a>
{% endif %}
</div>
</div> </div>
</div> {% endblock %}
{% endblock %}

View File

@ -559,7 +559,7 @@
sender address. sender address.
<br /> <br />
If this option is enabled, the original sender addresses is stored in the email header <b>X-SimpleLogin-Envelope-From</b> If this option is enabled, the original sender addresses is stored in the email header <b>X-SimpleLogin-Envelope-From</b>
and the original From header is stored in <b>X-SimpleLogin-Original-From<b>. and the original From header is stored in <b>X-SimpleLogin-Original-From</b>.
You can choose to display this header in your email client. You can choose to display this header in your email client.
<br /> <br />
As email headers aren't encrypted, your mailbox service can know the sender address via this header. As email headers aren't encrypted, your mailbox service can know the sender address via this header.
@ -583,6 +583,23 @@
</form> </form>
</div> </div>
</div> </div>
<!-- Alias import/export -->
<div class="card">
<div class="card-body">
<div class="card-title">
Alias import/export
</div>
<div class="mb-3">
You can import your aliases created on other platforms into SimpleLogin.
You can also export your aliases to a readable csv format for a future batch import.
</div>
<a href="{{ url_for('dashboard.batch_import_route') }}"
class="btn btn-outline-primary">Batch Import</a>
<a href="{{ url_for('dashboard.alias_export_route') }}"
class="btn btn-outline-secondary">Export Aliases</a>
</div>
</div>
<!-- END Alias import/export -->
</div> </div>
{% endblock %} {% endblock %}
{% block script %} {% block script %}

View File

@ -28,7 +28,7 @@
{{ render_text("Hi") }} {{ render_text("Hi") }}
{{ render_text("If you use Safari on a MacBook or iMac, you should check out our new Safari extension.") }} {{ render_text("If you use Safari on a MacBook or iMac, you should check out our new Safari extension.") }}
{{ render_text('It can be installed on {{ render_text('It can be installed on
<a href="https://apps.apple.com/app/id1494051017">App Store</a> <a href="https://apps.apple.com/app/id6475835429">App Store</a>
. Its code is available on . Its code is available on
<a href="https://github.com/simple-login/mac-app">GitHub</a> <a href="https://github.com/simple-login/mac-app">GitHub</a>
.') }} .') }}

View File

@ -8,7 +8,7 @@ If you use Safari on a MacBook or iMac, you should check out our new Safari exte
It can be installed on: It can be installed on:
https://apps.apple.com/app/id1494051017 https://apps.apple.com/app/id6475835429
As usual, let me know if you have any question by replying to this email. As usual, let me know if you have any question by replying to this email.

View File

@ -12,7 +12,7 @@ If you want to quickly create aliases <b>without</b> going to SimpleLogin websit
(or other Chromium-based browsers like Brave or Vivaldi), (or other Chromium-based browsers like Brave or Vivaldi),
<a href="https://addons.mozilla.org/firefox/addon/simplelogin/">Firefox</a> <a href="https://addons.mozilla.org/firefox/addon/simplelogin/">Firefox</a>
and and
<a href="https://apps.apple.com/app/id1494051017 ">Safari</a> <a href="https://apps.apple.com/app/id6475835429 ">Safari</a>
extension. extension.
{% endcall %} {% endcall %}

View File

@ -11,7 +11,7 @@ Chrome: https://chrome.google.com/webstore/detail/dphilobhebphkdjbpfohgikllaljmg
Firefox: https://addons.mozilla.org/firefox/addon/simplelogin/ Firefox: https://addons.mozilla.org/firefox/addon/simplelogin/
Safari: https://apps.apple.com/app/id1494051017 Safari: https://apps.apple.com/app/id6475835429
You can also manage your aliases using SimpleLogin mobile apps, available at You can also manage your aliases using SimpleLogin mobile apps, available at
- Play Store https://play.google.com/store/apps/details?id=io.simplelogin.android - Play Store https://play.google.com/store/apps/details?id=io.simplelogin.android

View File

@ -124,7 +124,7 @@
<li> <li>
<a class="list-group-item text-white footer-item " <a class="list-group-item text-white footer-item "
rel="noopener noreferrer" rel="noopener noreferrer"
href="https://apps.apple.com/app/id1494051017"> href="https://apps.apple.com/app/id6475835429">
Safari Safari
Extension Extension
</a> </a>

304
app/tests/auth/test_oidc.py Normal file
View File

@ -0,0 +1,304 @@
from app import config
from flask import url_for
from urllib.parse import parse_qs
from urllib3.util import parse_url
from app.auth.views.oidc import create_user
from app.utils import random_string
from unittest.mock import patch
from app.models import User
from app.config import URL, OIDC_CLIENT_ID
def test_oidc_login(flask_client):
r = flask_client.get(
url_for("auth.oidc_login"),
follow_redirects=False,
)
location = r.headers.get("Location")
assert location is not None
parsed = parse_url(location)
query = parse_qs(parsed.query)
expected_redirect_url = f"{URL}/auth/oidc/callback"
assert "code" == query["response_type"][0]
assert OIDC_CLIENT_ID == query["client_id"][0]
assert expected_redirect_url == query["redirect_uri"][0]
def test_oidc_login_no_client_id(flask_client):
config.OIDC_CLIENT_ID = None
r = flask_client.get(
url_for("auth.oidc_login"),
follow_redirects=False,
)
location = r.headers.get("Location")
assert location is not None
parsed = parse_url(location)
expected_redirect_url = "/auth/login"
assert expected_redirect_url == parsed.path
config.OIDC_CLIENT_ID = "to_fill"
def test_oidc_login_no_client_secret(flask_client):
config.OIDC_CLIENT_SECRET = None
r = flask_client.get(
url_for("auth.oidc_login"),
follow_redirects=False,
)
location = r.headers.get("Location")
assert location is not None
parsed = parse_url(location)
expected_redirect_url = "/auth/login"
assert expected_redirect_url == parsed.path
config.OIDC_CLIENT_SECRET = "to_fill"
def test_oidc_callback_no_oauth_state(flask_client):
with flask_client.session_transaction() as session:
session["oauth_state"] = None
r = flask_client.get(
url_for("auth.oidc_callback"),
follow_redirects=False,
)
location = r.headers.get("Location")
assert location is None
def test_oidc_callback_no_client_id(flask_client):
with flask_client.session_transaction() as session:
session["oauth_state"] = "state"
config.OIDC_CLIENT_ID = None
r = flask_client.get(
url_for("auth.oidc_callback"),
follow_redirects=False,
)
location = r.headers.get("Location")
assert location is not None
parsed = parse_url(location)
expected_redirect_url = "/auth/login"
assert expected_redirect_url == parsed.path
config.OIDC_CLIENT_ID = "to_fill"
with flask_client.session_transaction() as session:
session["oauth_state"] = None
def test_oidc_callback_no_client_secret(flask_client):
with flask_client.session_transaction() as session:
session["oauth_state"] = "state"
config.OIDC_CLIENT_SECRET = None
r = flask_client.get(
url_for("auth.oidc_callback"),
follow_redirects=False,
)
location = r.headers.get("Location")
assert location is not None
parsed = parse_url(location)
expected_redirect_url = "/auth/login"
assert expected_redirect_url == parsed.path
config.OIDC_CLIENT_SECRET = "to_fill"
with flask_client.session_transaction() as session:
session["oauth_state"] = None
@patch("requests_oauthlib.OAuth2Session.fetch_token")
@patch("requests_oauthlib.OAuth2Session.get")
def test_oidc_callback_invalid_user(mock_get, mock_fetch_token, flask_client):
mock_get.return_value = MockResponse(400, {})
with flask_client.session_transaction() as session:
session["oauth_state"] = "state"
r = flask_client.get(
url_for("auth.oidc_callback"),
follow_redirects=False,
)
location = r.headers.get("Location")
assert location is not None
parsed = parse_url(location)
expected_redirect_url = "/auth/login"
assert expected_redirect_url == parsed.path
assert mock_get.called
with flask_client.session_transaction() as session:
session["oauth_state"] = None
@patch("requests_oauthlib.OAuth2Session.fetch_token")
@patch("requests_oauthlib.OAuth2Session.get")
def test_oidc_callback_no_email(mock_get, mock_fetch_token, flask_client):
mock_get.return_value = MockResponse(200, {})
with flask_client.session_transaction() as session:
session["oauth_state"] = "state"
r = flask_client.get(
url_for("auth.oidc_callback"),
follow_redirects=False,
)
location = r.headers.get("Location")
assert location is not None
parsed = parse_url(location)
expected_redirect_url = "/auth/login"
assert expected_redirect_url == parsed.path
assert mock_get.called
with flask_client.session_transaction() as session:
session["oauth_state"] = None
@patch("requests_oauthlib.OAuth2Session.fetch_token")
@patch("requests_oauthlib.OAuth2Session.get")
def test_oidc_callback_disabled_registration(mock_get, mock_fetch_token, flask_client):
config.DISABLE_REGISTRATION = True
email = random_string()
mock_get.return_value = MockResponse(200, {"email": email})
with flask_client.session_transaction() as session:
session["oauth_state"] = "state"
r = flask_client.get(
url_for("auth.oidc_callback"),
follow_redirects=False,
)
location = r.headers.get("Location")
assert location is not None
parsed = parse_url(location)
expected_redirect_url = "/auth/register"
assert expected_redirect_url == parsed.path
assert mock_get.called
config.DISABLE_REGISTRATION = False
with flask_client.session_transaction() as session:
session["oauth_state"] = None
@patch("requests_oauthlib.OAuth2Session.fetch_token")
@patch("requests_oauthlib.OAuth2Session.get")
def test_oidc_callback_registration(mock_get, mock_fetch_token, flask_client):
email = random_string()
mock_get.return_value = MockResponse(
200,
{
"email": email,
config.OIDC_NAME_FIELD: "name",
},
)
with flask_client.session_transaction() as session:
session["oauth_state"] = "state"
user = User.get_by(email=email)
assert user is None
r = flask_client.get(
url_for("auth.oidc_callback"),
follow_redirects=False,
)
location = r.headers.get("Location")
assert location is not None
parsed = parse_url(location)
expected_redirect_url = "/dashboard/"
assert expected_redirect_url == parsed.path
assert mock_get.called
user = User.get_by(email=email)
assert user is not None
assert user.email == email
with flask_client.session_transaction() as session:
session["oauth_state"] = None
@patch("requests_oauthlib.OAuth2Session.fetch_token")
@patch("requests_oauthlib.OAuth2Session.get")
def test_oidc_callback_login(mock_get, mock_fetch_token, flask_client):
email = random_string()
mock_get.return_value = MockResponse(
200,
{
"email": email,
},
)
with flask_client.session_transaction() as session:
session["oauth_state"] = "state"
user = User.create(
email=email,
name="name",
password="",
activated=True,
)
user = User.get_by(email=email)
assert user is not None
r = flask_client.get(
url_for("auth.oidc_callback"),
follow_redirects=False,
)
location = r.headers.get("Location")
assert location is not None
parsed = parse_url(location)
expected_redirect_url = "/dashboard/"
assert expected_redirect_url == parsed.path
assert mock_get.called
with flask_client.session_transaction() as session:
session["oauth_state"] = None
def test_create_user():
email = random_string()
user = create_user(
email,
{
config.OIDC_NAME_FIELD: "name",
},
)
assert user.email == email
assert user.name == "name"
assert user.activated
class MockResponse:
def __init__(self, status_code, json_data):
self.status_code = status_code
self.json_data = json_data
self.text = "error"
def json(self):
return self.json_data

View File

View File

@ -0,0 +1,142 @@
import arrow
import pytest
import cron
from app.db import Session
from app.models import (
Alias,
AppleSubscription,
PlanEnum,
CoinbaseSubscription,
ManualSubscription,
Subscription,
PartnerUser,
PartnerSubscription,
User,
)
from app.proton.utils import get_proton_partner
from tests.utils import create_new_user, random_token
def test_get_alias_for_free_user_has_no_alias():
user = create_new_user()
alias_id = Alias.create_new_random(user).id
Session.commit()
aliases = list(
cron.get_alias_to_check_hibp(arrow.now(), [], alias_id, alias_id + 1)
)
assert len(aliases) == 0
def test_get_alias_for_lifetime_with_null_hibp_date():
user = create_new_user()
user.lifetime = True
alias_id = Alias.create_new_random(user).id
Session.commit()
aliases = list(
cron.get_alias_to_check_hibp(arrow.now(), [], alias_id, alias_id + 1)
)
assert alias_id == aliases[0].id
def test_get_alias_for_lifetime_with_old_hibp_date():
user = create_new_user()
user.lifetime = True
alias = Alias.create_new_random(user)
alias.hibp_last_check = arrow.now().shift(days=-1)
alias_id = alias.id
Session.commit()
aliases = list(
cron.get_alias_to_check_hibp(arrow.now(), [], alias_id, alias_id + 1)
)
assert alias_id == aliases[0].id
def create_partner_sub(user: User):
pu = PartnerUser.create(
partner_id=get_proton_partner().id,
partner_email=user.email,
external_user_id=random_token(10),
user_id=user.id,
flush=True,
)
PartnerSubscription.create(
partner_user_id=pu.id, end_at=arrow.utcnow().shift(days=15)
)
sub_generator_list = [
lambda u: AppleSubscription.create(
user_id=u.id,
expires_date=arrow.now().shift(days=15),
original_transaction_id=random_token(10),
receipt_data=random_token(10),
plan=PlanEnum.monthly,
),
lambda u: CoinbaseSubscription.create(
user_id=u.id,
end_at=arrow.now().shift(days=15),
),
lambda u: ManualSubscription.create(
user_id=u.id,
end_at=arrow.now().shift(days=15),
),
lambda u: Subscription.create(
user_id=u.id,
cancel_url="",
update_url="",
subscription_id=random_token(10),
event_time=arrow.now(),
next_bill_date=arrow.now().shift(days=15).date(),
plan=PlanEnum.monthly,
),
create_partner_sub,
]
@pytest.mark.parametrize("sub_generator", sub_generator_list)
def test_get_alias_for_sub(sub_generator):
user = create_new_user()
sub_generator(user)
alias_id = Alias.create_new_random(user).id
Session.commit()
aliases = list(
cron.get_alias_to_check_hibp(arrow.now(), [], alias_id, alias_id + 1)
)
assert alias_id == aliases[0].id
def test_disabled_user_is_not_checked():
user = create_new_user()
user.lifetime = True
user.disabled = True
alias_id = Alias.create_new_random(user).id
Session.commit()
aliases = list(
cron.get_alias_to_check_hibp(arrow.now(), [], alias_id, alias_id + 1)
)
assert len(aliases) == 0
def test_skipped_user_is_not_checked():
user = create_new_user()
user.lifetime = True
alias_id = Alias.create_new_random(user).id
Session.commit()
aliases = list(
cron.get_alias_to_check_hibp(arrow.now(), [user.id], alias_id, alias_id + 1)
)
assert len(aliases) == 0
def test_already_checked_is_not_checked():
user = create_new_user()
user.lifetime = True
alias = Alias.create_new_random(user)
alias.hibp_last_check = arrow.now().shift(days=1)
alias_id = alias.id
Session.commit()
aliases = list(
cron.get_alias_to_check_hibp(arrow.now(), [user.id], alias_id, alias_id + 1)
)
assert len(aliases) == 0

View File

View File

@ -0,0 +1,35 @@
import tempfile
from io import BytesIO
import arrow
from app import s3, config
from app.models import File, BatchImport
from tasks.cleanup_old_imports import cleanup_old_imports
from tests.utils import random_token, create_new_user
def test_cleanup_old_imports():
BatchImport.filter().delete()
with tempfile.TemporaryDirectory() as tmpdir:
config.UPLOAD_DIR = tmpdir
user = create_new_user()
path = random_token()
s3.upload_from_bytesio(path, BytesIO("data".encode("utf-8")))
file = File.create(path=path, commit=True) # noqa: F821
now = arrow.now()
delete_batch_import_id = BatchImport.create(
user_id=user.id,
file_id=file.id,
created_at=now.shift(minutes=-1),
flush=True,
).id
keep_batch_import_id = BatchImport.create(
user_id=user.id,
file_id=file.id,
created_at=now.shift(minutes=+1),
commit=True,
).id
cleanup_old_imports(now)
assert BatchImport.get(id=delete_batch_import_id) is None
assert BatchImport.get(id=keep_batch_import_id) is not None

View File

@ -0,0 +1,72 @@
import arrow
from app import config
from app.models import Job, JobState
from tasks.cleanup_old_jobs import cleanup_old_jobs
def test_cleanup_old_jobs():
Job.filter().delete()
now = arrow.now()
delete_ids = [
Job.create(
updated_at=now.shift(minutes=-1),
state=JobState.done.value,
name="",
payload="",
flush=True,
).id,
Job.create(
updated_at=now.shift(minutes=-1),
state=JobState.error.value,
name="",
payload="",
flush=True,
).id,
Job.create(
updated_at=now.shift(minutes=-1),
state=JobState.taken.value,
attempts=config.JOB_MAX_ATTEMPTS,
name="",
payload="",
flush=True,
).id,
]
keep_ids = [
Job.create(
updated_at=now.shift(minutes=+1),
state=JobState.done.value,
name="",
payload="",
flush=True,
).id,
Job.create(
updated_at=now.shift(minutes=+1),
state=JobState.error.value,
name="",
payload="",
flush=True,
).id,
Job.create(
updated_at=now.shift(minutes=+1),
state=JobState.taken.value,
attempts=config.JOB_MAX_ATTEMPTS,
name="",
payload="",
flush=True,
).id,
Job.create(
updated_at=now.shift(minutes=-1),
state=JobState.taken.value,
attempts=config.JOB_MAX_ATTEMPTS - 1,
name="",
payload="",
flush=True,
).id,
]
cleanup_old_jobs(now)
for delete_id in delete_ids:
assert Job.get(id=delete_id) is None
for keep_id in keep_ids:
assert Job.get(id=keep_id) is not None

View File

@ -0,0 +1,26 @@
import arrow
from app.models import Notification
from tasks.cleanup_old_notifications import cleanup_old_notifications
from tests.utils import create_new_user
def test_cleanup_old_notifications():
Notification.filter().delete()
user = create_new_user()
now = arrow.now()
delete_id = Notification.create(
user_id=user.id,
created_at=now.shift(minutes=-1),
message="",
flush=True,
).id
keep_id = Notification.create(
user_id=user.id,
created_at=now.shift(minutes=+1),
message="",
flush=True,
).id
cleanup_old_notifications(now)
assert Notification.get(id=delete_id) is None
assert Notification.get(id=keep_id) is not None

View File

@ -49,6 +49,16 @@ GOOGLE_CLIENT_SECRET=to_fill
FACEBOOK_CLIENT_ID=to_fill FACEBOOK_CLIENT_ID=to_fill
FACEBOOK_CLIENT_SECRET=to_fill FACEBOOK_CLIENT_SECRET=to_fill
# Login with OIDC
CONNECT_WITH_OIDC_ICON=fa-github
OIDC_AUTHORIZATION_URL=to_fill
OIDC_USER_INFO_URL=to_fill
OIDC_TOKEN_URL=to_fill
OIDC_SCOPES=openid email profile
OIDC_NAME_FIELD=name
OIDC_CLIENT_ID=to_fill
OIDC_CLIENT_SECRET=to_fill
PGP_SENDER_PRIVATE_KEY_PATH=local_data/private-pgp.asc PGP_SENDER_PRIVATE_KEY_PATH=local_data/private-pgp.asc
ALIAS_AUTOMATIC_DISABLE=true ALIAS_AUTOMATIC_DISABLE=true

View File

@ -384,3 +384,30 @@ def test_break_loop_alias_as_mailbox(flask_client):
msg[headers.SUBJECT] = random_string() msg[headers.SUBJECT] = random_string()
result = email_handler.handle(envelope, msg) result = email_handler.handle(envelope, msg)
assert result == status.E525 assert result == status.E525
@mail_sender.store_emails_test_decorator
def test_preserve_headers(flask_client):
headers_to_keep = [
headers.SUBJECT,
headers.DATE,
headers.MESSAGE_ID,
headers.REFERENCES,
headers.IN_REPLY_TO,
headers.SL_QUEUE_ID,
] + headers.MIME_HEADERS
user = create_new_user()
alias = Alias.create_new_random(user)
envelope = Envelope()
envelope.mail_from = "somewhere@lo.cal"
envelope.rcpt_tos = [alias.email]
msg = EmailMessage()
for header in headers_to_keep:
msg[header] = header + "keep"
result = email_handler.handle(envelope, msg)
assert result == status.E200
sent_mails = mail_sender.get_stored_emails()
assert len(sent_mails) == 1
msg = sent_mails[0].msg
for header in headers_to_keep:
assert msg[header] == header + "keep"