Add anysigned host selection algorithm

This commit is contained in:
Andrew Heberle 2025-08-30 16:00:54 +08:00
parent c79484091d
commit b2794edb3f
4 changed files with 44 additions and 26 deletions

View File

@ -17,9 +17,6 @@ const (
TlsDisable = "disable"
TlsAuto = "auto"
HostSelectionSigned = "signed"
HostSelectionRoundRobin = "roundrobin"
SessionStoreCookie = "cookie"
SessionStoreFile = "file"

View File

@ -0,0 +1,9 @@
package hostselection
const (
Any = "any"
AnySigned = "anysigned"
RoundRobin = "roundrobin"
Signed = "signed"
Unsigned = "unsigned"
)

View File

@ -6,6 +6,8 @@ import (
"fmt"
"log"
"strings"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/hostselection"
)
var (
@ -15,12 +17,12 @@ var (
func CheckHost(ctx context.Context, host string) (bool, error) {
switch HostSelection {
case "any":
case hostselection.Any, hostselection.AnySigned:
return true, nil
case "signed":
case hostselection.Signed:
// todo get from context?
return false, errors.New("cannot verify host in 'signed' mode as token data is missing")
case "roundrobin", "unsigned":
case hostselection.RoundRobin, hostselection.Unsigned:
s := getTunnel(ctx)
if s.User.UserName() == "" {
return false, errors.New("no valid session info or username found in context")

View File

@ -16,6 +16,7 @@ import (
"time"
"github.com/andrewheberle/rdpsign"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/hostselection"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/rdp"
)
@ -60,8 +61,8 @@ type Handler struct {
}
func (c *Config) NewHandler() *Handler {
if len(c.Hosts) < 1 {
log.Fatal("Not enough hosts to connect to specified")
if len(c.Hosts) < 1 && (c.HostSelection != hostselection.Any && c.HostSelection != hostselection.AnySigned) {
log.Fatalf("Not enough hosts to connect to specified for %s host selection algorithm", c.HostSelection)
}
handler := &Handler{
@ -98,9 +99,16 @@ func (h *Handler) selectRandomHost() string {
func (h *Handler) getHost(ctx context.Context, u *url.URL) (string, error) {
switch h.hostSelection {
case "roundrobin":
case hostselection.RoundRobin:
return h.selectRandomHost(), nil
case "signed":
case hostselection.AnySigned:
hosts, ok := u.Query()["host"]
if !ok {
return "", errors.New("invalid query parameter")
}
return h.queryInfo(ctx, hosts[0], h.queryTokenIssuer)
case hostselection.Signed:
hosts, ok := u.Query()["host"]
if !ok {
return "", errors.New("invalid query parameter")
@ -109,6 +117,7 @@ func (h *Handler) getHost(ctx context.Context, u *url.URL) (string, error) {
if err != nil {
return "", err
}
found := false
for _, check := range h.hosts {
if check == host {
@ -121,7 +130,7 @@ func (h *Handler) getHost(ctx context.Context, u *url.URL) (string, error) {
return "", errors.New("invalid host specified in query token")
}
return host, nil
case "unsigned":
case hostselection.Unsigned:
hosts, ok := u.Query()["host"]
if !ok {
return "", errors.New("invalid query parameter")
@ -134,7 +143,7 @@ func (h *Handler) getHost(ctx context.Context, u *url.URL) (string, error) {
// not found
log.Printf("Invalid host %s specified in client request", hosts[0])
return "", errors.New("invalid host specified in query parameter")
case "any":
case hostselection.Any:
hosts, ok := u.Query()["host"]
if !ok {
return "", errors.New("invalid query parameter")
@ -242,21 +251,22 @@ func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) {
d.Settings.GatewayCredentialMethod = 1
d.Settings.GatewayUsageMethod = 1
if h.rdpSigner != nil {
// get rdp content
rdpContent := d.String()
signedContent, err := h.rdpSigner.SignRdp(rdpContent)
if err != nil {
log.Printf("Could not sign RDP file due to %s", err)
http.Error(w, errors.New("could not sign RDP file").Error(), http.StatusInternalServerError)
return
}
// return signd rdp file
http.ServeContent(w, r, fn, time.Now(), bytes.NewReader(signedContent))
// no rdp siging
if h.rdpSigner == nil {
http.ServeContent(w, r, fn, time.Now(), strings.NewReader(d.String()))
return
}
http.ServeContent(w, r, fn, time.Now(), strings.NewReader(d.String()))
// get rdp content
rdpContent := d.String()
signedContent, err := h.rdpSigner.SignRdp(rdpContent)
if err != nil {
log.Printf("Could not sign RDP file due to %s", err)
http.Error(w, errors.New("could not sign RDP file").Error(), http.StatusInternalServerError)
return
}
// return signd rdp file
http.ServeContent(w, r, fn, time.Now(), bytes.NewReader(signedContent))
}