Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add mtls feature to rest transport #731

Merged
merged 5 commits into from
Jan 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -239,21 +239,15 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
# Create SSL credentials for mutual TLS if needed.
use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")))

ssl_credentials = None
client_cert_source_func = None
is_mtls = False
if use_client_cert:
if client_options.client_cert_source:
import grpc # type: ignore

cert, key = client_options.client_cert_source()
ssl_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)
is_mtls = True
client_cert_source_func = client_options.client_cert_source
else:
creds = SslCredentials()
is_mtls = creds.is_mtls
ssl_credentials = creds.ssl_credentials if is_mtls else None
is_mtls = mtls.has_default_client_cert_source()
client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None

# Figure out which api endpoint to use.
if client_options.api_endpoint is not None:
Expand Down Expand Up @@ -292,7 +286,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
credentials_file=client_options.credentials_file,
host=api_endpoint,
scopes=client_options.scopes,
ssl_channel_credentials=ssl_credentials,
client_cert_source_for_mtls=client_cert_source_func,
quota_project_id=client_options.quota_project_id,
client_info=client_info,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
ssl_channel_credentials: grpc.ChannelCredentials = None,
client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
Expand Down Expand Up @@ -82,6 +83,10 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
``api_mtls_endpoint`` is None.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
for grpc channel. It is ignored if ``channel`` is provided.
client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
A callback to provide client certificate bytes and private key bytes,
both in PEM format. It is used to configure mutual TLS channel. It is
ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
Expand All @@ -98,6 +103,11 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
"""
self._ssl_channel_credentials = ssl_channel_credentials

if api_mtls_endpoint:
warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
if client_cert_source:
warnings.warn("client_cert_source is deprecated", DeprecationWarning)

if channel:
# Sanity check: Ensure that channel and credentials are not both
# provided.
Expand All @@ -107,8 +117,6 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
self._grpc_channel = channel
self._ssl_channel_credentials = None
elif api_mtls_endpoint:
warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning)

host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443"

if credentials is None:
Expand Down Expand Up @@ -144,12 +152,18 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
if credentials is None:
credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id)

if client_cert_source_for_mtls and not ssl_channel_credentials:
cert, key = client_cert_source_for_mtls()
self._ssl_channel_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)

# create a new channel. The provided one is ignored.
self._grpc_channel = type(self).create_channel(
host,
credentials=credentials,
credentials_file=credentials_file,
ssl_credentials=ssl_channel_credentials,
ssl_credentials=self._ssl_channel_credentials,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
options=[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
api_mtls_endpoint: str = None,
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
ssl_channel_credentials: grpc.ChannelCredentials = None,
client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id=None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
Expand Down Expand Up @@ -126,6 +127,10 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
``api_mtls_endpoint`` is None.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
for grpc channel. It is ignored if ``channel`` is provided.
client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
A callback to provide client certificate bytes and private key bytes,
both in PEM format. It is used to configure mutual TLS channel. It is
ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
Expand All @@ -141,6 +146,11 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
and ``credentials_file`` are passed.
"""
self._ssl_channel_credentials = ssl_channel_credentials

if api_mtls_endpoint:
warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
if client_cert_source:
warnings.warn("client_cert_source is deprecated", DeprecationWarning)

if channel:
# Sanity check: Ensure that channel and credentials are not both
Expand All @@ -151,8 +161,6 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
self._grpc_channel = channel
self._ssl_channel_credentials = None
elif api_mtls_endpoint:
warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning)

host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443"

if credentials is None:
Expand Down Expand Up @@ -188,12 +196,18 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
if credentials is None:
credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id)

if client_cert_source_for_mtls and not ssl_channel_credentials:
cert, key = client_cert_source_for_mtls()
self._ssl_channel_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)

# create a new channel. The provided one is ignored.
self._grpc_channel = type(self).create_channel(
host,
credentials=credentials,
credentials_file=credentials_file,
ssl_credentials=ssl_channel_credentials,
ssl_credentials=self._ssl_channel_credentials,
scopes=scopes or self.AUTH_SCOPES,
quota_project_id=quota_project_id,
options=[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
credentials: credentials.Credentials = None,
credentials_file: str = None,
scopes: Sequence[str] = None,
ssl_channel_credentials: grpc.ChannelCredentials = None,
client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
) -> None:
Expand All @@ -68,8 +68,9 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
This argument is ignored if ``channel`` is provided.
scopes (Optional(Sequence[str])): A list of scopes. This argument is
ignored if ``channel`` is provided.
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
for grpc channel. It is ignored if ``channel`` is provided.
client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client
certificate to configure mutual TLS HTTP channel. It is ignored
if ``channel`` is provided.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
Expand All @@ -89,6 +90,8 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
{%- if service.has_lro %}
self._operations_client = None
{%- endif %}
if client_cert_source_for_mtls:
self._session.configure_mtls_channel(client_cert_source_for_mtls)

{%- if service.has_lro %}

Expand Down
Loading