This commit is contained in:
Andrew Heberle 2026-02-04 14:01:53 +01:00 committed by GitHub
commit 178de0d387
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 64 additions and 25 deletions

View File

@ -167,6 +167,7 @@ Server:
# - signed, a listed host specified in the signed query parameter
# - unsigned, a listed host specified in the query parameter
# - any, insecurely allow any host specified in the query parameter
# - anysigned, allow any host specified in the signed query parameter
HostSelection: roundrobin
# a random strings of at least 32 characters to secure cookies on the client
# make sure to share this across the different pods

View File

@ -5,6 +5,7 @@ import (
"os"
"strings"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/config/hostselection"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
"github.com/knadh/koanf/parsers/yaml"
"github.com/knadh/koanf/providers/confmap"
@ -17,9 +18,6 @@ const (
TlsDisable = "disable"
TlsAuto = "auto"
HostSelectionSigned = "signed"
HostSelectionRoundRobin = "roundrobin"
SessionStoreCookie = "cookie"
SessionStoreFile = "file"
@ -155,7 +153,7 @@ func Load(configFile string) Configuration {
"Server.Tls": "auto",
"Server.Port": 443,
"Server.SessionStore": "cookie",
"Server.HostSelection": "roundrobin",
"Server.HostSelection": hostselection.RoundRobin,
"Server.Authentication": "openid",
"Server.AuthSocket": "/tmp/rdpgw-auth.sock",
"Server.BasicAuthTimeout": 5,
@ -225,7 +223,7 @@ func Load(configFile string) Configuration {
log.Printf("No valid `server.sessionencryptionkey` specified (empty or not 32 characters). Setting to random")
}
if Conf.Server.HostSelection == "signed" && len(Conf.Security.QueryTokenSigningKey) == 0 {
if (Conf.Server.HostSelection == hostselection.Signed || Conf.Server.HostSelection == hostselection.AnySigned) && len(Conf.Security.QueryTokenSigningKey) == 0 {
log.Fatalf("host selection is set to `signed` but `querytokensigningkey` is not set")
}

View File

@ -0,0 +1,9 @@
package hostselection
const (
Any = "any"
AnySigned = "anysigned"
RoundRobin = "roundrobin"
Signed = "signed"
Unsigned = "unsigned"
)

View File

@ -6,6 +6,8 @@ import (
"fmt"
"log"
"strings"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/config/hostselection"
)
var (
@ -15,12 +17,12 @@ var (
func CheckHost(ctx context.Context, host string) (bool, error) {
switch HostSelection {
case "any":
case hostselection.Any, hostselection.AnySigned:
return true, nil
case "signed":
case hostselection.Signed:
// todo get from context?
return false, errors.New("cannot verify host in 'signed' mode as token data is missing")
case "roundrobin", "unsigned":
case hostselection.RoundRobin, hostselection.Unsigned:
s := getTunnel(ctx)
if s.User.UserName() == "" {
return false, errors.New("no valid session info or username found in context")

View File

@ -2,9 +2,11 @@ package security
import (
"context"
"testing"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/config/hostselection"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol"
"testing"
)
var (
@ -26,18 +28,18 @@ func TestCheckHost(t *testing.T) {
Hosts = hosts
// check any
HostSelection = "any"
HostSelection = hostselection.Any
host := "try.my.server:3389"
if ok, err := CheckHost(ctx, host); !ok || err != nil {
t.Fatalf("%s should be allowed with host selection %s (err: %s)", host, HostSelection, err)
}
HostSelection = "signed"
HostSelection = hostselection.Signed
if ok, err := CheckHost(ctx, host); ok || err == nil {
t.Fatalf("signed host selection isnt supported at the moment")
}
HostSelection = "roundrobin"
HostSelection = hostselection.RoundRobin
if ok, err := CheckHost(ctx, host); ok {
t.Fatalf("%s should NOT be allowed with host selection %s (err: %s)", host, HostSelection, err)
}

View File

@ -20,6 +20,7 @@ import (
"time"
"github.com/andrewheberle/rdpsign"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/config/hostselection"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/rdp"
)
@ -92,8 +93,8 @@ type Handler struct {
}
func (c *Config) NewHandler() *Handler {
if len(c.Hosts) < 1 {
log.Fatal("Not enough hosts to connect to specified")
if len(c.Hosts) < 1 && (c.HostSelection != hostselection.Any && c.HostSelection != hostselection.AnySigned) {
log.Fatalf("Not enough hosts to connect to specified for %s host selection algorithm", c.HostSelection)
}
handler := &Handler{
@ -353,9 +354,16 @@ func (h *Handler) selectRandomHost() string {
func (h *Handler) getHost(ctx context.Context, u *url.URL) (string, error) {
switch h.hostSelection {
case "roundrobin":
case hostselection.RoundRobin:
return h.selectRandomHost(), nil
case "signed":
case hostselection.AnySigned:
hosts, ok := u.Query()["host"]
if !ok {
return "", errors.New("invalid query parameter")
}
return h.queryInfo(ctx, hosts[0], h.queryTokenIssuer)
case hostselection.Signed:
hosts, ok := u.Query()["host"]
if !ok {
return "", errors.New("invalid query parameter")
@ -364,6 +372,7 @@ func (h *Handler) getHost(ctx context.Context, u *url.URL) (string, error) {
if err != nil {
return "", err
}
found := false
for _, check := range h.hosts {
if check == host {
@ -376,7 +385,7 @@ func (h *Handler) getHost(ctx context.Context, u *url.URL) (string, error) {
return "", errors.New("invalid host specified in query token")
}
return host, nil
case "unsigned":
case hostselection.Unsigned:
hosts, ok := u.Query()["host"]
if !ok {
return "", errors.New("invalid query parameter")
@ -389,7 +398,7 @@ func (h *Handler) getHost(ctx context.Context, u *url.URL) (string, error) {
// not found
log.Printf("Invalid host %s specified in client request", hosts[0])
return "", errors.New("invalid host specified in query parameter")
case "any":
case hostselection.Any:
hosts, ok := u.Query()["host"]
if !ok {
return "", errors.New("invalid query parameter")

View File

@ -17,6 +17,7 @@ import (
"time"
"github.com/andrewheberle/rdpsign"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/config/hostselection"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/rdp"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
@ -45,7 +46,7 @@ func contains(needle string, haystack []string) bool {
func TestGetHost(t *testing.T) {
ctx := context.Background()
c := Config{
HostSelection: "roundrobin",
HostSelection: hostselection.RoundRobin,
Hosts: hosts,
}
h := c.NewHandler()
@ -64,7 +65,7 @@ func TestGetHost(t *testing.T) {
}
// check unsigned
c.HostSelection = "unsigned"
c.HostSelection = hostselection.Unsigned
vals.Set("host", "in.valid.host")
u.RawQuery = vals.Encode()
h = c.NewHandler()
@ -85,7 +86,7 @@ func TestGetHost(t *testing.T) {
}
// check any
c.HostSelection = "any"
c.HostSelection = hostselection.Any
test := "bla.bla.com"
vals.Set("host", test)
u.RawQuery = vals.Encode()
@ -99,7 +100,7 @@ func TestGetHost(t *testing.T) {
}
// check signed
c.HostSelection = "signed"
c.HostSelection = hostselection.Signed
c.QueryInfo = security.QueryInfo
issuer := "rdpgwtest"
security.QuerySigningKey = key
@ -117,6 +118,23 @@ func TestGetHost(t *testing.T) {
if host != hosts[0] {
t.Fatalf("%s does not equal %s", host, hosts[0])
}
// check anysigned (uses same issuer and querytoken as previous test)
c.HostSelection = hostselection.AnySigned
// should work with no hosts
c.Hosts = make([]string, 0)
c.QueryInfo = security.QueryInfo
security.QuerySigningKey = key
vals.Set("host", queryToken)
u.RawQuery = vals.Encode()
h = c.NewHandler()
host, err = h.getHost(ctx, u)
if err != nil {
t.Fatalf("Not accepted host %s is in hosts list (err: %s)", hosts[0], err)
}
if host != hosts[0] {
t.Fatalf("%s does not equal %s", host, hosts[0])
}
}
func TestHandler_HandleDownload(t *testing.T) {
@ -136,7 +154,7 @@ func TestHandler_HandleDownload(t *testing.T) {
u, _ := url.Parse(gateway)
c := Config{
HostSelection: "roundrobin",
HostSelection: hostselection.RoundRobin,
Hosts: hosts,
PAATokenGenerator: paaTokenMock,
GatewayAddress: u,
@ -199,7 +217,7 @@ func TestHandler_HandleSignedDownload(t *testing.T) {
u, _ := url.Parse(gateway)
c := Config{
HostSelection: "roundrobin",
HostSelection: hostselection.RoundRobin,
Hosts: hosts,
PAATokenGenerator: paaTokenMock,
GatewayAddress: u,
@ -292,7 +310,7 @@ func TestHandler_HandleDownloadWithRdpTemplate(t *testing.T) {
u, _ := url.Parse(gateway)
c := Config{
HostSelection: "roundrobin",
HostSelection: hostselection.RoundRobin,
Hosts: hosts,
PAATokenGenerator: paaTokenMock,
GatewayAddress: u,