330 lines
6.9 KiB
Go
330 lines
6.9 KiB
Go
package ssdp
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"math/rand"
|
|
"net"
|
|
"net/http"
|
|
"net/textproto"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"golang.org/x/net/ipv4"
|
|
)
|
|
|
|
const (
|
|
AddrString = "239.255.255.250:1900"
|
|
rootDevice = "upnp:rootdevice"
|
|
aliveNTS = "ssdp:alive"
|
|
byebyeNTS = "ssdp:byebye"
|
|
)
|
|
|
|
var (
|
|
NetAddr *net.UDPAddr
|
|
)
|
|
|
|
func init() {
|
|
var err error
|
|
NetAddr, err = net.ResolveUDPAddr("udp4", AddrString)
|
|
if err != nil {
|
|
log.Panicf("Could not resolve %s: %s", AddrString, err)
|
|
}
|
|
}
|
|
|
|
type badStringError struct {
|
|
what string
|
|
str string
|
|
}
|
|
|
|
func (e *badStringError) Error() string { return fmt.Sprintf("%s %q", e.what, e.str) }
|
|
|
|
func ReadRequest(b *bufio.Reader) (req *http.Request, err error) {
|
|
tp := textproto.NewReader(b)
|
|
var s string
|
|
if s, err = tp.ReadLine(); err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() {
|
|
if err == io.EOF {
|
|
err = io.ErrUnexpectedEOF
|
|
}
|
|
}()
|
|
|
|
var f []string
|
|
// TODO a split that only allows N values?
|
|
if f = strings.SplitN(s, " ", 3); len(f) < 3 {
|
|
return nil, &badStringError{"malformed request line", s}
|
|
}
|
|
if f[1] != "*" {
|
|
return nil, &badStringError{"bad URL request", f[1]}
|
|
}
|
|
req = &http.Request{
|
|
Method: f[0],
|
|
}
|
|
var ok bool
|
|
if req.ProtoMajor, req.ProtoMinor, ok = http.ParseHTTPVersion(strings.TrimSpace(f[2])); !ok {
|
|
return nil, &badStringError{"malformed HTTP version", f[2]}
|
|
}
|
|
|
|
mimeHeader, err := tp.ReadMIMEHeader()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header = http.Header(mimeHeader)
|
|
return
|
|
}
|
|
|
|
type Server struct {
|
|
conn *net.UDPConn
|
|
Interface net.Interface
|
|
Server string
|
|
Services []string
|
|
Devices []string
|
|
Location func(net.IP) string
|
|
UUID string
|
|
NotifyInterval time.Duration
|
|
closed chan struct{}
|
|
}
|
|
|
|
func makeConn(ifi net.Interface) (ret *net.UDPConn, err error) {
|
|
ret, err = net.ListenMulticastUDP("udp", &ifi, NetAddr)
|
|
if err != nil {
|
|
return
|
|
}
|
|
p := ipv4.NewPacketConn(ret)
|
|
if err := p.SetMulticastTTL(2); err != nil {
|
|
log.Println(err)
|
|
}
|
|
if err := p.SetMulticastLoopback(true); err != nil {
|
|
log.Println(err)
|
|
}
|
|
return
|
|
}
|
|
|
|
func (me *Server) serve() {
|
|
for {
|
|
b := make([]byte, me.Interface.MTU)
|
|
n, addr, err := me.conn.ReadFromUDP(b)
|
|
select {
|
|
case <-me.closed:
|
|
return
|
|
default:
|
|
}
|
|
if err != nil {
|
|
log.Printf("error reading from UDP socket: %s", err)
|
|
break
|
|
}
|
|
go me.handle(b[:n], addr)
|
|
}
|
|
}
|
|
|
|
func (me *Server) Init() (err error) {
|
|
me.closed = make(chan struct{})
|
|
me.conn, err = makeConn(me.Interface)
|
|
return
|
|
}
|
|
|
|
func (me *Server) Close() {
|
|
close(me.closed)
|
|
me.sendByeBye()
|
|
me.conn.Close()
|
|
}
|
|
|
|
func (me *Server) Serve() (err error) {
|
|
go me.serve()
|
|
for {
|
|
addrs, err := me.Interface.Addrs()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, addr := range addrs {
|
|
ip := func() net.IP {
|
|
switch val := addr.(type) {
|
|
case *net.IPNet:
|
|
return val.IP
|
|
case *net.IPAddr:
|
|
return val.IP
|
|
}
|
|
panic(fmt.Sprint("unexpected addr type:", addr))
|
|
}()
|
|
extraHdrs := [][2]string{
|
|
{"CACHE-CONTROL", fmt.Sprintf("max-age=%d", 5*me.NotifyInterval/2/time.Second)},
|
|
{"LOCATION", me.Location(ip)},
|
|
}
|
|
me.notifyAll(aliveNTS, extraHdrs)
|
|
}
|
|
time.Sleep(me.NotifyInterval)
|
|
}
|
|
}
|
|
|
|
func (me *Server) usnFromTarget(target string) string {
|
|
if target == me.UUID {
|
|
return target
|
|
}
|
|
return me.UUID + "::" + target
|
|
}
|
|
|
|
func (me *Server) makeNotifyMessage(target, nts string, extraHdrs [][2]string) []byte {
|
|
lines := [...][2]string{
|
|
{"HOST", AddrString},
|
|
{"NT", target},
|
|
{"NTS", nts},
|
|
{"SERVER", me.Server},
|
|
{"USN", me.usnFromTarget(target)},
|
|
}
|
|
buf := &bytes.Buffer{}
|
|
fmt.Fprint(buf, "NOTIFY * HTTP/1.1\r\n")
|
|
writeHdr := func(keyValue [2]string) {
|
|
fmt.Fprintf(buf, "%s: %s\r\n", keyValue[0], keyValue[1])
|
|
}
|
|
for _, pair := range lines {
|
|
writeHdr(pair)
|
|
}
|
|
for _, pair := range extraHdrs {
|
|
writeHdr(pair)
|
|
}
|
|
fmt.Fprint(buf, "\r\n")
|
|
return buf.Bytes()
|
|
}
|
|
|
|
func (me *Server) send(buf []byte, addr *net.UDPAddr) {
|
|
if n, err := me.conn.WriteToUDP(buf, addr); err != nil {
|
|
log.Printf("error writing to UDP socket: %s", err)
|
|
} else if n != len(buf) {
|
|
log.Printf("short write: %d/%d bytes", n, len(buf))
|
|
}
|
|
}
|
|
|
|
func (me *Server) delayedSend(delay time.Duration, buf []byte, addr *net.UDPAddr) {
|
|
go func() {
|
|
select {
|
|
case <-time.After(delay):
|
|
me.send(buf, addr)
|
|
case <-me.closed:
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (me *Server) log(args ...interface{}) {
|
|
args = append([]interface{}{me.Interface.Name + ":"}, args...)
|
|
log.Print(args...)
|
|
}
|
|
|
|
func (me *Server) sendByeBye() {
|
|
for _, type_ := range me.allTypes() {
|
|
buf := me.makeNotifyMessage(type_, byebyeNTS, nil)
|
|
me.send(buf, NetAddr)
|
|
}
|
|
}
|
|
|
|
func (me *Server) notifyAll(nts string, extraHdrs [][2]string) {
|
|
for _, type_ := range me.allTypes() {
|
|
buf := me.makeNotifyMessage(type_, nts, extraHdrs)
|
|
delay := time.Duration(rand.Int63n(int64(100 * time.Millisecond)))
|
|
me.delayedSend(delay, buf, NetAddr)
|
|
}
|
|
}
|
|
|
|
func (me *Server) allTypes() (ret []string) {
|
|
for _, a := range [][]string{
|
|
{rootDevice, me.UUID},
|
|
me.Devices,
|
|
me.Services,
|
|
} {
|
|
ret = append(ret, a...)
|
|
}
|
|
return
|
|
}
|
|
|
|
func (me *Server) handle(buf []byte, sender *net.UDPAddr) {
|
|
req, err := ReadRequest(bufio.NewReader(bytes.NewReader(buf)))
|
|
if err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
if req.Method != "M-SEARCH" || req.Header.Get("man") != `"ssdp:discover"` {
|
|
return
|
|
}
|
|
var mx uint
|
|
if req.Header.Get("Host") == AddrString {
|
|
mxHeader := req.Header.Get("mx")
|
|
i, err := strconv.ParseUint(mxHeader, 0, 0)
|
|
if err != nil {
|
|
log.Printf("Invalid mx header %q: %s", mxHeader, err)
|
|
return
|
|
}
|
|
mx = uint(i)
|
|
} else {
|
|
mx = 1
|
|
}
|
|
types := func(st string) []string {
|
|
if st == "ssdp:all" {
|
|
return me.allTypes()
|
|
}
|
|
for _, t := range me.allTypes() {
|
|
if t == st {
|
|
return []string{t}
|
|
}
|
|
}
|
|
return nil
|
|
}(req.Header.Get("st"))
|
|
for _, ip := range func() (ret []net.IP) {
|
|
addrs, err := me.Interface.Addrs()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
for _, addr := range addrs {
|
|
if ip, ok := func() (net.IP, bool) {
|
|
switch data := addr.(type) {
|
|
case *net.IPNet:
|
|
if data.Contains(sender.IP) {
|
|
return data.IP, true
|
|
}
|
|
return nil, false
|
|
case *net.IPAddr:
|
|
return data.IP, true
|
|
}
|
|
panic(addr)
|
|
}(); ok {
|
|
ret = append(ret, ip)
|
|
}
|
|
}
|
|
return
|
|
}() {
|
|
for _, type_ := range types {
|
|
resp := me.makeResponse(ip, type_, req)
|
|
delay := time.Duration(rand.Int63n(int64(time.Second) * int64(mx)))
|
|
me.delayedSend(delay, resp, sender)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (me *Server) makeResponse(ip net.IP, targ string, req *http.Request) (ret []byte) {
|
|
resp := &http.Response{
|
|
StatusCode: 200,
|
|
ProtoMajor: 1,
|
|
ProtoMinor: 1,
|
|
Header: make(http.Header),
|
|
Request: req,
|
|
}
|
|
for _, pair := range [...][2]string{
|
|
{"CACHE-CONTROL", fmt.Sprintf("max-age=%d", 5*me.NotifyInterval/2/time.Second)},
|
|
{"EXT", ""},
|
|
{"LOCATION", me.Location(ip)},
|
|
{"SERVER", me.Server},
|
|
{"ST", targ},
|
|
{"USN", me.usnFromTarget(targ)},
|
|
} {
|
|
resp.Header.Set(pair[0], pair[1])
|
|
}
|
|
buf := &bytes.Buffer{}
|
|
if err := resp.Write(buf); err != nil {
|
|
panic(err)
|
|
}
|
|
return buf.Bytes()
|
|
}
|