[go: up one dir, main page]

Skip to content

Commit

Permalink
Merge pull request #11 from ryxx0811/revert-10-revert-9-ryxx
Browse files Browse the repository at this point in the history
Revert "Revert "update to caching API request""
  • Loading branch information
ryxx0811 authored Sep 2, 2024
2 parents beaf14c + 8d38f2d commit c51dc54
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 165 deletions.
199 changes: 110 additions & 89 deletions biocypher/_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

from typing import Union, Optional
from typing import Optional
import shutil

import requests
Expand All @@ -23,6 +23,7 @@

logger.debug(f"Loading module {__name__}.")

from abc import ABC
from datetime import datetime, timedelta
from tempfile import TemporaryDirectory
import os
Expand All @@ -31,16 +32,15 @@

import pooch

from ._misc import to_list
from ._misc import to_list, is_nested


class Resource:
class Resource(ABC):
def __init__(
self,
name: str,
url_s: str | list[str],
lifetime: int = 0,
is_dir: bool = False,
):
"""
Expand All @@ -56,32 +56,53 @@ def __init__(
lifetime (int): The lifetime of the resource in days. If 0, the
resource is considered to be permanent.
is_dir (bool): Whether the resource is a directory or not.
"""
self.name = name
self.url_s = url_s
self.lifetime = lifetime


class FileDownload(Resource):
def __init__(
self,
name: str,
url_s: str | list[str],
lifetime: int = 0,
is_dir: bool = False,
):
"""
Represents basic information for a File Download.
Args:
name(str):The name of the File Download.
url_s(str|list[str]):The URL of the File Download.
lifetime(int): The lifetime of the File Download in days. If 0, the
File Download is cached indefinitely.
is_dir (bool): Whether the File Download is a directory or not.
"""

super().__init__(name, url_s, lifetime)
self.is_dir = is_dir


class APIRequest:
def __init__(self, name: str, url_s: str | list, lifetime: int = 0):
class APIRequest(Resource):
def __init__(self, name: str, url_s: str | list[str], lifetime: int = 0):
"""
Represents basic information for an API request.
Represents basic information for an API Request.
Args:
name(str): The name of the API request.
name(str): The name of the API Request.
url(str|list): The URL of the API endpoint.
url_s(str|list): The URL of the API endpoint.
lifetime(int): The lifetime of the API request in days. If 0, the
API request is cached indefinitely.
lifetime(int): The lifetime of the API Request in days. If 0, the
API Request is cached indefinitely.
"""
self.name = name
self.url_s = url_s
self.lifetime = lifetime
super().__init__(name, url_s, lifetime)


class Downloader:
Expand All @@ -103,17 +124,17 @@ def __init__(self, cache_dir: Optional[str] = None) -> None:
self.cache_dict = self._load_cache_dict()

# download function that accepts a resource or a list of resources
def download(self, *resources: Union[Resource, APIRequest]):

def download(self, *resources: Resource):
"""
Download one or multiple file(s), APIRequest(s), or both.
Download one or multiple resources.
Args:
resources (Resource or APIRequest): The resource(s), i.e., file(s)
or API request(s), to download.
reousrces (Resource): The resource(s) to download.
Returns:
list[str]: The path or paths to the downloaded resource(s) or API
request(s).
list[str]: The path or paths to the downloaded resource(s).
"""
paths = []
for resource in resources:
Expand All @@ -125,44 +146,50 @@ def download(self, *resources: Union[Resource, APIRequest]):

return paths

def _download_or_cache(
self, resource: Union[Resource, APIRequest], cache: bool = True
):

def _download_or_cache(self, resource: Resource, cache: bool = True):
"""
Download a resource or an API request if it is not cached or exceeded
its lifetime.
Download a resource if it is not cached or exceeded its lifetime.
Args:
resource (Resource or APIRequest): The file or API request to
download.
resource (Resource): The resource to download.
Returns:
list[str]: The path or paths to the downloaded resource(s) or API
request(s).
list[str]: The path or paths to the downloaded resource(s).
"""
expired = self._is_cache_expired(resource)

if expired or not cache:
self._delete_expired_cache(resource)
if isinstance(resource, Resource):
if isinstance(resource, FileDownload):
logger.info(f"Asking for download of resource {resource.name}.")
paths = self._download_resource(cache, resource)
else:
paths = self._download_files(cache, resource)
elif isinstance(resource, APIRequest):

logger.info(
f"Asking for download of api request {resource.name}."
)
paths = self._download_api_request(resource)

else:
raise TypeError(f"Unknown resource type: {type(resource)}")

else:
paths = self.get_cached_version(resource)
self._update_cache_record(resource)
return paths

def _is_cache_expired(self, resource: Union[Resource, APIRequest]) -> bool:

def _is_cache_expired(self, resource: Resource) -> bool:


