Skip to content

Commit

Permalink
Add news api (#555)
Browse files Browse the repository at this point in the history
* added news stream

* added news historical

* fixed flake8 format

* Added rest api test for news

* news api default limit and page limit, other pr comments

* pr comments

* pr comments
  • Loading branch information
ccnlui authored Jan 26, 2022
1 parent f36d6c4 commit 6132575
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 9 deletions.
10 changes: 10 additions & 0 deletions alpaca_trade_api/entity_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,16 @@ def __init__(self, raw):
self[k] = _convert_or_none(QuoteV2, v)


class NewsV2(Entity):
def __init__(self, raw):
super().__init__(raw)


class NewsListV2(list):
def __init__(self, raw):
super().__init__([NewsV2(o) for o in raw])


def _convert_or_none(entityType, value):
if value:
return entityType(value)
Expand Down
77 changes: 70 additions & 7 deletions alpaca_trade_api/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
)
from .entity_v2 import (
BarV2, BarsV2, LatestBarsV2, LatestQuotesV2, LatestTradesV2,
SnapshotV2, SnapshotsV2, TradesV2, TradeV2, QuotesV2, QuoteV2)
SnapshotV2, SnapshotsV2, TradesV2, TradeV2, QuotesV2, QuoteV2,
NewsV2, NewsListV2
)

logger = logging.getLogger(__name__)
Positions = List[Position]
Expand All @@ -32,8 +34,10 @@
TradeIterator = Iterator[Union[Trade, dict]]
QuoteIterator = Iterator[Union[Quote, dict]]
BarIterator = Iterator[Union[Bar, dict]]
NewsIterator = Iterator[Union[NewsV2, dict]]

DATA_V2_MAX_LIMIT = 10000 # max items per api call
NEWS_MAX_LIMIT = 50 # max items per api call


class RetryException(Exception):
Expand Down Expand Up @@ -129,6 +133,14 @@ def validate(amount: int, unit: TimeFrameUnit):
TimeFrame.Day = TimeFrame(1, TimeFrameUnit.Day)


class Sort(Enum):
Asc = "asc"
Desc = "desc"

def __str__(self):
return self.value


class REST(object):
def __init__(self,
key_id: str = None,
Expand Down Expand Up @@ -609,27 +621,34 @@ def _data_get(self,
symbol_or_symbols: Union[str, List[str]],
api_version: str = 'v2',
endpoint_base: str = 'stocks',
resp_grouped_by_symbol: Optional[bool] = None,
page_limit: int = DATA_V2_MAX_LIMIT,
**kwargs):
page_token = None
total_items = 0
limit = kwargs.get('limit')
if resp_grouped_by_symbol is None:
resp_grouped_by_symbol = not isinstance(symbol_or_symbols, str)
while True:
actual_limit = None
if limit:
actual_limit = min(int(limit) - total_items, DATA_V2_MAX_LIMIT)
actual_limit = min(int(limit) - total_items, page_limit)
if actual_limit < 1:
break
data = kwargs
data['limit'] = actual_limit
data['page_token'] = page_token
if isinstance(symbol_or_symbols, str):
path = f'/{endpoint_base}/{symbol_or_symbols}/{endpoint}'
path = f'/{endpoint_base}'
if isinstance(symbol_or_symbols, str) and symbol_or_symbols:
path += f'/{symbol_or_symbols}'
else:
path = f'/{endpoint_base}/{endpoint}'
data['symbols'] = ','.join(symbol_or_symbols)
if endpoint:
path += f'/{endpoint}'
resp = self.data_get(path, data=data, api_version=api_version)
if isinstance(symbol_or_symbols, str):
for item in resp.get(endpoint, []) or []:
if not resp_grouped_by_symbol:
k = endpoint or endpoint_base
for item in resp.get(k, []) or []:
yield item
total_items += 1
else:
Expand Down Expand Up @@ -893,6 +912,50 @@ def get_crypto_snapshot(self, symbol: str, exchange: str) -> SnapshotV2:
api_version='v1beta1')
return self.response_wrapper(resp, SnapshotV2)

