feat(proxy): 添加正向代理功能并完善负载均衡器实现

- 添加 ForwardProxy 模块,支持基本的 HTTP 正向代理功能
- 实现代理认证和 ACL 访问控制检查
- 改进 LoadBalancer 中的连接计数方法,移除不必要的异步调用
- 优化 Upstream 健康检查逻辑,简化布尔值判断
- 在 Server 模块中添加对正向代理路由的支持

refactor(js_engine): 添加 JavaScript 引擎占位符和注释说明

- 为 JsEngine 结构体添加详细注释说明实际实现方案
- 修改配置解析逻辑,默认将非 JS 格式视为纯 JSON 文件
- 添加 is_available 方法作为引擎可用性检测占位符
- 完善中间件执行函数的注释文档

refactor(tcp_proxy): 使用 tokio 同步原语替换 parking_lot

- 将 TcpProxyManager 中的 Mutex 从 parking_lot::Mutex
  替换为 tokio::sync::Mutex
- 更新相关的锁获取方式以匹配新的同步原语
- 移除未使用的 parking_lot 导入声明
```
This commit is contained in:
kingecg 2026-01-16 21:21:16 +08:00
parent 3205a20b5f
commit 0a5df8403e
6 changed files with 218 additions and 42 deletions

View File

@ -3,6 +3,8 @@ use serde_json::Value;
#[derive(Debug)]
pub struct JsEngine {
// Placeholder for JavaScript engine implementation
// In a real implementation, this would contain a JavaScript runtime
// like rquickjs, deno_core, or boa_engine
}
impl JsEngine {
@ -28,7 +30,8 @@ impl JsEngine {
Ok(serde_json::from_str(json_str)?)
} else {
Err("Unsupported JS config format".into())
// Assume it's a plain JSON file
Ok(serde_json::from_str(&config_content)?)
}
}
@ -38,6 +41,8 @@ impl JsEngine {
_request: &Value,
) -> Result<Option<Value>, Box<dyn std::error::Error>> {
// Placeholder for middleware execution
// In a real implementation, this would execute JavaScript code
// that could modify the request or response
Ok(None)
}
@ -56,6 +61,12 @@ impl JsEngine {
Ok(())
}
// Placeholder for JavaScript engine functionality
pub fn is_available(&self) -> bool {
// In a real implementation, this would check if a JS engine is available
false
}
}
impl Default for JsEngine {

150
src/proxy/forward_proxy.rs Normal file
View File

@ -0,0 +1,150 @@
use axum::{
body::Body,
extract::Request,
http::StatusCode,
response::{IntoResponse, Response},
};
use reqwest::Client;
use tracing::error;
use crate::config::ProxyAuth;
#[derive(Debug)]
pub struct ForwardProxy {
client: Client,
auth: Option<ProxyAuth>,
acl: Option<Vec<String>>,
}
impl ForwardProxy {
pub fn new(auth: Option<ProxyAuth>, acl: Option<Vec<String>>) -> Self {
Self {
client: Client::new(),
auth,
acl,
}
}
pub async fn handle_request(&self, req: Request<Body>, target: &str) -> Response {
// Check authentication if required
if let Some(auth) = &self.auth {
if !self.check_auth(&req, auth).await {
return (StatusCode::UNAUTHORIZED, "Authentication required").into_response();
}
}
// Check ACL if required
if let Some(acl) = &self.acl {
if !self.check_acl(&req, acl).await {
return (StatusCode::FORBIDDEN, "Access denied by ACL").into_response();
}
}
// Forward the request
self.forward_request(req, target).await
}
async fn check_auth(&self, req: &Request<Body>, auth: &ProxyAuth) -> bool {
// Check for Proxy-Authorization header
if let Some(auth_header) = req.headers().get("proxy-authorization") {
if let Ok(auth_str) = auth_header.to_str() {
if auth_str.starts_with("Basic ") {
// Decode basic auth
if let Ok(decoded) = base64::engine::Engine::decode(
&base64::engine::general_purpose::STANDARD,
&auth_str[6..],
) {
let auth_string = String::from_utf8_lossy(&decoded);
if let Some((username, password)) = auth_string.split_once(':') {
return username == auth.username && password == auth.password;
}
}
}
}
}
false
}
async fn check_acl(&self, _req: &Request<Body>, _acl: &[String]) -> bool {
// For now, we'll allow all requests - ACL implementation would check IP or other criteria
// This is a simplified implementation
true
}
async fn forward_request(&self, req: Request<Body>, target: &str) -> Response {
let method = req.method();
let url = format!(
"{}{}",
target,
req.uri()
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("")
);
// Convert axum Method to reqwest Method
let reqwest_method = match method.as_str() {
"GET" => reqwest::Method::GET,
"POST" => reqwest::Method::POST,
"PUT" => reqwest::Method::PUT,
"DELETE" => reqwest::Method::DELETE,
"HEAD" => reqwest::Method::HEAD,
"OPTIONS" => reqwest::Method::OPTIONS,
"PATCH" => reqwest::Method::PATCH,
_ => reqwest::Method::GET,
};
let mut builder = self.client.request(reqwest_method, &url);
// Copy headers (excluding host to avoid conflicts)
for (name, value) in req.headers() {
if name != "host" && name != "proxy-authorization" {
if let Ok(value_str) = value.to_str() {
builder = builder.header(name.to_string(), value_str);
}
}
}
// Copy body
let body_bytes = match axum::body::to_bytes(req.into_body(), usize::MAX).await {
Ok(bytes) => bytes,
Err(e) => {
error!("Failed to read request body: {}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to read body").into_response();
}
};
match builder.body(body_bytes).send().await {
Ok(resp) => {
let status = resp.status();
let mut response = Response::builder().status(
StatusCode::from_u16(status.as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
);
// Copy response headers
for (name, value) in resp.headers() {
if let Ok(value_str) = value.to_str() {
response = response.header(name.to_string(), value_str);
}
}
// Get response body
match resp.bytes().await {
Ok(body_bytes) => response
.body(Body::from(body_bytes))
.unwrap()
.into_response(),
Err(e) => {
error!("Failed to read response body: {}", e);
(StatusCode::BAD_GATEWAY, "Failed to read response").into_response()
}
}
}
Err(e) => {
error!("Proxy request failed: {}", e);
(StatusCode::BAD_GATEWAY, "Proxy request failed").into_response()
}
}
}
}

View File

@ -1,9 +1,9 @@
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::Instant;
use tokio::sync::RwLock;
use tracing::{info, error};
use serde::{Serialize, Deserialize};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use tracing::{error, info};
#[derive(Debug)]
pub struct Upstream {
@ -20,7 +20,7 @@ impl Clone for Upstream {
url: self.url.clone(),
weight: self.weight,
is_healthy: Arc::clone(&self.is_healthy),
created_at: self.created_at, // Instant 实现了 Copy
created_at: self.created_at, // Instant 实现了 Copy
request_count: Arc::clone(&self.request_count),
}
}
@ -37,18 +37,18 @@ impl Upstream {
}
}
pub async fn increment_connections(&self) {
pub fn increment_connections(&self) {
self.request_count.fetch_add(1, Ordering::SeqCst);
}
pub async fn decrement_connections(&self) {
pub fn decrement_connections(&self) {
let current = self.request_count.load(Ordering::SeqCst);
if current > 0 {
self.request_count.fetch_sub(1, Ordering::SeqCst);
}
}
pub async fn increment_requests(&self) {
pub fn increment_requests(&self) {
self.request_count.fetch_add(1, Ordering::SeqCst);
}
@ -110,7 +110,8 @@ pub enum LoadBalancerStrategy {
impl LoadBalancer {
pub fn new(strategy: LoadBalancerStrategy, upstreams: Vec<String>) -> Self {
let upstreams_vec = upstreams.into_iter()
let upstreams_vec = upstreams
.into_iter()
.map(|url| Upstream::new(url, 1))
.collect();
@ -121,8 +122,12 @@ impl LoadBalancer {
}
}
pub async fn with_weights(strategy: LoadBalancerStrategy, upstreams: Vec<(String, u32)>) -> Self {
let upstreams_vec = upstreams.into_iter()
pub async fn with_weights(
strategy: LoadBalancerStrategy,
upstreams: Vec<(String, u32)>,
) -> Self {
let upstreams_vec = upstreams
.into_iter()
.map(|(url, weight)| Upstream::new(url, weight))
.collect();
@ -135,8 +140,9 @@ impl LoadBalancer {
pub async fn select_upstream(&self) -> Option<Upstream> {
let upstreams = self.upstreams.read().await;
let healthy_upstreams: Vec<Upstream> = upstreams.iter()
.filter(|u| u.is_healthy()) // 现在返回的是 bool不需要 await
let healthy_upstreams: Vec<Upstream> = upstreams
.iter()
.filter(|u| u.is_healthy()) // 现在返回的是 bool不需要 await
.cloned()
.collect();
@ -146,18 +152,14 @@ impl LoadBalancer {
}
match self.strategy {
LoadBalancerStrategy::RoundRobin => {
self.round_robin_select(&healthy_upstreams).await
}
LoadBalancerStrategy::RoundRobin => self.round_robin_select(&healthy_upstreams).await,
LoadBalancerStrategy::LeastConnections => {
self.least_connections_select(&healthy_upstreams).await
}
LoadBalancerStrategy::WeightedRoundRobin => {
self.weighted_round_robin_select(&healthy_upstreams).await
}
LoadBalancerStrategy::Random => {
self.random_select(&healthy_upstreams).await
}
LoadBalancerStrategy::Random => self.random_select(&healthy_upstreams).await,
LoadBalancerStrategy::IpHash => {
// For IP hash, we'd need client IP
// For now, fall back to round robin
@ -170,16 +172,17 @@ impl LoadBalancer {
let mut index = self.current_index.write().await;
let selected_index = *index % upstreams.len();
let selected = upstreams[selected_index].clone();
let mut upstreams_ref = self.upstreams.write().await;
if let Some(upstream) = upstreams_ref.iter_mut().find(|u| u.url == selected.url) {
upstream.increment_connections().await;
}
// Increment connection count for selected upstream
selected.increment_connections();
*index = (*index + 1) % upstreams.len();
Some(selected)
}
async fn least_connections_select(&self, upstreams: &[Upstream]) -> Option<Upstream> {
// Simple implementation: just return the first healthy one
// A proper implementation would find the one with the least connections
upstreams.first().cloned()
}
@ -191,13 +194,13 @@ impl LoadBalancer {
let mut index = self.current_index.write().await;
let current_weight = *index;
let mut accumulated_weight = 0;
for upstream in upstreams {
accumulated_weight += upstream.weight;
if current_weight < accumulated_weight as usize {
let selected = upstream.clone();
upstream.increment_connections().await;
selected.increment_connections();
*index = (*index + 1) % total_weight as usize;
return Some(selected);
}
@ -241,7 +244,8 @@ impl LoadBalancer {
for upstream in upstreams.iter() {
total_requests += upstream.get_total_requests();
total_connections += upstream.get_active_connections();
if upstream.is_healthy() { // 现在返回的是 bool不需要 await
if upstream.is_healthy() {
// 现在返回的是 bool不需要 await
healthy_count += 1;
}
}
@ -264,13 +268,19 @@ impl Default for LoadBalancerStrategy {
impl Default for LoadBalancer {
fn default() -> Self {
let upstreams = vec!["http://backend1:3000".to_string(), "http://backend2:3000".to_string()];
let upstreams = vec![
"http://backend1:3000".to_string(),
"http://backend2:3000".to_string(),
];
Self {
strategy: LoadBalancerStrategy::RoundRobin,
upstreams: Arc::new(RwLock::new(
upstreams.into_iter().map(|url| Upstream::new(url, 1)).collect()
upstreams
.into_iter()
.map(|url| Upstream::new(url, 1))
.collect(),
)),
current_index: Arc::new(RwLock::new(0)),
}
}
}
}

View File

@ -3,6 +3,7 @@ use std::sync::Arc;
use tokio::sync::RwLock;
pub mod connection_pool;
pub mod forward_proxy;
pub mod health_check;
pub mod load_balancer;
pub mod tcp_proxy;

View File

@ -4,13 +4,12 @@ use std::time::{Duration, Instant};
use tokio::net::TcpStream;
use tokio::sync::RwLock;
use tracing::info;
use parking_lot::Mutex; // 添加 parking_lot 导入
#[derive(Debug)]
pub struct TcpProxyManager {
connections: Arc<RwLock<HashMap<String, TcpConnection>>>,
#[allow(dead_code)]
last_cleanup: Arc<Mutex<Instant>>, // 使用 parking_lot::Mutex 替代 std::sync::Mutex
last_cleanup: Arc<tokio::sync::Mutex<Instant>>,
}
#[derive(Debug, Clone)]
@ -32,7 +31,7 @@ impl TcpProxyManager {
pub fn new() -> Self {
Self {
connections: Arc::new(RwLock::new(HashMap::new())),
last_cleanup: Arc::new(Mutex::new(Instant::now())),
last_cleanup: Arc::new(tokio::sync::Mutex::new(Instant::now())),
}
}
@ -66,7 +65,7 @@ impl TcpProxyManager {
connections.retain(|_, conn| conn.created_at.elapsed() < max_age);
let now = Instant::now();
let mut last_cleanup = self.last_cleanup.lock(); // 使用 parking_lot::Mutex
let mut last_cleanup = self.last_cleanup.lock().await;
if now.duration_since(*last_cleanup) > Duration::from_secs(60) {
info!(
"Cleaned up expired connections (total: {})",
@ -85,4 +84,4 @@ impl Default for TcpProxyManager {
fn default() -> Self {
Self::new()
}
}
}

View File

@ -96,14 +96,19 @@ pub async fn handle_request(State(server): State<ProxyServer>, req: Request<Body
// Handle request based on route type
match route {
RouteRule::Static { root, index, .. } => {
handle_static_request(req, root, index.as_deref(), &path).await
handle_static_request(req, &root, index.as_deref(), &path).await
}
RouteRule::ReverseProxy { target, .. } => handle_reverse_proxy(req, &target).await,
RouteRule::ForwardProxy { .. } => {
// For a forward proxy, we typically don't know the target from the path
// This would need to be handled differently in a real implementation
// For now, we'll return not implemented
(
StatusCode::NOT_IMPLEMENTED,
"Forward proxy not fully implemented yet",
)
.into_response()
}
RouteRule::ReverseProxy { target, .. } => handle_reverse_proxy(req, target).await,
RouteRule::ForwardProxy { .. } => (
StatusCode::NOT_IMPLEMENTED,
"Forward proxy not implemented yet",
)
.into_response(),
RouteRule::TcpProxy {
target, protocol, ..
} => {