Coverage for dibbler / queries / user_balance.py: 82%
75 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-12 13:57 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-12 13:57 +0000
1from dataclasses import dataclass
2from datetime import datetime
3from typing import Tuple
5from sqlalchemy import (
6 CTE,
7 BindParameter,
8 Float,
9 Integer,
10 Select,
11 and_,
12 bindparam,
13 case,
14 cast,
15 column,
16 func,
17 or_,
18 select,
19)
20from sqlalchemy.orm import Session, aliased
21from sqlalchemy.sql.elements import KeyedColumnElement
23from dibbler.models import (
24 Transaction,
25 TransactionType,
26 User,
27)
28from dibbler.models.Transaction import (
29 DEFAULT_INTEREST_RATE_PERCENT,
30 DEFAULT_PENALTY_MULTIPLIER_PERCENT,
31 DEFAULT_PENALTY_THRESHOLD,
32)
33from dibbler.queries.product_price import _product_price_query
34from dibbler.queries.query_helpers import (
35 CONST_NONE,
36 CONST_ONE,
37 CONST_ZERO,
38 const,
39 until_filter,
40)
43def _joint_transaction_query(
44 user_id: BindParameter[int] | int,
45 use_cache: bool = True,
46 until_time: BindParameter[datetime] | None = None,
47 until_transaction: Transaction | None = None,
48 until_inclusive: bool = True,
49) -> Select[Tuple[int, int, int]]:
50 """
51 The inner query for getting joint transactions relevant to a user.
53 This scans for JOINT_BUY_PRODUCT transactions made by the user,
54 then finds the corresponding JOINT transactions, and counts how many "shares"
55 of the joint transaction the user has, as well as the total number of shares.
56 """
58 if isinstance(until_transaction, Transaction):
59 if until_transaction.id is None: 59 ↛ 60line 59 didn't jump to line 60 because the condition on line 59 was never true
60 raise ValueError("until_transaction must be persisted in the database.")
61 until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
62 else:
63 until_transaction_id = None
65 # First, select all joint buy product transactions for the given user
66 # sub_joint_transaction = aliased(Transaction, name="right_trx")
67 sub_joint_transaction = (
68 select(Transaction.joint_transaction_id.distinct().label("joint_transaction_id"))
69 .where(
70 Transaction.type_ == TransactionType.JOINT_BUY_PRODUCT.as_literal_column(),
71 Transaction.user_id == user_id,
72 until_filter(
73 until_time=until_time,
74 until_transaction_id=until_transaction_id,
75 until_inclusive=until_inclusive,
76 transaction_time=Transaction.time,
77 ),
78 )
79 .subquery("sub_joint_transaction")
80 )
82 # Join those with their main joint transaction
83 # (just use Transaction)
85 # Then, count how many users are involved in each joint transaction
86 joint_transaction_count = aliased(Transaction, name="count_trx")
88 joint_transaction = (
89 select(
90 Transaction.id,
91 # Shares the user has in the transaction,
92 func.sum(
93 case(
94 (joint_transaction_count.user_id == user_id, CONST_ONE),
95 else_=CONST_ZERO,
96 )
97 ).label("user_shares"),
98 # The total number of shares in the transaction,
99 func.count(joint_transaction_count.id).label("user_count"),
100 )
101 .select_from(sub_joint_transaction)
102 .join(
103 Transaction,
104 onclause=Transaction.id == sub_joint_transaction.c.joint_transaction_id,
105 )
106 .join(
107 joint_transaction_count,
108 onclause=joint_transaction_count.joint_transaction_id == Transaction.id,
109 )
110 .group_by(joint_transaction_count.joint_transaction_id)
111 )
113 return joint_transaction
116def _non_joint_transaction_query(
117 user_id: BindParameter[int] | int,
118 use_cache: bool = True,
119 until_time: BindParameter[datetime] | None = None,
120 until_transaction: Transaction | None = None,
121 until_inclusive: bool = True,
122) -> Select[Tuple[int, None, None]]:
123 """
124 The inner query for getting non-joint transactions relevant to a user.
125 """
127 if isinstance(until_transaction, Transaction):
128 if until_transaction.id is None: 128 ↛ 129line 128 didn't jump to line 129 because the condition on line 128 was never true
129 raise ValueError("until_transaction must be persisted in the database.")
130 until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id)
131 else:
132 until_transaction_id = None
134 query = select(
135 Transaction.id,
136 CONST_NONE.label("user_shares"),
137 CONST_NONE.label("user_count"),
138 ).where(
139 or_(
140 and_(
141 Transaction.user_id == user_id,
142 Transaction.type_.in_(
143 [
144 TransactionType.ADD_PRODUCT.as_literal_column(),
145 TransactionType.ADJUST_BALANCE.as_literal_column(),
146 TransactionType.BUY_PRODUCT.as_literal_column(),
147 TransactionType.TRANSFER.as_literal_column(),
148 ]
149 ),
150 ),
151 and_(
152 Transaction.type_ == TransactionType.TRANSFER.as_literal_column(),
153 Transaction.transfer_user_id == user_id,
154 ),
155 Transaction.type_.in_(
156 [
157 TransactionType.THROW_PRODUCT.as_literal_column(),
158 TransactionType.ADJUST_INTEREST.as_literal_column(),
159 TransactionType.ADJUST_PENALTY.as_literal_column(),
160 ]
161 ),
162 ),
163 until_filter(
164 until_time=until_time,
165 until_transaction_id=until_transaction_id,
166 until_inclusive=until_inclusive,
167 ),
168 )
170 return query
173def _product_cost_expression(
174 product_count_column: KeyedColumnElement[int],
175 product_id_column: KeyedColumnElement[int],
176 interest_rate_percent_column: KeyedColumnElement[int],
177 user_balance_column: KeyedColumnElement[int],
178 penalty_threshold_column: KeyedColumnElement[int],
179 penalty_multiplier_percent_column: KeyedColumnElement[int],
180 joint_user_shares_column: KeyedColumnElement[int],
181 joint_user_count_column: KeyedColumnElement[int],
182 use_cache: bool = True,
183 until_time: BindParameter[datetime] | None = None,
184 until_transaction: Transaction | None = None,
185 until_inclusive: bool = True,
186 cte_name: str = "product_price_cte",
187 trx_subset_name: str = "product_price_trx_subset",
188):
189 # TODO: This can get quite expensive real quick, so we should do some caching of the
190 # product prices somehow.
191 expression = (
192 select(
193 cast(
194 func.ceil(
195 # Base price
196 (
197 cast(
198 column("price") * product_count_column * joint_user_shares_column,
199 Float,
200 )
201 / joint_user_count_column
202 )
203 # Interest
204 + (
205 cast(
206 column("price") * product_count_column * joint_user_shares_column,
207 Float,
208 )
209 / joint_user_count_column
210 * cast(interest_rate_percent_column - const(100), Float)
211 / const(100.0)
212 )
213 # Penalty
214 + (
215 (
216 cast(
217 column("price") * product_count_column * joint_user_shares_column,
218 Float,
219 )
220 / joint_user_count_column
221 )
222 * cast(penalty_multiplier_percent_column - const(100), Float)
223 / const(100.0)
224 * cast(user_balance_column < penalty_threshold_column, Integer)
225 )
226 ),
227 Integer,
228 )
229 )
230 .select_from(
231 _product_price_query(
232 product_id_column,
233 use_cache=use_cache,
234 until_time=until_time,
235 until_transaction=until_transaction,
236 until_inclusive=until_inclusive,
237 cte_name=cte_name,
238 trx_subset_name=trx_subset_name,
239 )
240 )
241 .order_by(column("i").desc())
242 .limit(CONST_ONE)
243 .scalar_subquery()
244 )
246 return expression
249def _user_balance_query(
250 user_id: BindParameter[int] | int,
251 use_cache: bool = True,
252 until_time: BindParameter[datetime] | None = None,
253 until_transaction: Transaction | None = None,
254 until_inclusive: bool = True,
255 cte_name: str = "rec_cte",
256 trx_subset_name: str = "trx_subset",
257) -> CTE:
258 """
259 The inner query for calculating the user's balance.
260 """
262 if use_cache: 262 ↛ 265line 262 didn't jump to line 265 because the condition on line 262 was always true
263 print("WARNING: Using cache for user balance query is not implemented yet.")
265 if isinstance(user_id, int): 265 ↛ 268line 265 didn't jump to line 268 because the condition on line 265 was always true
266 user_id = BindParameter("user_id", value=user_id)
268 initial_element = select(
269 CONST_ZERO.label("i"),
270 CONST_ZERO.label("time"),
271 CONST_NONE.label("transaction_id"),
272 CONST_ZERO.label("balance"),
273 const(DEFAULT_INTEREST_RATE_PERCENT).label("interest_rate_percent"),
274 const(DEFAULT_PENALTY_THRESHOLD).label("penalty_threshold"),
275 const(DEFAULT_PENALTY_MULTIPLIER_PERCENT).label("penalty_multiplier_percent"),
276 )
278 recursive_cte = initial_element.cte(name=cte_name, recursive=True)
280 trx_subset_subset = (
281 _non_joint_transaction_query(
282 user_id=user_id,
283 use_cache=use_cache,
284 until_time=until_time,
285 until_transaction=until_transaction,
286 until_inclusive=until_inclusive,
287 )
288 .union_all(
289 _joint_transaction_query(
290 user_id=user_id,
291 use_cache=use_cache,
292 until_time=until_time,
293 until_transaction=until_transaction,
294 until_inclusive=until_inclusive,
295 )
296 )
297 .subquery(f"{trx_subset_name}_subset")
298 )
300 # Subset of transactions that we'll want to iterate over.
301 trx_subset = (
302 select(
303 func.row_number().over(order_by=Transaction.time.asc()).label("i"),
304 Transaction.id,
305 Transaction.amount,
306 Transaction.interest_rate_percent,
307 Transaction.penalty_multiplier_percent,
308 Transaction.penalty_threshold,
309 Transaction.product_count,
310 Transaction.product_id,
311 Transaction.time,
312 Transaction.transfer_user_id,
313 Transaction.type_,
314 trx_subset_subset.c.user_shares,
315 trx_subset_subset.c.user_count,
316 )
317 .select_from(trx_subset_subset)
318 .join(
319 Transaction,
320 onclause=Transaction.id == trx_subset_subset.c.id,
321 )
322 .order_by(Transaction.time.asc())
323 .subquery(trx_subset_name)
324 )
326 recursive_elements = (
327 select(
328 trx_subset.c.i,
329 trx_subset.c.time,
330 trx_subset.c.id.label("transaction_id"),
331 case(
332 # Adjusts balance -> balance gets adjusted
333 (
334 trx_subset.c.type_ == TransactionType.ADJUST_BALANCE.as_literal_column(),
335 recursive_cte.c.balance + trx_subset.c.amount,
336 ),
337 # Adds a product -> balance increases
338 (
339 trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
340 recursive_cte.c.balance + trx_subset.c.amount,
341 ),
342 # Buys a product -> balance decreases
343 (
344 trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
345 recursive_cte.c.balance
346 - _product_cost_expression(
347 product_count_column=trx_subset.c.product_count,
348 product_id_column=trx_subset.c.product_id,
349 interest_rate_percent_column=recursive_cte.c.interest_rate_percent,
350 user_balance_column=recursive_cte.c.balance,
351 penalty_threshold_column=recursive_cte.c.penalty_threshold,
352 penalty_multiplier_percent_column=recursive_cte.c.penalty_multiplier_percent,
353 joint_user_shares_column=CONST_ONE,
354 joint_user_count_column=CONST_ONE,
355 use_cache=use_cache,
356 until_time=until_time,
357 until_transaction=until_transaction,
358 until_inclusive=until_inclusive,
359 cte_name=f"{cte_name}_price",
360 trx_subset_name=f"{trx_subset_name}_price",
361 ).label("product_cost"),
362 ),
363 # Joint transaction -> balance decreases proportionally
364 (
365 trx_subset.c.type_ == TransactionType.JOINT.as_literal_column(),
366 recursive_cte.c.balance
367 - _product_cost_expression(
368 product_count_column=trx_subset.c.product_count,
369 product_id_column=trx_subset.c.product_id,
370 interest_rate_percent_column=recursive_cte.c.interest_rate_percent,
371 user_balance_column=recursive_cte.c.balance,
372 penalty_threshold_column=recursive_cte.c.penalty_threshold,
373 penalty_multiplier_percent_column=recursive_cte.c.penalty_multiplier_percent,
374 joint_user_shares_column=trx_subset.c.user_shares,
375 joint_user_count_column=trx_subset.c.user_count,
376 use_cache=use_cache,
377 until_time=until_time,
378 until_transaction=until_transaction,
379 until_inclusive=until_inclusive,
380 cte_name=f"{cte_name}_joint_price",
381 trx_subset_name=f"{trx_subset_name}_joint_price",
382 ).label("joint_product_cost"),
383 ),
384 # Transfers money to self -> balance increases
385 (
386 and_(
387 trx_subset.c.type_ == TransactionType.TRANSFER.as_literal_column(),
388 trx_subset.c.transfer_user_id == user_id,
389 ),
390 recursive_cte.c.balance + trx_subset.c.amount,
391 ),
392 # Transfers money from self -> balance decreases
393 (
394 and_(
395 trx_subset.c.type_ == TransactionType.TRANSFER.as_literal_column(),
396 trx_subset.c.transfer_user_id != user_id,
397 ),
398 recursive_cte.c.balance - trx_subset.c.amount,
399 ),
400 # Throws a product -> if the user is considered to have bought it, balance increases
401 # TODO: # (
402 # trx_subset.c.type_ == TransactionType.THROW_PRODUCT,
403 # recursive_cte.c.balance + trx_subset.c.amount,
404 # ),
405 # Interest adjustment -> balance stays the same
406 # Penalty adjustment -> balance stays the same
407 else_=recursive_cte.c.balance,
408 ).label("balance"),
409 case(
410 (
411 trx_subset.c.type_ == TransactionType.ADJUST_INTEREST.as_literal_column(),
412 trx_subset.c.interest_rate_percent,
413 ),
414 else_=recursive_cte.c.interest_rate_percent,
415 ).label("interest_rate_percent"),
416 case(
417 (
418 trx_subset.c.type_ == TransactionType.ADJUST_PENALTY.as_literal_column(),
419 trx_subset.c.penalty_threshold,
420 ),
421 else_=recursive_cte.c.penalty_threshold,
422 ).label("penalty_threshold"),
423 case(
424 (
425 trx_subset.c.type_ == TransactionType.ADJUST_PENALTY.as_literal_column(),
426 trx_subset.c.penalty_multiplier_percent,
427 ),
428 else_=recursive_cte.c.penalty_multiplier_percent,
429 ).label("penalty_multiplier_percent"),
430 )
431 .select_from(trx_subset)
432 .where(trx_subset.c.i == recursive_cte.c.i + CONST_ONE)
433 )
435 return recursive_cte.union_all(recursive_elements)
438# TODO: create a function for the log that pretty prints the log entries
439# for debugging purposes
442@dataclass
443class UserBalanceLogEntry:
444 transaction: Transaction
445 balance: int
446 interest_rate_percent: int
447 penalty_threshold: int
448 penalty_multiplier_percent: int
450 def is_penalized(self) -> bool:
451 """
452 Returns whether this exact transaction is penalized.
453 """
455 raise NotImplementedError("is_penalized is not implemented yet.")
458def user_balance_log(
459 sql_session: Session,
460 user: User,
461 use_cache: bool = True,
462 until_time: BindParameter[datetime] | datetime | None = None,
463 until_transaction: Transaction | None = None,
464 until_inclusive: bool = True,
465) -> list[UserBalanceLogEntry]:
466 """
467 Returns a log of the user's balance over time, including interest and penalty adjustments.
469 If 'until' is given, only transactions up to that time are considered.
470 """
472 if user.id is None: 472 ↛ 473line 472 didn't jump to line 473 because the condition on line 472 was never true
473 raise ValueError("User must be persisted in the database.")
475 if not (until_time is None or until_transaction is None): 475 ↛ 476line 475 didn't jump to line 476 because the condition on line 475 was never true
476 raise ValueError("Cannot filter by both until_time and until_transaction.")
478 if isinstance(until_time, datetime):
479 until_time = BindParameter("until_time", value=until_time)
481 recursive_cte = _user_balance_query(
482 user.id,
483 use_cache=use_cache,
484 until_time=until_time,
485 until_transaction=until_transaction,
486 until_inclusive=until_inclusive,
487 )
489 result = sql_session.execute(
490 select(
491 Transaction,
492 recursive_cte.c.balance,
493 recursive_cte.c.interest_rate_percent,
494 recursive_cte.c.penalty_threshold,
495 recursive_cte.c.penalty_multiplier_percent,
496 )
497 .select_from(recursive_cte)
498 .join(
499 Transaction,
500 onclause=Transaction.id == recursive_cte.c.transaction_id,
501 )
502 .order_by(recursive_cte.c.i.asc())
503 ).all()
505 if result is None: 505 ↛ 507line 505 didn't jump to line 507 because the condition on line 505 was never true
506 # If there are no transactions for this user, the query should return 0, not None.
507 raise RuntimeError(
508 f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})."
509 )
511 return [
512 UserBalanceLogEntry(
513 transaction=row[0],
514 balance=row.balance,
515 interest_rate_percent=row.interest_rate_percent,
516 penalty_threshold=row.penalty_threshold,
517 penalty_multiplier_percent=row.penalty_multiplier_percent,
518 )
519 for row in result
520 ]
523def user_balance(
524 sql_session: Session,
525 user: User,
526 use_cache: bool = True,
527 until_time: BindParameter[datetime] | datetime | None = None,
528 until_transaction: Transaction | None = None,
529 until_inclusive: bool = True,
530) -> int:
531 """
532 Calculates the balance of a user.
534 If 'until' is given, only transactions up to that time are considered.
535 """
537 if user.id is None: 537 ↛ 538line 537 didn't jump to line 538 because the condition on line 537 was never true
538 raise ValueError("User must be persisted in the database.")
540 if not (until_time is None or until_transaction is None): 540 ↛ 541line 540 didn't jump to line 541 because the condition on line 540 was never true
541 raise ValueError("Cannot filter by both until_time and until_transaction.")
543 if isinstance(until_time, datetime):
544 until_time = BindParameter("until_time", value=until_time)
546 recursive_cte = _user_balance_query(
547 user.id,
548 use_cache=use_cache,
549 until_time=until_time,
550 until_transaction=until_transaction,
551 until_inclusive=until_inclusive,
552 )
554 result = sql_session.scalar(
555 select(recursive_cte.c.balance)
556 .order_by(recursive_cte.c.i.desc())
557 .limit(CONST_ONE)
558 .offset(CONST_ZERO)
559 )
561 if result is None: 561 ↛ 563line 561 didn't jump to line 563 because the condition on line 561 was never true
562 # If there are no transactions for this user, the query should return 0, not None.
563 raise RuntimeError(
564 f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})."
565 )
567 return result