/*
 * This file is open source software, licensed to you under the terms
 * of the Apache License, Version 2.0 (the "License").  See the NOTICE file
 * distributed with this work for additional information regarding copyright
 * ownership.  You may not use this file except in compliance with the License.
 *
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
/*
 * Copyright (C) 2014 Cloudius Systems, Ltd.
 */

#include <random>

#include <sys/socket.h>
#include <linux/if.h>
#include <linux/netlink.h>
#include <linux/rtnetlink.h>
#include <net/route.h>

#include <seastar/net/posix-stack.hh>
#include <seastar/net/net.hh>
#include <seastar/net/packet.hh>
#include <seastar/net/api.hh>
#include <seastar/net/inet_address.hh>
#include <seastar/util/std-compat.hh>
#include <netinet/tcp.h>
#include <netinet/sctp.h>

namespace std {

template <>
struct hash<seastar::net::posix_ap_server_socket_impl::protocol_and_socket_address> {
    size_t operator()(const seastar::net::posix_ap_server_socket_impl::protocol_and_socket_address& t_sa) const {
        auto h1 = std::hash<int>()(std::get<0>(t_sa));
        auto h2 = std::hash<seastar::net::socket_address>()(std::get<1>(t_sa));
        return h1 ^ h2;
    }
};

}

