From 2c0e491b3ef72ed34de28911034adf987f680032 Mon Sep 17 00:00:00 2001 From: y0sy4 Date: Tue, 24 Mar 2026 20:03:54 +0300 Subject: [PATCH] fix: upstream proxy support for --upstream-proxy flag (Issue #2) --- cmd/proxy/main.go | 37 +++++++++---- go.mod | 4 +- go.sum | 2 + internal/config/config.go | 19 +++---- internal/proxy/http_proxy.go | 45 +++++++++++++--- internal/proxy/proxy.go | 94 +++++++++++++++++++++++++-------- internal/websocket/websocket.go | 61 +++++++++++++++++---- mobile/mobile.go | 2 +- 8 files changed, 203 insertions(+), 61 deletions(-) create mode 100644 go.sum diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 044aa69..e13d878 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -116,15 +116,24 @@ func main() { if *auth != "" { cfg.Auth = *auth } - - // Setup logging - default to file if not specified - logPath := *logFile - if logPath == "" { - // Use default log file in app config directory - appDir := getAppDir() - logPath = filepath.Join(appDir, "proxy.log") + if *upstreamProxy != "" { + cfg.UpstreamProxy = *upstreamProxy + } + + // Setup logging - log to stdout if verbose, otherwise to file + var logger *log.Logger + logPath := *logFile + if cfg.Verbose && logPath == "" { + // Verbose mode: log to stdout + logger = setupLogging("", cfg.LogMaxMB, cfg.Verbose) + } else { + // File mode: log to file (default to app dir if not specified) + if logPath == "" { + appDir := getAppDir() + logPath = filepath.Join(appDir, "proxy.log") + } + logger = setupLogging(logPath, cfg.LogMaxMB, cfg.Verbose) } - logger := setupLogging(logPath, cfg.LogMaxMB, cfg.Verbose) // Log advanced features usage and start HTTP proxy if *httpPort != 0 { @@ -146,7 +155,7 @@ func main() { } // Create and start server - server, err := proxy.NewServer(cfg, logger) + server, err := proxy.NewServer(cfg, logger, cfg.UpstreamProxy) if err != nil { log.Fatalf("Failed to create server: %v", err) } @@ -248,8 +257,12 @@ func getAppDir() string { func setupLogging(logFile string, logMaxMB float64, verbose bool) *log.Logger { flags := log.LstdFlags | log.Lshortfile - if verbose { - flags |= log.Lshortfile + + // If verbose and no log file specified, log to stdout + if verbose && logFile == "" { + log.SetOutput(os.Stdout) + log.SetFlags(flags) + return log.New(os.Stdout, "", flags) } // Ensure directory exists @@ -260,6 +273,8 @@ func setupLogging(logFile string, logMaxMB float64, verbose bool) *log.Logger { f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { log.Printf("Warning: failed to open log file %s: %v, using stdout", logFile, err) + log.SetOutput(os.Stdout) + log.SetFlags(flags) return log.New(os.Stdout, "", flags) } diff --git a/go.mod b/go.mod index 80e423d..5d5da38 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/Flowseal/tg-ws-proxy -go 1.21 +go 1.25.0 + +require golang.org/x/net v0.52.0 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e3b24b9 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= diff --git a/internal/config/config.go b/internal/config/config.go index 756d96c..a14bdd4 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -13,15 +13,16 @@ import ( // Config holds the proxy configuration. type Config struct { - Port int `json:"port"` - Host string `json:"host"` - DCIP []string `json:"dc_ip"` - Verbose bool `json:"verbose"` - AutoStart bool `json:"autostart"` - LogMaxMB float64 `json:"log_max_mb"` - BufKB int `json:"buf_kb"` - PoolSize int `json:"pool_size"` - Auth string `json:"auth"` // username:password + Port int `json:"port"` + Host string `json:"host"` + DCIP []string `json:"dc_ip"` + Verbose bool `json:"verbose"` + AutoStart bool `json:"autostart"` + LogMaxMB float64 `json:"log_max_mb"` + BufKB int `json:"buf_kb"` + PoolSize int `json:"pool_size"` + Auth string `json:"auth"` // username:password + UpstreamProxy string `json:"upstream_proxy"` } // DefaultConfig returns the default configuration. diff --git a/internal/proxy/http_proxy.go b/internal/proxy/http_proxy.go index d0f5d13..3f7eeee 100644 --- a/internal/proxy/http_proxy.go +++ b/internal/proxy/http_proxy.go @@ -10,16 +10,45 @@ import ( "net/http" "net/url" "strings" + + "golang.org/x/net/proxy" ) // HTTPProxy represents an HTTP proxy server. type HTTPProxy struct { - port int - verbose bool - logger *log.Logger + port int + verbose bool + logger *log.Logger upstreamProxy *url.URL } +// dialWithUpstream creates a connection, optionally routing through an upstream proxy. +func (h *HTTPProxy) dialWithUpstream(network, addr string) (net.Conn, error) { + if h.upstreamProxy == nil { + return net.Dial(network, addr) + } + + switch h.upstreamProxy.Scheme { + case "socks5", "socks": + // Use proxy package for SOCKS5 + proxyDialer, err := proxy.FromURL(h.upstreamProxy, proxy.Direct) + if err != nil { + return nil, fmt.Errorf("create SOCKS5 dialer: %w", err) + } + return proxyDialer.Dial(network, addr) + + case "http", "https": + // Use http.Transport with Proxy for HTTP CONNECT + transport := &http.Transport{ + Proxy: http.ProxyURL(h.upstreamProxy), + } + return transport.Dial(network, addr) + + default: + return nil, fmt.Errorf("unsupported upstream proxy scheme: %s", h.upstreamProxy.Scheme) + } +} + // NewHTTPProxy creates a new HTTP proxy server. func NewHTTPProxy(port int, verbose bool, logger *log.Logger, upstreamProxyURL string) (*HTTPProxy, error) { var upstreamProxy *url.URL @@ -87,18 +116,18 @@ func (h *HTTPProxy) handleConnect(conn net.Conn, req *http.Request) { if !strings.Contains(host, ":") { host = host + ":80" } - - // Connect to target - target, err := net.Dial("tcp", host) + + // Connect to target (with upstream proxy if configured) + target, err := h.dialWithUpstream("tcp", host) if err != nil { conn.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n")) return } defer target.Close() - + // Send success response conn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")) - + // Bridge connections go io.Copy(target, conn) io.Copy(conn, target) diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 13870c3..c54a6d1 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -9,6 +9,8 @@ import ( "io" "log" "net" + "net/http" + "net/url" "sort" "strings" "sync" @@ -20,6 +22,7 @@ import ( "github.com/Flowseal/tg-ws-proxy/internal/pool" "github.com/Flowseal/tg-ws-proxy/internal/socks5" "github.com/Flowseal/tg-ws-proxy/internal/websocket" + "golang.org/x/net/proxy" ) const ( @@ -146,37 +149,80 @@ func (s *Stats) Summary() string { // Server represents the TG WS Proxy server. type Server struct { - config *config.Config - dcOpt map[int]string - wsPool *pool.WSPool - stats *Stats - wsBlacklist map[pool.DCKey]bool - dcFailUntil map[pool.DCKey]time.Time - mu sync.RWMutex - listener net.Listener - logger *log.Logger + config *config.Config + dcOpt map[int]string + wsPool *pool.WSPool + stats *Stats + wsBlacklist map[pool.DCKey]bool + dcFailUntil map[pool.DCKey]time.Time + mu sync.RWMutex + listener net.Listener + logger *log.Logger + upstreamProxy string } // NewServer creates a new proxy server. -func NewServer(cfg *config.Config, logger *log.Logger) (*Server, error) { +func NewServer(cfg *config.Config, logger *log.Logger, upstreamProxy string) (*Server, error) { dcOpt, err := config.ParseDCIPList(cfg.DCIP) if err != nil { return nil, err } s := &Server{ - config: cfg, - dcOpt: dcOpt, - wsPool: pool.NewWSPool(cfg.PoolSize, defaultPoolMaxAge), - stats: &Stats{}, - wsBlacklist: make(map[pool.DCKey]bool), - dcFailUntil: make(map[pool.DCKey]time.Time), - logger: logger, + config: cfg, + dcOpt: dcOpt, + wsPool: pool.NewWSPool(cfg.PoolSize, defaultPoolMaxAge), + stats: &Stats{}, + wsBlacklist: make(map[pool.DCKey]bool), + dcFailUntil: make(map[pool.DCKey]time.Time), + logger: logger, + upstreamProxy: upstreamProxy, } return s, nil } +// dialWithUpstream creates a connection, optionally routing through an upstream proxy. +func (s *Server) dialWithUpstream(network, addr string, timeout time.Duration) (net.Conn, error) { + if s.upstreamProxy == "" { + return net.DialTimeout(network, addr, timeout) + } + + // Parse upstream proxy URL + u, err := url.Parse(s.upstreamProxy) + if err != nil { + return nil, fmt.Errorf("parse upstream proxy: %w", err) + } + + switch u.Scheme { + case "socks5", "socks": + var auth *proxy.Auth + if u.User != nil { + password, _ := u.User.Password() + auth = &proxy.Auth{ + User: u.User.Username(), + Password: password, + } + } + dialer, err := proxy.SOCKS5(network, u.Host, auth, proxy.Direct) + if err != nil { + return nil, fmt.Errorf("create SOCKS5 dialer: %w", err) + } + return dialer.Dial(network, addr) + + case "http", "https": + // Use http.Transport with Proxy for HTTP CONNECT + transport := &http.Transport{ + Proxy: http.ProxyURL(u), + TLSHandshakeTimeout: timeout, + } + return transport.Dial(network, addr) + + default: + return nil, fmt.Errorf("unsupported upstream proxy scheme: %s", u.Scheme) + } +} + // Start starts the proxy server. func (s *Server) Start(ctx context.Context) error { addr := net.JoinHostPort(s.config.Host, fmt.Sprintf("%d", s.config.Port)) @@ -407,7 +453,9 @@ func (s *Server) getWebSocket(dcKey pool.DCKey, targetIP string, domains []strin s.logInfo("[%s] DC%d%s (%s:%d) -> %s via %s", label, dc, mediaTag, dst, port, url, targetIP) // Connect using targetIP, but use domain for TLS handshake - ws, wsErr = websocket.Connect(targetIP, domain, "/apiws", wsTimeout) + ws, wsErr = websocket.ConnectWithDialer(targetIP, domain, "/apiws", wsTimeout, func(network, addr string) (net.Conn, error) { + return s.dialWithUpstream(network, addr, wsTimeout) + }) if wsErr == nil { allRedirects = false break @@ -450,7 +498,7 @@ func (s *Server) getWebSocket(dcKey pool.DCKey, targetIP string, domains []strin } func (s *Server) handlePassthrough(conn net.Conn, dst string, port uint16, label string) { - remoteConn, err := net.DialTimeout("tcp", net.JoinHostPort(dst, fmt.Sprintf("%d", port)), 10*time.Second) + remoteConn, err := s.dialWithUpstream("tcp", net.JoinHostPort(dst, fmt.Sprintf("%d", port)), 10*time.Second) if err != nil { s.logWarning("[%s] passthrough failed to %s: %v", label, dst, err) conn.Write(socks5.Reply(socks5.ReplyFail)) @@ -465,7 +513,7 @@ func (s *Server) handlePassthrough(conn net.Conn, dst string, port uint16, label // handleIPv6Connection handles IPv6 connections via dual-stack or IPv4-mapped addresses. func (s *Server) handleIPv6Connection(conn net.Conn, ipv6Addr string, port uint16, label string) { // Try direct IPv6 first - remoteConn, err := net.DialTimeout("tcp6", net.JoinHostPort(ipv6Addr, fmt.Sprintf("%d", port)), 10*time.Second) + remoteConn, err := s.dialWithUpstream("tcp6", net.JoinHostPort(ipv6Addr, fmt.Sprintf("%d", port)), 10*time.Second) if err == nil { s.logInfo("[%s] IPv6 direct connection successful", label) defer remoteConn.Close() @@ -525,7 +573,7 @@ func extractIPv4FromNAT64(ipv6, prefix string) string { } func (s *Server) handleTCPFallback(conn net.Conn, dst string, port uint16, init []byte, label string, dc int, isMedia bool) { - remoteConn, err := net.DialTimeout("tcp", net.JoinHostPort(dst, fmt.Sprintf("%d", port)), 10*time.Second) + remoteConn, err := s.dialWithUpstream("tcp", net.JoinHostPort(dst, fmt.Sprintf("%d", port)), 10*time.Second) if err != nil { s.logWarning("[%s] TCP fallback to %s:%d failed: %v", label, dst, port, err) return @@ -672,7 +720,9 @@ func (s *Server) warmupPool() { go func(dcKey pool.DCKey, targetIP string, domains []string) { for s.wsPool.NeedRefill(dcKey) { for _, domain := range domains { - ws, err := websocket.Connect(targetIP, domain, "/apiws", wsConnectTimeout) + ws, err := websocket.ConnectWithDialer(targetIP, domain, "/apiws", wsConnectTimeout, func(network, addr string) (net.Conn, error) { + return s.dialWithUpstream(network, addr, wsConnectTimeout) + }) if err == nil { s.wsPool.Put(dcKey, ws) break diff --git a/internal/websocket/websocket.go b/internal/websocket/websocket.go index 962e5d8..3fac0f9 100644 --- a/internal/websocket/websocket.go +++ b/internal/websocket/websocket.go @@ -44,6 +44,12 @@ type WebSocket struct { // Connect establishes a WebSocket connection to the given domain via IP. func Connect(ip, domain, path string, timeout time.Duration) (*WebSocket, error) { + return ConnectWithDialer(ip, domain, path, timeout, nil) +} + +// ConnectWithDialer establishes a WebSocket connection using a custom dialer. +// If dialer is nil, it uses direct connection. +func ConnectWithDialer(ip, domain, path string, timeout time.Duration, dialFunc func(network, addr string) (net.Conn, error)) (*WebSocket, error) { if path == "" { path = "/apiws" } @@ -56,18 +62,55 @@ func Connect(ip, domain, path string, timeout time.Duration) (*WebSocket, error) wsKey := base64.StdEncoding.EncodeToString(keyBytes) // Dial TLS connection - dialer := &net.Dialer{Timeout: timeout} - tlsConfig := &tls.Config{ - ServerName: domain, - InsecureSkipVerify: true, + var rawConn net.Conn + var err error + + if dialFunc != nil { + // Use custom dialer + rawConn, err = dialFunc("tcp", net.JoinHostPort(ip, "443")) + if err != nil { + return nil, fmt.Errorf("dial: %w", err) + } + // Wrap with TLS + tlsConfig := &tls.Config{ + ServerName: domain, + InsecureSkipVerify: true, + } + rawConn = tls.Client(rawConn, tlsConfig) + // Set handshake timeout + if err := rawConn.SetDeadline(time.Now().Add(timeout)); err != nil { + rawConn.Close() + return nil, err + } + } else { + // Direct connection + dialer := &net.Dialer{Timeout: timeout} + tlsConfig := &tls.Config{ + ServerName: domain, + InsecureSkipVerify: true, + } + rawConn, err = tls.DialWithDialer(dialer, "tcp", net.JoinHostPort(ip, "443"), tlsConfig) + if err != nil { + return nil, fmt.Errorf("tls dial: %w", err) + } } - rawConn, err := tls.DialWithDialer(dialer, "tcp", net.JoinHostPort(ip, "443"), tlsConfig) - if err != nil { - return nil, fmt.Errorf("tls dial: %w", err) + + // Clear deadline after handshake + if err := rawConn.SetDeadline(time.Time{}); err != nil { + rawConn.Close() + return nil, err } // Set TCP_NODELAY and buffer sizes - if tcpConn, ok := rawConn.NetConn().(*net.TCPConn); ok { + if tcpConn, ok := rawConn.(*tls.Conn); ok { + if netConn := tcpConn.NetConn(); netConn != nil { + if tcpNetConn, ok := netConn.(*net.TCPConn); ok { + tcpNetConn.SetNoDelay(true) + tcpNetConn.SetReadBuffer(256 * 1024) + tcpNetConn.SetWriteBuffer(256 * 1024) + } + } + } else if tcpConn, ok := rawConn.(*net.TCPConn); ok { tcpConn.SetNoDelay(true) tcpConn.SetReadBuffer(256 * 1024) tcpConn.SetWriteBuffer(256 * 1024) @@ -115,7 +158,7 @@ func Connect(ip, domain, path string, timeout time.Duration) (*WebSocket, error) } return &WebSocket{ - conn: rawConn, + conn: rawConn.(*tls.Conn), reader: reader, writer: bufio.NewWriter(rawConn), maskKey: make([]byte, 4), diff --git a/mobile/mobile.go b/mobile/mobile.go index 821db7b..62926d8 100644 --- a/mobile/mobile.go +++ b/mobile/mobile.go @@ -44,7 +44,7 @@ func Start(host string, port int, dcIP string, verbose bool) string { var ctx context.Context ctx, cancel = context.WithCancel(context.Background()) - server, err = proxy.NewServer(cfg, logger) + server, err = proxy.NewServer(cfg, logger, "") if err != nil { cancel() return fmt.Sprintf("Failed to create server: %v", err)