Files

306 lines
10 KiB
Python

import logging
from datetime import datetime, timedelta
from decimal import Decimal
from typing import Callable, Iterable, List
import numpy as np
from t_tech.invest.strategies.base.account_manager import AccountManager
from t_tech.invest.strategies.base.errors import (
CandleEventForDateNotFound,
NotEnoughData,
OldCandleObservingError,
)
from t_tech.invest.strategies.base.models import CandleEvent
from t_tech.invest.strategies.base.signal import (
CloseLongMarketOrder,
CloseShortMarketOrder,
OpenLongMarketOrder,
OpenShortMarketOrder,
Signal,
)
from t_tech.invest.strategies.base.strategy_interface import InvestStrategy
from t_tech.invest.strategies.moving_average.strategy_settings import (
MovingAverageStrategySettings,
)
from t_tech.invest.strategies.moving_average.strategy_state import (
MovingAverageStrategyState,
)
from t_tech.invest.utils import (
candle_interval_to_timedelta,
ceil_datetime,
floor_datetime,
now,
)
logger = logging.getLogger(__name__)
class MovingAverageStrategy(InvestStrategy):
def __init__(
self,
settings: MovingAverageStrategySettings,
account_manager: AccountManager,
state: MovingAverageStrategyState,
):
self._data: List[CandleEvent] = []
self._settings = settings
self._account_manager = account_manager
self._state = state
self._MA_LONG_START: Decimal
self._candle_interval_timedelta = candle_interval_to_timedelta(
self._settings.candle_interval
)
def _ensure_enough_candles(self) -> None:
candles_needed = (
self._settings.short_period + self._settings.long_period
) / self._settings.candle_interval_timedelta
if candles_needed > len(self._data):
raise NotEnoughData(
f"Got {len(self._data)} candles but needed {candles_needed}"
)
logger.info("Got enough data for strategy")
def fit(self, candles: Iterable[CandleEvent]) -> None:
logger.debug("Strategy fitting with candles %s", candles)
for candle in candles:
self.observe(candle)
self._ensure_enough_candles()
def _append_candle_event(self, candle_event: CandleEvent) -> None:
last_candle_event = self._data[-1]
last_interval_floor = floor_datetime(
last_candle_event.time, self._candle_interval_timedelta
)
last_interval_ceil = ceil_datetime(
last_candle_event.time, self._candle_interval_timedelta
)
if candle_event.time < last_interval_floor:
raise OldCandleObservingError()
if (
candle_event.time < last_interval_ceil
or candle_event.time == last_interval_floor
):
self._data[-1] = candle_event
else:
self._data.append(candle_event)
def observe(self, candle: CandleEvent) -> None:
logger.debug("Observing candle event: %s", candle)
if len(self._data) > 0:
self._append_candle_event(candle)
else:
self._data.append(candle)
@staticmethod
def _get_newer_than_datetime_predicate(
anchor: datetime,
) -> Callable[[CandleEvent], bool]:
def _(event: CandleEvent) -> bool:
return event.time > anchor
return _
def _filter_from_the_end_with_early_stop(
self, predicate: Callable[[CandleEvent], bool]
) -> Iterable[CandleEvent]:
for event in reversed(self._data):
if not predicate(event):
break
yield event
def _select_for_period(self, period: timedelta):
predicate = self._get_newer_than_datetime_predicate(now() - period)
return self._filter_from_the_end_with_early_stop(predicate)
@staticmethod
def _get_prices(events: Iterable[CandleEvent]) -> Iterable[Decimal]:
for event in events:
yield event.candle.close
def _calculate_moving_average(self, period: timedelta) -> Decimal:
prices = list(self._get_prices(self._select_for_period(period)))
logger.debug("Selected prices: %s", prices)
return np.mean(prices, axis=0) # type: ignore
def _calculate_std(self, period: timedelta) -> Decimal:
prices = list(self._get_prices(self._select_for_period(period)))
return np.std(prices, axis=0) # type: ignore
def _get_first_candle_before(self, date: datetime) -> CandleEvent:
predicate = self._get_newer_than_datetime_predicate(date)
for event in reversed(self._data):
if not predicate(event):
return event
raise CandleEventForDateNotFound()
def _init_MA_LONG_START(self):
date = now() - self._settings.short_period
event = self._get_first_candle_before(date)
self._MA_LONG_START = event.candle.close
@staticmethod
def _is_long_open_signal(
MA_SHORT: Decimal,
MA_LONG: Decimal,
PRICE: Decimal,
STD: Decimal,
MA_LONG_START: Decimal,
) -> bool:
logger.debug("Try long opening")
logger.debug("\tMA_SHORT > MA_LONG, %s", MA_SHORT > MA_LONG)
logger.debug(
"\tand abs((PRICE - MA_LONG) / MA_LONG) < STD, %s",
abs((PRICE - MA_LONG) / MA_LONG) < STD,
)
logger.debug("\tand MA_LONG < MA_LONG_START, %s", MA_LONG > MA_LONG_START)
logger.debug(
"== %s",
MA_SHORT > MA_LONG > MA_LONG_START
and abs((PRICE - MA_LONG) / MA_LONG) < STD,
)
return (
MA_SHORT > MA_LONG > MA_LONG_START
and abs((PRICE - MA_LONG) / MA_LONG) < STD
)
@staticmethod
def _is_short_open_signal(
MA_SHORT: Decimal,
MA_LONG: Decimal,
PRICE: Decimal,
STD: Decimal,
MA_LONG_START: Decimal,
) -> bool:
logger.debug("Try short opening")
logger.debug("\tMA_SHORT < MA_LONG, %s", MA_SHORT < MA_LONG)
logger.debug(
"\tand abs((PRICE - MA_LONG) / MA_LONG) < STD, %s",
abs((PRICE - MA_LONG) / MA_LONG) < STD,
)
logger.debug("\tand MA_LONG > MA_LONG_START, %s", MA_LONG < MA_LONG_START)
logger.debug(
"== %s",
MA_SHORT < MA_LONG < MA_LONG_START
and abs((PRICE - MA_LONG) / MA_LONG) < STD,
)
return (
MA_SHORT < MA_LONG < MA_LONG_START
and abs((PRICE - MA_LONG) / MA_LONG) < STD
)
@staticmethod
def _is_long_close_signal(
MA_LONG: Decimal,
PRICE: Decimal,
STD: Decimal,
has_short_open_signal: bool,
) -> bool:
logger.debug("Try long closing")
logger.debug("\tPRICE > MA_LONG + 10 * STD, %s", PRICE > MA_LONG + 10 * STD)
logger.debug("\tor has_short_open_signal, %s", has_short_open_signal)
logger.debug("\tor PRICE < MA_LONG - 3 * STD, %s", PRICE < MA_LONG - 3 * STD)
logger.debug(
"== %s",
PRICE > MA_LONG + 10 * STD
or has_short_open_signal
or PRICE < MA_LONG - 3 * STD,
)
return (
PRICE > MA_LONG + 10 * STD
or has_short_open_signal
or PRICE < MA_LONG - 3 * STD
)
@staticmethod
def _is_short_close_signal(
MA_LONG: Decimal,
PRICE: Decimal,
STD: Decimal,
has_long_open_signal: bool,
) -> bool:
logger.debug("Try short closing")
logger.debug("\tPRICE < MA_LONG - 10 * STD, %s", PRICE < MA_LONG - 10 * STD)
logger.debug("\tor has_long_open_signal, %s", has_long_open_signal)
logger.debug("\tor PRICE > MA_LONG + 3 * STD, %s", PRICE > MA_LONG + 3 * STD)
logger.debug(
"== %s",
PRICE < MA_LONG - 10 * STD # кажется, что не работает закрытие
or has_long_open_signal
or PRICE > MA_LONG + 3 * STD,
)
return (
PRICE < MA_LONG - 10 * STD
or has_long_open_signal
or PRICE > MA_LONG + 3 * STD
)
def predict(self) -> Iterable[Signal]:
logger.info("Strategy predict")
self._init_MA_LONG_START()
MA_LONG_START = self._MA_LONG_START
logger.debug("MA_LONG_START: %s", MA_LONG_START)
PRICE = self._data[-1].candle.close
logger.debug("PRICE: %s", PRICE)
MA_LONG = self._calculate_moving_average(self._settings.long_period)
logger.debug("MA_LONG: %s", MA_LONG)
MA_SHORT = self._calculate_moving_average(self._settings.short_period)
logger.debug("MA_SHORT: %s", MA_SHORT)
STD = self._calculate_std(self._settings.std_period)
logger.debug("STD: %s", STD)
MONEY = self._account_manager.get_current_balance()
logger.debug("MONEY: %s", MONEY)
has_long_open_signal = False
has_short_open_signal = False
possible_lots = int(MONEY // PRICE)
if (
not self._state.long_open
and self._is_long_open_signal(
MA_SHORT=MA_SHORT,
MA_LONG=MA_LONG,
PRICE=PRICE,
STD=STD,
MA_LONG_START=MA_LONG_START,
)
and possible_lots > 0
):
has_long_open_signal = True
yield OpenLongMarketOrder(lots=possible_lots)
if (
not self._state.short_open
and self._is_short_open_signal(
MA_SHORT=MA_SHORT,
MA_LONG=MA_LONG,
PRICE=PRICE,
STD=STD,
MA_LONG_START=MA_LONG_START,
)
and possible_lots > 0
):
has_short_open_signal = True
yield OpenShortMarketOrder(lots=possible_lots)
if self._state.long_open and self._is_long_close_signal(
MA_LONG=MA_LONG,
PRICE=PRICE,
STD=STD,
has_short_open_signal=has_short_open_signal,
):
yield CloseLongMarketOrder(lots=self._state.position)
if self._state.short_open and self._is_short_close_signal(
MA_LONG=MA_LONG,
PRICE=PRICE,
STD=STD,
has_long_open_signal=has_long_open_signal,
):
yield CloseShortMarketOrder(lots=self._state.position)