diff --git a/go.mod b/go.mod index da9d1dc..d3a170a 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ toolchain go1.23.0 require github.com/quic-go/quic-go v0.48.1 require ( + github.com/conduitio/bwlimit v0.1.0 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect github.com/google/pprof v0.0.0-20241029010322-833c56d90c8e // indirect github.com/onsi/ginkgo/v2 v2.20.2 // indirect @@ -19,5 +20,6 @@ require ( golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.26.0 // indirect golang.org/x/text v0.19.0 // indirect + golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.26.0 // indirect ) diff --git a/main.go b/main.go index 0bdf1c1..687105e 100644 --- a/main.go +++ b/main.go @@ -15,9 +15,15 @@ import ( "syscall" "time" + "github.com/conduitio/bwlimit" "github.com/quic-go/quic-go/http3" ) +var ( + wl = flag.Int("w", 8000, "Write limit in Kbps") + rl = flag.Int("r", 8000, "Read limit in Kbps") +) + var h3client = &http.Client{ Transport: &http3.Transport{}, Timeout: 10 * time.Second, @@ -194,7 +200,6 @@ func main() { var https = flag.Bool("https", false, "Use built-in https server (recommended)") var h3 = flag.Bool("h3", false, "Use HTTP/3 for requests (high CPU usage)") var ipv6 = flag.Bool("ipv6_only", false, "Only use ipv6 for requests") - flag.StringVar(&ua, "u", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36", "Set User-Agent (do not modify)") flag.StringVar(&tls_cert, "tls-cert", "", "TLS Certificate path") flag.StringVar(&tls_key, "tls-key", "", "TLS Certificate Key path") flag.StringVar(&sock, "s", "/tmp/http-ytproxy.sock", "Specify a socket name") @@ -239,45 +244,55 @@ func main() { go requestPerSecond() go requestPerMinute() + ln, err := net.Listen("tcp", host+":"+port) + if err != nil { + log.Fatalf("Failed to listen: %v", err) + } + + // 1Kbit = 125Bytes + var ( + writeLimit = bwlimit.Byte(*wl) * bwlimit.Byte(125) + readLimit = bwlimit.Byte(*rl) * bwlimit.Byte(125) + ) + + ln = bwlimit.NewListener(ln, writeLimit, readLimit) + srv := &http.Server{ ReadTimeout: 5 * time.Second, WriteTimeout: 1 * time.Hour, - Addr: string(host) + ":" + string(port), Handler: mux, } - socket := string(sock) - syscall.Unlink(socket) - listener, err := net.Listen("unix", socket) - fmt.Println("Unix socket listening at:", string(sock)) + syscall.Unlink(sock) + socket_listener, err := net.Listen("unix", sock) if err != nil { - fmt.Println("Failed to bind to UDS, please check the socket name, falling back to TCP/IP") + fmt.Println("Failed to bind to UDS, please check the socket name") fmt.Println(err.Error()) - err := srv.ListenAndServe() - if err != nil { - fmt.Println("Cannot bind to port", string(port), "Error:", err) - fmt.Println("Please try changing the port number") - } } else { - defer listener.Close() + defer socket_listener.Close() // To allow everyone to access the socket - err = os.Chmod(socket, 0777) + err = os.Chmod(sock, 0777) if err != nil { fmt.Println("Error setting permissions:", err) return } else { fmt.Println("Setting socket permissions to 777") } - go srv.Serve(listener) + + go srv.Serve(socket_listener) + fmt.Println("Unix socket listening at:", string(sock)) + if *https { fmt.Println("Serving HTTPS at port", string(port)) - if err := srv.ListenAndServeTLS(tls_cert, tls_key); err != nil { + if err := srv.ServeTLS(ln, tls_cert, tls_key); err != nil { log.Fatal(err) } } else { fmt.Println("Serving HTTP at port", string(port)) - srv.ListenAndServe() + if err := srv.Serve(ln); err != nil { + log.Fatal(err) + } } } }