306 lines
10 KiB
Python
306 lines
10 KiB
Python
import logging
|
|
from datetime import datetime, timedelta
|
|
from decimal import Decimal
|
|
from typing import Callable, Iterable, List
|
|
|
|
import numpy as np
|
|
|
|
from t_tech.invest.strategies.base.account_manager import AccountManager
|
|
from t_tech.invest.strategies.base.errors import (
|
|
CandleEventForDateNotFound,
|
|
NotEnoughData,
|
|
OldCandleObservingError,
|
|
)
|
|
from t_tech.invest.strategies.base.models import CandleEvent
|
|
from t_tech.invest.strategies.base.signal import (
|
|
CloseLongMarketOrder,
|
|
CloseShortMarketOrder,
|
|
OpenLongMarketOrder,
|
|
OpenShortMarketOrder,
|
|
Signal,
|
|
)
|
|
from t_tech.invest.strategies.base.strategy_interface import InvestStrategy
|
|
from t_tech.invest.strategies.moving_average.strategy_settings import (
|
|
MovingAverageStrategySettings,
|
|
)
|
|
from t_tech.invest.strategies.moving_average.strategy_state import (
|
|
MovingAverageStrategyState,
|
|
)
|
|
from t_tech.invest.utils import (
|
|
candle_interval_to_timedelta,
|
|
ceil_datetime,
|
|
floor_datetime,
|
|
now,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MovingAverageStrategy(InvestStrategy):
|
|
def __init__(
|
|
self,
|
|
settings: MovingAverageStrategySettings,
|
|
account_manager: AccountManager,
|
|
state: MovingAverageStrategyState,
|
|
):
|
|
self._data: List[CandleEvent] = []
|
|
self._settings = settings
|
|
self._account_manager = account_manager
|
|
|
|
self._state = state
|
|
self._MA_LONG_START: Decimal
|
|
self._candle_interval_timedelta = candle_interval_to_timedelta(
|
|
self._settings.candle_interval
|
|
)
|
|
|
|
def _ensure_enough_candles(self) -> None:
|
|
candles_needed = (
|
|
self._settings.short_period + self._settings.long_period
|
|
) / self._settings.candle_interval_timedelta
|
|
if candles_needed > len(self._data):
|
|
raise NotEnoughData(
|
|
f"Got {len(self._data)} candles but needed {candles_needed}"
|
|
)
|
|
logger.info("Got enough data for strategy")
|
|
|
|
def fit(self, candles: Iterable[CandleEvent]) -> None:
|
|
logger.debug("Strategy fitting with candles %s", candles)
|
|
for candle in candles:
|
|
self.observe(candle)
|
|
self._ensure_enough_candles()
|
|
|
|
def _append_candle_event(self, candle_event: CandleEvent) -> None:
|
|
last_candle_event = self._data[-1]
|
|
last_interval_floor = floor_datetime(
|
|
last_candle_event.time, self._candle_interval_timedelta
|
|
)
|
|
last_interval_ceil = ceil_datetime(
|
|
last_candle_event.time, self._candle_interval_timedelta
|
|
)
|
|
|
|
if candle_event.time < last_interval_floor:
|
|
raise OldCandleObservingError()
|
|
if (
|
|
candle_event.time < last_interval_ceil
|
|
or candle_event.time == last_interval_floor
|
|
):
|
|
self._data[-1] = candle_event
|
|
else:
|
|
self._data.append(candle_event)
|
|
|
|
def observe(self, candle: CandleEvent) -> None:
|
|
logger.debug("Observing candle event: %s", candle)
|
|
|
|
if len(self._data) > 0:
|
|
self._append_candle_event(candle)
|
|
else:
|
|
self._data.append(candle)
|
|
|
|
@staticmethod
|
|
def _get_newer_than_datetime_predicate(
|
|
anchor: datetime,
|
|
) -> Callable[[CandleEvent], bool]:
|
|
def _(event: CandleEvent) -> bool:
|
|
return event.time > anchor
|
|
|
|
return _
|
|
|
|
def _filter_from_the_end_with_early_stop(
|
|
self, predicate: Callable[[CandleEvent], bool]
|
|
) -> Iterable[CandleEvent]:
|
|
for event in reversed(self._data):
|
|
if not predicate(event):
|
|
break
|
|
yield event
|
|
|
|
def _select_for_period(self, period: timedelta):
|
|
predicate = self._get_newer_than_datetime_predicate(now() - period)
|
|
return self._filter_from_the_end_with_early_stop(predicate)
|
|
|
|
@staticmethod
|
|
def _get_prices(events: Iterable[CandleEvent]) -> Iterable[Decimal]:
|
|
for event in events:
|
|
yield event.candle.close
|
|
|
|
def _calculate_moving_average(self, period: timedelta) -> Decimal:
|
|
prices = list(self._get_prices(self._select_for_period(period)))
|
|
logger.debug("Selected prices: %s", prices)
|
|
return np.mean(prices, axis=0) # type: ignore
|
|
|
|
def _calculate_std(self, period: timedelta) -> Decimal:
|
|
prices = list(self._get_prices(self._select_for_period(period)))
|
|
return np.std(prices, axis=0) # type: ignore
|
|
|
|
def _get_first_candle_before(self, date: datetime) -> CandleEvent:
|
|
predicate = self._get_newer_than_datetime_predicate(date)
|
|
for event in reversed(self._data):
|
|
if not predicate(event):
|
|
return event
|
|
raise CandleEventForDateNotFound()
|
|
|
|
def _init_MA_LONG_START(self):
|
|
date = now() - self._settings.short_period
|
|
event = self._get_first_candle_before(date)
|
|
self._MA_LONG_START = event.candle.close
|
|
|
|
@staticmethod
|
|
def _is_long_open_signal(
|
|
MA_SHORT: Decimal,
|
|
MA_LONG: Decimal,
|
|
PRICE: Decimal,
|
|
STD: Decimal,
|
|
MA_LONG_START: Decimal,
|
|
) -> bool:
|
|
logger.debug("Try long opening")
|
|
logger.debug("\tMA_SHORT > MA_LONG, %s", MA_SHORT > MA_LONG)
|
|
logger.debug(
|
|
"\tand abs((PRICE - MA_LONG) / MA_LONG) < STD, %s",
|
|
abs((PRICE - MA_LONG) / MA_LONG) < STD,
|
|
)
|
|
logger.debug("\tand MA_LONG < MA_LONG_START, %s", MA_LONG > MA_LONG_START)
|
|
logger.debug(
|
|
"== %s",
|
|
MA_SHORT > MA_LONG > MA_LONG_START
|
|
and abs((PRICE - MA_LONG) / MA_LONG) < STD,
|
|
)
|
|
return (
|
|
MA_SHORT > MA_LONG > MA_LONG_START
|
|
and abs((PRICE - MA_LONG) / MA_LONG) < STD
|
|
)
|
|
|
|
@staticmethod
|
|
def _is_short_open_signal(
|
|
MA_SHORT: Decimal,
|
|
MA_LONG: Decimal,
|
|
PRICE: Decimal,
|
|
STD: Decimal,
|
|
MA_LONG_START: Decimal,
|
|
) -> bool:
|
|
logger.debug("Try short opening")
|
|
logger.debug("\tMA_SHORT < MA_LONG, %s", MA_SHORT < MA_LONG)
|
|
logger.debug(
|
|
"\tand abs((PRICE - MA_LONG) / MA_LONG) < STD, %s",
|
|
abs((PRICE - MA_LONG) / MA_LONG) < STD,
|
|
)
|
|
logger.debug("\tand MA_LONG > MA_LONG_START, %s", MA_LONG < MA_LONG_START)
|
|
logger.debug(
|
|
"== %s",
|
|
MA_SHORT < MA_LONG < MA_LONG_START
|
|
and abs((PRICE - MA_LONG) / MA_LONG) < STD,
|
|
)
|
|
return (
|
|
MA_SHORT < MA_LONG < MA_LONG_START
|
|
and abs((PRICE - MA_LONG) / MA_LONG) < STD
|
|
)
|
|
|
|
@staticmethod
|
|
def _is_long_close_signal(
|
|
MA_LONG: Decimal,
|
|
PRICE: Decimal,
|
|
STD: Decimal,
|
|
has_short_open_signal: bool,
|
|
) -> bool:
|
|
logger.debug("Try long closing")
|
|
logger.debug("\tPRICE > MA_LONG + 10 * STD, %s", PRICE > MA_LONG + 10 * STD)
|
|
logger.debug("\tor has_short_open_signal, %s", has_short_open_signal)
|
|
logger.debug("\tor PRICE < MA_LONG - 3 * STD, %s", PRICE < MA_LONG - 3 * STD)
|
|
logger.debug(
|
|
"== %s",
|
|
PRICE > MA_LONG + 10 * STD
|
|
or has_short_open_signal
|
|
or PRICE < MA_LONG - 3 * STD,
|
|
)
|
|
return (
|
|
PRICE > MA_LONG + 10 * STD
|
|
or has_short_open_signal
|
|
or PRICE < MA_LONG - 3 * STD
|
|
)
|
|
|
|
@staticmethod
|
|
def _is_short_close_signal(
|
|
MA_LONG: Decimal,
|
|
PRICE: Decimal,
|
|
STD: Decimal,
|
|
has_long_open_signal: bool,
|
|
) -> bool:
|
|
logger.debug("Try short closing")
|
|
logger.debug("\tPRICE < MA_LONG - 10 * STD, %s", PRICE < MA_LONG - 10 * STD)
|
|
logger.debug("\tor has_long_open_signal, %s", has_long_open_signal)
|
|
logger.debug("\tor PRICE > MA_LONG + 3 * STD, %s", PRICE > MA_LONG + 3 * STD)
|
|
logger.debug(
|
|
"== %s",
|
|
PRICE < MA_LONG - 10 * STD # кажется, что не работает закрытие
|
|
or has_long_open_signal
|
|
or PRICE > MA_LONG + 3 * STD,
|
|
)
|
|
return (
|
|
PRICE < MA_LONG - 10 * STD
|
|
or has_long_open_signal
|
|
or PRICE > MA_LONG + 3 * STD
|
|
)
|
|
|
|
def predict(self) -> Iterable[Signal]:
|
|
logger.info("Strategy predict")
|
|
self._init_MA_LONG_START()
|
|
MA_LONG_START = self._MA_LONG_START
|
|
logger.debug("MA_LONG_START: %s", MA_LONG_START)
|
|
PRICE = self._data[-1].candle.close
|
|
logger.debug("PRICE: %s", PRICE)
|
|
MA_LONG = self._calculate_moving_average(self._settings.long_period)
|
|
logger.debug("MA_LONG: %s", MA_LONG)
|
|
MA_SHORT = self._calculate_moving_average(self._settings.short_period)
|
|
logger.debug("MA_SHORT: %s", MA_SHORT)
|
|
STD = self._calculate_std(self._settings.std_period)
|
|
logger.debug("STD: %s", STD)
|
|
MONEY = self._account_manager.get_current_balance()
|
|
logger.debug("MONEY: %s", MONEY)
|
|
|
|
has_long_open_signal = False
|
|
has_short_open_signal = False
|
|
|
|
possible_lots = int(MONEY // PRICE)
|
|
|
|
if (
|
|
not self._state.long_open
|
|
and self._is_long_open_signal(
|
|
MA_SHORT=MA_SHORT,
|
|
MA_LONG=MA_LONG,
|
|
PRICE=PRICE,
|
|
STD=STD,
|
|
MA_LONG_START=MA_LONG_START,
|
|
)
|
|
and possible_lots > 0
|
|
):
|
|
has_long_open_signal = True
|
|
yield OpenLongMarketOrder(lots=possible_lots)
|
|
|
|
if (
|
|
not self._state.short_open
|
|
and self._is_short_open_signal(
|
|
MA_SHORT=MA_SHORT,
|
|
MA_LONG=MA_LONG,
|
|
PRICE=PRICE,
|
|
STD=STD,
|
|
MA_LONG_START=MA_LONG_START,
|
|
)
|
|
and possible_lots > 0
|
|
):
|
|
has_short_open_signal = True
|
|
yield OpenShortMarketOrder(lots=possible_lots)
|
|
|
|
if self._state.long_open and self._is_long_close_signal(
|
|
MA_LONG=MA_LONG,
|
|
PRICE=PRICE,
|
|
STD=STD,
|
|
has_short_open_signal=has_short_open_signal,
|
|
):
|
|
yield CloseLongMarketOrder(lots=self._state.position)
|
|
|
|
if self._state.short_open and self._is_short_close_signal(
|
|
MA_LONG=MA_LONG,
|
|
PRICE=PRICE,
|
|
STD=STD,
|
|
has_long_open_signal=has_long_open_signal,
|
|
):
|
|
yield CloseShortMarketOrder(lots=self._state.position)
|