173 lines
6.5 KiB
Python
173 lines
6.5 KiB
Python
import logging
|
|
import time
|
|
from typing import Iterator, List
|
|
|
|
import t_tech
|
|
from t_tech.invest import (
|
|
CandleInstrument,
|
|
InvestError,
|
|
MarketDataRequest,
|
|
MarketDataResponse,
|
|
SubscribeCandlesRequest,
|
|
SubscriptionAction,
|
|
SubscriptionInterval,
|
|
)
|
|
from t_tech.invest.services import Services
|
|
from t_tech.invest.strategies.base.account_manager import AccountManager
|
|
from t_tech.invest.strategies.base.errors import MarketDataNotAvailableError
|
|
from t_tech.invest.strategies.base.event import DataEvent, SignalEvent
|
|
from t_tech.invest.strategies.base.models import CandleEvent
|
|
from t_tech.invest.strategies.base.signal import CloseSignal, OpenSignal, Signal
|
|
from t_tech.invest.strategies.base.signal_executor_base import SignalExecutor
|
|
from t_tech.invest.strategies.base.trader_base import Trader
|
|
from t_tech.invest.strategies.moving_average.strategy import MovingAverageStrategy
|
|
from t_tech.invest.strategies.moving_average.strategy_settings import (
|
|
MovingAverageStrategySettings,
|
|
)
|
|
from t_tech.invest.strategies.moving_average.strategy_state import (
|
|
MovingAverageStrategyState,
|
|
)
|
|
from t_tech.invest.strategies.moving_average.supervisor import (
|
|
MovingAverageStrategySupervisor,
|
|
)
|
|
from t_tech.invest.utils import floor_datetime, now
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MovingAverageStrategyTrader(Trader):
|
|
def __init__(
|
|
self,
|
|
strategy: MovingAverageStrategy,
|
|
settings: MovingAverageStrategySettings,
|
|
services: Services,
|
|
state: MovingAverageStrategyState,
|
|
signal_executor: SignalExecutor,
|
|
account_manager: AccountManager,
|
|
supervisor: MovingAverageStrategySupervisor,
|
|
):
|
|
super().__init__(strategy, services, settings)
|
|
self._settings: MovingAverageStrategySettings = settings
|
|
self._strategy = strategy
|
|
self._services = services
|
|
self._data: List[CandleEvent]
|
|
self._market_data_stream: Iterator[MarketDataResponse]
|
|
self._state = state
|
|
self._signal_executor = signal_executor
|
|
self._account_manager = account_manager
|
|
self._supervisor = supervisor
|
|
|
|
self._data = list(
|
|
self._load_candles(
|
|
(self._settings.short_period + self._settings.long_period) * 3
|
|
)
|
|
)
|
|
for candle_event in self._data:
|
|
self._supervisor.notify(self._convert_to_data_event(candle_event))
|
|
self._strategy.fit(self._data)
|
|
|
|
self._ensure_marginal_trade_active()
|
|
self._subscribe()
|
|
|
|
def _ensure_marginal_trade_active(self) -> None:
|
|
self._account_manager.ensure_marginal_trade()
|
|
|
|
def _subscribe(self):
|
|
current_instrument = CandleInstrument(
|
|
figi=self._settings.share_id,
|
|
interval=SubscriptionInterval.SUBSCRIPTION_INTERVAL_ONE_MINUTE,
|
|
)
|
|
candle_subscribe_request = MarketDataRequest(
|
|
subscribe_candles_request=SubscribeCandlesRequest(
|
|
subscription_action=SubscriptionAction.SUBSCRIPTION_ACTION_SUBSCRIBE,
|
|
instruments=[current_instrument],
|
|
)
|
|
)
|
|
|
|
def request_iterator():
|
|
yield candle_subscribe_request
|
|
while True:
|
|
time.sleep(1)
|
|
|
|
self._market_data_stream = self._services.market_data_stream.market_data_stream(
|
|
request_iterator()
|
|
)
|
|
|
|
def _is_candle_fresh(self, candle: t_tech.invest.Candle) -> bool:
|
|
is_fresh_border = floor_datetime(
|
|
now(), delta=self._settings.candle_interval_timedelta
|
|
)
|
|
logger.debug(
|
|
"Checking if candle is fresh: candle.time=%s > is_fresh_border=%s %s)",
|
|
candle.time,
|
|
is_fresh_border,
|
|
candle.time >= is_fresh_border,
|
|
)
|
|
return candle.time >= is_fresh_border
|
|
|
|
@staticmethod
|
|
def _convert_to_data_event(candle_event: CandleEvent) -> DataEvent:
|
|
return DataEvent(candle_event=candle_event, time=candle_event.time)
|
|
|
|
def _make_observations(self) -> None:
|
|
while True:
|
|
market_data_response: MarketDataResponse = next(self._market_data_stream)
|
|
logger.debug("got market_data_response: %s", market_data_response)
|
|
if market_data_response.candle is None:
|
|
logger.debug("market_data_response didn't have candle")
|
|
continue
|
|
candle = market_data_response.candle
|
|
logger.debug("candle extracted: %s", candle)
|
|
candle_event = self._convert_candle(candle)
|
|
self._strategy.observe(candle_event)
|
|
self._supervisor.notify(self._convert_to_data_event(candle_event))
|
|
if self._is_candle_fresh(candle):
|
|
logger.info("Data refreshed")
|
|
break
|
|
|
|
def _refresh_data(self) -> None:
|
|
logger.info("Refreshing data")
|
|
try:
|
|
self._make_observations()
|
|
except StopIteration as e:
|
|
logger.info("Fresh quotations not available")
|
|
raise MarketDataNotAvailableError() from e
|
|
|
|
def _filter_closing_signals(self, signals: List[Signal]) -> List[Signal]:
|
|
return list(filter(lambda signal: isinstance(signal, CloseSignal), signals))
|
|
|
|
def _filter_opening_signals(self, signals: List[Signal]) -> List[Signal]:
|
|
return list(filter(lambda signal: isinstance(signal, OpenSignal), signals))
|
|
|
|
def _execute(self, signal: Signal) -> None:
|
|
logger.info("Trying to execute signal %s", signal)
|
|
try:
|
|
self._signal_executor.execute(signal) # type: ignore
|
|
except InvestError:
|
|
was_executed = False
|
|
else:
|
|
was_executed = True
|
|
self._supervisor.notify(
|
|
SignalEvent(signal=signal, was_executed=was_executed, time=now())
|
|
)
|
|
|
|
def _get_signals(self) -> List[Signal]:
|
|
signals = list(self._strategy.predict())
|
|
return [
|
|
*self._filter_closing_signals(signals),
|
|
*self._filter_opening_signals(signals),
|
|
]
|
|
|
|
def trade(self) -> None:
|
|
"""Делает попытку следовать стратегии."""
|
|
logger.info("Balance: %s", self._account_manager.get_current_balance())
|
|
self._refresh_data()
|
|
|
|
signals = self._get_signals()
|
|
if signals:
|
|
logger.info("Got signals %s", signals)
|
|
for signal in signals:
|
|
self._execute(signal)
|
|
if self._state.position == 0:
|
|
logger.info("Trade try complete")
|