def get_news_iter(self,
symbol: Optional[Union[str, List[str]]] = None,
start: Optional[str] = None,
end: Optional[str] = None,
limit: int = 10,
sort: Sort = Sort.Desc,
include_content: bool = False,
exclude_contentless: bool = False,
raw=False) -> NewsIterator:
symbol = symbol or []
# Avoid passing symbol as path param
if isinstance(symbol, str):
symbol = [symbol]
news = self._data_get('', symbol,
api_version='v1beta1', endpoint_base='news',
start=start, end=end, limit=limit, sort=sort,
include_content=include_content,
exclude_contentless=exclude_contentless,
resp_grouped_by_symbol=False,
page_limit=NEWS_MAX_LIMIT)
for n in news:
if raw:
yield n
else:
yield self.response_wrapper(n, NewsV2)

def get_news(self,
symbol: Optional[Union[str, List[str]]] = None,
start: Optional[str] = None,
end: Optional[str] = None,
limit: int = 10,
sort: Sort = Sort.Desc,
include_content: bool = False,
exclude_contentless: bool = False,

) -> NewsListV2:
news = list(self.get_news_iter(symbol=symbol,
start=start, end=end,
limit=limit, sort=sort,
include_content=include_content,
exclude_contentless=exclude_contentless,
raw=True))
return NewsListV2(news)

