RAPTOR v18.4: Исправлена отчетность, активированы выходные
This commit is contained in:
@@ -0,0 +1,38 @@
|
||||
import logging
|
||||
from decimal import Decimal
|
||||
|
||||
from t_tech.invest import Quotation
|
||||
from t_tech.invest.services import Services
|
||||
from t_tech.invest.strategies.base.errors import (
|
||||
InsufficientMarginalTradeFunds,
|
||||
MarginalTradeIsNotActive,
|
||||
)
|
||||
from t_tech.invest.strategies.base.strategy_settings_base import StrategySettings
|
||||
from t_tech.invest.utils import quotation_to_decimal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AccountManager:
|
||||
def __init__(self, services: Services, strategy_settings: StrategySettings):
|
||||
self._services = services
|
||||
self._strategy_settings = strategy_settings
|
||||
|
||||
def get_current_balance(self) -> Decimal:
|
||||
account_id = self._strategy_settings.account_id
|
||||
portfolio_response = self._services.operations.get_portfolio(
|
||||
account_id=account_id
|
||||
)
|
||||
balance = portfolio_response.total_amount_currencies
|
||||
return quotation_to_decimal(Quotation(units=balance.units, nano=balance.nano))
|
||||
|
||||
def ensure_marginal_trade(self) -> None:
|
||||
account_id = self._strategy_settings.account_id
|
||||
try:
|
||||
response = self._services.users.get_margin_attributes(account_id=account_id)
|
||||
except Exception as e:
|
||||
raise MarginalTradeIsNotActive() from e
|
||||
value = quotation_to_decimal(response.funds_sufficiency_level)
|
||||
if value <= 1:
|
||||
raise InsufficientMarginalTradeFunds()
|
||||
logger.info("Marginal trade is active")
|
||||
38
invest-python-master/t_tech/invest/strategies/base/errors.py
Normal file
38
invest-python-master/t_tech/invest/strategies/base/errors.py
Normal file
@@ -0,0 +1,38 @@
|
||||
class StrategyError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NotEnoughData(StrategyError):
|
||||
pass
|
||||
|
||||
|
||||
class MarginalTradeIsNotActive(StrategyError):
|
||||
pass
|
||||
|
||||
|
||||
class InsufficientMarginalTradeFunds(StrategyError):
|
||||
pass
|
||||
|
||||
|
||||
class CandleEventForDateNotFound(StrategyError):
|
||||
pass
|
||||
|
||||
|
||||
class UnknownSignal(StrategyError):
|
||||
pass
|
||||
|
||||
|
||||
class OldCandleObservingError(StrategyError):
|
||||
pass
|
||||
|
||||
|
||||
class MarketDataNotAvailableError(StrategyError):
|
||||
pass
|
||||
|
||||
|
||||
class StrategySupervisorError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class EventsWereNotSupervised(StrategySupervisorError):
|
||||
pass
|
||||
21
invest-python-master/t_tech/invest/strategies/base/event.py
Normal file
21
invest-python-master/t_tech/invest/strategies/base/event.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
from t_tech.invest.strategies.base.models import CandleEvent
|
||||
from t_tech.invest.strategies.base.signal import Signal
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyEvent:
|
||||
time: datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataEvent(StrategyEvent):
|
||||
candle_event: CandleEvent
|
||||
|
||||
|
||||
@dataclass
|
||||
class SignalEvent(StrategyEvent):
|
||||
signal: Signal
|
||||
was_executed: bool
|
||||
19
invest-python-master/t_tech/invest/strategies/base/models.py
Normal file
19
invest-python-master/t_tech/invest/strategies/base/models.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=True)
|
||||
class Candle:
|
||||
open: Decimal
|
||||
high: Decimal
|
||||
low: Decimal
|
||||
close: Decimal
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=True)
|
||||
class CandleEvent:
|
||||
candle: Candle
|
||||
volume: int
|
||||
time: datetime
|
||||
is_complete: bool
|
||||
48
invest-python-master/t_tech/invest/strategies/base/signal.py
Normal file
48
invest-python-master/t_tech/invest/strategies/base/signal.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import enum
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
class SignalDirection(enum.Enum):
|
||||
LONG = "LONG"
|
||||
SHORT = "SHORT"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Signal:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrderSignal(Signal):
|
||||
lots: int
|
||||
direction: SignalDirection
|
||||
|
||||
|
||||
@dataclass
|
||||
class CloseSignal(OrderSignal):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenSignal(OrderSignal):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenLongMarketOrder(OpenSignal):
|
||||
direction: SignalDirection = field(default=SignalDirection.LONG)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CloseLongMarketOrder(CloseSignal):
|
||||
direction: SignalDirection = field(default=SignalDirection.LONG)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenShortMarketOrder(OpenSignal):
|
||||
direction: SignalDirection = field(default=SignalDirection.SHORT)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CloseShortMarketOrder(CloseSignal):
|
||||
direction: SignalDirection = field(default=SignalDirection.SHORT)
|
||||
@@ -0,0 +1,55 @@
|
||||
from t_tech.invest import OrderDirection, OrderType
|
||||
from t_tech.invest.services import Services
|
||||
from t_tech.invest.strategies.base.signal import (
|
||||
CloseLongMarketOrder,
|
||||
CloseShortMarketOrder,
|
||||
OpenLongMarketOrder,
|
||||
OpenShortMarketOrder,
|
||||
)
|
||||
from t_tech.invest.strategies.base.strategy_settings_base import StrategySettings
|
||||
|
||||
|
||||
class SignalExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
services: Services,
|
||||
settings: StrategySettings,
|
||||
):
|
||||
self._services = services
|
||||
self._settings = settings
|
||||
|
||||
def execute_open_long_market_order(self, signal: OpenLongMarketOrder) -> None:
|
||||
self._services.orders.post_order(
|
||||
figi=self._settings.share_id,
|
||||
quantity=signal.lots,
|
||||
direction=OrderDirection.ORDER_DIRECTION_BUY,
|
||||
account_id=self._settings.account_id,
|
||||
order_type=OrderType.ORDER_TYPE_MARKET,
|
||||
)
|
||||
|
||||
def execute_close_long_market_order(self, signal: CloseLongMarketOrder) -> None:
|
||||
self._services.orders.post_order(
|
||||
figi=self._settings.share_id,
|
||||
quantity=signal.lots,
|
||||
direction=OrderDirection.ORDER_DIRECTION_SELL,
|
||||
account_id=self._settings.account_id,
|
||||
order_type=OrderType.ORDER_TYPE_MARKET,
|
||||
)
|
||||
|
||||
def execute_open_short_market_order(self, signal: OpenShortMarketOrder) -> None:
|
||||
self._services.orders.post_order(
|
||||
figi=self._settings.share_id,
|
||||
quantity=signal.lots,
|
||||
direction=OrderDirection.ORDER_DIRECTION_SELL,
|
||||
account_id=self._settings.account_id,
|
||||
order_type=OrderType.ORDER_TYPE_MARKET,
|
||||
)
|
||||
|
||||
def execute_close_short_market_order(self, signal: CloseShortMarketOrder) -> None:
|
||||
self._services.orders.post_order(
|
||||
figi=self._settings.share_id,
|
||||
quantity=signal.lots,
|
||||
direction=OrderDirection.ORDER_DIRECTION_BUY,
|
||||
account_id=self._settings.account_id,
|
||||
order_type=OrderType.ORDER_TYPE_MARKET,
|
||||
)
|
||||
@@ -0,0 +1,15 @@
|
||||
from typing import Iterable, Protocol
|
||||
|
||||
from t_tech.invest.strategies.base.models import CandleEvent
|
||||
from t_tech.invest.strategies.base.signal import Signal
|
||||
|
||||
|
||||
class InvestStrategy(Protocol):
|
||||
def fit(self, candles: Iterable[CandleEvent]) -> None:
|
||||
pass
|
||||
|
||||
def observe(self, candle: CandleEvent) -> None:
|
||||
pass
|
||||
|
||||
def predict(self) -> Iterable[Signal]:
|
||||
pass
|
||||
@@ -0,0 +1,19 @@
|
||||
import dataclasses
|
||||
from datetime import timedelta
|
||||
from decimal import Decimal
|
||||
|
||||
from t_tech.invest import CandleInterval
|
||||
from t_tech.invest.typedefs import AccountId, ShareId
|
||||
from t_tech.invest.utils import candle_interval_to_timedelta
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class StrategySettings:
|
||||
share_id: ShareId
|
||||
account_id: AccountId
|
||||
max_transaction_price: Decimal
|
||||
candle_interval: CandleInterval
|
||||
|
||||
@property
|
||||
def candle_interval_timedelta(self) -> timedelta:
|
||||
return candle_interval_to_timedelta(self.candle_interval)
|
||||
@@ -0,0 +1,29 @@
|
||||
import abc
|
||||
from typing import Iterable, Protocol, Type
|
||||
|
||||
from t_tech.invest.strategies.base.event import StrategyEvent
|
||||
|
||||
|
||||
class IStrategySupervisor(Protocol):
|
||||
def notify(self, event: StrategyEvent) -> None:
|
||||
pass
|
||||
|
||||
def get_events(self) -> Iterable[StrategyEvent]:
|
||||
pass
|
||||
|
||||
def get_events_of_type(self, cls: Type[StrategyEvent]) -> Iterable[StrategyEvent]:
|
||||
pass
|
||||
|
||||
|
||||
class StrategySupervisor(abc.ABC, IStrategySupervisor):
|
||||
@abc.abstractmethod
|
||||
def notify(self, event: StrategyEvent) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_events(self) -> Iterable[StrategyEvent]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_events_of_type(self, cls: Type[StrategyEvent]) -> Iterable[StrategyEvent]:
|
||||
pass
|
||||
@@ -0,0 +1,69 @@
|
||||
import abc
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from typing import Iterable
|
||||
|
||||
import t_tech
|
||||
from t_tech.invest import HistoricCandle
|
||||
from t_tech.invest.services import Services
|
||||
from t_tech.invest.strategies.base.models import Candle, CandleEvent
|
||||
from t_tech.invest.strategies.base.strategy_interface import InvestStrategy
|
||||
from t_tech.invest.strategies.base.strategy_settings_base import StrategySettings
|
||||
from t_tech.invest.strategies.base.trader_interface import ITrader
|
||||
from t_tech.invest.utils import now, quotation_to_decimal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Trader(ITrader, abc.ABC):
|
||||
def __init__(
|
||||
self,
|
||||
strategy: InvestStrategy,
|
||||
services: Services,
|
||||
settings: StrategySettings,
|
||||
):
|
||||
self._strategy = strategy
|
||||
self._services = services
|
||||
self._settings = settings
|
||||
|
||||
@staticmethod
|
||||
def _convert_historic_candles_into_candle_events(
|
||||
historic_candles: Iterable[HistoricCandle],
|
||||
) -> Iterable[CandleEvent]:
|
||||
for candle in historic_candles:
|
||||
yield CandleEvent(
|
||||
candle=Candle(
|
||||
open=quotation_to_decimal(candle.open),
|
||||
close=quotation_to_decimal(candle.close),
|
||||
high=quotation_to_decimal(candle.high),
|
||||
low=quotation_to_decimal(candle.low),
|
||||
),
|
||||
volume=candle.volume,
|
||||
time=candle.time,
|
||||
is_complete=candle.is_complete,
|
||||
)
|
||||
|
||||
def _load_candles(self, period: timedelta) -> Iterable[CandleEvent]:
|
||||
logger.info("Loading candles for period %s from %s", period, now())
|
||||
|
||||
yield from self._convert_historic_candles_into_candle_events(
|
||||
self._services.get_all_candles(
|
||||
figi=self._settings.share_id,
|
||||
from_=now() - period,
|
||||
interval=self._settings.candle_interval,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _convert_candle(candle: t_tech.invest.schemas.Candle) -> CandleEvent:
|
||||
return CandleEvent(
|
||||
candle=Candle(
|
||||
open=quotation_to_decimal(candle.open),
|
||||
close=quotation_to_decimal(candle.close),
|
||||
high=quotation_to_decimal(candle.high),
|
||||
low=quotation_to_decimal(candle.low),
|
||||
),
|
||||
volume=candle.volume,
|
||||
time=candle.time,
|
||||
is_complete=False,
|
||||
)
|
||||
@@ -0,0 +1,6 @@
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class ITrader(Protocol):
|
||||
def trade(self):
|
||||
pass
|
||||
@@ -0,0 +1,197 @@
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional, Type, cast
|
||||
|
||||
import mplfinance as mpf
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from t_tech.invest.strategies.base.event import DataEvent, SignalEvent, StrategyEvent
|
||||
from t_tech.invest.strategies.base.signal import (
|
||||
CloseLongMarketOrder,
|
||||
CloseShortMarketOrder,
|
||||
OpenLongMarketOrder,
|
||||
OpenShortMarketOrder,
|
||||
OrderSignal,
|
||||
Signal,
|
||||
SignalDirection,
|
||||
)
|
||||
from t_tech.invest.strategies.moving_average.strategy_settings import (
|
||||
MovingAverageStrategySettings,
|
||||
)
|
||||
from t_tech.invest.strategies.plotting.plotter import PlotKwargs, StrategyPlotter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MovingAverageStrategyPlotter(StrategyPlotter):
|
||||
def __init__(self, settings: MovingAverageStrategySettings):
|
||||
self._was_not_executed_color = "grey"
|
||||
self._settings = settings
|
||||
self._signal_type_to_style_map: Dict[Type[Signal], Dict[str, Any]] = {
|
||||
OpenLongMarketOrder: {
|
||||
"type": "scatter",
|
||||
"markersize": 50,
|
||||
"marker": "^",
|
||||
"color": "green",
|
||||
},
|
||||
CloseLongMarketOrder: {
|
||||
"type": "scatter",
|
||||
"markersize": 50,
|
||||
"marker": "^",
|
||||
"color": "black",
|
||||
},
|
||||
OpenShortMarketOrder: {
|
||||
"type": "scatter",
|
||||
"markersize": 50,
|
||||
"marker": "v",
|
||||
"color": "red",
|
||||
},
|
||||
CloseShortMarketOrder: {
|
||||
"type": "scatter",
|
||||
"markersize": 50,
|
||||
"marker": "v",
|
||||
"color": "black",
|
||||
},
|
||||
}
|
||||
|
||||
self._signal_type_to_candle_point_map = {
|
||||
SignalDirection.LONG: lambda candle: candle.low,
|
||||
SignalDirection.SHORT: lambda candle: candle.high,
|
||||
}
|
||||
|
||||
def _filter_data_events(
|
||||
self, strategy_events: List[StrategyEvent]
|
||||
) -> List[DataEvent]:
|
||||
return cast(
|
||||
List[DataEvent],
|
||||
list(filter(lambda e: isinstance(e, DataEvent), strategy_events)),
|
||||
)
|
||||
|
||||
def _filter_signal_events(
|
||||
self, strategy_events: List[StrategyEvent]
|
||||
) -> List[SignalEvent]:
|
||||
return cast(
|
||||
List[SignalEvent],
|
||||
list(filter(lambda e: isinstance(e, SignalEvent), strategy_events)),
|
||||
)
|
||||
|
||||
def _get_interval_count_between_dates(
|
||||
self, start: datetime, end: datetime, interval_delta: timedelta
|
||||
) -> float:
|
||||
return (end - start) / interval_delta
|
||||
|
||||
def get_candle_plot_kwargs(
|
||||
self, strategy_events: List[StrategyEvent]
|
||||
) -> PlotKwargs:
|
||||
data_events = self._filter_data_events(strategy_events)
|
||||
quotes = {
|
||||
"open": [float(e.candle_event.candle.open) for e in data_events],
|
||||
"close": [float(e.candle_event.candle.close) for e in data_events],
|
||||
"high": [float(e.candle_event.candle.high) for e in data_events],
|
||||
"low": [float(e.candle_event.candle.low) for e in data_events],
|
||||
"volume": [float(e.candle_event.volume) for e in data_events],
|
||||
"time": [e.candle_event.time for e in data_events],
|
||||
}
|
||||
df = pd.DataFrame(quotes, index=quotes["time"])
|
||||
interval_count = self._get_interval_count_between_dates(
|
||||
start=df["time"].idxmin(),
|
||||
end=df["time"].idxmax(),
|
||||
interval_delta=self._settings.candle_interval_timedelta,
|
||||
)
|
||||
non_trading_coefficient = len(data_events) / interval_count
|
||||
mav = {
|
||||
"ma_short": int(
|
||||
self._settings.short_period
|
||||
/ self._settings.candle_interval_timedelta
|
||||
* non_trading_coefficient
|
||||
),
|
||||
"ma_long": int(
|
||||
self._settings.long_period
|
||||
/ self._settings.candle_interval_timedelta
|
||||
* non_trading_coefficient
|
||||
),
|
||||
}
|
||||
style = mpf.make_mpf_style(
|
||||
base_mpf_style="charles", mavcolors=["#1f77b4", "#ff7f0e", "#2ca02c"]
|
||||
)
|
||||
return cast(
|
||||
PlotKwargs,
|
||||
{
|
||||
"data": df,
|
||||
"type": "candle",
|
||||
"volume": True,
|
||||
"mav": tuple(mav.values()),
|
||||
"style": style,
|
||||
"returnfig": True,
|
||||
},
|
||||
)
|
||||
|
||||
def _get_plot_for_signal_type(
|
||||
self,
|
||||
signal_type: Type[Signal],
|
||||
signal_event_types_to_event_index: Dict[Type[Signal], Dict[int, SignalEvent]],
|
||||
data_events: List[DataEvent],
|
||||
was_executed_flag: bool,
|
||||
) -> Optional[PlotKwargs]:
|
||||
style = self._signal_type_to_style_map[signal_type]
|
||||
price = [np.nan] * len(data_events)
|
||||
color = style["color"]
|
||||
has_signal = False
|
||||
for index, signal_event in signal_event_types_to_event_index[
|
||||
signal_type
|
||||
].items():
|
||||
if was_executed_flag == signal_event.was_executed:
|
||||
has_signal = True
|
||||
candle = data_events[index].candle_event.candle
|
||||
signal = cast(OrderSignal, signal_event.signal)
|
||||
price[index] = self._signal_type_to_candle_point_map[signal.direction](
|
||||
candle
|
||||
)
|
||||
if not signal_event.was_executed:
|
||||
color = self._was_not_executed_color
|
||||
if not has_signal:
|
||||
return None
|
||||
style.update({"color": color})
|
||||
params = {
|
||||
"price": price,
|
||||
"time": [e.candle_event.time for e in data_events],
|
||||
}
|
||||
df = pd.DataFrame(params, index=params["time"])
|
||||
return cast(PlotKwargs, dict(data=df["price"], **style))
|
||||
|
||||
def get_signal_plot_kwargs(
|
||||
self, strategy_events: List[StrategyEvent]
|
||||
) -> List[PlotKwargs]:
|
||||
signal_events = self._filter_signal_events(strategy_events)
|
||||
data_events = self._filter_data_events(strategy_events)
|
||||
data_events.sort(key=lambda e: e.time)
|
||||
first_data_event, last_data_event = data_events[0], data_events[-1]
|
||||
data_events_timedelta = last_data_event.time - first_data_event.time
|
||||
|
||||
signal_event_types_to_event_index: Dict[
|
||||
Type[Signal], Dict[int, SignalEvent]
|
||||
] = {}
|
||||
for signal_event in signal_events:
|
||||
signal_type = type(signal_event.signal)
|
||||
event_index = int(
|
||||
((signal_event.time - first_data_event.time) / data_events_timedelta)
|
||||
* len(data_events)
|
||||
)
|
||||
event_index = min(event_index, len(data_events) - 1)
|
||||
if signal_type not in signal_event_types_to_event_index:
|
||||
signal_event_types_to_event_index[signal_type] = {}
|
||||
signal_event_types_to_event_index[signal_type][event_index] = signal_event
|
||||
|
||||
plots = []
|
||||
for was_executed_flag in [False, True]:
|
||||
for signal_type in signal_event_types_to_event_index:
|
||||
kwargs = self._get_plot_for_signal_type(
|
||||
signal_type=signal_type,
|
||||
signal_event_types_to_event_index=signal_event_types_to_event_index,
|
||||
data_events=data_events,
|
||||
was_executed_flag=was_executed_flag,
|
||||
)
|
||||
plots.append(kwargs)
|
||||
|
||||
return cast(List[PlotKwargs], list(filter(lambda p: p is not None, plots)))
|
||||
@@ -0,0 +1,65 @@
|
||||
import logging
|
||||
from functools import singledispatchmethod
|
||||
|
||||
from t_tech.invest.services import Services
|
||||
from t_tech.invest.strategies.base.errors import UnknownSignal
|
||||
from t_tech.invest.strategies.base.signal import (
|
||||
CloseLongMarketOrder,
|
||||
CloseShortMarketOrder,
|
||||
OpenLongMarketOrder,
|
||||
OpenShortMarketOrder,
|
||||
Signal,
|
||||
)
|
||||
from t_tech.invest.strategies.base.signal_executor_base import SignalExecutor
|
||||
from t_tech.invest.strategies.moving_average.strategy_settings import (
|
||||
MovingAverageStrategySettings,
|
||||
)
|
||||
from t_tech.invest.strategies.moving_average.strategy_state import (
|
||||
MovingAverageStrategyState,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MovingAverageSignalExecutor(SignalExecutor):
|
||||
def __init__(
|
||||
self,
|
||||
services: Services,
|
||||
state: MovingAverageStrategyState,
|
||||
settings: MovingAverageStrategySettings,
|
||||
):
|
||||
super().__init__(services, settings)
|
||||
self._services = services
|
||||
self._state = state
|
||||
|
||||
@singledispatchmethod
|
||||
def execute(self, signal: Signal) -> None:
|
||||
raise UnknownSignal()
|
||||
|
||||
@execute.register
|
||||
def _execute_open_long_market_order(self, signal: OpenLongMarketOrder) -> None:
|
||||
self.execute_open_long_market_order(signal)
|
||||
self._state.long_open = True
|
||||
self._state.position = signal.lots
|
||||
logger.info("Signal executed %s", signal)
|
||||
|
||||
@execute.register
|
||||
def _execute_close_long_market_order(self, signal: CloseLongMarketOrder) -> None:
|
||||
self.execute_close_long_market_order(signal)
|
||||
self._state.long_open = False
|
||||
self._state.position = 0
|
||||
logger.info("Signal executed %s", signal)
|
||||
|
||||
@execute.register
|
||||
def _execute_open_short_market_order(self, signal: OpenShortMarketOrder) -> None:
|
||||
self.execute_open_short_market_order(signal)
|
||||
self._state.short_open = True
|
||||
self._state.position = signal.lots
|
||||
logger.info("Signal executed %s", signal)
|
||||
|
||||
@execute.register
|
||||
def _execute_close_short_market_order(self, signal: CloseShortMarketOrder) -> None:
|
||||
self.execute_close_short_market_order(signal)
|
||||
self._state.short_open = False
|
||||
self._state.position = 0
|
||||
logger.info("Signal executed %s", signal)
|
||||
@@ -0,0 +1,305 @@
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from typing import Callable, Iterable, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from t_tech.invest.strategies.base.account_manager import AccountManager
|
||||
from t_tech.invest.strategies.base.errors import (
|
||||
CandleEventForDateNotFound,
|
||||
NotEnoughData,
|
||||
OldCandleObservingError,
|
||||
)
|
||||
from t_tech.invest.strategies.base.models import CandleEvent
|
||||
from t_tech.invest.strategies.base.signal import (
|
||||
CloseLongMarketOrder,
|
||||
CloseShortMarketOrder,
|
||||
OpenLongMarketOrder,
|
||||
OpenShortMarketOrder,
|
||||
Signal,
|
||||
)
|
||||
from t_tech.invest.strategies.base.strategy_interface import InvestStrategy
|
||||
from t_tech.invest.strategies.moving_average.strategy_settings import (
|
||||
MovingAverageStrategySettings,
|
||||
)
|
||||
from t_tech.invest.strategies.moving_average.strategy_state import (
|
||||
MovingAverageStrategyState,
|
||||
)
|
||||
from t_tech.invest.utils import (
|
||||
candle_interval_to_timedelta,
|
||||
ceil_datetime,
|
||||
floor_datetime,
|
||||
now,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MovingAverageStrategy(InvestStrategy):
|
||||
def __init__(
|
||||
self,
|
||||
settings: MovingAverageStrategySettings,
|
||||
account_manager: AccountManager,
|
||||
state: MovingAverageStrategyState,
|
||||
):
|
||||
self._data: List[CandleEvent] = []
|
||||
self._settings = settings
|
||||
self._account_manager = account_manager
|
||||
|
||||
self._state = state
|
||||
self._MA_LONG_START: Decimal
|
||||
self._candle_interval_timedelta = candle_interval_to_timedelta(
|
||||
self._settings.candle_interval
|
||||
)
|
||||
|
||||
def _ensure_enough_candles(self) -> None:
|
||||
candles_needed = (
|
||||
self._settings.short_period + self._settings.long_period
|
||||
) / self._settings.candle_interval_timedelta
|
||||
if candles_needed > len(self._data):
|
||||
raise NotEnoughData(
|
||||
f"Got {len(self._data)} candles but needed {candles_needed}"
|
||||
)
|
||||
logger.info("Got enough data for strategy")
|
||||
|
||||
def fit(self, candles: Iterable[CandleEvent]) -> None:
|
||||
logger.debug("Strategy fitting with candles %s", candles)
|
||||
for candle in candles:
|
||||
self.observe(candle)
|
||||
self._ensure_enough_candles()
|
||||
|
||||
def _append_candle_event(self, candle_event: CandleEvent) -> None:
|
||||
last_candle_event = self._data[-1]
|
||||
last_interval_floor = floor_datetime(
|
||||
last_candle_event.time, self._candle_interval_timedelta
|
||||
)
|
||||
last_interval_ceil = ceil_datetime(
|
||||
last_candle_event.time, self._candle_interval_timedelta
|
||||
)
|
||||
|
||||
if candle_event.time < last_interval_floor:
|
||||
raise OldCandleObservingError()
|
||||
if (
|
||||
candle_event.time < last_interval_ceil
|
||||
or candle_event.time == last_interval_floor
|
||||
):
|
||||
self._data[-1] = candle_event
|
||||
else:
|
||||
self._data.append(candle_event)
|
||||
|
||||
def observe(self, candle: CandleEvent) -> None:
|
||||
logger.debug("Observing candle event: %s", candle)
|
||||
|
||||
if len(self._data) > 0:
|
||||
self._append_candle_event(candle)
|
||||
else:
|
||||
self._data.append(candle)
|
||||
|
||||
@staticmethod
|
||||
def _get_newer_than_datetime_predicate(
|
||||
anchor: datetime,
|
||||
) -> Callable[[CandleEvent], bool]:
|
||||
def _(event: CandleEvent) -> bool:
|
||||
return event.time > anchor
|
||||
|
||||
return _
|
||||
|
||||
def _filter_from_the_end_with_early_stop(
|
||||
self, predicate: Callable[[CandleEvent], bool]
|
||||
) -> Iterable[CandleEvent]:
|
||||
for event in reversed(self._data):
|
||||
if not predicate(event):
|
||||
break
|
||||
yield event
|
||||
|
||||
def _select_for_period(self, period: timedelta):
|
||||
predicate = self._get_newer_than_datetime_predicate(now() - period)
|
||||
return self._filter_from_the_end_with_early_stop(predicate)
|
||||
|
||||
@staticmethod
|
||||
def _get_prices(events: Iterable[CandleEvent]) -> Iterable[Decimal]:
|
||||
for event in events:
|
||||
yield event.candle.close
|
||||
|
||||
def _calculate_moving_average(self, period: timedelta) -> Decimal:
|
||||
prices = list(self._get_prices(self._select_for_period(period)))
|
||||
logger.debug("Selected prices: %s", prices)
|
||||
return np.mean(prices, axis=0) # type: ignore
|
||||
|
||||
def _calculate_std(self, period: timedelta) -> Decimal:
|
||||
prices = list(self._get_prices(self._select_for_period(period)))
|
||||
return np.std(prices, axis=0) # type: ignore
|
||||
|
||||
def _get_first_candle_before(self, date: datetime) -> CandleEvent:
|
||||
predicate = self._get_newer_than_datetime_predicate(date)
|
||||
for event in reversed(self._data):
|
||||
if not predicate(event):
|
||||
return event
|
||||
raise CandleEventForDateNotFound()
|
||||
|
||||
def _init_MA_LONG_START(self):
|
||||
date = now() - self._settings.short_period
|
||||
event = self._get_first_candle_before(date)
|
||||
self._MA_LONG_START = event.candle.close
|
||||
|
||||
@staticmethod
|
||||
def _is_long_open_signal(
|
||||
MA_SHORT: Decimal,
|
||||
MA_LONG: Decimal,
|
||||
PRICE: Decimal,
|
||||
STD: Decimal,
|
||||
MA_LONG_START: Decimal,
|
||||
) -> bool:
|
||||
logger.debug("Try long opening")
|
||||
logger.debug("\tMA_SHORT > MA_LONG, %s", MA_SHORT > MA_LONG)
|
||||
logger.debug(
|
||||
"\tand abs((PRICE - MA_LONG) / MA_LONG) < STD, %s",
|
||||
abs((PRICE - MA_LONG) / MA_LONG) < STD,
|
||||
)
|
||||
logger.debug("\tand MA_LONG < MA_LONG_START, %s", MA_LONG > MA_LONG_START)
|
||||
logger.debug(
|
||||
"== %s",
|
||||
MA_SHORT > MA_LONG > MA_LONG_START
|
||||
and abs((PRICE - MA_LONG) / MA_LONG) < STD,
|
||||
)
|
||||
return (
|
||||
MA_SHORT > MA_LONG > MA_LONG_START
|
||||
and abs((PRICE - MA_LONG) / MA_LONG) < STD
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_short_open_signal(
|
||||
MA_SHORT: Decimal,
|
||||
MA_LONG: Decimal,
|
||||
PRICE: Decimal,
|
||||
STD: Decimal,
|
||||
MA_LONG_START: Decimal,
|
||||
) -> bool:
|
||||
logger.debug("Try short opening")
|
||||
logger.debug("\tMA_SHORT < MA_LONG, %s", MA_SHORT < MA_LONG)
|
||||
logger.debug(
|
||||
"\tand abs((PRICE - MA_LONG) / MA_LONG) < STD, %s",
|
||||
abs((PRICE - MA_LONG) / MA_LONG) < STD,
|
||||
)
|
||||
logger.debug("\tand MA_LONG > MA_LONG_START, %s", MA_LONG < MA_LONG_START)
|
||||
logger.debug(
|
||||
"== %s",
|
||||
MA_SHORT < MA_LONG < MA_LONG_START
|
||||
and abs((PRICE - MA_LONG) / MA_LONG) < STD,
|
||||
)
|
||||
return (
|
||||
MA_SHORT < MA_LONG < MA_LONG_START
|
||||
and abs((PRICE - MA_LONG) / MA_LONG) < STD
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_long_close_signal(
|
||||
MA_LONG: Decimal,
|
||||
PRICE: Decimal,
|
||||
STD: Decimal,
|
||||
has_short_open_signal: bool,
|
||||
) -> bool:
|
||||
logger.debug("Try long closing")
|
||||
logger.debug("\tPRICE > MA_LONG + 10 * STD, %s", PRICE > MA_LONG + 10 * STD)
|
||||
logger.debug("\tor has_short_open_signal, %s", has_short_open_signal)
|
||||
logger.debug("\tor PRICE < MA_LONG - 3 * STD, %s", PRICE < MA_LONG - 3 * STD)
|
||||
logger.debug(
|
||||
"== %s",
|
||||
PRICE > MA_LONG + 10 * STD
|
||||
or has_short_open_signal
|
||||
or PRICE < MA_LONG - 3 * STD,
|
||||
)
|
||||
return (
|
||||
PRICE > MA_LONG + 10 * STD
|
||||
or has_short_open_signal
|
||||
or PRICE < MA_LONG - 3 * STD
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_short_close_signal(
|
||||
MA_LONG: Decimal,
|
||||
PRICE: Decimal,
|
||||
STD: Decimal,
|
||||
has_long_open_signal: bool,
|
||||
) -> bool:
|
||||
logger.debug("Try short closing")
|
||||
logger.debug("\tPRICE < MA_LONG - 10 * STD, %s", PRICE < MA_LONG - 10 * STD)
|
||||
logger.debug("\tor has_long_open_signal, %s", has_long_open_signal)
|
||||
logger.debug("\tor PRICE > MA_LONG + 3 * STD, %s", PRICE > MA_LONG + 3 * STD)
|
||||
logger.debug(
|
||||
"== %s",
|
||||
PRICE < MA_LONG - 10 * STD # кажется, что не работает закрытие
|
||||
or has_long_open_signal
|
||||
or PRICE > MA_LONG + 3 * STD,
|
||||
)
|
||||
return (
|
||||
PRICE < MA_LONG - 10 * STD
|
||||
or has_long_open_signal
|
||||
or PRICE > MA_LONG + 3 * STD
|
||||
)
|
||||
|
||||
def predict(self) -> Iterable[Signal]:
|
||||
logger.info("Strategy predict")
|
||||
self._init_MA_LONG_START()
|
||||
MA_LONG_START = self._MA_LONG_START
|
||||
logger.debug("MA_LONG_START: %s", MA_LONG_START)
|
||||
PRICE = self._data[-1].candle.close
|
||||
logger.debug("PRICE: %s", PRICE)
|
||||
MA_LONG = self._calculate_moving_average(self._settings.long_period)
|
||||
logger.debug("MA_LONG: %s", MA_LONG)
|
||||
MA_SHORT = self._calculate_moving_average(self._settings.short_period)
|
||||
logger.debug("MA_SHORT: %s", MA_SHORT)
|
||||
STD = self._calculate_std(self._settings.std_period)
|
||||
logger.debug("STD: %s", STD)
|
||||
MONEY = self._account_manager.get_current_balance()
|
||||
logger.debug("MONEY: %s", MONEY)
|
||||
|
||||
has_long_open_signal = False
|
||||
has_short_open_signal = False
|
||||
|
||||
possible_lots = int(MONEY // PRICE)
|
||||
|
||||
if (
|
||||
not self._state.long_open
|
||||
and self._is_long_open_signal(
|
||||
MA_SHORT=MA_SHORT,
|
||||
MA_LONG=MA_LONG,
|
||||
PRICE=PRICE,
|
||||
STD=STD,
|
||||
MA_LONG_START=MA_LONG_START,
|
||||
)
|
||||
and possible_lots > 0
|
||||
):
|
||||
has_long_open_signal = True
|
||||
yield OpenLongMarketOrder(lots=possible_lots)
|
||||
|
||||
if (
|
||||
not self._state.short_open
|
||||
and self._is_short_open_signal(
|
||||
MA_SHORT=MA_SHORT,
|
||||
MA_LONG=MA_LONG,
|
||||
PRICE=PRICE,
|
||||
STD=STD,
|
||||
MA_LONG_START=MA_LONG_START,
|
||||
)
|
||||
and possible_lots > 0
|
||||
):
|
||||
has_short_open_signal = True
|
||||
yield OpenShortMarketOrder(lots=possible_lots)
|
||||
|
||||
if self._state.long_open and self._is_long_close_signal(
|
||||
MA_LONG=MA_LONG,
|
||||
PRICE=PRICE,
|
||||
STD=STD,
|
||||
has_short_open_signal=has_short_open_signal,
|
||||
):
|
||||
yield CloseLongMarketOrder(lots=self._state.position)
|
||||
|
||||
if self._state.short_open and self._is_short_close_signal(
|
||||
MA_LONG=MA_LONG,
|
||||
PRICE=PRICE,
|
||||
STD=STD,
|
||||
has_long_open_signal=has_long_open_signal,
|
||||
):
|
||||
yield CloseShortMarketOrder(lots=self._state.position)
|
||||
@@ -0,0 +1,11 @@
|
||||
import dataclasses
|
||||
from datetime import timedelta
|
||||
|
||||
from t_tech.invest.strategies.base.strategy_settings_base import StrategySettings
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MovingAverageStrategySettings(StrategySettings):
|
||||
long_period: timedelta
|
||||
short_period: timedelta
|
||||
std_period: timedelta
|
||||
@@ -0,0 +1,29 @@
|
||||
class MovingAverageStrategyState:
|
||||
def __init__(self):
|
||||
self._long_open: bool = False
|
||||
self._short_open: bool = False
|
||||
self._position: int = 0
|
||||
|
||||
@property
|
||||
def long_open(self) -> bool:
|
||||
return self._long_open
|
||||
|
||||
@long_open.setter
|
||||
def long_open(self, value: bool) -> None:
|
||||
self._long_open = value
|
||||
|
||||
@property
|
||||
def short_open(self) -> bool:
|
||||
return self._short_open
|
||||
|
||||
@short_open.setter
|
||||
def short_open(self, value: bool) -> None:
|
||||
self._short_open = value
|
||||
|
||||
@property
|
||||
def position(self) -> int:
|
||||
return self._position
|
||||
|
||||
@position.setter
|
||||
def position(self, value: int) -> None:
|
||||
self._position = value
|
||||
@@ -0,0 +1,24 @@
|
||||
from itertools import chain
|
||||
from typing import Dict, Iterable, List, Type, cast
|
||||
|
||||
from t_tech.invest.strategies.base.errors import EventsWereNotSupervised
|
||||
from t_tech.invest.strategies.base.event import StrategyEvent
|
||||
from t_tech.invest.strategies.base.strategy_supervisor import StrategySupervisor
|
||||
|
||||
|
||||
class MovingAverageStrategySupervisor(StrategySupervisor):
|
||||
def __init__(self):
|
||||
self._events: Dict[Type[StrategyEvent], List[StrategyEvent]] = {}
|
||||
|
||||
def notify(self, event: StrategyEvent) -> None:
|
||||
if type(event) not in self._events:
|
||||
self._events[type(event)] = []
|
||||
self._events[type(event)].append(event)
|
||||
|
||||
def get_events(self) -> Iterable[StrategyEvent]:
|
||||
return cast(Iterable[StrategyEvent], chain(*self._events.values()))
|
||||
|
||||
def get_events_of_type(self, cls: Type[StrategyEvent]) -> List[StrategyEvent]:
|
||||
if cls in self._events:
|
||||
return self._events[cls]
|
||||
raise EventsWereNotSupervised()
|
||||
@@ -0,0 +1,172 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Iterator, List
|
||||
|
||||
import t_tech
|
||||
from t_tech.invest import (
|
||||
CandleInstrument,
|
||||
InvestError,
|
||||
MarketDataRequest,
|
||||
MarketDataResponse,
|
||||
SubscribeCandlesRequest,
|
||||
SubscriptionAction,
|
||||
SubscriptionInterval,
|
||||
)
|
||||
from t_tech.invest.services import Services
|
||||
from t_tech.invest.strategies.base.account_manager import AccountManager
|
||||
from t_tech.invest.strategies.base.errors import MarketDataNotAvailableError
|
||||
from t_tech.invest.strategies.base.event import DataEvent, SignalEvent
|
||||
from t_tech.invest.strategies.base.models import CandleEvent
|
||||
from t_tech.invest.strategies.base.signal import CloseSignal, OpenSignal, Signal
|
||||
from t_tech.invest.strategies.base.signal_executor_base import SignalExecutor
|
||||
from t_tech.invest.strategies.base.trader_base import Trader
|
||||
from t_tech.invest.strategies.moving_average.strategy import MovingAverageStrategy
|
||||
from t_tech.invest.strategies.moving_average.strategy_settings import (
|
||||
MovingAverageStrategySettings,
|
||||
)
|
||||
from t_tech.invest.strategies.moving_average.strategy_state import (
|
||||
MovingAverageStrategyState,
|
||||
)
|
||||
from t_tech.invest.strategies.moving_average.supervisor import (
|
||||
MovingAverageStrategySupervisor,
|
||||
)
|
||||
from t_tech.invest.utils import floor_datetime, now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MovingAverageStrategyTrader(Trader):
|
||||
def __init__(
|
||||
self,
|
||||
strategy: MovingAverageStrategy,
|
||||
settings: MovingAverageStrategySettings,
|
||||
services: Services,
|
||||
state: MovingAverageStrategyState,
|
||||
signal_executor: SignalExecutor,
|
||||
account_manager: AccountManager,
|
||||
supervisor: MovingAverageStrategySupervisor,
|
||||
):
|
||||
super().__init__(strategy, services, settings)
|
||||
self._settings: MovingAverageStrategySettings = settings
|
||||
self._strategy = strategy
|
||||
self._services = services
|
||||
self._data: List[CandleEvent]
|
||||
self._market_data_stream: Iterator[MarketDataResponse]
|
||||
self._state = state
|
||||
self._signal_executor = signal_executor
|
||||
self._account_manager = account_manager
|
||||
self._supervisor = supervisor
|
||||
|
||||
self._data = list(
|
||||
self._load_candles(
|
||||
(self._settings.short_period + self._settings.long_period) * 3
|
||||
)
|
||||
)
|
||||
for candle_event in self._data:
|
||||
self._supervisor.notify(self._convert_to_data_event(candle_event))
|
||||
self._strategy.fit(self._data)
|
||||
|
||||
self._ensure_marginal_trade_active()
|
||||
self._subscribe()
|
||||
|
||||
def _ensure_marginal_trade_active(self) -> None:
|
||||
self._account_manager.ensure_marginal_trade()
|
||||
|
||||
def _subscribe(self):
|
||||
current_instrument = CandleInstrument(
|
||||
figi=self._settings.share_id,
|
||||
interval=SubscriptionInterval.SUBSCRIPTION_INTERVAL_ONE_MINUTE,
|
||||
)
|
||||
candle_subscribe_request = MarketDataRequest(
|
||||
subscribe_candles_request=SubscribeCandlesRequest(
|
||||
subscription_action=SubscriptionAction.SUBSCRIPTION_ACTION_SUBSCRIBE,
|
||||
instruments=[current_instrument],
|
||||
)
|
||||
)
|
||||
|
||||
def request_iterator():
|
||||
yield candle_subscribe_request
|
||||
while True:
|
||||
time.sleep(1)
|
||||
|
||||
self._market_data_stream = self._services.market_data_stream.market_data_stream(
|
||||
request_iterator()
|
||||
)
|
||||
|
||||
def _is_candle_fresh(self, candle: t_tech.invest.Candle) -> bool:
|
||||
is_fresh_border = floor_datetime(
|
||||
now(), delta=self._settings.candle_interval_timedelta
|
||||
)
|
||||
logger.debug(
|
||||
"Checking if candle is fresh: candle.time=%s > is_fresh_border=%s %s)",
|
||||
candle.time,
|
||||
is_fresh_border,
|
||||
candle.time >= is_fresh_border,
|
||||
)
|
||||
return candle.time >= is_fresh_border
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_data_event(candle_event: CandleEvent) -> DataEvent:
|
||||
return DataEvent(candle_event=candle_event, time=candle_event.time)
|
||||
|
||||
def _make_observations(self) -> None:
|
||||
while True:
|
||||
market_data_response: MarketDataResponse = next(self._market_data_stream)
|
||||
logger.debug("got market_data_response: %s", market_data_response)
|
||||
if market_data_response.candle is None:
|
||||
logger.debug("market_data_response didn't have candle")
|
||||
continue
|
||||
candle = market_data_response.candle
|
||||
logger.debug("candle extracted: %s", candle)
|
||||
candle_event = self._convert_candle(candle)
|
||||
self._strategy.observe(candle_event)
|
||||
self._supervisor.notify(self._convert_to_data_event(candle_event))
|
||||
if self._is_candle_fresh(candle):
|
||||
logger.info("Data refreshed")
|
||||
break
|
||||
|
||||
def _refresh_data(self) -> None:
|
||||
logger.info("Refreshing data")
|
||||
try:
|
||||
self._make_observations()
|
||||
except StopIteration as e:
|
||||
logger.info("Fresh quotations not available")
|
||||
raise MarketDataNotAvailableError() from e
|
||||
|
||||
def _filter_closing_signals(self, signals: List[Signal]) -> List[Signal]:
|
||||
return list(filter(lambda signal: isinstance(signal, CloseSignal), signals))
|
||||
|
||||
def _filter_opening_signals(self, signals: List[Signal]) -> List[Signal]:
|
||||
return list(filter(lambda signal: isinstance(signal, OpenSignal), signals))
|
||||
|
||||
def _execute(self, signal: Signal) -> None:
|
||||
logger.info("Trying to execute signal %s", signal)
|
||||
try:
|
||||
self._signal_executor.execute(signal) # type: ignore
|
||||
except InvestError:
|
||||
was_executed = False
|
||||
else:
|
||||
was_executed = True
|
||||
self._supervisor.notify(
|
||||
SignalEvent(signal=signal, was_executed=was_executed, time=now())
|
||||
)
|
||||
|
||||
def _get_signals(self) -> List[Signal]:
|
||||
signals = list(self._strategy.predict())
|
||||
return [
|
||||
*self._filter_closing_signals(signals),
|
||||
*self._filter_opening_signals(signals),
|
||||
]
|
||||
|
||||
def trade(self) -> None:
|
||||
"""Делает попытку следовать стратегии."""
|
||||
logger.info("Balance: %s", self._account_manager.get_current_balance())
|
||||
self._refresh_data()
|
||||
|
||||
signals = self._get_signals()
|
||||
if signals:
|
||||
logger.info("Got signals %s", signals)
|
||||
for signal in signals:
|
||||
self._execute(signal)
|
||||
if self._state.position == 0:
|
||||
logger.info("Trade try complete")
|
||||
@@ -0,0 +1,62 @@
|
||||
import abc
|
||||
import logging
|
||||
from typing import Any, Iterable, List, NewType, Protocol
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import mplfinance as mpf
|
||||
from IPython.display import clear_output
|
||||
from matplotlib.gridspec import GridSpec
|
||||
|
||||
from t_tech.invest.strategies.base.event import StrategyEvent
|
||||
|
||||
PlotKwargs = NewType("PlotKwargs", dict)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IPlotter(Protocol):
|
||||
def plot(self, strategy_events: Iterable[StrategyEvent]) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class StrategyPlotter(abc.ABC, IPlotter):
|
||||
@abc.abstractmethod
|
||||
def get_candle_plot_kwargs(
|
||||
self, strategy_events: List[StrategyEvent]
|
||||
) -> PlotKwargs:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_signal_plot_kwargs(
|
||||
self, strategy_events: List[StrategyEvent]
|
||||
) -> List[PlotKwargs]:
|
||||
pass
|
||||
|
||||
def get_plot_kwargs(
|
||||
self, strategy_events: Iterable[StrategyEvent], ax: Any
|
||||
) -> PlotKwargs:
|
||||
strategy_events = list(strategy_events)
|
||||
candle_plot = self.get_candle_plot_kwargs(strategy_events=strategy_events)
|
||||
if signal_plots := self.get_signal_plot_kwargs(strategy_events=strategy_events):
|
||||
add_plots = []
|
||||
for signal_plot in signal_plots:
|
||||
signal_plot.update({"ax": ax})
|
||||
ap = mpf.make_addplot(**signal_plot)
|
||||
add_plots.append(ap)
|
||||
|
||||
candle_plot.update({"addplot": add_plots})
|
||||
return candle_plot
|
||||
|
||||
def plot(self, strategy_events: Iterable[StrategyEvent]) -> None:
|
||||
_fig = plt.figure(figsize=(20, 20))
|
||||
gs = GridSpec(2, 1, height_ratios=[3, 1])
|
||||
_ax1 = plt.subplot(gs[0])
|
||||
_ax2 = plt.subplot(gs[1])
|
||||
|
||||
candle_plot_kwargs = self.get_plot_kwargs(strategy_events, ax=_ax1)
|
||||
candle_plot_kwargs.update({"ax": _ax1, "volume": _ax2})
|
||||
mpf.plot(**candle_plot_kwargs, warn_too_much_data=999999999)
|
||||
|
||||
clear_output(wait=True)
|
||||
_fig.canvas.draw()
|
||||
_fig.canvas.flush_events()
|
||||
Reference in New Issue
Block a user