Use hostselection constants in tests and check for key when anysigned used

This commit is contained in:
Andrew Heberle 2025-09-06 11:33:21 +08:00
parent 076f28e1ce
commit f5276eb878
6 changed files with 19 additions and 15 deletions

View File

@ -5,6 +5,7 @@ import (
"os"
"strings"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/config/hostselection"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
"github.com/knadh/koanf/parsers/yaml"
"github.com/knadh/koanf/providers/confmap"
@ -143,7 +144,7 @@ func Load(configFile string) Configuration {
"Server.Tls": "auto",
"Server.Port": 443,
"Server.SessionStore": "cookie",
"Server.HostSelection": "roundrobin",
"Server.HostSelection": hostselection.RoundRobin,
"Server.Authentication": "openid",
"Server.AuthSocket": "/tmp/rdpgw-auth.sock",
"Server.BasicAuthTimeout": 5,
@ -212,7 +213,7 @@ func Load(configFile string) Configuration {
log.Printf("No valid `server.sessionencryptionkey` specified (empty or not 32 characters). Setting to random")
}
if Conf.Server.HostSelection == "signed" && len(Conf.Security.QueryTokenSigningKey) == 0 {
if (Conf.Server.HostSelection == hostselection.Signed || Conf.Server.HostSelection == hostselection.AnySigned) && len(Conf.Security.QueryTokenSigningKey) == 0 {
log.Fatalf("host selection is set to `signed` but `querytokensigningkey` is not set")
}

View File

@ -7,7 +7,7 @@ import (
"log"
"strings"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/hostselection"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/config/hostselection"
)
var (

View File

@ -2,9 +2,11 @@ package security
import (
"context"
"testing"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/config/hostselection"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/protocol"
"testing"
)
var (
@ -26,18 +28,18 @@ func TestCheckHost(t *testing.T) {
Hosts = hosts
// check any
HostSelection = "any"
HostSelection = hostselection.Any
host := "try.my.server:3389"
if ok, err := CheckHost(ctx, host); !ok || err != nil {
t.Fatalf("%s should be allowed with host selection %s (err: %s)", host, HostSelection, err)
}
HostSelection = "signed"
HostSelection = hostselection.Signed
if ok, err := CheckHost(ctx, host); ok || err == nil {
t.Fatalf("signed host selection isnt supported at the moment")
}
HostSelection = "roundrobin"
HostSelection = hostselection.RoundRobin
if ok, err := CheckHost(ctx, host); ok {
t.Fatalf("%s should NOT be allowed with host selection %s (err: %s)", host, HostSelection, err)
}

View File

@ -16,7 +16,7 @@ import (
"time"
"github.com/andrewheberle/rdpsign"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/hostselection"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/config/hostselection"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/rdp"
)

View File

@ -17,6 +17,7 @@ import (
"time"
"github.com/andrewheberle/rdpsign"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/config/hostselection"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/rdp"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
@ -45,7 +46,7 @@ func contains(needle string, haystack []string) bool {
func TestGetHost(t *testing.T) {
ctx := context.Background()
c := Config{
HostSelection: "roundrobin",
HostSelection: hostselection.RoundRobin,
Hosts: hosts,
}
h := c.NewHandler()
@ -64,7 +65,7 @@ func TestGetHost(t *testing.T) {
}
// check unsigned
c.HostSelection = "unsigned"
c.HostSelection = hostselection.Unsigned
vals.Set("host", "in.valid.host")
u.RawQuery = vals.Encode()
h = c.NewHandler()
@ -85,7 +86,7 @@ func TestGetHost(t *testing.T) {
}
// check any
c.HostSelection = "any"
c.HostSelection = hostselection.Any
test := "bla.bla.com"
vals.Set("host", test)
u.RawQuery = vals.Encode()
@ -99,7 +100,7 @@ func TestGetHost(t *testing.T) {
}
// check signed
c.HostSelection = "signed"
c.HostSelection = hostselection.Signed
c.QueryInfo = security.QueryInfo
issuer := "rdpgwtest"
security.QuerySigningKey = key
@ -136,7 +137,7 @@ func TestHandler_HandleDownload(t *testing.T) {
u, _ := url.Parse(gateway)
c := Config{
HostSelection: "roundrobin",
HostSelection: hostselection.RoundRobin,
Hosts: hosts,
PAATokenGenerator: paaTokenMock,
GatewayAddress: u,
@ -199,7 +200,7 @@ func TestHandler_HandleSignedDownload(t *testing.T) {
u, _ := url.Parse(gateway)
c := Config{
HostSelection: "roundrobin",
HostSelection: hostselection.RoundRobin,
Hosts: hosts,
PAATokenGenerator: paaTokenMock,
GatewayAddress: u,
@ -292,7 +293,7 @@ func TestHandler_HandleDownloadWithRdpTemplate(t *testing.T) {
u, _ := url.Parse(gateway)
c := Config{
HostSelection: "roundrobin",
HostSelection: hostselection.RoundRobin,
Hosts: hosts,
PAATokenGenerator: paaTokenMock,
GatewayAddress: u,