Skip to content

Commit 969e867

Browse files
committed
feat grpc: support call interruption in middlewares
commit_hash:31fbc1033e5635a40386d94c94fa9e305c7a8696
1 parent b02f596 commit 969e867

File tree

16 files changed

+117
-83
lines changed

16 files changed

+117
-83
lines changed

grpc/include/userver/ugrpc/server/impl/call_processor.hpp

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
#include <google/protobuf/message.h>
1111
#include <grpcpp/server_context.h>
1212

13+
#include <userver/logging/log.hpp>
1314
#include <userver/server/handlers/exceptions.hpp>
1415
#include <userver/tracing/in_place_span.hpp>
16+
#include <userver/utils/fast_scope_guard.hpp>
1517
#include <userver/utils/impl/internal_tag.hpp>
1618

1719
#include <userver/ugrpc/server/exceptions.hpp>
@@ -125,6 +127,9 @@ class CallProcessor final {
125127
RunOnCallStart();
126128

127129
bool finished = false;
130+
const utils::FastScopeGuard post_finish_hooks_guard([this, &finished]() noexcept {
131+
RunOnCallFinish(finished ? std::make_optional(std::move(status_)) : std::nullopt);
132+
});
128133

129134
// Don't keep the config snapshot for too long, especially for streaming RPCs.
130135
state_.config_snapshot.reset();
@@ -138,7 +143,7 @@ class CallProcessor final {
138143
}
139144

140145
if (!engine::current_task::ShouldCancel() && !responder_.IsInterrupted()) {
141-
RunOnCallFinish(response);
146+
RunPreFinishHooks(response);
142147
finished = impl::Finish(responder_, response, status_);
143148
}
144149

@@ -183,7 +188,7 @@ class CallProcessor final {
183188
}
184189
}
185190

186-
void RunOnCallFinish(std::optional<Response>& response) {
191+
void RunPreFinishHooks(std::optional<Response>& response) {
187192
const auto& mids = state_.middlewares;
188193
const auto rbegin = mids.rbegin() + (mids.size() - success_pre_hooks_count_);
189194
for (auto it = rbegin; it != mids.rend(); ++it) {
@@ -197,8 +202,20 @@ class CallProcessor final {
197202
}
198203
}
199204

200-
// We must call all OnRpcFinish despite the failures. So, don't check the status.
201-
RunWithCatch([this, &middleware] { middleware->OnCallFinish(middleware_call_context_, status_); });
205+
RunWithCatch([this, &middleware] { middleware->PreSendStatus(middleware_call_context_, status_); });
206+
}
207+
}
208+
209+
void RunOnCallFinish(const std::optional<grpc::Status>& status) {
210+
const auto& mids = state_.middlewares;
211+
const auto rbegin = mids.rbegin() + (mids.size() - success_pre_hooks_count_);
212+
for (auto it = rbegin; it != mids.rend(); ++it) {
213+
const auto& middleware = *it;
214+
try {
215+
middleware->OnCallFinish(middleware_call_context_, status);
216+
} catch (const std::exception& ex) {
217+
LOG_WARNING() << "Error in OnCallFinish: " << ex;
218+
}
202219
}
203220
}
204221

grpc/include/userver/ugrpc/server/middlewares/base.hpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
/// @file userver/ugrpc/server/middlewares/base.hpp
44
/// @brief @copybrief ugrpc::server::MiddlewareBase
55

6+
#include <optional>
67
#include <string>
78

