diff --git a/app/app/coupon_utils.py b/app/app/coupon_utils.py index 3aded1f..c1bf71b 100644 --- a/app/app/coupon_utils.py +++ b/app/app/coupon_utils.py @@ -71,13 +71,18 @@ def redeem_coupon(coupon_code: str, user: User) -> Optional[Coupon]: else: sub.end_at = arrow.now().shift(years=coupon.nb_year, days=1) else: - sub = ManualSubscription.create( - user_id=user.id, - end_at=arrow.now().shift(years=coupon.nb_year, days=1), - comment="using coupon code", - is_giveaway=coupon.is_giveaway, - commit=True, - ) + # There may be an expired manual subscription + sub = ManualSubscription.get_by(user_id=user.id) + end_at = arrow.now().shift(years=coupon.nb_year, days=1) + if sub: + sub.end_at = end_at + else: + sub = ManualSubscription.create( + user_id=user.id, + end_at=end_at, + comment="using coupon code", + is_giveaway=coupon.is_giveaway, + ) emit_user_audit_log( user=user, action=UserAuditLogAction.Upgrade, diff --git a/app/oneshot/send_lifetime_user_events.py b/app/oneshot/send_lifetime_user_events.py index a1b0730..abd8621 100644 --- a/app/oneshot/send_lifetime_user_events.py +++ b/app/oneshot/send_lifetime_user_events.py @@ -1,14 +1,14 @@ #!/usr/bin/env python3 import argparse +import sys import time -import arrow from sqlalchemy import func +from app.db import Session from app.events.event_dispatcher import EventDispatcher from app.events.generated.event_pb2 import UserPlanChanged, EventContent from app.models import PartnerUser, User -from app.db import Session parser = argparse.ArgumentParser( prog="Backfill alias", description="Send lifetime users to proton" @@ -19,34 +19,69 @@ parser.add_argument( parser.add_argument( "-e", "--end_pu_id", default=0, type=int, help="Last partner_user_id" ) +parser.add_argument("-t", "--step", default=10000, type=int, help="Step to use") +parser.add_argument("-u", "--user", default="", type=str, help="User to sync") +parser.add_argument( + "-l", "--lifetime", action="store_true", help="Only sync lifetime users" +) args = parser.parse_args() pu_id_start = args.start_pu_id max_pu_id = args.end_pu_id +user_id = args.user +only_lifetime = args.lifetime +step = args.step + if max_pu_id == 0: max_pu_id = Session.query(func.max(PartnerUser.id)).scalar() +if user_id: + try: + user_id = int(user_id) + except ValueError: + user = User.get_by(email=user_id) + if not user: + print(f"User {user_id} not found") + sys.exit(1) + print(f"Limiting to user {user_id}") + user_id = user.id + # So we only have one loop + step = max_pu_id + print(f"Checking partner user {pu_id_start} to {max_pu_id}") -step = 1000 done = 0 start_time = time.time() with_lifetime = 0 +with_plan = 0 +with_free = 0 for batch_start in range(pu_id_start, max_pu_id, step): - users = ( - Session.query(User) - .join(PartnerUser, PartnerUser.user_id == User.id) - .filter( - PartnerUser.id >= batch_start, - PartnerUser.id < batch_start + step, + query = Session.query(User).join(PartnerUser, PartnerUser.user_id == User.id) + if user_id: + query = query.filter(User.id == user_id) + else: + query = query.filter( + PartnerUser.id >= batch_start, PartnerUser.id < batch_start + step + ) + if only_lifetime: + query = query.filter( User.lifetime == True, # noqa :E712 ) - ).all() + users = query.all() for user in users: # Just in case the == True cond is wonky - if not user.lifetime: - continue - with_lifetime += 1 - event = UserPlanChanged(plan_end_time=arrow.get("2038-01-01").timestamp) + if user.lifetime: + event = UserPlanChanged(lifetime=True) + with_lifetime += 1 + else: + plan_end = user.get_active_subscription_end( + include_partner_subscription=False + ) + if plan_end: + event = UserPlanChanged(plan_end_time=plan_end.timestamp) + with_plan += 1 + else: + event = UserPlanChanged() + with_free += 1 EventDispatcher.send_event(user, EventContent(user_plan_change=event)) Session.flush() Session.commit() @@ -57,6 +92,6 @@ for batch_start in range(pu_id_start, max_pu_id, step): time_remaining = remaining / time_per_alias hours_remaining = time_remaining / 60.0 print( - f"\PartnerUser {batch_start}/{max_pu_id} {with_lifetime} {hours_remaining:.2f} mins remaining" + f"artnerUser {batch_start}/{max_pu_id} lifetime {with_lifetime} paid {with_plan} free {with_free} {hours_remaining:.2f} mins remaining" ) -print(f"With SL lifetime {with_lifetime}") +print(f"Sent lifetime {with_lifetime} paid {with_plan} free {with_free}") diff --git a/app/tests/test_coupon_utils.py b/app/tests/test_coupon_utils.py index 743f4ef..12fc2d2 100644 --- a/app/tests/test_coupon_utils.py +++ b/app/tests/test_coupon_utils.py @@ -66,6 +66,31 @@ def test_use_coupon_extend_manual_sub(): assert left.days > 364 +def test_use_coupon_extend_expired_manual_sub(): + user = create_new_user() + initial_end = arrow.now().shift(days=-15) + ManualSubscription.create( + user_id=user.id, + end_at=initial_end, + flush=True, + ) + code = random_string(10) + Coupon.create(code=code, nb_year=1, commit=True) + + coupon = redeem_coupon(code, user) + assert coupon + + coupon = Coupon.get_by(code=code) + assert coupon + assert coupon.used + assert coupon.used_by_user_id == user.id + + sub = user.get_active_subscription() + assert isinstance(sub, ManualSubscription) + left = sub.end_at - initial_end + assert left.days > 364 + + def test_coupon_with_subscription(): user = create_new_user() end_at = arrow.utcnow().shift(days=1).replace(hour=0, minute=0, second=0)