Coverage for dibbler / queries / user_balance.py: 82%

75 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-12 13:57 +0000

1from dataclasses import dataclass 

2from datetime import datetime 

3from typing import Tuple 

4 

5from sqlalchemy import ( 

6 CTE, 

7 BindParameter, 

8 Float, 

9 Integer, 

10 Select, 

11 and_, 

12 bindparam, 

13 case, 

14 cast, 

15 column, 

16 func, 

17 or_, 

18 select, 

19) 

20from sqlalchemy.orm import Session, aliased 

21from sqlalchemy.sql.elements import KeyedColumnElement 

22 

23from dibbler.models import ( 

24 Transaction, 

25 TransactionType, 

26 User, 

27) 

28from dibbler.models.Transaction import ( 

29 DEFAULT_INTEREST_RATE_PERCENT, 

30 DEFAULT_PENALTY_MULTIPLIER_PERCENT, 

31 DEFAULT_PENALTY_THRESHOLD, 

32) 

33from dibbler.queries.product_price import _product_price_query 

34from dibbler.queries.query_helpers import ( 

35 CONST_NONE, 

36 CONST_ONE, 

37 CONST_ZERO, 

38 const, 

39 until_filter, 

40) 

41 

42 

43def _joint_transaction_query( 

44 user_id: BindParameter[int] | int, 

45 use_cache: bool = True, 

46 until_time: BindParameter[datetime] | None = None, 

47 until_transaction: Transaction | None = None, 

48 until_inclusive: bool = True, 

49) -> Select[Tuple[int, int, int]]: 

50 """ 

51 The inner query for getting joint transactions relevant to a user. 

52 

53 This scans for JOINT_BUY_PRODUCT transactions made by the user, 

54 then finds the corresponding JOINT transactions, and counts how many "shares" 

55 of the joint transaction the user has, as well as the total number of shares. 

56 """ 

57 

58 if isinstance(until_transaction, Transaction): 

59 if until_transaction.id is None: 59 ↛ 60line 59 didn't jump to line 60 because the condition on line 59 was never true

60 raise ValueError("until_transaction must be persisted in the database.") 

61 until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id) 

62 else: 

63 until_transaction_id = None 

64 

65 # First, select all joint buy product transactions for the given user 

66 # sub_joint_transaction = aliased(Transaction, name="right_trx") 

67 sub_joint_transaction = ( 

68 select(Transaction.joint_transaction_id.distinct().label("joint_transaction_id")) 

69 .where( 

70 Transaction.type_ == TransactionType.JOINT_BUY_PRODUCT.as_literal_column(), 

71 Transaction.user_id == user_id, 

72 until_filter( 

73 until_time=until_time, 

74 until_transaction_id=until_transaction_id, 

75 until_inclusive=until_inclusive, 

76 transaction_time=Transaction.time, 

77 ), 

78 ) 

79 .subquery("sub_joint_transaction") 

80 ) 

81 

82 # Join those with their main joint transaction 

83 # (just use Transaction) 

84 

85 # Then, count how many users are involved in each joint transaction 

86 joint_transaction_count = aliased(Transaction, name="count_trx") 

87 

88 joint_transaction = ( 

89 select( 

90 Transaction.id, 

91 # Shares the user has in the transaction, 

92 func.sum( 

93 case( 

94 (joint_transaction_count.user_id == user_id, CONST_ONE), 

95 else_=CONST_ZERO, 

96 ) 

97 ).label("user_shares"), 

98 # The total number of shares in the transaction, 

99 func.count(joint_transaction_count.id).label("user_count"), 

100 ) 

101 .select_from(sub_joint_transaction) 

102 .join( 

103 Transaction, 

104 onclause=Transaction.id == sub_joint_transaction.c.joint_transaction_id, 

105 ) 

106 .join( 

107 joint_transaction_count, 

108 onclause=joint_transaction_count.joint_transaction_id == Transaction.id, 

109 ) 

110 .group_by(joint_transaction_count.joint_transaction_id) 

111 ) 

112 

113 return joint_transaction 

114 

115 

116def _non_joint_transaction_query( 

117 user_id: BindParameter[int] | int, 

118 use_cache: bool = True, 

119 until_time: BindParameter[datetime] | None = None, 

120 until_transaction: Transaction | None = None, 

121 until_inclusive: bool = True, 

122) -> Select[Tuple[int, None, None]]: 

