Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat(post-rules) execute post-rules after every update to the firewall
  • Loading branch information
atasmohammadi committed Jun 6, 2025
commit 50fb2c3303ca9476696febe4794e5851dd7dd26c
6 changes: 6 additions & 0 deletions internal/firewall/firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ type Config struct { //nolint:maligned
stateMutex sync.Mutex
}

// applyUserPostRules applies user-defined post firewall rules
func (c *Config) applyUserPostRules(ctx context.Context) error {
const remove = false
return c.runUserPostRules(ctx, c.customRulesPath, remove)
}

// NewConfig creates a new Config instance and returns an error
// if no iptables implementation is available.
func NewConfig(ctx context.Context, logger Logger,
Expand Down
9 changes: 9 additions & 0 deletions internal/firewall/iptables.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,15 @@ func (c *Config) runUserPostRules(ctx context.Context, filepath string, remove b
} else if err != nil {
return err
}

// Log when post-rules are being applied
if !remove {
c.logger.Info("applying user-defined post firewall rules from " + filepath)
} else {
c.logger.Info("removing user-defined post firewall rules from " + filepath)
}


b, err := io.ReadAll(file)
if err != nil {
_ = file.Close()
Expand Down
57 changes: 37 additions & 20 deletions internal/firewall/outboundsubnets.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,47 @@ import (
"github.com/qdm12/gluetun/internal/subnet"
)

func (c *Config) SetOutboundSubnets(ctx context.Context, subnets []netip.Prefix) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()

if !c.enabled {
c.logger.Info("firewall disabled, only updating allowed subnets internal list")
c.outboundSubnets = make([]netip.Prefix, len(subnets))
copy(c.outboundSubnets, subnets)
return nil
}
func (c *Config) SetOutboundSubnets(ctx context.Context, outboundSubnets []netip.Prefix) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()

c.logger.Info("setting allowed subnets...")
if !c.enabled {
c.outboundSubnets = outboundSubnets
return nil
}

subnetsToAdd, subnetsToRemove := subnet.FindSubnetsToChange(c.outboundSubnets, subnets)
if len(subnetsToAdd) == 0 && len(subnetsToRemove) == 0 {
return nil
}
// Remove previous outbound subnet rules
for _, subnet := range c.outboundSubnets {
subnetIsIPv6 := subnet.Addr().Is6()
for _, defaultRoute := range c.defaultRoutes {
defaultRouteIsIPv6 := defaultRoute.Family == netlink.FamilyV6
ipFamilyMatch := subnetIsIPv6 == defaultRouteIsIPv6
if !ipFamilyMatch {
continue
}

c.removeOutboundSubnets(ctx, subnetsToRemove)
if err := c.addOutboundSubnets(ctx, subnetsToAdd); err != nil {
return fmt.Errorf("setting allowed outbound subnets: %w", err)
}
const remove = true
err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
defaultRoute.AssignedIP, subnet, remove)
if err != nil {
return err
}
}
}

return nil
c.outboundSubnets = outboundSubnets

// Add new outbound subnet rules
if err = c.allowOutboundSubnets(ctx); err != nil {
return err
}

// Re-apply user post-rules after subnet changes
if err = c.applyUserPostRules(ctx); err != nil {
return fmt.Errorf("re-applying user post-rules after outbound subnet change: %w", err)
}

return nil
}

func (c *Config) removeOutboundSubnets(ctx context.Context, subnets []netip.Prefix) {
Expand Down
66 changes: 30 additions & 36 deletions internal/firewall/ports.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,36 @@ import (
)

func (c *Config) SetAllowedPort(ctx context.Context, port uint16, intf string) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()

if port == 0 {
return nil
}

if !c.enabled {
c.logger.Info("firewall disabled, only updating allowed ports internal state")
existingInterfaces, ok := c.allowedInputPorts[port]
if !ok {
existingInterfaces = make(map[string]struct{})
}
existingInterfaces[intf] = struct{}{}
c.allowedInputPorts[port] = existingInterfaces
return nil
}

netInterfaces, has := c.allowedInputPorts[port]
if !has {
netInterfaces = make(map[string]struct{})
} else if _, exists := netInterfaces[intf]; exists {
return nil
}

c.logger.Info("setting allowed input port " + fmt.Sprint(port) + " through interface " + intf + "...")

const remove = false
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
return fmt.Errorf("allowing input to port %d through interface %s: %w",
port, intf, err)
}
netInterfaces[intf] = struct{}{}
c.allowedInputPorts[port] = netInterfaces

