Skip to content

Commit

Permalink
feat: modified the Authentication class to use the same token until i…
Browse files Browse the repository at this point in the history
…t expires and request a new one via the refresh_token
  • Loading branch information
AntonioVentilii committed Apr 3, 2024
1 parent 7ace524 commit 97a53f8
Show file tree
Hide file tree
Showing 12 changed files with 301 additions and 14 deletions.
40 changes: 40 additions & 0 deletions .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: Run Package Tests

on:
push:
branches:
- main
tags:
- 'v*'
pull_request:
branches:
- main
release:
types: [ created ]

jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ 3.7, 3.8, 3.9, '3.10', 3.11 ]

steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: Run tests
env:
TEST_CLIENT_ID: ${{ secrets.TEST_CLIENT_ID }}
TEST_CLIENT_SECRET: ${{ secrets.TEST_CLIENT_SECRET }}
run: |
pytest
5 changes: 5 additions & 0 deletions .github/workflows/publish-to-pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,24 @@ on:
jobs:
build-and-publish:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
- name: Build package
run: |
python setup.py sdist bdist_wheel
- name: Publish to PyPI
env:
TWINE_USERNAME: __token__
Expand Down
108 changes: 98 additions & 10 deletions deribit_wrapper/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import time
import uuid
import warnings

import requests
Expand All @@ -10,16 +11,24 @@

from .base import DeribitBase
from .exceptions import DeribitClientWarning
from .utilities import ScopeType


class Authentication(DeribitBase):
__AUTH = '/public/auth'

__GET_TIME = '/public/get_time'
__STATUS = '/public/status'
__TEST = '/public/test'

def __init__(self, env: str = 'prod', client_id: str = None, client_secret: str = None):
super().__init__(env=env)
self._client_id = None
self._client_secret = None
self.set_credentials(client_id, client_secret)
self._access_token = None
self._token_expiry = None
self._refresh_token = None

@property
def client_id(self) -> str:
Expand Down Expand Up @@ -54,7 +63,7 @@ def _request(self, uri: str, params: dict[str:str | int | float], give_results:
}
headers = None
if uri.startswith('/private'):
token = self._get_new_token()
token = self.access_token
headers = {'Authorization': 'bearer ' + token}
r = self._session.post(url=self.api_url, data=json.dumps(data), headers=headers)
if give_results:
Expand Down Expand Up @@ -82,15 +91,94 @@ def _request(self, uri: str, params: dict[str:str | int | float], give_results:
ret = r
return ret

def _get_new_token(self) -> str:
@property
def access_token(self) -> str:
self.refresh_token_if_expired()
return self._access_token

def is_token_expired(self) -> bool:
if self._token_expiry is None or self._access_token is None:
return True
current_time = int(time.time())
buffer = 60
return current_time >= self._token_expiry - buffer

def refresh_token_if_expired(self):
if self.is_token_expired():
self.get_new_token()

def create_new_scope(self, session_name: str = None, account: ScopeType = None,
trade: ScopeType = None, wallet: ScopeType = None, block_trade: ScopeType = None,
expires: int = 0, ip: str = '') -> str:
scope_parts = []

if session_name is None:
unique_part = uuid.uuid4()
timestamp = int(time.time())
session_name = f'{self.instance_name}_{timestamp}_{unique_part.hex}'
scope_parts.append(f'session:{session_name}')

if account:
scope_parts.append(f'account:{account}')

if trade:
scope_parts.append(f'trade:{trade}')

if wallet:
scope_parts.append(f'wallet:{wallet}')

if block_trade:
scope_parts.append(f'block_trade:{block_trade}')

if expires > 0:
scope_parts.append(f'expires:{expires}')

if ip:
scope_parts.append(f'ip:{ip}')

final_scope = ' '.join(scope_parts)
return final_scope

def get_new_token(self, use_refresh_token_if_available: bool = True) -> str:
assert self.client_id and self.client_secret, 'Cannot generate new token without Client ID and Client Secret'
uri = self.__AUTH
params = {
'grant_type': 'client_credentials',
'scope': 'session:first_test',
'client_id': self.client_id,
'client_secret': self.client_secret,
}
scope = self.create_new_scope()
if use_refresh_token_if_available and self._refresh_token:
params = {
'grant_type': 'refresh_token',
'refresh_token': self._refresh_token,
'scope': scope,
}
else:
params = {
'grant_type': 'client_credentials',
'client_id': self.client_id,
'client_secret': self.client_secret,
'scope': scope,
}
r = self._request(uri, params)
token = r['access_token']
return token
self._access_token = r['access_token']
self._token_expiry = int(time.time()) + r['expires_in']
self._refresh_token = r['refresh_token']
return self._access_token

def get_time(self) -> int:
uri = self.__GET_TIME
r = self._request(uri, {})
return r

def get_status(self) -> dict:
uri = self.__STATUS
r = self._request(uri, {})
return r

def get_locked_currencies(self) -> dict:
return self.get_status()['locked_currencies']

def test(self) -> dict:
uri = self.__TEST
r = self._request(uri, {})
return r

def get_api_version(self) -> str:
return self.test()['version']
2 changes: 1 addition & 1 deletion deribit_wrapper/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class DeribitBase(object):
'prod': 'https://www.deribit.com'
}
__API_URL = '/api/v2'
_instance_count = 0 # Class variable to keep track of the number of instances
_instance_count = 0 # Class variable to keep track of the instance number

