Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Update the auth providers to be async. #7935

Merged
merged 5 commits into from
Jul 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7935.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert the auth providers to be async/await.
187 changes: 94 additions & 93 deletions docs/password_auth_providers.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,102 +19,103 @@ password auth provider module implementations:

Password auth provider classes must provide the following methods:

*class* `SomeProvider.parse_config`(*config*)
* `parse_config(config)`
This method is passed the `config` object for this module from the
homeserver configuration file.

> This method is passed the `config` object for this module from the
> homeserver configuration file.
>
> It should perform any appropriate sanity checks on the provided
> configuration, and return an object which is then passed into
> `__init__`.
It should perform any appropriate sanity checks on the provided
configuration, and return an object which is then passed into

*class* `SomeProvider`(*config*, *account_handler*)
This method should have the `@staticmethod` decoration.

> The constructor is passed the config object returned by
> `parse_config`, and a `synapse.module_api.ModuleApi` object which
> allows the password provider to check if accounts exist and/or create
> new ones.
* `__init__(self, config, account_handler)`

The constructor is passed the config object returned by
`parse_config`, and a `synapse.module_api.ModuleApi` object which
allows the password provider to check if accounts exist and/or create
new ones.

## Optional methods

Password auth provider classes may optionally provide the following
methods.

*class* `SomeProvider.get_db_schema_files`()

> This method, if implemented, should return an Iterable of
> `(name, stream)` pairs of database schema files. Each file is applied
> in turn at initialisation, and a record is then made in the database
> so that it is not re-applied on the next start.

`someprovider.get_supported_login_types`()

> This method, if implemented, should return a `dict` mapping from a
> login type identifier (such as `m.login.password`) to an iterable
> giving the fields which must be provided by the user in the submission
> to the `/login` api. These fields are passed in the `login_dict`
> dictionary to `check_auth`.
>
> For example, if a password auth provider wants to implement a custom
> login type of `com.example.custom_login`, where the client is expected
> to pass the fields `secret1` and `secret2`, the provider should
> implement this method and return the following dict:
>
> {"com.example.custom_login": ("secret1", "secret2")}

`someprovider.check_auth`(*username*, *login_type*, *login_dict*)

> This method is the one that does the real work. If implemented, it
> will be called for each login attempt where the login type matches one
> of the keys returned by `get_supported_login_types`.
>
> It is passed the (possibly UNqualified) `user` provided by the client,
> the login type, and a dictionary of login secrets passed by the
> client.
>
> The method should return a Twisted `Deferred` object, which resolves
> to the canonical `@localpart:domain` user id if authentication is
> successful, and `None` if not.
>
> Alternatively, the `Deferred` can resolve to a `(str, func)` tuple, in
> which case the second field is a callback which will be called with
> the result from the `/login` call (including `access_token`,
> `device_id`, etc.)

`someprovider.check_3pid_auth`(*medium*, *address*, *password*)

> This method, if implemented, is called when a user attempts to
> register or log in with a third party identifier, such as email. It is
> passed the medium (ex. "email"), an address (ex.
> "<jdoe@example.com>") and the user's password.
>
> The method should return a Twisted `Deferred` object, which resolves
> to a `str` containing the user's (canonical) User ID if
> authentication was successful, and `None` if not.
>
> As with `check_auth`, the `Deferred` may alternatively resolve to a
> `(user_id, callback)` tuple.

`someprovider.check_password`(*user_id*, *password*)

> This method provides a simpler interface than
> `get_supported_login_types` and `check_auth` for password auth
> providers that just want to provide a mechanism for validating
> `m.login.password` logins.
>
> Iif implemented, it will be called to check logins with an
> `m.login.password` login type. It is passed a qualified
> `@localpart:domain` user id, and the password provided by the user.
>
> The method should return a Twisted `Deferred` object, which resolves
> to `True` if authentication is successful, and `False` if not.

`someprovider.on_logged_out`(*user_id*, *device_id*, *access_token*)

