diff --git a/main.go b/main.go index 585c0b2..ae48c3f 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "crypto/tls" "encoding/json" "flag" "fmt" @@ -16,6 +17,7 @@ import ( "time" "github.com/conduitio/bwlimit" + "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" ) @@ -58,7 +60,6 @@ var h2client = &http.Client{ }, } -// https://github.com/lucas-clemente/quic-go/issues/2836 var client *http.Client // Same user agent as Invidious @@ -93,6 +94,8 @@ var ipv6_only = false var version string +var h3s bool + type statusJson struct { Version string `json:"version"` RequestCount int64 `json:"requestCount"` @@ -169,13 +172,20 @@ func beforeAll(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { defer panicHandler(w) + if h3s { + w.Header().Set("Alt-Svc", "h3=\":8443\"; ma=86400") + } + + if req.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } + if req.Method != "GET" && req.Method != "HEAD" { io.WriteString(w, "Only GET and HEAD requests are allowed.") return } - atomic.AddInt64(&stats_.RequestCount, 1) - // To look like more like a browser req.Header.Add("Origin", "https://www.youtube.com") req.Header.Add("Referer", "https://www.youtube.com/") @@ -184,11 +194,7 @@ func beforeAll(next http.HandlerFunc) http.HandlerFunc { w.Header().Set("Access-Control-Allow-Headers", "*") w.Header().Set("Access-Control-Max-Age", "1728000") - if req.Method == "OPTIONS" { - w.WriteHeader(http.StatusOK) - return - } - + atomic.AddInt64(&stats_.RequestCount, 1) next(w, req) } } @@ -197,15 +203,23 @@ func main() { var sock string var host string var port string + var tls_cert string var tls_key string + var ipv6 bool + var https bool + var h3c bool - path_prefix = os.Getenv("PREFIX_PATH") - ipv6_only = os.Getenv("IPV6_ONLY") == "1" + ua = os.Getenv("USER_AGENT") + https = os.Getenv("HTTPS") == "1" + h3c = os.Getenv("H3C") == "1" + h3s = os.Getenv("H3S") == "1" + ipv6 = os.Getenv("IPV6_ONLY") == "1" - var https = flag.Bool("https", false, "Use built-in https server (recommended)") - var h3 = flag.Bool("h3", false, "Use HTTP/3 for requests (high CPU usage)") - var ipv6 = flag.Bool("ipv6_only", false, "Only use ipv6 for requests") + flag.BoolVar(&https, "https", false, "Use built-in https server (recommended)") + flag.BoolVar(&h3s, "h3c", false, "Use HTTP/3 for client requests (high CPU usage)") + flag.BoolVar(&h3s, "h3s", true, "Use HTTP/3 for server requests") + flag.BoolVar(&ipv6_only, "ipv6_only", false, "Only use ipv6 for requests") flag.StringVar(&tls_cert, "tls-cert", "", "TLS Certificate path") flag.StringVar(&tls_key, "tls-key", "", "TLS Certificate Key path") flag.StringVar(&sock, "s", "/tmp/http-ytproxy.sock", "Specify a socket name") @@ -213,25 +227,25 @@ func main() { flag.StringVar(&host, "l", "0.0.0.0", "Specify a listen address") flag.Parse() - if *h3 { + if h3c { client = h3client } else { client = h2client } - if *https { + if https { if len(tls_cert) <= 0 { - fmt.Println("tls-cert argument is missing") - fmt.Println("You need a TLS certificate for HTTPS") + fmt.Println("tls-cert argument is missing, you need a TLS certificate for HTTPS") + os.Exit(1) } if len(tls_key) <= 0 { - fmt.Println("tls-key argument is missing") - fmt.Println("You need a TLS key for HTTPS") + fmt.Println("tls-key argument is missing, you need a TLS key for HTTPS") + os.Exit(1) } } - ipv6_only = *ipv6 + ipv6_only = ipv6 mux := http.NewServeMux() @@ -262,11 +276,25 @@ func main() { ) ln = bwlimit.NewListener(ln, writeLimit, readLimit) + // srvDialer := bwlimit.NewDialer(&net.Dialer{}, writeLimit, readLimit) srv := &http.Server{ + Handler: mux, ReadTimeout: 5 * time.Second, WriteTimeout: 1 * time.Hour, - Handler: mux, + } + + srvh3 := &http3.Server{ + Handler: mux, + EnableDatagrams: false, // https://quic.video/blog/never-use-datagrams/ (Read it) + IdleTimeout: 120 * time.Second, + TLSConfig: http3.ConfigureTLSConfig(&tls.Config{}), + QUICConfig: &quic.Config{ + // KeepAlivePeriod: 10 * time.Second, + MaxIncomingStreams: 256, // I'm not sure if this is correct. + MaxIncomingUniStreams: 256, // Same as above + }, + Addr: host + ":" + port, } syscall.Unlink(sock) @@ -289,11 +317,22 @@ func main() { go srv.Serve(socket_listener) fmt.Println("Unix socket listening at:", string(sock)) - if *https { + if https { fmt.Println("Serving HTTPS at port", string(port)) - if err := srv.ServeTLS(ln, tls_cert, tls_key); err != nil { - log.Fatal(err) + go func() { + if err := srv.ServeTLS(ln, tls_cert, tls_key); err != nil { + log.Fatal(err) + } + }() + if h3s { + fmt.Println("Serving HTTPS via QUIC at port", string(port)) + go func() { + if err := srvh3.ListenAndServeTLS(tls_cert, tls_key); err != nil { + log.Fatal(err) + } + }() } + select {} } else { fmt.Println("Serving HTTP at port", string(port)) if err := srv.Serve(ln); err != nil {