namespace seastar {

namespace net {

using namespace seastar;

class posix_connected_socket_operations {
public:
    virtual ~posix_connected_socket_operations() = default;
    virtual void set_nodelay(file_desc& fd, bool nodelay) const = 0;
    virtual bool get_nodelay(file_desc& fd) const = 0;
    virtual void set_keepalive(file_desc& _fd, bool keepalive) const = 0;
    virtual bool get_keepalive(file_desc& _fd) const = 0;
    virtual void set_keepalive_parameters(file_desc& _fd, const keepalive_params& params) const = 0;
    virtual keepalive_params get_keepalive_parameters(file_desc& _fd) const = 0;
};

thread_local posix_ap_server_socket_impl::sockets_map_t posix_ap_server_socket_impl::sockets{};
thread_local posix_ap_server_socket_impl::conn_map_t posix_ap_server_socket_impl::conn_q{};

class posix_tcp_connected_socket_operations : public posix_connected_socket_operations {
public:
    virtual void set_nodelay(file_desc& _fd, bool nodelay) const override {
        _fd.setsockopt(IPPROTO_TCP, TCP_NODELAY, int(nodelay));
    }
    virtual bool get_nodelay(file_desc& _fd) const override {
        return _fd.getsockopt<int>(IPPROTO_TCP, TCP_NODELAY);
    }
    virtual void set_keepalive(file_desc& _fd, bool keepalive) const override {
        _fd.setsockopt(SOL_SOCKET, SO_KEEPALIVE, int(keepalive));
    }
    virtual bool get_keepalive(file_desc& _fd) const override {
        return _fd.getsockopt<int>(SOL_SOCKET, SO_KEEPALIVE);
    }
    virtual void set_keepalive_parameters(file_desc& _fd, const keepalive_params& params) const override {
        const tcp_keepalive_params& pms = compat::get<tcp_keepalive_params>(params);
        _fd.setsockopt(IPPROTO_TCP, TCP_KEEPCNT, pms.count);
        _fd.setsockopt(IPPROTO_TCP, TCP_KEEPIDLE, int(pms.idle.count()));
        _fd.setsockopt(IPPROTO_TCP, TCP_KEEPINTVL, int(pms.interval.count()));
    }
    virtual keepalive_params get_keepalive_parameters(file_desc& _fd) const override {
        return tcp_keepalive_params {
            std::chrono::seconds(_fd.getsockopt<int>(IPPROTO_TCP, TCP_KEEPIDLE)),
            std::chrono::seconds(_fd.getsockopt<int>(IPPROTO_TCP, TCP_KEEPINTVL)),
            _fd.getsockopt<unsigned>(IPPROTO_TCP, TCP_KEEPCNT)
        };
    }
};

class posix_sctp_connected_socket_operations : public posix_connected_socket_operations {
public:
    virtual void set_nodelay(file_desc& _fd, bool nodelay) const {
        _fd.setsockopt(SOL_SCTP, SCTP_NODELAY, int(nodelay));
    }
    virtual bool get_nodelay(file_desc& _fd) const {
        return _fd.getsockopt<int>(SOL_SCTP, SCTP_NODELAY);
    }
    virtual void set_keepalive(file_desc& _fd, bool keepalive) const override {
        auto heartbeat = _fd.getsockopt<sctp_paddrparams>(SOL_SCTP, SCTP_PEER_ADDR_PARAMS);
        if (keepalive) {
            heartbeat.spp_flags |= SPP_HB_ENABLE;
        } else {
            heartbeat.spp_flags &= ~SPP_HB_ENABLE;
        }
        _fd.setsockopt(SOL_SCTP, SCTP_PEER_ADDR_PARAMS, heartbeat);
    }
    virtual bool get_keepalive(file_desc& _fd) const override {
        return _fd.getsockopt<sctp_paddrparams>(SOL_SCTP, SCTP_PEER_ADDR_PARAMS).spp_flags & SPP_HB_ENABLE;
    }
    virtual void set_keepalive_parameters(file_desc& _fd, const keepalive_params& kpms) const override {
        const sctp_keepalive_params& pms = compat::get<sctp_keepalive_params>(kpms);
        auto params = _fd.getsockopt<sctp_paddrparams>(SOL_SCTP, SCTP_PEER_ADDR_PARAMS);
        params.spp_hbinterval = pms.interval.count() * 1000; // in milliseconds
        params.spp_pathmaxrxt = pms.count;
        _fd.setsockopt(SOL_SCTP, SCTP_PEER_ADDR_PARAMS, params);
    }
    virtual keepalive_params get_keepalive_parameters(file_desc& _fd) const override {
        auto params = _fd.getsockopt<sctp_paddrparams>(SOL_SCTP, SCTP_PEER_ADDR_PARAMS);
        return sctp_keepalive_params {
            std::chrono::seconds(params.spp_hbinterval/1000), // in seconds
            params.spp_pathmaxrxt
        };
    }
};

class posix_unix_stream_connected_socket_operations : public posix_connected_socket_operations {
public:
    virtual void set_nodelay(file_desc& fd, bool nodelay) const override {
        assert(nodelay); // make sure nobody actually tries to use this non-existing functionality
    }
    virtual bool get_nodelay(file_desc& fd) const override {
        return true;
    }
    virtual void set_keepalive(file_desc& fd, bool keepalive) const override {}
    virtual bool get_keepalive(file_desc& fd) const override {
        return false;
    }
    virtual void set_keepalive_parameters(file_desc& fd, const keepalive_params& p) const override {}
    virtual keepalive_params get_keepalive_parameters(file_desc& fd) const override {
        return keepalive_params{};
    }
};

static const posix_connected_socket_operations*
get_posix_connected_socket_ops(sa_family_t family, int protocol) {
    static posix_tcp_connected_socket_operations tcp_ops;
    static posix_sctp_connected_socket_operations sctp_ops;
    static posix_unix_stream_connected_socket_operations unix_ops;
    switch (family) {
    case AF_INET:
    case AF_INET6:
        switch (protocol) {
        case IPPROTO_TCP: return &tcp_ops;
        case IPPROTO_SCTP: return &sctp_ops;
        default: abort();
        }
    case AF_UNIX:
        return &unix_ops;
    default:
        abort();
    }
}

class posix_connected_socket_impl final : public connected_socket_impl {
    lw_shared_ptr<pollable_fd> _fd;
    const posix_connected_socket_operations* _ops;
    conntrack::handle _handle;
    compat::polymorphic_allocator<char>* _allocator;
private:
    explicit posix_connected_socket_impl(sa_family_t family, int protocol, lw_shared_ptr<pollable_fd> fd, compat::polymorphic_allocator<char>* allocator=memory::malloc_allocator) :
        _fd(std::move(fd)), _ops(get_posix_connected_socket_ops(family, protocol)), _allocator(allocator) {}
    explicit posix_connected_socket_impl(sa_family_t family, int protocol, lw_shared_ptr<pollable_fd> fd, conntrack::handle&& handle,
        compat::polymorphic_allocator<char>* allocator=memory::malloc_allocator) : _fd(std::move(fd))
                , _ops(get_posix_connected_socket_ops(family, protocol)), _handle(std::move(handle)), _allocator(allocator) {}
public:
    virtual data_source source() override {
        return data_source(std::make_unique< posix_data_source_impl>(_fd, _allocator));
    }
    virtual data_sink sink() override {
        return data_sink(std::make_unique< posix_data_sink_impl>(_fd));
    }
    virtual void shutdown_input() override {
        _fd->shutdown(SHUT_RD);
    }
    virtual void shutdown_output() override {
        _fd->shutdown(SHUT_WR);
    }
    virtual void set_nodelay(bool nodelay) override {
        return _ops->set_nodelay(_fd->get_file_desc(), nodelay);
    }
    virtual bool get_nodelay() const override {
        return _ops->get_nodelay(_fd->get_file_desc());
    }
    void set_keepalive(bool keepalive) override {
        return _ops->set_keepalive(_fd->get_file_desc(), keepalive);
    }
    bool get_keepalive() const override {
        return _ops->get_keepalive(_fd->get_file_desc());
    }
    void set_keepalive_parameters(const keepalive_params& p) override {
        return _ops->set_keepalive_parameters(_fd->get_file_desc(), p);
    }
    keepalive_params get_keepalive_parameters() const override {
        return _ops->get_keepalive_parameters(_fd->get_file_desc());
    }
    friend class posix_server_socket_impl;
    friend class posix_ap_server_socket_impl;
    friend class posix_reuseport_server_socket_impl;
    friend class posix_network_stack;
    friend class posix_ap_network_stack;
    friend class posix_socket_impl;
};

static void resolve_outgoing_address(socket_address& a) {
    if (a.family() != AF_INET6
        || a.as_posix_sockaddr_in6().sin6_scope_id != inet_address::invalid_scope
        || !IN6_IS_ADDR_LINKLOCAL(&a.as_posix_sockaddr_in6().sin6_addr)
    ) {
        return;
    }

    FILE *f;

    if (!(f = fopen("/proc/net/ipv6_route", "r"))) {
        throw std::system_error(errno, std::system_category(), "resolve_address");
    }

    auto holder = std::unique_ptr<FILE, int(*)(FILE *)>(f, &::fclose);

    /**
      Here all configured IPv6 routes are shown in a special format. The example displays for loopback interface only. The meaning is shown below (see net/ipv6/route.c for more).

    # cat /proc/net/ipv6_route
    00000000000000000000000000000000 00 00000000000000000000000000000000 00 00000000000000000000000000000000 ffffffff 00000001 00000001 00200200 lo
    +------------------------------+ ++ +------------------------------+ ++ +------------------------------+ +------+ +------+ +------+ +------+ ++
    |                                |  |                                |  |                                |        |        |        |        |
    1                                2  3                                4  5                                6        7        8        9        10

    1: IPv6 destination network displayed in 32 hexadecimal chars without colons as separator

    2: IPv6 destination prefix length in hexadecimal

    3: IPv6 source network displayed in 32 hexadecimal chars without colons as separator

    4: IPv6 source prefix length in hexadecimal

    5: IPv6 next hop displayed in 32 hexadecimal chars without colons as separator

    6: Metric in hexadecimal

    7: Reference counter

    8: Use counter

    9: Flags

    10: Device name

    */

    uint32_t prefix_len, src_prefix_len;
    unsigned long flags;
    char device[16];
    char dest_str[40];

    for (;;) {
        auto n = fscanf(f, "%4s%4s%4s%4s%4s%4s%4s%4s %02x "
                            "%*4s%*4s%*4s%*4s%*4s%*4s%*4s%*4s %02x "
                            "%*4s%*4s%*4s%*4s%*4s%*4s%*4s%*4s "
                            "%*08x %*08x %*08x %08lx %8s",
                            &dest_str[0], &dest_str[5], &dest_str[10], &dest_str[15],
                            &dest_str[20], &dest_str[25], &dest_str[30], &dest_str[35],
                            &prefix_len,
                            &src_prefix_len,
                            &flags, device);
        if (n != 12) {
            break;
        }

        if ((prefix_len > 128)  || (src_prefix_len != 0)
            || (flags & (RTF_POLICY | RTF_FLOW))
            || ((flags & RTF_REJECT) && prefix_len == 0) /* reject all */) {
            continue;
        }

        dest_str[4] = dest_str[9] = dest_str[14] = dest_str[19] = dest_str[24] = dest_str[29] = dest_str[34] = ':';
        dest_str[39] = '\0';

        struct in6_addr addr;
        if (inet_pton(AF_INET6, dest_str, &addr) < 0) {
            /* not an Ipv6 address */
            continue;
        }

        auto bytes = prefix_len / 8;
        auto bits = prefix_len % 8;

        auto& src = a.as_posix_sockaddr_in6().sin6_addr;

        if (bytes > 0 && memcmp(&src, &addr, bytes)) {
            continue;
        }
        if (bits > 0) {
            auto c1 = src.s6_addr[bytes];
            auto c2 = addr.s6_addr[bytes];
            auto mask = 0xffu << (8 - bits);
            if ((c1 & mask) != (c2 & mask)) {
                continue;
            }
        }

        // found the route.
        for (auto& nif : engine().net().network_interfaces()) {
            if (nif.name() == device || nif.display_name() == device) {
                a.as_posix_sockaddr_in6().sin6_scope_id = nif.index();
                return;
            }
        }
    }
}

class posix_socket_impl final : public socket_impl {
    lw_shared_ptr<pollable_fd> _fd;
    compat::polymorphic_allocator<char>* _allocator;
    bool _reuseaddr = false;

