diff --git a/pico_w/access_point/CMakeLists.txt b/pico_w/access_point/CMakeLists.txt index 7ee8112..fc2c244 100644 --- a/pico_w/access_point/CMakeLists.txt +++ b/pico_w/access_point/CMakeLists.txt @@ -1,12 +1,14 @@ add_executable(picow_access_point_background picow_access_point.c dhcpserver/dhcpserver.c + dnsserver/dnsserver.c ) target_include_directories(picow_access_point_background PRIVATE ${CMAKE_CURRENT_LIST_DIR} ${CMAKE_CURRENT_LIST_DIR}/.. # for our common lwipopts ${CMAKE_CURRENT_LIST_DIR}/dhcpserver + ${CMAKE_CURRENT_LIST_DIR}/dnsserver ) target_link_libraries(picow_access_point_background @@ -19,15 +21,16 @@ pico_add_extra_outputs(picow_access_point_background) add_executable(picow_access_point_poll picow_access_point.c dhcpserver/dhcpserver.c + dnsserver/dnsserver.c ) target_include_directories(picow_access_point_poll PRIVATE ${CMAKE_CURRENT_LIST_DIR} ${CMAKE_CURRENT_LIST_DIR}/.. # for our common lwipopts ${CMAKE_CURRENT_LIST_DIR}/dhcpserver + ${CMAKE_CURRENT_LIST_DIR}/dnsserver ) target_link_libraries(picow_access_point_poll pico_cyw43_arch_lwip_poll pico_stdlib ) pico_add_extra_outputs(picow_access_point_poll) - diff --git a/pico_w/access_point/dhcpserver/dhcpserver.c b/pico_w/access_point/dhcpserver/dhcpserver.c index 502aefd..d77b32c 100644 --- a/pico_w/access_point/dhcpserver/dhcpserver.c +++ b/pico_w/access_point/dhcpserver/dhcpserver.c @@ -63,7 +63,6 @@ #define PORT_DHCP_SERVER (67) #define PORT_DHCP_CLIENT (68) -#define DEFAULT_DNS MAKE_IP4(8, 8, 8, 8) #define DEFAULT_LEASE_TIME_S (24 * 60 * 60) // in seconds #define MAC_LEN (6) @@ -274,7 +273,7 @@ static void dhcp_server_process(void *arg, struct udp_pcb *upcb, struct pbuf *p, opt_write_n(&opt, DHCP_OPT_SERVER_ID, 4, &ip4_addr_get_u32(ip_2_ip4(&d->ip))); opt_write_n(&opt, DHCP_OPT_SUBNET_MASK, 4, &ip4_addr_get_u32(ip_2_ip4(&d->nm))); opt_write_n(&opt, DHCP_OPT_ROUTER, 4, &ip4_addr_get_u32(ip_2_ip4(&d->ip))); // aka gateway; can have mulitple addresses - opt_write_u32(&opt, DHCP_OPT_DNS, DEFAULT_DNS); // can have mulitple addresses + opt_write_n(&opt, DHCP_OPT_DNS, 4, &ip4_addr_get_u32(ip_2_ip4(&d->ip))); // this server is the dns opt_write_u32(&opt, DHCP_OPT_IP_LEASE_TIME, DEFAULT_LEASE_TIME_S); *opt++ = DHCP_OPT_END; dhcp_socket_sendto(&d->udp, &dhcp_msg, opt - (uint8_t *)&dhcp_msg, 0xffffffff, PORT_DHCP_CLIENT); diff --git a/pico_w/access_point/dnsserver/dnsserver.c b/pico_w/access_point/dnsserver/dnsserver.c new file mode 100644 index 0000000..b55bc31 --- /dev/null +++ b/pico_w/access_point/dnsserver/dnsserver.c @@ -0,0 +1,235 @@ +/** + * Copyright (c) 2022 Raspberry Pi (Trading) Ltd. + * + * SPDX-License-Identifier: BSD-3-Clause + */ + +#include +#include +#include +#include +#include + +#include "dnsserver.h" +#include "lwip/udp.h" + +#define PORT_DNS_SERVER 53 +#define DUMP_DATA 0 + +#define DEBUG_printf(...) +#define ERROR_printf printf + +typedef struct dns_header_t_ { + uint16_t id; + uint16_t flags; + uint16_t question_count; + uint16_t answer_record_count; + uint16_t authority_record_count; + uint16_t additional_record_count; +} dns_header_t; + +#define MAX_DNS_MSG_SIZE 300 + +static int dns_socket_new_dgram(struct udp_pcb **udp, void *cb_data, udp_recv_fn cb_udp_recv) { + *udp = udp_new(); + if (*udp == NULL) { + return -ENOMEM; + } + udp_recv(*udp, cb_udp_recv, (void *)cb_data); + return ERR_OK; +} + +static void dns_socket_free(struct udp_pcb **udp) { + if (*udp != NULL) { + udp_remove(*udp); + *udp = NULL; + } +} + +static int dns_socket_bind(struct udp_pcb **udp, uint32_t ip, uint16_t port) { + ip_addr_t addr; + IP4_ADDR(&addr, ip >> 24 & 0xff, ip >> 16 & 0xff, ip >> 8 & 0xff, ip & 0xff); + err_t err = udp_bind(*udp, &addr, port); + if (err != ERR_OK) { + ERROR_printf("dns failed to bind to port %u: %d", port, err); + assert(false); + } + return err; +} + +#if DUMP_DATA +static void dump_bytes(const uint8_t *bptr, uint32_t len) { + unsigned int i = 0; + + for (i = 0; i < len;) { + if ((i & 0x0f) == 0) { + printf("\n"); + } else if ((i & 0x07) == 0) { + printf(" "); + } + printf("%02x ", bptr[i++]); + } + printf("\n"); +} +#endif + +static int dns_socket_sendto(struct udp_pcb **udp, const void *buf, size_t len, const ip_addr_t *dest, uint16_t port) { + if (len > 0xffff) { + len = 0xffff; + } + + struct pbuf *p = pbuf_alloc(PBUF_TRANSPORT, len, PBUF_RAM); + if (p == NULL) { + ERROR_printf("DNS: Failed to send message out of memory\n"); + return -ENOMEM; + } + + memcpy(p->payload, buf, len); + err_t err = udp_sendto(*udp, p, dest, port); + + pbuf_free(p); + + if (err != ERR_OK) { + ERROR_printf("DNS: Failed to send message %d\n", err); + return err; + } + +#if DUMP_DATA + dump_bytes(buf, len); +#endif + return len; +} + +static void dns_server_process(void *arg, struct udp_pcb *upcb, struct pbuf *p, const ip_addr_t *src_addr, u16_t src_port) { + dns_server_t *d = arg; + DEBUG_printf("dns_server_process %u\n", p->tot_len); + + uint8_t dns_msg[MAX_DNS_MSG_SIZE]; + dns_header_t *dns_hdr = (dns_header_t*)dns_msg; + + size_t msg_len = pbuf_copy_partial(p, dns_msg, sizeof(dns_msg), 0); + if (msg_len < sizeof(dns_header_t)) { + goto ignore_request; + } + +#if DUMP_DATA + dump_bytes(dns_msg, msg_len); +#endif + + uint16_t flags = lwip_ntohs(dns_hdr->flags); + uint16_t question_count = lwip_ntohs(dns_hdr->question_count); + + DEBUG_printf("len %d\n", msg_len); + DEBUG_printf("dns flags 0x%x\n", flags); + DEBUG_printf("dns question count 0x%x\n", question_count); + + // flags from rfc1035 + // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + // |QR| Opcode |AA|TC|RD|RA| Z | RCODE | + // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + + // Check QR indicates a query + if (((flags >> 15) & 0x1) != 0) { + DEBUG_printf("Ignoring non-query\n"); + goto ignore_request; + } + + // Check for standard query + if (((flags >> 11) & 0xf) != 0) { + DEBUG_printf("Ignoring non-standard query\n"); + goto ignore_request; + } + + // Check question count + if (question_count < 1) { + DEBUG_printf("Invalid question count\n"); + goto ignore_request; + } + + // Print the question + DEBUG_printf("question: "); + const uint8_t *question_ptr_start = dns_msg + sizeof(dns_header_t); + const uint8_t *question_ptr_end = dns_msg + msg_len; + const uint8_t *question_ptr = question_ptr_start; + while(question_ptr < question_ptr_end) { + if (*question_ptr == 0) { + question_ptr++; + break; + } else { + if (question_ptr > question_ptr_start) { + DEBUG_printf("."); + } + int label_len = *question_ptr++; + if (label_len > 63) { + DEBUG_printf("Invalid label\n"); + goto ignore_request; + } + DEBUG_printf("%.*s", label_len, question_ptr); + question_ptr += label_len; + } + } + DEBUG_printf("\n"); + + // Check question length + if (question_ptr - question_ptr_start > 255) { + DEBUG_printf("Invalid question length\n"); + goto ignore_request; + } + + // Skip QNAME and QTYPE + question_ptr += 4; + + // Generate answer + uint8_t *answer_ptr = dns_msg + (question_ptr - dns_msg); + *answer_ptr++ = 0xc0; // pointer + *answer_ptr++ = question_ptr_start - dns_msg; // pointer to question + + *answer_ptr++ = 0; + *answer_ptr++ = 1; // host address + + *answer_ptr++ = 0; + *answer_ptr++ = 1; // Internet class + + *answer_ptr++ = 0; + *answer_ptr++ = 0; + *answer_ptr++ = 0; + *answer_ptr++ = 60; // ttl 60s + + *answer_ptr++ = 0; + *answer_ptr++ = 4; // length + memcpy(answer_ptr, &d->ip.addr, 4); // use our address + answer_ptr += 4; + + dns_hdr->flags = lwip_htons( + 0x1 << 15 | // QR = response + 0x1 << 10 | // AA = authoritive + 0x1 << 7); // RA = authenticated + dns_hdr->question_count = lwip_htons(1); + dns_hdr->answer_record_count = lwip_htons(1); + dns_hdr->authority_record_count = 0; + dns_hdr->additional_record_count = 0; + + // Send the reply + DEBUG_printf("Sending %d byte reply to %s:%d\n", answer_ptr - dns_msg, ipaddr_ntoa(src_addr), src_port); + dns_socket_sendto(&d->udp, &dns_msg, answer_ptr - dns_msg, src_addr, src_port); + +ignore_request: + pbuf_free(p); +} + +void dns_server_init(dns_server_t *d, ip_addr_t *ip) { + if (dns_socket_new_dgram(&d->udp, d, dns_server_process) != ERR_OK) { + DEBUG_printf("dns server failed to start\n"); + return; + } + if (dns_socket_bind(&d->udp, 0, PORT_DNS_SERVER) != ERR_OK) { + DEBUG_printf("dns server failed to bind\n"); + return; + } + ip_addr_copy(d->ip, *ip); + DEBUG_printf("dns server listening on port %d\n", PORT_DNS_SERVER); +} + +void dns_server_deinit(dns_server_t *d) { + dns_socket_free(&d->udp); +} diff --git a/pico_w/access_point/dnsserver/dnsserver.h b/pico_w/access_point/dnsserver/dnsserver.h new file mode 100644 index 0000000..d23534c --- /dev/null +++ b/pico_w/access_point/dnsserver/dnsserver.h @@ -0,0 +1,20 @@ +/** + * Copyright (c) 2022 Raspberry Pi (Trading) Ltd. + * + * SPDX-License-Identifier: BSD-3-Clause + */ + +#ifndef _DNSSERVER_H_ +#define _DNSSERVER_H_ + +#include "lwip/ip_addr.h" + +typedef struct dns_server_t_ { + struct udp_pcb *udp; + ip_addr_t ip; +} dns_server_t; + +void dns_server_init(dns_server_t *d, ip_addr_t *ip); +void dns_server_deinit(dns_server_t *d); + +#endif diff --git a/pico_w/access_point/picow_access_point.c b/pico_w/access_point/picow_access_point.c index a41add5..899e61f 100644 --- a/pico_w/access_point/picow_access_point.c +++ b/pico_w/access_point/picow_access_point.c @@ -6,86 +6,239 @@ #include -#include "pico/stdlib.h" #include "pico/cyw43_arch.h" +#include "pico/stdlib.h" #include "lwip/pbuf.h" #include "lwip/tcp.h" #include "dhcpserver.h" - -#ifndef USE_LED -#define USE_LED 1 -#endif +#include "dnsserver.h" #define TCP_PORT 80 #define DEBUG_printf printf +#define POLL_TIME_S 5 +#define HTTP_GET "GET" +#define HTTP_RESPONSE_HEADERS "HTTP/1.1 %d OK\nContent-Length: %d\nContent-Type: text/html; charset=utf-8\nConnection: close\n\n" +#define LED_TEST_BODY "

Hello from Pico W.

Led is %s

Turn led %s" +#define LED_PARAM "led=%d" +#define LED_TEST "/ledtest" +#define LED_GPIO 0 +#define HTTP_RESPONSE_REDIRECT "HTTP/1.1 302 Redirect\nLocation: http://%s" LED_TEST "\n\n" -typedef struct TCP_ASERVER_T_ { +typedef struct TCP_SERVER_T_ { struct tcp_pcb *server_pcb; - struct tcp_pcb *client_pcb; bool complete; + ip_addr_t gw; } TCP_SERVER_T; -static err_t tcp_server_close(void *arg) { - TCP_SERVER_T *state = (TCP_SERVER_T*)arg; - err_t err = ERR_OK; - if (state->client_pcb != NULL) { - tcp_arg(state->client_pcb, NULL); - tcp_poll(state->client_pcb, NULL, 0); - tcp_sent(state->client_pcb, NULL); - tcp_recv(state->client_pcb, NULL); - tcp_err(state->client_pcb, NULL); - err = tcp_close(state->client_pcb); +typedef struct TCP_CONNECT_STATE_T_ { + struct tcp_pcb *pcb; + int sent_len; + char headers[128]; + char result[256]; + int header_len; + int result_len; + ip_addr_t *gw; +} TCP_CONNECT_STATE_T; + +static err_t tcp_close_client_connection(TCP_CONNECT_STATE_T *con_state, struct tcp_pcb *client_pcb, err_t close_err) { + if (client_pcb) { + assert(con_state && con_state->pcb == client_pcb); + tcp_arg(client_pcb, NULL); + tcp_poll(client_pcb, NULL, 0); + tcp_sent(client_pcb, NULL); + tcp_recv(client_pcb, NULL); + tcp_err(client_pcb, NULL); + err_t err = tcp_close(client_pcb); if (err != ERR_OK) { DEBUG_printf("close failed %d, calling abort\n", err); - tcp_abort(state->client_pcb); - err = ERR_ABRT; + tcp_abort(client_pcb); + close_err = ERR_ABRT; + } + if (con_state) { + free(con_state); } - state->client_pcb = NULL; } + return close_err; +} + +static void tcp_server_close(TCP_SERVER_T *state) { if (state->server_pcb) { tcp_arg(state->server_pcb, NULL); tcp_close(state->server_pcb); state->server_pcb = NULL; } - return err; } -static err_t tcp_ap_result(void *arg, int status) { - TCP_SERVER_T *state = (TCP_SERVER_T*)arg; - if (status == 0) { - DEBUG_printf("test success\n"); - } else { - DEBUG_printf("test failed %d\n", status); +static err_t tcp_server_sent(void *arg, struct tcp_pcb *pcb, u16_t len) { + TCP_CONNECT_STATE_T *con_state = (TCP_CONNECT_STATE_T*)arg; + DEBUG_printf("tcp_server_sent %u\n", len); + con_state->sent_len += len; + if (con_state->sent_len >= con_state->header_len + con_state->result_len) { + DEBUG_printf("all done\n"); + return tcp_close_client_connection(con_state, pcb, ERR_OK); + } + return ERR_OK; +} + +static int test_server_content(const char *request, const char *params, char *result, size_t max_result_len) { + int len = 0; + if (strncmp(request, LED_TEST, sizeof(LED_TEST) - 1) == 0) { + // Get the state of the led + bool value; + cyw43_gpio_get(&cyw43_state, LED_GPIO, &value); + int led_state = value; + + // See if the user changed it + if (params) { + int led_param = sscanf(params, LED_PARAM, &led_state); + if (led_param == 1) { + if (led_state) { + // Turn led on + cyw43_gpio_set(&cyw43_state, 0, true); + } else { + // Turn led off + cyw43_gpio_set(&cyw43_state, 0, false); + } + } + } + // Generate result + if (led_state) { + len = snprintf(result, max_result_len, LED_TEST_BODY, "ON", 0, "OFF"); + } else { + len = snprintf(result, max_result_len, LED_TEST_BODY, "OFF", 1, "ON"); + } + } + return len; +} + +err_t tcp_server_recv(void *arg, struct tcp_pcb *pcb, struct pbuf *p, err_t err) { + TCP_CONNECT_STATE_T *con_state = (TCP_CONNECT_STATE_T*)arg; + if (!p) { + DEBUG_printf("connection closed\n"); + return tcp_close_client_connection(con_state, pcb, ERR_OK); + } + assert(con_state && con_state->pcb == pcb); + if (p->tot_len > 0) { + DEBUG_printf("tcp_server_recv %d err %d\n", p->tot_len, err); +#if 0 + for (struct pbuf *q = p; q != NULL; q = q->next) { + DEBUG_printf("in: %.*s\n", q->len, q->payload); + } +#endif + // Copy the request into the buffer + pbuf_copy_partial(p, con_state->headers, p->tot_len > sizeof(con_state->headers) - 1 ? sizeof(con_state->headers) - 1 : p->tot_len, 0); + + // Handle GET request + if (strncmp(HTTP_GET, con_state->headers, sizeof(HTTP_GET) - 1) == 0) { + char *request = con_state->headers + sizeof(HTTP_GET); // + space + char *params = strchr(request, '?'); + if (params) { + if (*params) { + char *space = strchr(request, ' '); + *params++ = 0; + if (space) { + *space = 0; + } + } else { + params = NULL; + } + } + + // Generate content + con_state->result_len = test_server_content(request, params, con_state->result, sizeof(con_state->result)); + DEBUG_printf("Request: %s?%s\n", request, params); + DEBUG_printf("Result: %d\n", con_state->result_len); + + // Check we had enough buffer space + if (con_state->result_len > sizeof(con_state->result) - 1) { + DEBUG_printf("Too much result data %d\n", con_state->result_len); + return tcp_close_client_connection(con_state, pcb, ERR_CLSD); + } + + // Generate web page + if (con_state->result_len > 0) { + con_state->header_len = snprintf(con_state->headers, sizeof(con_state->headers), HTTP_RESPONSE_HEADERS, + 200, con_state->result_len); + if (con_state->header_len > sizeof(con_state->headers) - 1) { + DEBUG_printf("Too much header data %d\n", con_state->header_len); + return tcp_close_client_connection(con_state, pcb, ERR_CLSD); + } + } else { + // Send redirect + con_state->header_len = snprintf(con_state->headers, sizeof(con_state->headers), HTTP_RESPONSE_REDIRECT, + ipaddr_ntoa(con_state->gw)); + DEBUG_printf("Sending redirect %s", con_state->headers); + } + + // Send the headers to the client + con_state->sent_len = 0; + err_t err = tcp_write(pcb, con_state->headers, con_state->header_len, 0); + if (err != ERR_OK) { + DEBUG_printf("failed to write header data %d\n", err); + return tcp_close_client_connection(con_state, pcb, err); + } + + // Send the body to the client + if (con_state->result_len) { + err = tcp_write(pcb, con_state->result, con_state->result_len, 0); + if (err != ERR_OK) { + DEBUG_printf("failed to write result data %d\n", err); + return tcp_close_client_connection(con_state, pcb, err); + } + } + } + tcp_recved(pcb, p->tot_len); + } + pbuf_free(p); + return ERR_OK; +} + +static err_t tcp_server_poll(void *arg, struct tcp_pcb *pcb) { + TCP_CONNECT_STATE_T *con_state = (TCP_CONNECT_STATE_T*)arg; + DEBUG_printf("tcp_server_poll_fn\n"); + return tcp_close_client_connection(con_state, pcb, ERR_OK); // Just disconnect clent? +} + +static void tcp_server_err(void *arg, err_t err) { + TCP_CONNECT_STATE_T *con_state = (TCP_CONNECT_STATE_T*)arg; + if (err != ERR_ABRT) { + DEBUG_printf("tcp_client_err_fn %d\n", err); + tcp_close_client_connection(con_state, con_state->pcb, err); } - state->complete = true; - return tcp_server_close(arg); } static err_t tcp_server_accept(void *arg, struct tcp_pcb *client_pcb, err_t err) { -// TCP_SERVER_T *state = (TCP_SERVER_T*)arg; + TCP_SERVER_T *state = (TCP_SERVER_T*)arg; if (err != ERR_OK || client_pcb == NULL) { - DEBUG_printf("Failure in accept\n"); - tcp_ap_result(arg, err); + DEBUG_printf("failure in accept\n"); return ERR_VAL; } - DEBUG_printf("Client connected\n"); + DEBUG_printf("client connected\n"); - /*state->client_pcb = client_pcb; - tcp_arg(client_pcb, state); + // Create the state for the connection + TCP_CONNECT_STATE_T *con_state = calloc(1, sizeof(TCP_CONNECT_STATE_T)); + if (!con_state) { + DEBUG_printf("failed to allocate connect state\n"); + return ERR_MEM; + } + con_state->pcb = client_pcb; // for checking + con_state->gw = &state->gw; + + // setup connection to client + tcp_arg(client_pcb, con_state); tcp_sent(client_pcb, tcp_server_sent); tcp_recv(client_pcb, tcp_server_recv); tcp_poll(client_pcb, tcp_server_poll, POLL_TIME_S * 2); tcp_err(client_pcb, tcp_server_err); - return tcp_server_send_data(arg, state->client_pcb);*/ return ERR_OK; } static bool tcp_server_open(void *arg) { TCP_SERVER_T *state = (TCP_SERVER_T*)arg; - DEBUG_printf("Starting server on port %u\n", TCP_PORT); + DEBUG_printf("starting server on port %u\n", TCP_PORT); struct tcp_pcb *pcb = tcp_new_ip_type(IPADDR_TYPE_ANY); if (!pcb) { @@ -124,7 +277,7 @@ int main() { } if (cyw43_arch_init()) { - printf("failed to initialise\n"); + DEBUG_printf("failed to initialise\n"); return 1; } const char *ap_name = "picow_test"; @@ -136,30 +289,24 @@ int main() { cyw43_arch_enable_ap_mode(ap_name, password, CYW43_AUTH_WPA2_AES_PSK); - ip_addr_t gw, mask; - IP4_ADDR(ip_2_ip4(&gw), 192, 168, 4, 1); + ip4_addr_t mask; + IP4_ADDR(ip_2_ip4(&state->gw), 192, 168, 4, 1); IP4_ADDR(ip_2_ip4(&mask), 255, 255, 255, 0); // Start the dhcp server dhcp_server_t dhcp_server; - dhcp_server_init(&dhcp_server, &gw, &mask); + dhcp_server_init(&dhcp_server, &state->gw, &mask); + + // Start the dns server + dns_server_t dns_server; + dns_server_init(&dns_server, &state->gw); if (!tcp_server_open(state)) { - tcp_ap_result(state, -1); + DEBUG_printf("failed to open server\n"); + return 1; } while(!state->complete) { - #if USE_LED - static absolute_time_t led_time; - static int led_on = true; - - // Invert the led - if (absolute_time_diff_us(get_absolute_time(), led_time) < 0) { - led_on = !led_on; - cyw43_arch_gpio_put(CYW43_WL_GPIO_LED_PIN, led_on); - led_time = make_timeout_time_ms(1000); - } - #endif // the following #ifdef is only here so this same example can be used in multiple modes; // you do not need it in your code #if PICO_CYW43_ARCH_POLL @@ -174,6 +321,7 @@ int main() { sleep_ms(1000); #endif } + dns_server_deinit(&dns_server); dhcp_server_deinit(&dhcp_server); cyw43_arch_deinit(); return 0;