Yume
overload.cpp
Go to the documentation of this file.
1#include "overload.hpp"
2#include "ast/ast.hpp"
4#include "ty/type.hpp"
5#include "util.hpp"
6#include <algorithm>
7#include <compare>
8#include <llvm/ADT/STLExtras.h>
9#include <llvm/Support/Casting.h>
10#include <llvm/Support/raw_ostream.h>
11#include <map>
12#include <stdexcept>
13#include <string>
14#include <utility>
15
16namespace yume::semantic {
17
18inline static constexpr auto get_val_ty = [](const ast::AST* ast) { return ast->val_ty(); };
19inline static constexpr auto indirect = [](const ty::Type& ty) { return &ty; };
20
21static auto join_args(const auto& iter, auto fn, llvm::raw_ostream& stream = errs()) {
22 for (auto& i : llvm::enumerate(iter)) {
23 if (i.index() != 0)
24 stream << ", ";
25 stream << fn(i.value())->name();
26 }
27}
28
29static auto overload_name(const ast::AST* ast) -> std::string {
30 if (const auto* call = dyn_cast<ast::CallExpr>(ast))
31 return (call->receiver.has_value() ? (call->receiver->ensure_ty().name() + ".") : ""s) + call->name;
32 if (const auto* ctor = dyn_cast<ast::CtorExpr>(ast))
33 return ctor->ensure_ty().name() + ":new";
34
35 llvm_unreachable("Cannot evaluate overload set against non-call, non-ctor");
36}
37
38static auto overload_receiver(const ast::AST* ast) -> optional<ty::Type> {
39 if (const auto* call = dyn_cast<ast::CallExpr>(ast))
40 return call->receiver.has_value() ? call->receiver->ensure_ty() : optional<ty::Type>{};
41 if (const auto* ctor = dyn_cast<ast::CtorExpr>(ast))
42 return ctor->ensure_ty();
43
44 llvm_unreachable("Cannot evaluate overload set against non-call, non-ctor");
45}
46
47void Overload::dump(llvm::raw_ostream& stream) const {
48 stream << fn->ast().location().to_string() << "\t";
49 if (fn->self_ty.has_value())
50 stream << fn->self_ty->name() << ".";
51 stream << fn->name() << "(";
52 join_args(fn->arg_types(), indirect, stream);
53 stream << ")";
54 if (!subs.empty()) {
55 stream << " with ";
56 int i = 0;
57 for (const auto& [k, v] : subs.mapping()) {
58 if (i++ > 0)
59 stream << ", ";
60
61 stream << k->name << " = " << v->name();
62 }
63 }
64}
65
66void OverloadSet::dump(llvm::raw_ostream& stream, bool hide_invalid) const {
67 stream << overload_name(call) << "(";
68 join_args(args, get_val_ty, stream);
69 stream << ")\n";
70 for (const auto& i_s : overloads) {
71 if (hide_invalid && !i_s.viable)
72 continue;
73 stream << " ";
74 i_s.dump(stream);
75 stream << "\n";
76 }
77}
78
79static auto literal_cast(ast::AST& arg, ty::Type target_type) -> ty::Compat {
80 if (arg.val_ty() == target_type)
81 return {.valid = true}; // Already the correct type
82
83 if (isa<ast::NumberExpr>(arg) && target_type.base_isa<ty::Int>()) {
84 auto& num_arg = cast<ast::NumberExpr>(arg);
85 const auto* int_type = target_type.base_cast<ty::Int>();
86
87 if (int_type->size() == 1)
88 return {}; // Can't implicitly cast to Bool
89
90 auto in_range = int_type->in_range(num_arg.val);
91
92 if (in_range)
93 return {.valid = true, .conv = {.dereference = false, .kind = ty::Conv::Int}};
94 }
95
96 return {};
97}
98
99auto parameter_count_matches(const vector<ast::AST*>& args, const Fn& fn) -> bool {
100 if (args.size() == fn.arg_count())
101 return true;
102
103 // Varargs functions may have more arguments than the amount of non-vararg parameters.
104 if (args.size() > fn.arg_count() && fn.varargs())
105 return true;
106
107 return false;
108}
109
110static auto generic_base(optional<ty::Type> type) -> optional<ty::Type> {
111 if (type.has_value())
112 return type->generic_base();
113 return type;
114}
115
117 const auto& fn = *overload.fn;
118 auto parent = generic_base(fn.self_ty);
119
120 auto receiver = overload_receiver(call);
121
122 // Check if the call has a receiver matching the type of the struct this method is in.
123 // If the call has a receiver, it will always fail to match against against a top level function.
124 // Note that a receiver is always a type, such as `Foo.method`. Calls with an "object" as a receiver look similar,
125 // but `foo.method` is always rewritten to `method(foo)` and thus uses the "argument dependent lookup" rules below.
126 if (generic_base(receiver) != parent) {
127 if (!parent.has_value()) {
128 notes->emit(overload.location()) << "Overload not considered due to ADL";
129 notes->emit(overload.location()) << " Because no receiver was specified";
130 return false;
131 }
132
133 // If there is no matching receiver, check if any arguments are of the type of the struct.
134 // This perform "argument dependent lookup" and is required for "member functions"
135 if (std::ranges::none_of(args, [parent](ast::AST* ast) {
136 return ast->ensure_ty().without_mut().without_opaque().generic_base() == parent->generic_base();
137 })) {
138 notes->emit(overload.location()) << "Overload not considered due to ADL";
139 for (const auto* ast : args) {
140 notes->emit(overload.location()) << " Because `" << ast->ensure_ty().without_mut().without_opaque().name()
141 << "' is not `" << parent->name() << "'";
142 }
143 return false;
144 }
145 }
146
147 // The overload is only viable if the amount of arguments matches the amount of parameters.
148 if (!parameter_count_matches(args, fn)) {
149 notes->emit(overload.location()) << "Overload not considered due to mismatch in parameter count";
150 return false;
151 }
152
153 if (receiver.has_value() && receiver->base_isa<ty::Struct>())
154 overload.subs = *receiver->base_cast<ty::Struct>()->subs();
155
156 overload.compatibilities.reserve(args.size());
157
158 // Look at generic parameters first in order to deduce generic parameters. This is done first, since later arguments
159 // could change the deduced type via intersection, which would lead to mismatching types, if deduction was done at
160 // the same time as substitution.
161 // As `llvm::zip` only iterates up to the size of the shorter argument, we don't try to deduce anything about the
162 // "variadic" part of varargs functions, since variadics don't yet carry any type information. This will change in the
163 // future.
164 for (const auto& [param_type_r, arg] : llvm::zip_first(fn.arg_types(), args)) {
165 auto arg_type = arg->ensure_ty();
166 ty::Type param_type = param_type_r;
167
168 if (param_type.is_generic()) {
169 auto new_subs = arg_type.determine_generic_subs(param_type, overload.subs);
170
171 if (!new_subs) {
172 notes->emit(overload.location()) << "Overload not valid";
173 notes->emit(overload.location()) << " Because the generic variables of `" << param_type_r.name()
174 << "' weren't able to be determined"; // TODO(rymiel): why?
175 return false;
176 }
177 overload.subs = *new_subs;
178 }
179 }
180
181 if (!overload.subs.fully_substituted()) {
182 notes->emit(overload.location()) << "Overload not valid because not all generic parameters could be deduced";
183 overload.subs.dump(*notes->buffer_stream);
184 // TODO(rymiel): Try to find which ones failed
185 return false;
186 }
187
188 // Determine the type compatibility of each argument individually. The performed conversions are also recorded for
189 // each step.
190 // As `llvm::zip` only iterates up to the size of the shorter argument, we don't try to determine type
191 // compatibility of the "variadic" part of varargs functions. Currently, varargs methods can only be primitives and
192 // carry no type information for their variadic part. This will change in the future.
193 for (const auto& [param_type_r, arg] : llvm::zip_first(fn.arg_types(), args)) {
194 auto arg_type = arg->ensure_ty();
195 ty::Type param_type = param_type_r;
196
197 if (param_type.is_generic()) {
198 param_type = param_type.apply_generic_substitution(overload.subs);
199
200 YUME_ASSERT(!param_type.is_generic(), "Generic substitution must produce a fully-substituted type, but `"s +
201 param_type.name() + "' is not fully substituted");
202
203 // if (!param_type) {
204 // notes->emit(overload.location()) << "Overload not valid";
205 // auto subs_dumped = std::string{};
206 // auto os = llvm::raw_string_ostream{subs_dumped};
207 // overload.subs.dump(os); // TODO(rymiel): nasty
208 // notes->emit(overload.location()) << " Because the generic type `" << param_type_r.name()
209 // << "' wasn't able to be substituted with " << subs_dumped;
210 // return false;
211 // }
212 }
213
214 // Attempt to do a literal cast
215 auto compat = literal_cast(*arg, param_type);
216 // Couldn't perform a literal cast, try regular casts
217 if (!compat.valid)
218 compat = arg_type.compatibility(param_type);
219
220 // Couldn't perform any kind of valid cast: one invalid conversion disqualifies the function entirely
221 if (!compat.valid) {
222 // TODO(rymiel): #17 Actually keep track of *why* a type is not viable, for diagnostics.
223 notes->emit(overload.location()) << "Overload not valid";
224 notes->emit(overload.location()) << " Because `" << arg_type.name() << "' is not convertible to `"
225 << param_type.name() << "'";
226 return false;
227 }
228
229 // Save the steps needed to perform the conversion
230 overload.compatibilities.push_back(compat);
231 }
232
233 // Add dummy conversions for each argument which maps to a variadic
234 while (overload.compatibilities.size() < args.size())
235 overload.compatibilities.emplace_back();
236
237 // Must be valid!
238 return true;
239}
240
242 // All `Overload`s are determined to not be viable by default, so determine the ones which actually are
243 for (auto& i : overloads)
244 i.viable = is_valid_overload(i);
245}
246
247static auto cmp(bool a, bool b) -> std::strong_ordering { return static_cast<int>(a) <=> static_cast<int>(b); }
248
249static auto compare_implicit_conversions(ty::Conv a, ty::Conv b) -> std::weak_ordering {
250 const auto& equal = std::strong_ordering::equal;
251
252 // No conversion is better than some conversion
253 if (auto c = cmp(a.kind == ty::Conv::None, b.kind == ty::Conv::None); c != equal)
254 return c;
255
256 // No dereference is better than performing a dereference
257 if (auto c = cmp(!a.dereference, !b.dereference); c != equal)
258 return c;
259
260 // Cannot distinguish!
261 return equal;
262}
263
264auto Overload::better_candidate_than(Overload other) const -> bool {
265 // Viable candidates are always better than non-viable ones
266 if (!other.viable)
267 return viable;
268 if (!viable)
269 return false;
270
271 // For each argument, determine which candidate has a "better" conversion.
272 for (const auto& [self_compat, other_compat] : llvm::zip_first(compatibilities, other.compatibilities)) {
273 auto comparison = compare_implicit_conversions(self_compat.conv, other_compat.conv);
274
275 if (is_gt(comparison))
276 return true;
277 if (is_lt(comparison))
278 return false;
279
280 // Cannot distinguish between these, try the next arguments
281 }
282
283 // If we got to here, it means all arguments were identical. Neither overload is better than the other
284 return false;
285}
286
288 const Overload* best = nullptr;
289
290 for (const auto& candidate : overloads) {
291 if (candidate.viable) {
292 if (best == nullptr || candidate.better_candidate_than(*best))
293 best = &candidate;
294 }
295 }
296
297 return best;
298}
299
301 const auto* best = try_best_viable_overload();
302
303 if (best == nullptr) {
304 string str{};
305 llvm::raw_string_ostream ss{str};
306 ss << "No viable overload for " << overload_name(call) << " with argument types ";
308 ss << "\nNone of the following overloads were suitable:\n";
309 for (const auto& i : overloads) {
310 i.dump(ss);
311 ss << "\n";
312 }
313 ss << "\n";
314 ss << notes->buffer;
315 throw std::logic_error(str);
316 }
317
318 vector<const Overload*> ambiguous;
319
320 for (const auto& candidate : overloads) {
321 if (candidate.viable && &candidate != best) {
322 if (!best->better_candidate_than(candidate))
323 ambiguous.push_back(&candidate);
324 }
325 }
326
327 if (ambiguous.empty())
328 return *best;
329
330 ambiguous.push_back(best);
331
332 string str{};
333 llvm::raw_string_ostream ss{str};
334 ss << "Ambiguous call for " << overload_name(call) << " with argument types ";
336 ss << "\nCouldn't pick between the following overloads:\n";
337 for (const auto* i : ambiguous) {
338 i->dump(ss);
339 ss << "\n";
340 }
341
342 throw std::logic_error(str);
343}
344
345} // namespace yume::semantic
All nodes in the AST tree of the program inherit from this class.
Definition: ast.hpp:224
auto ensure_ty() const -> ty::Type
Definition: ast.hpp:254
A built-in integral type, such as I32 or Bool.
Definition: type.hpp:35
auto in_range(int64_t num) const -> bool
Definition: type.cpp:536
An user-defined struct type with associated fields.
Definition: type.hpp:78
A "qualified" type, with a non-stackable qualifier, i.e. mut.
Definition: type_base.hpp:66
auto apply_generic_substitution(const Substitutions &sub) const -> Type
Definition: type.cpp:279
auto determine_generic_subs(Type generic, const Substitutions &subs) const -> optional< Substitutions >
Definition: type.cpp:96
auto is_generic() const noexcept -> bool
Definition: type.cpp:221
auto name() const -> string
Definition: type.cpp:450
static auto overload_name(const ast::AST *ast) -> std::string
Definition: overload.cpp:29
static auto compare_implicit_conversions(ty::Conv a, ty::Conv b) -> std::weak_ordering
Definition: overload.cpp:249
static auto join_args(const auto &iter, auto fn, llvm::raw_ostream &stream=errs())
Definition: overload.cpp:21
static constexpr auto get_val_ty
Definition: overload.cpp:18
static auto cmp(bool a, bool b) -> std::strong_ordering
Definition: overload.cpp:247
static auto literal_cast(ast::AST &arg, ty::Type target_type) -> ty::Compat
Definition: overload.cpp:79
static constexpr auto indirect
Definition: overload.cpp:19
static auto overload_receiver(const ast::AST *ast) -> optional< ty::Type >
Definition: overload.cpp:38
static auto generic_base(optional< ty::Type > type) -> optional< ty::Type >
Definition: overload.cpp:110
auto parameter_count_matches(const vector< ast::AST * > &args, const Fn &fn) -> bool
Definition: overload.cpp:99
A function declaration in the compiler.
Definition: vals.hpp:52
optional< ty::Type > self_ty
If this function is in the body of a struct, this points to its type. Used for the self type.
Definition: vals.hpp:57
auto arg_types() const -> vector< ty::Type >
Definition: vals.cpp:91
auto ast() const -> const ast::Stmt &
Definition: vals.cpp:74
auto name() const noexcept -> string
Definition: vals.cpp:64
auto empty() const -> bool
unique_ptr< diagnostic::StringNotesHolder > notes
Definition: overload.hpp:41
auto is_valid_overload(Overload &overload) -> bool
Definition: overload.cpp:116
vector< Overload > overloads
Definition: overload.hpp:39
auto try_best_viable_overload() const -> const Overload *
Definition: overload.cpp:287
vector< ast::AST * > args
Definition: overload.hpp:40
void dump(llvm::raw_ostream &stream, bool hide_invalid=false) const
Definition: overload.cpp:66
auto best_viable_overload() const -> Overload
Definition: overload.cpp:300
Substitutions subs
Definition: overload.hpp:25
void dump(llvm::raw_ostream &stream) const
Definition: overload.cpp:47
auto better_candidate_than(Overload other) const -> bool
Definition: overload.cpp:264
The compatibility between two types, for overload selection.
#define YUME_ASSERT(assertion, message)
Definition: util.hpp:81