Clean up tracking of service advertisements and conversion to Service objects.

This commit is contained in:
Dan Williams 2006-06-13 15:25:54 -04:00
parent 3e666c005f
commit 26ee2d57d8
2 changed files with 114 additions and 65 deletions

View File

@ -25,6 +25,39 @@ def _get_local_ip_address(ifname):
return addr return addr
class ServiceAdv(object):
"""Wrapper class for service attributes that Avahi passes back."""
def __init__(self, interface, protocol, name, stype, domain):
self._interface = interface
self._protocol = protocol
self._name = name
self._stype = stype
self._domain = domain
self._service = None
self._resolved = False
def interface(self):
return self._interface
def protocol(self):
return self._protocol
def name(self):
return self._name
def stype(self):
return self._stype
def domain(self):
return self._domain
def service(self):
return self._service
def set_service(self, service):
if not isinstance(service, Service.Service):
raise ValueError("must be a valid service.")
self._service = service
def resolved(self):
return self._resolved
def set_resolved(self, resolved):
self._resolved = resolved
class PresenceService(gobject.GObject): class PresenceService(gobject.GObject):
"""Object providing information about the presence of Buddies """Object providing information about the presence of Buddies
and what activities they make available to others.""" and what activities they make available to others."""
@ -74,11 +107,8 @@ class PresenceService(gobject.GObject):
self._service_type_browsers = {} self._service_type_browsers = {}
self._service_browsers = {} self._service_browsers = {}
# We only resolve services that our clients are interested in; # Resolved service list
# but we store unresolved services so that when a client does self._service_advs = []
# become interested in a new service type, we can quickly
# resolve it
self._unresolved_services = []
self._bus = dbus.SystemBus() self._bus = dbus.SystemBus()
self._server = dbus.Interface(self._bus.get_object(avahi.DBUS_NAME, self._server = dbus.Interface(self._bus.get_object(avahi.DBUS_NAME,
@ -116,22 +146,24 @@ class PresenceService(gobject.GObject):
def _resolve_service_error_handler(self, err): def _resolve_service_error_handler(self, err):
self._log("error resolving service: %s" % err) self._log("error resolving service: %s" % err)
def _find_service(self, slist, name=None, stype=None, domain=None, address=None, port=None): def _find_service_adv(self, interface=None, protocol=None, name=None, stype=None, domain=None):
"""Search a list of services for ones matching certain criteria.""" """Search a list of service advertisements for ones matching certain criteria."""
found = [] adv_list = []
for service in slist: for adv in self._service_advs:
if name and service.get_name() != name: if interface and adv.interface() != interface:
continue continue
if stype and service.get_type() != stype: if protocol and adv.protocol() != protocol:
continue continue
if domain and service.get_domain() != domain: if name and adv.name() != name:
continue continue
if address and service.get_address() != address: if stype and adv.stype() != stype:
continue continue
if port and service.get_port() != port: if domain and adv.domain() != domain:
continue continue
found.append(service) adv_list.append(adv)
return found if not len(adv_list):
return None
return adv_list
def _is_special_service_type(self, stype): def _is_special_service_type(self, stype):
"""Return True if the service type is a special, internal service """Return True if the service type is a special, internal service
@ -146,13 +178,13 @@ class PresenceService(gobject.GObject):
"""Deal with a new discovered service object.""" """Deal with a new discovered service object."""
# Once a service is resolved, we match it up to an existing buddy, # Once a service is resolved, we match it up to an existing buddy,
# or create a new Buddy if this is the first service known about the buddy # or create a new Buddy if this is the first service known about the buddy
added = was_valid = False buddy_was_valid = False
name = service.get_name() name = service.get_name()
buddy = None buddy = None
try: try:
buddy = self._buddies[name] buddy = self._buddies[name]
was_valid = buddy.is_valid() buddy_was_valid = buddy.is_valid()
added = buddy.add_service(service) buddy.add_service(service)
except KeyError: except KeyError:
# Should this service mark the owner? # Should this service mark the owner?
if service.get_address() in self._local_addrs.values(): if service.get_address() in self._local_addrs.values():
@ -161,20 +193,16 @@ class PresenceService(gobject.GObject):
else: else:
buddy = Buddy.Buddy(service) buddy = Buddy.Buddy(service)
self._buddies[name] = buddy self._buddies[name] = buddy
added = True if not buddy_was_valid and buddy.is_valid():
if not was_valid and buddy.is_valid():
self.emit("buddy-appeared", buddy) self.emit("buddy-appeared", buddy)
return buddy return buddy
def _handle_new_service_for_group(self, service, buddy): def _handle_new_service_for_group(self, service, buddy):
# If the serivce is a group service, merge it into our groups list # If the serivce is a group service, merge it into our groups list
if not buddy:
return
group = None group = None
if not self._groups.has_key(service.get_type()): if not self._groups.has_key(service.get_type()):
group = Group.Group(service) group = Group.Group(service)
else: self._groups[service.get_type()] = group
group = self._groups[service.get_type()]
def _resolve_service_reply_cb(self, interface, protocol, name, stype, domain, host, aprotocol, address, port, txt, flags): def _resolve_service_reply_cb(self, interface, protocol, name, stype, domain, host, aprotocol, address, port, txt, flags):
"""When the service discovery finally gets here, we've got enough information about the """When the service discovery finally gets here, we've got enough information about the
@ -183,22 +211,20 @@ class PresenceService(gobject.GObject):
# If this service was previously unresolved, remove it from the # If this service was previously unresolved, remove it from the
# unresolved list # unresolved list
found = self._find_service(self._unresolved_services, name=name, adv_list = self._find_service_adv(interface=interface, protocol=protocol, name=name,
stype=stype, domain=domain) stype=stype, domain=domain)
if not len(found): if not adv_list:
return False return False
for service in found: adv = adv_list[0]
self._unresolved_services.remove(service) adv.set_resolved(True)
# Update the service now that it's been resolved # Update the service now that it's been resolved
service = found[0] service = Service.Service(name, stype, domain, address, port, txt)
service.set_address(address) adv.set_service(service)
service.set_port(port)
service.set_properties(txt)
# Merge the service into our buddy and group lists, if needed # Merge the service into our buddy and group lists, if needed
buddy = self._handle_new_service_for_buddy(service) buddy = self._handle_new_service_for_buddy(service)
if service.is_group_service(): if buddy and service.is_group_service():
self._handle_new_service_for_group(service, buddy) self._handle_new_service_for_group(service, buddy)
return False return False
@ -220,11 +246,12 @@ class PresenceService(gobject.GObject):
self._log("found service '%s' (%d) of type '%s' in domain '%s' on %i.%i." % (name, flags, stype, domain, interface, protocol)) self._log("found service '%s' (%d) of type '%s' in domain '%s' on %i.%i." % (name, flags, stype, domain, interface, protocol))
# Add the service to our unresolved services list # Add the service to our unresolved services list
found = self._find_service(self._unresolved_services, name=name, adv_list = self._find_service_adv(interface=interface, protocol=protocol,
name=name, stype=stype, domain=domain)
if not adv_list:
adv = ServiceAdv(interface=interface, protocol=protocol, name=name,
stype=stype, domain=domain) stype=stype, domain=domain)
if not len(found): self._service_advs.append(adv)
service = Service.Service(name, stype, domain)
self._unresolved_services.append(service)
# Find out the IP address of this interface, if we haven't already # Find out the IP address of this interface, if we haven't already
if interface not in self._local_addrs.keys(): if interface not in self._local_addrs.keys():
@ -245,19 +272,19 @@ class PresenceService(gobject.GObject):
def _service_disappeared_cb(self, interface, protocol, name, stype, domain, flags): def _service_disappeared_cb(self, interface, protocol, name, stype, domain, flags):
self._log("service '%s' of type '%s' in domain '%s' on %i.%i disappeared." % (name, stype, domain, interface, protocol)) self._log("service '%s' of type '%s' in domain '%s' on %i.%i disappeared." % (name, stype, domain, interface, protocol))
# If it's an unresolved service, remove it from our unresolved list # If it's an unresolved service, remove it from our unresolved list
found = self._find_service(self._unresolved_services, name=name, adv_list = self._find_service_adv(interface=interface, protocol=protocol,
stype=stype, domain=domain) name=name, stype=stype, domain=domain)
for service in found: if not adv_list:
self._unresolved_services.remove(service) return False
# Unresolved services by definition aren't assigned to a buddy # Unresolved services by definition aren't assigned to a buddy
if not len(found):
try: try:
# Remove the service from the buddy # Remove the service from the buddy
buddy = self._buddies[name] buddy = self._buddies[name]
# FIXME: need to be more careful about how we remove services # FIXME: need to be more careful about how we remove services
# from buddies; this could be spoofed # from buddies; this could be spoofed
service = buddy.get_service_of_type(stype) adv = adv_list[0]
service = adv.get_service()
buddy.remove_service(service) buddy.remove_service(service)
if not buddy.is_valid(): if not buddy.is_valid():
self.emit("buddy-disappeared", buddy) self.emit("buddy-disappeared", buddy)
@ -334,9 +361,10 @@ class PresenceService(gobject.GObject):
# Find unresolved services that match the service type # Find unresolved services that match the service type
# we're now interested in, and resolve them # we're now interested in, and resolve them
found = self._find_service(self._unresolved_services, stype=stype) adv_list = self._find_service_adv(stype=stype)
for service in found: for adv in adv_list:
gobject.idle_add(self._resolve_service, interface, protocol, name, stype, domain, flags) gobject.idle_add(self._resolve_service, adv.interface(),
adv.protocol(), adv.name(), adv.stype(), adv.domain(), 0)
def untrack_service_type(self, stype): def untrack_service_type(self, stype):
"""Stop tracking a certain mDNS service.""" """Stop tracking a certain mDNS service."""
@ -347,6 +375,14 @@ class PresenceService(gobject.GObject):
if name in self._allowed_service_types: if name in self._allowed_service_types:
self._allowed_service_types.remove(stype) self._allowed_service_types.remove(stype)
def join_group(self, group):
"""Convenience function to join a group and notify other buddies
that you are a member of it."""
if not isinstance(group, Group.Group):
raise ValueError("group was not a valid group.")
gservice = group.get_service()
self.register_service(service)
def register_service(self, service): def register_service(self, service):
"""Register a new service, advertising it to other Buddies on the network.""" """Register a new service, advertising it to other Buddies on the network."""
if not self._started: if not self._started:
@ -355,7 +391,7 @@ class PresenceService(gobject.GObject):
rs_name = service.get_name() rs_name = service.get_name()
rs_stype = service.get_type() rs_stype = service.get_type()
rs_port = service.get_port() rs_port = service.get_port()
if type(rs_port) != type(1) and rs_port <= 1024: if type(rs_port) != type(1) and (rs_port <= 1024 or rs_port > 65536):
raise ValueError("invalid service port.") raise ValueError("invalid service port.")
rs_props = service.get_properties() rs_props = service.get_properties()
rs_domain = service.get_domain() rs_domain = service.get_domain()

View File

@ -18,7 +18,7 @@ def _txt_to_dict(txt):
prop_dict[key] = value prop_dict[key] = value
return prop_dict return prop_dict
def _is_multicast_address(address): def is_multicast_address(address):
"""Simple numerical check for whether an IP4 address """Simple numerical check for whether an IP4 address
is in the range for multicast addresses or not.""" is in the range for multicast addresses or not."""
if not address: if not address:
@ -30,10 +30,13 @@ def _is_multicast_address(address):
return True return True
return False return False
__GROUP_UID_TAG = "GroupUID"
class Service(object): class Service(object):
"""Encapsulates information about a specific ZeroConf/mDNS """Encapsulates information about a specific ZeroConf/mDNS
service as advertised on the network.""" service as advertised on the network."""
def __init__(self, name, stype, domain, address=None, port=-1, properties=None): def __init__(self, name, stype, domain, address=None, port=-1, properties=None, group=None):
# Validate immutable options # Validate immutable options
if not name or (type(name) != type("") and type(name) != type(u"")) or not len(name): if not name or (type(name) != type("") and type(name) != type(u"")) or not len(name):
raise ValueError("must specify a valid service name.") raise ValueError("must specify a valid service name.")
@ -49,9 +52,12 @@ class Service(object):
raise ValueError("must use the 'local' domain (for now).") raise ValueError("must use the 'local' domain (for now).")
# Group services must have multicast addresses # Group services must have multicast addresses
if Group.is_group_service_type(stype) and address and not _is_multicast_address(address): if Group.is_group_service_type(stype) and address and not is_multicast_address(address):
raise ValueError("group service type specified, but address was not multicast.") raise ValueError("group service type specified, but address was not multicast.")
if group and not isinstance(group, Group.Group):
raise ValueError("group was not a valid group object.")
self._name = name self._name = name
self._stype = stype self._stype = stype
self._domain = domain self._domain = domain
@ -61,6 +67,9 @@ class Service(object):
self.set_port(port) self.set_port(port)
self._properties = {} self._properties = {}
self.set_properties(properties) self.set_properties(properties)
self._group = group
if group:
self._properties[__GROUP_UID_TAG] = group.get_uid()
def get_name(self): def get_name(self):
"""Return the service's name, usually that of the """Return the service's name, usually that of the
@ -70,7 +79,7 @@ class Service(object):
def is_multicast_service(self): def is_multicast_service(self):
"""Return True if the service's address is a multicast address, """Return True if the service's address is a multicast address,
False if it is not.""" False if it is not."""
return _is_multicast_address(self._address) return is_multicast_address(self._address)
def is_group_service(self): def is_group_service(self):
"""Return True if the service represents a Group, """Return True if the service represents a Group,
@ -122,7 +131,7 @@ class Service(object):
raise ValueError("must specify a valid address.") raise ValueError("must specify a valid address.")
if not len(address): if not len(address):
raise ValueError("must specify a valid address.") raise ValueError("must specify a valid address.")
if Group.is_group_service_type(self._stype) and not _is_multicast_address(address): if Group.is_group_service_type(self._stype) and not is_multicast_address(address):
raise ValueError("group service type specified, but address was not multicast.") raise ValueError("group service type specified, but address was not multicast.")
self._address = address self._address = address
@ -130,6 +139,10 @@ class Service(object):
"""Return the ZeroConf/mDNS domain the service was found in.""" """Return the ZeroConf/mDNS domain the service was found in."""
return self._domain return self._domain
def get_group(self):
"""Return the group this service is associated with, if any."""
return self._group
################################################################# #################################################################
# Tests # Tests