Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TCP Echo Server example #96

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions examples/21_tcp_echo_server.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
//! This example demonstrates how to gracefully shutdown a server
//! that spawns an indefinite number of connection tasks.
//!
//! The server is a simple TCP echo server, capitalizing the data
//! it echos (to demonstrate that it computes things).
//! On shutdown, it transmits a goodbye message, to demonstrate
//! that during shutdown we can still perform cleanup steps.
//!
//! This example is similar to the hyper example; for a more complex
//! version of this same example, look there.

use miette::{Context, IntoDiagnostic, Result};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::time::Duration;
use tokio_graceful_shutdown::errors::CancelledByShutdown;
use tokio_graceful_shutdown::{FutureExt, SubsystemBuilder, SubsystemHandle, Toplevel};

use std::net::SocketAddr;

use tokio::net::{TcpListener, TcpStream};
use tokio_util::task::TaskTracker;

async fn echo_connection(tcp: &mut TcpStream) -> Result<()> {
tcp.write_all(b"Hello!\r\n").await.into_diagnostic()?;

let mut buffer = [0u8; 256];
loop {
match tcp.read(&mut buffer).await {
Ok(0) => return Ok(()),
Err(e) => return Err(e).into_diagnostic(),
Ok(len) => {
let bytes = &mut buffer[..len];
for byte in bytes.iter_mut() {
*byte = byte.to_ascii_uppercase();
}
tcp.write_all(bytes).await.into_diagnostic()?;
}
}
}
}

async fn echo_connection_shutdown(tcp: &mut TcpStream) -> Result<()> {
tcp.write_all(b"Goodbye.\r\n").await.into_diagnostic()?;
tcp.shutdown().await.into_diagnostic()?;

Ok(())
}

async fn connection_handler(
subsys: SubsystemHandle,
listener: TcpListener,
connection_tracker: TaskTracker,
) -> Result<()> {
loop {
let connection = match listener.accept().cancel_on_shutdown(&subsys).await {
Ok(connection) => connection,
Err(CancelledByShutdown) => break,
};
let (mut tcp, addr) = connection
.into_diagnostic()
.context("Error while waiting for connection")?;

// Spawn handler on connection tracker to give the parent subsystem
// the chance to wait for the shutdown to finish
connection_tracker.spawn({
let cancellation_token = subsys.create_cancellation_token();
async move {
tracing::info!("Connected to {} ...", addr);

let result = tokio::select! {
e = echo_connection(&mut tcp) => e,
_ = cancellation_token.cancelled() => {
tracing::info!("Shutting down {} ...", addr);
echo_connection_shutdown(&mut tcp).await
},
};

if let Err(err) = result {
tracing::warn!("Error serving connection: {:?}", err);
} else {
tracing::info!("Connection to {} closed.", addr);
}
}
});
}

Ok(())
}

async fn echo_subsystem(subsys: SubsystemHandle) -> Result<()> {
let addr: SocketAddr = ([127, 0, 0, 1], 12345).into();

// Bind to the port and listen for incoming TCP connections
let listener = TcpListener::bind(addr)
.await
.into_diagnostic()
.context("Unable to start tcp server")?;
tracing::info!("Listening on {}", addr);

// Use a tasktracker instead of spawning a subsystem for every connection,
// as this would result in a lot of overhead.
let connection_tracker = TaskTracker::new();

let listener = subsys.start(SubsystemBuilder::new("Echo Listener", {
let connection_tracker = connection_tracker.clone();
move |subsys| connection_handler(subsys, listener, connection_tracker)
}));

// Make sure no more tasks can be spawned before we close the tracker
listener.join().await?;

// Wait for connections to close
connection_tracker.close();
connection_tracker.wait().await;

Ok(())
}

#[tokio::main]
async fn main() -> Result<()> {
// Init logging
tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.init();

// Setup and execute subsystem tree
Toplevel::new(|s| async move {
s.start(SubsystemBuilder::new("EchoServer", echo_subsystem));
})
.catch_signals()
.handle_shutdown_requests(Duration::from_secs(5))
.await
.map_err(Into::into)
}
Loading