123 """ 

124 The inner query for getting non-joint transactions relevant to a user. 

125 """ 

126 

127 if isinstance(until_transaction, Transaction): 

128 if until_transaction.id is None: 128 ↛ 129line 128 didn't jump to line 129 because the condition on line 128 was never true

129 raise ValueError("until_transaction must be persisted in the database.") 

130 until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id) 

131 else: 

132 until_transaction_id = None 

133 

134 query = select( 

135 Transaction.id, 

136 CONST_NONE.label("user_shares"), 

137 CONST_NONE.label("user_count"), 

138 ).where( 

139 or_( 

140 and_( 

141 Transaction.user_id == user_id, 

142 Transaction.type_.in_( 

143 [ 

144 TransactionType.ADD_PRODUCT.as_literal_column(), 

145 TransactionType.ADJUST_BALANCE.as_literal_column(), 

146 TransactionType.BUY_PRODUCT.as_literal_column(), 

147 TransactionType.TRANSFER.as_literal_column(), 

148 ] 

149 ), 

150 ), 

151 and_( 

152 Transaction.type_ == TransactionType.TRANSFER.as_literal_column(), 

153 Transaction.transfer_user_id == user_id, 

154 ), 

155 Transaction.type_.in_( 

156 [ 

157 TransactionType.THROW_PRODUCT.as_literal_column(), 

158 TransactionType.ADJUST_INTEREST.as_literal_column(), 

159 TransactionType.ADJUST_PENALTY.as_literal_column(), 

160 ] 

161 ), 

162 ), 

163 until_filter( 

164 until_time=until_time, 

165 until_transaction_id=until_transaction_id, 

166 until_inclusive=until_inclusive, 

167 ), 

168 ) 

169 

170 return query 

171 

172 

173def _product_cost_expression( 

174 product_count_column: KeyedColumnElement[int], 

175 product_id_column: KeyedColumnElement[int], 

176 interest_rate_percent_column: KeyedColumnElement[int], 

177 user_balance_column: KeyedColumnElement[int], 

178 penalty_threshold_column: KeyedColumnElement[int], 

179 penalty_multiplier_percent_column: KeyedColumnElement[int], 

180 joint_user_shares_column: KeyedColumnElement[int], 

181 joint_user_count_column: KeyedColumnElement[int], 

182 use_cache: bool = True, 

183 until_time: BindParameter[datetime] | None = None, 

184 until_transaction: Transaction | None = None, 

185 until_inclusive: bool = True, 

186 cte_name: str = "product_price_cte", 

187 trx_subset_name: str = "product_price_trx_subset", 

188): 

189 # TODO: This can get quite expensive real quick, so we should do some caching of the 

190 # product prices somehow. 

191 expression = ( 

192 select( 

193 cast( 

194 func.ceil( 

195 # Base price 

196 ( 

197 cast( 

198 column("price") * product_count_column * joint_user_shares_column, 

199 Float, 

200 ) 

201 / joint_user_count_column 

202 ) 

203 # Interest 

204 + ( 

205 cast( 

206 column("price") * product_count_column * joint_user_shares_column, 

207 Float, 

208 ) 

209 / joint_user_count_column 

210 * cast(interest_rate_percent_column - const(100), Float) 

211 / const(100.0) 

212 ) 

213 # Penalty 

214 + ( 

215 ( 

216 cast( 

217 column("price") * product_count_column * joint_user_shares_column, 

218 Float, 

219 ) 

220 / joint_user_count_column 

221 ) 

222 * cast(penalty_multiplier_percent_column - const(100), Float) 

223 / const(100.0) 

224 * cast(user_balance_column < penalty_threshold_column, Integer) 

225 ) 

226 ), 

227 Integer, 

228 ) 

229 ) 

230 .select_from( 

231 _product_price_query( 

232 product_id_column, 

233 use_cache=use_cache, 

234 until_time=until_time, 

235 until_transaction=until_transaction, 

236 until_inclusive=until_inclusive, 

237 cte_name=cte_name, 

238 trx_subset_name=trx_subset_name, 

239 ) 

240 ) 

241 .order_by(column("i").desc()) 

242 .limit(CONST_ONE) 

243 .scalar_subquery() 

244 ) 

