From 3dd582097dd42dd85fc49ffaa8a4e6a9e48b3b05 Mon Sep 17 00:00:00 2001 From: Christopher Schramm Date: Mon, 21 Mar 2022 17:54:57 +0100 Subject: [PATCH] Client-side TLS 1.2 RFC 5746 Secure Renegotiation --- rustls/Cargo.toml | 1 + rustls/src/client/hs.rs | 36 ++++++--- rustls/src/client/tls12.rs | 144 +++++++++++++++++++++++++++-------- rustls/src/common_state.rs | 21 +++-- rustls/src/error.rs | 1 + rustls/src/msgs/deframer.rs | 2 - rustls/src/msgs/handshake.rs | 3 + rustls/src/record_layer.rs | 4 + 8 files changed, 160 insertions(+), 52 deletions(-) diff --git a/rustls/Cargo.toml b/rustls/Cargo.toml index 4ec52f860d..143257b90c 100644 --- a/rustls/Cargo.toml +++ b/rustls/Cargo.toml @@ -31,6 +31,7 @@ aws_lc_rs = ["dep:aws-lc-rs", "webpki/aws_lc_rs"] ring = ["dep:ring", "webpki/ring"] tls12 = [] read_buf = ["rustversion"] +renegotiation = [] [dev-dependencies] base64 = "0.21" diff --git a/rustls/src/client/hs.rs b/rustls/src/client/hs.rs index 26ce6383f3..9b4edfe17d 100644 --- a/rustls/src/client/hs.rs +++ b/rustls/src/client/hs.rs @@ -9,7 +9,7 @@ use crate::error::{Error, PeerIncompatible, PeerMisbehaved}; use crate::hash_hs::HandshakeHashBuffer; #[cfg(feature = "logging")] use crate::log::{debug, trace}; -use crate::msgs::base::Payload; +use crate::msgs::base::{Payload, PayloadU8}; use crate::msgs::enums::{Compression, ExtensionType}; use crate::msgs::enums::{ECPointFormat, PSKKeyExchangeMode}; use crate::msgs::handshake::ConvertProtocolNameList; @@ -104,7 +104,11 @@ pub(super) fn start_handshake( transcript_buffer.set_client_auth_enabled(); } - let mut resuming = find_session(&server_name, &config, cx); + let mut resuming = if cx.common.client_verify_data.is_empty() { + find_session(&server_name, &config, cx) + } else { + None + }; let key_share = if config.supports_version(ProtocolVersion::TLSv1_3) { Some(tls13::initial_key_share(&config, &server_name)?) @@ -271,12 +275,6 @@ fn emit_client_hello_for_retry( // Do we have a SessionID or ticket cached for this host? let tls13_session = prepare_resumption(&input.resuming, &mut exts, suite, cx, config); - // Note what extensions we sent. - input.hello.sent_extensions = exts - .iter() - .map(ClientExtension::get_type) - .collect(); - let mut cipher_suites: Vec<_> = config .provider .cipher_suites @@ -286,8 +284,20 @@ fn emit_client_hello_for_retry( false => None, }) .collect(); - // We don't do renegotiation at all, in fact. - cipher_suites.push(CipherSuite::TLS_EMPTY_RENEGOTIATION_INFO_SCSV); + + if cx.common.client_verify_data.is_empty() { + cipher_suites.push(CipherSuite::TLS_EMPTY_RENEGOTIATION_INFO_SCSV); + } else { + exts.push(ClientExtension::RenegotiationInfo(PayloadU8::new( + cx.common.client_verify_data.clone(), + ))); + } + + // Note what extensions we sent. + input.hello.sent_extensions = exts + .iter() + .map(ClientExtension::get_type) + .collect(); let mut chp = HandshakeMessagePayload { typ: HandshakeType::ClientHello, @@ -312,7 +322,7 @@ fn emit_client_hello_for_retry( // "This value MUST be set to 0x0303 for all records generated // by a TLS 1.3 implementation other than an initial ClientHello // (i.e., one not generated after a HelloRetryRequest)" - version: if retryreq.is_some() { + version: if retryreq.is_some() || !cx.common.client_verify_data.is_empty() { ProtocolVersion::TLSv1_2 } else { ProtocolVersion::TLSv1_0 @@ -329,7 +339,8 @@ fn emit_client_hello_for_retry( trace!("Sending ClientHello {:#?}", ch); transcript_buffer.add_message(&ch); - cx.common.send_msg(ch, false); + cx.common + .send_msg(ch, !cx.common.client_verify_data.is_empty()); // Calculate the hash of ClientHello and use it to derive EarlyTrafficSecret let early_key_schedule = early_key_schedule.map(|(resuming_suite, schedule)| { @@ -662,6 +673,7 @@ impl State for ExpectServerHello { randoms, using_ems: self.input.using_ems, transcript, + message_decrypter: None, } .handle_server_hello(cx, suite, server_hello, tls13_supported) } diff --git a/rustls/src/client/tls12.rs b/rustls/src/client/tls12.rs index d572ffda2c..f8afbb36e9 100644 --- a/rustls/src/client/tls12.rs +++ b/rustls/src/client/tls12.rs @@ -1,6 +1,7 @@ use crate::check::{inappropriate_handshake_message, inappropriate_message}; use crate::common_state::{CommonState, Side, State}; use crate::conn::ConnectionRandoms; +use crate::crypto::cipher::MessageDecrypter; use crate::enums::ProtocolVersion; use crate::enums::{AlertDescription, ContentType, HandshakeType}; use crate::error::{Error, InvalidMessage, PeerMisbehaved}; @@ -41,6 +42,7 @@ pub(super) use server_hello::CompleteServerHelloHandling; mod server_hello { use crate::msgs::enums::ExtensionType; use crate::msgs::handshake::HasServerExtensions; + use crate::msgs::handshake::ServerExtension; use crate::msgs::handshake::ServerHelloPayload; use super::*; @@ -52,6 +54,7 @@ mod server_hello { pub(in crate::client) randoms: ConnectionRandoms, pub(in crate::client) using_ems: bool, pub(in crate::client) transcript: HandshakeHash, + pub(in crate::client) message_decrypter: Option>, } impl CompleteServerHelloHandling { @@ -102,6 +105,35 @@ mod server_hello { debug!("Server may staple OCSP response"); } + let renegotiation_info = server_hello + .find_extension(ExtensionType::RenegotiationInfo) + .map(|ext| match ext { + ServerExtension::RenegotiationInfo(renegotiation_info) => renegotiation_info, + _ => unreachable!(), + }); + + if cx.common.secure_renegotiation.is_none() { + if renegotiation_info.map_or(false, |it| !it.0.is_empty()) { + return Err(cx.common.send_fatal_alert( + AlertDescription::HandshakeFailure, + Error::PeerMisbehaved(PeerMisbehaved::NonEmptyRenegotiationInfo), + )); + } + + cx.common.secure_renegotiation = Some(renegotiation_info.is_some()); + } else if renegotiation_info.map_or(true, |it| { + it.0 != [ + &cx.common.client_verify_data[..], + &cx.common.server_verify_data[..], + ] + .concat() + }) { + return Err(cx.common.send_fatal_alert( + AlertDescription::HandshakeFailure, + Error::PeerMisbehaved(PeerMisbehaved::NonEmptyRenegotiationInfo), + )); + } + // See if we're successfully resuming. if let Some(resuming) = self.resuming_session { if resuming.session_id == server_hello.session_id { @@ -145,6 +177,7 @@ mod server_hello { resuming: true, cert_verified, sig_verified, + message_decrypter: self.message_decrypter, })) } else { Ok(Box::new(ExpectCcs { @@ -159,6 +192,7 @@ mod server_hello { resuming: true, cert_verified, sig_verified, + message_decrypter: self.message_decrypter, })) }; } @@ -416,6 +450,7 @@ fn emit_certificate( transcript: &mut HandshakeHash, cert_chain: CertificateChain, common: &mut CommonState, + must_encrypt: bool, ) { let cert = Message { version: ProtocolVersion::TLSv1_2, @@ -426,10 +461,15 @@ fn emit_certificate( }; transcript.add_message(&cert); - common.send_msg(cert, false); + common.send_msg(cert, must_encrypt); } -fn emit_clientkx(transcript: &mut HandshakeHash, common: &mut CommonState, pub_key: &[u8]) { +fn emit_clientkx( + transcript: &mut HandshakeHash, + common: &mut CommonState, + pub_key: &[u8], + must_encrypt: bool, +) { let mut buf = Vec::new(); let ecpoint = PayloadU8::new(Vec::from(pub_key)); ecpoint.encode(&mut buf); @@ -444,13 +484,14 @@ fn emit_clientkx(transcript: &mut HandshakeHash, common: &mut CommonState, pub_k }; transcript.add_message(&ckx); - common.send_msg(ckx, false); + common.send_msg(ckx, must_encrypt); } fn emit_certverify( transcript: &mut HandshakeHash, signer: &dyn Signer, common: &mut CommonState, + must_encrypt: bool, ) -> Result<(), Error> { let message = transcript .take_handshake_buf() @@ -469,17 +510,17 @@ fn emit_certverify( }; transcript.add_message(&m); - common.send_msg(m, false); + common.send_msg(m, must_encrypt); Ok(()) } -fn emit_ccs(common: &mut CommonState) { +fn emit_ccs(common: &mut CommonState, must_encrypt: bool) { let ccs = Message { version: ProtocolVersion::TLSv1_2, payload: MessagePayload::ChangeCipherSpec(ChangeCipherSpecPayload {}), }; - common.send_msg(ccs, false); + common.send_msg(ccs, must_encrypt); } fn emit_finished( @@ -488,8 +529,8 @@ fn emit_finished( common: &mut CommonState, ) { let vh = transcript.get_current_hash(); - let verify_data = secrets.client_verify_data(&vh); - let verify_data_payload = Payload::new(verify_data); + common.client_verify_data = secrets.client_verify_data(&vh); + let verify_data_payload = Payload::new(common.client_verify_data.clone()); let f = Message { version: ProtocolVersion::TLSv1_2, @@ -752,13 +793,15 @@ impl State for ExpectServerDone { }; cx.common.peer_certificates = Some(st.server_cert.cert_chain); + let must_encrypt = !cx.common.client_verify_data.is_empty(); + // 4. if let Some(client_auth) = &st.client_auth { let certs = match client_auth { ClientAuthDetails::Empty { .. } => CertificateChain::default(), ClientAuthDetails::Verify { certkey, .. } => CertificateChain(certkey.cert.clone()), }; - emit_certificate(&mut st.transcript, certs, cx.common); + emit_certificate(&mut st.transcript, certs, cx.common, must_encrypt); } // 5a. @@ -777,7 +820,7 @@ impl State for ExpectServerDone { // 5b. let mut transcript = st.transcript; - emit_clientkx(&mut transcript, cx.common, kx.pub_key()); + emit_clientkx(&mut transcript, cx.common, kx.pub_key(), must_encrypt); // Note: EMS handshake hash only runs up to ClientKeyExchange. let ems_seed = st .using_ems @@ -785,11 +828,11 @@ impl State for ExpectServerDone { // 5c. if let Some(ClientAuthDetails::Verify { signer, .. }) = &st.client_auth { - emit_certverify(&mut transcript, signer.as_ref(), cx.common)?; + emit_certverify(&mut transcript, signer.as_ref(), cx.common, must_encrypt)?; } // 5d. - emit_ccs(cx.common); + emit_ccs(cx.common, must_encrypt); // 5e. Now commit secrets. let secrets = ConnectionSecrets::from_key_exchange( @@ -805,12 +848,25 @@ impl State for ExpectServerDone { &secrets.randoms.client, &secrets.master_secret, ); + + let (dec, enc) = secrets.make_cipher_pair(Side::Client); + cx.common - .start_encryption_tls12(&secrets, Side::Client); + .record_layer + .prepare_message_encrypter(enc); cx.common .record_layer .start_encrypting(); + let message_decrypter = if cx.common.record_layer.is_decrypting() { + Some(dec) + } else { + cx.common + .record_layer + .prepare_message_decrypter(dec); + None + }; + // 6. emit_finished(&secrets, &mut transcript, cx.common); @@ -826,6 +882,7 @@ impl State for ExpectServerDone { resuming: false, cert_verified, sig_verified, + message_decrypter, })) } else { Ok(Box::new(ExpectCcs { @@ -840,6 +897,7 @@ impl State for ExpectServerDone { resuming: false, cert_verified, sig_verified, + message_decrypter, })) } } @@ -856,6 +914,7 @@ struct ExpectNewTicket { resuming: bool, cert_verified: verify::ServerCertVerified, sig_verified: verify::HandshakeSignatureValid, + message_decrypter: Option>, } impl State for ExpectNewTicket { @@ -884,6 +943,7 @@ impl State for ExpectNewTicket { resuming: self.resuming, cert_verified: self.cert_verified, sig_verified: self.sig_verified, + message_decrypter: self.message_decrypter, })) } } @@ -901,6 +961,7 @@ struct ExpectCcs { resuming: bool, cert_verified: verify::ServerCertVerified, sig_verified: verify::HandshakeSignatureValid, + message_decrypter: Option>, } impl State for ExpectCcs { @@ -918,6 +979,12 @@ impl State for ExpectCcs { // message. cx.common.check_aligned_handshake()?; + if let Some(dec) = self.message_decrypter { + cx.common + .record_layer + .prepare_message_decrypter(dec); + } + // Note: msgs layer validates trivial contents of CCS. cx.common .record_layer @@ -1005,19 +1072,19 @@ impl State for ExpectFinished { // Work out what verify_data we expect. let vh = st.transcript.get_current_hash(); - let expect_verify_data = st.secrets.server_verify_data(&vh); + cx.common.server_verify_data = st.secrets.server_verify_data(&vh); // Constant-time verification of this is relatively unimportant: they only // get one chance. But it can't hurt. - let _fin_verified = match ConstantTimeEq::ct_eq(&expect_verify_data[..], &finished.0).into() - { - true => verify::FinishedMessageVerified::assertion(), - false => { - return Err(cx - .common - .send_fatal_alert(AlertDescription::DecryptError, Error::DecryptError)); - } - }; + let _fin_verified = + match ConstantTimeEq::ct_eq(&cx.common.server_verify_data[..], &finished.0).into() { + true => verify::FinishedMessageVerified::assertion(), + false => { + return Err(cx + .common + .send_fatal_alert(AlertDescription::DecryptError, Error::DecryptError)); + } + }; // Hash this message too. st.transcript.add_message(&m); @@ -1025,7 +1092,7 @@ impl State for ExpectFinished { st.save_session(cx); if st.resuming { - emit_ccs(cx.common); + emit_ccs(cx.common, false); cx.common .record_layer .start_encrypting(); @@ -1038,6 +1105,8 @@ impl State for ExpectFinished { _cert_verified: st.cert_verified, _sig_verified: st.sig_verified, _fin_verified, + server_name: st.server_name, + config: st.config, })) } @@ -1060,22 +1129,31 @@ struct ExpectTraffic { _cert_verified: verify::ServerCertVerified, _sig_verified: verify::HandshakeSignatureValid, _fin_verified: verify::FinishedMessageVerified, + server_name: ServerName<'static>, + config: Arc, } impl State for ExpectTraffic { fn handle(self: Box, cx: &mut ClientContext<'_>, m: Message) -> hs::NextStateOrError { match m.payload { - MessagePayload::ApplicationData(payload) => cx - .common - .take_received_plaintext(payload), - payload => { - return Err(inappropriate_message( - &payload, - &[ContentType::ApplicationData], - )); + MessagePayload::ApplicationData(payload) => { + cx.common + .take_received_plaintext(payload); + Ok(self) } + MessagePayload::Handshake { + parsed: + HandshakeMessagePayload { + payload: HandshakePayload::HelloRequest, + .. + }, + .. + } => hs::start_handshake(self.server_name, vec![], self.config, cx), + payload => Err(inappropriate_message( + &payload, + &[ContentType::ApplicationData], + )), } - Ok(self) } fn export_keying_material( diff --git a/rustls/src/common_state.rs b/rustls/src/common_state.rs index bc9cea92ce..ec3d62e37d 100644 --- a/rustls/src/common_state.rs +++ b/rustls/src/common_state.rs @@ -49,6 +49,9 @@ pub struct CommonState { pub(crate) protocol: Protocol, pub(crate) quic: quic::Quic, pub(crate) enable_secret_extraction: bool, + pub(crate) secure_renegotiation: Option, + pub(crate) client_verify_data: Vec, + pub(crate) server_verify_data: Vec, } impl CommonState { @@ -76,6 +79,9 @@ impl CommonState { protocol: Protocol::Tcp, quic: quic::Quic::default(), enable_secret_extraction: false, + secure_renegotiation: None, + client_verify_data: Vec::new(), + server_verify_data: Vec::new(), } } @@ -153,12 +159,17 @@ impl CommonState { // renegotiation requests. These can occur any time. if self.may_receive_application_data && !self.is_tls13() { let reject_ty = match self.side { - Side::Client => HandshakeType::HelloRequest, - Side::Server => HandshakeType::ClientHello, + #[cfg(feature = "renegotiation")] + Side::Client => None, + #[cfg(not(feature = "renegotiation"))] + Side::Client => Some(HandshakeType::HelloRequest), + Side::Server => Some(HandshakeType::ClientHello), }; - if msg.is_handshake_type(reject_ty) { - self.send_warning_alert(AlertDescription::NoRenegotiation); - return Ok(state); + if let Some(reject_ty) = reject_ty { + if msg.is_handshake_type(reject_ty) { + self.send_warning_alert(AlertDescription::NoRenegotiation); + return Ok(state); + } } } diff --git a/rustls/src/error.rs b/rustls/src/error.rs index 7d692b7f32..8760f3784e 100644 --- a/rustls/src/error.rs +++ b/rustls/src/error.rs @@ -205,6 +205,7 @@ pub enum PeerMisbehaved { MissingKeyShare, MissingPskModesExtension, MissingQuicTransportParameters, + NonEmptyRenegotiationInfo, OfferedDuplicateKeyShares, OfferedEarlyDataWithOldProtocolVersion, OfferedEmptyApplicationProtocol, diff --git a/rustls/src/msgs/deframer.rs b/rustls/src/msgs/deframer.rs index 1760340293..c8880e3a7c 100644 --- a/rustls/src/msgs/deframer.rs +++ b/rustls/src/msgs/deframer.rs @@ -90,8 +90,6 @@ impl MessageDeframer { let end = start + rd.used(); let version_is_tls13 = matches!(negotiated_version, Some(ProtocolVersion::TLSv1_3)); let allowed_plaintext = match m.typ { - // CCS messages are always plaintext. - ContentType::ChangeCipherSpec => true, // Alerts are allowed to be plaintext if-and-only-if: // * The negotiated protocol version is TLS 1.3. - In TLS 1.2 it is unambiguous when // keying changes based on the CCS message. Only TLS 1.3 requires these heuristics. diff --git a/rustls/src/msgs/handshake.rs b/rustls/src/msgs/handshake.rs index a0fb927390..4b903f9280 100644 --- a/rustls/src/msgs/handshake.rs +++ b/rustls/src/msgs/handshake.rs @@ -555,6 +555,7 @@ pub enum ClientExtension { TransportParameters(Vec), TransportParametersDraft(Vec), EarlyData, + RenegotiationInfo(PayloadU8), Unknown(UnknownExtension), } @@ -577,6 +578,7 @@ impl ClientExtension { Self::TransportParameters(_) => ExtensionType::TransportParameters, Self::TransportParametersDraft(_) => ExtensionType::TransportParametersDraft, Self::EarlyData => ExtensionType::EarlyData, + Self::RenegotiationInfo(_) => ExtensionType::RenegotiationInfo, Self::Unknown(ref r) => r.typ, } } @@ -606,6 +608,7 @@ impl Codec for ClientExtension { Self::TransportParameters(ref r) | Self::TransportParametersDraft(ref r) => { nested.buf.extend_from_slice(r); } + Self::RenegotiationInfo(ref r) => r.encode(nested.buf), Self::Unknown(ref r) => r.encode(nested.buf), } } diff --git a/rustls/src/record_layer.rs b/rustls/src/record_layer.rs index 3f8f3ed4ba..839c9cadcb 100644 --- a/rustls/src/record_layer.rs +++ b/rustls/src/record_layer.rs @@ -53,6 +53,10 @@ impl RecordLayer { } } + pub(crate) fn is_decrypting(&self) -> bool { + self.decrypt_state == DirectionState::Active + } + /// Decrypt a TLS message. /// /// `encr` is a decoded message allegedly received from the peer.