Merge remote-tracking branch 'origin/master' into anysigned

This commit is contained in:
Andrew Heberle 2025-09-06 11:08:15 +08:00
commit 657362c2f9
16 changed files with 320 additions and 38 deletions

85
.gitignore vendored
View File

@ -1,3 +1,86 @@
# Go build artifacts
go.sum
bin
bin/
*.exe
*.exe~
*.dll
*.so
*.dylib
*.test
*.out
coverage.html
# Go workspace file
go.work
# Editor files
*.swp
*.swo
*~
.vscode/
*.code-workspace
# IDE files - IntelliJ IDEA
.idea/
*.iml
*.ipr
*.iws
# IDE files - Eclipse
.project
.classpath
.c9/
*.launch
.settings/
.metadata
# IDE files - NetBeans
/nbproject/private/
/nbbuild/
/dist/
/nbdist/
/.nb-gradle/
# OS generated files
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db
# Temporary files
tmp/
temp/
*.tmp
*.temp
# Log files
*.log
logs/
# Configuration files (may contain secrets)
config.yaml
config.yml
*.env
.env
.env.local
.env.*.local
# SSL/TLS certificates and keys
*.pem
*.key
*.crt
*.csr
*.p12
*.pfx
# Database files
*.db
*.sqlite
*.sqlite3
# Backup files
*.bak
*.backup

View File

@ -328,6 +328,12 @@ Client:
SplitUserDomain: false
# If true, removes "username" (and "domain" if SplitUserDomain is true) from RDP file.
# NoUsername: true
# If both SigningCert and SigningKey are set the downloaded RDP file will be signed
# so the client can authenticate the validity of the RDP file and reduce warnings from
# the client if the CA that issued the certificate is trusted. Both should be PEM encoded
# and the key must be an unencrypted RSA private key.
# SigningCert: /path/to/signing.crt
# SigningKey: /path/to/signing.key
Security:
# a random string of 32 characters to secure cookies on the client
# make sure to share this amongst different pods

View File

@ -239,8 +239,9 @@ func main() {
ntlm := web.NTLMAuthHandler{SocketAddress: conf.Server.AuthSocket, Timeout: conf.Server.BasicAuthTimeout}
rdp.NewRoute().HeadersRegexp("Authorization", "NTLM").HandlerFunc(ntlm.NTLMAuth(gw.HandleGatewayProtocol))
rdp.NewRoute().HeadersRegexp("Authorization", "Negotiate").HandlerFunc(ntlm.NTLMAuth(gw.HandleGatewayProtocol))
auth.Register(`NTLM`)
auth.Register(`Negotiate`)
auth.Register([]string{`NTLM`, `Negotiate`}, func(r *http.Request) bool {
return r.Header.Get("Sec-WebSocket-Protocol") != "binary" // rdp client for ios is incompatible with this NTLM method.
})
}
// basic auth
@ -248,7 +249,7 @@ func main() {
log.Printf("enabling basic authentication")
q := web.BasicAuthHandler{SocketAddress: conf.Server.AuthSocket, Timeout: conf.Server.BasicAuthTimeout}
rdp.NewRoute().HeadersRegexp("Authorization", "Basic").HandlerFunc(q.BasicAuth(gw.HandleGatewayProtocol))
auth.Register(`Basic realm="restricted", charset="UTF-8"`)
auth.Register([]string{`Basic realm="restricted", charset="UTF-8"`}, nil)
}
// spnego / kerberos
@ -266,7 +267,7 @@ func main() {
// kdcproxy
k := kdcproxy.InitKdcProxy(conf.Kerberos.Krb5Conf)
r.HandleFunc(kdcProxyEndPoint, k.Handler).Methods("POST")
auth.Register("Negotiate")
auth.Register([]string{"Negotiate"}, nil)
}
// setup server

View File

