Julia’s multiple dispatch feature sounds interesting. I’m going to implement it in this post.

Describing the Problem

I’m going to assume you know a bit about subtyping and function overloading. But here’s a simple explanation of it: It’s a way to relate types. For example, if type B is a subtype of type A, that means B is a kind of A. So, you can use B anywhere you need an A.

And function overloading is a way to define multiple functions with the same name but different parameters. The function that is called is the one that matches these parameters the “best”.

Let’s say we have a subtyping hierarchy like this:

abstract type Number
struct Complex <: Number
struct Real <: Complex

And function like this:

function foo(x::Number, y::Number)
    println("Number, Number")
end

function foo(x::Complex, y::Complex)
    println("Complex, Complex")
end

function foo(x::Real, y::Real)
    println("Real, Real")
end

We want to select the most specific function for the given arguments. For example, if we call it with foo(Real(), Real()) we call the last definition; if we call it with foo(Complex(), Number()) we call the first definition because we can’t pass a Number where we need a Complex. The thing is, when we called foo(Real(), Real()) it would be totally fine to call the second definition; the code would work just fine. So we not only want to find a method that conforms to the arguments but also a way to rank them.

Modeling Subtyping

Let’s start by modeling subtyping first. I want something like this:

any = Type.new("Any")
number = Type.new("Number", any)
complex = Type.new("Complex", number)
real = Type.new("Real", complex)
string = Type.new("String", any)

puts real.is?(number) # true
puts real.is?(complex) # true
puts real.is?(real) # true
puts real.is?(string) # false
puts real.is?(any) # true
puts string.is?(any) # true

All types are subtype of Any which has no supertype. You’ll notice that a type is a subtype of itself. Why? I can use the type Real anywhere where I need a Real. And Any is also a supertype of all types, which includes itself, which I think is quite nice.

We could implement this API like this:

class Type
  attr_reader :name, :supertype

  def initialize(name, supertype = nil)
    @name = name
    @supertype = supertype
  end

  def is?(type)
    # Every type is a subtype of itself.
    return true if type == self
    # Any has no supertype.
    return false if @supertype.nil?
    # Is `type` my grandparent?
    @supertype.is?(type)
  end

  def ==(type)
    @name == type.name
  end
end

Modeling Function Signatures

A signature is just a list of types.

class Signature
  attr_reader :types

  def initialize(types)
    @types = types
  end

  def ==(signature)
    @types == signature.types
  end
end

We need a way to know if it is legal to call a signature with a given list of argument types. Let’s look at a simple case:

function f(x::Real) end

We obviously shouldn’t be able to call this function with f(Complex()) because Complex is not a subtype of Real. This generalizes to multiple arguments as well. We should be able to call f(Real(), Real()) but not f(Complex(), Real()).

The first argument isn’t special. We should be able to call a function with a list of types (t1, t2, ..., tn) where the signature of the function is (s1, s2, ..., sn) if every t is a subtype of the corresponding type s.

Here’s the implementation:

class Signature
  # Other stuff...

  def conforms?(signature)
    # for a signature to conform to this one:
    # 1. it must have the same number of types
    return false if signature.types.length != @types.length
    # 2. each type must be a subtype of the corresponding type in this signature
    @types.zip(signature.types).all? { |a, b| a.is?(b) }
  end
end

Ranking Conforming Signatures

Now we come to the meat of the problem. How do we select the most specific function for a given list of argument types? We need a way to rank them. Let’s look at the simplest case again.

function f(x::Number) end
function f(x::Complex) end
f(Real())

If I asked you which function should be called, you would say the second one, right? Why? Well, when we consider the subtype hierarchy Real <: Complex <: Number <: Any, the closest to type Real is Complex so you choose that. This is the main idea behind the ranking algorithm.

We need a way to get the distance from the given type:

class Type
  # Other stuff...

  def distance(type)
    # A String can't be a subtype of a Number or vice versa.
    raise "Not a subtype" unless is?(type)
    # You are 0 distance away from yourself.
    return 0 if self == type
    # We are whatever distance `type` away from our supertype + 1.
    # Example:
    #   real.distance(any) = 1 + complex.distance(any)
    #                      = 1 + 1 + number.distance(any)
    #                      = 1 + 1 + 1 + any.distance(any)
    #                      = 1 + 1 + 1 + 0
    #                      = 3
    1 + @supertype.distance(type)
  end
end

This can rank single argument functions. Supporting multiple arguments is straightforward. It’s like Euclidean distance: $\sqrt{(2 - 3)^2 + (4 - 5)^2}$ but easier. Find the element-wise distance between each argument and sum them up like $\lvert 2 - 3 \rvert + \lvert 4 - 5 \rvert$

class Signature
  # Other stuff...

  def distance(signature)
    @types.zip(signature.types).sum { |a, b| a.distance(b) }
  end
end

This is called Manhattan distance.

If we put this to work:

any = Type.new("Any")
number = Type.new("Number", any)
complex = Type.new("Complex", number)
real = Type.new("Real", complex)

f1 = Signature.new([number, number])
f2 = Signature.new([complex, complex])
f3 = Signature.new([real, real])

call_signature = Signature.new([real, real])
puts [f1, f2, f3].min_by { |f| call_signature.distance(f) } == f3 # true

A function isn’t just a signature though, it also has a name.

class Function
  attr_accessor :name, :signature

  def initialize(name, signature)
    @name = name
    @signature = signature
  end

  def to_s
    "#{@name}#{@signature}"
  end

  def ==(other)
    @name == other.name && @signature == other.signature
  end
end

And we’ll have a table of functions that contains all the definitions and gives the most specific one for a given call.

class FunctionTable
  attr_accessor :functions

  def initialize()
    @functions = []
  end

  def add(function)
    raise "Function already exists" if @functions.include?(function)
    @functions << function
  end

  def find(function)
    # find all the signatures that conform to the given signature.
    candidates = @functions.select { |m| function.signature.conforms?(m.signature) && m.name == function.name }
    # sort them by distance from closest to furthest.
    sorted_by_distance = candidates.sort_by { |m| function.signature.distance(m.signature) }

    # find the closest one.
    # There may be more than one with the same distance, so we find all of them.
    distances = sorted_by_distance.map { |m| function.signature.distance(m.signature) }
    min_distance = distances.min
    closest_functions = sorted_by_distance.select { |m| function.signature.distance(m.signature) == min_distance }
    raise "Ambiguous function call between #{closest_functions}" if closest_functions.length > 1
    closest_functions.first
  end
end

The most interesting method is find.

  1. It finds all the functions with the same name and conforming signature.
  2. Sorts them by distance.
  3. Finds the closest one.
  4. If there are more than one with the same distance, it raises an error. You could also return the second closest method instead of raising an error. (I wonder if there is a metric where two different point can’t have the same distance from a different third point).

If we put it all together:

any = Type.new("Any")
number = Type.new("Number", any)
complex = Type.new("Complex", number)
real = Type.new("Real", complex)

f1 = Function.new("f", Signature.new([number, number]))
f2 = Function.new("f", Signature.new([complex, complex]))
f3 = Function.new("f", Signature.new([real, real]))

table = FunctionTable.new
table.add(f1)
table.add(f2)
table.add(f3)

call_signature = Signature.new([real, real])
puts table.find(Function.new("f", call_signature)) == f3 # true

Full Code

Further Reading