diff options
Diffstat (limited to 'netdisco/ssdp.py')
-rw-r--r-- | netdisco/ssdp.py | 290 |
1 files changed, 290 insertions, 0 deletions
diff --git a/netdisco/ssdp.py b/netdisco/ssdp.py new file mode 100644 index 0000000..55cb08b --- /dev/null +++ b/netdisco/ssdp.py @@ -0,0 +1,290 @@ +"""Module that implements SSDP protocol.""" +import re +import select +import socket +import logging +from datetime import datetime, timedelta +from xml.etree import ElementTree + +import requests +import zeroconf + +from netdisco.util import etree_to_dict + +DISCOVER_TIMEOUT = 2 +# MX is a suggested random wait time for a device to reply, so should be +# bound by our discovery timeout. +SSDP_MX = DISCOVER_TIMEOUT +SSDP_TARGET = ("239.255.255.250", 1900) + +RESPONSE_REGEX = re.compile(r'\n(.*?)\: *(.*)\r') + +MIN_TIME_BETWEEN_SCANS = timedelta(seconds=59) + +# Devices and services +ST_ALL = "ssdp:all" + +# Devices only, some devices will only respond to this query +ST_ROOTDEVICE = "upnp:rootdevice" + + +class SSDP: + """Control the scanning of uPnP devices and services and caches output.""" + + def __init__(self): + """Initialize the discovery.""" + self.entries = [] + self.last_scan = None + + def scan(self): + """Scan the network.""" + self.update() + + def all(self): + """Return all found entries. + + Will scan for entries if not scanned recently. + """ + self.update() + + return list(self.entries) + + # pylint: disable=invalid-name + def find_by_st(self, st): + """Return a list of entries that match the ST.""" + self.update() + + return [entry for entry in self.entries + if entry.st == st] + + def find_by_device_description(self, values): + """Return a list of entries that match the description. + + Pass in a dict with values to match against the device tag in the + description. + """ + self.update() + + seen = set() + results = [] + + # Make unique based on the location since we don't care about ST here + for entry in self.entries: + location = entry.location + + if location not in seen and entry.match_device_description(values): + results.append(entry) + seen.add(location) + + return results + + def update(self, force_update=False): + """Scan for new uPnP devices and services.""" + if self.last_scan is None or force_update or \ + datetime.now()-self.last_scan > MIN_TIME_BETWEEN_SCANS: + + self.remove_expired() + + self.entries.extend( + entry for entry in scan() + if entry not in self.entries) + + self.last_scan = datetime.now() + + def remove_expired(self): + """Filter out expired entries.""" + self.entries = [entry for entry in self.entries + if not entry.is_expired] + + +class UPNPEntry: + """Found uPnP entry.""" + + DESCRIPTION_CACHE = {'_NO_LOCATION': {}} + + def __init__(self, values): + """Initialize the discovery.""" + self.values = values + self.created = datetime.now() + + if 'cache-control' in self.values: + cache_directive = self.values['cache-control'] + max_age = re.findall(r'max-age *= *\d+', cache_directive) + if max_age: + cache_seconds = int(max_age[0].split('=')[1]) + self.expires = self.created + timedelta(seconds=cache_seconds) + else: + self.expires = None + else: + self.expires = None + + @property + def is_expired(self): + """Return if the entry is expired or not.""" + return self.expires is not None and datetime.now() > self.expires + + # pylint: disable=invalid-name + @property + def st(self): + """Return ST value.""" + return self.values.get('st') + + @property + def location(self): + """Return Location value.""" + return self.values.get('location') + + @property + def description(self): + """Return the description from the uPnP entry.""" + url = self.values.get('location', '_NO_LOCATION') + + if url not in UPNPEntry.DESCRIPTION_CACHE: + try: + xml = requests.get(url, timeout=5).text + if not xml: + # Samsung Smart TV sometimes returns an empty document the + # first time. Retry once. + xml = requests.get(url, timeout=5).text + + tree = ElementTree.fromstring(xml) + + UPNPEntry.DESCRIPTION_CACHE[url] = \ + etree_to_dict(tree).get('root', {}) + except requests.RequestException: + logging.getLogger(__name__).debug( + "Error fetching description at %s", url) + + UPNPEntry.DESCRIPTION_CACHE[url] = {} + + except ElementTree.ParseError: + logging.getLogger(__name__).debug( + "Found malformed XML at %s: %s", url, xml) + + UPNPEntry.DESCRIPTION_CACHE[url] = {} + + return UPNPEntry.DESCRIPTION_CACHE[url] + + def match_device_description(self, values): + """Fetch description and matches against it. + + Values should only contain lowercase keys. + """ + device = self.description.get('device') + + if device is None: + return False + + return all(device.get(key) in val + if isinstance(val, list) + else val == device.get(key) + for key, val in values.items()) + + @classmethod + def from_response(cls, response): + """Create a uPnP entry from a response.""" + return UPNPEntry({key.lower(): item for key, item + in RESPONSE_REGEX.findall(response)}) + + def __eq__(self, other): + """Return the comparison.""" + return (self.__class__ == other.__class__ and + self.values == other.values) + + def __repr__(self): + """Return the entry.""" + return "<UPNPEntry {} - {}>".format(self.location or '', self.st or '') + + +def ssdp_request(ssdp_st, ssdp_mx=SSDP_MX): + """Return request bytes for given st and mx.""" + return "\r\n".join([ + 'M-SEARCH * HTTP/1.1', + 'ST: {}'.format(ssdp_st), + 'MX: {:d}'.format(ssdp_mx), + 'MAN: "ssdp:discover"', + 'HOST: {}:{}'.format(*SSDP_TARGET), + '', '']).encode('utf-8') + + +# pylint: disable=invalid-name,too-many-locals,too-many-branches +def scan(timeout=DISCOVER_TIMEOUT): + """Send a message over the network to discover uPnP devices. + + Inspired by Crimsdings + https://github.com/crimsdings/ChromeCast/blob/master/cc_discovery.py + + Protocol explanation: + https://embeddedinn.wordpress.com/tutorials/upnp-device-architecture/ + """ + ssdp_requests = ssdp_request(ST_ALL), ssdp_request(ST_ROOTDEVICE) + + stop_wait = datetime.now() + timedelta(seconds=timeout) + + sockets = [] + for addr in zeroconf.get_all_addresses(): + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + + # Set the time-to-live for messages for local network + sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, + SSDP_MX) + sock.bind((addr, 0)) + sockets.append(sock) + except socket.error: + pass + + entries = {} + for sock in [s for s in sockets]: + try: + for req in ssdp_requests: + sock.sendto(req, SSDP_TARGET) + sock.setblocking(False) + except socket.error: + sockets.remove(sock) + sock.close() + + try: + while sockets: + time_diff = stop_wait - datetime.now() + seconds_left = time_diff.total_seconds() + if seconds_left <= 0: + break + + ready = select.select(sockets, [], [], seconds_left)[0] + + for sock in ready: + try: + data, address = sock.recvfrom(1024) + response = data.decode("utf-8") + except UnicodeDecodeError: + logging.getLogger(__name__).debug( + 'Ignoring invalid unicode response from %s', address) + continue + except socket.error: + logging.getLogger(__name__).exception( + "Socket error while discovering SSDP devices") + sockets.remove(sock) + sock.close() + continue + + entry = UPNPEntry.from_response(response) + entries[(entry.st, entry.location)] = entry + + finally: + for s in sockets: + s.close() + + return sorted(entries.values(), key=lambda entry: entry.location or '') + + +def main(): + """Test SSDP discovery.""" + from pprint import pprint + + print("Scanning SSDP..") + pprint(scan()) + + +if __name__ == "__main__": + main() |