    future<> find_port_and_connect(socket_address sa, socket_address local, transport proto = transport::TCP) {
        static thread_local std::default_random_engine random_engine{std::random_device{}()};
        static thread_local std::uniform_int_distribution<uint16_t> u(49152/smp::count + 1, 65535/smp::count - 1);
        // If no explicit local address, set to dest address family wildcard. 
        if (local.is_unspecified()) {
            local = net::inet_address(sa.addr().in_family());
        }
        resolve_outgoing_address(sa);
        return repeat([this, sa, local, proto, attempts = 0, requested_port = ntoh(local.as_posix_sockaddr_in().sin_port)] () mutable {
            _fd = engine().make_pollable_fd(sa, int(proto));
            _fd->get_file_desc().setsockopt(SOL_SOCKET, SO_REUSEADDR, int(_reuseaddr));
            uint16_t port = attempts++ < 5 && requested_port == 0 && proto == transport::TCP ? u(random_engine) * smp::count + engine().cpu_id() : requested_port;
            local.as_posix_sockaddr_in().sin_port = hton(port);
            return futurize_apply([this, sa, local] { return engine().posix_connect(_fd, sa, local); }).then_wrapped([port, requested_port] (future<> f) {
                try {
                    f.get();
                    return stop_iteration::yes;
                } catch (std::system_error& err) {
                    if (port != requested_port && (err.code().value() == EADDRINUSE || err.code().value() == EADDRNOTAVAIL)) {
                        return stop_iteration::no;
                    }
                    throw;
                }
            });
        });
    }

