use axum::http::{HeaderName, HeaderValue, Uri, header, request::Parts}; use itertools::Itertools; use url::Url; /// Fields from a "Forwarded" header per [RFC7239 sec 4](https://www.rfc-editor.org/rfc/rfc7239#section-4) #[derive(Debug, Clone)] pub struct ForwardedHeader { pub for_field: Vec, pub by: Option, pub host: Option, pub proto: Option, } impl ForwardedHeader { /// Return the 'for' headers as a list of [std::net::IpAddr]'s. pub fn for_as_ipaddr(self) -> Vec { self.for_field .iter() .filter_map(|ip| { if ip.contains(']') { // this is an IPv6 address, get what's between the [] ip.split(']') .next()? .split('[') .next_back()? .parse::() .ok() } else { ip.parse::().ok() } }) .collect::>() } } /// This parses the Forwarded header, and returns a list of the IPs in the /// "for=" fields. Per [RFC7239 sec 4](https://www.rfc-editor.org/rfc/rfc7239#section-4) impl TryFrom for ForwardedHeader { type Error = String; fn try_from(forwarded: HeaderValue) -> Result { ForwardedHeader::try_from(&forwarded) } } /// This parses the Forwarded header, and returns a list of the IPs in the /// "for=" fields. Per [RFC7239 sec 4](https://www.rfc-editor.org/rfc/rfc7239#section-4) impl TryFrom<&HeaderValue> for ForwardedHeader { type Error = String; fn try_from(forwarded: &HeaderValue) -> Result { let mut for_field: Vec = Vec::new(); let mut by: Option = None; let mut host: Option = None; let mut proto: Option = None; // first get the k=v pairs forwarded .to_str() .map_err(|err| err.to_string())? .split(';') .for_each(|s| { let s = s.trim().to_lowercase(); // The for value can look like this: // for=192.0.2.43, for=198.51.100.17 // so we need to handle this case if s.starts_with("for=") || s.starts_with("for =") { // we have a valid thing to grab let chunks: Vec = s .split(',') .filter_map(|chunk| { chunk.trim().split('=').next_back().map(|c| c.to_string()) }) .collect::>(); for_field.extend(chunks); } else if s.starts_with("by=") { by = s.split('=').next_back().map(|c| c.to_string()); } else if s.starts_with("host=") { host = s.split('=').next_back().map(|c| c.to_string()); } else if s.starts_with("proto=") { proto = s.split('=').next_back().map(|c| c.to_string()); } else { // probably need to work out what to do here } }); Ok(ForwardedHeader { for_field, by, host, proto, }) } } #[derive(Clone, Debug)] pub struct ForwardedRelatedInfo { pub forwarded: Option, pub x_forwarded_proto: Option, pub x_forwarded_host: Option, pub x_forwarded_for: Option>, pub host: Option, pub uri: Uri, pub origin: Option, } impl ForwardedRelatedInfo { pub fn from_request_parts(request_parts: &Parts) -> ForwardedRelatedInfo { let headers = &request_parts.headers; let forwarded = headers .get(header::FORWARDED) .and_then(|s| ForwardedHeader::try_from(s.clone()).ok()); let x_forwarded_proto = headers .get(HeaderName::from_static("x-forwarded-proto")) .and_then(|s| s.to_str().map(String::from).ok()); let x_forwarded_host = headers .get(HeaderName::from_static("x-forwarded-host")) .and_then(|s| s.to_str().map(String::from).ok()); let x_forwarded_for = headers .get(HeaderName::from_static("x-forwarded-for")) .and_then(|s| s.to_str().ok()) .and_then(|s| { let l = s.split(",").map(|s| s.trim().to_string()).collect_vec(); if l.is_empty() { None } else { Some(l) } }); let host = headers .get(header::HOST) .and_then(|s| s.to_str().map(String::from).ok()); let origin = headers .get(header::ORIGIN) .and_then(|s| s.to_str().map(String::from).ok()); ForwardedRelatedInfo { host, x_forwarded_for, x_forwarded_host, x_forwarded_proto, forwarded, uri: request_parts.uri.clone(), origin, } } pub fn resolved_protocol(&self) -> Option<&str> { self.forwarded .as_ref() .and_then(|s| s.proto.as_deref()) .or(self.x_forwarded_proto.as_deref()) .or(self.uri.scheme_str()) } pub fn resolved_host(&self) -> Option<&str> { self.forwarded .as_ref() .and_then(|s| s.host.as_deref()) .or(self.x_forwarded_host.as_deref()) .or(self.uri.host()) } pub fn resolved_origin(&self) -> Option { if let (Some(protocol), Some(host)) = (self.resolved_protocol(), self.resolved_host()) { let origin = format!("{protocol}://{host}"); Url::parse(&origin).ok() } else { None } } }