mirror of
https://github.com/bolkedebruin/rdpgw.git
synced 2026-02-10 18:07:44 +00:00
Add anysigned host selection algorithm
This commit is contained in:
parent
c79484091d
commit
b2794edb3f
@ -17,9 +17,6 @@ const (
|
||||
TlsDisable = "disable"
|
||||
TlsAuto = "auto"
|
||||
|
||||
HostSelectionSigned = "signed"
|
||||
HostSelectionRoundRobin = "roundrobin"
|
||||
|
||||
SessionStoreCookie = "cookie"
|
||||
SessionStoreFile = "file"
|
||||
|
||||
|
||||
9
cmd/rdpgw/hostselection/hostselection.go
Normal file
9
cmd/rdpgw/hostselection/hostselection.go
Normal file
@ -0,0 +1,9 @@
|
||||
package hostselection
|
||||
|
||||
const (
|
||||
Any = "any"
|
||||
AnySigned = "anysigned"
|
||||
RoundRobin = "roundrobin"
|
||||
Signed = "signed"
|
||||
Unsigned = "unsigned"
|
||||
)
|
||||
@ -6,6 +6,8 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/hostselection"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -15,12 +17,12 @@ var (
|
||||
|
||||
func CheckHost(ctx context.Context, host string) (bool, error) {
|
||||
switch HostSelection {
|
||||
case "any":
|
||||
case hostselection.Any, hostselection.AnySigned:
|
||||
return true, nil
|
||||
case "signed":
|
||||
case hostselection.Signed:
|
||||
// todo get from context?
|
||||
return false, errors.New("cannot verify host in 'signed' mode as token data is missing")
|
||||
case "roundrobin", "unsigned":
|
||||
case hostselection.RoundRobin, hostselection.Unsigned:
|
||||
s := getTunnel(ctx)
|
||||
if s.User.UserName() == "" {
|
||||
return false, errors.New("no valid session info or username found in context")
|
||||
|
||||
@ -16,6 +16,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/andrewheberle/rdpsign"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/hostselection"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
|
||||
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/rdp"
|
||||
)
|
||||
@ -60,8 +61,8 @@ type Handler struct {
|
||||
}
|
||||
|
||||
func (c *Config) NewHandler() *Handler {
|
||||
if len(c.Hosts) < 1 {
|
||||
log.Fatal("Not enough hosts to connect to specified")
|
||||
if len(c.Hosts) < 1 && (c.HostSelection != hostselection.Any && c.HostSelection != hostselection.AnySigned) {
|
||||
log.Fatalf("Not enough hosts to connect to specified for %s host selection algorithm", c.HostSelection)
|
||||
}
|
||||
|
||||
handler := &Handler{
|
||||
@ -98,9 +99,16 @@ func (h *Handler) selectRandomHost() string {
|
||||
|
||||
func (h *Handler) getHost(ctx context.Context, u *url.URL) (string, error) {
|
||||
switch h.hostSelection {
|
||||
case "roundrobin":
|
||||
case hostselection.RoundRobin:
|
||||
return h.selectRandomHost(), nil
|
||||
case "signed":
|
||||
case hostselection.AnySigned:
|
||||
hosts, ok := u.Query()["host"]
|
||||
if !ok {
|
||||
return "", errors.New("invalid query parameter")
|
||||
}
|
||||
|
||||
return h.queryInfo(ctx, hosts[0], h.queryTokenIssuer)
|
||||
case hostselection.Signed:
|
||||
hosts, ok := u.Query()["host"]
|
||||
if !ok {
|
||||
return "", errors.New("invalid query parameter")
|
||||
@ -109,6 +117,7 @@ func (h *Handler) getHost(ctx context.Context, u *url.URL) (string, error) {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, check := range h.hosts {
|
||||
if check == host {
|
||||
@ -121,7 +130,7 @@ func (h *Handler) getHost(ctx context.Context, u *url.URL) (string, error) {
|
||||
return "", errors.New("invalid host specified in query token")
|
||||
}
|
||||
return host, nil
|
||||
case "unsigned":
|
||||
case hostselection.Unsigned:
|
||||
hosts, ok := u.Query()["host"]
|
||||
if !ok {
|
||||
return "", errors.New("invalid query parameter")
|
||||
@ -134,7 +143,7 @@ func (h *Handler) getHost(ctx context.Context, u *url.URL) (string, error) {
|
||||
// not found
|
||||
log.Printf("Invalid host %s specified in client request", hosts[0])
|
||||
return "", errors.New("invalid host specified in query parameter")
|
||||
case "any":
|
||||
case hostselection.Any:
|
||||
hosts, ok := u.Query()["host"]
|
||||
if !ok {
|
||||
return "", errors.New("invalid query parameter")
|
||||
@ -242,21 +251,22 @@ func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) {
|
||||
d.Settings.GatewayCredentialMethod = 1
|
||||
d.Settings.GatewayUsageMethod = 1
|
||||
|
||||
if h.rdpSigner != nil {
|
||||
// get rdp content
|
||||
rdpContent := d.String()
|
||||
|
||||
signedContent, err := h.rdpSigner.SignRdp(rdpContent)
|
||||
if err != nil {
|
||||
log.Printf("Could not sign RDP file due to %s", err)
|
||||
http.Error(w, errors.New("could not sign RDP file").Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// return signd rdp file
|
||||
http.ServeContent(w, r, fn, time.Now(), bytes.NewReader(signedContent))
|
||||
// no rdp siging
|
||||
if h.rdpSigner == nil {
|
||||
http.ServeContent(w, r, fn, time.Now(), strings.NewReader(d.String()))
|
||||
return
|
||||
}
|
||||
|
||||
http.ServeContent(w, r, fn, time.Now(), strings.NewReader(d.String()))
|
||||
// get rdp content
|
||||
rdpContent := d.String()
|
||||
|
||||
signedContent, err := h.rdpSigner.SignRdp(rdpContent)
|
||||
if err != nil {
|
||||
log.Printf("Could not sign RDP file due to %s", err)
|
||||
http.Error(w, errors.New("could not sign RDP file").Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// return signd rdp file
|
||||
http.ServeContent(w, r, fn, time.Now(), bytes.NewReader(signedContent))
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user