245 

246 return expression 

247 

248 

249def _user_balance_query( 

250 user_id: BindParameter[int] | int, 

251 use_cache: bool = True, 

252 until_time: BindParameter[datetime] | None = None, 

253 until_transaction: Transaction | None = None, 

254 until_inclusive: bool = True, 

255 cte_name: str = "rec_cte", 

256 trx_subset_name: str = "trx_subset", 

257) -> CTE: 

258 """ 

259 The inner query for calculating the user's balance. 

260 """ 

261 

262 if use_cache: 262 ↛ 265line 262 didn't jump to line 265 because the condition on line 262 was always true

263 print("WARNING: Using cache for user balance query is not implemented yet.") 

264 

265 if isinstance(user_id, int): 265 ↛ 268line 265 didn't jump to line 268 because the condition on line 265 was always true

266 user_id = BindParameter("user_id", value=user_id) 

267 

268 initial_element = select( 

269 CONST_ZERO.label("i"), 

270 CONST_ZERO.label("time"), 

271 CONST_NONE.label("transaction_id"), 

272 CONST_ZERO.label("balance"), 

273 const(DEFAULT_INTEREST_RATE_PERCENT).label("interest_rate_percent"), 

274 const(DEFAULT_PENALTY_THRESHOLD).label("penalty_threshold"), 

275 const(DEFAULT_PENALTY_MULTIPLIER_PERCENT).label("penalty_multiplier_percent"), 

276 ) 

277 

278 recursive_cte = initial_element.cte(name=cte_name, recursive=True) 

279 

280 trx_subset_subset = ( 

281 _non_joint_transaction_query( 

282 user_id=user_id, 

283 use_cache=use_cache, 

284 until_time=until_time, 

285 until_transaction=until_transaction, 

286 until_inclusive=until_inclusive, 

287 ) 

288 .union_all( 

289 _joint_transaction_query( 

290 user_id=user_id, 

291 use_cache=use_cache, 

292 until_time=until_time, 

293 until_transaction=until_transaction, 

294 until_inclusive=until_inclusive, 

295 ) 

296 ) 

297 .subquery(f"{trx_subset_name}_subset") 

298 ) 

299 

300 # Subset of transactions that we'll want to iterate over. 

301 trx_subset = ( 

302 select( 

303 func.row_number().over(order_by=Transaction.time.asc()).label("i"), 

304 Transaction.id, 

305 Transaction.amount, 

306 Transaction.interest_rate_percent, 

307 Transaction.penalty_multiplier_percent, 

308 Transaction.penalty_threshold, 

309 Transaction.product_count, 

310 Transaction.product_id, 

311 Transaction.time, 

312 Transaction.transfer_user_id, 

313 Transaction.type_, 

314 trx_subset_subset.c.user_shares, 

315 trx_subset_subset.c.user_count, 

316 ) 

317 .select_from(trx_subset_subset) 

318 .join( 

319 Transaction, 

320 onclause=Transaction.id == trx_subset_subset.c.id, 

321 ) 

322 .order_by(Transaction.time.asc()) 

323 .subquery(trx_subset_name) 

324 ) 

325 