89
#include <google/protobuf/message.h>
@@ -118,14 +119,23 @@ class MiddlewareBase {
118119
/// * stream: once per response message, that is, 0, 1, more times per Call (RPC).
119120
virtual void PreSendMessage(MiddlewareCallContext& context, google::protobuf::Message& response) const;
120121

121-
/// @brief This hook is invoked once per Call (RPC), after the handler function returns, but before the message is
122-
/// sent to the upstream client.
122+
/// @brief The function is invoked before sending the final status of the call.
123123
///
124-
/// All OnCallStart invoked in the reverse order relatively OnCallFinish. You can change grpc status and it will
125-
/// apply for a rpc call.
124+
/// PreSendStatus is called exactly once per Call (RPC), right before sending
125+
/// the final gRPC status to the client. This allows middlewares to inspect
126+
/// and potentially modify the status that will be sent to the client.
127+
virtual void PreSendStatus(MiddlewareCallContext& context, grpc::Status& status) const;
128+
129+
/// @brief This hook is invoked once per Call (RPC), after the handler function
130+
/// has finished execution and the final status is determined.
126131
///
127-
/// @warning If Handler (grpc method) returns !ok status, OnCallFinish won't be called.
128-
virtual void OnCallFinish(MiddlewareCallContext& context, const grpc::Status& status) const;
132+
/// OnCallFinish is called exactly once per Call (RPC), regardless of whether
133+
/// the call succeeded or failed. It's the final middleware hook in the call chain.
134+
/// This is useful for cleanup operations, logging, or metrics collection that should
135+
/// happen after the RPC is completely processed.
136+
/// @param context The middleware call context containing call information
137+
/// @param status The final status of the call, if available
138+
virtual void OnCallFinish(MiddlewareCallContext& context, const std::optional<grpc::Status>& status) const;
129139
};
130140

131141
/// @ingroup userver_base_classes

grpc/include/userver/ugrpc/server/middlewares/deadline_propagation/middleware.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class Middleware final : public MiddlewareBase {
2525

2626
void OnCallStart(MiddlewareCallContext& context) const override;
2727

28-
void OnCallFinish(MiddlewareCallContext& context, const grpc::Status& status) const override;
28+
void PreSendStatus(MiddlewareCallContext& context, grpc::Status& status) const override;
2929
};
3030

3131
} // namespace ugrpc::server::middlewares::deadline_propagation

grpc/src/ugrpc/server/impl/call_processor.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,5 @@
11
#include <userver/ugrpc/server/impl/call_processor.hpp>
22

