Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Coroutine closures implement regular Fn traits, when possible
  • Loading branch information
compiler-errors committed Feb 6, 2024
commit b8c93f1223695217cbabc1f3f1e428c358bb4e7a
17 changes: 12 additions & 5 deletions compiler/rustc_hir_typeck/src/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,18 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
// It's always helpful for inference if we know the kind of
// closure sooner rather than later, so first examine the expected
// type, and see if can glean a closure kind from there.
let (expected_sig, expected_kind) = match expected.to_option(self) {
Some(ty) => {
self.deduce_closure_signature(self.try_structurally_resolve_type(expr_span, ty))
}
None => (None, None),
let (expected_sig, expected_kind) = match closure.kind {
hir::ClosureKind::Closure => match expected.to_option(self) {
Some(ty) => {
self.deduce_closure_signature(self.try_structurally_resolve_type(expr_span, ty))
}
None => (None, None),
},
// We don't want to deduce a signature from `Fn` bounds for coroutines
// or coroutine-closures, because the former does not implement `Fn`
// ever, and the latter's signature doesn't correspond to the coroutine
// type that it returns.
hir::ClosureKind::Coroutine(_) | hir::ClosureKind::CoroutineClosure(_) => (None, None),
};

let ClosureSignatures { bound_sig, mut liberated_sig } =
Expand Down
74 changes: 70 additions & 4 deletions compiler/rustc_trait_selection/src/traits/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2074,7 +2074,9 @@ fn confirm_select_candidate<'cx, 'tcx>(
} else if lang_items.async_iterator_trait() == Some(trait_def_id) {
confirm_async_iterator_candidate(selcx, obligation, data)
} else if selcx.tcx().fn_trait_kind_from_def_id(trait_def_id).is_some() {
if obligation.predicate.self_ty().is_closure() {
if obligation.predicate.self_ty().is_closure()
|| obligation.predicate.self_ty().is_coroutine_closure()
{
confirm_closure_candidate(selcx, obligation, data)
} else {
confirm_fn_pointer_candidate(selcx, obligation, data)
Expand Down Expand Up @@ -2386,11 +2388,75 @@ fn confirm_closure_candidate<'cx, 'tcx>(
obligation: &ProjectionTyObligation<'tcx>,
nested: Vec<PredicateObligation<'tcx>>,
) -> Progress<'tcx> {
let tcx = selcx.tcx();
let self_ty = selcx.infcx.shallow_resolve(obligation.predicate.self_ty());
let ty::Closure(_, args) = self_ty.kind() else {
unreachable!("expected closure self type for closure candidate, found {self_ty}")
let closure_sig = match *self_ty.kind() {
ty::Closure(_, args) => args.as_closure().sig(),

// Construct a "normal" `FnOnce` signature for coroutine-closure. This is
// basically duplicated with the `AsyncFnOnce::CallOnce` confirmation, but
// I didn't see a good way to unify those.
ty::CoroutineClosure(def_id, args) => {
let args = args.as_coroutine_closure();
let kind_ty = args.kind_ty();
args.coroutine_closure_sig().map_bound(|sig| {
// If we know the kind and upvars, use that directly.
// Otherwise, defer to `AsyncFnKindHelper::Upvars` to delay
// the projection, like the `AsyncFn*` traits do.
let output_ty = if let Some(_) = kind_ty.to_opt_closure_kind() {
sig.to_coroutine_given_kind_and_upvars(
tcx,
args.parent_args(),
tcx.coroutine_for_closure(def_id),
ty::ClosureKind::FnOnce,
tcx.lifetimes.re_static,
args.tupled_upvars_ty(),
args.coroutine_captures_by_ref_ty(),
)
} else {
let async_fn_kind_trait_def_id =
tcx.require_lang_item(LangItem::AsyncFnKindHelper, None);
let upvars_projection_def_id = tcx
.associated_items(async_fn_kind_trait_def_id)
.filter_by_name_unhygienic(sym::Upvars)
.next()
.unwrap()
.def_id;
let tupled_upvars_ty = Ty::new_projection(
tcx,
upvars_projection_def_id,
[
ty::GenericArg::from(kind_ty),
Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce).into(),
tcx.lifetimes.re_static.into(),
sig.tupled_inputs_ty.into(),
args.tupled_upvars_ty().into(),
args.coroutine_captures_by_ref_ty().into(),
],
);
sig.to_coroutine(
tcx,
args.parent_args(),
Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce),
tcx.coroutine_for_closure(def_id),
tupled_upvars_ty,
)
};
tcx.mk_fn_sig(
[sig.tupled_inputs_ty],
output_ty,
sig.c_variadic,
sig.unsafety,
sig.abi,
)
})
}

_ => {
unreachable!("expected closure self type for closure candidate, found {self_ty}");
}
};
let closure_sig = args.as_closure().sig();

let Normalized { value: closure_sig, obligations } = normalize_with_depth(
selcx,
obligation.param_env,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,31 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
}
}
}
ty::CoroutineClosure(def_id, args) => {
let is_const = self.tcx().is_const_fn_raw(def_id);
match self.infcx.closure_kind(self_ty) {
Some(closure_kind) => {
let no_borrows = self
.infcx
.shallow_resolve(args.as_coroutine_closure().tupled_upvars_ty())
.tuple_fields()
.is_empty();
if no_borrows && closure_kind.extends(kind) {
candidates.vec.push(ClosureCandidate { is_const });
} else if kind == ty::ClosureKind::FnOnce {
candidates.vec.push(ClosureCandidate { is_const });
}
}
None => {
if kind == ty::ClosureKind::FnOnce {
candidates.vec.push(ClosureCandidate { is_const });
} else {
// This stays ambiguous until kind+upvars are determined.
candidates.ambiguous = true;
}
}
}
}
ty::Infer(ty::TyVar(_)) => {
debug!("assemble_unboxed_closure_candidates: ambiguous self-type");
candidates.ambiguous = true;
Expand Down
26 changes: 17 additions & 9 deletions compiler/rustc_trait_selection/src/traits/select/confirmation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -865,17 +865,25 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
// touch bound regions, they just capture the in-scope
// type/region parameters.
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
let ty::Closure(closure_def_id, args) = *self_ty.kind() else {
bug!("closure candidate for non-closure {:?}", obligation);
let trait_ref = match *self_ty.kind() {
ty::Closure(_, args) => {
self.closure_trait_ref_unnormalized(obligation, args, self.tcx().consts.true_)
}
ty::CoroutineClosure(_, args) => {
args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
ty::TraitRef::new(
self.tcx(),
obligation.predicate.def_id(),
[self_ty, sig.tupled_inputs_ty],
)
})
}
_ => {
bug!("closure candidate for non-closure {:?}", obligation);
}
};

let trait_ref =
self.closure_trait_ref_unnormalized(obligation, args, self.tcx().consts.true_);
let nested = self.confirm_poly_trait_refs(obligation, trait_ref)?;

debug!(?closure_def_id, ?trait_ref, ?nested, "confirm closure candidate obligations");

Ok(nested)
self.confirm_poly_trait_refs(obligation, trait_ref)
}

#[instrument(skip(self), level = "debug")]
Expand Down
18 changes: 18 additions & 0 deletions compiler/rustc_ty_utils/src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,24 @@ fn resolve_associated_item<'tcx>(
def: ty::InstanceDef::FnPtrShim(trait_item_id, rcvr_args.type_at(0)),
args: rcvr_args,
}),
ty::CoroutineClosure(coroutine_closure_def_id, args) => {
// When a coroutine-closure implements the `Fn` traits, then it
// always dispatches to the `FnOnce` implementation. This is to
// ensure that the `closure_kind` of the resulting closure is in
// sync with the built-in trait implementations (since all of the
// implementations return `FnOnce::Output`).
if ty::ClosureKind::FnOnce == args.as_coroutine_closure().kind() {
Some(Instance::new(coroutine_closure_def_id, args))
} else {
Some(Instance {
def: ty::InstanceDef::ConstructCoroutineInClosureShim {
coroutine_closure_def_id,
target_kind: ty::ClosureKind::FnOnce,
},
args,
})
}
}
_ => bug!(
"no built-in definition for `{trait_ref}::{}` for non-fn type",
tcx.item_name(trait_item_id)
Expand Down