    /// an aux function to handle unix-domain-specific requests
    future<connected_socket> connect_unix_domain(socket_address sa, socket_address local) {
        // note that if the 'local' address was not set by the client, it was created as an undefined address
        if (local.is_unspecified()) {
            local = socket_address{unix_domain_addr{std::string{}}};
        }

        _fd = engine().make_pollable_fd(sa, 0);
        return engine().posix_connect(_fd, sa, local).then(
            [fd = _fd, allocator = _allocator](){
                // a problem with 'private' interaction with 'unique_ptr'
                std::unique_ptr<connected_socket_impl> csi;
                csi.reset(new posix_connected_socket_impl{AF_UNIX, 0, std::move(fd), allocator});
                return make_ready_future<connected_socket>(connected_socket(std::move(csi)));
            }
        );
    }

public:
    explicit posix_socket_impl(compat::polymorphic_allocator<char>* allocator=memory::malloc_allocator) : _allocator(allocator) {}

    virtual future<connected_socket> connect(socket_address sa, socket_address local, transport proto = transport::TCP) override {
        if (sa.is_af_unix()) {
            return connect_unix_domain(sa, local);
        }
        return find_port_and_connect(sa, local, proto).then([this, sa, proto, allocator = _allocator] () mutable {
            std::unique_ptr<connected_socket_impl> csi;
            csi.reset(new posix_connected_socket_impl(sa.family(), static_cast<int>(proto), _fd, allocator));
            return make_ready_future<connected_socket>(connected_socket(std::move(csi)));
        });
    }

    void set_reuseaddr(bool reuseaddr) override {
        _reuseaddr = reuseaddr;
        if (_fd) {
            _fd->get_file_desc().setsockopt(SOL_SOCKET, SO_REUSEADDR, int(reuseaddr));
        }
    }

    bool get_reuseaddr() const override {
        if(_fd) {
            return _fd->get_file_desc().getsockopt<int>(SOL_SOCKET, SO_REUSEADDR);
        } else {
            return _reuseaddr;
        }
    }