> This method, if implemented, is called when a user logs out. It is
> passed the qualified user ID, the ID of the deactivated device (if
> any: access tokens are occasionally created without an associated
> device ID), and the (now deactivated) access token.
>
> It may return a Twisted `Deferred` object; the logout request will
> wait for the deferred to complete but the result is ignored.
Password auth provider classes may optionally provide the following methods:

* `get_db_schema_files(self)`

This method, if implemented, should return an Iterable of
`(name, stream)` pairs of database schema files. Each file is applied
in turn at initialisation, and a record is then made in the database
so that it is not re-applied on the next start.

* `get_supported_login_types(self)`

This method, if implemented, should return a `dict` mapping from a
login type identifier (such as `m.login.password`) to an iterable
giving the fields which must be provided by the user in the submission
to [the `/login` API](https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-login).
These fields are passed in the `login_dict` dictionary to `check_auth`.

For example, if a password auth provider wants to implement a custom
login type of `com.example.custom_login`, where the client is expected
to pass the fields `secret1` and `secret2`, the provider should
implement this method and return the following dict:

```python
{"com.example.custom_login": ("secret1", "secret2")}
```

* `check_auth(self, username, login_type, login_dict)`

This method does the real work. If implemented, it
will be called for each login attempt where the login type matches one
of the keys returned by `get_supported_login_types`.

It is passed the (possibly unqualified) `user` field provided by the client,
the login type, and a dictionary of login secrets passed by the
client.

The method should return an `Awaitable` object, which resolves
to the canonical `@localpart:domain` user ID if authentication is
successful, and `None` if not.

Alternatively, the `Awaitable` can resolve to a `(str, func)` tuple, in
which case the second field is a callback which will be called with
the result from the `/login` call (including `access_token`,
`device_id`, etc.)

* `check_3pid_auth(self, medium, address, password)`

This method, if implemented, is called when a user attempts to
register or log in with a third party identifier, such as email. It is
passed the medium (ex. "email"), an address (ex.
"<jdoe@example.com>") and the user's password.

The method should return an `Awaitable` object, which resolves
to a `str` containing the user's (canonical) User id if
authentication was successful, and `None` if not.

As with `check_auth`, the `Awaitable` may alternatively resolve to a
`(user_id, callback)` tuple.

* `check_password(self, user_id, password)`

This method provides a simpler interface than
`get_supported_login_types` and `check_auth` for password auth
providers that just want to provide a mechanism for validating
`m.login.password` logins.

If implemented, it will be called to check logins with an
`m.login.password` login type. It is passed a qualified
`@localpart:domain` user id, and the password provided by the user.

The method should return an `Awaitable` object, which resolves
to `True` if authentication is successful, and `False` if not.

* `on_logged_out(self, user_id, device_id, access_token)`

This method, if implemented, is called when a user logs out. It is
passed the qualified user ID, the ID of the deactivated device (if
any: access tokens are occasionally created without an associated
device ID), and the (now deactivated) access token.