@ -3,17 +3,18 @@ package protocol
import (
"context"
"errors"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/patrickmn/go-cache"
"log"
"net"
"net/http"
"reflect"
"syscall"
"time"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/patrickmn/go-cache"
)
const (
@ -140,7 +141,7 @@ func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error {
if !ptrSysFd.IsValid() {
return errors.New("cannot find Sysfd field")
}
fd := int(ptrSysFd.Int())
fd := int64ToFd(ptrSysFd.Int())
if g.ReceiveBuf > 0 {
err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, g.ReceiveBuf)

View File

@ -0,0 +1,8 @@
//go:build !windows
package protocol
// the fd arg to syscall.SetsockoptInt on Linix is of type int
func int64ToFd(n int64) int {
return int(n)
}

View File

@ -0,0 +1,10 @@
package protocol
import (
"syscall"
)
// the fd arg to syscall.SetsockoptInt on Windows is of type syscall.Handle
func int64ToFd(n int64) syscall.Handle {
return syscall.Handle(n)
}

View File

@ -49,7 +49,7 @@ type RdpSettings struct {
EnableRdsAasAuth bool `rdp:"enablerdsaadauth" default:"false"`
DisableConnectionSharing bool `rdp:"disableconnectionsharing" default:"false"`
AlternateShell string `rdp:"alternate shell"`
AutoReconnectionEnabled bool `rdp:"autoreconnectionenabled" default:"true"`
AutoReconnectionEnabled bool `rdp:"autoreconnection enabled" default:"true"`
BandwidthAutodetect bool `rdp:"bandwidthautodetect" default:"true"`
NetworkAutodetect bool `rdp:"networkautodetect" default:"true"`
Compression bool `rdp:"compression" default:"true"`

View File

@ -20,14 +20,13 @@ func TestRdpBuilder(t *testing.T) {
if !strings.Contains(s, "gatewayhostname:s:"+GatewayHostName+CRLF) {
t.Fatalf("%s does not contain `gatewayhostname:s:%s", s, GatewayHostName)
}
if strings.Contains(s, "autoreconnectionenabled") {
t.Fatalf("autoreconnectionenabled is in %s, but it's default value", s)
if strings.Contains(s, "autoreconnection enabled") {
t.Fatalf("autoreconnection enabled is in %s, but it's default value", s)
}
if !strings.Contains(s, "smart sizing:i:1"+CRLF) {
t.Fatalf("%s does not contain smart sizing:i:1", s)
}
log.Printf(builder.String())
log.Printf("%s", builder.String())
}
func TestInitStruct(t *testing.T) {

View File

@ -24,7 +24,7 @@ BitmapPersistenceEnabled:i:0
AudioRedirectionMode:i:2
EnablePortRedirection:i:0
EnableDriveRedirection:i:0
AutoReconnectEnabled:i:1
AutoReconnect Enabled:i:1
EnableSCardRedirection:i:1
EnablePrinterRedirection:i:0
BBarEnabled:i:0

View File

@ -5,21 +5,35 @@ import (
"net/http"
)
type AuthHeader struct {
header string
condition func(*http.Request) bool
}
type AuthMux struct {
headers []string
headers []AuthHeader
}
func NewAuthMux() *AuthMux {
return &AuthMux{}
}
func (a *AuthMux) Register(s string) {
a.headers = append(a.headers, s)
// Register adds authentication methods with optional condition function
func (a *AuthMux) Register(headers []string, condition func(*http.Request) bool) {
for _, header := range headers {
a.headers = append(a.headers, AuthHeader{
header: header,
condition: condition,
})
}
}
func (a *AuthMux) SetAuthenticate(w http.ResponseWriter, r *http.Request) {
for _, s := range a.headers {
w.Header().Add("WWW-Authenticate", s)
for _, authHeader := range a.headers {
// If condition is nil or condition returns true, add the header
if authHeader.condition == nil || authHeader.condition(r) {
w.Header().Add("WWW-Authenticate", authHeader.header)
}
}
http.Error(w, "Unauthorized", http.StatusUnauthorized)
}

View File

@ -4,12 +4,13 @@ import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"net/http"
"time"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/patrickmn/go-cache"
"golang.org/x/oauth2"
"net/http"
"time"
)
const (
@ -91,7 +92,7 @@ func (h *OIDC) HandleCallback(w http.ResponseWriter, r *http.Request) {
id.SetAuthTime(time.Now())
id.SetAttribute(identity.AttrAccessToken, oauth2Token.AccessToken)
if err = SaveSessionIdentity(r, w, id); err != nil {
if err := SaveSessionIdentity(r, w, id); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}

View File

@ -16,7 +16,6 @@ import (
"time"
"github.com/andrewheberle/rdpsign"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/hostselection"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/rdp"
)
@ -187,7 +186,7 @@ func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) {
render := user
if opts.UsernameTemplate != "" {
render = fmt.Sprintf(h.rdpOpts.UsernameTemplate)
render = fmt.Sprint(h.rdpOpts.UsernameTemplate)
render = strings.Replace(render, "{{ username }}", user, 1)
if h.rdpOpts.UsernameTemplate == render {
log.Printf("Invalid username template. %s == %s", h.rdpOpts.UsernameTemplate, user)
@ -251,7 +250,7 @@ func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) {
d.Settings.GatewayCredentialMethod = 1
d.Settings.GatewayUsageMethod = 1
// no rdp siging
// no rdp siging so return as-is
if h.rdpSigner == nil {
http.ServeContent(w, r, fn, time.Now(), strings.NewReader(d.String()))
return
@ -260,7 +259,8 @@ func (h *Handler) HandleDownload(w http.ResponseWriter, r *http.Request) {
// get rdp content
rdpContent := d.String()
signedContent, err := h.rdpSigner.SignRdp(rdpContent)
// sign rdp content
signedContent, err := h.rdpSigner.Sign(rdpContent)
if err != nil {
log.Printf("Could not sign RDP file due to %s", err)
http.Error(w, errors.New("could not sign RDP file").Error(), http.StatusInternalServerError)

View File

@ -2,15 +2,25 @@ package web
import (
"context"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/rdp"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"
"time"
"github.com/andrewheberle/rdpsign"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/rdp"
"github.com/bolkedebruin/rdpgw/cmd/rdpgw/security"
"github.com/spf13/afero"
)
const (
@ -172,6 +182,89 @@ func TestHandler_HandleDownload(t *testing.T) {
}
func TestHandler_HandleSignedDownload(t *testing.T) {
req, err := http.NewRequest("GET", "/connect", nil)
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
id := identity.NewUser()
id.SetUserName(testuser)
id.SetAuthenticated(true)
req = identity.AddToRequestCtx(id, req)
ctx := req.Context()
u, _ := url.Parse(gateway)
c := Config{
HostSelection: "roundrobin",
Hosts: hosts,
PAATokenGenerator: paaTokenMock,
GatewayAddress: u,
RdpOpts: RdpOpts{SplitUserDomain: true},
}
h := c.NewHandler()
// set up rdp signer
fs := afero.NewMemMapFs()
if err := genKeypair(fs); err != nil {
t.Errorf("could not generate key pair for testing: %s", err)
}
signer, err := rdpsign.New("test.crt", "test.key", rdpsign.WithFs(fs))
if err != nil {
t.Errorf("could not create *rdpsign.Signer for testing: %s", err)
}
h.rdpSigner = signer
hh := http.HandlerFunc(h.HandleDownload)
hh.ServeHTTP(rr, req)
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
if ctype := rr.Header().Get("Content-Type"); ctype != "application/x-rdp" {
t.Errorf("content type header does not match: got %v want %v",
ctype, "application/json")
}
if cdisp := rr.Header().Get("Content-Disposition"); cdisp == "" {
t.Errorf("content disposition is nil")
}
data := rdpToMap(strings.Split(rr.Body.String(), rdp.CRLF))
if data["username"] != testuser {
t.Errorf("username key in rdp does not match: got %v want %v", data["username"], testuser)
}
if data["gatewayhostname"] != u.Host {
t.Errorf("gatewayhostname key in rdp does not match: got %v want %v", data["gatewayhostname"], u.Host)
}
if token, _ := paaTokenMock(ctx, testuser, data["full address"]); token != data["gatewayaccesstoken"] {
t.Errorf("gatewayaccesstoken key in rdp does not match username_full address: got %v want %v",
data["gatewayaccesstoken"], token)
}
if !contains(data["full address"], hosts) {
t.Errorf("full address key in rdp is not in allowed hosts list: go %v want in %v",
data["full address"], hosts)
}
signscopeWant := "GatewayHostname,Full Address,GatewayCredentialsSource,GatewayProfileUsageMethod,GatewayUsageMethod,Alternate Full Address"
if data["signscope"] != signscopeWant {
t.Errorf("signscope key in rdp does not match: got %v want %v", data["signscope"], signscopeWant)
}
if _, found := data["signature"]; !found {
t.Errorf("no signature found in rdp")
}
}
func TestHandler_HandleDownloadWithRdpTemplate(t *testing.T) {
f, err := os.CreateTemp("", "rdp")
if err != nil {
@ -233,3 +326,68 @@ func rdpToMap(rdp []string) map[string]string {
return ret
}
func genKeypair(fs afero.Fs) error {
// generate private key
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return err
}
// convert to DER
der, err := x509.MarshalPKCS8PrivateKey(privateKey)
if err != nil {
return err
}
// encode DER private key as PEM
if err := func() error {
f, err := fs.Create("test.key")
if err != nil {
return err
}
defer f.Close()
return pem.Encode(f, &pem.Block{
Type: "PRIVATE KEY",
Bytes: der,
})
}(); err != nil {
return err
}
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Example Organization"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Minute * 10),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if err != nil {
return err
}
// encode cert as PEM
if err := func() error {
f, err := fs.Create("test.crt")
if err != nil {
return err
}
defer f.Close()
return pem.Encode(f, &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
})
}(); err != nil {
return err
}
return nil
}

View File

@ -1,4 +1,4 @@
FROM golang:1
FROM golang:1.24
WORKDIR /src
ENV CGO_ENABLED 0
COPY go.mod go.sum ./

View File

@ -1,5 +1,5 @@
# builder stage
FROM golang:1.22-alpine as builder
FROM golang:1.24-alpine as builder
#RUN apt-get update && apt-get install -y libpam-dev
RUN apk --no-cache add git gcc musl-dev linux-pam-dev openssl

9
go.mod
View File

@ -1,14 +1,15 @@
module github.com/bolkedebruin/rdpgw
go 1.24.2
go 1.24.2
require (
github.com/andrewheberle/rdpsign v1.0.0
github.com/andrewheberle/rdpsign v1.1.0
github.com/bolkedebruin/gokrb5/v8 v8.5.0
github.com/coreos/go-oidc/v3 v3.9.0
github.com/fatih/structs v1.1.0
github.com/go-jose/go-jose/v4 v4.0.5
github.com/go-viper/mapstructure/v2 v2.3.0
github.com/go-viper/mapstructure/v2 v2.4.0
github.com/google/uuid v1.6.0
github.com/gorilla/mux v1.8.1
github.com/gorilla/sessions v1.2.2
@ -24,10 +25,11 @@ require (
github.com/msteinert/pam/v2 v2.0.0
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/prometheus/client_golang v1.19.0
github.com/spf13/afero v1.14.0
github.com/stretchr/testify v1.10.0
github.com/thought-machine/go-flags v1.6.3
golang.org/x/crypto v0.36.0
golang.org/x/oauth2 v0.18.0
golang.org/x/oauth2 v0.27.0
google.golang.org/grpc v1.62.1
google.golang.org/protobuf v1.33.0
)
@ -55,7 +57,6 @@ require (
golang.org/x/net v0.38.0 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect
google.golang.org/appengine v1.6.8 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240314234333-6e1732d8331c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)