Files

198 lines
7.4 KiB
Python

import logging
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Type, cast
import mplfinance as mpf
import numpy as np
import pandas as pd
from t_tech.invest.strategies.base.event import DataEvent, SignalEvent, StrategyEvent
from t_tech.invest.strategies.base.signal import (
CloseLongMarketOrder,
CloseShortMarketOrder,
OpenLongMarketOrder,
OpenShortMarketOrder,
OrderSignal,
Signal,
SignalDirection,
)
from t_tech.invest.strategies.moving_average.strategy_settings import (
MovingAverageStrategySettings,
)
from t_tech.invest.strategies.plotting.plotter import PlotKwargs, StrategyPlotter
logger = logging.getLogger(__name__)
class MovingAverageStrategyPlotter(StrategyPlotter):
def __init__(self, settings: MovingAverageStrategySettings):
self._was_not_executed_color = "grey"
self._settings = settings
self._signal_type_to_style_map: Dict[Type[Signal], Dict[str, Any]] = {
OpenLongMarketOrder: {
"type": "scatter",
"markersize": 50,
"marker": "^",
"color": "green",
},
CloseLongMarketOrder: {
"type": "scatter",
"markersize": 50,
"marker": "^",
"color": "black",
},
OpenShortMarketOrder: {
"type": "scatter",
"markersize": 50,
"marker": "v",
"color": "red",
},
CloseShortMarketOrder: {
"type": "scatter",
"markersize": 50,
"marker": "v",
"color": "black",
},
}
self._signal_type_to_candle_point_map = {
SignalDirection.LONG: lambda candle: candle.low,
SignalDirection.SHORT: lambda candle: candle.high,
}
def _filter_data_events(
self, strategy_events: List[StrategyEvent]
) -> List[DataEvent]:
return cast(
List[DataEvent],
list(filter(lambda e: isinstance(e, DataEvent), strategy_events)),
)
def _filter_signal_events(
self, strategy_events: List[StrategyEvent]
) -> List[SignalEvent]:
return cast(
List[SignalEvent],
list(filter(lambda e: isinstance(e, SignalEvent), strategy_events)),
)
def _get_interval_count_between_dates(
self, start: datetime, end: datetime, interval_delta: timedelta
) -> float:
return (end - start) / interval_delta
def get_candle_plot_kwargs(
self, strategy_events: List[StrategyEvent]
) -> PlotKwargs:
data_events = self._filter_data_events(strategy_events)
quotes = {
"open": [float(e.candle_event.candle.open) for e in data_events],
"close": [float(e.candle_event.candle.close) for e in data_events],
"high": [float(e.candle_event.candle.high) for e in data_events],
"low": [float(e.candle_event.candle.low) for e in data_events],
"volume": [float(e.candle_event.volume) for e in data_events],
"time": [e.candle_event.time for e in data_events],
}
df = pd.DataFrame(quotes, index=quotes["time"])
interval_count = self._get_interval_count_between_dates(
start=df["time"].idxmin(),
end=df["time"].idxmax(),
interval_delta=self._settings.candle_interval_timedelta,
)
non_trading_coefficient = len(data_events) / interval_count
mav = {
"ma_short": int(
self._settings.short_period
/ self._settings.candle_interval_timedelta
* non_trading_coefficient
),
"ma_long": int(
self._settings.long_period
/ self._settings.candle_interval_timedelta
* non_trading_coefficient
),
}
style = mpf.make_mpf_style(
base_mpf_style="charles", mavcolors=["#1f77b4", "#ff7f0e", "#2ca02c"]
)
return cast(
PlotKwargs,
{
"data": df,
"type": "candle",
"volume": True,
"mav": tuple(mav.values()),
"style": style,
"returnfig": True,
},
)
def _get_plot_for_signal_type(
self,
signal_type: Type[Signal],
signal_event_types_to_event_index: Dict[Type[Signal], Dict[int, SignalEvent]],
data_events: List[DataEvent],
was_executed_flag: bool,
) -> Optional[PlotKwargs]:
style = self._signal_type_to_style_map[signal_type]
price = [np.nan] * len(data_events)
color = style["color"]
has_signal = False
for index, signal_event in signal_event_types_to_event_index[
signal_type
].items():
if was_executed_flag == signal_event.was_executed:
has_signal = True
candle = data_events[index].candle_event.candle
signal = cast(OrderSignal, signal_event.signal)
price[index] = self._signal_type_to_candle_point_map[signal.direction](
candle
)
if not signal_event.was_executed:
color = self._was_not_executed_color
if not has_signal:
return None
style.update({"color": color})
params = {
"price": price,
"time": [e.candle_event.time for e in data_events],
}
df = pd.DataFrame(params, index=params["time"])
return cast(PlotKwargs, dict(data=df["price"], **style))
def get_signal_plot_kwargs(
self, strategy_events: List[StrategyEvent]
) -> List[PlotKwargs]:
signal_events = self._filter_signal_events(strategy_events)
data_events = self._filter_data_events(strategy_events)
data_events.sort(key=lambda e: e.time)
first_data_event, last_data_event = data_events[0], data_events[-1]
data_events_timedelta = last_data_event.time - first_data_event.time
signal_event_types_to_event_index: Dict[
Type[Signal], Dict[int, SignalEvent]
] = {}
for signal_event in signal_events:
signal_type = type(signal_event.signal)
event_index = int(
((signal_event.time - first_data_event.time) / data_events_timedelta)
* len(data_events)
)
event_index = min(event_index, len(data_events) - 1)
if signal_type not in signal_event_types_to_event_index:
signal_event_types_to_event_index[signal_type] = {}
signal_event_types_to_event_index[signal_type][event_index] = signal_event
plots = []
for was_executed_flag in [False, True]:
for signal_type in signal_event_types_to_event_index:
kwargs = self._get_plot_for_signal_type(
signal_type=signal_type,
signal_event_types_to_event_index=signal_event_types_to_event_index,
data_events=data_events,
was_executed_flag=was_executed_flag,
)
plots.append(kwargs)
return cast(List[PlotKwargs], list(filter(lambda p: p is not None, plots)))