4.21.3
This commit is contained in:
0
app/app/proton/__init__.py
Normal file
0
app/app/proton/__init__.py
Normal file
110
app/app/proton/proton_callback_handler.py
Normal file
110
app/app/proton/proton_callback_handler.py
Normal file
@ -0,0 +1,110 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from flask import url_for
|
||||
from typing import Optional
|
||||
|
||||
from app.errors import LinkException
|
||||
from app.models import User, Partner
|
||||
from app.proton.proton_client import ProtonClient, ProtonUser
|
||||
from app.account_linking import (
|
||||
process_login_case,
|
||||
process_link_case,
|
||||
PartnerLinkRequest,
|
||||
)
|
||||
|
||||
|
||||
class Action(Enum):
|
||||
Login = 1
|
||||
Link = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProtonCallbackResult:
|
||||
redirect_to_login: bool
|
||||
flash_message: Optional[str]
|
||||
flash_category: Optional[str]
|
||||
redirect: Optional[str]
|
||||
user: Optional[User]
|
||||
|
||||
|
||||
def generate_account_not_allowed_to_log_in() -> ProtonCallbackResult:
|
||||
return ProtonCallbackResult(
|
||||
redirect_to_login=True,
|
||||
flash_message="This account is not allowed to log in with Proton. Please convert your account to a full Proton account",
|
||||
flash_category="error",
|
||||
redirect=None,
|
||||
user=None,
|
||||
)
|
||||
|
||||
|
||||
class ProtonCallbackHandler:
|
||||
def __init__(self, proton_client: ProtonClient):
|
||||
self.proton_client = proton_client
|
||||
|
||||
def handle_login(self, partner: Partner) -> ProtonCallbackResult:
|
||||
try:
|
||||
user = self.__get_partner_user()
|
||||
if user is None:
|
||||
return generate_account_not_allowed_to_log_in()
|
||||
res = process_login_case(user, partner)
|
||||
return ProtonCallbackResult(
|
||||
redirect_to_login=False,
|
||||
flash_message=None,
|
||||
flash_category=None,
|
||||
redirect=None,
|
||||
user=res.user,
|
||||
)
|
||||
except LinkException as e:
|
||||
return ProtonCallbackResult(
|
||||
redirect_to_login=True,
|
||||
flash_message=e.message,
|
||||
flash_category="error",
|
||||
redirect=None,
|
||||
user=None,
|
||||
)
|
||||
|
||||
def handle_link(
|
||||
self,
|
||||
current_user: Optional[User],
|
||||
partner: Partner,
|
||||
) -> ProtonCallbackResult:
|
||||
if current_user is None:
|
||||
raise Exception("Cannot link account with current_user being None")
|
||||
try:
|
||||
user = self.__get_partner_user()
|
||||
if user is None:
|
||||
return generate_account_not_allowed_to_log_in()
|
||||
res = process_link_case(user, current_user, partner)
|
||||
return ProtonCallbackResult(
|
||||
redirect_to_login=False,
|
||||
flash_message="Account successfully linked",
|
||||
flash_category="success",
|
||||
redirect=url_for("dashboard.setting"),
|
||||
user=res.user,
|
||||
)
|
||||
except LinkException as e:
|
||||
return ProtonCallbackResult(
|
||||
redirect_to_login=False,
|
||||
flash_message=e.message,
|
||||
flash_category="error",
|
||||
redirect=None,
|
||||
user=None,
|
||||
)
|
||||
|
||||
def __get_partner_user(self) -> Optional[PartnerLinkRequest]:
|
||||
proton_user = self.__get_proton_user()
|
||||
if proton_user is None:
|
||||
return None
|
||||
return PartnerLinkRequest(
|
||||
email=proton_user.email,
|
||||
external_user_id=proton_user.id,
|
||||
name=proton_user.name,
|
||||
plan=proton_user.plan,
|
||||
from_partner=False, # The user has started this flow, so we don't mark it as created by a partner
|
||||
)
|
||||
|
||||
def __get_proton_user(self) -> Optional[ProtonUser]:
|
||||
user = self.proton_client.get_user()
|
||||
if user is None:
|
||||
return None
|
||||
return ProtonUser(email=user.email, plan=user.plan, name=user.name, id=user.id)
|
||||
137
app/app/proton/proton_client.py
Normal file
137
app/app/proton/proton_client.py
Normal file
@ -0,0 +1,137 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from arrow import Arrow
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from requests import Response, Session
|
||||
from typing import Optional
|
||||
|
||||
from app.account_linking import SLPlan, SLPlanType
|
||||
from app.config import PROTON_EXTRA_HEADER_NAME, PROTON_EXTRA_HEADER_VALUE
|
||||
from app.log import LOG
|
||||
|
||||
_APP_VERSION = "OauthClient_1.0.0"
|
||||
|
||||
PROTON_ERROR_CODE_NOT_EXISTS = 2501
|
||||
|
||||
PLAN_FREE = 1
|
||||
PLAN_PREMIUM = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserInformation:
|
||||
email: str
|
||||
name: str
|
||||
id: str
|
||||
plan: SLPlan
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProtonUser:
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
plan: SLPlan
|
||||
|
||||
|
||||
@dataclass
|
||||
class AccessCredentials:
|
||||
access_token: str
|
||||
session_id: str
|
||||
|
||||
|
||||
def convert_access_token(access_token_response: str) -> AccessCredentials:
|
||||
"""
|
||||
The Access token response contains both the Proton Session ID and the Access Token.
|
||||
The Session ID is necessary in order to use the Proton API. However, the OAuth response does not allow us to return
|
||||
extra content.
|
||||
This method takes the Access token response and extracts the session ID and the access token.
|
||||
"""
|
||||
parts = access_token_response.split("-")
|
||||
if len(parts) != 3:
|
||||
raise Exception("Invalid access token response")
|
||||
if parts[0] != "pt":
|
||||
raise Exception("Invalid access token response format")
|
||||
return AccessCredentials(
|
||||
session_id=parts[1],
|
||||
access_token=parts[2],
|
||||
)
|
||||
|
||||
|
||||
class ProtonClient(ABC):
|
||||
@abstractmethod
|
||||
def get_user(self) -> Optional[UserInformation]:
|
||||
pass
|
||||
|
||||
|
||||
class HttpProtonClient(ProtonClient):
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
credentials: AccessCredentials,
|
||||
original_ip: Optional[str],
|
||||
verify: bool = True,
|
||||
):
|
||||
self.base_url = base_url
|
||||
self.access_token = credentials.access_token
|
||||
client = Session()
|
||||
client.verify = verify
|
||||
headers = {
|
||||
"x-pm-appversion": _APP_VERSION,
|
||||
"x-pm-apiversion": "3",
|
||||
"x-pm-uid": credentials.session_id,
|
||||
"authorization": f"Bearer {credentials.access_token}",
|
||||
"accept": "application/vnd.protonmail.v1+json",
|
||||
"user-agent": "ProtonOauthClient",
|
||||
}
|
||||
|
||||
if PROTON_EXTRA_HEADER_NAME and PROTON_EXTRA_HEADER_VALUE:
|
||||
headers[PROTON_EXTRA_HEADER_NAME] = PROTON_EXTRA_HEADER_VALUE
|
||||
|
||||
if original_ip is not None:
|
||||
headers["x-forwarded-for"] = original_ip
|
||||
client.headers.update(headers)
|
||||
self.client = client
|
||||
|
||||
def get_user(self) -> Optional[UserInformation]:
|
||||
info = self.__get("/simple_login/v1/subscription")["Subscription"]
|
||||
if not info["IsAllowed"]:
|
||||
LOG.debug("Account is not allowed to log into SL")
|
||||
return None
|
||||
|
||||
plan_value = info["Plan"]
|
||||
if plan_value == PLAN_FREE:
|
||||
plan = SLPlan(type=SLPlanType.Free, expiration=None)
|
||||
elif plan_value == PLAN_PREMIUM:
|
||||
plan = SLPlan(
|
||||
type=SLPlanType.Premium,
|
||||
expiration=Arrow.fromtimestamp(info["PlanExpiration"], tzinfo="utc"),
|
||||
)
|
||||
else:
|
||||
raise Exception(f"Invalid value for plan: {plan_value}")
|
||||
|
||||
return UserInformation(
|
||||
email=info.get("Email"),
|
||||
name=info.get("DisplayName"),
|
||||
id=info.get("UserID"),
|
||||
plan=plan,
|
||||
)
|
||||
|
||||
def __get(self, route: str) -> dict:
|
||||
url = f"{self.base_url}{route}"
|
||||
res = self.client.get(url)
|
||||
return self.__validate_response(res)
|
||||
|
||||
@staticmethod
|
||||
def __validate_response(res: Response) -> dict:
|
||||
status = res.status_code
|
||||
if status != HTTPStatus.OK:
|
||||
raise Exception(
|
||||
f"Unexpected status code. Wanted 200 and got {status}: " + res.text
|
||||
)
|
||||
as_json = res.json()
|
||||
res_code = as_json.get("Code")
|
||||
if not res_code or res_code != 1000:
|
||||
raise Exception(
|
||||
f"Unexpected response code. Wanted 1000 and got {res_code}: " + res.text
|
||||
)
|
||||
return as_json
|
||||
35
app/app/proton/utils.py
Normal file
35
app/app/proton/utils.py
Normal file
@ -0,0 +1,35 @@
|
||||
from newrelic import agent
|
||||
from typing import Optional
|
||||
|
||||
from app.db import Session
|
||||
from app.errors import ProtonPartnerNotSetUp
|
||||
from app.models import Partner, PartnerUser, User
|
||||
|
||||
PROTON_PARTNER_NAME = "Proton"
|
||||
_PROTON_PARTNER: Optional[Partner] = None
|
||||
|
||||
|
||||
def get_proton_partner() -> Partner:
|
||||
global _PROTON_PARTNER
|
||||
if _PROTON_PARTNER is None:
|
||||
partner = Partner.get_by(name=PROTON_PARTNER_NAME)
|
||||
if partner is None:
|
||||
raise ProtonPartnerNotSetUp
|
||||
Session.expunge(partner)
|
||||
_PROTON_PARTNER = partner
|
||||
return _PROTON_PARTNER
|
||||
|
||||
|
||||
def is_proton_partner(partner: Partner) -> bool:
|
||||
return partner.name == PROTON_PARTNER_NAME
|
||||
|
||||
|
||||
def perform_proton_account_unlink(current_user: User):
|
||||
proton_partner = get_proton_partner()
|
||||
partner_user = PartnerUser.get_by(
|
||||
user_id=current_user.id, partner_id=proton_partner.id
|
||||
)
|
||||
if partner_user is not None:
|
||||
PartnerUser.delete(partner_user.id)
|
||||
Session.commit()
|
||||
agent.record_custom_event("AccountUnlinked", {"partner": proton_partner.name})
|
||||
Reference in New Issue
Block a user