diff --git a/app/app/auth/__init__.py b/app/app/auth/__init__.py index e8adcc6..7a6cdee 100644 --- a/app/app/auth/__init__.py +++ b/app/app/auth/__init__.py @@ -16,6 +16,7 @@ from .views import ( social, recovery, api_to_cookie, + oidc, ) __all__ = [ @@ -36,4 +37,5 @@ __all__ = [ "social", "recovery", "api_to_cookie", + "oidc", ] diff --git a/app/app/auth/views/login.py b/app/app/auth/views/login.py index 56b2ac3..261c1a9 100644 --- a/app/app/auth/views/login.py +++ b/app/app/auth/views/login.py @@ -5,7 +5,7 @@ from wtforms import StringField, validators from app.auth.base import auth_bp from app.auth.views.login_utils import after_login -from app.config import CONNECT_WITH_PROTON +from app.config import CONNECT_WITH_PROTON, CONNECT_WITH_OIDC_ICON, OIDC_CLIENT_ID from app.events.auth_event import LoginEvent from app.extensions import limiter from app.log import LOG @@ -77,4 +77,6 @@ def login(): next_url=next_url, show_resend_activation=show_resend_activation, connect_with_proton=CONNECT_WITH_PROTON, + connect_with_oidc=OIDC_CLIENT_ID is not None, + connect_with_oidc_icon=CONNECT_WITH_OIDC_ICON, ) diff --git a/app/app/auth/views/oidc.py b/app/app/auth/views/oidc.py new file mode 100644 index 0000000..12c4e49 --- /dev/null +++ b/app/app/auth/views/oidc.py @@ -0,0 +1,131 @@ +from flask import request, session, redirect, flash, url_for +from requests_oauthlib import OAuth2Session + +from app import config +from app.auth.base import auth_bp +from app.auth.views.login_utils import after_login +from app.config import ( + URL, + OIDC_AUTHORIZATION_URL, + OIDC_USER_INFO_URL, + OIDC_TOKEN_URL, + OIDC_SCOPES, + OIDC_NAME_FIELD, +) +from app.db import Session +from app.email_utils import send_welcome_email +from app.log import LOG +from app.models import User, SocialAuth +from app.utils import encode_url, sanitize_email, sanitize_next_url + + +# need to set explicitly redirect_uri instead of leaving the lib to pre-fill redirect_uri +# when served behind nginx, the redirect_uri is localhost... and not the real url +_redirect_uri = URL + "/auth/oidc/callback" + +SESSION_STATE_KEY = "oauth_state" + + +@auth_bp.route("/oidc/login") +def oidc_login(): + if config.OIDC_CLIENT_ID is None or config.OIDC_CLIENT_SECRET is None: + return redirect(url_for("auth.login")) + + next_url = sanitize_next_url(request.args.get("next")) + if next_url: + redirect_uri = _redirect_uri + "?next=" + encode_url(next_url) + else: + redirect_uri = _redirect_uri + + oidc = OAuth2Session( + config.OIDC_CLIENT_ID, scope=[OIDC_SCOPES], redirect_uri=redirect_uri + ) + authorization_url, state = oidc.authorization_url(OIDC_AUTHORIZATION_URL) + + # State is used to prevent CSRF, keep this for later. + session[SESSION_STATE_KEY] = state + return redirect(authorization_url) + + +@auth_bp.route("/oidc/callback") +def oidc_callback(): + if SESSION_STATE_KEY not in session: + flash("Invalid state, please retry", "error") + return redirect(url_for("auth.login")) + if config.OIDC_CLIENT_ID is None or config.OIDC_CLIENT_SECRET is None: + return redirect(url_for("auth.login")) + + # user clicks on cancel + if "error" in request.args: + flash("Please use another sign in method then", "warning") + return redirect("/") + + oidc = OAuth2Session( + config.OIDC_CLIENT_ID, + state=session[SESSION_STATE_KEY], + scope=[OIDC_SCOPES], + redirect_uri=_redirect_uri, + ) + oidc.fetch_token( + OIDC_TOKEN_URL, + client_secret=config.OIDC_CLIENT_SECRET, + authorization_response=request.url, + ) + + oidc_user_data = oidc.get(OIDC_USER_INFO_URL) + if oidc_user_data.status_code != 200: + LOG.e( + f"cannot get oidc user data {oidc_user_data.status_code} {oidc_user_data.text}" + ) + flash( + "Cannot get user data from OIDC, please use another way to login/sign up", + "error", + ) + return redirect(url_for("auth.login")) + oidc_user_data = oidc_user_data.json() + + email = oidc_user_data.get("email") + + if not email: + LOG.e(f"cannot get email for OIDC user {oidc_user_data} {email}") + flash( + "Cannot get a valid email from OIDC, please another way to login/sign up", + "error", + ) + return redirect(url_for("auth.login")) + + email = sanitize_email(email) + user = User.get_by(email=email) + + if not user and config.DISABLE_REGISTRATION: + flash( + "Sorry you cannot sign up via the OIDC provider. Please sign-up first with your email.", + "error", + ) + return redirect(url_for("auth.register")) + elif not user: + user = create_user(email, oidc_user_data) + + if not SocialAuth.get_by(user_id=user.id, social="oidc"): + SocialAuth.create(user_id=user.id, social="oidc") + Session.commit() + + # The activation link contains the original page, for ex authorize page + next_url = sanitize_next_url(request.args.get("next")) if request.args else None + + return after_login(user, next_url) + + +def create_user(email, oidc_user_data): + new_user = User.create( + email=email, + name=oidc_user_data.get(OIDC_NAME_FIELD), + password="", + activated=True, + ) + LOG.i(f"Created new user for login request from OIDC. New user {new_user.id}") + Session.commit() + + send_welcome_email(new_user) + + return new_user diff --git a/app/app/auth/views/register.py b/app/app/auth/views/register.py index f40a98a..7505303 100644 --- a/app/app/auth/views/register.py +++ b/app/app/auth/views/register.py @@ -6,7 +6,7 @@ from wtforms import StringField, validators from app import email_utils, config from app.auth.base import auth_bp -from app.config import CONNECT_WITH_PROTON +from app.config import CONNECT_WITH_PROTON, CONNECT_WITH_OIDC_ICON from app.auth.views.login_utils import get_referral from app.config import URL, HCAPTCHA_SECRET, HCAPTCHA_SITEKEY from app.db import Session @@ -109,6 +109,8 @@ def register(): next_url=next_url, HCAPTCHA_SITEKEY=HCAPTCHA_SITEKEY, connect_with_proton=CONNECT_WITH_PROTON, + connect_with_oidc=config.OIDC_CLIENT_ID is not None, + connect_with_oidc_icon=CONNECT_WITH_OIDC_ICON, ) diff --git a/app/app/config.py b/app/app/config.py index e8824f5..c55c555 100644 --- a/app/app/config.py +++ b/app/app/config.py @@ -234,7 +234,7 @@ else: print("WARNING: Use a temp directory for GNUPGHOME", GNUPGHOME) -# Github, Google, Facebook client id and secrets +# Github, Google, Facebook, OIDC client id and secrets GITHUB_CLIENT_ID = os.environ.get("GITHUB_CLIENT_ID") GITHUB_CLIENT_SECRET = os.environ.get("GITHUB_CLIENT_SECRET") @@ -244,6 +244,15 @@ GOOGLE_CLIENT_SECRET = os.environ.get("GOOGLE_CLIENT_SECRET") FACEBOOK_CLIENT_ID = os.environ.get("FACEBOOK_CLIENT_ID") FACEBOOK_CLIENT_SECRET = os.environ.get("FACEBOOK_CLIENT_SECRET") +CONNECT_WITH_OIDC_ICON = os.environ.get("CONNECT_WITH_OIDC_ICON") +OIDC_AUTHORIZATION_URL = os.environ.get("OIDC_AUTHORIZATION_URL") +OIDC_USER_INFO_URL = os.environ.get("OIDC_USER_INFO_URL") +OIDC_TOKEN_URL = os.environ.get("OIDC_TOKEN_URL") +OIDC_CLIENT_ID = os.environ.get("OIDC_CLIENT_ID") +OIDC_CLIENT_SECRET = os.environ.get("OIDC_CLIENT_SECRET") +OIDC_SCOPES = os.environ.get("OIDC_SCOPES") +OIDC_NAME_FIELD = os.environ.get("OIDC_NAME_FIELD", "name") + PROTON_CLIENT_ID = os.environ.get("PROTON_CLIENT_ID") PROTON_CLIENT_SECRET = os.environ.get("PROTON_CLIENT_SECRET") PROTON_BASE_URL = os.environ.get( @@ -423,6 +432,7 @@ except Exception: HIBP_API_KEYS = sl_getenv("HIBP_API_KEYS", list) or [] HIBP_MAX_ALIAS_CHECK = 10_000 HIBP_RPM = 100 +HIBP_SKIP_PARTNER_ALIAS = os.environ.get("HIBP_SKIP_PARTNER_ALIAS") POSTMASTER = os.environ.get("POSTMASTER") diff --git a/app/app/dashboard/views/enter_sudo.py b/app/app/dashboard/views/enter_sudo.py index d32deb8..3910873 100644 --- a/app/app/dashboard/views/enter_sudo.py +++ b/app/app/dashboard/views/enter_sudo.py @@ -6,11 +6,11 @@ from flask_login import login_required, current_user from flask_wtf import FlaskForm from wtforms import PasswordField, validators -from app.config import CONNECT_WITH_PROTON +from app.config import CONNECT_WITH_PROTON, OIDC_CLIENT_ID, CONNECT_WITH_OIDC_ICON from app.dashboard.base import dashboard_bp from app.extensions import limiter from app.log import LOG -from app.models import PartnerUser +from app.models import PartnerUser, SocialAuth from app.proton.utils import get_proton_partner from app.utils import sanitize_next_url @@ -51,11 +51,19 @@ def enter_sudo(): if not partner_user or partner_user.partner_id != get_proton_partner().id: proton_enabled = False + oidc_enabled = OIDC_CLIENT_ID is not None + if oidc_enabled: + oidc_enabled = ( + SocialAuth.get_by(user_id=current_user.id, social="oidc") is not None + ) + return render_template( "dashboard/enter_sudo.html", password_check_form=password_check_form, next=request.args.get("next"), connect_with_proton=proton_enabled, + connect_with_oidc=oidc_enabled, + connect_with_oidc_icon=CONNECT_WITH_OIDC_ICON, ) diff --git a/app/app/models.py b/app/app/models.py index 06f585d..472b98a 100644 --- a/app/app/models.py +++ b/app/app/models.py @@ -1422,6 +1422,9 @@ def generate_random_alias_email( class Alias(Base, ModelMixin): __tablename__ = "alias" + + FLAG_PARTNER_CREATED = 1 << 0 + user_id = sa.Column( sa.ForeignKey(User.id, ondelete="cascade"), nullable=False, index=True ) @@ -1431,6 +1434,9 @@ class Alias(Base, ModelMixin): name = sa.Column(sa.String(128), nullable=True, default=None) enabled = sa.Column(sa.Boolean(), default=True, nullable=False) + flags = sa.Column( + sa.BigInteger(), default=0, server_default="0", nullable=False, index=True + ) custom_domain_id = sa.Column( sa.ForeignKey("custom_domain.id", ondelete="cascade"), nullable=True, index=True @@ -2586,10 +2592,13 @@ class Job(Base, ModelMixin): nullable=False, server_default=str(JobState.ready.value), default=JobState.ready.value, + index=True, ) attempts = sa.Column(sa.Integer, nullable=False, server_default="0", default=0) taken_at = sa.Column(ArrowType, nullable=True) + __table_args__ = (Index("ix_state_run_at_taken_at", state, run_at, taken_at),) + def __repr__(self): return f"" @@ -2937,7 +2946,9 @@ class RecoveryCode(Base, ModelMixin): class Notification(Base, ModelMixin): __tablename__ = "notification" - user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) + user_id = sa.Column( + sa.ForeignKey(User.id, ondelete="cascade"), nullable=False, index=True + ) message = sa.Column(sa.Text, nullable=False) title = sa.Column(sa.String(512)) diff --git a/app/cron.py b/app/cron.py index 503046e..b344696 100644 --- a/app/cron.py +++ b/app/cron.py @@ -1059,13 +1059,15 @@ async def check_hibp(): LOG.d("Preparing list of aliases to check") queue = asyncio.Queue() max_date = arrow.now().shift(days=-config.HIBP_SCAN_INTERVAL_DAYS) + alias_query = Alias.filter( + or_(Alias.hibp_last_check.is_(None), Alias.hibp_last_check < max_date), + Alias.user_id.notin_(user_ids), + Alias.enabled, + ) + if config.HIBP_SKIP_PARTNER_ALIAS: + alias_query = alias_query(Alias.flags.op("&")(Alias.FLAG_PARTNER_CREATED) == 0) for alias in ( - Alias.filter( - or_(Alias.hibp_last_check.is_(None), Alias.hibp_last_check < max_date), - Alias.user_id.notin_(user_ids), - ) - .filter(Alias.enabled) - .order_by(nullsfirst(Alias.hibp_last_check.asc()), Alias.id.asc()) + alias_query.order_by(nullsfirst(Alias.hibp_last_check.asc()), Alias.id.asc()) .yield_per(500) .enable_eagerloads(False) ): diff --git a/app/example.env b/app/example.env index d2e22b6..4ee0951 100644 --- a/app/example.env +++ b/app/example.env @@ -116,6 +116,16 @@ WORDS_FILE_PATH=local_data/test_words.txt # CONNECT_WITH_PROTON=true # CONNECT_WITH_PROTON_COOKIE_NAME=to_fill +# Login with OIDC +# CONNECT_WITH_OIDC_ICON=fa-github +# OIDC_AUTHORIZATION_URL=to_fill +# OIDC_USER_INFO_URL=to_fill +# OIDC_TOKEN_URL=to_fill +# OIDC_SCOPES=openid email profile +# OIDC_NAME_FIELD=name +# OIDC_CLIENT_ID=to_fill +# OIDC_CLIENT_SECRET=to_fill + # Flask profiler # FLASK_PROFILER_PATH=/tmp/flask-profiler.sql # FLASK_PROFILER_PASSWORD=password diff --git a/app/migrations/versions/2024_031212_52510a633d6f_.py b/app/migrations/versions/2024_031212_52510a633d6f_.py new file mode 100644 index 0000000..3ebbeba --- /dev/null +++ b/app/migrations/versions/2024_031212_52510a633d6f_.py @@ -0,0 +1,48 @@ +"""empty message + +Revision ID: 52510a633d6f +Revises: 818b0a956205 +Create Date: 2024-03-12 12:46:24.161644 + +""" +import sqlalchemy_utils +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "52510a633d6f" +down_revision = "818b0a956205" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "alias", sa.Column("flags", sa.BigInteger(), server_default="0", nullable=False) + ) + with op.get_context().autocommit_block(): + op.create_index(op.f("ix_alias_flags"), "alias", ["flags"], unique=False) + op.create_index(op.f("ix_job_state"), "job", ["state"], unique=False) + op.create_index( + "ix_state_run_at_taken_at", + "job", + ["state", "run_at", "taken_at"], + unique=False, + ) + op.create_index( + op.f("ix_notification_user_id"), "notification", ["user_id"], unique=False + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.get_context().autocommit_block(): + op.drop_index(op.f("ix_notification_user_id"), table_name="notification") + op.drop_index("ix_state_run_at_taken_at", table_name="job") + op.drop_index(op.f("ix_job_state"), table_name="job") + op.drop_index(op.f("ix_alias_flags"), table_name="alias") + op.drop_column("alias", "flags") + # ### end Alembic commands ### diff --git a/app/oneshot/mark_aliases_as_partner_created.py b/app/oneshot/mark_aliases_as_partner_created.py new file mode 100644 index 0000000..cf99db9 --- /dev/null +++ b/app/oneshot/mark_aliases_as_partner_created.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +import argparse + + +from app.log import LOG +from app.models import Alias, SLDomain +from app.db import Session + +parser = argparse.ArgumentParser( + prog="Mark partner created aliases with the PARTNER_CREATED flag", +) +args = parser.parse_args() + +domains = SLDomain.filter(SLDomain.partner_id.isnot(None)).all() + +for domain in domains: + LOG.i(f"Checking aliases for domain {domain.domain}") + for alias in ( + Alias.filter( + Alias.email.like(f"%{domain.domain}"), + Alias.flags.op("&")(Alias.FLAG_PARTNER_CREATED) == 0, + ) + .enable_eagerloads(False) + .yield_per(100) + .all() + ): + alias.flags = alias.flags | Alias.FLAG_PARTNER_CREATED + LOG.i(f" * Updating {alias.email} to {alias.flags}") + Session.commit() diff --git a/app/templates/auth/login.html b/app/templates/auth/login.html index 0451fd3..898e6d0 100644 --- a/app/templates/auth/login.html +++ b/app/templates/auth/login.html @@ -38,11 +38,21 @@ or + href="{{ url_for('auth.proton_login', next=next_url) }}"> Log in with Proton {% endif %} + {% if connect_with_oidc %} + +
+ or +
+ + Log in with SSO + + {% endif %}
diff --git a/app/templates/auth/register.html b/app/templates/auth/register.html index 6be4f4c..c8400e1 100644 --- a/app/templates/auth/register.html +++ b/app/templates/auth/register.html @@ -50,11 +50,21 @@ or
+ href="{{ url_for('auth.proton_login', next=next_url) }}"> Sign up with Proton {% endif %} + {% if connect_with_oidc %} + +
+ or +
+ + Sign up with SSO + + {% endif %}
diff --git a/app/templates/dashboard/enter_sudo.html b/app/templates/dashboard/enter_sudo.html index ed1f895..3175f37 100644 --- a/app/templates/dashboard/enter_sudo.html +++ b/app/templates/dashboard/enter_sudo.html @@ -22,11 +22,20 @@

