import logging from contextlib import contextmanager from datetime import datetime, timedelta from decimal import Decimal from functools import cached_property from typing import Any, Callable, Dict, Generator, List, Optional from grpc import Channel from pytest_freezegun import freeze_time from t_tech.invest import ( Candle, GetCandlesResponse, GetMarginAttributesResponse, HistoricCandle, MarketDataResponse, MoneyValue, OrderDirection, OrderType, PortfolioPosition, PortfolioResponse, PostOrderResponse, Quotation, ) from t_tech.invest.channels import create_channel from t_tech.invest.services import Services from t_tech.invest.strategies.base.strategy_settings_base import StrategySettings from t_tech.invest.typedefs import AccountId, ChannelArgumentType from t_tech.invest.utils import ( candle_interval_to_subscription_interval, decimal_to_quotation, now, quotation_to_decimal, ) logger = logging.getLogger(__name__) @contextmanager def MockedClient( token: str, *, settings: StrategySettings, real_market_data_test_from: datetime, real_market_data_test_start: datetime, real_market_data_test_end: datetime, balance: MoneyValue, options: Optional[ChannelArgumentType] = None, ) -> Generator[Services, None, None]: with create_channel(options=options) as channel: with freeze_time(real_market_data_test_start) as frozen_datetime: yield MockedServices( channel=channel, token=token, settings=settings, frozen_datetime=frozen_datetime, real_market_data_test_from=real_market_data_test_from, real_market_data_test_start=real_market_data_test_start, real_market_data_test_end=real_market_data_test_end, balance=balance, ) @contextmanager def MockedSandboxClient( token: str, *, balance: MoneyValue, options: Optional[ChannelArgumentType] = None, ) -> Generator[Services, None, None]: with create_channel(options=options) as channel: services = MockedSandboxServices( channel=channel, token=token, balance=balance, ) try: yield services except Exception: del services raise class MockedServices(Services): def __init__( self, channel: Channel, token: str, settings: StrategySettings, frozen_datetime, real_market_data_test_from: datetime, real_market_data_test_start: datetime, real_market_data_test_end: datetime, balance: MoneyValue, ): super().__init__(channel, token) self._settings = settings self._figi = settings.share_id self._current_market_data: List[Candle] = [] self._portfolio_positions: Dict[str, PortfolioPosition] = {} self._real_market_data_test_from = real_market_data_test_from self._real_market_data_test_start = real_market_data_test_start self._real_market_data_test_end = real_market_data_test_end self._balance = balance self._frozen_datetime = frozen_datetime _ = self._real_market_data self.market_data.get_candles = self._mocked_market_data_get_candles() self.orders.post_order = self._mocked_orders_post_order() self.operations.get_portfolio = self._mocked_operations_get_portfolio() self.market_data_stream.market_data_stream = self._mocked_market_data_stream() self.users.get_margin_attributes = self._mocked_users_get_margin_attributes() def _mocked_orders_post_order(self) -> Callable[[Any], Any]: def _post_order( # pylint: disable=too-many-locals *, figi: str = "", quantity: int = 0, price: Optional[Quotation] = None, # pylint: disable=unused-argument direction: OrderDirection = OrderDirection(0), account_id: str = "", # pylint: disable=unused-argument order_type: OrderType = OrderType(0), # pylint: disable=unused-argument order_id: str = "", # pylint: disable=unused-argument ): last_candle = self._current_market_data[-1] last_market_price = quotation_to_decimal(last_candle.close) position = self._portfolio_positions.get(figi) if position is None: position = PortfolioPosition( figi=figi, quantity=decimal_to_quotation(Decimal(0)), ) if direction == OrderDirection.ORDER_DIRECTION_SELL: quantity_delta = -quantity balance_delta = last_market_price * quantity elif direction == OrderDirection.ORDER_DIRECTION_BUY: quantity_delta = +quantity balance_delta = -(last_market_price * quantity) else: raise AssertionError("Incorrect direction") logger.warning("Operation: %s, %s", direction, balance_delta) old_quantity = quotation_to_decimal(position.quantity) new_quantity = decimal_to_quotation(old_quantity + quantity_delta) position.quantity.units = new_quantity.units position.quantity.nano = new_quantity.nano old_balance = quotation_to_decimal( Quotation(units=self._balance.units, nano=self._balance.nano) ) new_balance = decimal_to_quotation(old_balance + balance_delta) self._balance.units = new_balance.units self._balance.nano = new_balance.nano self._portfolio_positions[figi] = position return _post_order # type: ignore @cached_property def _portfolio_response(self) -> PortfolioResponse: return PortfolioResponse( total_amount_shares=MoneyValue(currency="rub", units=28691, nano=300000000), total_amount_bonds=MoneyValue(currency="rub", units=0, nano=0), total_amount_etf=MoneyValue(currency="rub", units=0, nano=0), total_amount_currencies=self._balance, total_amount_futures=MoneyValue(currency="rub", units=0, nano=0), expected_yield=Quotation(units=0, nano=-350000000), positions=list(self._portfolio_positions.values()), ) def _mocked_operations_get_portfolio(self) -> Callable[[Any], Any]: def _get_portfolio(*args, **kwars): # pylint: disable=unused-argument return self._portfolio_response return _get_portfolio def _mocked_market_data_stream(self) -> Callable[[Any], Any]: self._frozen_datetime.move_to(self._real_market_data_test_start) def _market_data_stream(*args, **kwargs): # pylint: disable=unused-argument yield MarketDataResponse(candle=None) # type: ignore interval = candle_interval_to_subscription_interval( self._settings.candle_interval ) for historic_candle in self._after_start_candles: candle = Candle( figi=self._figi, interval=interval, open=historic_candle.open, high=historic_candle.high, low=historic_candle.low, close=historic_candle.close, volume=historic_candle.volume, time=historic_candle.time, ) self._current_market_data.append(candle) yield MarketDataResponse(candle=candle) self._frozen_datetime.move_to(now() + timedelta(minutes=1)) return _market_data_stream @cached_property def _real_market_data(self) -> List[HistoricCandle]: real_market_data = [] for candle in self.get_all_candles( figi=self._figi, from_=self._real_market_data_test_from, to=self._real_market_data_test_end, interval=self._settings.candle_interval, ): real_market_data.append(candle) return real_market_data @cached_property def _initial_candles(self) -> List[HistoricCandle]: return [ candle for candle in self._real_market_data if candle.time < self._real_market_data_test_start ] @cached_property def _after_start_candles(self) -> List[HistoricCandle]: return [ candle for candle in self._real_market_data if candle.time >= self._real_market_data_test_start ] def _mocked_market_data_get_candles(self): def _get_candles(*args, **kwargs): # pylint: disable=unused-argument return GetCandlesResponse(candles=self._initial_candles) return _get_candles def _mocked_users_get_margin_attributes(self): def _get_margin_attributes(*agrs, **kwargs): # pylint: disable=unused-argument return GetMarginAttributesResponse( liquid_portfolio=MoneyValue(currency="", units=0, nano=0), starting_margin=MoneyValue(currency="", units=0, nano=0), minimal_margin=MoneyValue(currency="", units=0, nano=0), funds_sufficiency_level=Quotation(units=322, nano=0), amount_of_missing_funds=MoneyValue(currency="", units=0, nano=0), ) return _get_margin_attributes class MockedSandboxServices(Services): def __init__( self, channel: Channel, token: str, balance: MoneyValue, ): super().__init__(channel, token) self.orders.post_order = self._mocked_orders_post_order() self.operations.get_portfolio = self._mocked_operations_get_portfolio() self.users.get_margin_attributes = self._mocked_users_get_margin_attributes() self._account_id = self._open_account() self._pay_in(balance) def _pay_in(self, amount: MoneyValue): logger.info("Pay in %s for %s", amount, self._account_id) self.sandbox.sandbox_pay_in(account_id=self._account_id, amount=amount) def _open_account(self) -> AccountId: response = self.sandbox.open_sandbox_account() self._account_id = response.account_id logger.info("Opened sandbox account %s", self._account_id) return self._account_id def __del__(self): self._close_account() def _close_account(self) -> None: logger.info("Closing sandbox account %s", self._account_id) self.sandbox.close_sandbox_account(account_id=self._account_id) def _mocked_orders_post_order( self, ) -> Callable[[Any], PostOrderResponse]: def _post_order( *, figi: str = "", quantity: int = 0, price: Optional[Quotation] = None, direction: OrderDirection = OrderDirection(0), _: str = "", order_type: OrderType = OrderType(0), order_id: str = "", ) -> PostOrderResponse: return self.sandbox.post_sandbox_order( figi=figi, quantity=quantity, price=price, direction=direction, account_id=self._account_id, order_type=order_type, order_id=order_id, ) return _post_order # type: ignore def _mocked_operations_get_portfolio(self) -> Callable[[Any], Any]: def _get_sandbox_portfolio( *, account_id: str = "" # pylint: disable=unused-argument ) -> PortfolioResponse: return self.sandbox.get_sandbox_portfolio(account_id=self._account_id) return _get_sandbox_portfolio # type: ignore def _mocked_users_get_margin_attributes(self) -> Callable[[Any], PortfolioResponse]: def _get_margin_attributes(*agrs, **kwargs): # pylint: disable=unused-argument return GetMarginAttributesResponse( liquid_portfolio=MoneyValue(currency="", units=0, nano=0), starting_margin=MoneyValue(currency="", units=0, nano=0), minimal_margin=MoneyValue(currency="", units=0, nano=0), funds_sufficiency_level=Quotation(units=322, nano=0), amount_of_missing_funds=MoneyValue(currency="", units=0, nano=0), ) return _get_margin_attributes