From 9416c147ea347cecc97118f4aa71fda0c7a37152 Mon Sep 17 00:00:00 2001 From: Jeb Bearer Date: Wed, 1 May 2024 14:36:29 -0400 Subject: [PATCH] Support multi-segment and empty API prefixes It is sometimes useful to have an API where different modules internally look like parts of the same module to clients. For example, in the sequencer we have `state/fee/:height/:account` and `state/blocks/:height/index`, which look like two different endpoints in `state` module, but are actually two separate modules, `state/fee` and `state/blocks`. For this, we need multi-segment API prefixes. Separately, it is often desirable when a service has only one module to host that module at the root URL, instead of adding some dummy prefix like `/api`. For this, we need empty API prefixes. This change generalizes API dispatching to sequences of path segments rather than singular path segments, using a trie data structure to match (prefixes of) URL paths with API modules. We enforce that only the leaves of this trie contain actual data; in other words, it is disallowed, for clarity's sake, to have an API module whose path is a prefix of a different module. --- src/app.rs | 435 +++++++++++++++++++++++++++++++++------------- src/dispatch.rs | 353 +++++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + src/middleware.rs | 6 +- 4 files changed, 671 insertions(+), 124 deletions(-) create mode 100644 src/dispatch.rs diff --git a/src/app.rs b/src/app.rs index c6d95c1c..67248f37 100644 --- a/src/app.rs +++ b/src/app.rs @@ -6,6 +6,7 @@ use crate::{ api::{Api, ApiError, ApiInner, ApiVersion}, + dispatch::{self, DispatchError, Trie}, healthcheck::{HealthCheck, HealthStatus}, http, method::Method, @@ -16,9 +17,9 @@ use crate::{ Html, StatusCode, }; use async_std::sync::Arc; +use derive_more::From; use futures::future::{BoxFuture, FutureExt}; use include_dir::{include_dir, Dir}; -use itertools::Itertools; use lazy_static::lazy_static; use maud::{html, PreEscaped}; use semver::Version; @@ -26,10 +27,7 @@ use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DisplayFromStr}; use snafu::{ResultExt, Snafu}; use std::{ - collections::{ - btree_map::{BTreeMap, Entry as BTreeEntry}, - hash_map::{Entry as HashEntry, HashMap}, - }, + collections::btree_map::BTreeMap, convert::Infallible, env, fmt::Display, @@ -58,24 +56,23 @@ pub use tide::listener::{Listener, ToListener}; /// use by any given API module may differ, depending on the supported version of the API. #[derive(Debug)] pub struct App { - // Map from base URL, major version to API. - pub(crate) apis: HashMap>>, + pub(crate) modules: Trie>, pub(crate) state: Arc, app_version: Option, } /// An error encountered while building an [App]. -#[derive(Clone, Debug, Snafu, PartialEq, Eq)] +#[derive(Clone, Debug, From, Snafu, PartialEq, Eq)] pub enum AppError { Api { source: ApiError }, - ModuleAlreadyExists, + Dispatch { source: DispatchError }, } impl App { /// Create a new [App] with a given state. pub fn with_state(state: State) -> Self { Self { - apis: HashMap::new(), + modules: Default::default(), state: Arc::new(state), app_version: None, } @@ -158,20 +155,8 @@ impl App { } }; - match self.apis.entry(base_url.to_string()) { - HashEntry::Occupied(mut e) => match e.get_mut().entry(major_version) { - BTreeEntry::Occupied(_) => { - return Err(AppError::ModuleAlreadyExists); - } - BTreeEntry::Vacant(e) => { - e.insert(api); - } - }, - HashEntry::Vacant(e) => { - e.insert([(major_version, api)].into()); - } - } - + self.modules + .insert(dispatch::split(base_url), major_version, api)?; Ok(self) } @@ -212,12 +197,17 @@ impl App { app_version: self.app_version.clone(), disco_version: env!("CARGO_PKG_VERSION").parse().unwrap(), modules: self - .apis + .modules .iter() - .map(|(name, versions)| { + .map(|module| { ( - name.clone(), - versions.values().rev().map(|api| api.version()).collect(), + module.path(), + module + .versions + .values() + .rev() + .map(|api| api.version()) + .collect(), ) }) .collect(), @@ -231,19 +221,22 @@ impl App { /// (due to type erasure) but can be queried using [module_health](Self::module_health) or by /// hitting the endpoint `GET /:module/healthcheck`. pub async fn health(&self, req: RequestParams, state: &State) -> AppHealth { - let mut modules = BTreeMap::>::new(); + let mut modules_health = BTreeMap::>::new(); let mut status = HealthStatus::Available; - for (name, versions) in &self.apis { - let module = modules.entry(name.clone()).or_default(); - for (version, api) in versions { + for module in &self.modules { + let versions_health = modules_health.entry(module.path()).or_default(); + for (version, api) in &module.versions { let health = StatusCode::from(api.health(req.clone(), state).await.status()); if health != StatusCode::Ok { status = HealthStatus::Unhealthy; } - module.insert(*version, health); + versions_health.insert(*version, health); } } - AppHealth { status, modules } + AppHealth { + status, + modules: modules_health, + } } /// Check the health of the named module. @@ -264,10 +257,10 @@ impl App { module: &str, major_version: Option, ) -> Option { - let versions = self.apis.get(module)?; + let module = self.modules.get(dispatch::split(module))?; let api = match major_version { - Some(v) => versions.get(&v)?, - None => versions.last_key_value()?.1, + Some(v) => module.versions.get(&v)?, + None => module.versions.last_key_value()?.1, }; Some(api.health(req, state).await) } @@ -317,35 +310,41 @@ where .allow_credentials(true), ); - for (name, versions) in &state.apis { - Self::register_api(&mut server, name.clone(), versions)?; + for module in &state.modules { + Self::register_api(&mut server, module.prefix.clone(), &module.versions)?; } - // Register app-level automatic routes: `healthcheck` and `version`. - server - .at("healthcheck") - .get(move |req: tide::Request>| async move { - let state = req.state().clone(); - let app_state = &*state.state; - let req = request_params(req, &[]).await?; - let accept = req.accept()?; - let res = state.health(req, app_state).await; - Ok(health_check_response::<_, VER>(&accept, res)) - }); - server - .at("version") - .get(move |req: tide::Request>| async move { - let accept = RequestParams::accept_from_headers(&req)?; - respond_with(&accept, req.state().version(), bind_version) - .map_err(|err| Error::from_route_error::(err).into_tide_error()) - }); - - // Serve documentation at the root URL for discoverability - server - .at("/") - .all(move |req: tide::Request>| async move { - Ok(tide::Response::from(Self::top_level_docs(req))) - }); + // Register app-level routes summarizing the status and documentation of all the registered + // modules. We skip this step if this is a singleton app with only one module registered at + // the root URL, as these app-level endpoints would conflict with the (probably more + // specific) API-level status endpoints. + if !state.modules.is_singleton() { + // Register app-level automatic routes: `healthcheck` and `version`. + server + .at("healthcheck") + .get(move |req: tide::Request>| async move { + let state = req.state().clone(); + let app_state = &*state.state; + let req = request_params(req, &[]).await?; + let accept = req.accept()?; + let res = state.health(req, app_state).await; + Ok(health_check_response::<_, VER>(&accept, res)) + }); + server + .at("version") + .get(move |req: tide::Request>| async move { + let accept = RequestParams::accept_from_headers(&req)?; + respond_with(&accept, req.state().version(), bind_version) + .map_err(|err| Error::from_route_error::(err).into_tide_error()) + }); + + // Serve documentation at the root URL for discoverability + server + .at("/") + .all(move |req: tide::Request>| async move { + Ok(tide::Response::from(Self::top_level_docs(req))) + }); + } server.listen(listener).await } @@ -353,22 +352,22 @@ where fn list_apis(&self) -> Html { html! { ul { - @for (name, versions) in &self.apis { + @for module in &self.modules { li { // Link to the alias for the latest version as the primary link. - a href=(format!("/{}", name)) {(name)} + a href=(format!("/{}", module.path())) {(module.path())} // Add a superscript link (link a footnote) for each specific supported // version, linking to documentation for that specific version. - @for version in versions.keys().rev() { + @for version in module.versions.keys().rev() { sup { - a href=(format!("/v{version}/{name}")) { + a href=(format!("/v{version}/{}", module.path())) { (format!("[v{version}]")) } } } " " // Take the description of the latest supported version. - (PreEscaped(versions.last_key_value().unwrap().1.short_description())) + (PreEscaped(module.versions.last_key_value().unwrap().1.short_description())) } } } @@ -377,7 +376,7 @@ where fn register_api( server: &mut tide::Server>, - prefix: String, + prefix: Vec, versions: &BTreeMap>, ) -> io::Result<()> { for (version, api) in versions { @@ -388,7 +387,7 @@ where fn register_api_version( server: &mut tide::Server>, - prefix: &String, + prefix: &[String], version: u64, api: &ApiInner, ) -> io::Result<()> { @@ -400,11 +399,16 @@ where server .at("/public") .at(&format!("v{version}")) - .at(prefix) + .at(&prefix.join("/")) .serve_dir(api.public().unwrap_or_else(|| &DEFAULT_PUBLIC_PATH))?; // Register routes for this API. - let mut api_endpoint = server.at(&format!("/v{version}/{prefix}")); + let mut version_endpoint = server.at(&format!("/v{version}")); + let mut api_endpoint = if prefix.is_empty() { + version_endpoint + } else { + version_endpoint.at(&prefix.join("/")) + }; api_endpoint.with(AddErrorBody::new(api.error_handler())); for (path, routes) in api.routes_by_path() { let mut endpoint = api_endpoint.at(path); @@ -418,7 +422,7 @@ where // If there is a socket route with this pattern, add the socket middleware to // all endpoints registered under this pattern, so that any request with any // method that has the socket upgrade headers will trigger a WebSockets upgrade. - Self::register_socket(prefix.to_owned(), version, &mut endpoint, socket_route); + Self::register_socket(prefix.to_vec(), version, &mut endpoint, socket_route); } if let Some(metrics_route) = routes .iter() @@ -428,13 +432,13 @@ where // all endpoints registered under this pattern, so that a request to this path // with the right headers will return metrics instead of going through the // normal method-based dispatching. - Self::register_metrics(prefix.to_owned(), version, &mut endpoint, metrics_route); + Self::register_metrics(prefix.to_vec(), version, &mut endpoint, metrics_route); } // Register the HTTP routes. for route in routes { if let Method::Http(method) = route.method() { - Self::register_route(prefix.to_owned(), version, &mut endpoint, route, method); + Self::register_route(prefix.to_vec(), version, &mut endpoint, route, method); } } } @@ -442,26 +446,26 @@ where // Register automatic routes for this API: documentation, `healthcheck` and `version`. Serve // documentation at the root of the API (with or without a trailing slash). for path in ["", "/"] { - let prefix = prefix.clone(); + let prefix = prefix.to_vec(); api_endpoint .at(path) .all(move |req: tide::Request>| { let prefix = prefix.clone(); async move { - let api = &req.state().clone().apis[&prefix][&version]; + let api = &req.state().clone().modules[&prefix].versions[&version]; Ok(api.documentation()) } }); } { - let prefix = prefix.clone(); + let prefix = prefix.to_vec(); api_endpoint .at("*path") .all(move |req: tide::Request>| { let prefix = prefix.clone(); async move { // The request did not match any route. Serve documentation for the API. - let api = &req.state().clone().apis[&prefix][&version]; + let api = &req.state().clone().modules[&prefix].versions[&version]; let docs = html! { "No route matches /" (req.param("path")?) br{} @@ -474,13 +478,13 @@ where }); } { - let prefix = prefix.clone(); + let prefix = prefix.to_vec(); api_endpoint .at("healthcheck") .get(move |req: tide::Request>| { let prefix = prefix.clone(); async move { - let api = &req.state().clone().apis[&prefix][&version]; + let api = &req.state().clone().modules[&prefix].versions[&version]; let state = req.state().clone(); Ok(api .health(request_params(req, &[]).await?, &state.state) @@ -489,13 +493,13 @@ where }); } { - let prefix = prefix.clone(); + let prefix = prefix.to_vec(); api_endpoint .at("version") .get(move |req: tide::Request>| { let prefix = prefix.clone(); async move { - let api = &req.state().apis[&prefix][&version]; + let api = &req.state().modules[&prefix].versions[&version]; let accept = RequestParams::accept_from_headers(&req)?; api.version_handler()(&accept, api.version()) .map_err(|err| Error::from_route_error(err).into_tide_error()) @@ -507,7 +511,7 @@ where } fn register_route( - api: String, + api: Vec, version: u64, endpoint: &mut tide::Route>, route: &Route, @@ -518,7 +522,7 @@ where let name = name.clone(); let api = api.clone(); async move { - let route = &req.state().clone().apis[&api][&version][&name]; + let route = &req.state().clone().modules[&api].versions[&version][&name]; let state = &*req.state().clone().state; let req = request_params(req, route.params()).await?; route @@ -534,7 +538,7 @@ where } fn register_metrics( - api: String, + api: Vec, version: u64, endpoint: &mut tide::Route>, route: &Route, @@ -560,7 +564,7 @@ where } fn register_socket( - api: String, + api: Vec, version: u64, endpoint: &mut tide::Route>, route: &Route, @@ -576,7 +580,7 @@ where let name = name.clone(); let api = api.clone(); async move { - let route = &req.state().clone().apis[&api][&version][&name]; + let route = &req.state().clone().modules[&api].versions[&version][&name]; let state = &*req.state().clone().state; let req = request_params(req, route.params()).await?; route @@ -608,7 +612,7 @@ where } fn register_fallback( - api: String, + api: Vec, version: u64, endpoint: &mut tide::Route>, route: &Route, @@ -618,7 +622,7 @@ where let name = name.clone(); let api = api.clone(); async move { - let route = &req.state().clone().apis[&api][&version][&name]; + let route = &req.state().clone().modules[&api].versions[&version][&name]; route .default_handler() .map_err(|err| match err { @@ -637,12 +641,13 @@ where next: tide::Next>, ) -> BoxFuture { async move { - let Some(mut path) = req.url().path_segments() else { + let Some(path) = req.url().path_segments() else { // If we can't parse the path, we can't run this middleware. Do our best by // continuing the request processing lifecycle. return Ok(next.run(req).await); }; - let Some(seg1) = path.next() else { + let path = path.collect::>(); + let Some(seg1) = path.first() else { // This is the root URL, with no path segments. Nothing for this middleware to do. return Ok(next.run(req).await); }; @@ -651,32 +656,25 @@ where return Ok(next.run(req).await); } - // The first segment is either a version identifier or an API identifier (implicitly - // requesting the latest version of the API). We handle these cases differently. + // The first segment is either a version identifier or (part of) an API identifier + // (implicitly requesting the latest version of the API). We handle these cases + // differently. if let Some(version) = seg1.strip_prefix('v').and_then(|n| n.parse().ok()) { // If the version identifier is present, we probably don't need a redirect. However, // we still check if this is a valid version for the request API. If not, we will // serve documentation listing the available versions. - let Some(api) = path.next() else { - // A version identifier with no API is an error, serve documentation. - return Ok(Self::top_level_error( - req, - StatusCode::BadRequest, - "illegal version prefix without API specifier", - )); - }; - let Some(versions) = req.state().apis.get(api) else { - let message = format!("No API matches /{api}"); + let Some(module) = req.state().modules.search(&path[1..]) else { + let message = format!("No API matches /{}", path[1..].join("/")); return Ok(Self::top_level_error(req, StatusCode::NotFound, message)); }; - if versions.get(&version).is_none() { + if module.versions.get(&version).is_none() { // This version is not supported, list suported versions. return Ok(html! { "Unsupported version v" (version) ". Supported versions are:" ul { - @for v in versions.keys().rev() { + @for v in module.versions.keys().rev() { li { - a href=(format!("/v{v}/{api}")) { "v" (v) } + a href=(format!("/v{v}/{}", module.path())) { "v" (v) } } } } @@ -688,20 +686,21 @@ where // successfully by the route handlers for this API. Ok(next.run(req).await) } else { - // If the first path segment is not a version prefix, it is either the name of an - // API or one of the magic top-level endpoints (version, healthcheck), implicitly - // requesting the latest version. Validate the API and then redirect. - if ["version", "healthcheck"].contains(&seg1) { + // If the first path segment is not a version prefix, then the path is either the + // name of an API (implicitly requesting the latest version) or one of the magic + // top-level endpoints (version, healthcheck). Validate the API and then redirect. + if !req.state().modules.is_singleton() && ["version", "healthcheck"].contains(seg1) + { return Ok(next.run(req).await); } - let Some(versions) = req.state().apis.get(seg1) else { - let message = format!("No API matches /{seg1}"); + let Some(module) = req.state().modules.search(&path) else { + let message = format!("No API matches /{}", path.join("/")); return Ok(Self::top_level_error(req, StatusCode::NotFound, message)); }; - let latest_version = *versions.last_key_value().unwrap().0; + let latest_version = *module.versions.last_key_value().unwrap().0; let path = path.join("/"); - Ok(tide::Redirect::permanent(format!("/v{latest_version}/{seg1}/{path}")).into()) + Ok(tide::Redirect::permanent(format!("/v{latest_version}/{path}")).into()) } } .boxed() @@ -779,6 +778,7 @@ pub struct AppVersion { /// Note that if anything goes wrong during module registration (for example, there is already an /// incompatible module registered with the same name), the drop implementation may panic. To handle /// errors without panicking, call [`register`](Self::register) explicitly. +#[derive(Debug)] pub struct Module<'a, State, Error, ModuleError, ModuleVersion> where State: Send + Sync + 'static, @@ -1118,7 +1118,14 @@ mod test { .module::("mod", v1_toml) .unwrap(); api.with_version("1.1.1".parse().unwrap()); - assert_eq!(api.register().unwrap_err(), AppError::ModuleAlreadyExists); + assert_eq!( + api.register().unwrap_err(), + DispatchError::ModuleAlreadyExists { + prefix: "mod".into(), + version: 1, + } + .into() + ); } { let mut v3 = app @@ -1396,10 +1403,7 @@ mod test { .text() .await .unwrap(); - assert!( - docs.contains("illegal version prefix without API specifier"), - "{docs}" - ); + assert!(docs.contains("No API matches /"), "{docs}"); assert!(docs.contains(&expected_list_item), "{docs}"); } @@ -1455,6 +1459,8 @@ mod test { #[async_std::test] async fn test_format_versions() { + setup_test(); + // Register two modules with different binary format versions, each in turn different from // the app-level version. Each module has two endpoints, one which always succeeds and one // which always fails, so we can test error serialization. @@ -1610,4 +1616,191 @@ mod test { check_err::(&client, "mod02/err").await; check_err::(&client, "mod03/err").await; } + + #[async_std::test] + async fn test_api_prefix() { + setup_test(); + + // It is illegal to register two API modules where one is a prefix (in terms of route + // segments) of another. + for (api1, api2) in [ + ("", "api"), + ("api", ""), + ("path", "path/sub"), + ("path/sub", "path"), + ] { + tracing::info!(api1, api2, "test case"); + let (prefix, conflict) = if api1.len() < api2.len() { + (api1.to_string(), api2.to_string()) + } else { + (api2.to_string(), api1.to_string()) + }; + + let mut app = App::<_, ServerError>::with_state(()); + let toml = toml! { + route = {} + }; + app.module::(api1, toml.clone()) + .unwrap() + .register() + .unwrap(); + assert_eq!( + app.module::(api2, toml) + .unwrap() + .register() + .unwrap_err(), + DispatchError::ConflictingModules { prefix, conflict }.into() + ); + } + } + + #[async_std::test] + async fn test_singleton_api() { + setup_test(); + + // If there is only one API, it should be possible to register it with an empty prefix. + let toml = toml! { + [route.test] + PATH = ["/test"] + }; + let mut app = App::<_, ServerError>::with_state(()); + let mut api = app.module::("", toml).unwrap(); + api.with_version("0.1.0".parse().unwrap()) + .get("test", |_, _| async move { Ok("response") }.boxed()) + .unwrap(); + api.register().unwrap(); + + let port = pick_unused_port().unwrap(); + spawn(app.serve(format!("0.0.0.0:{port}"), StaticVer01::instance())); + let client = Client::new(format!("http://localhost:{port}").parse().unwrap()).await; + + // Test an endpoint. + let res = client.get("/test").send().await.unwrap(); + assert_eq!( + res.status(), + StatusCode::Ok, + "{}", + res.text().await.unwrap() + ); + assert_eq!(res.json::().await.unwrap(), "response"); + + // Test healthcheck and version endpoints. Since these would ordinarily conflict with the + // app-level healthcheck and version endpoints for an API with no prefix, we only get the + // API-level endpoints, so that a singleton API behaves like a normal API, while app-level + // stuff is reserved for non-trivial applications with more than one API. + let res = client.get("/healthcheck").send().await.unwrap(); + assert_eq!(res.status(), StatusCode::Ok); + assert_eq!( + res.json::().await.unwrap(), + HealthStatus::Available + ); + + let res = client.get("/version").send().await.unwrap(); + assert_eq!(res.status(), StatusCode::Ok); + assert_eq!( + res.json::().await.unwrap(), + ApiVersion { + api_version: Some("0.1.0".parse().unwrap()), + spec_version: "0.1.0".parse().unwrap(), + }, + ); + } + + #[async_std::test] + async fn test_multi_segment() { + setup_test(); + + let toml = toml! { + [route.test] + PATH = ["/test"] + }; + let mut app = App::<_, ServerError>::with_state(()); + + for name in ["a", "b"] { + let path = format!("api/{name}"); + let mut api = app + .module::(&path, toml.clone()) + .unwrap(); + api.with_version("0.1.0".parse().unwrap()) + .get("test", move |_, _| async move { Ok(name) }.boxed()) + .unwrap(); + api.register().unwrap(); + } + + let port = pick_unused_port().unwrap(); + spawn(app.serve(format!("0.0.0.0:{port}"), StaticVer01::instance())); + let client = Client::new(format!("http://localhost:{port}").parse().unwrap()).await; + + for api in ["a", "b"] { + tracing::info!(api, "testing api"); + + // Test an endpoint. + let res = client.get(&format!("api/{api}/test")).send().await.unwrap(); + assert_eq!(res.status(), StatusCode::Ok); + assert_eq!(res.json::().await.unwrap(), api); + + // Test healthcheck. + let res = client + .get(&format!("api/{api}/healthcheck")) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::Ok); + assert_eq!( + res.json::().await.unwrap(), + HealthStatus::Available + ); + + // Test version. + let res = client + .get(&format!("api/{api}/version")) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::Ok); + assert_eq!( + res.json::().await.unwrap().api_version.unwrap(), + "0.1.0".parse().unwrap() + ); + } + + // Test app-level healthcheck. + let res = client.get("healthcheck").send().await.unwrap(); + assert_eq!(res.status(), StatusCode::Ok); + assert_eq!( + res.json::().await.unwrap(), + AppHealth { + status: HealthStatus::Available, + modules: [ + ("api/a".into(), [(0, StatusCode::Ok)].into()), + ("api/b".into(), [(0, StatusCode::Ok)].into()), + ] + .into() + } + ); + + // Test app-level version. + let res = client.get("version").send().await.unwrap(); + assert_eq!(res.status(), StatusCode::Ok); + assert_eq!( + res.json::().await.unwrap().modules, + [ + ( + "api/a".into(), + vec![ApiVersion { + api_version: Some("0.1.0".parse().unwrap()), + spec_version: "0.1.0".parse().unwrap(), + }] + ), + ( + "api/b".into(), + vec![ApiVersion { + api_version: Some("0.1.0".parse().unwrap()), + spec_version: "0.1.0".parse().unwrap(), + }] + ), + ] + .into() + ); + } } diff --git a/src/dispatch.rs b/src/dispatch.rs new file mode 100644 index 00000000..49cac178 --- /dev/null +++ b/src/dispatch.rs @@ -0,0 +1,353 @@ +use itertools::Itertools; +use snafu::Snafu; +use std::{ + collections::{btree_map::Entry, BTreeMap}, + ops::Index, +}; + +pub use crate::join; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub(crate) struct Module { + pub(crate) prefix: Vec, + pub(crate) versions: BTreeMap, +} + +impl Module { + fn new(prefix: Vec) -> Self { + Self { + prefix, + versions: Default::default(), + } + } + + pub(crate) fn path(&self) -> String { + self.prefix.join("/") + } +} + +#[derive(Clone, Debug, Snafu, PartialEq, Eq)] +pub enum DispatchError { + #[snafu(display("duplicate module {prefix} v{version}"))] + ModuleAlreadyExists { prefix: String, version: u64 }, + #[snafu(display("module {prefix} cannot be a prefix of module {conflict}"))] + ConflictingModules { prefix: String, conflict: String }, +} + +/// Mapping from route prefixes to APIs. +#[derive(Debug)] +pub(crate) enum Trie { + Branch { + /// The route prefix represented by this node. + prefix: Vec, + /// APIs with this prefix, indexed by the next route segment. + children: BTreeMap>, + }, + Leaf { + /// APIs available at this prefix, sorted by version. + module: Module, + }, +} + +impl Default for Trie { + fn default() -> Self { + Self::Branch { + prefix: vec![], + children: Default::default(), + } + } +} + +impl Trie { + /// Whether this is a singleton [`Trie`]. + /// + /// A singleton [`Trie`] is one with only one module, registered under the empty prefix. Note + /// that any [`Trie`] with a module with an empty prefix must be singleton, because no other + /// modules would be permitted: the empty prefix is a prefix of every other module path. + pub(crate) fn is_singleton(&self) -> bool { + matches!(self, Self::Leaf { .. }) + } + + /// Insert a new API with a certain version under the given prefix. + pub(crate) fn insert( + &mut self, + prefix: I, + version: u64, + api: Api, + ) -> Result<(), DispatchError> + where + I: IntoIterator, + I::Item: Into, + { + let mut prefix = prefix.into_iter().map(|segment| segment.into()); + + // Traverse to a leaf matching `prefix`. + let mut curr = self; + while let Some(segment) = prefix.next() { + // If there are more segments in the prefix, we must be at a branch. + match curr { + Self::Branch { prefix, children } => { + // Move to the child associated with the next path segment, inserting an empty + // child if this is the first module we've seen that has this path as a prefix. + curr = children.entry(segment.clone()).or_insert_with(|| { + let mut prefix = prefix.clone(); + prefix.push(segment); + Box::new(Trie::Branch { + prefix, + children: Default::default(), + }) + }); + } + Self::Leaf { module } => { + // If there is a leaf here, then there is already a module registered which is a + // prefix of the new module. This is not allowed. + return Err(DispatchError::ConflictingModules { + prefix: module.path(), + conflict: join!(&module.path(), &segment, &prefix.join("/")), + }); + } + } + } + + // If we have reached the end of the prefix, we must be at either a leaf or a temporary + // empty branch that we can turn into a leaf. + if let Self::Branch { prefix, children } = curr { + if children.is_empty() { + *curr = Self::Leaf { + module: Module::new(prefix.clone()), + }; + } else { + // If we have a non-trival branch at the end of the desired prefix, there is already + // a module registered for which `prefix` is a strict prefix of the registered path. + // This is not allowed. To give a useful error message, follow the existing trie + // down to a leaf so we can give an example of a module which conflicts with this + // prefix. + let prefix = prefix.join("/"); + let conflict = loop { + match curr { + Self::Branch { children, .. } => { + curr = children + .values_mut() + .next() + .expect("malformed dispatch trie: empty branch"); + } + Self::Leaf { module } => { + break module.path(); + } + } + }; + return Err(DispatchError::ConflictingModules { prefix, conflict }); + } + } + let Self::Leaf { module } = curr else { + unreachable!(); + }; + + // Insert the new API, as long as there isn't already an API with the same version in this + // module. + let Entry::Vacant(e) = module.versions.entry(version) else { + return Err(DispatchError::ModuleAlreadyExists { + prefix: module.path(), + version, + }); + }; + e.insert(api); + Ok(()) + } + + /// Get the module named by `prefix`. + /// + /// This function is similar to [`search`](Self::search), except the given `prefix` must exactly + /// match the prefix under which a module is registered. + pub(crate) fn get(&self, prefix: I) -> Option<&Module> + where + I: IntoIterator, + I::Item: AsRef, + { + let mut iter = prefix.into_iter(); + let module = self.traverse(&mut iter)?; + // Check for exact match. + if iter.next().is_some() { + None + } else { + Some(module) + } + } + + /// Get the supported versions of the API identified by the given request path. + /// + /// If a prefix of `path` uniquely identifies a registered module, the module (with all + /// supported versions) is returned. + pub(crate) fn search(&self, path: I) -> Option<&Module> + where + I: IntoIterator, + I::Item: AsRef, + { + self.traverse(&mut path.into_iter()) + } + + /// Iterate over registered modules and their supported versions. + pub(crate) fn iter(&self) -> Iter { + Iter { stack: vec![self] } + } + + /// Internal implementation of `get` and `search`. + /// + /// Returns the matching module and advances the iterator past all the segments used in the + /// match. + fn traverse(&self, iter: &mut I) -> Option<&Module> + where + I: Iterator, + I::Item: AsRef, + { + let mut curr = self; + loop { + match curr { + Self::Branch { children, .. } => { + // Traverse to the next child based on the next segment in the path. + let segment = iter.next()?; + curr = children.get(segment.as_ref())?; + } + Self::Leaf { module } => return Some(module), + } + } + } +} + +pub(crate) struct Iter<'a, Api> { + stack: Vec<&'a Trie>, +} + +impl<'a, Api> Iterator for Iter<'a, Api> { + type Item = &'a Module; + + fn next(&mut self) -> Option { + loop { + match self.stack.pop()? { + Trie::Branch { children, .. } => { + // Push children onto the stack and start visiting them. We add them in reverse + // order so that we will visit the lexicographically first children first. + self.stack + .extend(children.values().rev().map(|boxed| &**boxed)); + } + Trie::Leaf { module } => return Some(module), + } + } + } +} + +impl<'a, Api> IntoIterator for &'a Trie { + type IntoIter = Iter<'a, Api>; + type Item = &'a Module; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl Index for Trie +where + I: IntoIterator, + I::Item: AsRef, +{ + type Output = Module; + + fn index(&self, index: I) -> &Self::Output { + self.get(index).unwrap() + } +} + +/// Split a path prefix into its segments. +/// +/// Leading and trailing slashes are ignored. That is, `/prefix/` yields only the single segment +/// `prefix`, with no preceding or following empty segments. +pub(crate) fn split(s: &str) -> impl '_ + Iterator { + s.split('/').filter(|seg| !seg.is_empty()) +} + +/// Join two path strings, ensuring there are no leading or trailing slashes. +pub(crate) fn join(s1: &str, s2: &str) -> String { + let s1 = s1.strip_prefix('/').unwrap_or(s1); + let s1 = s1.strip_suffix('/').unwrap_or(s1); + let s2 = s2.strip_prefix('/').unwrap_or(s2); + let s2 = s2.strip_suffix('/').unwrap_or(s2); + if s1.is_empty() { + s2.to_string() + } else if s2.is_empty() { + s1.to_string() + } else { + format!("{s1}/{s2}") + } +} + +#[macro_export] +macro_rules! join { + () => { String::new() }; + ($s:expr) => { $s }; + ($head:expr$(, $($tail:expr),*)?) => { + $crate::dispatch::join($head, &$crate::join!($($($tail),*)?)) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_empty_trie() { + let t = Trie::<()>::default(); + assert_eq!(t.iter().next(), None); + assert_eq!(t.get(["mod"]), None); + } + + #[test] + fn test_branch_trie() { + let mut t = Trie::default(); + + let mod_a = Module { + prefix: vec!["mod".into(), "a".into()], + versions: [(0, 0)].into(), + }; + let mod_b = Module { + prefix: vec!["mod".into(), "b".into()], + versions: [(1, 1)].into(), + }; + + t.insert(["mod", "a"], 0, 0).unwrap(); + t.insert(["mod", "b"], 1, 1).unwrap(); + + assert_eq!(t.iter().collect::>(), [&mod_a, &mod_b]); + + assert_eq!(t.search(["mod", "a", "route"]), Some(&mod_a)); + assert_eq!(t.get(["mod", "a"]), Some(&mod_a)); + assert_eq!(t.get(["mod", "a", "route"]), None); + + assert_eq!(t.search(["mod", "b", "route"]), Some(&mod_b)); + assert_eq!(t.get(["mod", "b"]), Some(&mod_b)); + assert_eq!(t.get(["mod", "b", "route"]), None); + + // Cannot register a module which is a prefix or suffix of the already registered modules. + t.insert(["mod"], 0, 0).unwrap_err(); + t.insert(Vec::::new(), 0, 0).unwrap_err(); + t.insert(["mod", "a", "b"], 0, 0).unwrap_err(); + } + + #[test] + fn test_null_prefix() { + let mut t = Trie::default(); + + let module = Module { + prefix: vec![], + versions: [(0, 0)].into(), + }; + t.insert(Vec::::new(), 0, 0).unwrap(); + + assert_eq!(t.iter().collect::>(), [&module]); + assert_eq!(t.search(["anything"]), Some(&module)); + assert_eq!(t.get(Vec::::new()), Some(&module)); + assert_eq!(t.get(["anything"]), None); + + // Any other module has the null module as a prefix and is thus not allowed. + t.insert(["anything"], 1, 1).unwrap_err(); + } +} diff --git a/src/lib.rs b/src/lib.rs index 92f7aa73..b9d562fe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -289,6 +289,7 @@ pub mod socket; pub mod status; pub mod testing; +mod dispatch; mod middleware; mod route; diff --git a/src/middleware.rs b/src/middleware.rs index 2040169e..1a306efa 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -99,12 +99,12 @@ where pub(crate) struct MetricsMiddleware { route: String, - api: String, + api: Vec, api_version: u64, } impl MetricsMiddleware { - pub(crate) fn new(route: String, api: String, api_version: u64) -> Self { + pub(crate) fn new(route: String, api: Vec, api_version: u64) -> Self { Self { route, api, @@ -148,7 +148,7 @@ where } // This is a metrics request, abort the rest of the dispatching chain and run the // metrics handler. - let route = &req.state().clone().apis[&api][&version][&route]; + let route = &req.state().clone().modules[&api].versions[&version][&route]; let state = &*req.state().clone().state; let req = request_params(req, route.params()).await?; route