101 lines
2.9 KiB
Python
101 lines
2.9 KiB
Python
import arrow
|
|
|
|
from app.account_linking import (
|
|
SLPlan,
|
|
SLPlanType,
|
|
set_plan_for_partner_user,
|
|
)
|
|
from app.db import Session
|
|
from app.models import User, PartnerUser, PartnerSubscription
|
|
from app.proton.proton_partner import get_proton_partner
|
|
from app.utils import random_string
|
|
from tests.utils import random_email
|
|
|
|
partner_user_id: int = 0
|
|
|
|
|
|
def setup_module():
|
|
global partner_user_id
|
|
email = random_email()
|
|
external_id = random_string()
|
|
sl_user = User.create(email, commit=True)
|
|
partner_user_id = PartnerUser.create(
|
|
user_id=sl_user.id,
|
|
partner_id=get_proton_partner().id,
|
|
external_user_id=external_id,
|
|
partner_email=email,
|
|
commit=True,
|
|
).id
|
|
|
|
|
|
def setup_function(func):
|
|
Session.query(PartnerSubscription).delete()
|
|
|
|
|
|
def test_free_plan_removes_sub():
|
|
pu = PartnerUser.get(partner_user_id)
|
|
sub_id = PartnerSubscription.create(
|
|
partner_user_id=partner_user_id,
|
|
end_at=arrow.utcnow(),
|
|
lifetime=False,
|
|
commit=True,
|
|
).id
|
|
set_plan_for_partner_user(pu, plan=SLPlan(type=SLPlanType.Free, expiration=None))
|
|
assert PartnerSubscription.get(sub_id) is None
|
|
|
|
|
|
def test_premium_plan_updates_expiration():
|
|
pu = PartnerUser.get(partner_user_id)
|
|
sub_id = PartnerSubscription.create(
|
|
partner_user_id=partner_user_id,
|
|
end_at=arrow.utcnow(),
|
|
lifetime=False,
|
|
commit=True,
|
|
).id
|
|
new_expiration = arrow.utcnow().shift(days=+10)
|
|
set_plan_for_partner_user(
|
|
pu, plan=SLPlan(type=SLPlanType.Premium, expiration=new_expiration)
|
|
)
|
|
assert PartnerSubscription.get(sub_id).end_at == new_expiration
|
|
|
|
|
|
def test_premium_plan_creates_sub():
|
|
pu = PartnerUser.get(partner_user_id)
|
|
new_expiration = arrow.utcnow().shift(days=+10)
|
|
set_plan_for_partner_user(
|
|
pu, plan=SLPlan(type=SLPlanType.Premium, expiration=new_expiration)
|
|
)
|
|
assert (
|
|
PartnerSubscription.get_by(partner_user_id=partner_user_id).end_at
|
|
== new_expiration
|
|
)
|
|
|
|
|
|
def test_lifetime_creates_sub():
|
|
pu = PartnerUser.get(partner_user_id)
|
|
new_expiration = arrow.utcnow().shift(days=+10)
|
|
set_plan_for_partner_user(
|
|
pu, plan=SLPlan(type=SLPlanType.PremiumLifetime, expiration=new_expiration)
|
|
)
|
|
sub = PartnerSubscription.get_by(partner_user_id=partner_user_id)
|
|
assert sub is not None
|
|
assert sub.end_at is None
|
|
assert sub.lifetime
|
|
|
|
|
|
def test_lifetime_updates_sub():
|
|
pu = PartnerUser.get(partner_user_id)
|
|
sub_id = PartnerSubscription.create(
|
|
partner_user_id=partner_user_id,
|
|
end_at=arrow.utcnow(),
|
|
lifetime=False,
|
|
commit=True,
|
|
).id
|
|
set_plan_for_partner_user(
|
|
pu, plan=SLPlan(type=SLPlanType.PremiumLifetime, expiration=arrow.utcnow())
|
|
)
|
|
sub = PartnerSubscription.get(sub_id)
|
|
assert sub is not None
|
|
assert sub.end_at is None
|
|
assert sub.lifetime
|