Alternatively you can use your Proton credentials to ensure it's you.

+ href="{{ url_for('auth.proton_login', next=next) }}"> Authenticate with Proton {% endif %} + {% if connect_with_oidc %} + +
+

Alternatively you can use your SSO credentials to ensure it's you.

+ + Authenticate with SSO + + {% endif %} +
- -{% endblock %} + {% endblock %} diff --git a/app/templates/emails/com/newsletter/safari-extension.html b/app/templates/emails/com/newsletter/safari-extension.html index b9bf9b5..969db62 100644 --- a/app/templates/emails/com/newsletter/safari-extension.html +++ b/app/templates/emails/com/newsletter/safari-extension.html @@ -28,7 +28,7 @@ {{ render_text("Hi") }} {{ render_text("If you use Safari on a MacBook or iMac, you should check out our new Safari extension.") }} {{ render_text('It can be installed on - App Store + App Store . Its code is available on GitHub .') }} diff --git a/app/templates/emails/com/newsletter/safari-extension.txt b/app/templates/emails/com/newsletter/safari-extension.txt index 4056aa1..82d3bc9 100644 --- a/app/templates/emails/com/newsletter/safari-extension.txt +++ b/app/templates/emails/com/newsletter/safari-extension.txt @@ -8,7 +8,7 @@ If you use Safari on a MacBook or iMac, you should check out our new Safari exte It can be installed on: -https://apps.apple.com/app/id1494051017 +https://apps.apple.com/app/id6475835429 As usual, let me know if you have any question by replying to this email. diff --git a/app/templates/emails/com/onboarding/browser-extension.html b/app/templates/emails/com/onboarding/browser-extension.html index 9a43549..b2007af 100644 --- a/app/templates/emails/com/onboarding/browser-extension.html +++ b/app/templates/emails/com/onboarding/browser-extension.html @@ -12,7 +12,7 @@ If you want to quickly create aliases without going to SimpleLogin websit (or other Chromium-based browsers like Brave or Vivaldi), Firefox and -Safari +Safari extension. {% endcall %} diff --git a/app/templates/emails/com/onboarding/browser-extension.txt b/app/templates/emails/com/onboarding/browser-extension.txt index ebb9f8d..2407951 100644 --- a/app/templates/emails/com/onboarding/browser-extension.txt +++ b/app/templates/emails/com/onboarding/browser-extension.txt @@ -11,7 +11,7 @@ Chrome: https://chrome.google.com/webstore/detail/dphilobhebphkdjbpfohgikllaljmg Firefox: https://addons.mozilla.org/firefox/addon/simplelogin/ -Safari: https://apps.apple.com/app/id1494051017 +Safari: https://apps.apple.com/app/id6475835429 You can also manage your aliases using SimpleLogin mobile apps, available at - Play Store https://play.google.com/store/apps/details?id=io.simplelogin.android diff --git a/app/templates/footer.html b/app/templates/footer.html index bea517e..180b6c7 100644 --- a/app/templates/footer.html +++ b/app/templates/footer.html @@ -124,7 +124,7 @@
  • + href="https://apps.apple.com/app/id6475835429"> Safari Extension diff --git a/app/tests/auth/test_oidc.py b/app/tests/auth/test_oidc.py new file mode 100644 index 0000000..e35bb5e --- /dev/null +++ b/app/tests/auth/test_oidc.py @@ -0,0 +1,304 @@ +from app import config +from flask import url_for +from urllib.parse import parse_qs +from urllib3.util import parse_url +from app.auth.views.oidc import create_user +from app.utils import random_string +from unittest.mock import patch +from app.models import User + +from app.config import URL, OIDC_CLIENT_ID + + +def test_oidc_login(flask_client): + r = flask_client.get( + url_for("auth.oidc_login"), + follow_redirects=False, + ) + location = r.headers.get("Location") + assert location is not None + + parsed = parse_url(location) + query = parse_qs(parsed.query) + + expected_redirect_url = f"{URL}/auth/oidc/callback" + + assert "code" == query["response_type"][0] + assert OIDC_CLIENT_ID == query["client_id"][0] + assert expected_redirect_url == query["redirect_uri"][0] + + +def test_oidc_login_no_client_id(flask_client): + config.OIDC_CLIENT_ID = None + + r = flask_client.get( + url_for("auth.oidc_login"), + follow_redirects=False, + ) + location = r.headers.get("Location") + assert location is not None + + parsed = parse_url(location) + + expected_redirect_url = "/auth/login" + + assert expected_redirect_url == parsed.path + + config.OIDC_CLIENT_ID = "to_fill" + + +def test_oidc_login_no_client_secret(flask_client): + config.OIDC_CLIENT_SECRET = None + + r = flask_client.get( + url_for("auth.oidc_login"), + follow_redirects=False, + ) + location = r.headers.get("Location") + assert location is not None + + parsed = parse_url(location) + + expected_redirect_url = "/auth/login" + + assert expected_redirect_url == parsed.path + + config.OIDC_CLIENT_SECRET = "to_fill" + + +def test_oidc_callback_no_oauth_state(flask_client): + with flask_client.session_transaction() as session: + session["oauth_state"] = None + + r = flask_client.get( + url_for("auth.oidc_callback"), + follow_redirects=False, + ) + location = r.headers.get("Location") + assert location is None + + +def test_oidc_callback_no_client_id(flask_client): + with flask_client.session_transaction() as session: + session["oauth_state"] = "state" + config.OIDC_CLIENT_ID = None + + r = flask_client.get( + url_for("auth.oidc_callback"), + follow_redirects=False, + ) + location = r.headers.get("Location") + assert location is not None + + parsed = parse_url(location) + + expected_redirect_url = "/auth/login" + + assert expected_redirect_url == parsed.path + + config.OIDC_CLIENT_ID = "to_fill" + with flask_client.session_transaction() as session: + session["oauth_state"] = None + + +def test_oidc_callback_no_client_secret(flask_client): + with flask_client.session_transaction() as session: + session["oauth_state"] = "state" + config.OIDC_CLIENT_SECRET = None + + r = flask_client.get( + url_for("auth.oidc_callback"), + follow_redirects=False, + ) + location = r.headers.get("Location") + assert location is not None + + parsed = parse_url(location) + + expected_redirect_url = "/auth/login" + + assert expected_redirect_url == parsed.path + + config.OIDC_CLIENT_SECRET = "to_fill" + with flask_client.session_transaction() as session: + session["oauth_state"] = None + + +@patch("requests_oauthlib.OAuth2Session.fetch_token") +@patch("requests_oauthlib.OAuth2Session.get") +def test_oidc_callback_invalid_user(mock_get, mock_fetch_token, flask_client): + mock_get.return_value = MockResponse(400, {}) + with flask_client.session_transaction() as session: + session["oauth_state"] = "state" + + r = flask_client.get( + url_for("auth.oidc_callback"), + follow_redirects=False, + ) + location = r.headers.get("Location") + assert location is not None + + parsed = parse_url(location) + + expected_redirect_url = "/auth/login" + + assert expected_redirect_url == parsed.path + assert mock_get.called + + with flask_client.session_transaction() as session: + session["oauth_state"] = None + + +@patch("requests_oauthlib.OAuth2Session.fetch_token") +@patch("requests_oauthlib.OAuth2Session.get") +def test_oidc_callback_no_email(mock_get, mock_fetch_token, flask_client): + mock_get.return_value = MockResponse(200, {}) + with flask_client.session_transaction() as session: + session["oauth_state"] = "state" + + r = flask_client.get( + url_for("auth.oidc_callback"), + follow_redirects=False, + ) + location = r.headers.get("Location") + assert location is not None + + parsed = parse_url(location) + + expected_redirect_url = "/auth/login" + + assert expected_redirect_url == parsed.path + assert mock_get.called + + with flask_client.session_transaction() as session: + session["oauth_state"] = None + + +@patch("requests_oauthlib.OAuth2Session.fetch_token") +@patch("requests_oauthlib.OAuth2Session.get") +def test_oidc_callback_disabled_registration(mock_get, mock_fetch_token, flask_client): + config.DISABLE_REGISTRATION = True + email = random_string() + mock_get.return_value = MockResponse(200, {"email": email}) + with flask_client.session_transaction() as session: + session["oauth_state"] = "state" + + r = flask_client.get( + url_for("auth.oidc_callback"), + follow_redirects=False, + ) + location = r.headers.get("Location") + assert location is not None + + parsed = parse_url(location) + + expected_redirect_url = "/auth/register" + + assert expected_redirect_url == parsed.path + assert mock_get.called + + config.DISABLE_REGISTRATION = False + with flask_client.session_transaction() as session: + session["oauth_state"] = None + + +@patch("requests_oauthlib.OAuth2Session.fetch_token") +@patch("requests_oauthlib.OAuth2Session.get") +def test_oidc_callback_registration(mock_get, mock_fetch_token, flask_client): + email = random_string() + mock_get.return_value = MockResponse( + 200, + { + "email": email, + config.OIDC_NAME_FIELD: "name", + }, + ) + with flask_client.session_transaction() as session: + session["oauth_state"] = "state" + + user = User.get_by(email=email) + assert user is None + + r = flask_client.get( + url_for("auth.oidc_callback"), + follow_redirects=False, + ) + location = r.headers.get("Location") + assert location is not None + + parsed = parse_url(location) + + expected_redirect_url = "/dashboard/" + + assert expected_redirect_url == parsed.path + assert mock_get.called + + user = User.get_by(email=email) + assert user is not None + assert user.email == email + + with flask_client.session_transaction() as session: + session["oauth_state"] = None + + +@patch("requests_oauthlib.OAuth2Session.fetch_token") +@patch("requests_oauthlib.OAuth2Session.get") +def test_oidc_callback_login(mock_get, mock_fetch_token, flask_client): + email = random_string() + mock_get.return_value = MockResponse( + 200, + { + "email": email, + }, + ) + with flask_client.session_transaction() as session: + session["oauth_state"] = "state" + + user = User.create( + email=email, + name="name", + password="", + activated=True, + ) + user = User.get_by(email=email) + assert user is not None + + r = flask_client.get( + url_for("auth.oidc_callback"), + follow_redirects=False, + ) + location = r.headers.get("Location") + assert location is not None + + parsed = parse_url(location) + + expected_redirect_url = "/dashboard/" + + assert expected_redirect_url == parsed.path + assert mock_get.called + + with flask_client.session_transaction() as session: + session["oauth_state"] = None + + +def test_create_user(): + email = random_string() + user = create_user( + email, + { + config.OIDC_NAME_FIELD: "name", + }, + ) + assert user.email == email + assert user.name == "name" + assert user.activated + + +class MockResponse: + def __init__(self, status_code, json_data): + self.status_code = status_code + self.json_data = json_data + self.text = "error" + + def json(self): + return self.json_data diff --git a/app/tests/test.env b/app/tests/test.env index 86d383a..49941be 100644 --- a/app/tests/test.env +++ b/app/tests/test.env @@ -49,6 +49,16 @@ GOOGLE_CLIENT_SECRET=to_fill FACEBOOK_CLIENT_ID=to_fill FACEBOOK_CLIENT_SECRET=to_fill +# Login with OIDC +CONNECT_WITH_OIDC_ICON=fa-github +OIDC_AUTHORIZATION_URL=to_fill +OIDC_USER_INFO_URL=to_fill +OIDC_TOKEN_URL=to_fill +OIDC_SCOPES=openid email profile +OIDC_NAME_FIELD=name +OIDC_CLIENT_ID=to_fill +OIDC_CLIENT_SECRET=to_fill + PGP_SENDER_PRIVATE_KEY_PATH=local_data/private-pgp.asc ALIAS_AUTOMATIC_DISABLE=true