def get_clock(self) -> Clock:
resp = self.get('/clock')
return self.response_wrapper(resp, Clock)
Expand Down
84 changes: 82 additions & 2 deletions alpaca_trade_api/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
LULDV2,
CancelErrorV2,
CorrectionV2,
NewsV2,
)

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -432,6 +433,63 @@ def __init__(self,
self._name = 'crypto data'


class NewsDataStream(_DataStream):
def __init__(self,
key_id: str,
secret_key: str,
base_url: URL,
raw_data: bool):
self._key_id = key_id
self._secret_key = secret_key
base_url = re.sub(r'^http', 'ws', base_url)
endpoint = base_url + '/v1beta1/news'
super().__init__(endpoint=endpoint,
key_id=key_id,
secret_key=secret_key,
raw_data=raw_data,
)
self._handlers = {
'news': {},
}
self._name = 'news data'

def _cast(self, msg_type, msg):
result = super()._cast(msg_type, msg)
if not self._raw_data:
if msg_type == 'n':
result = NewsV2(msg)
return result

async def _dispatch(self, msg):
msg_type = msg.get('T')
symbol = msg.get('S')
if msg_type == 'n':
handler = self._handlers['news'].get(
symbol, self._handlers['news'].get('*', None))
if handler:
await handler(self._cast(msg_type, msg))
else:
await super()._dispatch(msg)

async def _unsubscribe(self, news=()):
if news:
await self._ws.send(
msgpack.packb({
'action': 'unsubscribe',
'news': news,
}))

def subscribe_news(self, handler, *symbols):
self._subscribe(handler, symbols, self._handlers['news'])

def unsubscribe_news(self, *symbols):
if self._running:
asyncio.get_event_loop().run_until_complete(
self._unsubscribe(news=symbols))
for symbol in symbols:
del self._handlers['news'][symbol]


class TradingStream:
def __init__(self,
key_id: str,
Expand Down Expand Up @@ -588,6 +646,10 @@ def __init__(self,
self._data_steam_url,
raw_data,
crypto_exchanges)
self._news_ws = NewsDataStream(self._key_id,
self._secret_key,
self._data_steam_url,
raw_data)

def subscribe_trade_updates(self, handler):
self._trading_ws.subscribe_trade_updates(handler)
Expand Down Expand Up @@ -634,6 +696,9 @@ def subscribe_crypto_bars(self, handler, *symbols):
def subscribe_crypto_daily_bars(self, handler, *symbols):
self._crypto_ws.subscribe_daily_bars(handler, *symbols)

def subscribe_news(self, handler, *symbols):
self._news_ws.subscribe_news(handler, *symbols)

def on_trade_update(self, func):
self.subscribe_trade_updates(func)
return func
Expand Down Expand Up @@ -722,6 +787,13 @@ def decorator(func):

return decorator

def on_news(self, *symbols):
def decorator(func):
self.subscribe_news(func, *symbols)
return func

return decorator

def unsubscribe_trades(self, *symbols):
self._data_ws.unsubscribe_trades(*symbols)
self._data_ws.unregister_handler("cancelErrors", *symbols)
Expand Down Expand Up @@ -754,10 +826,14 @@ def unsubscribe_crypto_bars(self, *symbols):
def unsubscribe_crypto_daily_bars(self, *symbols):
self._crypto_ws.unsubscribe_daily_bars(*symbols)

def unsubscribe_news(self, *symbols):
self._news_ws.unsubscribe_news(*symbols)

async def _run_forever(self):
await asyncio.gather(self._trading_ws._run_forever(),
self._data_ws._run_forever(),
self._crypto_ws._run_forever())
self._crypto_ws._run_forever(),
self._news_ws._run_forever())

def run(self):
loop = asyncio.get_event_loop()
Expand All @@ -780,12 +856,16 @@ async def stop_ws(self):
if self._crypto_ws:
await self._crypto_ws.stop_ws()

if self._news_ws:
await self._news_ws.stop_ws()

def is_open(self):
"""
Checks if either of the websockets is open
:return:
"""
open_ws = self._trading_ws._ws or self._data_ws._ws or self._crypto_ws._ws # noqa
open_ws = (self._trading_ws._ws or self._data_ws._ws
or self._crypto_ws._ws or self._news_ws) # noqa
if open_ws:
return True
return False
71 changes: 71 additions & 0 deletions tests/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,77 @@ def test_data(reqmock):
assert msft_snapshot.prev_daily_bar is None
assert snapshots.get('INVALID') is None

# News
reqmock.get(
'https://data.alpaca.markets/v1beta1/news' +
'?symbols=AAPL,TSLA&limit=2',
text='''
{
"news": [
{
"id": 24994117,
"headline": "'Tesla Approved...",
"author": "Benzinga Newsdesk",
"created_at": "2022-01-11T13:50:47Z",
"updated_at": "2022-01-11T13:50:47Z",
"summary": "",
"url": "https://www.benzinga.com/news/some/path",
"images": [],
"symbols": [
"TSLA"
],
"source": "benzinga"
},
{
"id": 24993189,
"headline": "Dogecoin Is Down 80% ...",
"author": "Samyuktha Sriram",
"created_at": "2022-01-11T13:49:40Z",
"updated_at": "2022-01-11T13:49:41Z",
"summary": "Popular meme-based cryptocurrency...",
"url": "https://www.benzinga.com/markets/some/path",
"images": [
{
"size": "large",
"url": "https://cdn.benzinga.com/files/some.jpeg"
},
{
"size": "small",
"url": "https://cdn.benzinga.com/files/some.jpeg"
},
{
"size": "thumb",
"url": "https://cdn.benzinga.com/files/some.jpeg"
}
],
"symbols": [
"BTCUSD",
"DOGEUSD",
"SHIBUSD",
"TSLA"
],
"source": "benzinga"
}
]
}
'''
)
news = api.get_news(['AAPL', 'TSLA'], limit=2)
assert len(news) == 2
first = news[0]
assert first is not None
assert first.author == 'Benzinga Newsdesk'
assert 'TSLA' in first.symbols
assert first.source == 'benzinga'
assert type(first) == tradeapi.entity_v2.NewsV2
second = news[1]
assert second is not None
assert second.headline != ''
assert type(second.images) == list
assert 'TSLA' in second.symbols
assert second.source == 'benzinga'
assert type(second) == tradeapi.entity_v2.NewsV2


def test_timeframe(reqmock):
# Custom timeframe: Minutes
Expand Down

0 comments on commit 6132575

Please sign in to comment.