RAPTOR v18.4: Исправлена отчетность, активированы выходные
This commit is contained in:
@@ -0,0 +1,305 @@
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from typing import Callable, Iterable, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from t_tech.invest.strategies.base.account_manager import AccountManager
|
||||
from t_tech.invest.strategies.base.errors import (
|
||||
CandleEventForDateNotFound,
|
||||
NotEnoughData,
|
||||
OldCandleObservingError,
|
||||
)
|
||||
from t_tech.invest.strategies.base.models import CandleEvent
|
||||
from t_tech.invest.strategies.base.signal import (
|
||||
CloseLongMarketOrder,
|
||||
CloseShortMarketOrder,
|
||||
OpenLongMarketOrder,
|
||||
OpenShortMarketOrder,
|
||||
Signal,
|
||||
)
|
||||
from t_tech.invest.strategies.base.strategy_interface import InvestStrategy
|
||||
from t_tech.invest.strategies.moving_average.strategy_settings import (
|
||||
MovingAverageStrategySettings,
|
||||
)
|
||||
from t_tech.invest.strategies.moving_average.strategy_state import (
|
||||
MovingAverageStrategyState,
|
||||
)
|
||||
from t_tech.invest.utils import (
|
||||
candle_interval_to_timedelta,
|
||||
ceil_datetime,
|
||||
floor_datetime,
|
||||
now,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MovingAverageStrategy(InvestStrategy):
|
||||
def __init__(
|
||||
self,
|
||||
settings: MovingAverageStrategySettings,
|
||||
account_manager: AccountManager,
|
||||
state: MovingAverageStrategyState,
|
||||
):
|
||||
self._data: List[CandleEvent] = []
|
||||
self._settings = settings
|
||||
self._account_manager = account_manager
|
||||
|
||||
self._state = state
|
||||
self._MA_LONG_START: Decimal
|
||||
self._candle_interval_timedelta = candle_interval_to_timedelta(
|
||||
self._settings.candle_interval
|
||||
)
|
||||
|
||||
def _ensure_enough_candles(self) -> None:
|
||||
candles_needed = (
|
||||
self._settings.short_period + self._settings.long_period
|
||||
) / self._settings.candle_interval_timedelta
|
||||
if candles_needed > len(self._data):
|
||||
raise NotEnoughData(
|
||||
f"Got {len(self._data)} candles but needed {candles_needed}"
|
||||
)
|
||||
logger.info("Got enough data for strategy")
|
||||
|
||||
def fit(self, candles: Iterable[CandleEvent]) -> None:
|
||||
logger.debug("Strategy fitting with candles %s", candles)
|
||||
for candle in candles:
|
||||
self.observe(candle)
|
||||
self._ensure_enough_candles()
|
||||
|
||||
def _append_candle_event(self, candle_event: CandleEvent) -> None:
|
||||
last_candle_event = self._data[-1]
|
||||
last_interval_floor = floor_datetime(
|
||||
last_candle_event.time, self._candle_interval_timedelta
|
||||
)
|
||||
last_interval_ceil = ceil_datetime(
|
||||
last_candle_event.time, self._candle_interval_timedelta
|
||||
)
|
||||
|
||||
if candle_event.time < last_interval_floor:
|
||||
raise OldCandleObservingError()
|
||||
if (
|
||||
candle_event.time < last_interval_ceil
|
||||
or candle_event.time == last_interval_floor
|
||||
):
|
||||
self._data[-1] = candle_event
|
||||
else:
|
||||
self._data.append(candle_event)
|
||||
|
||||
def observe(self, candle: CandleEvent) -> None:
|
||||
logger.debug("Observing candle event: %s", candle)
|
||||
|
||||
if len(self._data) > 0:
|
||||
self._append_candle_event(candle)
|
||||
else:
|
||||
self._data.append(candle)
|
||||
|
||||
@staticmethod
|
||||
def _get_newer_than_datetime_predicate(
|
||||
anchor: datetime,
|
||||
) -> Callable[[CandleEvent], bool]:
|
||||
def _(event: CandleEvent) -> bool:
|
||||
return event.time > anchor
|
||||
|
||||
return _
|
||||
|
||||
def _filter_from_the_end_with_early_stop(
|
||||
self, predicate: Callable[[CandleEvent], bool]
|
||||
) -> Iterable[CandleEvent]:
|
||||
for event in reversed(self._data):
|
||||
if not predicate(event):
|
||||
break
|
||||
yield event
|
||||
|
||||
def _select_for_period(self, period: timedelta):
|
||||
predicate = self._get_newer_than_datetime_predicate(now() - period)
|
||||
return self._filter_from_the_end_with_early_stop(predicate)
|
||||
|
||||
@staticmethod
|
||||
def _get_prices(events: Iterable[CandleEvent]) -> Iterable[Decimal]:
|
||||
for event in events:
|
||||
yield event.candle.close
|
||||
|
||||
def _calculate_moving_average(self, period: timedelta) -> Decimal:
|
||||
prices = list(self._get_prices(self._select_for_period(period)))
|
||||
logger.debug("Selected prices: %s", prices)
|
||||
return np.mean(prices, axis=0) # type: ignore
|
||||
|
||||
def _calculate_std(self, period: timedelta) -> Decimal:
|
||||
prices = list(self._get_prices(self._select_for_period(period)))
|
||||
return np.std(prices, axis=0) # type: ignore
|
||||
|
||||
def _get_first_candle_before(self, date: datetime) -> CandleEvent:
|
||||
predicate = self._get_newer_than_datetime_predicate(date)
|
||||
for event in reversed(self._data):
|
||||
if not predicate(event):
|
||||
return event
|
||||
raise CandleEventForDateNotFound()
|
||||
|
||||
def _init_MA_LONG_START(self):
|
||||
date = now() - self._settings.short_period
|
||||
event = self._get_first_candle_before(date)
|
||||
self._MA_LONG_START = event.candle.close
|
||||
|
||||
@staticmethod
|
||||
def _is_long_open_signal(
|
||||
MA_SHORT: Decimal,
|
||||
MA_LONG: Decimal,
|
||||
PRICE: Decimal,
|
||||
STD: Decimal,
|
||||
MA_LONG_START: Decimal,
|
||||
) -> bool:
|
||||
logger.debug("Try long opening")
|
||||
logger.debug("\tMA_SHORT > MA_LONG, %s", MA_SHORT > MA_LONG)
|
||||
logger.debug(
|
||||
"\tand abs((PRICE - MA_LONG) / MA_LONG) < STD, %s",
|
||||
abs((PRICE - MA_LONG) / MA_LONG) < STD,
|
||||
)
|
||||
logger.debug("\tand MA_LONG < MA_LONG_START, %s", MA_LONG > MA_LONG_START)
|
||||
logger.debug(
|
||||
"== %s",
|
||||
MA_SHORT > MA_LONG > MA_LONG_START
|
||||
and abs((PRICE - MA_LONG) / MA_LONG) < STD,
|
||||
)
|
||||
return (
|
||||
MA_SHORT > MA_LONG > MA_LONG_START
|
||||
and abs((PRICE - MA_LONG) / MA_LONG) < STD
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_short_open_signal(
|
||||
MA_SHORT: Decimal,
|
||||
MA_LONG: Decimal,
|
||||
PRICE: Decimal,
|
||||
STD: Decimal,
|
||||
MA_LONG_START: Decimal,
|
||||
) -> bool:
|
||||
logger.debug("Try short opening")
|
||||
logger.debug("\tMA_SHORT < MA_LONG, %s", MA_SHORT < MA_LONG)
|
||||
logger.debug(
|
||||
"\tand abs((PRICE - MA_LONG) / MA_LONG) < STD, %s",
|
||||
abs((PRICE - MA_LONG) / MA_LONG) < STD,
|
||||
)
|
||||
logger.debug("\tand MA_LONG > MA_LONG_START, %s", MA_LONG < MA_LONG_START)
|
||||
logger.debug(
|
||||
"== %s",
|
||||
MA_SHORT < MA_LONG < MA_LONG_START
|
||||
and abs((PRICE - MA_LONG) / MA_LONG) < STD,
|
||||
)
|
||||
return (
|
||||
MA_SHORT < MA_LONG < MA_LONG_START
|
||||
and abs((PRICE - MA_LONG) / MA_LONG) < STD
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_long_close_signal(
|
||||
MA_LONG: Decimal,
|
||||
PRICE: Decimal,
|
||||
STD: Decimal,
|
||||
has_short_open_signal: bool,
|
||||
) -> bool:
|
||||
logger.debug("Try long closing")
|
||||
logger.debug("\tPRICE > MA_LONG + 10 * STD, %s", PRICE > MA_LONG + 10 * STD)
|
||||
logger.debug("\tor has_short_open_signal, %s", has_short_open_signal)
|
||||
logger.debug("\tor PRICE < MA_LONG - 3 * STD, %s", PRICE < MA_LONG - 3 * STD)
|
||||
logger.debug(
|
||||
"== %s",
|
||||
PRICE > MA_LONG + 10 * STD
|
||||
or has_short_open_signal
|
||||
or PRICE < MA_LONG - 3 * STD,
|
||||
)
|
||||
return (
|
||||
PRICE > MA_LONG + 10 * STD
|
||||
or has_short_open_signal
|
||||
or PRICE < MA_LONG - 3 * STD
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_short_close_signal(
|
||||
MA_LONG: Decimal,
|
||||
PRICE: Decimal,
|
||||
STD: Decimal,
|
||||
has_long_open_signal: bool,
|
||||
) -> bool:
|
||||
logger.debug("Try short closing")
|
||||
logger.debug("\tPRICE < MA_LONG - 10 * STD, %s", PRICE < MA_LONG - 10 * STD)
|
||||
logger.debug("\tor has_long_open_signal, %s", has_long_open_signal)
|
||||
logger.debug("\tor PRICE > MA_LONG + 3 * STD, %s", PRICE > MA_LONG + 3 * STD)
|
||||
logger.debug(
|
||||
"== %s",
|
||||
PRICE < MA_LONG - 10 * STD # кажется, что не работает закрытие
|
||||
or has_long_open_signal
|
||||
or PRICE > MA_LONG + 3 * STD,
|
||||
)
|
||||
return (
|
||||
PRICE < MA_LONG - 10 * STD
|
||||
or has_long_open_signal
|
||||
or PRICE > MA_LONG + 3 * STD
|
||||
)
|
||||
|
||||
def predict(self) -> Iterable[Signal]:
|
||||
logger.info("Strategy predict")
|
||||
self._init_MA_LONG_START()
|
||||
MA_LONG_START = self._MA_LONG_START
|
||||
logger.debug("MA_LONG_START: %s", MA_LONG_START)
|
||||
PRICE = self._data[-1].candle.close
|
||||
logger.debug("PRICE: %s", PRICE)
|
||||
MA_LONG = self._calculate_moving_average(self._settings.long_period)
|
||||
logger.debug("MA_LONG: %s", MA_LONG)
|
||||
MA_SHORT = self._calculate_moving_average(self._settings.short_period)
|
||||
logger.debug("MA_SHORT: %s", MA_SHORT)
|
||||
STD = self._calculate_std(self._settings.std_period)
|
||||
logger.debug("STD: %s", STD)
|
||||
MONEY = self._account_manager.get_current_balance()
|
||||
logger.debug("MONEY: %s", MONEY)
|
||||
|
||||
has_long_open_signal = False
|
||||
has_short_open_signal = False
|
||||
|
||||
possible_lots = int(MONEY // PRICE)
|
||||
|
||||
if (
|
||||
not self._state.long_open
|
||||
and self._is_long_open_signal(
|
||||
MA_SHORT=MA_SHORT,
|
||||
MA_LONG=MA_LONG,
|
||||
PRICE=PRICE,
|
||||
STD=STD,
|
||||
MA_LONG_START=MA_LONG_START,
|
||||
)
|
||||
and possible_lots > 0
|
||||
):
|
||||
has_long_open_signal = True
|
||||
yield OpenLongMarketOrder(lots=possible_lots)
|
||||
|
||||
if (
|
||||
not self._state.short_open
|
||||
and self._is_short_open_signal(
|
||||
MA_SHORT=MA_SHORT,
|
||||
MA_LONG=MA_LONG,
|
||||
PRICE=PRICE,
|
||||
STD=STD,
|
||||
MA_LONG_START=MA_LONG_START,
|
||||
)
|
||||
and possible_lots > 0
|
||||
):
|
||||
has_short_open_signal = True
|
||||
yield OpenShortMarketOrder(lots=possible_lots)
|
||||
|
||||
if self._state.long_open and self._is_long_close_signal(
|
||||
MA_LONG=MA_LONG,
|
||||
PRICE=PRICE,
|
||||
STD=STD,
|
||||
has_short_open_signal=has_short_open_signal,
|
||||
):
|
||||
yield CloseLongMarketOrder(lots=self._state.position)
|
||||
|
||||
if self._state.short_open and self._is_short_close_signal(
|
||||
MA_LONG=MA_LONG,
|
||||
PRICE=PRICE,
|
||||
STD=STD,
|
||||
has_long_open_signal=has_long_open_signal,
|
||||
):
|
||||
yield CloseShortMarketOrder(lots=self._state.position)
|
||||
Reference in New Issue
Block a user