Modules

API

Core

Utils

class lstm_forecast.utils.early_stopping.EarlyStopping(patience=10, delta=0.001, verbose=False, path='checkpoint.pt')[source]

Bases: object

Monitors the validation loss during training and triggers early stopping if the validation loss does not improve after a specified number of epochs (patience). Also saves the best model state.

patience

Number of epochs to wait for an improvement in validation loss before stopping.

Type:

int

delta

Minimum change in the monitored validation loss to qualify as an improvement.

Type:

float

verbose

If True, logs information about validation loss improvements and early stopping.

Type:

bool

path

File path to save the model checkpoint with the best validation loss.

Type:

str

counter

Counts the number of consecutive epochs without improvement in validation loss.

Type:

int

best_loss

Best validation loss observed during training.

Type:

float or None

early_stop

Flag to indicate whether early stopping has been triggered.

Type:

bool

best_model_state

State dictionary of the model corresponding to the best validation loss.

Type:

dict or None

logger

Logger instance for logging messages related to early stopping.

Type:

logging.Logger

reset()[source]

Resets the early stopping attributes to their initial states. This is useful if you want to reuse the EarlyStopping instance for another training session.

Parameters:

None

Return type:

None

Returns:

None

lstm_forecast.utils.logger.setup_logger(name, log_file='logs/lstm_forecast.log', level=20)[source]
Return type:

Logger

lstm_forecast.utils.predict_utils.inverse_transform_predictions(predictions, scaler_prices, scaler_volume, features, num_targets)[source]
Return type:

ndarray

Parameters:
  • predictions (ndarray)

  • scaler_prices (StandardScaler)

  • scaler_volume (MinMaxScaler)

  • features (List[str])

  • num_targets (int)

lstm_forecast.utils.predict_utils.create_candles(predictions, freq, start_date)[source]
Return type:

DataFrame

Parameters:
  • predictions (ndarray)

  • freq (str)

  • start_date (Timestamp)

lstm_forecast.utils.predict_utils.plot_predictions(symbol, filename, candles, predictions, future_predictions, freq, interval, logger)[source]
Return type:

None

Parameters:
  • symbol (str)

  • filename (str)

  • candles (DataFrame)

  • predictions (ndarray)

  • future_predictions (ndarray)

  • freq (str)

  • interval (str)

lstm_forecast.utils.predict_utils.add_candlestick_trace(fig, candles, name, increasing_color='green', decreasing_color='red')[source]
Return type:

None

Parameters:
  • fig (Figure)

  • candles (DataFrame)

  • name (str)

lstm_forecast.utils.predict_utils.update_layout(fig, symbol, interval)[source]
Return type:

None

Parameters:
  • fig (Figure)

  • symbol (str)

  • interval (str)

lstm_forecast.utils.predict_utils.save_predictions_report(predictions, targets, start_date, freq, symbol)[source]
Return type:

None

Parameters:
  • predictions (ndarray)

  • targets (List[str])

  • start_date (Timestamp)

  • freq (str)

  • symbol (str)

Backtesting

lstm_forecast.backtesting.metrics.calculate_sortino_ratio(historical_data)[source]
Return type:

float

Parameters:

historical_data (DataFrame)

lstm_forecast.backtesting.metrics.calculate_sharpe_ratio(historical_data)[source]
Return type:

float

Parameters:

historical_data (DataFrame)

lstm_forecast.backtesting.plot.plot_predictions_with_orders(symbol, filename, historical_data, predictions, future_predictions, data, freq, transactions)[source]
Return type:

None

Parameters:
  • symbol (str)

  • filename (str)

  • historical_data (ndarray)

  • predictions (ndarray)

  • future_predictions (ndarray)

  • data (DataFrame)

  • freq (str)

  • transactions (DataFrame)

Other

class lstm_forecast.config.Config(config_path)[source]

Bases: object

load_config()[source]
save()[source]
update(key, value)[source]
get(key, default=None)[source]
lstm_forecast.config.load_config(config_path)[source]
lstm_forecast.config.update_config(config, key, value)[source]