From bd6a36f72ee3fedc348ab304485d43fe23258866 Mon Sep 17 00:00:00 2001 From: akth <32414918+AthLw@users.noreply.github.com> Date: Wed, 30 Oct 2024 08:37:47 +0800 Subject: [PATCH] add connect subcommand (#57) * first commit * modify time depend version * separate connect handler from origin * read config from yaml --- bin/Cargo.toml | 2 + bin/src/cs.rs | 97 +++++++- bin/src/main.rs | 15 ++ bin/src/manager.rs | 43 +++- bin/src/peer/conn.rs | 84 ++++++- bin/src/peer/connect.rs | 525 ++++++++++++++++++++++++++++++++++++++++ bin/src/peer/mod.rs | 50 +++- 7 files changed, 804 insertions(+), 12 deletions(-) create mode 100644 bin/src/peer/connect.rs diff --git a/bin/Cargo.toml b/bin/Cargo.toml index baa896a6..2e87d3cc 100644 --- a/bin/Cargo.toml +++ b/bin/Cargo.toml @@ -24,3 +24,5 @@ webrtc = "0.9.0" serde_yaml = "0.9.30" notify = { version = "6.1.1", default-features = false, features = ["macos_kqueue"] } futures = "0.3.30" +time = "0.3.35" +reqwest = { version = "0.11", features = ["json"] } diff --git a/bin/src/cs.rs b/bin/src/cs.rs index 8a0df5a1..fd9b3723 100644 --- a/bin/src/cs.rs +++ b/bin/src/cs.rs @@ -20,10 +20,19 @@ #![allow(unused)] use clap::Args; +use log::info; use serde::{Deserialize, Serialize}; -use std::ffi::{c_char, c_void, CString}; +use std::{ffi::{c_char, c_void, CString}, fmt::Debug, process::ExitCode}; include!("cs_bindings.rs"); +use std::fs; +use std::path::Path; +use std::process; +use std::sync::Arc; +use tokio::sync::Mutex; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use crate::peer::*; + #[derive(Args, Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] pub struct ServerArgs { /// Config file path @@ -38,6 +47,13 @@ pub struct ClientArgs { pub config: Option, } +#[derive(Args, Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)] +pub struct ConnectArgs { + /// Config file path + #[arg(short, long)] + pub config: Option, +} + fn convert_to_go_slices(vec: &Vec) -> (GoSlice, Vec) { let mut go_slices: Vec = Vec::with_capacity(vec.len()); @@ -57,6 +73,85 @@ fn convert_to_go_slices(vec: &Vec) -> (GoSlice, Vec) { go_slices, ) } + +fn load_config(config_path: &str) -> Result> { + // 验证文件是否存在 + if !Path::new(config_path).exists() { + return Err(format!("Config file '{}' does not exist", config_path).into()); + } + + // 读取文件内容 + let config_content = fs::read_to_string(config_path) + .map_err(|e| format!("Failed to read config file '{}': {}", config_path, e))?; + + // 验证文件不为空 + if config_content.trim().is_empty() { + return Err("Config file is empty".into()); + } + + // 解析 YAML + let config: ConnectConfig = serde_yaml::from_str(&config_content) + .map_err(|e| format!("Failed to parse YAML config: {}", e))?; + + match serde_yaml::from_str::(&config_content) { + Ok(config) => println!("解析成功: {:?}", config), + Err(e) => println!("解析错误: {}", e), + } + + // 验证必要的字段 + validate_config(&config)?; + + Ok(config) +} + +fn validate_config(config: &ConnectConfig) -> Result<(), Box> { + // 配置验证 + if config.options.tcp_forward_addr.trim().is_empty() { + return Err("tcp_forward_addr cannot be empty".into()); + } + if config.options.tcp_forward_host_prefix.trim().is_empty() { + return Err("tcp_forward_host_prefix cannot be empty".into()); + } + Ok(()) +} + + +pub fn run_connect(connect_args: ConnectArgs) { + let mut args = if let Some(config_path) = &connect_args.config { + match load_config(config_path) { + Ok(config) => { + println!("Successfully loaded config from '{}'", config_path); + println!("Config details:"); + println!(" TCP Forward Address: {}", config.options.tcp_forward_addr); + println!(" TCP Forward Host Prefix: {}", config.options.tcp_forward_host_prefix); + config + }, + Err(e) => { + eprintln!("Error loading config: {}", e); + process::exit(1); + } + } + } else { + println!("No config file specified, using default configuration"); + ConnectConfig::default() + }; + info!("Run connect cmd."); + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async move { + info!("Runtime started."); + let connect_reader = tokio::io::stdin(); + let connect_writer = tokio::io::stdout(); + // let reader = Arc::new(Mutex::new(connect_reader)); + // let writer = Arc::new(Mutex::new(connect_writer)); + if let Err(e) = process_connect(connect_reader, connect_writer, args).await { + eprintln!("process p2p connect: {}", e); + process::exit(1); + }; + }); + unsafe {} + // TODO +} + pub fn run_client(client_args: ClientArgs) { let mut args = if let Some(config) = client_args.config { vec!["client".to_owned(), "-config".to_owned(), config] diff --git a/bin/src/main.rs b/bin/src/main.rs index 151b47dc..94182a32 100644 --- a/bin/src/main.rs +++ b/bin/src/main.rs @@ -18,6 +18,7 @@ use std::path::PathBuf; use clap::Parser; use clap::Subcommand; +use cs::ConnectArgs; use env_logger::Env; use log::{error, info}; @@ -50,6 +51,8 @@ enum Commands { Server(ServerArgs), /// Run GT Client Client(ClientArgs), + /// Run GT Connect + Connect(ConnectArgs), #[command(hide = true)] SubP2P, @@ -57,6 +60,8 @@ enum Commands { SubServer(ServerArgs), #[command(hide = true)] SubClient(ClientArgs), + #[command(hide = true)] + SubConnect(ConnectArgs), } fn main() { @@ -75,6 +80,7 @@ fn main() { depth: cli.depth, server_args: None, client_args: None, + connect_args: None, }; if let Some(command) = cli.command { match command { @@ -84,6 +90,9 @@ fn main() { Commands::Client(args) => { manager_args.client_args = Some(args); } + Commands::Connect(args) => { + manager_args.connect_args = Some(args); + } Commands::SubP2P => { info!("GT SubP2P"); peer::start_peer_connection(); @@ -102,6 +111,12 @@ fn main() { info!("GT SubClient done"); return; } + Commands::SubConnect(args) => { + info!("GT SubConnect"); + cs::run_connect(args); + info!("GT SubConnect done"); + return; + } } } diff --git a/bin/src/manager.rs b/bin/src/manager.rs index 2fd9a5fe..09085c63 100644 --- a/bin/src/manager.rs +++ b/bin/src/manager.rs @@ -41,7 +41,7 @@ use tokio::sync::{mpsc, Mutex, oneshot}; use tokio::sync::oneshot::{Receiver, Sender}; use tokio::time::timeout; -use crate::cs::{ClientArgs, ServerArgs}; +use crate::cs::{ClientArgs, ConnectArgs, ServerArgs}; #[derive(Debug)] pub struct ManagerArgs { @@ -49,6 +49,7 @@ pub struct ManagerArgs { pub depth: Option, pub server_args: Option, pub client_args: Option, + pub connect_args: Option, } #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] @@ -71,6 +72,7 @@ enum ProcessConfigEnum { Config(PathBuf), Server(ServerArgs), Client(ClientArgs), + Connect(ConnectArgs), } pub struct Manager { @@ -323,6 +325,11 @@ impl Manager { $cmd.arg("-c").arg(path.clone()); } } + ProcessConfigEnum::Connect(args) => { + if let Some(path) = &args.config { + $cmd.arg("-c").arg(path.clone()); + } + } } $cmd.stdin(Stdio::piped()); $cmd.stdout(Stdio::piped()); @@ -452,17 +459,21 @@ impl Manager { async fn run_configs(&self, configs: Vec) -> Result<()> { let mut server_config = vec![]; let mut client_config = vec![]; + let mut connect_config = vec![]; for config in configs { match &config { ProcessConfigEnum::Config(path) => { if is_client_config_path(path).context("is_client_config_path failed")? { client_config.push(config); - } else { + } else if is_server_config_path(path).context("is_server_config_path failed")? { server_config.push(config); + } else { + connect_config.push(config); } } ProcessConfigEnum::Server(_) => server_config.push(config), ProcessConfigEnum::Client(_) => client_config.push(config), + ProcessConfigEnum::Connect(_) => connect_config.push(config), } } if !server_config.is_empty() { @@ -476,6 +487,12 @@ impl Manager { .await .context("run_client failed")?; } + + if !connect_config.is_empty() { + Self::run(self.cmds.clone(), connect_config, "sub-connect") + .await + .context("run_connect failed")?; + } Ok(()) } @@ -802,6 +819,11 @@ fn is_client_config_path(path: &PathBuf) -> Result { is_client_config(&yaml) } +fn is_server_config_path(path: &PathBuf) -> Result { + let yaml = fs::read_to_string(path)?; + is_server_config(&yaml) +} + fn is_client_config(yaml: &str) -> Result { let c = serde_yaml::from_str::(yaml)?; if c.services.is_some() { @@ -811,6 +833,23 @@ fn is_client_config(yaml: &str) -> Result { return match typ.as_str() { "client" => Ok(true), "server" => Ok(false), + "connect" => Ok(false), + t => Err(anyhow!("invalid config type {}", t)), + }; + } + Ok(false) +} + +fn is_server_config(yaml: &str) -> Result { + let s = serde_yaml::from_str::(yaml)?; + if s.services.is_some() { + return Ok(false); + } + if let Some(typ) = s.typ { + return match typ.as_str() { + "client" => Ok(false), + "server" => Ok(true), + "connect" => Ok(false), t => Err(anyhow!("invalid config type {}", t)), }; } diff --git a/bin/src/peer/conn.rs b/bin/src/peer/conn.rs index c8c25bee..8719069e 100644 --- a/bin/src/peer/conn.rs +++ b/bin/src/peer/conn.rs @@ -73,13 +73,25 @@ where debug!("config json: {}", &json); let op = serde_json::from_str::(&json) .with_context(|| format!("deserialize config json failed: {}", json))?; + // let op: OP = OP::Config(Config { + // stuns: vec!["stun:127.0.0.1:3478".to_owned()], + // http_routes: HashMap::from([("@".to_owned(), "http://www.baidu.com".to_owned())]), + // ..Default::default() + // }); + // write json config to stdout + let output = Arc::new(Mutex::new(tokio::io::stdout())); + write_json(Arc::clone(&output), &serde_json::to_string(&op).unwrap()) + .await + .map_err(|e| println!("write json error: {:?}", e)) + .expect("write json"); + // bind origin op to config let config = match op { OP::Config(config) => config, _ => { - bail!("invalid config json {}", &json); + bail!("invalid config json."); } }; - + // init webrtc configuration let rtc_config = RTCConfiguration { ice_servers: vec![RTCIceServer { urls: config.stuns, @@ -88,35 +100,43 @@ where ..Default::default() }; + // configure media engine let mut m = MediaEngine::default(); m.register_default_codecs() .context("register default codecs")?; + // register default registry let mut registry = Registry::new(); registry = register_default_interceptors(registry, &mut m) .context("register default interceptors")?; + // set min_port and max_port let mut s = SettingEngine::default(); s.set_udp_network(UDPNetwork::Ephemeral( udp_network::EphemeralUDP::new(config.port_min, config.port_max) .context("create udp network")?, )); + // first detach data channel s.detach_data_channels(); + // build api with configuration let api = APIBuilder::new() .with_media_engine(m) .with_interceptor_registry(registry) .with_setting_engine(s) .build(); + // create pc connection let peer_connection = Arc::new( api.new_peer_connection(rtc_config) .await .context("new pc")?, ); + // config max timeout value let timeout = config.timeout.max(5); + // create pc success Ok(Arc::new(PeerConnHandler { reader, writer, @@ -129,6 +149,23 @@ where })) } + // pub async fn send_offer(self: Arc) -> Result<()> { + // let pc = Arc::clone(&self.peer_connection); + // let offer = pc.create_offer(None).await.context("create offer")?; + // let sdp = serde_json::to_string(&offer).context("serialize answer")?; + // let op = OP::OfferSDP(sdp); + // write_json( + // Arc::clone(&self.writer), + // &serde_json::to_string(&op).context("encode op")?, + // ) + // .await + // .context("write answer sdp to stdout")?; + // pc.set_local_description(offer) + // .await + // .context("set local description")?; + // Ok(()) + // } + fn setup_data_channel(self: Arc, d: Arc) { let dc = Arc::clone(&d); d.on_open(Box::new(|| { @@ -217,6 +254,7 @@ where pub async fn handle(self: Arc) -> Result<()> { let writer_on_ice_candidate = Arc::clone(&self.writer); self.peer_connection + // register ice_candidate process function .on_ice_candidate(Box::new(move |c: Option| { info!("on_ice_candidate {:?}", c); let writer_on_ice_candidate = Arc::clone(&writer_on_ice_candidate); @@ -248,6 +286,7 @@ where error!("failed to write ice candidate: {}", e); } } else { + // build candidate with default value let op = OP::Candidate("".to_owned()); let json = match serde_json::to_string(&op) { Err(e) => { @@ -266,18 +305,42 @@ where let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::>(1); self.peer_connection + // register pc state change function .on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { info!("peer Connection State has changed: {s}"); match s { - RTCPeerConnectionState::Unspecified => {} - RTCPeerConnectionState::New => {} - RTCPeerConnectionState::Connecting => {} - RTCPeerConnectionState::Connected => {} - RTCPeerConnectionState::Disconnected => {} + RTCPeerConnectionState::Unspecified => { + // 未指定状态,通常不需要特殊处理 + info!("Connection state unspecified"); + } + RTCPeerConnectionState::New => { + // 新建连接,记录初始ICE连接状态 + info!("New peer connection established"); + } + RTCPeerConnectionState::Connecting => { + // 正在建立连接 + info!("Establishing peer connection..."); + } + RTCPeerConnectionState::Connected => { + // 连接成功建立 + info!("Peer connection successfully established"); + } + RTCPeerConnectionState::Disconnected => { + // 连接断开,可以尝试重连 + warn!("Peer connection disconnected, may attempt reconnection"); + // 可以在这里添加重连逻辑 + } RTCPeerConnectionState::Failed => { + // 连接失败,发送错误信号 + error!("Peer connection failed"); let _ = done_tx.try_send(Err(anyhow!("peer connection state failed"))); } - RTCPeerConnectionState::Closed => {} + RTCPeerConnectionState::Closed => { + // 连接已关闭 + info!("Peer connection closed"); + // 可以在这里进行清理工作 + let _ = done_tx.try_send(Ok(())); + } } Box::pin(async {}) @@ -285,6 +348,7 @@ where let handler = Arc::clone(&self); self.peer_connection + // register data channel build function .on_data_channel(Box::new(move |d: Arc| { info!("new dataChannel {} {}", d.label(), d.id()); let handler = Arc::clone(&handler); @@ -328,6 +392,7 @@ where let pc = Arc::clone(&self.peer_connection); match op { + // receive offer from remote, return answer to remote OP::OfferSDP(sdp) => { let sdp = serde_json::from_str::(&sdp) .context("offer sdp from op")?; @@ -347,6 +412,7 @@ where .await .context("set local description")?; } + // receive candidate from remote, add candidate OP::Candidate(candidate) => { if candidate.is_empty() { continue; @@ -357,6 +423,7 @@ where .await .context("add candidate")?; } + // create data channel and local offer OP::GetOfferSDP { channel_name } => { let data_channel = pc .create_data_channel(&channel_name, None) @@ -377,6 +444,7 @@ where .await .context("set local description")?; } + // receive answer from remote, set remote sdp OP::AnswerSDP(sdp) => { let sdp = serde_json::from_str::(&sdp) .context("answer sdp from op")?; diff --git a/bin/src/peer/connect.rs b/bin/src/peer/connect.rs new file mode 100644 index 00000000..7099ddf8 --- /dev/null +++ b/bin/src/peer/connect.rs @@ -0,0 +1,525 @@ +/* + * Copyright (c) 2022 Institute of Software, Chinese Academy of Sciences (ISCAS) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + use std::collections::HashMap; + use std::future::Future; + use std::pin::Pin; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + use std::time::Duration; + + use anyhow::{anyhow, bail, Context, Result}; + use log::*; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpStream; + use tokio::sync::Mutex; + use tokio::{io, select, time}; + use url::Url; + use webrtc::api::interceptor_registry::register_default_interceptors; + use webrtc::api::media_engine::MediaEngine; + use webrtc::api::setting_engine::SettingEngine; + use webrtc::api::APIBuilder; + use webrtc::data::data_channel::PollDataChannel; + use webrtc::data_channel::RTCDataChannel; + use webrtc::ice::udp_network; + use webrtc::ice::udp_network::UDPNetwork; + use webrtc::ice_transport::ice_candidate::{RTCIceCandidate, RTCIceCandidateInit}; + use webrtc::ice_transport::ice_server::RTCIceServer; + use webrtc::interceptor::registry::Registry; + use webrtc::peer_connection::configuration::RTCConfiguration; + use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; + use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; + use webrtc::peer_connection::RTCPeerConnection; + use reqwest::{Client, header}; + + use crate::peer::{read_json, write_json, LibError, OP, Config, ConnectConfig}; + +use super::ConnectOptions; + + pub(crate) struct ConnectPeerConnHandler { + http_routes: HashMap, + tcp_routes: HashMap, + reader: Arc>, + writer: Arc>, + channel_count: AtomicUsize, + no_channel_id: AtomicUsize, + peer_connection: Arc, + timeout: u16, + options: Arc, + } + + impl ConnectPeerConnHandler + where + R: AsyncReadExt + Unpin + Send + 'static, + W: AsyncWriteExt + Unpin + Send + 'static, + { + pub async fn new(reader: R, writer: W, args: ConnectConfig) -> Result> { + let reader = Arc::new(Mutex::new(reader)); + let writer = Arc::new(Mutex::new(writer)); + let tempargs = Arc::clone(&Arc::new(args.options)); + let op: OP = OP::Config(Config { + stuns: vec![tempargs.stun_addr.to_owned()], + http_routes: HashMap::from([("@".to_owned(), "http://www.baidu.com".to_owned())]), + ..Default::default() + }); + // write json config to stdout + let output = Arc::new(Mutex::new(tokio::io::stdout())); + write_json(Arc::clone(&output), &serde_json::to_string(&op).unwrap()) + .await + .map_err(|e| println!("write json error: {:?}", e)) + .expect("write json"); + // bind origin op to config + let config = match op { + OP::Config(config) => config, + _ => { + bail!("invalid config json."); + } + }; + // init webrtc configuration + let rtc_config = RTCConfiguration { + ice_servers: vec![RTCIceServer { + urls: config.stuns, + ..Default::default() + }], + ..Default::default() + }; + + // configure media engine + let mut m = MediaEngine::default(); + m.register_default_codecs() + .context("register default codecs")?; + + // register default registry + let mut registry = Registry::new(); + + registry = register_default_interceptors(registry, &mut m) + .context("register default interceptors")?; + + // set min_port and max_port + let mut s = SettingEngine::default(); + s.set_udp_network(UDPNetwork::Ephemeral( + udp_network::EphemeralUDP::new(config.port_min, config.port_max) + .context("create udp network")?, + )); + // first detach data channel + s.detach_data_channels(); + + // build api with configuration + let api = APIBuilder::new() + .with_media_engine(m) + .with_interceptor_registry(registry) + .with_setting_engine(s) + .build(); + + // create pc connection + let peer_connection = Arc::new( + api.new_peer_connection(rtc_config) + .await + .context("new pc")?, + ); + + // config max timeout value + let timeout = config.timeout.max(5); + // create pc success + Ok(Arc::new(ConnectPeerConnHandler { + reader, + writer, + peer_connection, + timeout, + http_routes: config.http_routes, + tcp_routes: config.tcp_routes, + channel_count: Default::default(), + no_channel_id: Default::default(), + options: tempargs, + })) + } + + pub async fn send_http_request(self: Arc, + url: &str, + method: &str, + host: Option<&str>, + headers: Option>, + body: Option<&str>, + ) -> Result { + let client = Client::new(); + + // 创建请求构建器 + let mut request_builder = match method.to_uppercase().as_str() { + "GET" => client.get(url), + "POST" => client.post(url), + "PUT" => client.put(url), + "DELETE" => client.delete(url), + _ => todo!(), + }; + + // 如果提供了 host,则设置 Host 头 + if let Some(host_value) = host { + request_builder = request_builder.header(header::HOST, host_value); + } + + // 添加其他自定义头 + if let Some(custom_headers) = headers { + for (key, value) in custom_headers { + request_builder = request_builder.header(key, value); + } + } + + // 如果提供了 body,则添加到请求中 + if let Some(body_content) = body { + request_builder = request_builder.body(body_content.to_string()); + } + + // 发送请求并获取响应 + let response = request_builder.send().await?; + + // 检查状态码 + if !response.status().is_success() { + error!("HTTP error: {}", response.status()); + } + + // 获取响应体 + let body = response.text().await?; + + Ok(body) + } + + pub async fn forward_data_with_server(self: Arc, msg: &str) -> Result { + // let ya = serde_yaml::from_str::(yaml)?; + let options = Arc::clone(&self.options); + let url = &options.tcp_forward_addr; + let method = "GET"; + let host = Some(options.tcp_forward_host_prefix.as_str()); + let headers = Some(vec![ + ("Users-Agent".to_string(), "gt-connect".to_string()), + ]); + let body = Some(msg); + + let resp = self.send_http_request(&url, method, host, headers, body).await?; + info!("Response from remote: {}", resp); + Ok(true) + } + + pub async fn send_offer(self: Arc) -> Result<()> { + let pc = Arc::clone(&self.peer_connection); + let offer = pc.create_offer(None).await.context("create offer")?; + let sdp = serde_json::to_string(&offer).context("serialize answer")?; + let op = OP::OfferSDP(sdp); + write_json( + Arc::clone(&self.writer), + &serde_json::to_string(&op).context("encode op")?, + ) + .await + .context("write answer sdp to stdout")?; + pc.set_local_description(offer) + .await + .context("set local description")?; + Ok(()) + } + + fn setup_data_channel(self: Arc, d: Arc) { + let dc = Arc::clone(&d); + d.on_open(Box::new(|| { + self.channel_count.fetch_add(1, Ordering::Relaxed); + self.new_data_channel_process_handler(dc) + })); + } + + fn new_data_channel_process_handler( + self: Arc, + d: Arc, + ) -> Pin + Sized>> { + Box::pin(async move { + let label = d.label(); + info!("data channel '{}'-'{}' open.", label, d.id()); + let target = label.split_once('/').map_or_else( + || self.http_routes.get("@"), + |(t, _)| { + t.get(0..1).map_or_else( + || self.http_routes.get("@"), + |c| { + t.get(1..).map_or_else( + || self.http_routes.get(t), + |r| { + if c == "@" && !r.is_empty() { + self.http_routes.get(r) + } else if c == ":" && !r.is_empty() { + self.tcp_routes.get(r) + } else { + self.http_routes.get(t) + } + }, + ) + }, + ) + }, + ); + if let Some(target) = target { + info!("{} connect to {}", label, target); + let dc = Arc::clone(&d); + if let Err(err) = self.connect_target(target, dc).await { + info!("{} failed to connect to {}: {}", label, target, err); + } + } else { + error!("no routes for {}", label); + } + info!("data channel '{}'-'{}' done.", label, d.id()); + let _ = self + .channel_count + .fetch_update(Ordering::Release, Ordering::Relaxed, |v| { + if v == 1 { + self.no_channel_id.fetch_add(1, Ordering::Relaxed); + } + Some(v - 1) + }); + }) + } + + async fn connect_target(&self, target: &str, d: Arc) -> Result<()> { + let url = Url::parse(target).context("invalid url")?; + let addrs = url + .socket_addrs(|| match url.scheme() { + "http" | "ws" | "tcp" => Some(80), + "https" | "wss" | "tls" => Some(443), + _ => Some(80), + }) + .context("no address")?; + let raw = d.detach().await.context("detach data channel")?; + + let mut s = TcpStream::connect(&*addrs) + .await + .context("connect to service")?; + let result = io::copy_bidirectional(&mut PollDataChannel::new(raw), &mut s).await; + match result { + Ok((a, b)) => { + info!("{} copy done: {}, {}", d.label(), a, b); + } + Err(err) => { + error!("{} copy err: {}", d.label(), err); + bail!(err); + } + } + Ok(()) + } + + pub async fn handle(self: Arc) -> Result<()> { + let writer_on_ice_candidate = Arc::clone(&self.writer); + self.peer_connection + // register ice_candidate process function + .on_ice_candidate(Box::new(move |c: Option| { + info!("on_ice_candidate {:?}", c); + let writer_on_ice_candidate = Arc::clone(&writer_on_ice_candidate); + Box::pin(async move { + if let Some(c) = c { + let json = match c.to_json() { + Err(e) => { + error!("failed to serialize ice candidate: {}", e); + return; + } + Ok(json) => json, + }; + let json = match serde_json::to_string(&json) { + Err(e) => { + error!("failed to serialize ice candidate init: {}", e); + return; + } + Ok(json) => json, + }; + let op = OP::Candidate(json); + let json = match serde_json::to_string(&op) { + Err(e) => { + error!("failed to serialize op: {}", e); + return; + } + Ok(json) => json, + }; + if let Err(e) = write_json(writer_on_ice_candidate, &json).await { + error!("failed to write ice candidate: {}", e); + } + } else { + // build candidate with default value + let op = OP::Candidate("".to_owned()); + let json = match serde_json::to_string(&op) { + Err(e) => { + error!("failed to serialize op: {}", e); + return; + } + Ok(json) => json, + }; + if let Err(e) = write_json(writer_on_ice_candidate, &json).await { + error!("failed to write ice candidate: {}", e); + } + } + }) + })); + + let (done_tx, mut done_rx) = tokio::sync::mpsc::channel::>(1); + + self.peer_connection + // register pc state change function + .on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { + info!("peer Connection State has changed: {s}"); + match s { + RTCPeerConnectionState::Unspecified => { + // 未指定状态,通常不需要特殊处理 + info!("Connection state unspecified"); + } + RTCPeerConnectionState::New => { + // 新建连接,记录初始ICE连接状态 + info!("New peer connection established"); + } + RTCPeerConnectionState::Connecting => { + // 正在建立连接 + info!("Establishing peer connection..."); + } + RTCPeerConnectionState::Connected => { + // 连接成功建立 + info!("Peer connection successfully established"); + } + RTCPeerConnectionState::Disconnected => { + // 连接断开,可以尝试重连 + warn!("Peer connection disconnected, may attempt reconnection"); + // 可以在这里添加重连逻辑 + } + RTCPeerConnectionState::Failed => { + // 连接失败,发送错误信号 + error!("Peer connection failed"); + let _ = done_tx.try_send(Err(anyhow!("peer connection state failed"))); + } + RTCPeerConnectionState::Closed => { + // 连接已关闭 + info!("Peer connection closed"); + // 可以在这里进行清理工作 + let _ = done_tx.try_send(Ok(())); + } + } + + Box::pin(async {}) + })); + + let handler = Arc::clone(&self); + self.peer_connection + // register data channel build function + .on_data_channel(Box::new(move |d: Arc| { + info!("new dataChannel {} {}", d.label(), d.id()); + let handler = Arc::clone(&handler); + handler.setup_data_channel(d); + Box::pin(async {}) + })); + + let mut no_channel_id: usize = 0; + loop { + let sleep = time::sleep(Duration::from_secs(self.timeout as u64)); + tokio::pin!(sleep); + let json = select! { + result = read_json(Arc::clone(&self.reader)) => { + result? + }, + rx = done_rx.recv() => { + return match rx { + None => { + Ok(()) + } + Some(result) => { + result + } + } + } + _ = &mut sleep => { + if self.channel_count.load(Ordering::Acquire) == 0 { + let id = self.no_channel_id.load(Ordering::Relaxed); + if no_channel_id == id { + return Err(LibError::NoChannelInPeerConnectionTimeout.into()); + } else { + no_channel_id = id; + } + } + continue; + } + }; + debug!("op json: {}", &json); + let op = serde_json::from_str::(&json) + .with_context(|| format!("parse op json: {}", json))?; + + let pc = Arc::clone(&self.peer_connection); + match op { + // receive offer from remote, return answer to remote + OP::OfferSDP(sdp) => { + let sdp = serde_json::from_str::(&sdp) + .context("offer sdp from op")?; + pc.set_remote_description(sdp) + .await + .context("set remote description")?; + let answer = pc.create_answer(None).await.context("create answer")?; + let sdp = serde_json::to_string(&answer).context("serialize answer")?; + let op = OP::AnswerSDP(sdp); + write_json( + Arc::clone(&self.writer), + &serde_json::to_string(&op).context("encode op")?, + ) + .await + .context("write answer sdp to stdout")?; + pc.set_local_description(answer) + .await + .context("set local description")?; + } + // receive candidate from remote, add candidate + OP::Candidate(candidate) => { + if candidate.is_empty() { + continue; + } + let candidate = serde_json::from_str::(&candidate) + .context("candidate from op")?; + pc.add_ice_candidate(candidate) + .await + .context("add candidate")?; + } + // create data channel and local offer + OP::GetOfferSDP { channel_name } => { + let data_channel = pc + .create_data_channel(&channel_name, None) + .await + .context("create data channel")?; + let handler = Arc::clone(&self); + handler.setup_data_channel(data_channel); + let offer = pc.create_offer(None).await.context("create offer")?; + let sdp = serde_json::to_string(&offer).context("serialize answer")?; + let op = OP::OfferSDP(sdp); + write_json( + Arc::clone(&self.writer), + &serde_json::to_string(&op).context("encode op")?, + ) + .await + .context("write answer sdp to stdout")?; + pc.set_local_description(offer) + .await + .context("set local description")?; + } + // receive answer from remote, set remote sdp + OP::AnswerSDP(sdp) => { + let sdp = serde_json::from_str::(&sdp) + .context("answer sdp from op")?; + pc.set_remote_description(sdp) + .await + .context("set remote description")?; + } + _ => { + bail!("invalid op {:?}", op) + } + }; + } + } + } + \ No newline at end of file diff --git a/bin/src/peer/mod.rs b/bin/src/peer/mod.rs index 48faa29d..111a853d 100644 --- a/bin/src/peer/mod.rs +++ b/bin/src/peer/mod.rs @@ -23,10 +23,12 @@ use log::*; use serde::{Deserialize, Serialize}; use thiserror::Error; use tokio::io; -use tokio::io::{stdin, stdout}; +use tokio::io::{stdin, stdout, AsyncBufReadExt, BufReader}; use tokio::sync::Mutex; +use tokio::sync::mpsc; mod conn; +mod connect; pub fn start_peer_connection() { let rt = tokio::runtime::Builder::new_current_thread() @@ -56,6 +58,35 @@ where handler.handle().await } +pub async fn process_connect(reader: R, writer: W, args: ConnectConfig) -> Result<()> +where + R: io::AsyncReadExt + Unpin + Send + 'static, + W: io::AsyncWriteExt + Unpin + Send + 'static, +{ + let handler = connect::ConnectPeerConnHandler::new(reader, writer, args).await?; + let _ = Arc::clone(&handler).send_offer(); + let (tx, mut rx) = mpsc::channel(8); + // 在一个独立的任务中读取标准输入 + tokio::spawn(async move { + let mut stdin = BufReader::new(tokio::io::stdin()).lines(); + while let Some(line) = stdin.next_line().await.unwrap() { + tx.send(line).await.unwrap(); + } + }); + // 在主任务中处理读取到的行 + while let Some(line) = rx.recv().await { + println!("Received line: {}", line); + // 将读取的行发送给服务端转发 + let handler = Arc::clone(&handler); + match handler.forward_data_with_server(&line).await { + Ok(_) => println!("Successfully forwarded: {}", line), + Err(e) => eprintln!("Error forwarding data: {}", e), + } + } + let _ = Arc::clone(&handler).handle().await; + Ok(()) +} + #[derive(Serialize, Deserialize, Debug, Default)] #[serde(default, rename_all = "camelCase")] pub struct Config { @@ -80,6 +111,23 @@ pub enum OP { }, } +#[derive(Serialize, Deserialize, Debug, Default)] +#[serde(default)] +pub struct ConnectConfig { + #[serde(rename = "type")] + pub typ: String, + pub options: ConnectOptions, +} + +#[derive(Serialize, Deserialize, Debug, Default)] +#[serde(default)] +pub struct ConnectOptions { + pub remote: String, + pub stun_addr: String, + pub tcp_forward_addr: String, + pub tcp_forward_host_prefix: String, +} + pub async fn read_json(reader: Arc>) -> Result where R: io::AsyncReadExt + Unpin,