It may return an `Awaitable` object; the logout request will
wait for the `Awaitable` to complete, but the result is ignored.
7 changes: 6 additions & 1 deletion synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
import time
import unicodedata
Expand Down Expand Up @@ -863,11 +864,15 @@ async def delete_access_token(self, access_token: str):
# see if any of our auth providers want to know about this
for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
await provider.on_logged_out(
# This might return an awaitable, if it does block the log out
# until it completes.
result = provider.on_logged_out(
user_id=str(user_info["user"]),
device_id=user_info["device_id"],
access_token=access_token,
)
if inspect.isawaitable(result):
await result

# delete pushers associated with this access token
if user_info["token_id"] is not None:
Expand Down
35 changes: 17 additions & 18 deletions synapse/handlers/ui_auth/checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# limitations under the License.

import logging
from typing import Any

from canonicaljson import json

from twisted.internet import defer
from twisted.web.client import PartialDownloadError

from synapse.api.constants import LoginType
Expand All @@ -33,25 +33,25 @@ class UserInteractiveAuthChecker:
def __init__(self, hs):
pass

def is_enabled(self):
def is_enabled(self) -> bool:
"""Check if the configuration of the homeserver allows this checker to work

Returns:
bool: True if this login type is enabled.
True if this login type is enabled.
"""

def check_auth(self, authdict, clientip):
async def check_auth(self, authdict: dict, clientip: str) -> Any:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Were you holding off specifying the possibilities here so that you wouldn't need to specify Deferred? If so, it's worth noting that isinstance(a_deferred, Awaitable) resolves to True, so I think something not too cumbersome like Optional[Union[str, Tuple]] should work here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't specify the possibilities because I did the typing before reading the separate documentation and the docstring didn't specify what was returned. 😄 I can double check the return types to the documentation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wait, I know why I did this. (I should have ☕ before replying to things...)

Although this method is also named check_auth, it is NOT the same as above. It gets called twice:

result = await self.checkers[stagetype].check_auth(authdict, clientip)
if result:
await self.store.mark_ui_auth_stage_complete(
authdict["session"], stagetype, result
)

checker = self.checkers.get(login_type)
if checker is not None:
res = await checker.check_auth(authdict, clientip=clientip)
return res

While the one from a password provider gets called:

result = await provider.check_auth(username, login_type, login_dict)
if result:
if isinstance(result, str):
result = (result, None)
return result

Anyway, the result of UserInteractiveAuthChecker.check_auth gets saved into the database and sometimes inspected in other places.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, fair enough then! Thanks for the clear explanation :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're welcome! I'm now realizing that putting these changes in the same PR was confusing...since they're not really related. 😢 Sorry about that!

"""Given the authentication dict from the client, attempt to check this step

Args:
authdict (dict): authentication dictionary from the client
clientip (str): The IP address of the client.
authdict: authentication dictionary from the client
clientip: The IP address of the client.

Raises:
SynapseError if authentication failed

Returns:
Deferred: the result of authentication (to pass back to the client?)
The result of authentication (to pass back to the client?)
"""
raise NotImplementedError()

Expand All @@ -62,8 +62,8 @@ class DummyAuthChecker(UserInteractiveAuthChecker):
def is_enabled(self):
return True

def check_auth(self, authdict, clientip):
return defer.succeed(True)
async def check_auth(self, authdict, clientip):
return True


class TermsAuthChecker(UserInteractiveAuthChecker):
Expand All @@ -72,8 +72,8 @@ class TermsAuthChecker(UserInteractiveAuthChecker):
def is_enabled(self):
return True

def check_auth(self, authdict, clientip):
return defer.succeed(True)
async def check_auth(self, authdict, clientip):
return True


class RecaptchaAuthChecker(UserInteractiveAuthChecker):
Expand All @@ -89,8 +89,7 @@ def __init__(self, hs):
def is_enabled(self):
return self._enabled

@defer.inlineCallbacks
def check_auth(self, authdict, clientip):
async def check_auth(self, authdict, clientip):
try:
user_response = authdict["response"]
except KeyError:
Expand All @@ -107,7 +106,7 @@ def check_auth(self, authdict, clientip):
# TODO: get this from the homeserver rather than creating a new one for
# each request
try:
resp_body = yield self._http_client.post_urlencoded_get_json(
resp_body = await self._http_client.post_urlencoded_get_json(
self._url,
args={
"secret": self._secret,
Expand Down Expand Up @@ -219,8 +218,8 @@ def is_enabled(self):
ThreepidBehaviour.LOCAL,
)

def check_auth(self, authdict, clientip):
return defer.ensureDeferred(self._check_threepid("email", authdict))
async def check_auth(self, authdict, clientip):
return await self._check_threepid("email", authdict)


class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
Expand All @@ -233,8 +232,8 @@ def __init__(self, hs):
def is_enabled(self):
return bool(self.hs.config.account_threepid_delegate_msisdn)

def check_auth(self, authdict, clientip):
return defer.ensureDeferred(self._check_threepid("msisdn", authdict))
async def check_auth(self, authdict, clientip):
return await self._check_threepid("msisdn", authdict)


INTERACTIVE_AUTH_CHECKERS = [
Expand Down