    virtual void shutdown() override {
        if (_fd) {
            try {
                _fd->shutdown(SHUT_RDWR);
            } catch (std::system_error& e) {
                if (e.code().value() != ENOTCONN) {
                    throw;
                }
            }
        }
    }
};

future<accept_result>
posix_server_socket_impl::accept() {
    return _lfd.accept().then([this] (std::tuple<pollable_fd, socket_address> fd_sa) {
        auto& fd = std::get<0>(fd_sa);
        auto& sa = std::get<1>(fd_sa);
        auto cth = [this, &sa] {
            switch(_lba) {
            case server_socket::load_balancing_algorithm::connection_distribution:
                return _conntrack.get_handle();
            case server_socket::load_balancing_algorithm::port:
                return _conntrack.get_handle(ntoh(sa.as_posix_sockaddr_in().sin_port) % smp::count);
            case server_socket::load_balancing_algorithm::fixed:
                return _conntrack.get_handle(_fixed_cpu);
            default: abort();
            }
        } ();
        auto cpu = cth.cpu();
        if (cpu == engine().cpu_id()) {
            std::unique_ptr<connected_socket_impl> csi(
                    new posix_connected_socket_impl(sa.family(), _protocol, make_lw_shared(std::move(fd)), std::move(cth), _allocator));
            return make_ready_future<accept_result>(
                    accept_result{connected_socket(std::move(csi)), sa});
        } else {
            // FIXME: future is discarded
            (void)smp::submit_to(cpu, [protocol = _protocol, ssa = _sa, fd = std::move(fd.get_file_desc()), sa, cth = std::move(cth), allocator = _allocator] () mutable {
                posix_ap_server_socket_impl::move_connected_socket(protocol, ssa, pollable_fd(std::move(fd)), sa, std::move(cth), allocator);
            });
            return accept();
        }
    });
}

void
posix_server_socket_impl::abort_accept() {
    _lfd.abort_reader();
}

socket_address posix_server_socket_impl::local_address() const {
    return _lfd.get_file_desc().get_address();
}

future<accept_result> posix_ap_server_socket_impl::accept() {
    auto t_sa = std::make_tuple(_protocol, _sa);
    auto conni = conn_q.find(t_sa);
    if (conni != conn_q.end()) {
        connection c = std::move(conni->second);
        conn_q.erase(conni);
        try {
            std::unique_ptr<connected_socket_impl> csi(
                    new posix_connected_socket_impl(_sa.family(), _protocol, make_lw_shared(std::move(c.fd)), std::move(c.connection_tracking_handle)));
            return make_ready_future<accept_result>(accept_result{connected_socket(std::move(csi)), std::move(c.addr)});
        } catch (...) {
            return make_exception_future<accept_result>(std::current_exception());
        }
    } else {
        try {
            auto i = sockets.emplace(std::piecewise_construct, std::make_tuple(t_sa), std::make_tuple());
            assert(i.second);
            return i.first->second.get_future();
        } catch (...) {
            return make_exception_future<accept_result>(std::current_exception());
        }
    }
}

void
posix_ap_server_socket_impl::abort_accept() {
    auto t_sa = std::make_tuple(_protocol, _sa);
    conn_q.erase(t_sa);
    auto i = sockets.find(t_sa);
    if (i != sockets.end()) {
        i->second.set_exception(std::system_error(ECONNABORTED, std::system_category()));
        sockets.erase(i);
    }
}

future<accept_result>
posix_reuseport_server_socket_impl::accept() {
    return _lfd.accept().then([allocator = _allocator, protocol = _protocol] (std::tuple<pollable_fd, socket_address> fd_sa) {
        auto& fd = std::get<0>(fd_sa);
        auto& sa = std::get<1>(fd_sa);
        std::unique_ptr<connected_socket_impl> csi(
                new posix_connected_socket_impl(sa.family(), protocol, make_lw_shared(std::move(fd)), allocator));
        return make_ready_future<accept_result>(
            accept_result{connected_socket(std::move(csi)), sa});
    });
}

void
posix_reuseport_server_socket_impl::abort_accept() {
    _lfd.abort_reader();
}

socket_address posix_reuseport_server_socket_impl::local_address() const {
    return _lfd.get_file_desc().get_address();
}

void
posix_ap_server_socket_impl::move_connected_socket(int protocol, socket_address sa, pollable_fd fd, socket_address addr, conntrack::handle cth, compat::polymorphic_allocator<char>* allocator) {
    auto t_sa = std::make_tuple(protocol, sa);
    auto i = sockets.find(t_sa);
    if (i != sockets.end()) {
        try {
            std::unique_ptr<connected_socket_impl> csi(new posix_connected_socket_impl(sa.family(), protocol, make_lw_shared(std::move(fd)), std::move(cth), allocator));
            i->second.set_value(accept_result{connected_socket(std::move(csi)), std::move(addr)});
        } catch (...) {
            i->second.set_exception(std::current_exception());
        }
        sockets.erase(i);
    } else {
        conn_q.emplace(std::piecewise_construct, std::make_tuple(t_sa), std::make_tuple(std::move(fd), std::move(addr), std::move(cth)));
    }
}

future<temporary_buffer<char>>
posix_data_source_impl::get() {
    return _fd->read_some(_buf.get_write(), _buf_size).then([this] (size_t size) {
        _buf.trim(size);
        auto ret = std::move(_buf);
        _buf = make_temporary_buffer<char>(_buffer_allocator, _buf_size);
        return make_ready_future<temporary_buffer<char>>(std::move(ret));
    });
}

future<> posix_data_source_impl::close() {
    _fd->shutdown(SHUT_RD);
    return make_ready_future<>();
}

std::vector<struct iovec> to_iovec(const packet& p) {
    std::vector<struct iovec> v;
    v.reserve(p.nr_frags());
    for (auto&& f : p.fragments()) {
        v.push_back({.iov_base = f.base, .iov_len = f.size});
    }
    return v;
}

std::vector<iovec> to_iovec(std::vector<temporary_buffer<char>>& buf_vec) {
    std::vector<iovec> v;
    v.reserve(buf_vec.size());
    for (auto& buf : buf_vec) {
        v.push_back({.iov_base = buf.get_write(), .iov_len = buf.size()});
    }
    return v;
}

future<>
posix_data_sink_impl::put(temporary_buffer<char> buf) {
    return _fd->write_all(buf.get(), buf.size()).then([d = buf.release()] {});
}

future<>
posix_data_sink_impl::put(packet p) {
    _p = std::move(p);
    return _fd->write_all(_p).then([this] { _p.reset(); });
}

future<>
posix_data_sink_impl::close() {
    _fd->shutdown(SHUT_WR);
    return make_ready_future<>();
}

server_socket
posix_network_stack::listen(socket_address sa, listen_options opt) {
    using server_socket = seastar::api_v2::server_socket;
    // allow unspecified bind address -> default to ipv4 wildcard
    if (sa.is_unspecified()) {
        sa = inet_address(inet_address::family::INET);
    }
    if (sa.is_af_unix()) {
        return server_socket(std::make_unique<posix_server_socket_impl>(0, sa, engine().posix_listen(sa, opt), opt.lba, opt.fixed_cpu, _allocator));
    }
    auto protocol = static_cast<int>(opt.proto);
    return _reuseport ?
        server_socket(std::make_unique<posix_reuseport_server_socket_impl>(protocol, sa, engine().posix_listen(sa, opt), _allocator))
        :
        server_socket(std::make_unique<posix_server_socket_impl>(protocol, sa, engine().posix_listen(sa, opt), opt.lba, opt.fixed_cpu, _allocator));
}

::seastar::socket posix_network_stack::socket() {
    return ::seastar::socket(std::make_unique<posix_socket_impl>(_allocator));
}

server_socket
posix_ap_network_stack::listen(socket_address sa, listen_options opt) {
    using server_socket = seastar::api_v2::server_socket;
    // allow unspecified bind address -> default to ipv4 wildcard
    if (sa.is_unspecified()) {
        sa = inet_address(inet_address::family::INET);
    }
    if (sa.is_af_unix()) {
        return server_socket(std::make_unique<posix_ap_server_socket_impl>(0, sa));
    }
    auto protocol = static_cast<int>(opt.proto);
    return _reuseport ?
        server_socket(std::make_unique<posix_reuseport_server_socket_impl>(protocol, sa, engine().posix_listen(sa, opt)))
        :
        server_socket(std::make_unique<posix_ap_server_socket_impl>(protocol, sa));
}

struct cmsg_with_pktinfo {
    struct cmsghdrcmh;
    union {
        struct in_pktinfo pktinfo;
        struct in6_pktinfo pkt6info;
    };
};

class posix_udp_channel : public udp_channel_impl {
private:
    static constexpr int MAX_DATAGRAM_SIZE = 65507;
    struct recv_ctx {
        struct msghdr _hdr;
        struct iovec _iov;
        socket_address _src_addr;
        char* _buffer;
        cmsg_with_pktinfo _cmsg;

