diff --git a/README.md b/README.md index 695a75fbc..810518360 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,7 @@ | `mongodb`/ `mongo` | `host` | 否 | `String` | Mongodb 进行数据库连接使用的 host | 默认为`localhost` | | `mongodb`/ `mongo` | `port` | 否 | `String` | Mongodb 进行数据库连接使用的端口 | 默认为`27017` | | `mongodb`/ `mongo` | `db_name` | 是 | `String` | Mongodb 进行数据库连接使用的数据库 | 无 | + | `mongodb`/ `mongo` | `query` | 否 | `HashMap` | Mongodb 进行数据库连接使用的参数 | 默认为`{}` | | `user_auth`/ `auth` | `jwt`/`jwt-key` | 否 | `String` | 用户鉴权使用的`Jwt`密钥 | 最大长度不超过 32 位。过长部分将会被截断,过短部分将会被随机数填充 | | `user_auth`/ `auth` | `header`/`header_name` | 否 | `String` | 获取 token 的 Header | 默认为`Token` | | `user_auth`/ `auth` | `mob_header` | 否 | `String` | 获取 mob_id 的 Header | 默认为`mob-id` | diff --git a/persistence/database/mongo_connection/Cargo.toml b/persistence/database/mongo_connection/Cargo.toml index bfb4a9f0a..7483972cb 100644 --- a/persistence/database/mongo_connection/Cargo.toml +++ b/persistence/database/mongo_connection/Cargo.toml @@ -15,6 +15,7 @@ async-trait = { workspace = true } database_traits = { path = "../database_traits" } tracing = { workspace = true } time-utils = { workspace = true, features = ["with-mongo"] } +url.workspace = true [dependencies.status-err] path = "../../../libs/status-err" diff --git a/persistence/database/mongo_connection/src/config.rs b/persistence/database/mongo_connection/src/config.rs index 60bd2a728..35e008a77 100644 --- a/persistence/database/mongo_connection/src/config.rs +++ b/persistence/database/mongo_connection/src/config.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + pub trait DbConnectConfig: serde::de::DeserializeOwned { fn scheme(&self) -> &str; fn username(&self) -> &str; @@ -5,6 +7,7 @@ pub trait DbConnectConfig: serde::de::DeserializeOwned { fn host(&self) -> &str; fn port(&self) -> u16; fn name(&self) -> &str; + fn query(&self) -> &HashMap; } #[derive(Debug, serde::Deserialize)] @@ -16,6 +19,8 @@ pub struct MongoDbConfig { #[serde(default = "port_default")] port: u16, db_name: String, + #[serde(default = "query_default")] + query: HashMap, } impl DbConnectConfig for MongoDbConfig { @@ -30,8 +35,12 @@ impl DbConnectConfig for MongoDbConfig { fn port(&self) -> u16 { self.port } fn name(&self) -> &str { &self.db_name } + + fn query(&self) -> &HashMap { &self.query } } fn host_default() -> String { "localhost".into() } fn port_default() -> u16 { 27017 } + +fn query_default() -> HashMap { HashMap::new() } diff --git a/persistence/database/mongo_connection/src/mongo_connect.rs b/persistence/database/mongo_connection/src/mongo_connect.rs index eb369dee7..d5618d0fd 100644 --- a/persistence/database/mongo_connection/src/mongo_connect.rs +++ b/persistence/database/mongo_connection/src/mongo_connect.rs @@ -1,5 +1,6 @@ use mongo_migrate_util::MigratorTrait; use mongodb::{options::ClientOptions, Database}; +use url::Url; use crate::{ database::builder::DatabaseBuilder, static_vars::set_mongo_database, @@ -52,16 +53,83 @@ async fn init_mongodb(url: &str) -> Result { } fn format_url(cfg: &impl DbConnectConfig) -> String { - let s = format!( - "{}://{}:{}@{}:{}/{}?authSource=admin", + let mut s = Url::parse(&format!( + "{}://{}:{}@{}:{}/{}", cfg.scheme(), cfg.username(), urlencoding::encode(cfg.password()), cfg.host(), cfg.port(), cfg.name() - ); + )) + .expect("MongoDb 连接URL生成异常"); - tracing::info!(mongodb.URL = s); - s + // 添加查询参数 + for (key, value) in cfg.query() { + s.query_pairs_mut().append_pair(key, value); + } + + tracing::info!(mongodb.URL = s.to_string()); + s.to_string() +} + +#[cfg(test)] +mod tests { + + use std::collections::HashMap; + + use super::*; + + #[derive(Debug, serde::Deserialize)] + pub struct MongoDbConfig { + username: String, + password: String, + #[serde(default = "host_default")] + host: String, + #[serde(default = "port_default")] + port: u16, + db_name: String, + query: HashMap, + } + + impl DbConnectConfig for MongoDbConfig { + fn scheme(&self) -> &str { "mongodb" } + + fn username(&self) -> &str { &self.username } + + fn password(&self) -> &str { &self.password } + + fn host(&self) -> &str { &self.host } + + fn port(&self) -> u16 { self.port } + + fn name(&self) -> &str { &self.db_name } + + fn query(&self) -> &HashMap { &self.query } + } + + fn host_default() -> String { "localhost".into() } + + fn port_default() -> u16 { 27017 } + + #[test] + fn test_format_url() { + let mut query = HashMap::new(); + query.insert("authSource".to_string(), "admin".to_string()); + query.insert("directConnection".to_string(), "true".to_string()); + + let config = MongoDbConfig { + username: "user".to_string(), + password: "password".to_string(), + host: "localhost".to_string(), + port: 27017, + db_name: "mydb".to_string(), + query, + }; + + let expected_url = "mongodb://user:password@localhost:27017/mydb?\ + authSource=admin&directConnection=true"; + let result = format_url(&config); + assert_eq!(result, expected_url); + } }