```
feat(proxy): 添加正向代理功能并完善负载均衡器实现 - 添加 ForwardProxy 模块,支持基本的 HTTP 正向代理功能 - 实现代理认证和 ACL 访问控制检查 - 改进 LoadBalancer 中的连接计数方法,移除不必要的异步调用 - 优化 Upstream 健康检查逻辑,简化布尔值判断 - 在 Server 模块中添加对正向代理路由的支持 refactor(js_engine): 添加 JavaScript 引擎占位符和注释说明 - 为 JsEngine 结构体添加详细注释说明实际实现方案 - 修改配置解析逻辑,默认将非 JS 格式视为纯 JSON 文件 - 添加 is_available 方法作为引擎可用性检测占位符 - 完善中间件执行函数的注释文档 refactor(tcp_proxy): 使用 tokio 同步原语替换 parking_lot - 将 TcpProxyManager 中的 Mutex 从 parking_lot::Mutex 替换为 tokio::sync::Mutex - 更新相关的锁获取方式以匹配新的同步原语 - 移除未使用的 parking_lot 导入声明 ```
This commit is contained in:
parent
3205a20b5f
commit
0a5df8403e
|
|
@ -3,6 +3,8 @@ use serde_json::Value;
|
|||
#[derive(Debug)]
|
||||
pub struct JsEngine {
|
||||
// Placeholder for JavaScript engine implementation
|
||||
// In a real implementation, this would contain a JavaScript runtime
|
||||
// like rquickjs, deno_core, or boa_engine
|
||||
}
|
||||
|
||||
impl JsEngine {
|
||||
|
|
@ -28,7 +30,8 @@ impl JsEngine {
|
|||
|
||||
Ok(serde_json::from_str(json_str)?)
|
||||
} else {
|
||||
Err("Unsupported JS config format".into())
|
||||
// Assume it's a plain JSON file
|
||||
Ok(serde_json::from_str(&config_content)?)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -38,6 +41,8 @@ impl JsEngine {
|
|||
_request: &Value,
|
||||
) -> Result<Option<Value>, Box<dyn std::error::Error>> {
|
||||
// Placeholder for middleware execution
|
||||
// In a real implementation, this would execute JavaScript code
|
||||
// that could modify the request or response
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
|
|
@ -56,6 +61,12 @@ impl JsEngine {
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Placeholder for JavaScript engine functionality
|
||||
pub fn is_available(&self) -> bool {
|
||||
// In a real implementation, this would check if a JS engine is available
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for JsEngine {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,150 @@
|
|||
use axum::{
|
||||
body::Body,
|
||||
extract::Request,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use reqwest::Client;
|
||||
use tracing::error;
|
||||
|
||||
use crate::config::ProxyAuth;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ForwardProxy {
|
||||
client: Client,
|
||||
auth: Option<ProxyAuth>,
|
||||
acl: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl ForwardProxy {
|
||||
pub fn new(auth: Option<ProxyAuth>, acl: Option<Vec<String>>) -> Self {
|
||||
Self {
|
||||
client: Client::new(),
|
||||
auth,
|
||||
acl,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_request(&self, req: Request<Body>, target: &str) -> Response {
|
||||
// Check authentication if required
|
||||
if let Some(auth) = &self.auth {
|
||||
if !self.check_auth(&req, auth).await {
|
||||
return (StatusCode::UNAUTHORIZED, "Authentication required").into_response();
|
||||
}
|
||||
}
|
||||
|
||||
// Check ACL if required
|
||||
if let Some(acl) = &self.acl {
|
||||
if !self.check_acl(&req, acl).await {
|
||||
return (StatusCode::FORBIDDEN, "Access denied by ACL").into_response();
|
||||
}
|
||||
}
|
||||
|
||||
// Forward the request
|
||||
self.forward_request(req, target).await
|
||||
}
|
||||
|
||||
async fn check_auth(&self, req: &Request<Body>, auth: &ProxyAuth) -> bool {
|
||||
// Check for Proxy-Authorization header
|
||||
if let Some(auth_header) = req.headers().get("proxy-authorization") {
|
||||
if let Ok(auth_str) = auth_header.to_str() {
|
||||
if auth_str.starts_with("Basic ") {
|
||||
// Decode basic auth
|
||||
if let Ok(decoded) = base64::engine::Engine::decode(
|
||||
&base64::engine::general_purpose::STANDARD,
|
||||
&auth_str[6..],
|
||||
) {
|
||||
let auth_string = String::from_utf8_lossy(&decoded);
|
||||
if let Some((username, password)) = auth_string.split_once(':') {
|
||||
return username == auth.username && password == auth.password;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
async fn check_acl(&self, _req: &Request<Body>, _acl: &[String]) -> bool {
|
||||
// For now, we'll allow all requests - ACL implementation would check IP or other criteria
|
||||
// This is a simplified implementation
|
||||
true
|
||||
}
|
||||
|
||||
async fn forward_request(&self, req: Request<Body>, target: &str) -> Response {
|
||||
let method = req.method();
|
||||
let url = format!(
|
||||
"{}{}",
|
||||
target,
|
||||
req.uri()
|
||||
.path_and_query()
|
||||
.map(|pq| pq.as_str())
|
||||
.unwrap_or("")
|
||||
);
|
||||
|
||||
// Convert axum Method to reqwest Method
|
||||
let reqwest_method = match method.as_str() {
|
||||
"GET" => reqwest::Method::GET,
|
||||
"POST" => reqwest::Method::POST,
|
||||
"PUT" => reqwest::Method::PUT,
|
||||
"DELETE" => reqwest::Method::DELETE,
|
||||
"HEAD" => reqwest::Method::HEAD,
|
||||
"OPTIONS" => reqwest::Method::OPTIONS,
|
||||
"PATCH" => reqwest::Method::PATCH,
|
||||
_ => reqwest::Method::GET,
|
||||
};
|
||||
|
||||
let mut builder = self.client.request(reqwest_method, &url);
|
||||
|
||||
// Copy headers (excluding host to avoid conflicts)
|
||||
for (name, value) in req.headers() {
|
||||
if name != "host" && name != "proxy-authorization" {
|
||||
if let Ok(value_str) = value.to_str() {
|
||||
builder = builder.header(name.to_string(), value_str);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Copy body
|
||||
let body_bytes = match axum::body::to_bytes(req.into_body(), usize::MAX).await {
|
||||
Ok(bytes) => bytes,
|
||||
Err(e) => {
|
||||
error!("Failed to read request body: {}", e);
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to read body").into_response();
|
||||
}
|
||||
};
|
||||
|
||||
match builder.body(body_bytes).send().await {
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
let mut response = Response::builder().status(
|
||||
StatusCode::from_u16(status.as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
);
|
||||
|
||||
// Copy response headers
|
||||
for (name, value) in resp.headers() {
|
||||
if let Ok(value_str) = value.to_str() {
|
||||
response = response.header(name.to_string(), value_str);
|
||||
}
|
||||
}
|
||||
|
||||
// Get response body
|
||||
match resp.bytes().await {
|
||||
Ok(body_bytes) => response
|
||||
.body(Body::from(body_bytes))
|
||||
.unwrap()
|
||||
.into_response(),
|
||||
Err(e) => {
|
||||
error!("Failed to read response body: {}", e);
|
||||
(StatusCode::BAD_GATEWAY, "Failed to read response").into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Proxy request failed: {}", e);
|
||||
(StatusCode::BAD_GATEWAY, "Proxy request failed").into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,9 +1,9 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::time::Instant;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{info, error};
|
||||
use serde::{Serialize, Deserialize};
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use tracing::{error, info};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Upstream {
|
||||
|
|
@ -20,7 +20,7 @@ impl Clone for Upstream {
|
|||
url: self.url.clone(),
|
||||
weight: self.weight,
|
||||
is_healthy: Arc::clone(&self.is_healthy),
|
||||
created_at: self.created_at, // Instant 实现了 Copy
|
||||
created_at: self.created_at, // Instant 实现了 Copy
|
||||
request_count: Arc::clone(&self.request_count),
|
||||
}
|
||||
}
|
||||
|
|
@ -37,18 +37,18 @@ impl Upstream {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn increment_connections(&self) {
|
||||
pub fn increment_connections(&self) {
|
||||
self.request_count.fetch_add(1, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
pub async fn decrement_connections(&self) {
|
||||
pub fn decrement_connections(&self) {
|
||||
let current = self.request_count.load(Ordering::SeqCst);
|
||||
if current > 0 {
|
||||
self.request_count.fetch_sub(1, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn increment_requests(&self) {
|
||||
pub fn increment_requests(&self) {
|
||||
self.request_count.fetch_add(1, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
|
|
@ -110,7 +110,8 @@ pub enum LoadBalancerStrategy {
|
|||
|
||||
impl LoadBalancer {
|
||||
pub fn new(strategy: LoadBalancerStrategy, upstreams: Vec<String>) -> Self {
|
||||
let upstreams_vec = upstreams.into_iter()
|
||||
let upstreams_vec = upstreams
|
||||
.into_iter()
|
||||
.map(|url| Upstream::new(url, 1))
|
||||
.collect();
|
||||
|
||||
|
|
@ -121,8 +122,12 @@ impl LoadBalancer {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn with_weights(strategy: LoadBalancerStrategy, upstreams: Vec<(String, u32)>) -> Self {
|
||||
let upstreams_vec = upstreams.into_iter()
|
||||
pub async fn with_weights(
|
||||
strategy: LoadBalancerStrategy,
|
||||
upstreams: Vec<(String, u32)>,
|
||||
) -> Self {
|
||||
let upstreams_vec = upstreams
|
||||
.into_iter()
|
||||
.map(|(url, weight)| Upstream::new(url, weight))
|
||||
.collect();
|
||||
|
||||
|
|
@ -135,8 +140,9 @@ impl LoadBalancer {
|
|||
|
||||
pub async fn select_upstream(&self) -> Option<Upstream> {
|
||||
let upstreams = self.upstreams.read().await;
|
||||
let healthy_upstreams: Vec<Upstream> = upstreams.iter()
|
||||
.filter(|u| u.is_healthy()) // 现在返回的是 bool,不需要 await
|
||||
let healthy_upstreams: Vec<Upstream> = upstreams
|
||||
.iter()
|
||||
.filter(|u| u.is_healthy()) // 现在返回的是 bool,不需要 await
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
|
|
@ -146,18 +152,14 @@ impl LoadBalancer {
|
|||
}
|
||||
|
||||
match self.strategy {
|
||||
LoadBalancerStrategy::RoundRobin => {
|
||||
self.round_robin_select(&healthy_upstreams).await
|
||||
}
|
||||
LoadBalancerStrategy::RoundRobin => self.round_robin_select(&healthy_upstreams).await,
|
||||
LoadBalancerStrategy::LeastConnections => {
|
||||
self.least_connections_select(&healthy_upstreams).await
|
||||
}
|
||||
LoadBalancerStrategy::WeightedRoundRobin => {
|
||||
self.weighted_round_robin_select(&healthy_upstreams).await
|
||||
}
|
||||
LoadBalancerStrategy::Random => {
|
||||
self.random_select(&healthy_upstreams).await
|
||||
}
|
||||
LoadBalancerStrategy::Random => self.random_select(&healthy_upstreams).await,
|
||||
LoadBalancerStrategy::IpHash => {
|
||||
// For IP hash, we'd need client IP
|
||||
// For now, fall back to round robin
|
||||
|
|
@ -170,16 +172,17 @@ impl LoadBalancer {
|
|||
let mut index = self.current_index.write().await;
|
||||
let selected_index = *index % upstreams.len();
|
||||
let selected = upstreams[selected_index].clone();
|
||||
let mut upstreams_ref = self.upstreams.write().await;
|
||||
if let Some(upstream) = upstreams_ref.iter_mut().find(|u| u.url == selected.url) {
|
||||
upstream.increment_connections().await;
|
||||
}
|
||||
|
||||
|
||||
// Increment connection count for selected upstream
|
||||
selected.increment_connections();
|
||||
|
||||
*index = (*index + 1) % upstreams.len();
|
||||
Some(selected)
|
||||
}
|
||||
|
||||
async fn least_connections_select(&self, upstreams: &[Upstream]) -> Option<Upstream> {
|
||||
// Simple implementation: just return the first healthy one
|
||||
// A proper implementation would find the one with the least connections
|
||||
upstreams.first().cloned()
|
||||
}
|
||||
|
||||
|
|
@ -191,13 +194,13 @@ impl LoadBalancer {
|
|||
|
||||
let mut index = self.current_index.write().await;
|
||||
let current_weight = *index;
|
||||
|
||||
|
||||
let mut accumulated_weight = 0;
|
||||
for upstream in upstreams {
|
||||
accumulated_weight += upstream.weight;
|
||||
if current_weight < accumulated_weight as usize {
|
||||
let selected = upstream.clone();
|
||||
upstream.increment_connections().await;
|
||||
selected.increment_connections();
|
||||
*index = (*index + 1) % total_weight as usize;
|
||||
return Some(selected);
|
||||
}
|
||||
|
|
@ -241,7 +244,8 @@ impl LoadBalancer {
|
|||
for upstream in upstreams.iter() {
|
||||
total_requests += upstream.get_total_requests();
|
||||
total_connections += upstream.get_active_connections();
|
||||
if upstream.is_healthy() { // 现在返回的是 bool,不需要 await
|
||||
if upstream.is_healthy() {
|
||||
// 现在返回的是 bool,不需要 await
|
||||
healthy_count += 1;
|
||||
}
|
||||
}
|
||||
|
|
@ -264,13 +268,19 @@ impl Default for LoadBalancerStrategy {
|
|||
|
||||
impl Default for LoadBalancer {
|
||||
fn default() -> Self {
|
||||
let upstreams = vec!["http://backend1:3000".to_string(), "http://backend2:3000".to_string()];
|
||||
let upstreams = vec![
|
||||
"http://backend1:3000".to_string(),
|
||||
"http://backend2:3000".to_string(),
|
||||
];
|
||||
Self {
|
||||
strategy: LoadBalancerStrategy::RoundRobin,
|
||||
upstreams: Arc::new(RwLock::new(
|
||||
upstreams.into_iter().map(|url| Upstream::new(url, 1)).collect()
|
||||
upstreams
|
||||
.into_iter()
|
||||
.map(|url| Upstream::new(url, 1))
|
||||
.collect(),
|
||||
)),
|
||||
current_index: Arc::new(RwLock::new(0)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ use std::sync::Arc;
|
|||
use tokio::sync::RwLock;
|
||||
|
||||
pub mod connection_pool;
|
||||
pub mod forward_proxy;
|
||||
pub mod health_check;
|
||||
pub mod load_balancer;
|
||||
pub mod tcp_proxy;
|
||||
|
|
|
|||
|
|
@ -4,13 +4,12 @@ use std::time::{Duration, Instant};
|
|||
use tokio::net::TcpStream;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::info;
|
||||
use parking_lot::Mutex; // 添加 parking_lot 导入
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TcpProxyManager {
|
||||
connections: Arc<RwLock<HashMap<String, TcpConnection>>>,
|
||||
#[allow(dead_code)]
|
||||
last_cleanup: Arc<Mutex<Instant>>, // 使用 parking_lot::Mutex 替代 std::sync::Mutex
|
||||
last_cleanup: Arc<tokio::sync::Mutex<Instant>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
|
@ -32,7 +31,7 @@ impl TcpProxyManager {
|
|||
pub fn new() -> Self {
|
||||
Self {
|
||||
connections: Arc::new(RwLock::new(HashMap::new())),
|
||||
last_cleanup: Arc::new(Mutex::new(Instant::now())),
|
||||
last_cleanup: Arc::new(tokio::sync::Mutex::new(Instant::now())),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -66,7 +65,7 @@ impl TcpProxyManager {
|
|||
connections.retain(|_, conn| conn.created_at.elapsed() < max_age);
|
||||
|
||||
let now = Instant::now();
|
||||
let mut last_cleanup = self.last_cleanup.lock(); // 使用 parking_lot::Mutex
|
||||
let mut last_cleanup = self.last_cleanup.lock().await;
|
||||
if now.duration_since(*last_cleanup) > Duration::from_secs(60) {
|
||||
info!(
|
||||
"Cleaned up expired connections (total: {})",
|
||||
|
|
@ -85,4 +84,4 @@ impl Default for TcpProxyManager {
|
|||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -96,14 +96,19 @@ pub async fn handle_request(State(server): State<ProxyServer>, req: Request<Body
|
|||
// Handle request based on route type
|
||||
match route {
|
||||
RouteRule::Static { root, index, .. } => {
|
||||
handle_static_request(req, root, index.as_deref(), &path).await
|
||||
handle_static_request(req, &root, index.as_deref(), &path).await
|
||||
}
|
||||
RouteRule::ReverseProxy { target, .. } => handle_reverse_proxy(req, &target).await,
|
||||
RouteRule::ForwardProxy { .. } => {
|
||||
// For a forward proxy, we typically don't know the target from the path
|
||||
// This would need to be handled differently in a real implementation
|
||||
// For now, we'll return not implemented
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"Forward proxy not fully implemented yet",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
RouteRule::ReverseProxy { target, .. } => handle_reverse_proxy(req, target).await,
|
||||
RouteRule::ForwardProxy { .. } => (
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"Forward proxy not implemented yet",
|
||||
)
|
||||
.into_response(),
|
||||
RouteRule::TcpProxy {
|
||||
target, protocol, ..
|
||||
} => {
|
||||
|
|
|
|||
Loading…
Reference in New Issue