        recv_ctx() {
            memset(&_hdr, 0, sizeof(_hdr));
            _hdr.msg_iov = &_iov;
            _hdr.msg_iovlen = 1;
            _hdr.msg_name = &_src_addr.u.sa;
            _hdr.msg_namelen = sizeof(_src_addr.u.sas);
            memset(&_cmsg, 0, sizeof(_cmsg));
            _hdr.msg_control = &_cmsg;
            _hdr.msg_controllen = sizeof(_cmsg);
        }

        void prepare() {
            _buffer = new char[MAX_DATAGRAM_SIZE];
            _iov.iov_base = _buffer;
            _iov.iov_len = MAX_DATAGRAM_SIZE;
        }
    };
    struct send_ctx {
        struct msghdr _hdr;
        std::vector<struct iovec> _iovecs;
        socket_address _dst;
        packet _p;

        send_ctx() {
            memset(&_hdr, 0, sizeof(_hdr));
            _hdr.msg_name = &_dst.u.sa;
            _hdr.msg_namelen = sizeof(_dst.u.sas);
        }

        void prepare(const socket_address& dst, packet p) {
            _dst = dst;
            _p = std::move(p);
            _iovecs = to_iovec(_p);
            _hdr.msg_iov = _iovecs.data();
            _hdr.msg_iovlen = _iovecs.size();
            resolve_outgoing_address(_dst);
        }
    };
    std::unique_ptr<pollable_fd> _fd;
    socket_address _address;
    recv_ctx _recv;
    send_ctx _send;
    bool _closed;
public:
    posix_udp_channel(const socket_address& bind_address)
            : _closed(false) {
        auto sa = bind_address.is_unspecified() ? socket_address(inet_address(inet_address::family::INET)) : bind_address;
        file_desc fd = file_desc::socket(sa.u.sa.sa_family, SOCK_DGRAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0);
        fd.setsockopt(SOL_IP, IP_PKTINFO, true);
        if (engine().posix_reuseport_available()) {
            fd.setsockopt(SOL_SOCKET, SO_REUSEPORT, 1);
        }
        fd.bind(sa.u.sa, sizeof(sa.u.sas));
        _address = fd.get_address();
        _fd = std::make_unique<pollable_fd>(std::move(fd));
    }
    virtual ~posix_udp_channel() { if (!_closed) close(); };
    virtual future<udp_datagram> receive() override;
    virtual future<> send(const socket_address& dst, const char *msg) override;
    virtual future<> send(const socket_address& dst, packet p) override;
    virtual void shutdown_input() override {
        _fd->abort_reader();
    }
    virtual void shutdown_output() override {
        _fd->abort_writer();
    }
    virtual void close() override {
        _closed = true;
        _fd.reset();
    }
    virtual bool is_closed() const override { return _closed; }
    socket_address local_address() const override {
        assert(_address.u.sas.ss_family != AF_INET6 || (_address.addr_length > 20));
        return _address;
    }
};

future<> posix_udp_channel::send(const socket_address& dst, const char *message) {
    auto len = strlen(message);
    auto a = dst;
    resolve_outgoing_address(a);
    return _fd->sendto(a, message, len)
            .then([len] (size_t size) { assert(size == len); });
}

future<> posix_udp_channel::send(const socket_address& dst, packet p) {
    auto len = p.len();
    _send.prepare(dst, std::move(p));
    return _fd->sendmsg(&_send._hdr)
            .then([len] (size_t size) { assert(size == len); });
}

udp_channel
posix_network_stack::make_udp_channel(const socket_address& addr) {
    return udp_channel(std::make_unique<posix_udp_channel>(addr));
}

bool
posix_network_stack::supports_ipv6() const {
    static bool has_ipv6 = [] {
        try {
            posix_udp_channel c(ipv6_addr{"::1"});
            c.close();
            return true;
        } catch (...) {}
        return false;
    }();

    return has_ipv6;
}

class posix_datagram : public udp_datagram_impl {
private:
    socket_address _src;
    socket_address _dst;
    packet _p;
public:
    posix_datagram(const socket_address& src, const socket_address& dst, packet p) : _src(src), _dst(dst), _p(std::move(p)) {}
    virtual socket_address get_src() override { return _src; }
    virtual socket_address get_dst() override { return _dst; }
    virtual uint16_t get_dst_port() override { return _dst.port(); }
    virtual packet& get_data() override { return _p; }
};

future<udp_datagram>
posix_udp_channel::receive() {
    _recv.prepare();
    return _fd->recvmsg(&_recv._hdr).then([this] (size_t size) {
        socket_address dst;
        for (auto* cmsg = CMSG_FIRSTHDR(&_recv._hdr); cmsg != nullptr; cmsg = CMSG_NXTHDR(&_recv._hdr, cmsg)) {
            if (cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_PKTINFO) {
                dst = ipv4_addr(reinterpret_cast<const in_pktinfo*>(CMSG_DATA(cmsg))->ipi_addr, _address.port());
                break;
            } else if (cmsg->cmsg_level == IPPROTO_IPV6 && cmsg->cmsg_type == IPV6_PKTINFO) {
                dst = ipv6_addr(reinterpret_cast<const in6_pktinfo*>(CMSG_DATA(cmsg))->ipi6_addr, _address.port());
                break;
            }
        }
        return make_ready_future<udp_datagram>(udp_datagram(std::make_unique<posix_datagram>(
            _recv._src_addr, dst, packet(fragment{_recv._buffer, size}, make_deleter([buf = _recv._buffer] { delete[] buf; })))));
    }).handle_exception([p = _recv._buffer](auto ep) {
        delete[] p;
        return make_exception_future<udp_datagram>(std::move(ep));
    });
}

void register_posix_stack() {
    register_network_stack("posix", boost::program_options::options_description(),
        [](boost::program_options::variables_map ops) {
            return smp::main_thread() ? posix_network_stack::create(ops)
                                      : posix_ap_network_stack::create(ops);
        },
        true);
}

// nw interface stuff

std::vector<network_interface> posix_network_stack::network_interfaces() {
    class posix_network_interface_impl : public network_interface_impl {
    public:
        uint32_t _index = 0, _mtu = 0;
        sstring _name, _display_name;
        std::vector<net::inet_address> _addresses;
        std::vector<uint8_t> _hardware_address;
        bool _loopback = false, _virtual = false, _up = false;

        uint32_t index() const override {
            return _index;
        }
        uint32_t mtu() const override {
            return _mtu;
        }
        const sstring& name() const override {
            return _name;   
        }
        const sstring& display_name() const override {
            return _display_name.empty() ? name() : _display_name;
        }
        const std::vector<net::inet_address>& addresses() const override {
            return _addresses;            
        }
        const std::vector<uint8_t> hardware_address() const override {
            return _hardware_address;
        }
        bool is_loopback() const override {
            return _loopback;   
        }
        bool is_virtual() const override {
            return _virtual;
        }
        bool is_up() const override {
            // TODO: should be checked on query?
            return _up;
        }
        bool supports_ipv6() const override {
            // TODO: this is not 100% correct.
            return std::any_of(_addresses.begin(), _addresses.end(), std::mem_fn(&inet_address::is_ipv6));
        }
    };

    // For now, keep an immutable set of interfaces created on start, shared across 
    // shards
    static const std::vector<posix_network_interface_impl> global_interfaces = [] {
        auto fd = ::socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE);
        throw_system_error_on(fd < 0, "could not open netlink socket");

        std::unique_ptr<int, void(*)(int*)> fd_guard(&fd, [](int* p) { ::close(*p); });

        auto pid = ::getpid();

        sockaddr_nl local = { 0, };
        local.nl_family = AF_NETLINK;
        local.nl_pid = pid;
        local.nl_groups = RTMGRP_IPV6_IFADDR|RTMGRP_IPV4_IFADDR;

        throw_system_error_on(bind(fd, (struct sockaddr *) &local, sizeof(local)) < 0, "could not bind netlink socket");

        /* RTNL socket is ready for use, prepare and send requests */

        std::vector<posix_network_interface_impl> res;

        for (auto msg : { RTM_GETLINK, RTM_GETADDR}) {
            struct nl_req {
                nlmsghdr hdr;
                union {
                    rtgenmsg gen;
                    ifaddrmsg addr; 
                }; 
            } req = { 0, };

            sockaddr_nl kernel = { 0, }; 
            msghdr rtnl_msg = { 0, };
    
            kernel.nl_family = AF_NETLINK; /* fill-in kernel address (destination of our message) */

            req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(struct rtgenmsg));
            req.hdr.nlmsg_type = msg;
            req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ROOT; 
            req.hdr.nlmsg_seq = 1;
            req.hdr.nlmsg_pid = pid;

            if (msg == RTM_GETLINK) {
                req.gen.rtgen_family = AF_PACKET; /*  no preferred AF, we will get *all* interfaces */
            } else {
                req.addr.ifa_family = AF_UNSPEC;
            }

            iovec io;

            io.iov_base = &req;
            io.iov_len = req.hdr.nlmsg_len;

            rtnl_msg.msg_iov = &io;
            rtnl_msg.msg_iovlen = 1;
            rtnl_msg.msg_name = &kernel;
            rtnl_msg.msg_namelen = sizeof(kernel);

            throw_system_error_on(::sendmsg(fd, (struct msghdr *) &rtnl_msg, 0) < 0, "could not send netlink request");
            /* parse reply */

            constexpr size_t reply_buffer_size = 8192;
            char reply[reply_buffer_size]; 

            bool done = false;

            while (!done) {
                msghdr rtnl_reply = { 0, };
                iovec io_reply = { 0, };

                io_reply.iov_base = reply;
                io_reply.iov_len = reply_buffer_size;
                rtnl_reply.msg_iov = &io_reply;
                rtnl_reply.msg_iovlen = 1;
                rtnl_reply.msg_name = &kernel;
                rtnl_reply.msg_namelen = sizeof(kernel);

                auto len = ::recvmsg(fd, &rtnl_reply, 0); /* read as much data as fits in the receive buffer */
                if (len <= 0) {
                    return res;
                }

                for (auto* msg_ptr = (struct nlmsghdr *) reply; NLMSG_OK(msg_ptr, len); msg_ptr = NLMSG_NEXT(msg_ptr, len)) {
                    switch(msg_ptr->nlmsg_type) {
                    case NLMSG_DONE: // that is all
                        done = true;
                        break;                    
                    case RTM_NEWLINK: 
                    {
                        auto* iface = reinterpret_cast<const ifinfomsg*>(NLMSG_DATA(msg_ptr));
                        auto ilen = msg_ptr->nlmsg_len - NLMSG_LENGTH(sizeof(ifinfomsg));

                        // todo: filter any non-network interfaces (family)

                        posix_network_interface_impl nwif;
                        
                        nwif._index = iface->ifi_index;
                        nwif._loopback = (iface->ifi_flags & IFF_LOOPBACK) != 0;
                        nwif._up = (iface->ifi_flags & IFF_UP) != 0;
    #if defined(IFF_802_1Q_VLAN) && defined(IFF_EBRIDGE) && defined(IFF_SLAVE_INACTIVE)
                        nwif._virtual = (iface->ifi_flags & (IFF_802_1Q_VLAN|IFF_EBRIDGE|IFF_SLAVE_INACTIVE)) != 0;
    #endif                                        
                        for (auto* attribute = IFLA_RTA(iface); RTA_OK(attribute, ilen); attribute = RTA_NEXT(attribute, ilen)) {
                            switch(attribute->rta_type) {
                            case IFLA_IFNAME:
                                nwif._name = reinterpret_cast<const char *>(RTA_DATA(attribute));
                                break;
                            case IFLA_MTU:
                                nwif._mtu = *reinterpret_cast<const uint32_t *>(RTA_DATA(attribute));                            
                                break;
                            case IFLA_ADDRESS:
                                nwif._hardware_address.assign(reinterpret_cast<const uint8_t *>(RTA_DATA(attribute)), reinterpret_cast<const uint8_t *>(RTA_DATA(attribute)) + RTA_PAYLOAD(attribute));
                                break;
                            default:
                                break;
                            }
                        }

                        res.emplace_back(std::move(nwif));

                        break;
                    }
                    case RTM_NEWADDR:
                    {
                        auto* addr = reinterpret_cast<const ifaddrmsg*>(NLMSG_DATA(msg_ptr));
                        auto ilen = msg_ptr->nlmsg_len - NLMSG_LENGTH(sizeof(ifaddrmsg));
                        
                        for (auto& nwif : res) {
                            if (nwif._index == addr->ifa_index) {
                                for (auto* attribute = IFA_RTA(addr); RTA_OK(attribute, ilen); attribute = RTA_NEXT(attribute, ilen)) {
                                    compat::optional<inet_address> ia;
                                    
                                    switch(attribute->rta_type) {
                                    case IFA_LOCAL:
                                    case IFA_ADDRESS: // ipv6 addresses are reported only as "ADDRESS"

                                        if (RTA_PAYLOAD(attribute) == sizeof(::in_addr)) {
                                            ia.emplace(*reinterpret_cast<const ::in_addr *>(RTA_DATA(attribute)));
                                        } else if (RTA_PAYLOAD(attribute) == sizeof(::in6_addr)) {
                                            ia.emplace(*reinterpret_cast<const ::in6_addr *>(RTA_DATA(attribute)), nwif.index());
                                        }
                                        
                                        if (ia && std::find(nwif._addresses.begin(), nwif._addresses.end(), *ia) == nwif._addresses.end()) {
                                            nwif._addresses.emplace_back(*ia);
                                        }

                                        break;
                                    default:
                                        break;
                                    }
                                }

                                break;
                            }
                        }
                    }
                    default:
                        break;
                    }
                }      
            }
        }

        return res;
    }();

    // And a similarly immutable set of shared_ptr to network_interface_impl per shard, ready 
    // to be handed out to callers with minimal overhead
    static const thread_local std::vector<shared_ptr<posix_network_interface_impl>> thread_local_interfaces = [] {
        std::vector<shared_ptr<posix_network_interface_impl>> res;
        res.reserve(global_interfaces.size());
        std::transform(global_interfaces.begin(), global_interfaces.end(), std::back_inserter(res), [](const posix_network_interface_impl& impl) {
            return make_shared<posix_network_interface_impl>(impl);
        });
        return res;
    }();

    return std::vector<network_interface>(thread_local_interfaces.begin(), thread_local_interfaces.end());
}

}

}