"""
Check if resource or API request cache is expired.
Args:
resource (Resource or APIRequest): The file or API request to
download.
resource (Resource): The resource or API request to download.
Returns:
bool: True if cache is expired, False if not.
Expand All @@ -178,44 +205,53 @@ def _is_cache_expired(self, resource: Union[Resource, APIRequest]) -> bool:
expired = True
return expired

def _delete_expired_cache(self, resource: Union[Resource, APIRequest]):
cache_path = self.cache_dir + "/" + resource.name
if os.path.exists(cache_path) and os.path.isdir(cache_path):
shutil.rmtree(cache_path)

def _download_resource(self, cache, resource):
def _delete_expired_cache(self, resource: Resource):
cache_resource_path = self.cache_dir + "/" + resource.name
if os.path.exists(cache_resource_path) and os.path.isdir(
cache_resource_path
):
shutil.rmtree(cache_resource_path)

def _download_files(self, cache, file_download: FileDownload):


"""Download a resource.
Args:
cache (bool): Whether to cache the resource or not.
resource (Resource): The resource to download.
file_download (FileDownload): The resource to download.
Returns:
list[str]: The path or paths to the downloaded resource(s).
"""
if resource.is_dir:
files = self._get_files(resource)
resource.url_s = [resource.url_s + "/" + file for file in files]
resource.is_dir = False
paths = self._download_or_cache(resource, cache)
elif isinstance(resource.url_s, list):
if file_download.is_dir:
files = self._get_files(file_download)
file_download.url_s = [
file_download.url_s + "/" + file for file in files
]
file_download.is_dir = False
paths = self._download_or_cache(file_download, cache)
elif isinstance(file_download.url_s, list):
paths = []
for url in resource.url_s:
fname = url[url.rfind("/") + 1 :]
for url in file_download.url_s:
fname = url[url.rfind("/") + 1 :].split("?")[0]
paths.append(
self._retrieve(
url=url,
fname=fname,
path=os.path.join(self.cache_dir, resource.name),
path=os.path.join(self.cache_dir, file_download.name),
)
)
else:
paths = []
fname = resource.url_s[resource.url_s.rfind("/") + 1 :]
fname = file_download.url_s[
file_download.url_s.rfind("/") + 1 :
].split("?")[0]
results = self._retrieve(
url=resource.url_s,
url=file_download.url_s,
fname=fname,
path=os.path.join(self.cache_dir, resource.name),
path=os.path.join(self.cache_dir, file_download.name),
)
if isinstance(results, list):
paths.extend(results)
Expand Down Expand Up @@ -265,17 +301,16 @@ def _download_api_request(self, api_request: APIRequest):
paths.append(api_path)
return paths

def get_cached_version(
self, resource: Union[Resource, APIRequest]
) -> list[str]:

def get_cached_version(self, resource: Resource) -> list[str]:
"""Get the cached version of a resource.
Args:
resource(Resource or APIRequest): The file or API request to get the
cached version of.
resource(Resource): The resource to get the cached version of.
Returns:
list[str]: The paths to the cached file(s) or API request(s).
list[str]: The paths to the cached resource(s).
"""
cached_location = os.path.join(self.cache_dir, resource.name)
logger.info(f"Use cached version from {cached_location}.")
Expand Down Expand Up @@ -341,23 +376,23 @@ def _retrieve(
progressbar=True,
)

def _get_files(self, resource: Resource):
def _get_files(self, file_download: FileDownload):
"""
Get the files contained in a directory resource.
Get the files contained in a directory file.
Args:
resource (Resource): The directory resource.
file_download (FileDownload): The directory file.
Returns:
list: The files contained in the directory.
"""
if resource.url_s.startswith("ftp://"):
if file_download.url_s.startswith("ftp://"):
# remove protocol
url = resource.url_s[6:]
url = file_download.url_s[6:]
# get base url
url = url[: url.find("/")]
# get directory (remove initial slash as well)
dir = resource.url_s[7 + len(url) :]
dir = file_download.url_s[7 + len(url) :]
# get files
ftp = ftplib.FTP(url)
ftp.login()
Expand Down Expand Up @@ -389,26 +424,26 @@ def _load_cache_dict(self):
logger.info(f"Loading cache file {self.cache_file}.")
return json.load(f)

def _get_cache_record(self, resource: Union[Resource, APIRequest]):

def _get_cache_record(self, resource: Resource):
"""
Get the cache record of a file or an API request.
Get the cache record of a resource.
Args:
resource (Resource or APIRequest): The file or API request to get
the cache record of.
resource (Resource): The resource to get the cache record of.
Returns:
The cache record of the resource.
"""
return self.cache_dict.get(resource.name, {})

def _update_cache_record(self, resource: Union[Resource, APIRequest]):

def _update_cache_record(self, resource: Resource):
"""
Update the cache record of a file or an API request.
Update the cache record of a resource.
Args:
resource (Resource or APIrequest): The file or API request to update
the cache record of.
resource (Resource): The resource to update the cache record of.
"""
cache_record = {}
cache_record["url"] = to_list(resource.url_s)
Expand All @@ -419,17 +454,3 @@ def _update_cache_record(self, resource: Union[Resource, APIRequest]):
json.dump(self.cache_dict, f, default=str)


def is_nested(lst):
"""
Check if a list is nested.
Args:
lst (list): The list to check.
Returns:
bool: True if the list is nested, False otherwise.
"""
for item in lst:
if isinstance(item, list):
return True
return False
16 changes: 16 additions & 0 deletions biocypher/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,19 @@ def to_lower_sentence_case(s: str) -> str:
return pascalcase_to_sentencecase(s)
else:
return s


def is_nested(lst) -> bool:
"""
Check if a list is nested.
Args:
lst (list): The list to check.
Returns:
bool: True if the list is nested, False otherwise.
"""
for item in lst:
if isinstance(item, list):
return True
return False
Loading

0 comments on commit c51dc54

Please sign in to comment.