326 recursive_elements = ( 

327 select( 

328 trx_subset.c.i, 

329 trx_subset.c.time, 

330 trx_subset.c.id.label("transaction_id"), 

331 case( 

332 # Adjusts balance -> balance gets adjusted 

333 ( 

334 trx_subset.c.type_ == TransactionType.ADJUST_BALANCE.as_literal_column(), 

335 recursive_cte.c.balance + trx_subset.c.amount, 

336 ), 

337 # Adds a product -> balance increases 

338 ( 

339 trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), 

340 recursive_cte.c.balance + trx_subset.c.amount, 

341 ), 

342 # Buys a product -> balance decreases 

343 ( 

344 trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(), 

345 recursive_cte.c.balance 

346 - _product_cost_expression( 

347 product_count_column=trx_subset.c.product_count, 

348 product_id_column=trx_subset.c.product_id, 

349 interest_rate_percent_column=recursive_cte.c.interest_rate_percent, 

350 user_balance_column=recursive_cte.c.balance, 

351 penalty_threshold_column=recursive_cte.c.penalty_threshold, 

352 penalty_multiplier_percent_column=recursive_cte.c.penalty_multiplier_percent, 

353 joint_user_shares_column=CONST_ONE, 

354 joint_user_count_column=CONST_ONE, 

355 use_cache=use_cache, 

356 until_time=until_time, 

357 until_transaction=until_transaction, 

358 until_inclusive=until_inclusive, 

359 cte_name=f"{cte_name}_price", 

360 trx_subset_name=f"{trx_subset_name}_price", 

361 ).label("product_cost"), 

362 ), 

363 # Joint transaction -> balance decreases proportionally 

364 ( 

365 trx_subset.c.type_ == TransactionType.JOINT.as_literal_column(), 

366 recursive_cte.c.balance 

367 - _product_cost_expression( 

368 product_count_column=trx_subset.c.product_count, 

369 product_id_column=trx_subset.c.product_id, 

370 interest_rate_percent_column=recursive_cte.c.interest_rate_percent, 

371 user_balance_column=recursive_cte.c.balance, 

372 penalty_threshold_column=recursive_cte.c.penalty_threshold, 

373 penalty_multiplier_percent_column=recursive_cte.c.penalty_multiplier_percent, 

374 joint_user_shares_column=trx_subset.c.user_shares, 

375 joint_user_count_column=trx_subset.c.user_count, 

376 use_cache=use_cache, 

377 until_time=until_time, 

378 until_transaction=until_transaction, 

379 until_inclusive=until_inclusive, 

380 cte_name=f"{cte_name}_joint_price", 

381 trx_subset_name=f"{trx_subset_name}_joint_price", 

382 ).label("joint_product_cost"), 

383 ), 

384 # Transfers money to self -> balance increases 

385 ( 

386 and_( 

387 trx_subset.c.type_ == TransactionType.TRANSFER.as_literal_column(), 

388 trx_subset.c.transfer_user_id == user_id, 

389 ), 

390 recursive_cte.c.balance + trx_subset.c.amount, 

391 ), 

392 # Transfers money from self -> balance decreases 

393 ( 

394 and_( 

395 trx_subset.c.type_ == TransactionType.TRANSFER.as_literal_column(), 

396 trx_subset.c.transfer_user_id != user_id, 

397 ), 

398 recursive_cte.c.balance - trx_subset.c.amount, 

399 ), 

400 # Throws a product -> if the user is considered to have bought it, balance increases 

401 # TODO: # ( 

402 # trx_subset.c.type_ == TransactionType.THROW_PRODUCT, 

403 # recursive_cte.c.balance + trx_subset.c.amount, 

404 # ), 

405 # Interest adjustment -> balance stays the same 

406 # Penalty adjustment -> balance stays the same 

407 else_=recursive_cte.c.balance, 

408 ).label("balance"), 

409 case( 

410 ( 

411 trx_subset.c.type_ == TransactionType.ADJUST_INTEREST.as_literal_column(), 

412 trx_subset.c.interest_rate_percent, 

413 ), 

414 else_=recursive_cte.c.interest_rate_percent, 

415 ).label("interest_rate_percent"), 

416 case( 

417 ( 

418 trx_subset.c.type_ == TransactionType.ADJUST_PENALTY.as_literal_column(), 

419 trx_subset.c.penalty_threshold, 

420 ), 

421 else_=recursive_cte.c.penalty_threshold, 

422 ).label("penalty_threshold"), 

423 case( 

424 ( 

425 trx_subset.c.type_ == TransactionType.ADJUST_PENALTY.as_literal_column(), 

426 trx_subset.c.penalty_multiplier_percent, 

427 ), 

428 else_=recursive_cte.c.penalty_multiplier_percent, 

429 ).label("penalty_multiplier_percent"), 

430 ) 

431 .select_from(trx_subset) 

432 .where(trx_subset.c.i == recursive_cte.c.i + CONST_ONE) 

433 ) 

434 

435 return recursive_cte.union_all(recursive_elements) 

436 

437 

438# TODO: create a function for the log that pretty prints the log entries 

439# for debugging purposes 

440 

441 

442@dataclass 

443class UserBalanceLogEntry: 

444 transaction: Transaction 

445 balance: int 

446 interest_rate_percent: int 

447 penalty_threshold: int 

448 penalty_multiplier_percent: int 

449 

450 def is_penalized(self) -> bool: 

