Modules
API
Core
Utils
- class lstm_forecast.utils.early_stopping.EarlyStopping(patience=10, delta=0.001, verbose=False, path='checkpoint.pt')[source]
Bases:
objectMonitors the validation loss during training and triggers early stopping if the validation loss does not improve after a specified number of epochs (patience). Also saves the best model state.
- patience
Number of epochs to wait for an improvement in validation loss before stopping.
- Type:
int
- delta
Minimum change in the monitored validation loss to qualify as an improvement.
- Type:
float
- verbose
If True, logs information about validation loss improvements and early stopping.
- Type:
bool
- path
File path to save the model checkpoint with the best validation loss.
- Type:
str
- counter
Counts the number of consecutive epochs without improvement in validation loss.
- Type:
int
- best_loss
Best validation loss observed during training.
- Type:
float or None
- early_stop
Flag to indicate whether early stopping has been triggered.
- Type:
bool
- best_model_state
State dictionary of the model corresponding to the best validation loss.
- Type:
dict or None
- logger
Logger instance for logging messages related to early stopping.
- Type:
logging.Logger
- lstm_forecast.utils.logger.setup_logger(name, log_file='logs/lstm_forecast.log', level=20)[source]
- Return type:
Logger
- lstm_forecast.utils.predict_utils.inverse_transform_predictions(predictions, scaler_prices, scaler_volume, features, num_targets)[source]
- Return type:
ndarray- Parameters:
predictions (ndarray)
scaler_prices (StandardScaler)
scaler_volume (MinMaxScaler)
features (List[str])
num_targets (int)
- lstm_forecast.utils.predict_utils.create_candles(predictions, freq, start_date)[source]
- Return type:
DataFrame- Parameters:
predictions (ndarray)
freq (str)
start_date (Timestamp)
- lstm_forecast.utils.predict_utils.plot_predictions(symbol, filename, candles, predictions, future_predictions, freq, interval, logger)[source]
- Return type:
None- Parameters:
symbol (str)
filename (str)
candles (DataFrame)
predictions (ndarray)
future_predictions (ndarray)
freq (str)
interval (str)
- lstm_forecast.utils.predict_utils.add_candlestick_trace(fig, candles, name, increasing_color='green', decreasing_color='red')[source]
- Return type:
None- Parameters:
fig (Figure)
candles (DataFrame)
name (str)
Backtesting
- lstm_forecast.backtesting.metrics.calculate_sortino_ratio(historical_data)[source]
- Return type:
float- Parameters:
historical_data (DataFrame)
- lstm_forecast.backtesting.metrics.calculate_sharpe_ratio(historical_data)[source]
- Return type:
float- Parameters:
historical_data (DataFrame)
- lstm_forecast.backtesting.plot.plot_predictions_with_orders(symbol, filename, historical_data, predictions, future_predictions, data, freq, transactions)[source]
- Return type:
None- Parameters:
symbol (str)
filename (str)
historical_data (ndarray)
predictions (ndarray)
future_predictions (ndarray)
data (DataFrame)
freq (str)
transactions (DataFrame)