Merge branch 'master' into ADG-10852-rewrites-enabled

This commit is contained in:
Stanislav Chzhen 2025-10-09 18:44:18 +03:00
commit 497441d595
115 changed files with 2270 additions and 734 deletions

View File

@ -1,7 +1,7 @@
'name': 'build'
'env':
'GO_VERSION': '1.25.1'
'GO_VERSION': '1.25.2'
'NODE_VERSION': '20'
'on':

View File

@ -1,7 +1,7 @@
'name': 'lint'
'env':
'GO_VERSION': '1.25.1'
'GO_VERSION': '1.25.2'
'on':
'push':

View File

@ -17,6 +17,11 @@ See also the [v0.107.68 GitHub milestone][ms-v0.107.68].
NOTE: Add new changes BELOW THIS COMMENT.
-->
### Security
- Go version has been updated to prevent the possibility of exploiting the Go vulnerabilities fixed in [1.25.2][go-1.25.2].
### Added
- New DNS rewrite settings endpoints `GET /control/rewrite/settings` and `PUT /control/rewrite/settings/update`. See `openapi/openapi.yaml` for details.
@ -53,6 +58,8 @@ In this release, the schema version has changed from 30 to 31.
To roll back this change, set `schema_version` back to `30`.
[go-1.25.2]: https://groups.google.com/g/golang-announce/c/4Emdl2iQ_bI
<!--
NOTE: Add new changes ABOVE THIS COMMENT.
-->

View File

@ -25,7 +25,7 @@ DIST_DIR = dist
GOAMD64 = v1
GOPROXY = https://proxy.golang.org|direct
GOTELEMETRY = off
GOTOOLCHAIN = go1.25.1
GOTOOLCHAIN = go1.25.2
GPG_KEY = devteam@adguard.com
GPG_KEY_PASSPHRASE = not-a-real-password
NPM = npm

View File

@ -8,7 +8,7 @@
'variables':
'channel': 'edge'
'dockerFrontend': 'adguard/home-js-builder:3.1'
'dockerGo': 'adguard/go-builder:1.25.1--1'
'dockerGo': 'adguard/go-builder:1.25.2--1'
'stages':
- 'Build frontend':

View File

@ -6,7 +6,7 @@
'name': 'AdGuard Home - Build and run tests'
'variables':
'dockerFrontend': 'adguard/home-js-builder:3.1'
'dockerGo': 'adguard/go-builder:1.25.1--1'
'dockerGo': 'adguard/go-builder:1.25.2--1'
'channel': 'development'
'stages':

View File

@ -203,7 +203,7 @@
"new_blocklist": "Новы чорны спіс",
"new_allowlist": "Новы белы спіс",
"edit_blocklist": "Рэдагаваць чорны спіс",
"edit_allowlist": "Рэдагаваць белы спіс",
"edit_allowlist": "Рэдагаваць спіс дазволеных",
"choose_blocklist": "Абярыце спісы блакаванняў",
"choose_allowlist": "Выберите списки разрешённых",
"enter_valid_blocklist": "Дадайце дзейны URL-адрас у чорны спіс.",
@ -247,7 +247,7 @@
"add_persistent_client": "Дадаць у захаваныя кліенты",
"time_table_header": "Час",
"date": "Дата",
"domain_name_table_header": "Дамен",
"domain_name_table_header": "Даменнае імя",
"domain_or_client": "Дамен ці кліент",
"type_table_header": "Тып",
"response_table_header": "Адказ",
@ -446,7 +446,7 @@
"down": "Уніз",
"fix": "Выправіць",
"dns_providers": "<0>Спіс вядомых DNS-правайдараў</0> на выбар.",
"update_now": "Абнавіць цяпер",
"update_now": "Абнавіць",
"update_failed": "Памылка аўто-абнаўлення. Калі ласка, <a>кіруйцеся інструкцыі</a> для абнаўлення ручна.",
"manual_update": "Калі ласка, <a>кіруйцеся інструкцыі</a> для абнаўлення ручна.",
"processing_update": "Калі ласка, пачакайце, AdGuard Home абнаўляецца",

View File

@ -13,7 +13,7 @@
"fallback_dns_desc": "Lista de servidores DNS alternativos utilizados cuando los proveedores DNS no responden. La sintaxis es la misma que en el campo de los principales proveedores DNS anterior.",
"fallback_dns_placeholder": "Ingresa un servidor DNS alternativo por línea",
"local_ptr_title": "Servidores DNS inversos y privados",
"local_ptr_desc": "Los servidores DNS que AdGuard Home utiliza para consultas PTR, SOA y NS privadas. La petición se considera privada si solicita un dominio ARPA que contiene una subred dentro de rangos IP privados, por ejemplo \"192.168.12.34\", y procede de un cliente con dirección privada. Si no se configura, AdGuard Home utiliza las direcciones de los resolvedores DNS predeterminados de tu sistema operativo, excepto las direcciones del propio AdGuard Home.",
"local_ptr_desc": "Servidores DNS que AdGuard Home utiliza para peticiones privadas PTR, SOA y NS. Una petición se considera privada si solicita un dominio ARPA que contiene una subred dentro de rangos de IP privados (como \"192.168.12.34\") y proviene de un cliente con una dirección IP privada. Si no se configura, se utilizarán los resolutores DNS predeterminados de tu sistema operativo, excepto para las direcciones IP de AdGuard Home.",
"local_ptr_default_resolver": "Por defecto, AdGuard Home utiliza los siguientes resolutores DNS inversos: {{ip}}.",
"local_ptr_no_default_resolver": "AdGuard Home no pudo determinar los resolutores DNS inversos y privados adecuados para este sistema.",
"local_ptr_placeholder": "Ingresa una dirección IP por línea",
@ -62,7 +62,7 @@
"dhcp_form_range_end": "Final de rango",
"dhcp_form_lease_title": "Tiempo de asignación DHCP (en segundos)",
"dhcp_form_lease_input": "Duración de asignación",
"dhcp_interface_select": "Seleccione la interfaz DHCP",
"dhcp_interface_select": "Seleccionar interfaz DHCP",
"dhcp_hardware_address": "Dirección MAC",
"dhcp_ip_addresses": "Direcciones IP",
"ip": "IP",
@ -122,7 +122,7 @@
"stats_query_domain": "Dominios más consultados",
"for_last_hours": "de la última {{count}} hora",
"for_last_hours_plural": "de las últimas {{count}} horas",
"for_last_days": "durante los últimos {{count}} días",
"for_last_days": "durante el último {{count}} día",
"for_last_days_plural": "durante los últimos {{count}} días",
"stats_disabled": "Las estadísticas se han deshabilitado. Puedes habilitarlas desde la <0>página de configuración</0>.",
"stats_disabled_short": "Las estadísticas se han deshabilitado",
@ -277,7 +277,7 @@
"query_log_configuration": "Configuración de registros",
"query_log_disabled": "El registro de consultas está deshabilitado y se puede configurar en la <0>configuración</0>",
"query_log_strict_search": "Usar comillas dobles para una búsqueda estricta",
"query_log_retention_confirm": "¿Está seguro de que deseas cambiar la rotación del registro de consultas? Si reduces el valor del intervalo, se perderán algunos datos",
"query_log_retention_confirm": "¿Estás seguro de que deseas cambiar la rotación del registro de consultas? Si disminuyes el valor del intervalo, se perderán algunos datos",
"anonymize_client_ip": "Anonimizar IP del cliente",
"anonymize_client_ip_desc": "No guarda la dirección IP completa del cliente en registros o estadísticas",
"dns_config": "Configuración del servidor DNS",
@ -292,8 +292,8 @@
"blocking_ipv4": "Bloqueo de IPv4",
"blocking_ipv6": "Bloqueo de IPv6",
"blocked_response_ttl": "Respuesta TTL bloqueada",
"blocked_response_ttl_desc": "Especifica durante cuántos segundos los clientes deben almacenar en cache una respuesta filtrada",
"form_enter_blocked_response_ttl": "Ingresa el TTL de respuesta bloqueada (segundos)",
"blocked_response_ttl_desc": "Especifica durante cuántos segundos los clientes deben almacenar en caché una respuesta filtrada",
"form_enter_blocked_response_ttl": "Ingresa el TTL de respuesta bloqueada (en segundos)",
"upstream_timeout": "Tiempo de espera del proveedor DNS",
"upstream_timeout_desc": "Especifica el número de segundos que se debe esperar para recibir una respuesta del proveedor DNS",
"form_enter_upstream_timeout": "Ingresa la duración de tiempo de espera del proveedor DNS en segundos",
@ -313,17 +313,17 @@
"edns_enable": "Habilitar subred de cliente EDNS",
"edns_cs_desc": "Añade la opción subred de cliente EDNS (ECS) a las peticiones del proveedor DNS y registra los valores enviados por los clientes en el registro de consultas.",
"edns_use_custom_ip": "Usar IP personalizada para EDNS",
"edns_use_custom_ip_desc": "Permitir el uso de IP personalizadas para EDNS",
"edns_use_custom_ip_desc": "Permitir el uso de IP personalizada para EDNS",
"rate_limit_desc": "Número de peticiones por segundo permitidas por cliente. Establecerlo en 0 significa que no hay límite.",
"rate_limit_subnet_len_ipv4": "Longitud del prefijo de subred para direcciones IPv4",
"rate_limit_subnet_len_ipv4_desc": "Longitud del prefijo de subred para direcciones IPv4 utilizadas para limitar la velocidad. El valor predeterminado es 24",
"rate_limit_subnet_len_ipv4_desc": "Longitud del prefijo de subred para direcciones IPv4 utilizadas para limitar la cantidad. El valor predeterminado es 24",
"rate_limit_subnet_len_ipv4_error": "La longitud del prefijo de subred IPv4 debe estar entre 0 y 32",
"rate_limit_subnet_len_ipv6": "Longitud del prefijo de subred para direcciones IPv6",
"rate_limit_subnet_len_ipv6_desc": "Longitud del prefijo de subred para direcciones IPv6 utilizadas para limitar la velocidad. El valor predeterminado es 56",
"rate_limit_subnet_len_ipv6_desc": "Longitud del prefijo de subred para direcciones IPv6 utilizadas para limitar la cantidad. El valor predeterminado es 56",
"rate_limit_subnet_len_ipv6_error": "La longitud del prefijo de subred IPv6 debe estar entre 0 y 128",
"form_enter_rate_limit_subnet_len": "Ingresa la longitud del prefijo de subred para limitar la velocidad",
"rate_limit_whitelist": "Lista de permitidos de limitación de velocidad",
"rate_limit_whitelist_desc": "Direcciones IP excluidas de la limitación de velocidad",
"form_enter_rate_limit_subnet_len": "Ingresa la longitud del prefijo de subred para limitar la cantidad",
"rate_limit_whitelist": "Lista de permitido de límite de cantidad",
"rate_limit_whitelist_desc": "Direcciones IP excluidas del límite de cantidad",
"rate_limit_whitelist_placeholder": "Ingresa una dirección IP por línea",
"blocking_ipv4_desc": "Dirección IP devolverá una petición A bloqueada",
"blocking_ipv6_desc": "Dirección IP devolverá una petición AAAA bloqueada",
@ -403,7 +403,7 @@
"encryption_server_enter": "Ingresa el nombre del dominio",
"encryption_server_desc": "Si se configura, AdGuard Home detecta los ID de clientes, responde a las consultas DDR y realiza validaciones de conexión adicionales. Si no se configura, estas funciones se deshabilitarán. Debe coincidir con uno de los nombres DNS del certificado.",
"encryption_redirect": "Redireccionar a HTTPS automáticamente",
"encryption_redirect_desc": "Si está marcado, AdGuard Home redireccionará automáticamente de HTTP a las direcciones HTTPS.",
"encryption_redirect_desc": "Si está marcada, AdGuard Home lo redireccionará automáticamente de direcciones HTTP a HTTPS.",
"encryption_https": "Puerto HTTPS",
"encryption_https_desc": "Si el puerto HTTPS está configurado, la interfaz de administración de AdGuard Home será accesible a través de HTTPS, y también proporcionará DNS mediante HTTPS en la ubicación '/dns-query'.",
"encryption_dot": "Puerto DNS mediante TLS",
@ -479,7 +479,7 @@
"client_confirm_delete": "¿Estás seguro de que deseas eliminar el cliente \"{{key}}\"?",
"list_confirm_delete": "¿Estás seguro de que deseas eliminar esta lista?",
"auto_clients_title": "Clientes activos",
"auto_clients_desc": "Información sobre las direcciones IP de los dispositivos que usan o pueden usar AdGuard Home. Esta información se recopila de varias fuentes, incluidos ficheros de host, DNS inverso, etc.",
"auto_clients_desc": "Información sobre las direcciones IP de los dispositivos que utilizan o pueden utilizar AdGuard Home. Esta información se recopila de varias fuentes, incluidos archivos hosts, DNS inverso, etc.",
"access_title": "Configuración de acceso",
"access_desc": "Aquí puedes configurar las reglas de acceso para el servidor DNS de AdGuard Home",
"access_allowed_title": "Clientes permitidos",
@ -513,9 +513,9 @@
"setup_dns_notice": "Para utilizar <1>DNS mediante HTTPS</1> o <1>DNS mediante TLS</1>, debes <0>configurar el cifrado</0> en la configuración de AdGuard Home.",
"rewrite_added": "Reescritura DNS para \"{{key}}\" añadido correctamente",
"rewrite_deleted": "Reescritura DNS para \"{{key}}\" eliminado correctamente",
"rewrite_updated": "Reconfiguración de DNS actualizada correctamente",
"rewrite_updated": "Reescritura DNS actualizada correctamente",
"rewrite_add": "Añadir reescritura DNS",
"rewrite_edit": "Editar reconfiguración de DNS",
"rewrite_edit": "Editar reescritura DNS",
"rewrite_not_found": "No se han encontrado reescrituras DNS",
"rewrite_confirm_delete": "¿Estás seguro de que deseas eliminar la reescritura DNS para \"{{key}}\"?",
"rewrite_desc": "Permite configurar fácilmente la respuesta DNS personalizada para un nombre de dominio específico.",
@ -557,10 +557,10 @@
"filter_updated": "La lista ha sido actualizada correctamente",
"statistics_configuration": "Configuración de estadísticas",
"statistics_retention": "Retención de estadísticas",
"statistics_retention_desc": "Si disminuye el valor del intervalo, se perderán algunos datos",
"statistics_retention_desc": "Si disminuyes el valor del intervalo, se perderán algunos datos",
"statistics_clear": "Borrar estadísticas",
"statistics_clear_confirm": "¿Estás seguro de que deseas borrar las estadísticas?",
"statistics_retention_confirm": "¿Estás seguro de que deseas cambiar la retención de estadísticas? Si disminuye el valor del intervalo, se perderán algunos datos",
"statistics_retention_confirm": "¿Estás seguro de que deseas cambiar la retención de estadísticas? Si disminuyes el valor del intervalo, se perderán algunos datos",
"statistics_cleared": "Estadísticas borradas correctamente",
"statistics_enable": "Habilitar estadísticas",
"ignore_domains": "Dominios ignorados (separados por una nueva línea)",
@ -599,7 +599,7 @@
"rewrite_A": "<0>A</0>: valor especial, mantiene registros <0>A</0> del proveedor DNS",
"rewrite_AAAA": "<0>AAAA</0>: valor especial, mantiene registros <0>AAAA</0> del proveedor DNS",
"disable_ipv6": "Deshabilitar resolución de direcciones IPv6",
"disable_ipv6_desc": "Descarta todas las consultas de DNS para direcciones IPv6 (tipo AAAA) y elimina las sugerencias de IPv6 de las respuestas HTTPS.",
"disable_ipv6_desc": "Descarta todas las consultas DNS para direcciones IPv6 (tipo AAAA) y elimina las sugerencias IPv6 de las respuestas HTTPS.",
"fastest_addr": "Dirección IP más rápida",
"fastest_addr_desc": "Espera respuestas de <b>todos</b> los servidores DNS, mide la velocidad de conexión TCP de cada servidor y devuelve la dirección IP del servidor con la velocidad de conexión más rápida.<br/>Este modo puede ralentizar significativamente las consultas DNS, si uno o más proveedores DNS no responden. Asegúrate de que tus proveedores DNS sean estables y de que el tiempo de espera tu proveedor DNS sea bajo.",
"autofix_warning_text": "Si haces clic en \"Corregir\", AdGuard Home configurará tu sistema para utilizar el servidor DNS de AdGuard Home.",
@ -607,7 +607,7 @@
"autofix_warning_result": "Como resultado, todas las peticiones DNS de tu sistema serán procesadas por AdGuard Home de manera predeterminada.",
"tags_title": "Etiquetas",
"tags_desc": "Puedes seleccionar las etiquetas que correspondan al cliente. Incluye etiquetas en las reglas de filtrado para aplicarlas con mayor precisión. <0>Más información</0>.",
"form_select_tags": "Seleccione las etiquetas del cliente",
"form_select_tags": "Selecciona las etiquetas del cliente",
"check_title": "Comprobar filtrado",
"check_desc": "Comprueba si un nombre del host está siendo filtrado.",
"check": "Comprobar",
@ -621,7 +621,7 @@
"check_reason": "Razón: {{reason}}",
"check_service": "Nombre del servicio: {{service}}",
"check_hostname": "Nombre de host o nombre de dominio",
"check_client_id": "Identificador del cliente (ClientID o dirección IP)",
"check_client_id": "Identificador del cliente (ID de cliente o dirección IP)",
"check_enter_client_id": "Ingresa el identificador del cliente",
"check_dns_record": "Selecciona el tipo de registro DNS",
"service_name": "Nombre del servicio",
@ -655,8 +655,8 @@
"safe_search": "Búsqueda segura",
"blocklist": "Lista de bloqueo",
"milliseconds_abbreviation": "ms",
"cache_enabled": "Activar caché",
"cache_enabled_desc": "Almacene las respuestas de DNS localmente.",
"cache_enabled": "Habilitar caché",
"cache_enabled_desc": "Almacena las respuestas DNS localmente.",
"cache_size": "Tamaño de la caché",
"cache_size_validation": "El tamaño de la cache debe ser mayor que cero cuando está habilitado.",
"cache_size_desc": "Tamaño de la caché DNS (en bytes). Para deshabilitar el almacenamiento en caché, establécelo en 0.",
@ -665,8 +665,8 @@
"enter_cache_size": "Ingresa el tamaño de la caché (bytes)",
"enter_cache_ttl_min_override": "Ingresa el TTL mínimo (en segundos)",
"enter_cache_ttl_max_override": "Ingresa el TTL máximo (en segundos)",
"cache_ttl_min_override_desc": "Amplía el corto tiempo de vida (segundos) de los valores recibidos del proveedor DNS al almacenar en caché las respuestas DNS.",
"cache_ttl_max_override_desc": "Establece un valor de tiempo de vida (segundos) máximo para las entradas en la caché DNS.",
"cache_ttl_min_override_desc": "Amplía el corto tiempo de vida (en segundos) de los valores recibidos del proveedor DNS al almacenar en caché las respuestas DNS.",
"cache_ttl_max_override_desc": "Establece un valor de tiempo de vida (en segundos) máximo para las entradas en la caché DNS.",
"ttl_cache_validation": "La anulación TTL mínimo de la caché debe ser menor o igual al máximo",
"cache_optimistic": "Caché optimista",
"cache_optimistic_desc": "Haz que AdGuard Home responda desde la caché incluso cuando las entradas estén expiradas y también intente actualizarlas.",

View File