def __init__(self, env: str = 'prod', instance_name: str = None):
super().__init__()
Expand Down
2 changes: 1 addition & 1 deletion deribit_wrapper/trading.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _error_handler(self, ret: dict, uri: str, params: dict, exclude_codes: list[
if code != 10041:
break

else:
elif code not in exclude_codes:
print(f'Error code {code} not handled yet.')

return ret
Expand Down
4 changes: 3 additions & 1 deletion deribit_wrapper/utilities.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import absolute_import, annotations

from datetime import datetime
from typing import List, Tuple, Union
from typing import List, Literal, Tuple, Union

import numpy as np
import pandas as pd
Expand All @@ -12,6 +12,8 @@
DatetimeType = Union[datetime, str, float]
StrikeType = Union[str, float]

ScopeType = Literal['read', 'read_write', 'none']

DEFAULT_START = '2000-01-01'
DEFAULT_END = 'now'

Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ numpy
pandas
urllib3
progressbar2
pytest
pytest-mock
python-dotenv
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='deribit_wrapper',
version='0.2.1',
version='0.2.3',
packages=find_packages(),
description='A Python wrapper for seamless integration with Deribit\'s trading API, offering easy access to '
'market data, account management, and trading operations.',
Expand Down
Empty file added tests/__init__.py
Empty file.
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import os

from dotenv import load_dotenv

load_dotenv()

client_id = os.environ.get("DERIBIT_CLIENT_ID")
client_secret = os.environ.get("DERIBIT_CLIENT_SECRET")
98 changes: 98 additions & 0 deletions tests/test_authentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import os
from unittest import TestCase
from unittest.mock import patch

import pytest
from dotenv import load_dotenv

from deribit_wrapper.authentication import Authentication
from deribit_wrapper.exceptions import DeribitClientWarning

load_dotenv()

token_mock_response = {
'access_token': 'new_access_token',
'expires_in': 3600,
'refresh_token': 'new_refresh_token',
}


@pytest.fixture
def auth_instance():
"""Fixture to create an Authentication instance with credentials loaded from environment variables."""
client_id = os.environ.get("TEST_CLIENT_ID")
client_secret = os.environ.get("TEST_CLIENT_SECRET")
return Authentication(env='test', client_id=client_id, client_secret=client_secret)


def test_credentials_set_correctly(auth_instance):
"""Test that client ID and client secret are set correctly from environment variables."""
assert auth_instance.client_id == os.environ.get("TEST_CLIENT_ID")
assert auth_instance.client_secret == os.environ.get("TEST_CLIENT_SECRET")


def test_warning_raised_when_credentials_not_provided():
"""Test that a warning is raised when credentials are not provided."""
with pytest.warns(DeribitClientWarning):
Authentication(env='test')


@patch('deribit_wrapper.authentication.Authentication._request')
def test_authentication_process(mock_request, auth_instance):
"""Test the authentication process, assuming successful token retrieval."""
# Mock the _request method to return a mock access token response
mock_request.return_value = token_mock_response

token = auth_instance.get_new_token()
assert token == 'new_access_token'
mock_request.assert_called_once()


@patch('deribit_wrapper.authentication.Authentication._request',
side_effect=Exception("Cannot generate new token without Client ID and Client Secret"))
def test_authentication_failure_leads_to_exception(mock_request, auth_instance):
"""Test that an exception is raised when the authentication request fails."""
with pytest.raises(Exception) as excinfo:
auth_instance.get_new_token()
assert "Cannot generate new token without Client ID and Client Secret" in str(excinfo.value)


def test_get_new_token_retrieves_new_token():
mock_response = token_mock_response

with patch('deribit_wrapper.authentication.Authentication._request') as mock_request, \
patch('deribit_wrapper.authentication.Authentication.create_new_scope',
return_value='session:fixed_session_name') as mock_create_new_scope:
mock_request.return_value = mock_response

auth = Authentication(env='test', client_id='dummy_id', client_secret='dummy_secret')
new_token = auth.get_new_token()

assert new_token == 'new_access_token'
mock_request.assert_called_once_with('/public/auth', {
'grant_type': 'client_credentials',
'client_id': 'dummy_id',
'client_secret': 'dummy_secret',
'scope': 'session:fixed_session_name',
})
mock_create_new_scope.assert_called()


class TestDeribitIntegration(TestCase):
def setUp(self):
env = 'test'
client_id = os.environ.get('TEST_CLIENT_ID')
client_secret = os.environ.get('TEST_CLIENT_SECRET')
self.auth = Authentication(env=env, client_id=client_id, client_secret=client_secret)

def test_get_new_token(self):
token = self.auth.access_token
self.assertIsNotNone(token)

def test_get_time(self):
time = self.auth.get_time()
self.assertIsInstance(time, int)

def test_get_api_version(self):
version = self.auth.get_api_version()
self.assertIsInstance(version, str)
Loading

0 comments on commit 97a53f8

Please sign in to comment.