Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ dev = [
"pytest-mock>=3.14.0,<4",
"pytest-asyncio>=0.23.0,<1.3",
"pytest-watch>=4.2.0,<5",
"ruff>=0.9.1,<0.15.0",
"ruff>=0.9.1",
"pyright>=1.1.403",
]

Expand Down
88 changes: 78 additions & 10 deletions tests/test_ibkr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from ib_async import IB, Contract, Order, OrderStatus, Stock, Ticker, Trade

from thetagang import log
from thetagang.ibkr import IBKR, RequiredFieldValidationError, TickerField
from thetagang.ibkr import (
IBKR,
IBKRRequestTimeout,
RequiredFieldValidationError,
TickerField,
)

# Mark all tests in this module as asyncio
pytestmark = pytest.mark.asyncio
Expand Down Expand Up @@ -51,9 +56,6 @@ def mock_trade(mocker):
return trade


# --- Tests for get_ticker_for_contract ---


async def test_get_ticker_for_contract_success(ibkr, mock_ib, mock_ticker, mocker):
"""Test get_ticker_for_contract when all waits succeed."""
mocker.patch.object(
Expand Down Expand Up @@ -161,9 +163,6 @@ def mock_handler_logic(field):
assert "MIDPOINT" in mock_log_warning.call_args[0][0]


# --- Tests for wait_for_submitting_orders ---


async def test_wait_for_submitting_orders_success(ibkr, mock_trade, mocker):
"""Test wait_for_submitting_orders when all waits succeed."""
mocker.patch.object(
Expand Down Expand Up @@ -215,9 +214,6 @@ async def mock_wait(*args, **kwargs):
assert ibkr.__trade_wait_for_condition__.call_count == 2


# --- Tests for wait_for_orders_complete ---


async def test_wait_for_orders_complete_success(ibkr, mock_trade, mocker):
"""Test wait_for_orders_complete when all waits succeed."""
mocker.patch.object(
Expand Down Expand Up @@ -265,3 +261,75 @@ async def mock_wait(*args, **kwargs):
assert "Timeout waiting for orders to complete" in mock_log_warning.call_args[0][0]
assert "FAIL (OrderId: 2)" in mock_log_warning.call_args[0][0]
assert "PASS (OrderId: 1)" not in mock_log_warning.call_args[0][0]


async def test_refresh_account_updates_uses_timeout_wrapper(ibkr, mocker):
"""refresh_account_updates delegates to _await_with_timeout."""
req_future: asyncio.Future = asyncio.get_running_loop().create_future()
req_future.set_result(None)
ibkr.ib.reqAccountUpdatesAsync = mocker.Mock(return_value=req_future)
await_wrapper = mocker.patch.object(
ibkr, "_await_with_timeout", new=mocker.AsyncMock(return_value=None)
)

await ibkr.refresh_account_updates("ACC123")

ibkr.ib.reqAccountUpdatesAsync.assert_called_once_with("ACC123")
assert await_wrapper.await_count == 1
await_args = await_wrapper.await_args
assert await_args.args[0] is req_future
assert await_args.args[1] == "account updates"


async def test_refresh_positions_uses_timeout_wrapper(ibkr, mocker):
"""refresh_positions delegates to _await_with_timeout."""
req_future: asyncio.Future = asyncio.get_running_loop().create_future()
req_future.set_result([])
ibkr.ib.reqPositionsAsync = mocker.Mock(return_value=req_future)
await_wrapper = mocker.patch.object(
ibkr, "_await_with_timeout", new=mocker.AsyncMock(return_value=[])
)

result = await ibkr.refresh_positions()

assert result == []
ibkr.ib.reqPositionsAsync.assert_called_once_with()
assert await_wrapper.await_count == 1
await_args = await_wrapper.await_args
assert await_args.args[0] is req_future
assert await_args.args[1] == "positions snapshot"


async def test_refresh_account_updates_propagates_timeout(ibkr, mocker):
"""refresh_account_updates re-raises IBKRRequestTimeout."""
ibkr.ib.reqAccountUpdatesAsync = mocker.Mock(return_value=object())
mocker.patch.object(
ibkr,
"_await_with_timeout",
new=mocker.AsyncMock(
side_effect=IBKRRequestTimeout(
"account updates", ibkr.api_response_wait_time
)
),
)

with pytest.raises(IBKRRequestTimeout):
await ibkr.refresh_account_updates("ACC123")


async def test_await_with_timeout_wraps_timeout_error(ibkr, mocker):
"""_await_with_timeout raises IBKRRequestTimeout on asyncio timeout."""

async def dummy() -> None:
return None

async def fake_wait_for(awaitable, timeout):
await awaitable
raise asyncio.TimeoutError()

mocker.patch("thetagang.ibkr.asyncio.wait_for", new=fake_wait_for)

with pytest.raises(IBKRRequestTimeout) as excinfo:
await ibkr._await_with_timeout(dummy(), "positions snapshot")

assert "positions snapshot" in str(excinfo.value)
94 changes: 94 additions & 0 deletions tests/test_portfolio_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from types import SimpleNamespace

import pytest
from ib_async import IB, Stock, Ticker

from thetagang.ibkr import IBKRRequestTimeout
from thetagang.portfolio_manager import PortfolioManager


Expand All @@ -23,6 +26,8 @@ def mock_config(mocker):
config.ib_async.api_response_wait_time = 1
config.orders = mocker.Mock()
config.orders.exchange = "SMART"
config.cash_management = mocker.Mock()
config.cash_management.cash_fund = "MMDA1"
return config


Expand Down Expand Up @@ -238,6 +243,95 @@ async def mock_track_async(tasks, description):
for symbol, _, _, _ in to_write:
assert symbol != "AAPL"

@pytest.mark.asyncio
async def test_get_portfolio_positions_success(self, portfolio_manager, mocker):
"""Returns filtered positions when both portfolio and snapshot succeed."""
portfolio_manager.config.symbols = {"AAPL": mocker.Mock()}

portfolio_item = SimpleNamespace(
account="TEST123",
contract=SimpleNamespace(symbol="AAPL", conId=1),
position=5,
averageCost=100.0,
marketPrice=105.0,
marketValue=525.0,
unrealizedPNL=25.0,
)
snapshot_position = SimpleNamespace(
account="TEST123",
contract=SimpleNamespace(symbol="AAPL", conId=1),
position=5,
)

portfolio_manager.ibkr.refresh_account_updates = mocker.AsyncMock()
portfolio_manager.ibkr.portfolio = mocker.Mock(return_value=[portfolio_item])
portfolio_manager.ibkr.refresh_positions = mocker.AsyncMock(
return_value=[snapshot_position]
)

result = await portfolio_manager.get_portfolio_positions()

assert result == {"AAPL": [portfolio_item]}
portfolio_manager.ibkr.refresh_account_updates.assert_awaited_once_with(
"TEST123"
)
portfolio_manager.ibkr.refresh_positions.assert_awaited_once()

@pytest.mark.asyncio
async def test_get_portfolio_positions_retries_on_account_timeout(
self, portfolio_manager, mocker
):
"""Retries when account update snapshot times out, then returns data."""
portfolio_manager.config.symbols = {}

sleep_mock = mocker.patch(
"thetagang.portfolio_manager.asyncio.sleep", new=mocker.AsyncMock()
)

portfolio_manager.ibkr.refresh_account_updates = mocker.AsyncMock(
side_effect=[
IBKRRequestTimeout("account updates", 1),
None,
]
)
portfolio_manager.ibkr.portfolio = mocker.Mock(return_value=[])
portfolio_manager.ibkr.refresh_positions = mocker.AsyncMock(return_value=[])

result = await portfolio_manager.get_portfolio_positions()

assert result == {}
assert portfolio_manager.ibkr.refresh_account_updates.await_count == 2
sleep_mock.assert_awaited()

@pytest.mark.asyncio
async def test_get_portfolio_positions_raises_after_missing_positions(
self, portfolio_manager, mocker
):
"""Raises when portfolio snapshot never includes tracked positions."""
portfolio_manager.config.symbols = {"AAPL": mocker.Mock()}

sleep_mock = mocker.patch(
"thetagang.portfolio_manager.asyncio.sleep", new=mocker.AsyncMock()
)

portfolio_manager.ibkr.refresh_account_updates = mocker.AsyncMock()
portfolio_manager.ibkr.portfolio = mocker.Mock(return_value=[])

tracked_position = SimpleNamespace(
account="TEST123",
contract=SimpleNamespace(symbol="AAPL", conId=1),
position=5,
)
portfolio_manager.ibkr.refresh_positions = mocker.AsyncMock(
return_value=[tracked_position]
)

with pytest.raises(RuntimeError):
await portfolio_manager.get_portfolio_positions()

assert portfolio_manager.ibkr.refresh_positions.await_count == 3
sleep_mock.assert_awaited()

@pytest.mark.asyncio
async def test_check_buy_only_positions(self, portfolio_manager, mocker):
"""Test check_buy_only_positions method."""
Expand Down
40 changes: 40 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,46 @@ def test_position_pnl() -> None:
)
assert round(position_pnl(spy_put), 2) == -0.13

zero_avg_cost = PortfolioItem(
contract=Stock(
conId=999001,
symbol="IWM",
right="0",
primaryExchange="ARCA",
currency="USD",
localSymbol="IWM",
tradingClass="IWM",
),
position=50.0,
marketPrice=200.0,
marketValue=10000.0,
averageCost=0.0,
unrealizedPNL=500.0,
realizedPNL=0.0,
account="DU2962946",
)
assert position_pnl(zero_avg_cost) == 0.0

flat_position = PortfolioItem(
contract=Stock(
conId=999002,
symbol="QQQ",
right="0",
primaryExchange="NASDAQ",
currency="USD",
localSymbol="QQQ",
tradingClass="QQQ",
),
position=0.0,
marketPrice=300.0,
marketValue=0.0,
averageCost=150.0,
unrealizedPNL=0.0,
realizedPNL=0.0,
account="DU2962946",
)
assert position_pnl(flat_position) == 0.0


def test_get_delta() -> None:
target_config = TargetConfigFactory.build(delta=0.5, puts=None, calls=None)
Expand Down
36 changes: 35 additions & 1 deletion thetagang/ibkr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from enum import Enum
from typing import Any, Awaitable, Callable, Coroutine, List, Optional
from typing import Any, Awaitable, Callable, Coroutine, List, Optional, TypeVar

from ib_async import (
IB,
Expand All @@ -10,6 +10,7 @@
OptionChain,
Order,
PortfolioItem,
Position,
Stock,
Ticker,
Trade,
Expand All @@ -35,6 +36,18 @@ def __init__(self, message: str) -> None:
super().__init__(self.message)


class IBKRRequestTimeout(RuntimeError):
"""Raised when an IBKR request does not complete within the configured timeout."""

def __init__(self, description: str, timeout_seconds: int) -> None:
super().__init__(
f"Timed out waiting for {description} after {timeout_seconds} seconds"
)


T = TypeVar("T")


class IBKR:
def __init__(
self, ib: IB, api_response_wait_time: int, default_order_exchange: str
Expand Down Expand Up @@ -79,6 +92,19 @@ def place_order(self, contract: Contract, order: Order) -> Trade:
def cancel_order(self, order: Order) -> None:
self.ib.cancelOrder(order)

async def refresh_account_updates(self, account: str) -> None:
await self._await_with_timeout(
self.ib.reqAccountUpdatesAsync(account), "account updates"
)

async def refresh_positions(self) -> List[Position]:
return await self._await_with_timeout(
self.ib.reqPositionsAsync(), "positions snapshot"
)

def positions(self, account: str) -> List[Position]:
return self.ib.positions(account)

async def get_chains_for_contract(self, contract: Contract) -> List[OptionChain]:
return await self.ib.reqSecDefOptParamsAsync(
contract.symbol, "", contract.secType, contract.conId
Expand Down Expand Up @@ -246,6 +272,14 @@ def orderStatusEvent(self, trade: Trade) -> None:
f"{trade.contract.symbol}: Order updated with status={trade.orderStatus.status}"
)

async def _await_with_timeout(self, awaitable: Awaitable[T], description: str) -> T:
try:
return await asyncio.wait_for(
awaitable, timeout=self.api_response_wait_time
)
except asyncio.TimeoutError as exc:
raise IBKRRequestTimeout(description, self.api_response_wait_time) from exc

async def __market_data_streaming_handler__(
self,
contract: Contract,
Expand Down
Loading