@ -28,12 +28,6 @@ export default {
"homepage": "https://badmojr.github.io/1Hosts/",
"source": "https://adguardteam.github.io/HostlistsRegistry/assets/filter_24.txt"
},
"1hosts_pro": {
"name": "1Hosts (Pro)",
"categoryId": "general",
"homepage": "https://badmojr.github.io/1Hosts/",
"source": "https://adguardteam.github.io/HostlistsRegistry/assets/filter_64.txt"
},
"1hosts_xtra": {
"name": "1Hosts (Xtra)",
"categoryId": "general",

24
go.mod
View File

@ -1,10 +1,10 @@
module github.com/AdguardTeam/AdGuardHome
go 1.25.1
go 1.25.2
require (
github.com/AdguardTeam/dnsproxy v0.76.2
github.com/AdguardTeam/golibs v0.34.1
github.com/AdguardTeam/dnsproxy v0.77.0
github.com/AdguardTeam/golibs v0.35.0
github.com/AdguardTeam/urlfilter v0.22.0
github.com/NYTimes/gziphandler v1.1.1
github.com/ameshkov/dnscrypt/v2 v2.4.0
@ -35,15 +35,15 @@ require (
go.yaml.in/yaml/v4 v4.0.0-rc.2
golang.org/x/crypto v0.42.0
golang.org/x/exp v0.0.0-20250911091902-df9299821621
golang.org/x/net v0.44.0
golang.org/x/sys v0.36.0
golang.org/x/net v0.45.0
golang.org/x/sys v0.37.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
howett.net/plist v1.0.1
)
require (
cloud.google.com/go v0.123.0 // indirect
cloud.google.com/go/auth v0.16.5 // indirect
cloud.google.com/go/auth v0.17.0 // indirect
cloud.google.com/go/compute/metadata v0.9.0 // indirect
github.com/BurntSushi/toml v1.5.0 // indirect
github.com/ameshkov/dnsstamps v1.0.3 // indirect
@ -87,19 +87,19 @@ require (
go.opentelemetry.io/otel/metric v1.38.0 // indirect
go.opentelemetry.io/otel/trace v1.38.0 // indirect
go.uber.org/mock v0.6.0 // indirect
golang.org/x/exp/typeparams v0.0.0-20250911091902-df9299821621 // indirect
golang.org/x/exp/typeparams v0.0.0-20251002181428-27f1f14c8bb9 // indirect
golang.org/x/mod v0.28.0 // indirect
golang.org/x/sync v0.17.0 // indirect
golang.org/x/telemetry v0.0.0-20250930190813-8e6447515a8c // indirect
golang.org/x/telemetry v0.0.0-20251001141935-4eae98a72453 // indirect
golang.org/x/term v0.35.0 // indirect
golang.org/x/text v0.29.0 // indirect
golang.org/x/tools v0.37.0 // indirect
golang.org/x/vuln v1.1.4 // indirect
gonum.org/v1/gonum v0.16.0 // indirect
google.golang.org/genai v1.26.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.9 // indirect
google.golang.org/genai v1.28.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251007200510-49b9836ed3ff // indirect
google.golang.org/grpc v1.76.0 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
honnef.co/go/tools v0.6.1 // indirect
mvdan.cc/editorconfig v0.3.0 // indirect

44
go.sum
View File

@ -1,13 +1,13 @@
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
cloud.google.com/go/auth v0.16.5 h1:mFWNQ2FEVWAliEQWpAdH80omXFokmrnbDhUS9cBywsI=
cloud.google.com/go/auth v0.16.5/go.mod h1:utzRfHMP+Vv0mpOkTRQoWD2q3BatTOoWbA7gCc2dUhQ=
cloud.google.com/go/auth v0.17.0 h1:74yCm7hCj2rUyyAocqnFzsAYXgJhrG26XCFimrc/Kz4=
cloud.google.com/go/auth v0.17.0/go.mod h1:6wv/t5/6rOPAX4fJiRjKkJCvswLwdet7G8+UGXt7nCQ=
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
github.com/AdguardTeam/dnsproxy v0.76.2 h1:Az8r9mHaI4Tmz/Hs6xh8/GRs76gW578qxDBtBhH7oZc=
github.com/AdguardTeam/dnsproxy v0.76.2/go.mod h1:dUg1PVDa993/6Px+pN0rf45Yex8cXTwm4Bb81jax17o=
github.com/AdguardTeam/golibs v0.34.1 h1:RyBpZiXnJqlO3T+xjWldlxsEZDelmaFfKvXiJHDZZFQ=
github.com/AdguardTeam/golibs v0.34.1/go.mod h1:K4C2EbfSEM1zY5YXoti9SfbTAHN/kIX97LpDtCwORrM=
github.com/AdguardTeam/dnsproxy v0.77.0 h1:hQUqNeSDx4hK9bx90lJ4nTE1gRbhjoYYCApm6drlgPU=
github.com/AdguardTeam/dnsproxy v0.77.0/go.mod h1:tWS7JZj0uOGXaiK4NvOANc1hAL8VVPYGrNe2FuuZOHY=
github.com/AdguardTeam/golibs v0.35.0 h1:O990+tbZ5W5yB0ybtaUJy4FUb0bXxyzeUC7t8cr1pCg=
github.com/AdguardTeam/golibs v0.35.0/go.mod h1:y552twxCtvOD8KKQ7ESjo10KZBAE+HSj24yAuAvz9IA=
github.com/AdguardTeam/urlfilter v0.22.0 h1:ybOz3FywbpGDGC+8gFFkM1LMUOSosY7CWSBXIYXnG1U=
github.com/AdguardTeam/urlfilter v0.22.0/go.mod h1:q0lWKapXlYTA4TUWUM1YDwU6Q0PKvQEokztcvRV2OW0=
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
@ -201,8 +201,8 @@ golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
golang.org/x/exp v0.0.0-20250911091902-df9299821621 h1:2id6c1/gto0kaHYyrixvknJ8tUK/Qs5IsmBtrc+FtgU=
golang.org/x/exp v0.0.0-20250911091902-df9299821621/go.mod h1:TwQYMMnGpvZyc+JpB/UAuTNIsVJifOlSkrZkhcvpVUk=
golang.org/x/exp/typeparams v0.0.0-20250911091902-df9299821621 h1:Yl4H5w2RV7L/dvSHp2GerziT5K2CORgFINPaMFxWGWw=
golang.org/x/exp/typeparams v0.0.0-20250911091902-df9299821621/go.mod h1:4Mzdyp/6jzw9auFDJ3OMF5qksa7UvPnzKqTVGcb04ms=
golang.org/x/exp/typeparams v0.0.0-20251002181428-27f1f14c8bb9 h1:EvjuVHWMoRaAxH402KMgrQpGUjoBy/OWvZjLOqQnwNk=
golang.org/x/exp/typeparams v0.0.0-20251002181428-27f1f14c8bb9/go.mod h1:4Mzdyp/6jzw9auFDJ3OMF5qksa7UvPnzKqTVGcb04ms=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.28.0 h1:gQBtGhjxykdjY9YhZpSlZIsbnaE2+PgjfLWUQTnoZ1U=
@ -213,8 +213,8 @@ golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc=
golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I=
golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
@ -227,10 +227,10 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/telemetry v0.0.0-20250930190813-8e6447515a8c h1:MsueSJrtclpfbcgPRFyj7XAy73cuuce6EYRcSutPUZY=
golang.org/x/telemetry v0.0.0-20250930190813-8e6447515a8c/go.mod h1:+nZKN+XVh4LCiA9DV3ywrzN4gumyCnKjau3NGb9SGoE=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/telemetry v0.0.0-20251001141935-4eae98a72453 h1:UMcclxirvpV79c6GDkin5Z9OeBachvXq6x4cUCGWhWY=
golang.org/x/telemetry v0.0.0-20251001141935-4eae98a72453/go.mod h1:+nZKN+XVh4LCiA9DV3ywrzN4gumyCnKjau3NGb9SGoE=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ=
golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA=
@ -252,14 +252,14 @@ golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/genai v1.26.0 h1:r4HGL54kFv/WCRMTAbZg05Ct+vXfhAbTRlXhFyBkEQo=
google.golang.org/genai v1.26.0/go.mod h1:OClfdf+r5aaD+sCd4aUSkPzJItmg2wD/WON9lQnRPaY=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 h1:i8QOKZfYg6AbGVZzUAY3LrNWCKF8O6zFisU9Wl9RER4=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4/go.mod h1:HSkG/KdJWusxU1F6CNrwNDjBMgisKxGnc5dAZfT0mjQ=
google.golang.org/grpc v1.75.1 h1:/ODCNEuf9VghjgO3rqLcfg8fiOP0nSluljWFlDxELLI=
google.golang.org/grpc v1.75.1/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ=
google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7Iw=
google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
google.golang.org/genai v1.28.0 h1:6qpUWFH3PkHPhxNnu3wjaCVJ6Jri1EIR7ks07f9IpIk=
google.golang.org/genai v1.28.0/go.mod h1:7pAilaICJlQBonjKKJNhftDFv3SREhZcTe9F6nRcjbg=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251007200510-49b9836ed3ff h1:A90eA31Wq6HOMIQlLfzFwzqGKBTuaVztYu/g8sn+8Zc=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251007200510-49b9836ed3ff/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
google.golang.org/grpc v1.76.0 h1:UnVkv1+uMLYXoIz6o7chp59WfQUYA2ex/BXQ9rHZu7A=
google.golang.org/grpc v1.76.0/go.mod h1:Ju12QI8M6iQJtbcsV+awF5a4hfJMLi4X0JLo94ULZ6c=
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=

View File

@ -20,10 +20,10 @@ import (
// TODO(e.burkov, a.garipov): Get rid of it.
type RegisterFunc func(method, url string, handler http.HandlerFunc)
// OK responds with word OK.
func OK(w http.ResponseWriter) {
// OK writes "OK\n" to the response. l and w must not be nil.
func OK(ctx context.Context, l *slog.Logger, w http.ResponseWriter) {
if _, err := io.WriteString(w, "OK\n"); err != nil {
log.Error("couldn't write body: %s", err)
l.WarnContext(ctx, "writing ok body", slogutil.KeyError, err)
}
}
@ -36,8 +36,8 @@ func Error(r *http.Request, w http.ResponseWriter, code int, format string, args
http.Error(w, text, code)
}
// ErrorAndLog writes formatted message to w and also logs it with the specified
// logging level.
// ErrorAndLog writes a formatted HTTP error response and logs it at
// [slog.LevelError] level. l, r, and w must not be nil.
func ErrorAndLog(
ctx context.Context,
l *slog.Logger,
@ -48,13 +48,14 @@ func ErrorAndLog(
args ...any,
) {
text := fmt.Sprintf(format, args...)
l.ErrorContext(
l.WarnContext(
ctx,
"http error",
"host", r.Host,
"method", r.Method,
"raddr", r.RemoteAddr,
"request_uri", r.RequestURI,
"status", code,
slogutil.KeyError, text,
)
@ -74,13 +75,18 @@ const textPlainDeprMsg = `using this api with the text/plain content-type is dep
// WriteTextPlainDeprecated responds to the request with a message about
// deprecation and removal of a plain-text API if the request is made with the
// "text/plain" content-type.
func WriteTextPlainDeprecated(w http.ResponseWriter, r *http.Request) (isPlainText bool) {
// "text/plain" Content-Type. All arguments must not be nil.
func WriteTextPlainDeprecated(
ctx context.Context,
l *slog.Logger,
w http.ResponseWriter,
r *http.Request,
) (isPlainText bool) {
if r.Header.Get(httphdr.ContentType) != HdrValTextPlain {
return false
}
Error(r, w, http.StatusUnsupportedMediaType, textPlainDeprMsg)
ErrorAndLog(ctx, l, r, w, http.StatusUnsupportedMediaType, textPlainDeprMsg)
return true
}

View File

@ -1,6 +1,16 @@
package aghnet
// CheckOtherDHCP tries to discover another DHCP server in the network.
func CheckOtherDHCP(ifaceName string) (ok4, ok6 bool, err4, err6 error) {
return checkOtherDHCP(ifaceName)
import (
"context"
"log/slog"
)
// CheckOtherDHCP tries to discover another DHCP server in the network. l must
// not be nil.
func CheckOtherDHCP(
ctx context.Context,
l *slog.Logger,
ifaceName string,
) (ok4, ok6 bool, err4, err6 error) {
return checkOtherDHCP(ctx, l, ifaceName)
}

View File

@ -4,14 +4,16 @@ package aghnet
import (
"bytes"
"context"
"fmt"
"log/slog"
"net"
"net/netip"
"os"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/insomniacslk/dhcp/dhcpv6"
@ -23,7 +25,11 @@ import (
// response.
const defaultDiscoverTime = 3 * time.Second
func checkOtherDHCP(ifaceName string) (ok4, ok6 bool, err4, err6 error) {
func checkOtherDHCP(
ctx context.Context,
l *slog.Logger,
ifaceName string,
) (ok4, ok6 bool, err4, err6 error) {
iface, err := net.InterfaceByName(ifaceName)
if err != nil {
err = fmt.Errorf("couldn't find interface by name %s: %w", ifaceName, err)
@ -32,8 +38,8 @@ func checkOtherDHCP(ifaceName string) (ok4, ok6 bool, err4, err6 error) {
return false, false, err4, err6
}
ok4, err4 = checkOtherDHCPv4(iface)
ok6, err6 = checkOtherDHCPv6(iface)
ok4, err4 = checkOtherDHCPv4(ctx, l, iface)
ok6, err6 = checkOtherDHCPv6(ctx, l, iface)
return ok4, ok6, err4, err6
}
@ -69,8 +75,13 @@ func ifaceIPv4Subnet(iface *net.Interface) (subnet netip.Prefix, err error) {
}
// checkOtherDHCPv4 sends a DHCP request to the specified network interface, and
// waits for a response for a period defined by defaultDiscoverTime.
func checkOtherDHCPv4(iface *net.Interface) (ok bool, err error) {
// waits for a response for a period defined by defaultDiscoverTime. l must not
// be nil.
func checkOtherDHCPv4(
ctx context.Context,
l *slog.Logger,
iface *net.Interface,
) (ok bool, err error) {
var subnet netip.Prefix
if subnet, err = ifaceIPv4Subnet(iface); err != nil {
return false, err
@ -88,12 +99,18 @@ func checkOtherDHCPv4(iface *net.Interface) (ok bool, err error) {
return false, fmt.Errorf("couldn't get hostname: %w", err)
}
return discover4(iface, dstAddr, hostname)
return discover4(ctx, l, iface, dstAddr, hostname)
}
// discover4 sends a DHCPv4 discovery to the specified network interface and
// waits for response. iface and dstAddr must not be nil.
func discover4(iface *net.Interface, dstAddr *net.UDPAddr, hostname string) (ok bool, err error) {
func discover4(
ctx context.Context,
l *slog.Logger,
iface *net.Interface,
dstAddr *net.UDPAddr,
hostname string,
) (ok bool, err error) {
var req *dhcpv4.DHCPv4
if req, err = dhcpv4.NewDiscovery(iface.HardwareAddr); err != nil {
return false, fmt.Errorf("dhcpv4.NewDiscovery: %w", err)
@ -124,7 +141,7 @@ func discover4(iface *net.Interface, dstAddr *net.UDPAddr, hostname string) (ok
for {
var next bool
ok, next, err = tryConn4(req, c, iface)
ok, next, err = tryConn4(ctx, l, req, c, iface)
if next {
continue
}
@ -143,6 +160,8 @@ func discover4(iface *net.Interface, dstAddr *net.UDPAddr, hostname string) (ok
// TODO(a.garipov): Refactor further. Inspect error handling, remove parameter
// next, address the TODO, merge with tryConn6, etc.
func tryConn4(
ctx context.Context,
l *slog.Logger,
req *dhcpv4.DHCPv4,
c net.PacketConn,
iface *net.Interface,
@ -153,13 +172,13 @@ func tryConn4(
// TODO: replicate dhclient's behavior of retrying several times with
// progressively longer timeouts.
log.Tracef("dhcpv4: waiting %v for an answer", defaultDiscoverTime)
l.Log(ctx, slogutil.LevelTrace, "waiting for an answer", "timeout", defaultDiscoverTime)
b := make([]byte, 1500)
n, _, err := c.ReadFrom(b)
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
log.Debug("dhcpv4: didn't receive dhcp response")
l.DebugContext(ctx, "did not receive response")
return false, false, nil
}
@ -167,16 +186,16 @@ func tryConn4(
return false, false, fmt.Errorf("receiving packet: %w", err)
}
log.Tracef("dhcpv4: received packet, %d bytes", n)
l.Log(ctx, slogutil.LevelTrace, "received packet", "size", n)
response, err := dhcpv4.FromBytes(b[:n])
if err != nil {
log.Debug("dhcpv4: encoding: %s", err)
l.DebugContext(ctx, "encoding", slogutil.KeyError, err)
return false, true, err
}
log.Debug("dhcpv4: received message from server: %s", response.Summary())
l.DebugContext(ctx, "received message from server", "summary", response.Summary())
switch {
case
@ -185,19 +204,24 @@ func tryConn4(
!bytes.Equal(response.ClientHWAddr, iface.HardwareAddr),
response.TransactionID != req.TransactionID,
!response.Options.Has(dhcpv4.OptionDHCPMessageType):
log.Debug("dhcpv4: received response doesn't match the request")
l.DebugContext(ctx, "dhcpv4: received response does not match the request")
return false, true, nil
default:
log.Tracef("dhcpv4: the packet is from an active dhcp server")
l.Log(ctx, slogutil.LevelTrace, "packet is from an active dhcp server")
return true, false, nil
}
}
// checkOtherDHCPv6 sends a DHCP request to the specified network interface, and
// waits for a response for a period defined by defaultDiscoverTime.
func checkOtherDHCPv6(iface *net.Interface) (ok bool, err error) {
// waits for a response for a period defined by defaultDiscoverTime. l must not
// be nil.
func checkOtherDHCPv6(
ctx context.Context,
l *slog.Logger,
iface *net.Interface,
) (ok bool, err error) {
ifaceIPNet, err := IfaceIPAddrs(iface, IPVersion6)
if err != nil {
return false, fmt.Errorf("getting ipv6 addrs for iface %s: %w", iface.Name, err)
@ -224,18 +248,24 @@ func checkOtherDHCPv6(iface *net.Interface) (ok bool, err error) {
return false, fmt.Errorf("dhcpv6: Couldn't resolve UDP address %s: %w", dst, err)
}
return discover6(iface, udpAddr, dstAddr)
return discover6(ctx, l, iface, udpAddr, dstAddr)
}
// discover6 sends a DHCPv6 discovery to the specified network interface and
// waits for response. iface, updAddr and dstAddr must not be nil.
func discover6(iface *net.Interface, udpAddr, dstAddr *net.UDPAddr) (ok bool, err error) {
func discover6(
ctx context.Context,
l *slog.Logger,
iface *net.Interface,
udpAddr *net.UDPAddr,
dstAddr *net.UDPAddr,
) (ok bool, err error) {
req, err := dhcpv6.NewSolicit(iface.HardwareAddr)
if err != nil {
return false, fmt.Errorf("dhcpv6: dhcpv6.NewSolicit: %w", err)
}
log.Debug("DHCPv6: Listening to udp6 %+v", udpAddr)
l.DebugContext(ctx, "listening on udp6", "addr", udpAddr)
c, err := nclient6.NewIPv6UDPConn(iface.Name, dhcpv6.DefaultClientPort)
if err != nil {
return false, fmt.Errorf("dhcpv6: Couldn't listen on :546: %w", err)
@ -249,7 +279,7 @@ func discover6(iface *net.Interface, udpAddr, dstAddr *net.UDPAddr) (ok bool, er
for {
var next bool
ok, next, err = tryConn6(req, c)
ok, next, err = tryConn6(ctx, l, req, c)
if next {
continue
}
@ -266,10 +296,15 @@ func discover6(iface *net.Interface, udpAddr, dstAddr *net.UDPAddr) (ok bool, er
// the original request. req and c must not be nil.
//
// TODO(a.garipov): See the comment on tryConn4. Sigh…
func tryConn6(req *dhcpv6.Message, c net.PacketConn) (ok, next bool, err error) {
func tryConn6(
ctx context.Context,
l *slog.Logger,
req *dhcpv6.Message,
c net.PacketConn,
) (ok, next bool, err error) {
// TODO: replicate dhclient's behavior of retrying several times with
// progressively longer timeouts.
log.Tracef("dhcpv6: waiting %v for an answer", defaultDiscoverTime)
l.Log(ctx, slogutil.LevelTrace, "waiting for an answer", "timeout", defaultDiscoverTime)
b := make([]byte, 4096)
err = c.SetDeadline(time.Now().Add(defaultDiscoverTime))
@ -280,7 +315,7 @@ func tryConn6(req *dhcpv6.Message, c net.PacketConn) (ok, next bool, err error)
n, _, err := c.ReadFrom(b)
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
log.Debug("dhcpv6: didn't receive dhcp response")
l.DebugContext(ctx, "did not receive response")
return false, false, nil
}
@ -288,21 +323,21 @@ func tryConn6(req *dhcpv6.Message, c net.PacketConn) (ok, next bool, err error)
return false, false, fmt.Errorf("receiving packet: %w", err)
}
log.Tracef("dhcpv6: received packet, %d bytes", n)
l.Log(ctx, slogutil.LevelTrace, "dhcpv6: received packet", "size", n)
response, err := dhcpv6.FromBytes(b[:n])
if err != nil {
log.Debug("dhcpv6: encoding: %s", err)
l.DebugContext(ctx, "encoding", slogutil.KeyError, err)
return false, true, err
}
log.Debug("dhcpv6: received message from server: %s", response.Summary())
l.DebugContext(ctx, "received message from server", "summary", response.Summary())
cid := req.Options.ClientID()
msg, err := response.GetInnerMessage()
if err != nil {
log.Debug("dhcpv6: resp.GetInnerMessage(): %s", err)
l.DebugContext(ctx, "getting inner message", slogutil.KeyError, err)
return false, true, err
}
@ -313,12 +348,12 @@ func tryConn6(req *dhcpv6.Message, c net.PacketConn) (ok, next bool, err error)
rcid != nil &&
cid.Equal(rcid)) {
log.Debug("dhcpv6: received message from server doesn't match our request")
l.DebugContext(ctx, "received message from server does not match our request")
return false, true, nil
}
log.Tracef("dhcpv6: the packet is from an active dhcp server")
l.Log(ctx, slogutil.LevelTrace, "dhcpv6: the packet is from an active dhcp server")
return true, false, nil
}

View File

@ -2,9 +2,18 @@
package aghnet
import "github.com/AdguardTeam/AdGuardHome/internal/aghos"
import (
"context"
"log/slog"
func checkOtherDHCP(ifaceName string) (ok4, ok6 bool, err4, err6 error) {
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
)
func checkOtherDHCP(
_ context.Context,
_ *slog.Logger,
ifaceName string,
) (ok4, ok6 bool, err4, err6 error) {
return false,
false,
aghos.Unsupported("CheckIfOtherDHCPServersPresentV4"),

View File

@ -5,6 +5,7 @@ import (
"fmt"
"io"
"io/fs"
"log/slog"
"net/netip"
"path"
"sync/atomic"
@ -12,17 +13,21 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
// hostsContainerPrefix is a prefix for logging and wrapping errors in
// HostsContainer's methods.
// hostsContainerPrefix is a prefix for wrapping errors in HostsContainer's
// methods.
const hostsContainerPrefix = "hosts container"
// HostsContainer stores the relevant hosts database provided by the OS and
// processes both A/AAAA and PTR DNS requests for those.
type HostsContainer struct {
// done is the channel to sign closing the container.
// logger is used for logging the operation of the hosts container. It must
// not be nil.
logger *slog.Logger
// done is the channel to signal closing the container.
done chan struct{}
// updates is the channel for receiving updated hosts.
@ -31,10 +36,12 @@ type HostsContainer struct {
// current is the last set of hosts parsed.
current atomic.Pointer[hostsfile.DefaultStorage]
// fsys is the working file system to read hosts files from.
// fsys is the working file system to read hosts files from. It must not be
// nil.
fsys fs.FS
// watcher tracks the changes in specified files and directories.
// watcher tracks the changes in specified files and directories. It must
// not be nil.
watcher aghos.FSWatcher
// patterns stores specified paths in the fs.Glob-compatible form.
@ -45,11 +52,12 @@ type HostsContainer struct {
// the HostsContainer.
const ErrNoHostsPaths errors.Error = "no valid paths to hosts files provided"
// NewHostsContainer creates a container of hosts, that watches the paths with
// w. listID is used as an identifier of the underlying rules list. paths
// shouldn't be empty and each of paths should locate either a file or a
// directory in fsys. fsys and w must be non-nil.
// NewHostsContainer creates a container of hosts that watches the paths with w.
// paths shouldn't be empty, and each path should refer either a file or a
// directory in fsys. l, fsys, and w must be non-nil.
func NewHostsContainer(
ctx context.Context,
l *slog.Logger,
fsys fs.FS,
w aghos.FSWatcher,
paths ...string,
@ -69,6 +77,7 @@ func NewHostsContainer(
}
hc = &HostsContainer{
logger: l,
done: make(chan struct{}, 1),
updates: make(chan *hostsfile.DefaultStorage, 1),
fsys: fsys,
@ -76,10 +85,10 @@ func NewHostsContainer(
patterns: patterns,
}
log.Debug("%s: starting", hostsContainerPrefix)
l.DebugContext(ctx, "starting")
// Load initially.
if err = hc.refresh(); err != nil {
if err = hc.refresh(ctx); err != nil {
return nil, err
}
@ -89,22 +98,24 @@ func NewHostsContainer(
return nil, fmt.Errorf("adding path: %w", err)
}
log.Debug("%s: %s is expected to exist but doesn't", hostsContainerPrefix, p)
l.DebugContext(ctx, "expected path does not exist", "path", p)
}
}
go hc.handleEvents()
go hc.handleEvents(ctx)
return hc, nil
}
// Close implements the [io.Closer] interface for *HostsContainer. It closes
// both itself and its [aghos.FSWatcher]. Close must only be called once.
//
// TODO(s.chzhen): Implement [service.Interface].
func (hc *HostsContainer) Close() (err error) {
log.Debug("%s: closing", hostsContainerPrefix)
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
hc.logger.DebugContext(ctx, "closing")
err = errors.Annotate(hc.watcher.Shutdown(ctx), "closing fs watcher: %w")
// Go on and close the container either way.
@ -159,8 +170,8 @@ func pathsToPatterns(fsys fs.FS, paths []string) (patterns []string, err error)
// handleEvents concurrently handles the file system events. It closes the
// update channel of HostsContainer when finishes. It is intended to be used as
// a goroutine.
func (hc *HostsContainer) handleEvents() {
defer log.OnPanic(fmt.Sprintf("%s: handling events", hostsContainerPrefix))
func (hc *HostsContainer) handleEvents(ctx context.Context) {
defer slogutil.RecoverAndLog(ctx, hc.logger)
defer close(hc.updates)
@ -170,13 +181,13 @@ func (hc *HostsContainer) handleEvents() {
select {
case _, ok = <-eventsCh:
if !ok {
log.Debug("%s: watcher closed the events channel", hostsContainerPrefix)
hc.logger.DebugContext(ctx, "watcher closed the events channel")
continue
}
if err := hc.refresh(); err != nil {
log.Error("%s: warning: refreshing: %s", hostsContainerPrefix, err)
if err := hc.refresh(ctx); err != nil {
hc.logger.ErrorContext(ctx, "refreshing", slogutil.KeyError, err)
}
case _, ok = <-hc.done:
// Go on.
@ -185,8 +196,8 @@ func (hc *HostsContainer) handleEvents() {
}
// sendUpd tries to send the parsed data to the ch.
func (hc *HostsContainer) sendUpd(recs *hostsfile.DefaultStorage) {
log.Debug("%s: sending upd", hostsContainerPrefix)
func (hc *HostsContainer) sendUpd(ctx context.Context, recs *hostsfile.DefaultStorage) {
hc.logger.DebugContext(ctx, "sending update")
ch := hc.updates
select {
@ -194,11 +205,11 @@ func (hc *HostsContainer) sendUpd(recs *hostsfile.DefaultStorage) {
// Updates are delivered. Go on.
case <-ch:
ch <- recs
log.Debug("%s: replaced the last update", hostsContainerPrefix)
hc.logger.DebugContext(ctx, "replaced the last update")
case ch <- recs:
// The previous update was just read and the next one pushed. Go on.
default:
log.Error("%s: the updates channel is broken", hostsContainerPrefix)
hc.logger.ErrorContext(ctx, "updates channel is broken")
}
}
@ -206,14 +217,16 @@ func (hc *HostsContainer) sendUpd(recs *hostsfile.DefaultStorage) {
// needed.
//
// TODO(e.burkov): Accept a parameter to specify the files to refresh.
func (hc *HostsContainer) refresh() (err error) {
log.Debug("%s: refreshing", hostsContainerPrefix)
func (hc *HostsContainer) refresh(ctx context.Context) (err error) {
hc.logger.DebugContext(ctx, "refreshing")
// The error is always nil here since no readers passed.
strg, _ := hostsfile.NewDefaultStorage()
strg, _ := hostsfile.NewDefaultStorage(ctx, &hostsfile.DefaultStorageConfig{
Logger: hc.logger,
})
_, err = aghos.FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
// Don't wrap the error since it's already informative enough as is.
return nil, true, hostsfile.Parse(strg, r, nil)
return nil, true, hostsfile.Parse(ctx, strg, r, nil)
}).Walk(hc.fsys, hc.patterns...)
if err != nil {
// Don't wrap the error since it's informative enough as is.
@ -223,7 +236,7 @@ func (hc *HostsContainer) refresh() (err error) {
// TODO(e.burkov): Serialize updates using [time.Time].
if !hc.current.Load().Equal(strg) {
hc.current.Store(strg)
hc.sendUpd(strg)
hc.sendUpd(ctx, strg)
}
return nil

View File

@ -67,7 +67,9 @@ func TestNewHostsContainer(t *testing.T) {
return eventsCh
}
hc, err := aghnet.NewHostsContainer(testFS, &aghtest.FSWatcher{
ctx := testutil.ContextWithTimeout(t, testTimeout)
hc, err := aghnet.NewHostsContainer(ctx, testLogger, testFS, &aghtest.FSWatcher{
OnStart: func(ctx context.Context) (_ error) {
panic(testutil.UnexpectedCall(ctx))
},
@ -95,8 +97,9 @@ func TestNewHostsContainer(t *testing.T) {
}
t.Run("nil_fs", func(t *testing.T) {
ctx := testutil.ContextWithTimeout(t, testTimeout)
require.Panics(t, func() {
_, _ = aghnet.NewHostsContainer(nil, &aghtest.FSWatcher{
_, _ = aghnet.NewHostsContainer(ctx, testLogger, nil, &aghtest.FSWatcher{
OnStart: func(ctx context.Context) (_ error) {
panic(testutil.UnexpectedCall(ctx))
},
@ -110,7 +113,8 @@ func TestNewHostsContainer(t *testing.T) {
t.Run("nil_watcher", func(t *testing.T) {
require.Panics(t, func() {
_, _ = aghnet.NewHostsContainer(testFS, nil, p)
ctx := testutil.ContextWithTimeout(t, testTimeout)
_, _ = aghnet.NewHostsContainer(ctx, testLogger, testFS, nil, p)
})
})
@ -124,7 +128,8 @@ func TestNewHostsContainer(t *testing.T) {
OnShutdown: func(_ context.Context) (err error) { return nil },
}
hc, err := aghnet.NewHostsContainer(testFS, errWatcher, p)
ctx := testutil.ContextWithTimeout(t, testTimeout)
hc, err := aghnet.NewHostsContainer(ctx, testLogger, testFS, errWatcher, p)
require.ErrorIs(t, err, errOnAdd)
assert.Nil(t, hc)
@ -173,12 +178,15 @@ func TestHostsContainer_refresh(t *testing.T) {
OnShutdown: func(_ context.Context) (err error) { return nil },
}
hc, err := aghnet.NewHostsContainer(testFS, w, "dir")
ctx := testutil.ContextWithTimeout(t, testTimeout)
hc, err := aghnet.NewHostsContainer(ctx, testLogger, testFS, w, "dir")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, hc.Close)
strg, _ := hostsfile.NewDefaultStorage()
strg.Add(r1)
strg, _ := hostsfile.NewDefaultStorage(ctx, &hostsfile.DefaultStorageConfig{
Logger: testLogger,
})
strg.Add(ctx, r1)
t.Run("initial_refresh", func(t *testing.T) {
upd, ok := testutil.RequireReceive(t, hc.Upd(), 1*time.Second)
@ -187,7 +195,7 @@ func TestHostsContainer_refresh(t *testing.T) {
assert.True(t, strg.Equal(upd))
})
strg.Add(r2)
strg.Add(ctx, r2)
t.Run("second_refresh", func(t *testing.T) {
testFS["dir/file2"] = &fstest.MapFile{Data: r2Data}

View File

@ -1,11 +1,11 @@
package aghnet
import (
"context"
"fmt"
"log/slog"
"net"
"time"
"github.com/AdguardTeam/golibs/log"
)
// IPVersion is a alias for int for documentation purposes. Use it when the
@ -71,7 +71,7 @@ func ipFromAddr(addr net.Addr, ipv IPVersion) (ip net.IP) {
// IfaceDNSIPAddrs returns IP addresses of the interface suitable to send to
// clients as DNS addresses. If err is nil, addrs contains either no addresses
// or at least two.
// or at least two. l must not be nil.
//
// It makes up to maxAttempts attempts to get the addresses if there are none,
// each time using the provided backoff. Sometimes an interface needs a few
@ -79,6 +79,8 @@ func ipFromAddr(addr net.Addr, ipv IPVersion) (ip net.IP) {
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2304.
func IfaceDNSIPAddrs(
ctx context.Context,
l *slog.Logger,
iface NetIface,
ipv IPVersion,
maxAttempts int,
@ -95,7 +97,7 @@ func IfaceDNSIPAddrs(
break
}
log.Debug("dhcpv%d: attempt %d: no ip addresses", ipv, n)
l.DebugContext(ctx, "no ip addresses", "attempt", n, "ipv", ipv)
time.Sleep(backoff)
}
@ -107,7 +109,7 @@ func IfaceDNSIPAddrs(
// Don't return errors in case the users want to try and enable the DHCP
// server later.
t := time.Duration(n) * backoff
log.Error("dhcpv%d: no ip for iface after %d attempts and %s", ipv, n, t)
l.ErrorContext(ctx, "no ip addresses for iface", "attempts", n, "duration", t, "ipv", ipv)
return nil, nil
case 1:
@ -116,13 +118,13 @@ func IfaceDNSIPAddrs(
// address.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/1708.
log.Debug("dhcpv%d: setting secondary dns ip to itself", ipv)
l.DebugContext(ctx, "setting secondary dns ip to itself", "ipv", ipv)
addrs = append(addrs, addrs[0])
default:
// Go on.
}
log.Debug("dhcpv%d: got addresses %s after %d attempts", ipv, addrs, n)
l.DebugContext(ctx, "got addresses", "addrs", addrs, "attempts", n, "ipv", ipv)
return addrs, nil
}

View File

@ -220,7 +220,8 @@ func TestIfaceDNSIPAddrs(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, err := aghnet.IfaceDNSIPAddrs(tc.iface, tc.ipv, 2, 0)
ctx := testutil.ContextWithTimeout(t, testTimeout)
got, err := aghnet.IfaceDNSIPAddrs(ctx, testLogger, tc.iface, tc.ipv, 2, 0)
require.ErrorIs(t, err, tc.wantErr)
assert.Equal(t, tc.want, got)

View File

@ -1,6 +1,4 @@
// Package aghnet contains networking utilities.
//
// TODO(s.chzhen): Use slog.
package aghnet
import (
@ -9,6 +7,7 @@ import (
"encoding/json"
"fmt"
"io"
"log/slog"
"net"
"net/netip"
"net/url"
@ -17,7 +16,7 @@ import (
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/osutil/executil"
)
@ -51,23 +50,25 @@ func IfaceHasStaticIP(
return ifaceHasStaticIP(ctx, cmdCons, ifaceName)
}
// IfaceSetStaticIP sets a static IP address for network interface. cmdCons
// must not be nil.
// IfaceSetStaticIP sets a static IP address for network interface. l and
// cmdCons must not be nil.
func IfaceSetStaticIP(
ctx context.Context,
l *slog.Logger,
cmdCons executil.CommandConstructor,
ifaceName string,
) (err error) {
return ifaceSetStaticIP(ctx, cmdCons, ifaceName)
return ifaceSetStaticIP(ctx, l, cmdCons, ifaceName)
}
// GatewayIP returns the gateway IP address for the interface. cmdCons must not
// be nil.
// GatewayIP returns the gateway IP address for the interface. l and cmdCons
// must not be nil.
//
// TODO(e.burkov): Investigate if the gateway address may be fetched in another
// way since not every machine has the software installed.
func GatewayIP(
ctx context.Context,
l *slog.Logger,
cmdCons executil.CommandConstructor,
ifaceName string,
) (ip netip.Addr) {
@ -83,11 +84,11 @@ func GatewayIP(
)
if err != nil {
if code, ok := executil.ExitCodeFromError(err); ok {
log.Debug("fetching gateway ip: unexpected exit code: %d", code)
} else {
log.Debug("%s", err)
err = fmt.Errorf("unexpected exit code %d: %w", code, err)
}
l.DebugContext(ctx, "fetching gateway ip", slogutil.KeyError, err)
return netip.Addr{}
}
@ -104,9 +105,9 @@ func GatewayIP(
}
// CanBindPrivilegedPorts checks if current process can bind to privileged
// ports.
func CanBindPrivilegedPorts() (can bool, err error) {
return canBindPrivilegedPorts()
// ports. l must not be nil.
func CanBindPrivilegedPorts(ctx context.Context, l *slog.Logger) (can bool, err error) {
return canBindPrivilegedPorts(ctx, l)
}
// NetInterface represents an entry of network interfaces map.
@ -265,13 +266,13 @@ func InterfaceByIP(ip netip.Addr) (ifaceName string) {
}
// GetSubnet returns the subnet corresponding to the interface of zero prefix if
// the search fails.
// the search fails. l must not be nil.
//
// TODO(e.burkov): See TODO on GetValidNetInterfacesForWeb.
func GetSubnet(ifaceName string) (p netip.Prefix) {
func GetSubnet(ctx context.Context, l *slog.Logger, ifaceName string) (p netip.Prefix) {
netIfaces, err := GetValidNetInterfacesForWeb()
if err != nil {
log.Error("Could not get network interfaces info: %v", err)
l.ErrorContext(ctx, "could not get network interfaces info", slogutil.KeyError, err)
return p
}

View File

@ -2,8 +2,13 @@
package aghnet
import "github.com/AdguardTeam/AdGuardHome/internal/aghos"
import (
"context"
"log/slog"
func canBindPrivilegedPorts() (can bool, err error) {
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
)
func canBindPrivilegedPorts(_ context.Context, _ *slog.Logger) (can bool, err error) {
return aghos.HaveAdminRights()
}

View File

@ -8,6 +8,7 @@ import (
"context"
"fmt"
"io"
"log/slog"
"regexp"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
@ -119,6 +120,7 @@ func getHardwarePortInfo(
// ifaceSetStaticIP sets a static IP on ifaceName. cmdCons must not be nil.
func ifaceSetStaticIP(
ctx context.Context,
_ *slog.Logger,
cmdCons executil.CommandConstructor,
ifaceName string,
) (err error) {

View File

@ -249,7 +249,7 @@ func TestIfaceSetStaticIP(t *testing.T) {
substRootDirFS(t, tc.fsys)
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := IfaceSetStaticIP(ctx, tc.cmdCons, "en0")
err := IfaceSetStaticIP(ctx, testLogger, tc.cmdCons, "en0")
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}

View File

@ -7,6 +7,7 @@ import (
"context"
"fmt"
"io"
"log/slog"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
@ -58,6 +59,11 @@ func (n interfaceName) rcConfStaticConfig(r io.Reader) (_ []string, cont bool, e
return nil, true, s.Err()
}
func ifaceSetStaticIP(_ context.Context, _ executil.CommandConstructor, _ string) (err error) {
func ifaceSetStaticIP(
_ context.Context,
_ *slog.Logger,
_ executil.CommandConstructor,
_ string,
) (err error) {
return aghos.Unsupported("setting static ip")
}

View File

@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/agh"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/osutil/executil"
"github.com/AdguardTeam/golibs/testutil"
@ -24,6 +25,9 @@ const testTimeout = 1 * time.Second
// testCmdCons is the common command constructor for tests.
var testCmdCons = executil.EmptyCommandConstructor{}
// testLogger is a logger used in tests.
var testLogger = slogutil.NewDiscardLogger()
// substRootDirFS replaces the aghos.RootDirFS function used throughout the
// package with fsys for tests ran under t.
func substRootDirFS(tb testing.TB, fsys fs.FS) {
@ -87,7 +91,7 @@ func TestGatewayIP(t *testing.T) {
t.Parallel()
ctx := testutil.ContextWithTimeout(t, testTimeout)
assert.Equal(t, tc.want, GatewayIP(ctx, tc.cmdCons, ifaceName))
assert.Equal(t, tc.want, GatewayIP(ctx, testLogger, tc.cmdCons, ifaceName))
})
}
}

View File

@ -7,13 +7,14 @@ import (
"context"
"fmt"
"io"
"log/slog"
"net/netip"
"os"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil/executil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/google/renameio/v2/maybe"
@ -23,7 +24,7 @@ import (
// dhcpсdConf is the name of /etc/dhcpcd.conf file in the root filesystem.
const dhcpcdConf = "etc/dhcpcd.conf"
func canBindPrivilegedPorts() (can bool, err error) {
func canBindPrivilegedPorts(ctx context.Context, l *slog.Logger) (can bool, err error) {
res, err := unix.PrctlRetInt(
unix.PR_CAP_AMBIENT,
unix.PR_CAP_AMBIENT_IS_SET,
@ -35,7 +36,11 @@ func canBindPrivilegedPorts() (can bool, err error) {
if errors.Is(err, unix.EINVAL) {
// Older versions of Linux kernel do not support this. Print a
// warning and check admin rights.
log.Info("warning: cannot check capability cap_net_bind_service: %s", err)
l.WarnContext(
ctx,
"checking capability cap_net_bind_service",
slogutil.KeyError, err,
)
} else {
return false, err
}
@ -154,13 +159,14 @@ func findIfaceLine(s *bufio.Scanner, name string) (ok bool) {
}
// ifaceSetStaticIP configures the system to retain its current IP on the
// interface through dhcpcd.conf. cmdCons must not be nil.
// interface through dhcpcd.conf. l and cmdCons must not be nil.
func ifaceSetStaticIP(
ctx context.Context,
l *slog.Logger,
cmdCons executil.CommandConstructor,
ifaceName string,
) (err error) {
ipNet := GetSubnet(ifaceName)
ipNet := GetSubnet(ctx, l, ifaceName)
if !ipNet.Addr().IsValid() {
return errors.Error("can't get IP address")
}
@ -170,7 +176,7 @@ func ifaceSetStaticIP(
return err
}
gatewayIP := GatewayIP(ctx, cmdCons, ifaceName)
gatewayIP := GatewayIP(ctx, l, cmdCons, ifaceName)
add := dhcpcdConfIface(ifaceName, ipNet, gatewayIP)
body = append(body, []byte(add)...)

View File

@ -7,6 +7,7 @@ import (
"context"
"fmt"
"io"
"log/slog"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
@ -45,6 +46,11 @@ func hostnameIfStaticConfig(r io.Reader) (_ []string, ok bool, err error) {
return nil, true, s.Err()
}
func ifaceSetStaticIP(_ context.Context, _ executil.CommandConstructor, _ string) (err error) {
func ifaceSetStaticIP(
_ context.Context,
_ *slog.Logger,
_ executil.CommandConstructor,
_ string,
) (err error) {
return aghos.Unsupported("setting static ip")
}

View File

@ -5,13 +5,21 @@ import (
"net/netip"
"net/url"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
)
// testTimeout is a common timeout for tests.
const testTimeout = 1 * time.Second
// testLogger is a logger used in tests.
var testLogger = slogutil.NewDiscardLogger()
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}

View File

@ -5,6 +5,7 @@ package aghnet
import (
"context"
"io"
"log/slog"
"syscall"
"time"
@ -14,7 +15,7 @@ import (
"golang.org/x/sys/windows"
)
func canBindPrivilegedPorts() (can bool, err error) {
func canBindPrivilegedPorts(_ context.Context, _ *slog.Logger) (can bool, err error) {
return true, nil
}
@ -26,7 +27,12 @@ func ifaceHasStaticIP(
return false, aghos.Unsupported("checking static ip")
}
func ifaceSetStaticIP(_ context.Context, _ executil.CommandConstructor, _ string) (err error) {
func ifaceSetStaticIP(
_ context.Context,
_ *slog.Logger,
_ executil.CommandConstructor,
_ string,
) (err error) {
return aghos.Unsupported("setting static ip")
}

View File

@ -153,10 +153,12 @@ func TestStorage_Add_hostsfile(t *testing.T) {
t.Run("add_hosts", func(t *testing.T) {
var s *hostsfile.DefaultStorage
s, err = hostsfile.NewDefaultStorage()
s, err = hostsfile.NewDefaultStorage(ctx, &hostsfile.DefaultStorageConfig{
Logger: testLogger,
})
require.NoError(t, err)
s.Add(&hostsfile.Record{
s.Add(ctx, &hostsfile.Record{
Addr: cliIP1,
Names: []string{cliName1},
})
@ -173,10 +175,12 @@ func TestStorage_Add_hostsfile(t *testing.T) {
t.Run("update_hosts", func(t *testing.T) {
var s *hostsfile.DefaultStorage
s, err = hostsfile.NewDefaultStorage()
s, err = hostsfile.NewDefaultStorage(ctx, &hostsfile.DefaultStorageConfig{
Logger: testLogger,
})
require.NoError(t, err)
s.Add(&hostsfile.Record{
s.Add(ctx, &hostsfile.Record{
Addr: cliIP2,
Names: []string{cliName2},
})
@ -452,10 +456,12 @@ func TestClientsDHCP(t *testing.T) {
require.True(t, t.Run("find_runtime_higher_priority", func(t *testing.T) {
// Add a higher-priority client.
s, strgErr := hostsfile.NewDefaultStorage()
s, strgErr := hostsfile.NewDefaultStorage(ctx, &hostsfile.DefaultStorageConfig{
Logger: testLogger,
})
require.NoError(t, strgErr)
s.Add(&hostsfile.Record{
s.Add(ctx, &hostsfile.Record{
Addr: cliIP1,
Names: []string{cliName1},
})
@ -476,7 +482,9 @@ func TestClientsDHCP(t *testing.T) {
//
// TODO(a.garipov): Consider adding ways of explicitly clearing runtime
// sources by source.
s, strgErr = hostsfile.NewDefaultStorage()
s, strgErr = hostsfile.NewDefaultStorage(ctx, &hostsfile.DefaultStorageConfig{
Logger: testLogger,
})
require.NoError(t, strgErr)
testutil.RequireSend(t, etcHostsCh, s, testTimeout)
@ -576,6 +584,7 @@ func TestClientsAddExisting(t *testing.T) {
// First, init a DHCP server with a single static lease.
config := &dhcpd.ServerConfig{
Logger: testLogger,
Enabled: true,
DataDir: t.TempDir(),
Conf4: dhcpd.V4ServerConf{
@ -587,7 +596,8 @@ func TestClientsAddExisting(t *testing.T) {
},
}
dhcpServer, err := dhcpd.Create(config)
ctx = testutil.ContextWithTimeout(t, testTimeout)
dhcpServer, err := dhcpd.Create(ctx, config)
require.NoError(t, err)
storage, err := client.NewStorage(ctx, &client.StorageConfig{

View File

@ -0,0 +1,13 @@
package configmigrate
import (
"time"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
// testTimeout is the common timeout for tests.
const testTimeout = 1 * time.Second
// testLogger is a logger used in tests.
var testLogger = slogutil.NewDiscardLogger()

View File

@ -0,0 +1,13 @@
package configmigrate_test
import (
"time"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
// testTimeout is the common timeout for tests.
const testTimeout = 1 * time.Second
// testLogger is a logger used in tests.
var testLogger = slogutil.NewDiscardLogger()

View File

@ -12,17 +12,25 @@ import (
"github.com/stretchr/testify/require"
)
// TODO(e.burkov): Cover all migrations, use a testdata/ dir.
func TestUpgradeSchema1to2(t *testing.T) {
diskConf := testDiskConf(1)
m := New(&Config{
// emptyMigrator is a helper function that returns initialized with empty values
// *Migrator and no-op implementations for tests.
func emptyMigrator() (m *Migrator) {
return New(&Config{
Logger: testLogger,
WorkingDir: "",
DataDir: "",
})
}
err := m.migrateTo2(diskConf)
func TestUpgradeSchema1to2(t *testing.T) {
t.Parallel()
diskConf := testDiskConf(1)
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo2(ctx, diskConf)
require.NoError(t, err)
require.Equal(t, diskConf["schema_version"], 2)
@ -43,9 +51,13 @@ func TestUpgradeSchema1to2(t *testing.T) {
}
func TestUpgradeSchema2to3(t *testing.T) {
t.Parallel()
diskConf := testDiskConf(2)
err := migrateTo3(diskConf)
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo3(ctx, diskConf)
require.NoError(t, err)
require.Equal(t, diskConf["schema_version"], 3)
@ -75,6 +87,8 @@ func TestUpgradeSchema2to3(t *testing.T) {
}
func TestUpgradeSchema5to6(t *testing.T) {
t.Parallel()
const newSchemaVer = 6
testCases := []struct {
@ -156,7 +170,11 @@ func TestUpgradeSchema5to6(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := migrateTo6(tc.in)
t.Parallel()
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo6(ctx, tc.in)
testutil.AssertErrorMsg(t, tc.wantErr, err)
assert.Equal(t, tc.want, tc.in)
})
@ -164,6 +182,8 @@ func TestUpgradeSchema5to6(t *testing.T) {
}
func TestUpgradeSchema7to8(t *testing.T) {
t.Parallel()
const host = "1.2.3.4"
oldConf := yobj{
"dns": yobj{
@ -172,7 +192,9 @@ func TestUpgradeSchema7to8(t *testing.T) {
"schema_version": 7,
}
err := migrateTo8(oldConf)
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo8(ctx, oldConf)
require.NoError(t, err)
require.Equal(t, oldConf["schema_version"], 8)
@ -190,9 +212,13 @@ func TestUpgradeSchema7to8(t *testing.T) {
}
func TestUpgradeSchema8to9(t *testing.T) {
t.Parallel()
const tld = "foo"
t.Run("with_autohost_tld", func(t *testing.T) {
t.Parallel()
oldConf := yobj{
"dns": yobj{
"autohost_tld": tld,
@ -200,7 +226,9 @@ func TestUpgradeSchema8to9(t *testing.T) {
"schema_version": 8,
}
err := migrateTo9(oldConf)
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo9(ctx, oldConf)
require.NoError(t, err)
require.Equal(t, oldConf["schema_version"], 9)
@ -218,12 +246,16 @@ func TestUpgradeSchema8to9(t *testing.T) {
})
t.Run("without_autohost_tld", func(t *testing.T) {
t.Parallel()
oldConf := yobj{
"dns": yobj{},
"schema_version": 8,
}
err := migrateTo9(oldConf)
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo9(ctx, oldConf)
require.NoError(t, err)
require.Equal(t, oldConf["schema_version"], 9)
@ -315,6 +347,8 @@ func testDNSConf(schemaVersion int) (dnsConf yobj) {
}
func TestAddQUICPort(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
ups string
@ -387,6 +421,8 @@ func TestAddQUICPort(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
withPort := addQUICPort(tc.ups, 784)
assert.Equal(t, tc.want, withPort)
@ -395,6 +431,8 @@ func TestAddQUICPort(t *testing.T) {
}
func TestUpgradeSchema9to10(t *testing.T) {
t.Parallel()
const ultimateAns = 42
testCases := []struct {
@ -427,7 +465,11 @@ func TestUpgradeSchema9to10(t *testing.T) {
"schema_version": 9,
}
t.Run(tc.name, func(t *testing.T) {
err := migrateTo10(conf)
t.Parallel()
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo10(ctx, conf)
if tc.wantErr != "" {
testutil.AssertErrorMsg(t, tc.wantErr, err)
@ -452,13 +494,21 @@ func TestUpgradeSchema9to10(t *testing.T) {
}
t.Run("no_dns", func(t *testing.T) {
err := migrateTo10(yobj{})
t.Parallel()
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo10(ctx, yobj{})
assert.NoError(t, err)
})
t.Run("bad_dns", func(t *testing.T) {
err := migrateTo10(yobj{
t.Parallel()
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo10(ctx, yobj{
"dns": ultimateAns,
})
@ -467,10 +517,14 @@ func TestUpgradeSchema9to10(t *testing.T) {
}
func TestUpgradeSchema10to11(t *testing.T) {
t.Parallel()
check := func(t *testing.T, conf yobj) {
rlimit, _ := conf["rlimit_nofile"].(int)
err := migrateTo11(conf)
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo11(ctx, conf)
require.NoError(t, err)
require.Equal(t, conf["schema_version"], 11)
@ -498,6 +552,8 @@ func TestUpgradeSchema10to11(t *testing.T) {
const rlimit = 42
t.Run("with_rlimit", func(t *testing.T) {
t.Parallel()
conf := yobj{
"rlimit_nofile": rlimit,
"schema_version": 10,
@ -506,6 +562,8 @@ func TestUpgradeSchema10to11(t *testing.T) {
})
t.Run("without_rlimit", func(t *testing.T) {
t.Parallel()
conf := yobj{
"schema_version": 10,
}
@ -514,6 +572,8 @@ func TestUpgradeSchema10to11(t *testing.T) {
}
func TestUpgradeSchema11to12(t *testing.T) {
t.Parallel()
testCases := []struct {
ivl any
want any
@ -539,7 +599,11 @@ func TestUpgradeSchema11to12(t *testing.T) {
"schema_version": 11,
}
t.Run(tc.name, func(t *testing.T) {
err := migrateTo12(conf)
t.Parallel()
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo12(ctx, conf)
if tc.wantErr != "" {
require.Error(t, err)
@ -568,13 +632,21 @@ func TestUpgradeSchema11to12(t *testing.T) {
}
t.Run("no_dns", func(t *testing.T) {
err := migrateTo12(yobj{})
t.Parallel()
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo12(ctx, yobj{})
assert.NoError(t, err)
})
t.Run("bad_dns", func(t *testing.T) {
err := migrateTo12(yobj{
t.Parallel()
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo12(ctx, yobj{
"dns": 0,
})
@ -582,11 +654,15 @@ func TestUpgradeSchema11to12(t *testing.T) {
})
t.Run("no_field", func(t *testing.T) {
t.Parallel()
conf := yobj{
"dns": yobj{},
}
err := migrateTo12(conf)
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo12(ctx, conf)
require.NoError(t, err)
dns, ok := conf["dns"]
@ -609,6 +685,8 @@ func TestUpgradeSchema11to12(t *testing.T) {
}
func TestUpgradeSchema12to13(t *testing.T) {
t.Parallel()
const newSchemaVer = 13
testCases := []struct {
@ -646,7 +724,11 @@ func TestUpgradeSchema12to13(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := migrateTo13(tc.in)
t.Parallel()
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo13(ctx, tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, tc.in)
@ -655,6 +737,8 @@ func TestUpgradeSchema12to13(t *testing.T) {
}
func TestUpgradeSchema13to14(t *testing.T) {
t.Parallel()
const newSchemaVer = 14
testClient := yobj{
@ -728,7 +812,11 @@ func TestUpgradeSchema13to14(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := migrateTo14(tc.in)
t.Parallel()
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo14(ctx, tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, tc.in)
@ -737,6 +825,8 @@ func TestUpgradeSchema13to14(t *testing.T) {
}
func TestUpgradeSchema14to15(t *testing.T) {
t.Parallel()
const newSchemaVer = 15
defaultWantObj := yobj{
@ -776,7 +866,11 @@ func TestUpgradeSchema14to15(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := migrateTo15(tc.in)
t.Parallel()
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo15(ctx, tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, tc.in)
@ -785,6 +879,8 @@ func TestUpgradeSchema14to15(t *testing.T) {
}
func TestUpgradeSchema15to16(t *testing.T) {
t.Parallel()
const newSchemaVer = 16
defaultWantObj := yobj{
@ -835,7 +931,11 @@ func TestUpgradeSchema15to16(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := migrateTo16(tc.in)
t.Parallel()
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo16(ctx, tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, tc.in)
@ -844,6 +944,8 @@ func TestUpgradeSchema15to16(t *testing.T) {
}
func TestUpgradeSchema16to17(t *testing.T) {
t.Parallel()
const newSchemaVer = 17
defaultWantObj := yobj{
@ -896,7 +998,11 @@ func TestUpgradeSchema16to17(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := migrateTo17(tc.in)
t.Parallel()
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo17(ctx, tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, tc.in)
@ -905,6 +1011,8 @@ func TestUpgradeSchema16to17(t *testing.T) {
}
func TestUpgradeSchema17to18(t *testing.T) {
t.Parallel()
const newSchemaVer = 18
defaultWantObj := yobj{
@ -955,7 +1063,11 @@ func TestUpgradeSchema17to18(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := migrateTo18(tc.in)
t.Parallel()
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo18(ctx, tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, tc.in)
@ -964,6 +1076,8 @@ func TestUpgradeSchema17to18(t *testing.T) {
}
func TestUpgradeSchema18to19(t *testing.T) {
t.Parallel()
const newSchemaVer = 19
defaultWantObj := yobj{
@ -1039,7 +1153,11 @@ func TestUpgradeSchema18to19(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := migrateTo19(tc.in)
t.Parallel()
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo19(ctx, tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, tc.in)
@ -1048,6 +1166,8 @@ func TestUpgradeSchema18to19(t *testing.T) {
}
func TestUpgradeSchema19to20(t *testing.T) {
t.Parallel()
testCases := []struct {
ivl any
want any
@ -1078,7 +1198,11 @@ func TestUpgradeSchema19to20(t *testing.T) {
"schema_version": 19,
}
t.Run(tc.name, func(t *testing.T) {
err := migrateTo20(conf)
t.Parallel()
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo20(ctx, conf)
if tc.wantErr != "" {
require.Error(t, err)
@ -1107,13 +1231,21 @@ func TestUpgradeSchema19to20(t *testing.T) {
}
t.Run("no_stats", func(t *testing.T) {
err := migrateTo20(yobj{})
t.Parallel()
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo20(ctx, yobj{})
assert.NoError(t, err)
})
t.Run("bad_stats", func(t *testing.T) {
err := migrateTo20(yobj{
t.Parallel()
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo20(ctx, yobj{
"statistics": 0,
})
@ -1121,11 +1253,15 @@ func TestUpgradeSchema19to20(t *testing.T) {
})
t.Run("no_field", func(t *testing.T) {
t.Parallel()
conf := yobj{
"statistics": yobj{},
}
err := migrateTo20(conf)
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo20(ctx, conf)
require.NoError(t, err)
statsVal, ok := conf["statistics"]
@ -1148,6 +1284,8 @@ func TestUpgradeSchema19to20(t *testing.T) {
}
func TestUpgradeSchema20to21(t *testing.T) {
t.Parallel()
const newSchemaVer = 21
testCases := []struct {
@ -1182,7 +1320,11 @@ func TestUpgradeSchema20to21(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := migrateTo21(tc.in)
t.Parallel()
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo21(ctx, tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, tc.in)
@ -1191,6 +1333,8 @@ func TestUpgradeSchema20to21(t *testing.T) {
}
func TestUpgradeSchema21to22(t *testing.T) {
t.Parallel()
const newSchemaVer = 22
testCases := []struct {
@ -1252,7 +1396,11 @@ func TestUpgradeSchema21to22(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := migrateTo22(tc.in)
t.Parallel()
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo22(ctx, tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, tc.in)
@ -1261,6 +1409,8 @@ func TestUpgradeSchema21to22(t *testing.T) {
}
func TestUpgradeSchema22to23(t *testing.T) {
t.Parallel()
const newSchemaVer = 23
testCases := []struct {
@ -1305,7 +1455,11 @@ func TestUpgradeSchema22to23(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := migrateTo23(tc.in)
t.Parallel()
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo23(ctx, tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, tc.in)
@ -1314,6 +1468,8 @@ func TestUpgradeSchema22to23(t *testing.T) {
}
func TestUpgradeSchema23to24(t *testing.T) {
t.Parallel()
const newSchemaVer = 24
testCases := []struct {
@ -1372,7 +1528,11 @@ func TestUpgradeSchema23to24(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := migrateTo24(tc.in)
t.Parallel()
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo24(ctx, tc.in)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.want, tc.in)
@ -1381,6 +1541,8 @@ func TestUpgradeSchema23to24(t *testing.T) {
}
func TestUpgradeSchema24to25(t *testing.T) {
t.Parallel()
const newSchemaVer = 25
testCases := []struct {
@ -1459,7 +1621,11 @@ func TestUpgradeSchema24to25(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := migrateTo25(tc.in)
t.Parallel()
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo25(ctx, tc.in)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.want, tc.in)
@ -1468,6 +1634,8 @@ func TestUpgradeSchema24to25(t *testing.T) {
}
func TestUpgradeSchema25to26(t *testing.T) {
t.Parallel()
const newSchemaVer = 26
testCases := []struct {
@ -1558,7 +1726,11 @@ func TestUpgradeSchema25to26(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := migrateTo26(tc.in)
t.Parallel()
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo26(ctx, tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, tc.in)
@ -1567,6 +1739,8 @@ func TestUpgradeSchema25to26(t *testing.T) {
}
func TestUpgradeSchema26to27(t *testing.T) {
t.Parallel()
const newSchemaVer = 27
testCases := []struct {
@ -1641,7 +1815,11 @@ func TestUpgradeSchema26to27(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := migrateTo27(tc.in)
t.Parallel()
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo27(ctx, tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, tc.in)
@ -1650,6 +1828,8 @@ func TestUpgradeSchema26to27(t *testing.T) {
}
func TestUpgradeSchema27to28(t *testing.T) {
t.Parallel()
const newSchemaVer = 28
testCases := []struct {
@ -1722,7 +1902,11 @@ func TestUpgradeSchema27to28(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := migrateTo28(tc.in)
t.Parallel()
m := emptyMigrator()
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := m.migrateTo28(ctx, tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, tc.in)

View File

@ -2,14 +2,19 @@ package configmigrate
import (
"bytes"
"context"
"fmt"
"log/slog"
"github.com/AdguardTeam/golibs/log"
yaml "go.yaml.in/yaml/v4"
)
// Config is a the configuration for initializing a [Migrator].
type Config struct {
// Logger is used to log the operation of configuration migrator. It must
// not be nil.
Logger *slog.Logger
// WorkingDir is the absolute path to the working directory of AdGuardHome.
WorkingDir string
@ -19,6 +24,7 @@ type Config struct {
// Migrator performs the YAML configuration file migrations.
type Migrator struct {
logger *slog.Logger
workingDir string
dataDir string
}
@ -26,6 +32,7 @@ type Migrator struct {
// New creates a new Migrator.
func New(c *Config) (m *Migrator) {
return &Migrator{
logger: c.Logger,
workingDir: c.WorkingDir,
dataDir: c.DataDir,
}
@ -35,7 +42,11 @@ func New(c *Config) (m *Migrator) {
// schema version, if needed. It returns the body of the upgraded config file,
// whether the file was upgraded, and an error, if any. If upgraded is false,
// the body is the same as the input.
func (m *Migrator) Migrate(body []byte, target uint) (newBody []byte, upgraded bool, err error) {
func (m *Migrator) Migrate(
ctx context.Context,
body []byte,
target uint,
) (newBody []byte, upgraded bool, err error) {
diskConf := yobj{}
err = yaml.Unmarshal(body, &diskConf)
if err != nil {
@ -49,7 +60,7 @@ func (m *Migrator) Migrate(body []byte, target uint) (newBody []byte, upgraded b
}
current := uint(currentInt)
log.Debug("got schema version %v", current)
m.logger.DebugContext(ctx, "got", "schema_version", current)
if err = validateVersion(current, target); err != nil {
// Don't wrap the error, since it's informative enough as is.
@ -58,7 +69,7 @@ func (m *Migrator) Migrate(body []byte, target uint) (newBody []byte, upgraded b
return body, false, nil
}
if err = m.upgradeConfigSchema(current, target, diskConf); err != nil {
if err = m.upgradeConfigSchema(ctx, current, target, diskConf); err != nil {
// Don't wrap the error, since it's informative enough as is.
return body, false, err
}
@ -89,41 +100,45 @@ func validateVersion(current, target uint) (err error) {
}
// migrateFunc is a function that upgrades a config and returns an error.
type migrateFunc = func(diskConf yobj) (err error)
type migrateFunc = func(ctx context.Context, diskConf yobj) (err error)
// upgradeConfigSchema upgrades the configuration schema in diskConf from
// current to target version. current must be less than target, and both must
// be non-negative and less or equal to [LastSchemaVersion].
func (m *Migrator) upgradeConfigSchema(current, target uint, diskConf yobj) (err error) {
func (m *Migrator) upgradeConfigSchema(
ctx context.Context,
current, target uint,
diskConf yobj,
) (err error) {
upgrades := [LastSchemaVersion]migrateFunc{
0: m.migrateTo1,
1: m.migrateTo2,
2: migrateTo3,
3: migrateTo4,
4: migrateTo5,
5: migrateTo6,
6: migrateTo7,
7: migrateTo8,
8: migrateTo9,
9: migrateTo10,
10: migrateTo11,
11: migrateTo12,
12: migrateTo13,
13: migrateTo14,
14: migrateTo15,
15: migrateTo16,
16: migrateTo17,
17: migrateTo18,
18: migrateTo19,
19: migrateTo20,
20: migrateTo21,
21: migrateTo22,
22: migrateTo23,
23: migrateTo24,
24: migrateTo25,
25: migrateTo26,
26: migrateTo27,
27: migrateTo28,
2: m.migrateTo3,
3: m.migrateTo4,
4: m.migrateTo5,
5: m.migrateTo6,
6: m.migrateTo7,
7: m.migrateTo8,
8: m.migrateTo9,
9: m.migrateTo10,
10: m.migrateTo11,
11: m.migrateTo12,
12: m.migrateTo13,
13: m.migrateTo14,
14: m.migrateTo15,
15: m.migrateTo16,
16: m.migrateTo17,
17: m.migrateTo18,
18: m.migrateTo19,
19: m.migrateTo20,
20: m.migrateTo21,
21: m.migrateTo22,
22: m.migrateTo23,
23: m.migrateTo24,
24: m.migrateTo25,
25: m.migrateTo26,
26: m.migrateTo27,
27: m.migrateTo28,
28: m.migrateTo29,
29: m.migrateTo30,
30: m.migrateTo31,
@ -133,9 +148,9 @@ func (m *Migrator) upgradeConfigSchema(current, target uint, diskConf yobj) (err
cur := current + uint(i)
next := current + uint(i) + 1
log.Printf("Upgrade yaml: %d to %d", cur, next)
m.logger.InfoContext(ctx, "upgrade yaml", "from", cur, "to", next)
if err = migrate(diskConf); err != nil {
if err = migrate(ctx, diskConf); err != nil {
return fmt.Errorf("migrating schema %d to %d: %w", cur, next, err)
}
}

View File

@ -16,10 +16,6 @@ import (
"golang.org/x/crypto/bcrypt"
)
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
// testdata is a virtual filesystem containing test data.
var testdata = os.DirFS("testdata")
@ -195,6 +191,10 @@ func TestMigrateConfig_Migrate(t *testing.T) {
yamlEqFunc: require.YAMLEq,
name: "v27",
targetVersion: 27,
}, {
yamlEqFunc: require.YAMLEq,
name: "v28",
targetVersion: 28,
}, {
yamlEqFunc: require.YAMLEq,
name: "v30",
@ -216,10 +216,12 @@ func TestMigrateConfig_Migrate(t *testing.T) {
require.NoError(t, err)
migrator := configmigrate.New(&configmigrate.Config{
Logger: testLogger,
WorkingDir: t.Name(),
DataDir: filepath.Join(t.Name(), "data"),
})
newBody, upgraded, err := migrator.Migrate(body, tc.targetVersion)
ctx := testutil.ContextWithTimeout(t, testTimeout)
newBody, upgraded, err := migrator.Migrate(ctx, body, tc.targetVersion)
require.NoError(t, err)
require.True(t, upgraded)
@ -230,6 +232,8 @@ func TestMigrateConfig_Migrate(t *testing.T) {
// TODO(a.garipov): Consider ways of merging into the previous one.
func TestMigrateConfig_Migrate_v29(t *testing.T) {
t.Parallel()
const (
pathUnix = `/path/to/file.txt`
userDirPatUnix = `TestMigrateConfig_Migrate/v29/data/userfilters/*`
@ -257,11 +261,13 @@ func TestMigrateConfig_Migrate_v29(t *testing.T) {
wantBody = bytes.ReplaceAll(wantBody, []byte("USERFILTERSPATH"), []byte(patternToReplace))
migrator := configmigrate.New(&configmigrate.Config{
Logger: testLogger,
WorkingDir: t.Name(),
DataDir: "TestMigrateConfig_Migrate/v29/data",
})
newBody, upgraded, err := migrator.Migrate(body, 29)
ctx := testutil.ContextWithTimeout(t, testTimeout)
newBody, upgraded, err := migrator.Migrate(ctx, body, 29)
require.NoError(t, err)
require.True(t, upgraded)

View File

@ -0,0 +1,124 @@
http:
address: 127.0.0.1:3000
session_ttl: 3h
pprof:
enabled: true
port: 6060
users:
- name: testuser
password: testpassword
dns:
bind_hosts:
- 127.0.0.1
port: 53
parental_sensitivity: 0
upstream_dns:
- tls://1.1.1.1
- tls://1.0.0.1
- quic://8.8.8.8:784
bootstrap_dns:
- 8.8.8.8:53
all_servers: true
edns_client_subnet:
enabled: true
use_custom: false
custom_ip: ""
filtering:
filtering_enabled: true
parental_enabled: false
safebrowsing_enabled: false
safe_search:
enabled: false
bing: true
duckduckgo: true
google: true
pixabay: true
yandex: true
youtube: true
protection_enabled: true
blocked_services:
schedule:
time_zone: Local
ids:
- 500px
blocked_response_ttl: 10
filters:
- url: https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt
name: ""
enabled: true
- url: https://adaway.org/hosts.txt
name: AdAway
enabled: false
- url: https://hosts-file.net/ad_servers.txt
name: hpHosts - Ad and Tracking servers only
enabled: false
- url: http://www.malwaredomainlist.com/hostslist/hosts.txt
name: MalwareDomainList.com Hosts List
enabled: false
clients:
persistent:
- name: localhost
ids:
- 127.0.0.1
- aa:aa:aa:aa:aa:aa
use_global_settings: true
use_global_blocked_services: true
filtering_enabled: false
parental_enabled: false
safebrowsing_enabled: false
safe_search:
enabled: true
bing: true
duckduckgo: true
google: true
pixabay: true
yandex: true
youtube: true
blocked_services:
schedule:
time_zone: Local
ids:
- 500px
runtime_sources:
whois: true
arp: true
rdns: true
dhcp: true
hosts: true
dhcp:
enabled: false
interface_name: vboxnet0
local_domain_name: local
dhcpv4:
gateway_ip: 192.168.0.1
subnet_mask: 255.255.255.0
range_start: 192.168.0.10
range_end: 192.168.0.250
lease_duration: 1234
icmp_timeout_msec: 10
schema_version: 27
user_rules: []
querylog:
enabled: true
file_enabled: true
interval: 720h
size_memory: 1000
ignored:
- '|.^'
statistics:
enabled: true
interval: 240h
ignored:
- '|.^'
os:
group: ''
rlimit_nofile: 123
user: ''
log:
file: ""
max_backups: 0
max_size: 100
max_age: 3
compress: true
local_time: false
verbose: true

View File

@ -0,0 +1,124 @@
http:
address: 127.0.0.1:3000
session_ttl: 3h
pprof:
enabled: true
port: 6060
users:
- name: testuser
password: testpassword
dns:
bind_hosts:
- 127.0.0.1
port: 53
parental_sensitivity: 0
upstream_dns:
- tls://1.1.1.1
- tls://1.0.0.1
- quic://8.8.8.8:784
bootstrap_dns:
- 8.8.8.8:53
upstream_mode: parallel
edns_client_subnet:
enabled: true
use_custom: false
custom_ip: ""
filtering:
filtering_enabled: true
parental_enabled: false
safebrowsing_enabled: false
safe_search:
enabled: false
bing: true
duckduckgo: true
google: true
pixabay: true
yandex: true
youtube: true
protection_enabled: true
blocked_services:
schedule:
time_zone: Local
ids:
- 500px
blocked_response_ttl: 10
filters:
- url: https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt
name: ""
enabled: true
- url: https://adaway.org/hosts.txt
name: AdAway
enabled: false
- url: https://hosts-file.net/ad_servers.txt
name: hpHosts - Ad and Tracking servers only
enabled: false
- url: http://www.malwaredomainlist.com/hostslist/hosts.txt
name: MalwareDomainList.com Hosts List
enabled: false
clients:
persistent:
- name: localhost
ids:
- 127.0.0.1
- aa:aa:aa:aa:aa:aa
use_global_settings: true
use_global_blocked_services: true
filtering_enabled: false
parental_enabled: false
safebrowsing_enabled: false
safe_search:
enabled: true
bing: true
duckduckgo: true
google: true
pixabay: true
yandex: true
youtube: true
blocked_services:
schedule:
time_zone: Local
ids:
- 500px
runtime_sources:
whois: true
arp: true
rdns: true
dhcp: true
hosts: true
dhcp:
enabled: false
interface_name: vboxnet0
local_domain_name: local
dhcpv4:
gateway_ip: 192.168.0.1
subnet_mask: 255.255.255.0
range_start: 192.168.0.10
range_end: 192.168.0.250
lease_duration: 1234
icmp_timeout_msec: 10
schema_version: 28
user_rules: []
querylog:
enabled: true
file_enabled: true
interval: 720h
size_memory: 1000
ignored:
- '|.^'
statistics:
enabled: true
interval: 240h
ignored:
- '|.^'
os:
group: ''
rlimit_nofile: 123
user: ''
log:
file: ""
max_backups: 0
max_size: 100
max_age: 3
compress: true
local_time: false
verbose: true

View File

@ -1,11 +1,12 @@
package configmigrate
import (
"context"
"os"
"path/filepath"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
// migrateTo1 performs the following changes:
@ -19,14 +20,14 @@ import (
//
// It also deletes the unused dnsfilter.txt file, since the following versions
// store filters in data/filters/.
func (m *Migrator) migrateTo1(diskConf yobj) (err error) {
func (m *Migrator) migrateTo1(ctx context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 1
dnsFilterPath := filepath.Join(m.workingDir, "dnsfilter.txt")
log.Printf("deleting %s as we don't need it anymore", dnsFilterPath)
m.logger.InfoContext(ctx, "deleting file as we do not need it anymore", "path", dnsFilterPath)
err = os.Remove(dnsFilterPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Info("warning: %s", err)
m.logger.InfoContext(ctx, "failed to delete", slogutil.KeyError, err)
// Go on.
}

View File

@ -1,6 +1,7 @@
package configmigrate
import (
"context"
"fmt"
"net/url"
"strconv"
@ -30,7 +31,7 @@ import (
// - 'quic://some-upstream.com:784'
// # …
// # …
func migrateTo10(diskConf yobj) (err error) {
func (m *Migrator) migrateTo10(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 10
dns, ok, err := fieldVal[yobj](diskConf, "dns")

View File

@ -1,5 +1,7 @@
package configmigrate
import "context"
// migrateTo11 performs the following changes:
//
// # BEFORE:
@ -14,7 +16,7 @@ package configmigrate
// 'rlimit_nofile': 42
// 'user': ''
// # …
func migrateTo11(diskConf yobj) (err error) {
func (m *Migrator) migrateTo11(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 11
rlimit, _, err := fieldVal[int](diskConf, "rlimit_nofile")

View File

@ -1,6 +1,7 @@
package configmigrate
import (
"context"
"time"
"github.com/AdguardTeam/golibs/timeutil"
@ -17,7 +18,7 @@ import (
// 'schema_version': 12
// 'querylog_interval': '2160h'
// # …
func migrateTo12(diskConf yobj) (err error) {
func (m *Migrator) migrateTo12(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 12
dns, ok, err := fieldVal[yobj](diskConf, "dns")

View File

@ -1,5 +1,7 @@
package configmigrate
import "context"
// migrateTo13 performs the following changes:
//
// # BEFORE:
@ -15,7 +17,7 @@ package configmigrate
// 'local_domain_name': 'lan'
// # …
// # …
func migrateTo13(diskConf yobj) (err error) {
func (m *Migrator) migrateTo13(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 13
dns, ok, err := fieldVal[yobj](diskConf, "dns")

View File

@ -1,5 +1,7 @@
package configmigrate
import "context"
// migrateTo14 performs the following changes:
//
// # BEFORE:
@ -27,7 +29,7 @@ package configmigrate
// 'dhcp': true
// 'hosts': true
// # …
func migrateTo14(diskConf yobj) (err error) {
func (m *Migrator) migrateTo14(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 14
persistent, ok, err := fieldVal[yarr](diskConf, "clients")

View File

@ -1,6 +1,10 @@
package configmigrate
import "github.com/AdguardTeam/golibs/errors"
import (
"context"
"github.com/AdguardTeam/golibs/errors"
)
// migrateTo15 performs the following changes:
//
@ -28,7 +32,7 @@ import "github.com/AdguardTeam/golibs/errors"
// 'ignored': []
// # …
// # …
func migrateTo15(diskConf yobj) (err error) {
func (m *Migrator) migrateTo15(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 15
dns, ok, err := fieldVal[yobj](diskConf, "dns")

View File

@ -1,5 +1,7 @@
package configmigrate
import "context"
// migrateTo16 performs the following changes:
//
// # BEFORE:
@ -43,7 +45,7 @@ package configmigrate
// 'ignored': []
// # …
// # …
func migrateTo16(diskConf yobj) (err error) {
func (m *Migrator) migrateTo16(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 16
dns, ok, err := fieldVal[yobj](diskConf, "dns")

View File

@ -1,5 +1,7 @@
package configmigrate
import "context"
// migrateTo17 performs the following changes:
//
// # BEFORE:
@ -18,7 +20,7 @@ package configmigrate
// 'custom_ip': ""
// # …
// # …
func migrateTo17(diskConf yobj) (err error) {
func (m *Migrator) migrateTo17(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 17
dns, ok, err := fieldVal[yobj](diskConf, "dns")

View File

@ -1,5 +1,7 @@
package configmigrate
import "context"
// migrateTo18 performs the following changes:
//
// # BEFORE:
@ -22,7 +24,7 @@ package configmigrate
// 'youtube': true
// # …
// # …
func migrateTo18(diskConf yobj) (err error) {
func (m *Migrator) migrateTo18(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 18
dns, ok, err := fieldVal[yobj](diskConf, "dns")

View File

@ -1,6 +1,10 @@
package configmigrate
import "github.com/AdguardTeam/golibs/log"
import (
"context"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
// migrateTo19 performs the following changes:
//
@ -30,7 +34,7 @@ import "github.com/AdguardTeam/golibs/log"
// # …
// # …
// # …
func migrateTo19(diskConf yobj) (err error) {
func (m *Migrator) migrateTo19(ctx context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 19
clients, ok, err := fieldVal[yobj](diskConf, "clients")
@ -62,7 +66,7 @@ func migrateTo19(diskConf yobj) (err error) {
err = moveVal[bool](c, safeSearch, "safesearch_enabled", "enabled")
if err != nil {
log.Debug("migrating to version 19: %s", err)
m.logger.DebugContext(ctx, "migrating to", "version", 19, slogutil.KeyError, err)
}
c["safe_search"] = safeSearch

View File

@ -1,11 +1,12 @@
package configmigrate
import (
"context"
"os"
"path/filepath"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
// migrateTo2 performs the following changes:
@ -21,14 +22,14 @@ import (
// # …
//
// It also deletes the Corefile file, since it isn't used anymore.
func (m *Migrator) migrateTo2(diskConf yobj) (err error) {
func (m *Migrator) migrateTo2(ctx context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 2
coreFilePath := filepath.Join(m.workingDir, "Corefile")
log.Printf("deleting %s as we don't need it anymore", coreFilePath)
m.logger.InfoContext(ctx, "deleting file as we do not need it anymore", "path", coreFilePath)
err = os.Remove(coreFilePath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Info("warning: %s", err)
m.logger.WarnContext(ctx, "failed to delete", slogutil.KeyError, err)
// Go on.
}

View File

@ -1,6 +1,7 @@
package configmigrate
import (
"context"
"time"
"github.com/AdguardTeam/golibs/timeutil"
@ -21,7 +22,7 @@ import (
// 'interval': 24h
// # …
// # …
func migrateTo20(diskConf yobj) (err error) {
func (m *Migrator) migrateTo20(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 20
stats, ok, err := fieldVal[yobj](diskConf, "statistics")

View File

@ -1,5 +1,7 @@
package configmigrate
import "context"
// migrateTo21 performs the following changes:
//
// # BEFORE:
@ -22,7 +24,7 @@ package configmigrate
// 'time_zone': 'Local'
// # …
// # …
func migrateTo21(diskConf yobj) (err error) {
func (m *Migrator) migrateTo21(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 21
const field = "blocked_services"

View File

@ -1,6 +1,7 @@
package configmigrate
import (
"context"
"fmt"
)
@ -30,7 +31,7 @@ import (
// # …
// # …
// # …
func migrateTo22(diskConf yobj) (err error) {
func (m *Migrator) migrateTo22(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 22
const field = "blocked_services"

View File

@ -1,6 +1,7 @@
package configmigrate
import (
"context"
"fmt"
"net/netip"
"time"
@ -23,7 +24,7 @@ import (
// 'address': '1.2.3.4:8080'
// 'session_ttl': '720h'
// # …
func migrateTo23(diskConf yobj) (err error) {
func (m *Migrator) migrateTo23(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 23
bindHost, ok, err := fieldVal[string](diskConf, "bind_host")

View File

@ -1,6 +1,10 @@
package configmigrate
import "github.com/AdguardTeam/golibs/errors"
import (
"context"
"github.com/AdguardTeam/golibs/errors"
)
// migrateTo24 performs the following changes:
//
@ -26,7 +30,7 @@ import "github.com/AdguardTeam/golibs/errors"
// 'local_time': false
// 'verbose': false
// # …
func migrateTo24(diskConf yobj) (err error) {
func (m *Migrator) migrateTo24(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 24
logObj := yobj{}

View File

@ -1,5 +1,7 @@
package configmigrate
import "context"
// migrateTo25 performs the following changes:
//
// # BEFORE:
@ -14,7 +16,7 @@ package configmigrate
// 'enabled': true
// 'port': 6060
// # …
func migrateTo25(diskConf yobj) (err error) {
func (m *Migrator) migrateTo25(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 25
httpObj, ok, err := fieldVal[yobj](diskConf, "http")

View File

@ -1,6 +1,10 @@
package configmigrate
import "github.com/AdguardTeam/golibs/errors"
import (
"context"
"github.com/AdguardTeam/golibs/errors"
)
// migrateTo26 performs the following changes:
//
@ -71,7 +75,7 @@ import "github.com/AdguardTeam/golibs/errors"
// 'dns'
// # …
// # …
func migrateTo26(diskConf yobj) (err error) {
func (m *Migrator) migrateTo26(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 26
dns, ok, err := fieldVal[yobj](diskConf, "dns")

View File

@ -1,5 +1,7 @@
package configmigrate
import "context"
// migrateTo27 performs the following changes:
//
// # BEFORE:
@ -27,7 +29,7 @@ package configmigrate
// - # …
// # …
// # …
func migrateTo27(diskConf yobj) (err error) {
func (m *Migrator) migrateTo27(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 27
keys := []string{"querylog", "statistics"}

View File

@ -1,6 +1,8 @@
package configmigrate
import (
"context"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
)
@ -18,7 +20,7 @@ import (
// 'upstream_mode': 'parallel'
// # …
// # …
func migrateTo28(diskConf yobj) (err error) {
func (m *Migrator) migrateTo28(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 28
dns, ok, err := fieldVal[yobj](diskConf, "dns")

View File

@ -1,6 +1,7 @@
package configmigrate
import (
"context"
"fmt"
"path/filepath"
)
@ -26,7 +27,7 @@ import (
// - '/opt/AdGuardHome/data/userfilters/*'
// - '/path/to/file.txt'
// # …
func (m Migrator) migrateTo29(diskConf yobj) (err error) {
func (m Migrator) migrateTo29(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 29
filterVals, ok, err := fieldVal[[]any](diskConf, "filters")

View File

@ -1,5 +1,7 @@
package configmigrate
import "context"
// migrateTo3 performs the following changes:
//
// # BEFORE:
@ -14,7 +16,7 @@ package configmigrate
// 'bootstrap_dns':
// - '1.1.1.1'
// # …
func migrateTo3(diskConf yobj) (err error) {
func (m *Migrator) migrateTo3(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 3
dnsConfig, ok, err := fieldVal[yobj](diskConf, "dns")

View File

@ -1,5 +1,7 @@
package configmigrate
import "context"
// migrateTo30 performs the following changes:
//
// # BEFORE:
@ -14,7 +16,7 @@ package configmigrate
// # …
//
// If cache_size is zero, then cache_enabled should be false.
func (m Migrator) migrateTo30(diskConf yobj) (err error) {
func (m Migrator) migrateTo30(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 30
dnsConf, ok, err := fieldVal[yobj](diskConf, "dns")

View File

@ -1,5 +1,7 @@
package configmigrate
import "context"
// migrateTo31 performs the following changes:
//
// # BEFORE:
@ -18,7 +20,7 @@ package configmigrate
// 'enabled': true
// # …
// # …
func (m Migrator) migrateTo31(diskConf yobj) (err error) {
func (m *Migrator) migrateTo31(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 31
fltConf, ok, err := fieldVal[yobj](diskConf, "filtering")

View File

@ -1,5 +1,7 @@
package configmigrate
import "context"
// migrateTo4 performs the following changes:
//
// # BEFORE:
@ -14,7 +16,7 @@ package configmigrate
// - 'use_global_blocked_services': true
// # …
// # …
func migrateTo4(diskConf yobj) (err error) {
func (m *Migrator) migrateTo4(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 4
clients, ok, _ := fieldVal[yarr](diskConf, "clients")

View File

@ -1,6 +1,7 @@
package configmigrate
import (
"context"
"fmt"
"golang.org/x/crypto/bcrypt"
@ -20,7 +21,7 @@ import (
// - 'name': …
// 'password': <hashed>
// # …
func migrateTo5(diskConf yobj) (err error) {
func (m *Migrator) migrateTo5(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 5
user := yobj{}

View File

@ -1,6 +1,9 @@
package configmigrate
import "fmt"
import (
"context"
"fmt"
)
// migrateTo6 performs the following changes:
//
@ -24,7 +27,7 @@ import "fmt"
// - 'AA:AA:AA:AA:AA:AA'
// # …
// # …
func migrateTo6(diskConf yobj) (err error) {
func (m *Migrator) migrateTo6(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 6
clients, ok, err := fieldVal[yarr](diskConf, "clients")

View File

@ -1,6 +1,10 @@
package configmigrate
import "github.com/AdguardTeam/golibs/errors"
import (
"context"
"github.com/AdguardTeam/golibs/errors"
)
// migrateTo7 performs the following changes:
//
@ -30,7 +34,7 @@ import "github.com/AdguardTeam/golibs/errors"
// 'lease_duration': 86400
// 'icmp_timeout_msec': 1000
// # …
func migrateTo7(diskConf yobj) (err error) {
func (m *Migrator) migrateTo7(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 7
dhcp, ok, _ := fieldVal[yobj](diskConf, "dhcp")

View File

@ -1,5 +1,7 @@
package configmigrate
import "context"
// migrateTo8 performs the following changes:
//
// # BEFORE:
@ -16,7 +18,7 @@ package configmigrate
// - '127.0.0.1'
// # …
// # …
func migrateTo8(diskConf yobj) (err error) {
func (m *Migrator) migrateTo8(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 8
dns, ok, err := fieldVal[yobj](diskConf, "dns")

View File

@ -1,5 +1,7 @@
package configmigrate
import "context"
// migrateTo9 performs the following changes:
//
// # BEFORE:
@ -15,7 +17,7 @@ package configmigrate
// 'local_domain_name': 'lan'
// # …
// # …
func migrateTo9(diskConf yobj) (err error) {
func (m *Migrator) migrateTo9(_ context.Context, diskConf yobj) (err error) {
diskConf["schema_version"] = 9
dns, ok, err := fieldVal[yobj](diskConf, "dns")

View File

@ -1,7 +1,9 @@
package dhcpd
import (
"context"
"fmt"
"log/slog"
"net"
"net/netip"
"time"
@ -17,6 +19,10 @@ import (
// ServerConfig is the configuration for the DHCP server. The order of YAML
// fields is important, since the YAML configuration file follows it.
type ServerConfig struct {
// Logger is used for logging the operation of the DHCP server. It must not
// be nil.
Logger *slog.Logger `yaml:"-"`
// CommandConstructor is used to run external commands. It must not be nil.
CommandConstructor executil.CommandConstructor `yaml:"-"`
@ -85,7 +91,7 @@ type DHCPServer interface {
WriteDiskConfig6(c *V6ServerConf)
// Start - start server
Start() (err error)
Start(ctx context.Context) (err error)
// Stop - stop server
Stop() (err error)
getLeasesRef() []*dhcpsvc.Lease
@ -93,6 +99,10 @@ type DHCPServer interface {
// V4ServerConf - server configuration
type V4ServerConf struct {
// Logger is used for logging the operation of the DHCPv4 server. It must
// not be nil.
Logger *slog.Logger `yaml:"-" json:"-"`
Enabled bool `yaml:"-" json:"-"`
InterfaceName string `yaml:"-" json:"-"`
@ -232,6 +242,10 @@ func (c *V4ServerConf) Validate() (err error) {
// V6ServerConf - server configuration
type V6ServerConf struct {
// Logger is used for logging the operation of the DHCPv6 server. It must
// not be nil.
Logger *slog.Logger `yaml:"-" json:"-"`
Enabled bool `yaml:"-" json:"-"`
InterfaceName string `yaml:"-" json:"-"`

View File

@ -2,6 +2,7 @@
package dhcpd
import (
"context"
"fmt"
"net"
"net/netip"
@ -9,7 +10,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/timeutil"
)
@ -53,7 +54,7 @@ const (
// Interface is the DHCP server that deals with both IP address families.
type Interface interface {
Start() (err error)
Start(ctx context.Context) (err error)
Stop() (err error)
// Enabled returns true if the DHCP server is running.
@ -104,9 +105,10 @@ var _ Interface = (*server)(nil)
// Create initializes and returns the DHCP server handling both address
// families. It also registers the corresponding HTTP API endpoints.
func Create(conf *ServerConfig) (s *server, err error) {
func Create(ctx context.Context, conf *ServerConfig) (s *server, err error) {
s = &server{
conf: &ServerConfig{
Logger: conf.Logger,
CommandConstructor: conf.CommandConstructor,
ConfModifier: conf.ConfModifier,
@ -125,7 +127,7 @@ func Create(conf *ServerConfig) (s *server, err error) {
// [aghhttp.RegisterFunc].
s.registerHandlers()
v4Enabled, v6Enabled, err := s.setServers(conf)
v4Enabled, v6Enabled, err := s.setServers(ctx, conf)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
@ -158,8 +160,12 @@ func Create(conf *ServerConfig) (s *server, err error) {
// setServers updates DHCPv4 and DHCPv6 servers created from the provided
// configuration conf. It returns the status of both the DHCPv4 and the DHCPv6
// servers, which is always false for corresponding server on any error.
func (s *server) setServers(conf *ServerConfig) (v4Enabled, v6Enabled bool, err error) {
func (s *server) setServers(
ctx context.Context,
conf *ServerConfig,
) (v4Enabled, v6Enabled bool, err error) {
v4conf := conf.Conf4
v4conf.Logger = s.conf.Logger.With("ip_version", "4")
v4conf.InterfaceName = s.conf.InterfaceName
v4conf.notify = s.onNotify
v4conf.Enabled = s.conf.Enabled && v4conf.RangeStart.IsValid()
@ -170,10 +176,11 @@ func (s *server) setServers(conf *ServerConfig) (v4Enabled, v6Enabled bool, err
return false, false, fmt.Errorf("creating dhcpv4 srv: %w", err)
}
log.Debug("dhcpd: warning: creating dhcpv4 srv: %s", err)
s.conf.Logger.WarnContext(ctx, "creating dhcpv4 server", slogutil.KeyError, err)
}
v6conf := conf.Conf6
v6conf.Logger = s.conf.Logger.With("ip_version", "6")
v6conf.InterfaceName = s.conf.InterfaceName
v6conf.notify = s.onNotify
v6conf.Enabled = s.conf.Enabled && len(v6conf.RangeStart) != 0
@ -213,7 +220,9 @@ func (s *server) onNotify(flags uint32) {
if flags == LeaseChangedDBStore {
err := s.dbStore()
if err != nil {
log.Error("updating db: %s", err)
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
s.conf.Logger.ErrorContext(ctx, "updating db", slogutil.KeyError, err)
}
return
@ -239,13 +248,13 @@ func (s *server) WriteDiskConfig(c *ServerConfig) {
}
// Start will listen on port 67 and serve DHCP requests.
func (s *server) Start() (err error) {
err = s.srv4.Start()
func (s *server) Start(ctx context.Context) (err error) {
err = s.srv4.Start(ctx)
if err != nil {
return err
}
err = s.srv6.Start()
err = s.srv6.Start(ctx)
if err != nil {
return err
}

View File

@ -0,0 +1,13 @@
package dhcpd
import (
"time"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
// testTimeout is a common timeout for tests.
const testTimeout = 1 * time.Second
// testLogger is a logger used in tests.
var testLogger = slogutil.NewDiscardLogger()

View File

@ -7,6 +7,7 @@ import (
"encoding/json"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/netip"
@ -205,7 +206,7 @@ func (s *server) enableDHCP(ctx context.Context, ifaceName string) (code int, er
}
if !hasStaticIP {
err = aghnet.IfaceSetStaticIP(ctx, cmdCons, ifaceName)
err = aghnet.IfaceSetStaticIP(ctx, s.conf.Logger, cmdCons, ifaceName)
if err != nil {
err = fmt.Errorf("setting static ip: %w", err)
@ -213,7 +214,7 @@ func (s *server) enableDHCP(ctx context.Context, ifaceName string) (code int, er
}
}
err = s.Start()
err = s.Start(ctx)
if err != nil {
return http.StatusBadRequest, fmt.Errorf("starting dhcp server: %w", err)
}
@ -313,6 +314,7 @@ func (s *server) createServers(conf *dhcpServerConfigJSON) (srv4, srv6 DHCPServe
// HTTP API.
func (s *server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := s.conf.Logger
conf := &dhcpServerConfigJSON{}
conf.Enabled = aghalg.BoolToNullBool(s.conf.Enabled)
@ -320,21 +322,29 @@ func (s *server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
err := json.NewDecoder(r.Body).Decode(conf)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "failed to parse new dhcp config json: %s", err)
aghhttp.ErrorAndLog(
ctx,
l,
r,
w,
http.StatusBadRequest,
"failed to parse new dhcp config json: %s",
err,
)
return
}
srv4, srv6, err := s.createServers(conf)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "%s", err)
return
}
err = s.Stop()
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "stopping dhcp: %s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusInternalServerError, "stopping dhcp: %s", err)
return
}
@ -344,7 +354,15 @@ func (s *server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
err = s.dbLoad()
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "loading leases db: %s", err)
aghhttp.ErrorAndLog(
ctx,
l,
r,
w,
http.StatusInternalServerError,
"loading leases db: %s",
err,
)
return
}
@ -353,7 +371,7 @@ func (s *server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
var code int
code, err = s.enableDHCP(ctx, conf.InterfaceName)
if err != nil {
aghhttp.Error(r, w, code, "enabling dhcp: %s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, code, "enabling dhcp: %s", err)
}
}
}
@ -390,11 +408,22 @@ type netInterfaceJSON struct {
// handleDHCPInterfaces is the handler for the GET /control/dhcp/interfaces
// HTTP API.
func (s *server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := s.conf.Logger
resp := map[string]*netInterfaceJSON{}
ifaces, err := net.Interfaces()
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
aghhttp.ErrorAndLog(
ctx,
l,
r,
w,
http.StatusInternalServerError,
"Couldn't get interfaces: %s",
err,
)
return
}
@ -410,9 +439,9 @@ func (s *server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
continue
}
jsonIface, iErr := newNetInterfaceJSON(r.Context(), iface, s.conf.CommandConstructor)
jsonIface, iErr := newNetInterfaceJSON(ctx, s.conf.Logger, iface, s.conf.CommandConstructor)
if iErr != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", iErr)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusInternalServerError, "%s", iErr)
return
}
@ -426,9 +455,10 @@ func (s *server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
}
// newNetInterfaceJSON creates a JSON object from a [net.Interface] iface.
// cmdCons must not be nil.
// l and cmdCons must not be nil.
func newNetInterfaceJSON(
ctx context.Context,
l *slog.Logger,
iface net.Interface,
cmdCons executil.CommandConstructor,
) (out *netInterfaceJSON, err error) {
@ -483,7 +513,7 @@ func newNetInterfaceJSON(
return nil, nil
}
out.GatewayIP = aghnet.GatewayIP(ctx, cmdCons, iface.Name)
out.GatewayIP = aghnet.GatewayIP(ctx, l, cmdCons, iface.Name)
return out, nil
}
@ -533,21 +563,24 @@ type findActiveServerReq struct {
// 2. check if a static IP is configured for the network interface;
// 3. responds with the results.
func (s *server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) {
if aghhttp.WriteTextPlainDeprecated(w, r) {
ctx := r.Context()
l := s.conf.Logger
if aghhttp.WriteTextPlainDeprecated(ctx, l, w, r) {
return
}
req := &findActiveServerReq{}
err := json.NewDecoder(r.Body).Decode(req)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "reading req: %s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "reading req: %s", err)
return
}
ifaceName := req.Interface
if ifaceName == "" {
aghhttp.Error(r, w, http.StatusBadRequest, "empty interface name")
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "empty interface name")
return
}
@ -569,24 +602,28 @@ func (s *server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque
}
cmdCons := s.conf.CommandConstructor
if isStaticIP, serr := aghnet.IfaceHasStaticIP(r.Context(), cmdCons, ifaceName); serr != nil {
if isStaticIP, serr := aghnet.IfaceHasStaticIP(ctx, cmdCons, ifaceName); serr != nil {
result.V4.StaticIP.Static = "error"
result.V4.StaticIP.Error = serr.Error()
} else if !isStaticIP {
result.V4.StaticIP.Static = "no"
// TODO(e.burkov): The returned IP should only be of version 4.
result.V4.StaticIP.IP = aghnet.GetSubnet(ifaceName).String()
result.V4.StaticIP.IP = aghnet.GetSubnet(ctx, s.conf.Logger, ifaceName).String()
}
setOtherDHCPResult(ifaceName, result)
s.setOtherDHCPResult(ctx, ifaceName, result)
aghhttp.WriteJSONResponseOK(w, r, result)
}
// setOtherDHCPResult sets the results of the check for another DHCP server in
// result.
func setOtherDHCPResult(ifaceName string, result *dhcpSearchResult) {
found4, found6, err4, err6 := aghnet.CheckOtherDHCP(ifaceName)
// result. result must not be nil.
func (s *server) setOtherDHCPResult(
ctx context.Context,
ifaceName string,
result *dhcpSearchResult,
) {
found4, found6, err4, err6 := aghnet.CheckOtherDHCP(ctx, s.conf.Logger, ifaceName)
if err4 != nil {
result.V4.OtherServer.Found = "error"
result.V4.OtherServer.Error = err4.Error()
@ -634,52 +671,64 @@ func (s *server) parseLease(r io.Reader) (srv DHCPServer, lease *dhcpsvc.Lease,
// handleDHCPAddStaticLease is the handler for the POST
// /control/dhcp/add_static_lease HTTP API.
func (s *server) handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := s.conf.Logger
srv, lease, err := s.parseLease(r.Body)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "%s", err)
return
}
if err = srv.AddStaticLease(lease); err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "%s", err)
}
}
// handleDHCPRemoveStaticLease is the handler for the POST
// /control/dhcp/remove_static_lease HTTP API.
func (s *server) handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := s.conf.Logger
srv, lease, err := s.parseLease(r.Body)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "%s", err)
return
}
if err = srv.RemoveStaticLease(lease); err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "%s", err)
}
}
// handleDHCPUpdateStaticLease is the handler for the POST
// /control/dhcp/update_static_lease HTTP API.
func (s *server) handleDHCPUpdateStaticLease(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := s.conf.Logger
srv, lease, err := s.parseLease(r.Body)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "%s", err)
return
}
if err = srv.UpdateStaticLease(lease); err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "%s", err)
}
}
func (s *server) handleReset(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := s.conf.Logger
err := s.Stop()
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "stopping dhcp: %s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusInternalServerError, "stopping dhcp: %s", err)
return
}
@ -713,14 +762,16 @@ func (s *server) handleReset(w http.ResponseWriter, r *http.Request) {
}
s.srv6, _ = v6Create(v6conf)
s.conf.ConfModifier.Apply(r.Context())
s.conf.ConfModifier.Apply(ctx)
}
func (s *server) handleResetLeases(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
err := s.resetLeases()
if err != nil {
msg := "resetting leases: %s"
aghhttp.Error(r, w, http.StatusInternalServerError, msg, err)
aghhttp.ErrorAndLog(ctx, s.conf.Logger, r, w, http.StatusInternalServerError, msg, err)
return
}

View File

@ -11,6 +11,7 @@ import (
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/agh"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -84,7 +85,9 @@ func TestServer_handleDHCPStatus(t *testing.T) {
Hostname: staticName,
}
s, err := Create(&ServerConfig{
ctx := testutil.ContextWithTimeout(t, testTimeout)
s, err := Create(ctx, &ServerConfig{
Logger: testLogger,
Enabled: true,
Conf4: *defaultV4ServerConf(),
DataDir: t.TempDir(),
@ -178,7 +181,9 @@ func TestServer_HandleUpdateStaticLease(t *testing.T) {
},
}
s, err := Create(&ServerConfig{
ctx := testutil.ContextWithTimeout(t, testTimeout)
s, err := Create(ctx, &ServerConfig{
Logger: testLogger,
Enabled: true,
Conf4: *defaultV4ServerConf(),
Conf6: V6ServerConf{},
@ -266,7 +271,9 @@ func TestServer_HandleUpdateStaticLease_validation(t *testing.T) {
Hostname: anotherV4Name,
}}
s, err := Create(&ServerConfig{
ctx := testutil.ContextWithTimeout(t, testTimeout)
s, err := Create(ctx, &ServerConfig{
Logger: testLogger,
Enabled: true,
Conf4: *defaultV4ServerConf(),
Conf6: V6ServerConf{},

View File

@ -5,6 +5,7 @@ package dhcpd
// 'u-root/u-root' package, a dependency of 'insomniacslk/dhcp' package, doesn't build on Windows
import (
"context"
"net"
"net/netip"
@ -25,7 +26,7 @@ func (winServer) UpdateStaticLease(_ *dhcpsvc.Lease) (err error) { return
func (winServer) FindMACbyIP(_ netip.Addr) (mac net.HardwareAddr) { return nil }
func (winServer) WriteDiskConfig4(_ *V4ServerConf) {}
func (winServer) WriteDiskConfig6(_ *V6ServerConf) {}
func (winServer) Start() (err error) { return nil }
func (winServer) Start(_ context.Context) (err error) { return nil }
func (winServer) Stop() (err error) { return nil }
func (winServer) HostByIP(_ netip.Addr) (host string) { return "" }
func (winServer) IPByHost(_ string) (ip netip.Addr) { return netip.Addr{} }

View File

@ -4,6 +4,7 @@ package dhcpd
import (
"bytes"
"context"
"fmt"
"net"
"net/netip"
@ -1299,7 +1300,7 @@ func (s *v4Server) packetHandler(conn net.PacketConn, peer net.Addr, req *dhcpv4
}
// Start starts the IPv4 DHCP server.
func (s *v4Server) Start() (err error) {
func (s *v4Server) Start(ctx context.Context) (err error) {
defer func() { err = errors.Annotate(err, "dhcpv4: %w") }()
if !s.enabled() {
@ -1315,6 +1316,8 @@ func (s *v4Server) Start() (err error) {
log.Debug("dhcpv4: starting...")
dnsIPAddrs, err := aghnet.IfaceDNSIPAddrs(
ctx,
s.conf.Logger,
iface,
aghnet.IPVersion4,
defaultMaxAttempts,

View File

@ -4,6 +4,7 @@ package dhcpd
import (
"bytes"
"context"
"fmt"
"net"
"net/netip"
@ -657,8 +658,13 @@ func (s *v6Server) packetHandler(conn net.PacketConn, peer net.Addr, req dhcpv6.
// configureDNSIPAddrs updates v6Server configuration with the slice of DNS IP
// addresses of provided interface iface. Initializes RA module.
func (s *v6Server) configureDNSIPAddrs(iface *net.Interface) (ok bool, err error) {
func (s *v6Server) configureDNSIPAddrs(
ctx context.Context,
iface *net.Interface,
) (ok bool, err error) {
dnsIPAddrs, err := aghnet.IfaceDNSIPAddrs(
ctx,
s.conf.Logger,
iface,
aghnet.IPVersion6,
defaultMaxAttempts,
@ -700,7 +706,7 @@ func (s *v6Server) initRA(iface *net.Interface) (err error) {
}
// Start starts the IPv6 DHCP server.
func (s *v6Server) Start() (err error) {
func (s *v6Server) Start(ctx context.Context) (err error) {
defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
if !s.conf.Enabled {
@ -715,7 +721,7 @@ func (s *v6Server) Start() (err error) {
log.Debug("dhcpv6: starting...")
ok, err := s.configureDNSIPAddrs(iface)
ok, err := s.configureDNSIPAddrs(ctx, iface)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err

View File

@ -230,18 +230,19 @@ func validateStrUniq(clients []string) (uc aghalg.UniqChecker[string], err error
// handleAccessSet handles requests to the POST /control/access/set endpoint.
func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := s.logger
list := &accessListJSON{}
err := json.NewDecoder(r.Body).Decode(&list)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "decoding request: %s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "decoding request: %s", err)
return
}
err = validateAccessSet(list)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "%s", err)
return
}
@ -249,12 +250,12 @@ func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
var a *accessManager
a, err = newAccessCtx(list.AllowedClients, list.DisallowedClients, list.BlockedHosts)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "creating access ctx: %s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "creating access ctx: %s", err)
return
}
defer s.logger.DebugContext(
defer l.DebugContext(
ctx,
"updated access lists",
"allowed", len(list.AllowedClients),

View File

@ -209,11 +209,19 @@ func createServerTLSConfig(tb testing.TB) (*tls.Config, []byte, []byte) {
}
template.DNSNames = append(template.DNSNames, tlsServerName)
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey)
derBytes, err := x509.CreateCertificate(
rand.Reader,
&template,
&template,
publicKey(privateKey),
privateKey,
)
require.NoErrorf(tb, err, "failed to create certificate: %s", err)
certPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
keyPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})
keyPem := pem.EncodeToMemory(
&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)},
)
cert, err := tls.X509KeyPair(certPem, keyPem)
require.NoErrorf(tb, err, "failed to create certificate: %s", err)
@ -1065,10 +1073,22 @@ func TestNullBlockedRequest(t *testing.T) {
reply, err := dns.Exchange(&req, addr.String())
require.NoErrorf(t, err, "couldn't talk to server %s: %s", addr, err)
require.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
a, ok := reply.Answer[0].(*dns.A)
require.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
assert.Truef(t, a.A.IsUnspecified(), "dns server %s returned wrong answer instead of 0.0.0.0: %v", addr, a.A)
require.Lenf(
t,
reply.Answer,
1,
"dns server %s returned reply with wrong number of answers - %d",
addr,
len(reply.Answer),
)
a := testutil.RequireTypeAssert[*dns.A](t, reply.Answer[0])
assert.Truef(
t,
a.A.IsUnspecified(),
"dns server %s returned wrong answer instead of 0.0.0.0: %v",
addr,
a.A,
)
}
func TestBlockedCustomIP(t *testing.T) {
@ -1184,10 +1204,30 @@ func TestBlockedByHosts(t *testing.T) {
reply, err := dns.Exchange(req, addr.String())
require.NoErrorf(t, err, "couldn't talk to server %s: %s", addr, err)
require.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
require.Lenf(
t,
reply.Answer,
1,
"dns server %s returned reply with wrong number of answers - %d",
addr,
len(reply.Answer),
)
a, ok := reply.Answer[0].(*dns.A)
require.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0])
assert.Equalf(t, net.IP{127, 0, 0, 1}, a.A, "dns server %s returned wrong answer instead of 8.8.8.8: %v", addr, a.A)
require.Truef(
t,
ok,
"dns server %s returned wrong answer type instead of A: %v",
addr,
reply.Answer[0],
)
assert.Equalf(
t,
net.IP{127, 0, 0, 1},
a.A,
"dns server %s returned wrong answer instead of 8.8.8.8: %v",
addr,
a.A,
)
}
func TestBlockedBySafeBrowsing(t *testing.T) {
@ -1235,7 +1275,14 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
reply, err := dns.Exchange(req, addr.String())
require.NoErrorf(t, err, "couldn't talk to server %s: %s", addr, err)
require.Lenf(t, reply.Answer, 1, "dns server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer))
require.Lenf(
t,
reply.Answer,
1,
"dns server %s returned reply with wrong number of answers - %d",
addr,
len(reply.Answer),
)
assertResponse(t, reply, ans4)
}
@ -1469,7 +1516,8 @@ func TestPTRResponseFromHosts(t *testing.T) {
}
var eventsCalledCounter uint32
hc, err := aghnet.NewHostsContainer(testFS, &aghtest.FSWatcher{
ctx := testutil.ContextWithTimeout(t, testTimeout)
hc, err := aghnet.NewHostsContainer(ctx, testLogger, testFS, &aghtest.FSWatcher{
OnStart: func(ctx context.Context) (_ error) { panic(testutil.UnexpectedCall(ctx)) },
OnEvents: func() (e <-chan struct{}) {
assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1))
@ -1646,7 +1694,9 @@ func TestServer_Exchange(t *testing.T) {
extUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := cmp.Or(
aghtest.MatchedResponse(req, dns.TypePTR, onesRevExtIPv4, dns.Fqdn(onesHost)),
doubleTTL(aghtest.MatchedResponse(req, dns.TypePTR, twosRevExtIPv4, dns.Fqdn(twosHost))),
doubleTTL(
aghtest.MatchedResponse(req, dns.TypePTR, twosRevExtIPv4, dns.Fqdn(twosHost)),
),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
)

View File

@ -520,11 +520,12 @@ func checkInclusion(ptr *int, minN, maxN int) (err error) {
// handleSetConfig handles requests to the POST /control/dns_config endpoint.
func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := s.logger
req := &jsonDNSConfig{}
err := json.NewDecoder(r.Body).Decode(req)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "decoding request: %s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "decoding request: %s", err)
return
}
@ -533,14 +534,22 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
ourAddrs, err := s.conf.ourAddrsSet()
if err != nil {
// TODO(e.burkov): Put into openapi.
aghhttp.Error(r, w, http.StatusInternalServerError, "getting our addresses: %s", err)
aghhttp.ErrorAndLog(
ctx,
l,
r,
w,
http.StatusInternalServerError,
"getting our addresses: %s",
err,
)
return
}
err = req.validate(ourAddrs, s.sysResolvers, s.privateNets, s.conf.CacheSize)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "%s", err)
return
}
@ -551,7 +560,7 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
if restart {
err = s.Reconfigure(ctx, nil)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusInternalServerError, "%s", err)
}
}
}
@ -682,10 +691,21 @@ func closeBoots(boots []*upstream.UpstreamResolver) {
// handleTestUpstreamDNS handles requests to the POST /control/test_upstream_dns
// endpoint.
func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := s.logger
req := &upstreamJSON{}
err := json.NewDecoder(r.Body).Decode(req)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to read request body: %s", err)
aghhttp.ErrorAndLog(
ctx,
l,
r,
w,
http.StatusBadRequest,
"Failed to read request body: %s",
err,
)
return
}
@ -701,7 +721,15 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
var boots []*upstream.UpstreamResolver
opts.Bootstrap, boots, err = newBootstrap(req.BootstrapDNS, s.etcHosts, opts)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse bootstrap servers: %s", err)
aghhttp.ErrorAndLog(
ctx,
l,
r,
w,
http.StatusBadRequest,
"Failed to parse bootstrap servers: %s",
err,
)
return
}
@ -730,10 +758,13 @@ type protectionJSON struct {
// handleSetProtection is a handler for the POST /control/protection HTTP API.
func (s *Server) handleSetProtection(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := s.logger
protectionReq := &protectionJSON{}
err := json.NewDecoder(r.Body).Decode(protectionReq)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "reading req: %s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "reading req: %s", err)
return
}
@ -741,7 +772,9 @@ func (s *Server) handleSetProtection(w http.ResponseWriter, r *http.Request) {
var disabledUntil *time.Time
if protectionReq.Duration > 0 {
if protectionReq.Enabled {
aghhttp.Error(
aghhttp.ErrorAndLog(
ctx,
l,
r,
w,
http.StatusBadRequest,
@ -762,9 +795,9 @@ func (s *Server) handleSetProtection(w http.ResponseWriter, r *http.Request) {
s.dnsFilter.SetProtectionStatus(protectionReq.Enabled, disabledUntil)
}()
s.conf.ConfModifier.Apply(r.Context())
s.conf.ConfModifier.Apply(ctx)
aghhttp.OK(w)
aghhttp.OK(ctx, s.logger, w)
}
// handleDoH is the DNS-over-HTTPs handler.
@ -778,14 +811,24 @@ func (s *Server) handleSetProtection(w http.ResponseWriter, r *http.Request) {
// -> proxy.handleDNSRequest
// -> dnsforward.handleDNSRequest
func (s *Server) handleDoH(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := s.logger
if !s.conf.TLSAllowUnencryptedDoH && r.TLS == nil {
aghhttp.Error(r, w, http.StatusNotFound, "Not Found")
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusNotFound, "Not Found")
return
}
if !s.IsRunning() {
aghhttp.Error(r, w, http.StatusInternalServerError, "dns server is not running")
aghhttp.ErrorAndLog(
ctx,
l,
r,
w,
http.StatusInternalServerError,
"dns server is not running",
)
return
}

View File

@ -292,7 +292,11 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(func() {
s.dnsFilter.SetBlockingMode(filtering.BlockingModeDefault, netip.Addr{}, netip.Addr{})
s.dnsFilter.SetBlockingMode(
filtering.BlockingModeDefault,
netip.Addr{},
netip.Addr{},
)
s.conf = defaultConf
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{}
s.dnsFilter.SetBlockedResponseTTL(testBlockedRespTTL)
@ -362,7 +366,10 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
Host: netutil.JoinHostPort(upstreamHost, hostsListener.Port()),
}).String()
ctx := testutil.ContextWithTimeout(t, testTimeout)
hc, err := aghnet.NewHostsContainer(
ctx,
testLogger,
fstest.MapFS{
hostsFileName: &fstest.MapFile{
Data: []byte(hostsListener.Addr().String() + " " + upstreamHost),

View File

@ -166,7 +166,7 @@ func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Requ
list := []string{}
err := json.NewDecoder(r.Body).Decode(&list)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
aghhttp.ErrorAndLog(ctx, d.logger, r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
@ -200,18 +200,19 @@ func (d *DNSFilter) handleBlockedServicesGet(w http.ResponseWriter, r *http.Requ
// /control/blocked_services/update HTTP API.
func (d *DNSFilter) handleBlockedServicesUpdate(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := d.logger
bsvc := &BlockedServices{}
err := json.NewDecoder(r.Body).Decode(bsvc)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
err = bsvc.Validate()
if err != nil {
aghhttp.Error(r, w, http.StatusUnprocessableEntity, "validating: %s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusUnprocessableEntity, "validating: %s", err)
return
}
@ -227,7 +228,7 @@ func (d *DNSFilter) handleBlockedServicesUpdate(w http.ResponseWriter, r *http.R
d.conf.BlockedServices = bsvc
}()
d.logger.DebugContext(ctx, "updated blocked services schedule", "len", len(bsvc.IDs))
l.DebugContext(ctx, "updated blocked services schedule", "len", len(bsvc.IDs))
d.conf.ConfModifier.Apply(ctx)
}

View File

@ -6,7 +6,6 @@ import (
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
@ -53,7 +52,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
`
conf := &filtering.Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
SafeBrowsingCacheSize: 10000,
ParentalCacheSize: 10000,
SafeSearchCacheSize: 1000,

View File

@ -8,18 +8,13 @@ import (
"os"
"path/filepath"
"testing"
"time"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testTimeout is the common timeout for tests.
const testTimeout = 5 * time.Second
// serveHTTPLocally starts a new HTTP server, that handles its index with h. It
// also gracefully closes the listener when the test under t finishes.
func serveHTTPLocally(tb testing.TB, h http.Handler) (urlStr string) {
@ -86,7 +81,7 @@ func newDNSFilter(tb testing.TB) (d *DNSFilter) {
tb.Helper()
dnsFilter, err := New(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
DataDir: tb.TempDir(),
HTTPClient: &http.Client{
Timeout: testTimeout,

View File

@ -6,6 +6,7 @@ import (
"fmt"
"net/netip"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/hashprefix"
@ -18,6 +19,9 @@ import (
"github.com/stretchr/testify/require"
)
// testTimeout is a common timeout for tests.
const testTimeout = 1 * time.Second
const (
sbBlocked = "wmconvirus.narod.ru"
pcBlocked = "pornhub.com"

View File

@ -0,0 +1,13 @@
package filtering_test
import (
"time"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
// testTimeout is a common timeout for tests.
const testTimeout = 1 * time.Second
// testLogger is a logger used in tests.
var testLogger = slogutil.NewDiscardLogger()

View File

@ -22,6 +22,9 @@ const (
cacheSize = 10000
)
// testLogger is a logger used in tests.
var testLogger = slogutil.NewDiscardLogger()
func TestChcker_getQuestion(t *testing.T) {
const suf = "sb.dns.adguard.com."
@ -45,7 +48,7 @@ func TestChcker_getQuestion(t *testing.T) {
assert.False(t, slices.Contains(hashes, hash))
c := New(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
TXTSuffix: suf,
})
@ -100,7 +103,7 @@ func TestChecker_storeInCache(t *testing.T) {
const testTimeout = 1 * time.Second
c := New(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
CacheTime: cacheTime,
})
@ -158,7 +161,7 @@ func TestChecker_storeInCache(t *testing.T) {
assert.True(t, ok)
c = New(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
CacheTime: cacheTime,
})
@ -195,7 +198,7 @@ func TestChecker_Check(t *testing.T) {
for _, tc := range testCases {
c := New(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
CacheTime: cacheTime,
CacheSize: cacheSize,
})

View File

@ -11,7 +11,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
@ -49,12 +48,14 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) {
OnAdd: func(name string) (err error) { return nil },
OnShutdown: func(_ context.Context) (err error) { return nil },
}
hc, err := aghnet.NewHostsContainer(files, watcher, "hosts")
ctx := testutil.ContextWithTimeout(t, testTimeout)
hc, err := aghnet.NewHostsContainer(ctx, testLogger, files, watcher, "hosts")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, hc.Close)
conf := &filtering.Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
EtcHosts: hc,
}
f, err := filtering.New(conf, nil)

View File

@ -63,17 +63,28 @@ type filterAddJSON struct {
}
func (d *DNSFilter) handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := d.logger
fj := filterAddJSON{}
err := json.NewDecoder(r.Body).Decode(&fj)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse request body json: %s", err)
aghhttp.ErrorAndLog(
ctx,
l,
r,
w,
http.StatusBadRequest,
"Failed to parse request body json: %s",
err,
)
return
}
err = d.validateFilterURL(fj.URL)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "%s", err)
return
}
@ -81,7 +92,16 @@ func (d *DNSFilter) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
// Check for duplicates
if d.filterExists(fj.URL) {
err = errFilterExists
aghhttp.Error(r, w, http.StatusBadRequest, "Filter with URL %q: %s", fj.URL, err)
aghhttp.ErrorAndLog(
ctx,
l,
r,
w,
http.StatusBadRequest,
"Filter with URL %q: %s",
fj.URL,
err,
)
return
}
@ -100,7 +120,9 @@ func (d *DNSFilter) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
// Download the filter contents
ok, err := d.update(&filt)
if err != nil {
aghhttp.Error(
aghhttp.ErrorAndLog(
ctx,
l,
r,
w,
http.StatusBadRequest,
@ -113,7 +135,9 @@ func (d *DNSFilter) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
}
if !ok {
aghhttp.Error(
aghhttp.ErrorAndLog(
ctx,
l,
r,
w,
http.StatusBadRequest,
@ -128,17 +152,34 @@ func (d *DNSFilter) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
// file and reload it to engines.
err = d.filterAdd(filt)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "Filter with URL %q: %s", filt.URL, err)
aghhttp.ErrorAndLog(
ctx,
l,
r,
w,
http.StatusBadRequest,
"Filter with URL %q: %s",
filt.URL,
err,
)
return
}
d.conf.ConfModifier.Apply(r.Context())
d.conf.ConfModifier.Apply(ctx)
d.EnableFilters(true)
_, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "Couldn't write body: %s", err)
aghhttp.ErrorAndLog(
ctx,
l,
r,
w,
http.StatusInternalServerError,
"Couldn't write body: %s",
err,
)
}
}
@ -153,7 +194,15 @@ func (d *DNSFilter) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
req := request{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "failed to parse request body json: %s", err)
aghhttp.ErrorAndLog(
ctx,
d.logger,
r,
w,
http.StatusBadRequest,
"failed to parse request body json: %s",
err,
)
return
}
@ -213,7 +262,15 @@ func (d *DNSFilter) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
_, err = fmt.Fprintf(w, "OK %d rules\n", deleted.RulesCount)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "couldn't write body: %s", err)
aghhttp.ErrorAndLog(
ctx,
d.logger,
r,
w,
http.StatusInternalServerError,
"couldn't write body: %s",
err,
)
}
}
@ -230,23 +287,34 @@ type filterURLReq struct {
}
func (d *DNSFilter) handleFilteringSetURL(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := d.logger
fj := filterURLReq{}
err := json.NewDecoder(r.Body).Decode(&fj)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "decoding request: %s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "decoding request: %s", err)
return
}
if fj.Data == nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", errors.Error("data is absent"))
aghhttp.ErrorAndLog(
ctx,
l,
r,
w,
http.StatusBadRequest,
"%s",
errors.Error("data is absent"),
)
return
}
err = d.validateFilterURL(fj.Data.URL)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "invalid url: %s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "invalid url: %s", err)
return
}
@ -259,12 +327,12 @@ func (d *DNSFilter) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
restart, err := d.filterSetProperties(fj.URL, filt, fj.Whitelist)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "%s", err)
return
}
d.conf.ConfModifier.Apply(r.Context())
d.conf.ConfModifier.Apply(ctx)
if restart {
d.EnableFilters(true)
}
@ -276,20 +344,22 @@ type filteringRulesReq struct {
}
func (d *DNSFilter) handleFilteringSetRules(w http.ResponseWriter, r *http.Request) {
if aghhttp.WriteTextPlainDeprecated(w, r) {
ctx := r.Context()
if aghhttp.WriteTextPlainDeprecated(ctx, d.logger, w, r) {
return
}
req := &filteringRulesReq{}
err := json.NewDecoder(r.Body).Decode(req)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "reading req: %s", err)
aghhttp.ErrorAndLog(ctx, d.logger, r, w, http.StatusBadRequest, "reading req: %s", err)
return
}
d.conf.UserRules = req.Rules
d.conf.ConfModifier.Apply(r.Context())
d.conf.ConfModifier.Apply(ctx)
d.EnableFilters(true)
}
@ -299,10 +369,12 @@ func (d *DNSFilter) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
}
var err error
ctx := r.Context()
req := Req{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
aghhttp.ErrorAndLog(ctx, d.logger, r, w, http.StatusBadRequest, "json decode: %s", err)
return
}
@ -313,7 +385,9 @@ func (d *DNSFilter) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
}{}
resp.Updated, _, ok = d.tryRefreshFilters(!req.White, req.White, true)
if !ok {
aghhttp.Error(
aghhttp.ErrorAndLog(
ctx,
d.logger,
r,
w,
http.StatusInternalServerError,
@ -385,16 +459,19 @@ func (d *DNSFilter) handleFilteringStatus(w http.ResponseWriter, r *http.Request
// Set filtering configuration
func (d *DNSFilter) handleFilteringConfig(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := d.logger
req := filteringConfig{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "json decode: %s", err)
return
}
if !ValidateUpdateIvl(req.Interval) {
aghhttp.Error(r, w, http.StatusBadRequest, "Unsupported interval")
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "Unsupported interval")
return
}
@ -407,7 +484,7 @@ func (d *DNSFilter) handleFilteringConfig(w http.ResponseWriter, r *http.Request
d.conf.FiltersUpdateIntervalHours = req.Interval
}()
d.conf.ConfModifier.Apply(r.Context())
d.conf.ConfModifier.Apply(ctx)
d.EnableFilters(true)
}
@ -443,10 +520,14 @@ type checkHostResp struct {
// handleCheckHost is the handler for the GET /control/filtering/check_host HTTP
// API.
func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
query := r.URL.Query()
host := query.Get("name")
if host == "" {
aghhttp.Error(
aghhttp.ErrorAndLog(
ctx,
d.logger,
r,
w,
http.StatusBadRequest,
@ -460,7 +541,9 @@ func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
qTypeStr := query.Get("qtype")
qType, err := stringToDNSType(qTypeStr)
if err != nil {
aghhttp.Error(
aghhttp.ErrorAndLog(
ctx,
d.logger,
r,
w,
http.StatusUnprocessableEntity,
@ -487,7 +570,9 @@ func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
result, err := d.CheckHost(host, qType, setts)
if err != nil {
aghhttp.Error(
aghhttp.ErrorAndLog(
ctx,
d.logger,
r,
w,
http.StatusInternalServerError,

View File

@ -14,7 +14,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -110,7 +109,7 @@ func TestDNSFilter_handleFilteringSetURL(t *testing.T) {
confModifiedCalled = true
}
d, err := New(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
FilteringEnabled: true,
Filters: tc.initial,
HTTPClient: &http.Client{
@ -195,7 +194,7 @@ func TestDNSFilter_handleSafeBrowsingStatus(t *testing.T) {
}
d, err := New(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
ConfModifier: confModifier,
DataDir: filtersDir,
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
@ -282,7 +281,7 @@ func TestDNSFilter_handleParentalStatus(t *testing.T) {
}
d, err := New(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
ConfModifier: confModifier,
DataDir: filtersDir,
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
@ -384,7 +383,7 @@ func TestDNSFilter_HandleCheckHost(t *testing.T) {
}
dnsFilter, err := New(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
BlockedServices: &BlockedServices{
Schedule: schedule.EmptyWeekly(),
},

View File

@ -4,7 +4,6 @@ import (
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/stretchr/testify/assert"
)
@ -65,7 +64,7 @@ func TestIDGenerator_Fix(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
g := newIDGenerator(1, slogutil.NewDiscardLogger())
g := newIDGenerator(1, testLogger)
g.fix(tc.in)
assertUniqueIDs(t, tc.in)

View File

@ -16,6 +16,9 @@ import (
// testListID is the common rule-list ID for tests.
const testListID rules.ListID = 1
// testLogger is a logger used in tests.
var testLogger = slogutil.NewDiscardLogger()
func TestNewDefaultStorage(t *testing.T) {
items := []*Item{{
Domain: "example.com",
@ -23,7 +26,7 @@ func TestNewDefaultStorage(t *testing.T) {
}}
s, err := NewDefaultStorage(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
Rewrites: items,
ListID: testListID,
})
@ -36,7 +39,7 @@ func TestDefaultStorage_CRUD(t *testing.T) {
var items []*Item
s, err := NewDefaultStorage(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
Rewrites: items,
ListID: testListID,
})
@ -125,7 +128,7 @@ func TestDefaultStorage_MatchRequest(t *testing.T) {
}}
s, err := NewDefaultStorage(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
Rewrites: items,
ListID: testListID,
})
@ -301,7 +304,7 @@ func TestDefaultStorage_MatchRequest_Levels(t *testing.T) {
}}
s, err := NewDefaultStorage(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
Rewrites: items,
ListID: testListID,
})
@ -373,7 +376,7 @@ func TestDefaultStorage_MatchRequest_ExceptionCNAME(t *testing.T) {
}}
s, err := NewDefaultStorage(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
Rewrites: items,
ListID: testListID,
})
@ -441,7 +444,7 @@ func TestDefaultStorage_MatchRequest_ExceptionIP(t *testing.T) {
}}
s, err := NewDefaultStorage(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
Rewrites: items,
ListID: testListID,
})

View File

@ -51,11 +51,12 @@ func (d *DNSFilter) handleRewriteList(w http.ResponseWriter, r *http.Request) {
// handleRewriteAdd is the handler for the POST /control/rewrite/add HTTP API.
func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := d.logger
rwJSON := rewriteEntryJSON{}
err := json.NewDecoder(r.Body).Decode(&rwJSON)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
@ -71,11 +72,11 @@ func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
Enabled: enabled,
}
err = rw.normalize(ctx, d.logger)
err = rw.normalize(ctx, l)
if err != nil {
// Shouldn't happen currently, since normalize only returns a non-nil
// error when a rewrite is nil, but be change-proof.
aghhttp.Error(r, w, http.StatusBadRequest, "normalizing: %s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "normalizing: %s", err)
return
}
@ -85,7 +86,7 @@ func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
defer d.confMu.Unlock()
d.conf.Rewrites = append(d.conf.Rewrites, rw)
d.logger.DebugContext(
l.DebugContext(
ctx,
"added rewrite element",
"domain", rw.Domain,
@ -101,11 +102,12 @@ func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
// API.
func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := d.logger
jsent := rewriteEntryJSON{}
err := json.NewDecoder(r.Body).Decode(&jsent)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
@ -128,7 +130,7 @@ func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request)
continue
}
d.logger.DebugContext(
l.DebugContext(
ctx,
"removed rewrite element",
"domain", ent.Domain,
@ -149,11 +151,12 @@ type rewriteUpdateJSON struct {
// API.
func (d *DNSFilter) handleRewriteUpdate(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := d.logger
updateJSON := rewriteUpdateJSON{}
err := json.NewDecoder(r.Body).Decode(&updateJSON)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
@ -168,11 +171,11 @@ func (d *DNSFilter) handleRewriteUpdate(w http.ResponseWriter, r *http.Request)
Answer: updateJSON.Update.Answer,
}
err = rwAdd.normalize(ctx, d.logger)
err = rwAdd.normalize(ctx, l)
if err != nil {
// Shouldn't happen currently, since normalize only returns a non-nil
// error when a rewrite is nil, but be change-proof.
aghhttp.Error(r, w, http.StatusBadRequest, "normalizing: %s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "normalizing: %s", err)
return
}
@ -189,7 +192,7 @@ func (d *DNSFilter) handleRewriteUpdate(w http.ResponseWriter, r *http.Request)
index = slices.IndexFunc(d.conf.Rewrites, rwDel.equal)
if index == -1 {
aghhttp.Error(r, w, http.StatusBadRequest, "target rule not found")
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "target rule not found")
return
}
@ -203,14 +206,14 @@ func (d *DNSFilter) handleRewriteUpdate(w http.ResponseWriter, r *http.Request)
d.conf.Rewrites = slices.Replace(d.conf.Rewrites, index, index+1, rwAdd)
d.logger.DebugContext(
l.DebugContext(
ctx,
"removed rewrite element",
"domain", rwDel.Domain,
"answer", rwDel.Answer,
"enabled", rwDel.Enabled,
)
d.logger.DebugContext(
l.DebugContext(
ctx,
"added rewrite element",
"domain", rwAdd.Domain,

View File

@ -9,12 +9,10 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -42,9 +40,6 @@ type rewriteUpdateJSON struct {
}
const (
// testTimeout is the common timeout for tests.
testTimeout = 100 * time.Millisecond
listURL = "/control/rewrite/list"
addURL = "/control/rewrite/add"
deleteURL = "/control/rewrite/delete"
@ -264,7 +259,7 @@ func TestDNSFilter_HandleRewriteHTTP(t *testing.T) {
}
d, err := filtering.New(&filtering.Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
ConfModifier: confModifier,
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
handlers[url] = handler
@ -365,7 +360,7 @@ func TestDNSFilter_HandleRewriteSettings(t *testing.T) {
handlers := make(map[string]http.Handler)
d, err := filtering.New(&filtering.Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
ConfModifier: confModifier,
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
handlers[url] = handler

View File

@ -41,6 +41,9 @@ var testConf = filtering.SafeSearchConfig{
YouTube: true,
}
// testLogger is a logger used in tests.
var testLogger = slogutil.NewDiscardLogger()
// yandexIP is the expected IP address of Yandex safe search results. Keep in
// sync with the rules data.
var yandexIP = netip.AddrFrom4([4]byte{213, 180, 193, 56})
@ -49,7 +52,7 @@ func TestDefault_CheckHost_yandex(t *testing.T) {
conf := testConf
ctx := testutil.ContextWithTimeout(t, testTimeout)
ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
ServicesConfig: conf,
CacheSize: testCacheSize,
CacheTTL: testCacheTTL,
@ -111,7 +114,7 @@ func TestDefault_CheckHost_yandex(t *testing.T) {
func TestDefault_CheckHost_google(t *testing.T) {
ctx := testutil.ContextWithTimeout(t, testTimeout)
ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
ServicesConfig: testConf,
CacheSize: testCacheSize,
CacheTTL: testCacheTTL,
@ -163,7 +166,7 @@ func (r *testResolver) LookupIP(
func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) {
ctx := testutil.ContextWithTimeout(t, testTimeout)
ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
ServicesConfig: testConf,
CacheSize: testCacheSize,
CacheTTL: testCacheTTL,
@ -186,7 +189,7 @@ func TestDefault_Update(t *testing.T) {
conf := testConf
ctx := testutil.ContextWithTimeout(t, testTimeout)
ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
ServicesConfig: conf,
CacheSize: testCacheSize,
CacheTTL: testCacheTTL,

View File

@ -43,11 +43,12 @@ func (d *DNSFilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Reques
// HTTP API.
func (d *DNSFilter) handleSafeSearchSettings(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := d.logger
req := &SafeSearchConfig{}
err := json.NewDecoder(r.Body).Decode(req)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "reading req: %s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "reading req: %s", err)
return
}
@ -55,7 +56,7 @@ func (d *DNSFilter) handleSafeSearchSettings(w http.ResponseWriter, r *http.Requ
conf := *req
err = d.safeSearch.Update(ctx, conf)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "updating: %s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "updating: %s", err)
return
}
@ -69,5 +70,5 @@ func (d *DNSFilter) handleSafeSearchSettings(w http.ResponseWriter, r *http.Requ
d.conf.ConfModifier.Apply(ctx)
aghhttp.OK(w)
aghhttp.OK(ctx, l, w)
}

View File

@ -103,7 +103,7 @@ func (web *webAPI) handleLogin(w http.ResponseWriter, r *http.Request) {
req := loginJSON{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "json decode: %s", err)
aghhttp.ErrorAndLog(ctx, web.logger, r, w, http.StatusBadRequest, "json decode: %s", err)
return
}
@ -173,7 +173,7 @@ func (web *webAPI) handleLogin(w http.ResponseWriter, r *http.Request) {
h.Set(httphdr.Pragma, "no-cache")
h.Set(httphdr.Expires, "0")
aghhttp.OK(w)
aghhttp.OK(ctx, web.logger, w)
}
// newCookie creates a new authentication cookie. rateLimiter must not be nil.
@ -264,11 +264,11 @@ func (web *webAPI) handleLogout(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusFound)
}
// RegisterAuthHandlers - register handlers
func RegisterAuthHandlers(web *webAPI) {
globalContext.mux.Handle(
// registerAuthHandlers registers authentication handlers.
func (web *webAPI) registerAuthHandlers() {
web.conf.mux.Handle(
"/control/login",
postInstallHandler(ensureHandler(http.MethodPost, web.handleLogin)),
web.postInstallHandler(ensure(http.MethodPost, web.handleLogin)),
)
httpRegister(http.MethodGet, "/control/logout", web.handleLogout)
}

View File

@ -321,8 +321,6 @@ func authRequest(path string, c *http.Cookie, user, pass string) (r *http.Reques
func TestAuth_ServeHTTP_firstRun(t *testing.T) {
storeGlobals(t)
globalContext.firstRun = true
mux := http.NewServeMux()
globalContext.mux = mux
@ -335,8 +333,10 @@ func TestAuth_ServeHTTP_firstRun(t *testing.T) {
testLogger,
nil,
nil,
mux,
agh.EmptyConfigModifier{},
false,
true,
)
require.NoError(t, err)
@ -484,7 +484,8 @@ func TestAuth_ServeHTTP_auth(t *testing.T) {
t.Cleanup(func() { auth.close(testutil.ContextWithTimeout(t, testTimeout)) })
globalContext.mux = http.NewServeMux()
baseMux := http.NewServeMux()
globalContext.mux = baseMux
tlsMgr, err := newTLSManager(testutil.ContextWithTimeout(t, testTimeout), &tlsManagerConfig{
logger: testLogger,
@ -501,17 +502,19 @@ func TestAuth_ServeHTTP_auth(t *testing.T) {
testLogger,
tlsMgr,
auth,
baseMux,
agh.EmptyConfigModifier{},
false,
false,
)
require.NoError(t, err)
globalContext.web = web
mux := auth.middleware().Wrap(globalContext.mux)
mux := auth.middleware().Wrap(baseMux)
auth.isGLiNet = true
gliNetMw := auth.middleware().Wrap(globalContext.mux)
gliNetMw := auth.middleware().Wrap(baseMux)
loginCookie := generateAuthCookie(t, mux, userName, userPassword)
@ -642,24 +645,28 @@ func TestAuth_ServeHTTP_logout(t *testing.T) {
t.Cleanup(func() { auth.close(testutil.ContextWithTimeout(t, testTimeout)) })
globalContext.mux = http.NewServeMux()
baseMux := http.NewServeMux()
globalContext.mux = baseMux
ctx := testutil.ContextWithTimeout(t, testTimeout)
web, err := initWeb(ctx,
web, err := initWeb(
ctx,
options{},
nil,
nil,
testLogger,
nil,
auth,
baseMux,
agh.EmptyConfigModifier{},
false,
false,
)
require.NoError(t, err)
globalContext.web = web
mux := auth.middleware().Wrap(globalContext.mux)
mux := auth.middleware().Wrap(baseMux)
loginCookie := generateAuthCookie(t, mux, userName, userPassword)

View File

@ -327,25 +327,34 @@ func clientToJSON(c *client.Persistent) (cj *clientJSON) {
// handleAddClient is the handler for POST /control/clients/add HTTP API.
func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := clients.logger
cj := clientJSON{}
err := json.NewDecoder(r.Body).Decode(&cj)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "failed to process request body: %s", err)
aghhttp.ErrorAndLog(
ctx,
l,
r,
w,
http.StatusBadRequest,
"failed to process request body: %s",
err,
)
return
}
c, err := clients.jsonToClient(ctx, cj, nil)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "%s", err)
return
}
err = clients.storage.Add(ctx, c)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "%s", err)
return
}
@ -356,23 +365,32 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
// handleDelClient is the handler for POST /control/clients/delete HTTP API.
func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := clients.logger
cj := clientJSON{}
err := json.NewDecoder(r.Body).Decode(&cj)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "failed to process request body: %s", err)
aghhttp.ErrorAndLog(
ctx,
l,
r,
w,
http.StatusBadRequest,
"failed to process request body: %s",
err,
)
return
}
if len(cj.Name) == 0 {
aghhttp.Error(r, w, http.StatusBadRequest, "client's name must be non-empty")
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "client's name must be non-empty")
return
}
if !clients.storage.RemoveByName(ctx, cj.Name) {
aghhttp.Error(r, w, http.StatusBadRequest, "Client not found")
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "Client not found")
return
}
@ -391,31 +409,40 @@ type updateJSON struct {
// TODO(s.chzhen): Accept updated parameters instead of whole structure.
func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := clients.logger
dj := updateJSON{}
err := json.NewDecoder(r.Body).Decode(&dj)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "failed to process request body: %s", err)
aghhttp.ErrorAndLog(
ctx,
l,
r,
w,
http.StatusBadRequest,
"failed to process request body: %s",
err,
)
return
}
if len(dj.Name) == 0 {
aghhttp.Error(r, w, http.StatusBadRequest, "Invalid request")
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "Invalid request")
return
}
c, err := clients.jsonToClient(ctx, dj.Data, nil)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "%s", err)
return
}
err = clients.storage.Update(ctx, dj.Name, c)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "%s", err)
return
}
@ -502,10 +529,20 @@ type searchClientJSON struct {
// handleSearchClient is the handler for the POST /control/clients/search HTTP
// API.
func (clients *clientsContainer) handleSearchClient(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
q := searchQueryJSON{}
err := json.NewDecoder(r.Body).Decode(&q)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "failed to process request body: %s", err)
aghhttp.ErrorAndLog(
ctx,
clients.logger,
r,
w,
http.StatusBadRequest,
"failed to process request body: %s",
err,
)
return
}
@ -518,7 +555,7 @@ func (clients *clientsContainer) handleSearchClient(w http.ResponseWriter, r *ht
err = params.Set(idStr)
if err != nil {
clients.logger.DebugContext(
r.Context(),
ctx,
"searching client",
"id", idStr,
slogutil.KeyError, err,

View File

@ -640,12 +640,14 @@ func parseConfig(ctx context.Context, l *slog.Logger) (err error) {
}
migrator := configmigrate.New(&configmigrate.Config{
Logger: l.With(slogutil.KeyPrefix, "config_migrator"),
WorkingDir: globalContext.workDir,
DataDir: globalContext.getDataDir(),
})
var upgraded bool
config.fileData, upgraded, err = migrator.Migrate(
ctx,
config.fileData,
configmigrate.LastSchemaVersion,
)

View File

@ -121,7 +121,7 @@ func (web *webAPI) handleStatus(w http.ResponseWriter, r *http.Request) {
if err != nil {
// Don't add a lot of formatting, since the error is already
// wrapped by collectDNSAddresses.
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
aghhttp.ErrorAndLog(ctx, web.logger, r, w, http.StatusInternalServerError, "%s", err)
return
}
@ -171,9 +171,13 @@ func (web *webAPI) handleStatus(w http.ResponseWriter, r *http.Request) {
}
// registerControlHandlers sets up HTTP handlers for various control endpoints.
// web must not be nil.
func registerControlHandlers(web *webAPI) {
globalContext.mux.HandleFunc("/control/version.json", postInstall(web.handleVersionJSON))
func (web *webAPI) registerControlHandlers() {
mux := web.conf.mux
mux.Handle(
"/control/version.json",
web.postInstallHandler(http.HandlerFunc(web.handleVersionJSON)),
)
httpRegister(http.MethodPost, "/control/update", web.handleUpdate)
httpRegister(http.MethodGet, "/control/status", web.handleStatus)
@ -182,23 +186,32 @@ func registerControlHandlers(web *webAPI) {
httpRegister(http.MethodGet, "/control/profile", web.handleGetProfile)
httpRegister(http.MethodPut, "/control/profile/update", web.handlePutProfile)
// No auth is necessary for DoH/DoT configurations
globalContext.mux.HandleFunc("/apple/doh.mobileconfig", postInstall(handleMobileConfigDoH))
globalContext.mux.HandleFunc("/apple/dot.mobileconfig", postInstall(handleMobileConfigDoT))
RegisterAuthHandlers(web)
// No authentication is required for DoH/DoT configuration endpoints.
mux.Handle(
"/apple/doh.mobileconfig",
web.postInstallHandler(http.HandlerFunc(handleMobileConfigDoH)),
)
mux.Handle(
"/apple/dot.mobileconfig",
web.postInstallHandler(http.HandlerFunc(handleMobileConfigDoT)),
)
web.registerAuthHandlers()
}
// httpRegister registers an HTTP handler.
//
// TODO(s.chzhen): Do not use [globalContext.mux].
func httpRegister(method, url string, handler http.HandlerFunc) {
if method == "" {
// "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method
globalContext.mux.HandleFunc(url, postInstall(handler))
globalContext.mux.Handle(url, postInstallHandler(handler))
return
}
globalContext.mux.Handle(
url,
postInstallHandler(gziphandler.GzipHandler(ensureHandler(method, handler))),
postInstallHandler(gziphandler.GzipHandler(ensure(method, handler))),
)
}
@ -207,7 +220,7 @@ func httpRegister(method, url string, handler http.HandlerFunc) {
func ensure(
method string,
handler func(http.ResponseWriter, *http.Request),
) (wrapped func(http.ResponseWriter, *http.Request)) {
) (wrapped http.HandlerFunc) {
return func(w http.ResponseWriter, r *http.Request) {
m := r.Method
if m != method {
@ -265,53 +278,20 @@ func ensureContentType(w http.ResponseWriter, r *http.Request) (ok bool) {
return false
}
func ensurePOST(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return ensure(http.MethodPost, handler)
}
func ensureGET(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return ensure(http.MethodGet, handler)
}
// Bridge between http.Handler object and Go function
type httpHandler struct {
handler func(http.ResponseWriter, *http.Request)
}
func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.handler(w, r)
}
func ensureHandler(method string, handler func(http.ResponseWriter, *http.Request)) http.Handler {
h := httpHandler{}
h.handler = ensure(method, handler)
return &h
}
// preInstall lets the handler run only if firstRun is true, no redirects
func preInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if !globalContext.firstRun {
// if it's not first run, don't let users access it (for example /install.html when configuration is done)
// preInstallHandler lets the handler run only if firstRun is true; it does not
// perform redirects.
func (web *webAPI) preInstallHandler(handler http.Handler) (wrapped http.Handler) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !web.conf.firstRun {
// If it's not first run, do not allow access to install-only routes
// (for example, /install.html once configuration is complete).
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
handler(w, r)
}
}
// preInstallStruct wraps preInstall into a struct that can be returned as an interface where necessary
type preInstallHandlerStruct struct {
handler http.Handler
}
func (p *preInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
preInstall(p.handler.ServeHTTP)(w, r)
}
// preInstallHandler returns http.Handler interface for preInstall wrapper
func preInstallHandler(handler http.Handler) http.Handler {
return &preInstallHandlerStruct{handler}
handler.ServeHTTP(w, r)
})
}
// handleHTTPSRedirect redirects the request to HTTPS, if needed, and adds some
@ -401,34 +381,52 @@ func httpsURL(u *url.URL, host string, portHTTPS uint16) (redirectURL *url.URL)
}
}
// postInstall lets the handler to run only if firstRun is false. Otherwise, it
// redirects to /install.html. It also enforces HTTPS if it is enabled and
// configured and sets appropriate access control headers.
func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
// postInstallHandler lets the handler to run only if firstRun is false.
// Otherwise, it redirects to /install.html. It also enforces HTTPS if it is
// enabled and configured and sets appropriate access control headers.
//
// TODO(s.chzhen): Replace with [web.postInstall] after fixing its usage in
// [httpRegister], which is called by [dhcpd.Create] before [web] is
// initialized.
func postInstallHandler(handler http.Handler) (wrapped http.Handler) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if globalContext.web == nil {
aghhttp.Error(r, w, http.StatusTooEarly, "it is not initialized yet")
return
}
path := r.URL.Path
if globalContext.firstRun && !strings.HasPrefix(path, "/install.") &&
if globalContext.web.conf.firstRun &&
!strings.HasPrefix(path, "/install.") &&
!strings.HasPrefix(path, "/assets/") {
http.Redirect(w, r, "install.html", http.StatusFound)
return
}
proceed := handleHTTPSRedirect(w, r)
if proceed {
handler(w, r)
if handleHTTPSRedirect(w, r) {
handler.ServeHTTP(w, r)
}
}
})
}
type postInstallHandlerStruct struct {
handler http.Handler
}
// postInstallHandler lets the handler to run only if firstRun is false.
// Otherwise, it redirects to /install.html. It also enforces HTTPS if it is
// enabled and configured and sets appropriate access control headers.
func (web *webAPI) postInstallHandler(handler http.Handler) (wrapped http.Handler) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
if web.conf.firstRun &&
!strings.HasPrefix(path, "/install.") &&
!strings.HasPrefix(path, "/assets/") {
http.Redirect(w, r, "install.html", http.StatusFound)
func (p *postInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
postInstall(p.handler.ServeHTTP)(w, r)
}
return
}
func postInstallHandler(handler http.Handler) http.Handler {
return &postInstallHandlerStruct{handler}
if handleHTTPSRedirect(w, r) {
handler.ServeHTTP(w, r)
}
})
}

Some files were not shown because too many files have changed in this diff Show More