return nil
c.stateMutex.Lock()
defer c.stateMutex.Unlock()

interfaceSet, ok := c.allowedInputPorts[port]
if !ok {
interfaceSet = make(map[string]struct{})
c.allowedInputPorts[port] = interfaceSet
}

_, alreadySet := interfaceSet[intf]
if alreadySet {
return nil
}

if c.enabled {
const remove = false
err = c.acceptInputToPort(ctx, intf, port, remove)
if err != nil {
return fmt.Errorf("accepting input port %d on interface %s: %w",
port, intf, err)
}

// ADD THIS: Re-apply user post-rules after port changes
if err = c.applyUserPostRules(ctx); err != nil {
return fmt.Errorf("re-applying user post-rules after port change: %w", err)
}
}

interfaceSet[intf] = struct{}{}
return nil
}

func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error) {
Expand Down
91 changes: 41 additions & 50 deletions internal/firewall/vpn.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,54 +7,45 @@ import (
"github.com/qdm12/gluetun/internal/models"
)

func (c *Config) SetVPNConnection(ctx context.Context,
connection models.Connection, vpnIntf string,
) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()

if !c.enabled {
c.logger.Info("firewall disabled, only updating internal VPN connection")
c.vpnConnection = connection
return nil
}

c.logger.Info("allowing VPN connection...")

if c.vpnConnection.Equal(connection) {
return nil
}

remove := true
if c.vpnConnection.IP.IsValid() {
for _, defaultRoute := range c.defaultRoutes {
if err := c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove); err != nil {
c.logger.Error("cannot remove outdated VPN connection rule: " + err.Error())
}
}
}
c.vpnConnection = models.Connection{}

if c.vpnIntf != "" {
if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
c.logger.Error("cannot remove outdated VPN interface rule: " + err.Error())
}
}
c.vpnIntf = ""

remove = false

for _, defaultRoute := range c.defaultRoutes {
if err := c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, connection, remove); err != nil {
return fmt.Errorf("allowing output traffic through VPN connection: %w", err)
}
}
c.vpnConnection = connection

if err = c.acceptOutputThroughInterface(ctx, vpnIntf, remove); err != nil {
return fmt.Errorf("accepting output traffic through interface %s: %w", vpnIntf, err)
}
c.vpnIntf = vpnIntf

return nil
func (c *Config) SetVPNConnection(ctx context.Context, connection models.Connection, intf string) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()

if !c.enabled {
c.vpnConnection = connection
c.vpnIntf = intf
return nil
}

// Remove previous VPN rules
if c.vpnConnection.IP.IsValid() {
const remove = true
interfacesSeen := make(map[string]struct{}, len(c.defaultRoutes))
for _, defaultRoute := range c.defaultRoutes {
_, seen := interfacesSeen[defaultRoute.NetInterface]
if seen {
continue
}
interfacesSeen[defaultRoute.NetInterface] = struct{}{}
err = c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove)
if err != nil {
return fmt.Errorf("removing output traffic through VPN: %w", err)
}
}
}

c.vpnConnection = connection
c.vpnIntf = intf

// Add new VPN rules
if err = c.allowVPNIP(ctx); err != nil {
return err
}

// Re-apply user post-rules after VPN changes
if err = c.applyUserPostRules(ctx); err != nil {
return fmt.Errorf("re-applying user post-rules after VPN change: %w", err)
}

return nil
}