Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Send ports forwarded to control server #2392

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/portforward/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ type Service interface {
Start(ctx context.Context) (runError <-chan error, err error)
Stop() (err error)
GetPortsForwarded() (ports []uint16)
SetPortsForwarded(ctx context.Context, ports []uint16) (err error)
}

type Routing interface {
Expand Down
11 changes: 11 additions & 0 deletions internal/portforward/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package portforward

import (
"context"
"errors"
"fmt"
"net/http"
"sync"
Expand Down Expand Up @@ -166,6 +167,16 @@ func (l *Loop) GetPortsForwarded() (ports []uint16) {
return l.service.GetPortsForwarded()
}

var ErrServiceNotStarted = errors.New("port forwarding service not started")

func (l *Loop) SetPortsForwarded(ports []uint16) (err error) {
if l.service == nil {
return fmt.Errorf("%w", ErrServiceNotStarted)
}

return l.service.SetPortsForwarded(l.runCtx, ports)
}

func ptrTo[T any](value T) *T {
return &value
}
28 changes: 28 additions & 0 deletions internal/portforward/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package service

import (
"context"
"fmt"
"net/http"
"slices"
"sync"
)

Expand Down Expand Up @@ -50,3 +52,29 @@ func (s *Service) GetPortsForwarded() (ports []uint16) {
copy(ports, s.ports)
return ports
}

func (s *Service) SetPortsForwarded(ctx context.Context, ports []uint16) (err error) {
s.startStopMutex.Lock()
defer s.startStopMutex.Unlock()

s.portMutex.Lock()
defer s.portMutex.Unlock()
slices.Sort(ports)
if slices.Equal(s.ports, ports) {
return nil
}

err = s.cleanup()
if err != nil {
return fmt.Errorf("cleaning up: %w", err)
}

err = s.onNewPorts(ctx, ports)
if err != nil {
return fmt.Errorf("handling new ports: %w", err)
}

s.logger.Info("updated: " + portsToString(s.ports))

return nil
}
73 changes: 43 additions & 30 deletions internal/portforward/service/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package service
import (
"context"
"fmt"
"slices"

"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/provider/utils"
Expand Down Expand Up @@ -47,30 +48,67 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
return nil, fmt.Errorf("port forwarding for the first time: %w", err)
}

err = s.onNewPorts(ctx, ports)
if err != nil {
return nil, err
}

keepPortCtx, keepPortCancel := context.WithCancel(context.Background())
s.keepPortCancel = keepPortCancel
runErrorCh := make(chan error)
keepPortDoneCh := make(chan struct{})
s.keepPortDoneCh = keepPortDoneCh

readyCh := make(chan struct{})
go func(ctx context.Context, portForwarder PortForwarder,
obj utils.PortForwardObjects, readyCh chan<- struct{},
runError chan<- error, doneCh chan<- struct{},
) {
defer close(doneCh)
close(readyCh)
err = portForwarder.KeepPortForward(ctx, obj)
crashed := ctx.Err() == nil
if !crashed { // stopped by Stop call
return
}
s.startStopMutex.Lock()
defer s.startStopMutex.Unlock()
_ = s.cleanup()
runError <- err
}(keepPortCtx, s.settings.PortForwarder, obj, readyCh, runErrorCh, keepPortDoneCh)
<-readyCh

return runErrorCh, nil
}

func (s *Service) onNewPorts(ctx context.Context, ports []uint16) (err error) {
slices.Sort(ports)

s.logger.Info(portsToString(ports))

for _, port := range ports {
err = s.portAllower.SetAllowedPort(ctx, port, s.settings.Interface)
if err != nil {
return nil, fmt.Errorf("allowing port in firewall: %w", err)
return fmt.Errorf("allowing port in firewall: %w", err)
}

if s.settings.ListeningPort != 0 {
err = s.portAllower.RedirectPort(ctx, s.settings.Interface, port, s.settings.ListeningPort)
if err != nil {
return nil, fmt.Errorf("redirecting port in firewall: %w", err)
return fmt.Errorf("redirecting port in firewall: %w", err)
}
}
}

err = s.writePortForwardedFile(ports)
if err != nil {
_ = s.cleanup()
return nil, fmt.Errorf("writing port file: %w", err)
return fmt.Errorf("writing port file: %w", err)
}

s.portMutex.Lock()
s.ports = ports
s.ports = make([]uint16, len(ports))
copy(s.ports, ports)
s.portMutex.Unlock()

if s.settings.UpCommand != "" {
Expand All @@ -81,30 +119,5 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
}
}

keepPortCtx, keepPortCancel := context.WithCancel(context.Background())
s.keepPortCancel = keepPortCancel
runErrorCh := make(chan error)
keepPortDoneCh := make(chan struct{})
s.keepPortDoneCh = keepPortDoneCh

readyCh := make(chan struct{})
go func(ctx context.Context, portForwarder PortForwarder,
obj utils.PortForwardObjects, readyCh chan<- struct{},
runError chan<- error, doneCh chan<- struct{},
) {
defer close(doneCh)
close(readyCh)
err = portForwarder.KeepPortForward(ctx, obj)
crashed := ctx.Err() == nil
if !crashed { // stopped by Stop call
return
}
s.startStopMutex.Lock()
defer s.startStopMutex.Unlock()
_ = s.cleanup()
runError <- err
}(keepPortCtx, s.settings.PortForwarder, obj, readyCh, runErrorCh, keepPortDoneCh)
<-readyCh

return runErrorCh, nil
return nil
}
4 changes: 2 additions & 2 deletions internal/server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func newHandler(ctx context.Context, logger Logger, logging bool,
authSettings auth.Settings,
buildInfo models.BuildInformation,
vpnLooper VPNLooper,
pfGetter PortForwardedGetter,
pf PortForwarding,
dnsLooper DNSLoop,
updaterLooper UpdaterLooper,
publicIPLooper PublicIPLoop,
Expand All @@ -25,7 +25,7 @@ func newHandler(ctx context.Context, logger Logger, logging bool,
handler := &handler{}

vpn := newVPNHandler(ctx, vpnLooper, storage, ipv6Supported, logger)
openvpn := newOpenvpnHandler(ctx, vpnLooper, pfGetter, logger)
openvpn := newOpenvpnHandler(ctx, vpnLooper, pf, logger)
dns := newDNSHandler(ctx, dnsLooper, logger)
updater := newUpdaterHandler(ctx, updaterLooper, logger)
publicip := newPublicIPHandler(publicIPLooper, logger)
Expand Down
3 changes: 2 additions & 1 deletion internal/server/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ type DNSLoop interface {
GetStatus() (status models.LoopStatus)
}

type PortForwardedGetter interface {
type PortForwarding interface {
GetPortsForwarded() (ports []uint16)
SetPortsForwarded(ports []uint16) (err error)
}

type PublicIPLoop interface {
Expand Down
30 changes: 27 additions & 3 deletions internal/server/openvpn.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package server
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"

Expand All @@ -11,20 +12,20 @@ import (
)

func newOpenvpnHandler(ctx context.Context, looper VPNLooper,
pfGetter PortForwardedGetter, w warner,
portForwarding PortForwarding, w warner,
) http.Handler {
return &openvpnHandler{
ctx: ctx,
looper: looper,
pf: pfGetter,
pf: portForwarding,
warner: w,
}
}

type openvpnHandler struct {
ctx context.Context //nolint:containedctx
looper VPNLooper
pf PortForwardedGetter
pf PortForwarding
warner warner
}

Expand All @@ -51,6 +52,8 @@ func (h *openvpnHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
h.getPortForwarded(w)
case http.MethodPut:
h.setPortForwarded(w, r)
default:
errMethodNotSupported(w, r.Method)
}
Expand Down Expand Up @@ -142,3 +145,24 @@ func (h *openvpnHandler) getPortForwarded(w http.ResponseWriter) {
w.WriteHeader(http.StatusInternalServerError)
}
}

func (h *openvpnHandler) setPortForwarded(w http.ResponseWriter, r *http.Request) {
var data portsWrapper

decoder := json.NewDecoder(r.Body)
err := decoder.Decode(&data)
if err != nil {
h.warner.Warn(fmt.Sprintf("failed setting forwarded ports: %s", err))
http.Error(w, "failed setting forwarded ports", http.StatusBadRequest)
return
}

err = h.pf.SetPortsForwarded(data.Ports)
if err != nil {
h.warner.Warn(fmt.Sprintf("failed setting forwarded ports: %s", err))
http.Error(w, "failed setting forwarded ports", http.StatusInternalServerError)
return
}

w.WriteHeader(http.StatusOK)
}
4 changes: 2 additions & 2 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (

func New(ctx context.Context, address string, logEnabled bool, logger Logger,
authConfigPath string, buildInfo models.BuildInformation, openvpnLooper VPNLooper,
pfGetter PortForwardedGetter, dnsLooper DNSLoop,
pf PortForwarding, dnsLooper DNSLoop,
updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, storage Storage,
ipv6Supported bool) (
server *httpserver.Server, err error,
Expand All @@ -27,7 +27,7 @@ func New(ctx context.Context, address string, logEnabled bool, logger Logger,
}

handler, err := newHandler(ctx, logger, logEnabled, authSettings, buildInfo,
openvpnLooper, pfGetter, dnsLooper, updaterLooper, publicIPLooper,
openvpnLooper, pf, dnsLooper, updaterLooper, publicIPLooper,
storage, ipv6Supported)
if err != nil {
return nil, fmt.Errorf("creating handler: %w", err)
Expand Down