3-
#include <chrono>
4-
5-
#include <boost/container/flat_map.hpp>
6-
7-
#include <userver/logging/impl/logger_base.hpp>
8-
#include <userver/logging/log.hpp>
93
#include <userver/tracing/opentelemetry.hpp>
104
#include <userver/tracing/tags.hpp>
115
#include <userver/utils/algo.hpp>
@@ -134,11 +128,6 @@ void ReportFinished(const grpc::Status& status, CallState& state) noexcept {
134128

135129
void ReportInterrupted(CallState& state) noexcept {
136130
try {
137-
// RPC interruption leads to asynchronous task cancellation by RpcFinishedEvent,
138-
// so the task either is already cancelled, or is going to be cancelled.
139-
LOG_WARNING()
140-
<< "RPC interrupted in '" << state.call_name
141-
<< "'. The previously logged cancellation or network exception, if any, is likely caused by it.";
142131
state.statistics_scope.OnNetworkError();
143132
auto& span = state.GetSpan();
144133
span.AddNonInheritableTag(tracing::kErrorFlag, true);

grpc/src/ugrpc/server/impl/format_log_message.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66

77
#include <userver/logging/impl/logger_base.hpp>
88
#include <userver/logging/impl/timestamp.hpp>
9-
#include <userver/ugrpc/impl/to_string.hpp>
109
#include <userver/utils/datetime.hpp>
1110
#include <userver/utils/encoding/tskv.hpp>
1211
#include <userver/utils/text_light.hpp>
1312

13+
#include <userver/ugrpc/impl/to_string.hpp>
14+
#include <userver/ugrpc/status_codes.hpp>
15+
1416
USERVER_NAMESPACE_BEGIN
1517

1618
namespace ugrpc::server::impl {
@@ -53,7 +55,7 @@ logging::impl::LogExtraTskvFormatter FormatLogMessage(
5355
std::string_view peer,
5456
std::chrono::system_clock::time_point start_time,
5557
std::string_view call_name,
56-
grpc::StatusCode code,
58+
std::optional<grpc::StatusCode> code,
5759
const logging::LogExtra* log_extra
5860
) {
5961
static const auto kTimezone = utils::datetime::LocalTimezoneTimestring(start_time, "%z");
@@ -96,8 +98,8 @@ logging::impl::LogExtraTskvFormatter FormatLogMessage(
9698
// TODO remove, this is for safe migration from old access log parsers.
9799
request_time_seconds.count(),
98100
request_time_milliseconds.count(),
99-
static_cast<int>(code),
100-
ToString(code)
101+
code.has_value() ? std::to_string(static_cast<int>(*code)) : "-",
102+
code.has_value() ? ToString(*code) : "-"
101103
);
102104

103105
if (log_extra) {

grpc/src/ugrpc/server/impl/format_log_message.hpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
#pragma once
22

3-
#include <google/protobuf/message.h>
3+
#include <chrono>
4+
#include <map>
5+
#include <optional>
6+
#include <string_view>
7+
8+
#include <grpcpp/support/status_code_enum.h>
49
#include <grpcpp/support/string_ref.h>
510

6-
#include <userver/dynamic_config/snapshot.hpp>
711
#include <userver/logging/impl/log_extra_tskv_formatter.hpp>
8-
#include <userver/logging/impl/logger_base.hpp>
912
#include <userver/logging/log_extra.hpp>
10-
#include <userver/ugrpc/status_codes.hpp>
1113

1214
USERVER_NAMESPACE_BEGIN
1315

@@ -18,7 +20,7 @@ logging::impl::LogExtraTskvFormatter FormatLogMessage(
1820
std::string_view peer,
1921
std::chrono::system_clock::time_point start_time,
2022
std::string_view call_name,
21-
grpc::StatusCode code,
23+
std::optional<grpc::StatusCode> code,
2224
const logging::LogExtra* log_extra
2325
);
2426

grpc/src/ugrpc/server/middlewares/access_log/middleware.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace {
1313

1414
void WriteAccessLog(
1515
MiddlewareCallContext& context,
16-
const grpc::Status& status,
16+
const std::optional<grpc::Status>& status,
1717
logging::TextLoggerRef access_tskv_logger
1818
) noexcept {
1919
try {
@@ -30,7 +30,7 @@ void WriteAccessLog(
3030
server_context.peer(),
3131
context.GetSpan().GetStartSystemTime(),
3232
context.GetCallName(),
33-
status.error_code(),
33+
status.has_value() ? std::make_optional(status->error_code()) : std::nullopt,
3434
log_extra
3535
)
3636
.ExtractTextLogItem()
@@ -47,7 +47,7 @@ Middleware::Middleware(Settings&& settings)
4747
: logger_(std::move(settings.access_tskv_logger))
4848
{}
4949

50-
void Middleware::OnCallFinish(MiddlewareCallContext& context, const grpc::Status& status) const {
50+
void Middleware::OnCallFinish(MiddlewareCallContext& context, const std::optional<grpc::Status>& status) const {
5151
WriteAccessLog(context, status, *logger_);
5252
}
5353

grpc/src/ugrpc/server/middlewares/access_log/middleware.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class Middleware final : public MiddlewareBase {
1717
public:
1818
explicit Middleware(Settings&& settings);
1919

20-
void OnCallFinish(MiddlewareCallContext& context, const grpc::Status& status) const override;
20+
void OnCallFinish(MiddlewareCallContext& context, const std::optional<grpc::Status>& status) const override;
2121

2222
private:
2323
logging::TextLoggerPtr logger_;

grpc/src/ugrpc/server/middlewares/base.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ void MiddlewareBase::PostRecvMessage(MiddlewareCallContext&, google::protobuf::M
5151

5252
void MiddlewareBase::PreSendMessage(MiddlewareCallContext&, google::protobuf::Message&) const {}
5353

54-
void MiddlewareBase::OnCallFinish(MiddlewareCallContext&, const grpc::Status&) const {}
54+
void MiddlewareBase::PreSendStatus(MiddlewareCallContext&, grpc::Status&) const {}
55+
56+
void MiddlewareBase::OnCallFinish(MiddlewareCallContext&, const std::optional<grpc::Status>&) const {}
5557

5658
} // namespace ugrpc::server
5759

grpc/src/ugrpc/server/middlewares/deadline_propagation/middleware.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ void Middleware::OnCallStart(MiddlewareCallContext& context) const {
9191
}
9292
}
9393

94-
void Middleware::OnCallFinish(MiddlewareCallContext& context, const grpc::Status& status) const {
94+
void Middleware::PreSendStatus(MiddlewareCallContext& context, grpc::Status& status) const {
9595
const auto* const inherited_data = USERVER_NAMESPACE::server::request::kTaskInheritedData.GetOptional();
9696

9797
// if !USERVER_DEADLINE_PROPAGATION_ENABLED, inherited_data must be nullptr

0 commit comments

Comments
 (0)