Skip to content

Commit

Permalink
feat(target_chains/starknet): add query_price_feed methods (#1627)
Browse files Browse the repository at this point in the history
  • Loading branch information
Riateche authored May 29, 2024
1 parent 8fc116c commit a666702
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 1 deletion.
35 changes: 35 additions & 0 deletions target_chains/starknet/contracts/src/pyth.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,41 @@ mod pyth {
Result::Ok(price)
}

fn query_price_feed_no_older_than(
self: @ContractState, price_id: u256, age: u64
) -> Result<PriceFeed, GetPriceNoOlderThanError> {
let feed = self.query_price_feed_unsafe(price_id).map_err_into()?;
if !is_no_older_than(feed.price.publish_time, age) {
return Result::Err(GetPriceNoOlderThanError::StalePrice);
}
Result::Ok(feed)
}

fn query_price_feed_unsafe(
self: @ContractState, price_id: u256
) -> Result<PriceFeed, GetPriceUnsafeError> {
let info = self.latest_price_info.read(price_id);
if info.publish_time == 0 {
return Result::Err(GetPriceUnsafeError::PriceFeedNotFound);
}
let feed = PriceFeed {
id: price_id,
price: Price {
price: info.price,
conf: info.conf,
expo: info.expo,
publish_time: info.publish_time,
},
ema_price: Price {
price: info.ema_price,
conf: info.ema_conf,
expo: info.expo,
publish_time: info.publish_time,
},
};
Result::Ok(feed)
}

fn update_price_feeds(ref self: ContractState, data: ByteArray) {
self.update_price_feeds_internal(data, array![], 0, 0, false);
}
Expand Down
4 changes: 4 additions & 0 deletions target_chains/starknet/contracts/src/pyth/interface.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ pub trait IPyth<T> {
self: @T, price_id: u256, age: u64
) -> Result<Price, GetPriceNoOlderThanError>;
fn get_ema_price_unsafe(self: @T, price_id: u256) -> Result<Price, GetPriceUnsafeError>;
fn query_price_feed_no_older_than(
self: @T, price_id: u256, age: u64
) -> Result<PriceFeed, GetPriceNoOlderThanError>;
fn query_price_feed_unsafe(self: @T, price_id: u256) -> Result<PriceFeed, GetPriceUnsafeError>;
fn update_price_feeds(ref self: T, data: ByteArray);
fn update_price_feeds_if_necessary(
ref self: T, update: ByteArray, required_publish_times: Array<PriceFeedPublishTime>
Expand Down
31 changes: 30 additions & 1 deletion target_chains/starknet/contracts/tests/pyth.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use snforge_std::{
use pyth::pyth::{
IPythDispatcher, IPythDispatcherTrait, DataSource, Event as PythEvent, PriceFeedUpdated,
WormholeAddressSet, GovernanceDataSourceSet, ContractUpgraded, DataSourcesSet, FeeSet,
PriceFeedPublishTime, GetPriceNoOlderThanError, Price, PriceFeed,
PriceFeedPublishTime, GetPriceNoOlderThanError, Price, PriceFeed, GetPriceUnsafeError,
};
use pyth::byte_array::{ByteArray, ByteArrayImpl};
use pyth::util::{array_try_into, UnwrapWithFelt252};
Expand Down Expand Up @@ -131,6 +131,19 @@ fn update_price_feeds_works() {
assert!(last_ema_price.conf == 4096812700);
assert!(last_ema_price.expo == -8);
assert!(last_ema_price.publish_time == 1712589206);

let feed = pyth
.query_price_feed_unsafe(0xe62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43)
.unwrap_with_felt252();
assert!(feed.id == 0xe62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43);
assert!(feed.price.price == 7192002930010);
assert!(feed.price.conf == 3596501465);
assert!(feed.price.expo == -8);
assert!(feed.price.publish_time == 1712589206);
assert!(feed.ema_price.price == 7181868900000);
assert!(feed.ema_price.conf == 4096812700);
assert!(feed.ema_price.expo == -8);
assert!(feed.ema_price.publish_time == 1712589206);
}

#[test]
Expand Down Expand Up @@ -407,10 +420,18 @@ fn test_get_no_older_works() {
let fee_contract = deploy_fee_contract(user);
let pyth = deploy_default(wormhole.contract_address, fee_contract.contract_address);
let price_id = 0xe62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43;
let err = pyth.get_price_unsafe(price_id).unwrap_err();
assert!(err == GetPriceUnsafeError::PriceFeedNotFound);
let err = pyth.get_ema_price_unsafe(price_id).unwrap_err();
assert!(err == GetPriceUnsafeError::PriceFeedNotFound);
let err = pyth.query_price_feed_unsafe(price_id).unwrap_err();
assert!(err == GetPriceUnsafeError::PriceFeedNotFound);
let err = pyth.get_price_no_older_than(price_id, 100).unwrap_err();
assert!(err == GetPriceNoOlderThanError::PriceFeedNotFound);
let err = pyth.get_ema_price_no_older_than(price_id, 100).unwrap_err();
assert!(err == GetPriceNoOlderThanError::PriceFeedNotFound);
let err = pyth.query_price_feed_no_older_than(price_id, 100).unwrap_err();
assert!(err == GetPriceNoOlderThanError::PriceFeedNotFound);

start_prank(CheatTarget::One(fee_contract.contract_address), user.try_into().unwrap());
fee_contract.approve(pyth.contract_address, 10000);
Expand All @@ -425,6 +446,8 @@ fn test_get_no_older_works() {
assert!(err == GetPriceNoOlderThanError::StalePrice);
let err = pyth.get_ema_price_no_older_than(price_id, 3).unwrap_err();
assert!(err == GetPriceNoOlderThanError::StalePrice);
let err = pyth.query_price_feed_no_older_than(price_id, 3).unwrap_err();
assert!(err == GetPriceNoOlderThanError::StalePrice);

start_warp(CheatTarget::One(pyth.contract_address), 1712589208);
let val = pyth.get_price_no_older_than(price_id, 3).unwrap_with_felt252();
Expand All @@ -433,6 +456,9 @@ fn test_get_no_older_works() {
let val = pyth.get_ema_price_no_older_than(price_id, 3).unwrap_with_felt252();
assert!(val.publish_time == 1712589206);
assert!(val.price == 7181868900000);
let val = pyth.query_price_feed_no_older_than(price_id, 3).unwrap_with_felt252();
assert!(val.price.publish_time == 1712589206);
assert!(val.price.price == 7192002930010);

start_warp(CheatTarget::One(pyth.contract_address), 1712589204);
let val = pyth.get_price_no_older_than(price_id, 3).unwrap_with_felt252();
Expand All @@ -441,6 +467,9 @@ fn test_get_no_older_works() {
let val = pyth.get_ema_price_no_older_than(price_id, 3).unwrap_with_felt252();
assert!(val.publish_time == 1712589206);
assert!(val.price == 7181868900000);
let val = pyth.query_price_feed_no_older_than(price_id, 3).unwrap_with_felt252();
assert!(val.price.publish_time == 1712589206);
assert!(val.price.price == 7192002930010);

stop_warp(CheatTarget::One(pyth.contract_address));
}
Expand Down

0 comments on commit a666702

Please sign in to comment.