451 """ 

452 Returns whether this exact transaction is penalized. 

453 """ 

454 

455 raise NotImplementedError("is_penalized is not implemented yet.") 

456 

457 

458def user_balance_log( 

459 sql_session: Session, 

460 user: User, 

461 use_cache: bool = True, 

462 until_time: BindParameter[datetime] | datetime | None = None, 

463 until_transaction: Transaction | None = None, 

464 until_inclusive: bool = True, 

465) -> list[UserBalanceLogEntry]: 

466 """ 

467 Returns a log of the user's balance over time, including interest and penalty adjustments. 

468 

469 If 'until' is given, only transactions up to that time are considered. 

470 """ 

471 

472 if user.id is None: 472 ↛ 473line 472 didn't jump to line 473 because the condition on line 472 was never true

473 raise ValueError("User must be persisted in the database.") 

474 

475 if not (until_time is None or until_transaction is None): 475 ↛ 476line 475 didn't jump to line 476 because the condition on line 475 was never true

476 raise ValueError("Cannot filter by both until_time and until_transaction.") 

477 

478 if isinstance(until_time, datetime): 

479 until_time = BindParameter("until_time", value=until_time) 

480 

481 recursive_cte = _user_balance_query( 

482 user.id, 

483 use_cache=use_cache, 

484 until_time=until_time, 

485 until_transaction=until_transaction, 

486 until_inclusive=until_inclusive, 

487 ) 

488 

489 result = sql_session.execute( 

490 select( 

491 Transaction, 

492 recursive_cte.c.balance, 

493 recursive_cte.c.interest_rate_percent, 

494 recursive_cte.c.penalty_threshold, 

495 recursive_cte.c.penalty_multiplier_percent, 

496 ) 

497 .select_from(recursive_cte) 

498 .join( 

499 Transaction, 

500 onclause=Transaction.id == recursive_cte.c.transaction_id, 

501 ) 

502 .order_by(recursive_cte.c.i.asc()) 

503 ).all() 

504 

505 if result is None: 505 ↛ 507line 505 didn't jump to line 507 because the condition on line 505 was never true

506 # If there are no transactions for this user, the query should return 0, not None. 

507 raise RuntimeError( 

508 f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})." 

509 ) 

510 

511 return [ 

512 UserBalanceLogEntry( 

513 transaction=row[0], 

514 balance=row.balance, 

515 interest_rate_percent=row.interest_rate_percent, 

516 penalty_threshold=row.penalty_threshold, 

517 penalty_multiplier_percent=row.penalty_multiplier_percent, 

518 ) 

519 for row in result 

520 ] 

521 

522 

523def user_balance( 

524 sql_session: Session, 

525 user: User, 

526 use_cache: bool = True, 

527 until_time: BindParameter[datetime] | datetime | None = None, 

528 until_transaction: Transaction | None = None, 

529 until_inclusive: bool = True, 

530) -> int: 

531 """ 

532 Calculates the balance of a user. 

533 

534 If 'until' is given, only transactions up to that time are considered. 

535 """ 

536 

537 if user.id is None: 537 ↛ 538line 537 didn't jump to line 538 because the condition on line 537 was never true

538 raise ValueError("User must be persisted in the database.") 

539 

540 if not (until_time is None or until_transaction is None): 540 ↛ 541line 540 didn't jump to line 541 because the condition on line 540 was never true

541 raise ValueError("Cannot filter by both until_time and until_transaction.") 

542 

543 if isinstance(until_time, datetime): 

544 until_time = BindParameter("until_time", value=until_time) 

545 

546 recursive_cte = _user_balance_query( 

547 user.id, 

548 use_cache=use_cache, 

549 until_time=until_time, 

550 until_transaction=until_transaction, 

551 until_inclusive=until_inclusive, 

552 ) 

553 

554 result = sql_session.scalar( 

555 select(recursive_cte.c.balance) 

556 .order_by(recursive_cte.c.i.desc()) 

557 .limit(CONST_ONE) 

558 .offset(CONST_ZERO) 

559 ) 

560 

561 if result is None: 561 ↛ 563line 561 didn't jump to line 563 because the condition on line 561 was never true

562 # If there are no transactions for this user, the query should return 0, not None. 

563 raise RuntimeError( 

564 f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})." 

565 ) 

566 

567 return result