[go: up one dir, main page]

Skip to content

Commit

Permalink
✨ Adapted serialization to new tensorflow version
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy authored and migduroli committed Oct 24, 2024
1 parent bb04028 commit 1d3b7ed
Show file tree
Hide file tree
Showing 10 changed files with 51 additions and 36 deletions.
14 changes: 14 additions & 0 deletions flama/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"WebSocketException",
"NotFoundException",
"MethodNotAllowedException",
"FrameworkNotInstalled",
"FrameworkVersionWarning",
]


Expand Down Expand Up @@ -139,3 +141,15 @@ def __repr__(self) -> str:
params = ("path", "params", "method", "allowed")
formatted_params = ", ".join([f"{x}={getattr(self, x)}" for x in params if getattr(self, x)])
return f"{self.__class__.__name__}({formatted_params})"


class FrameworkNotInstalled(Exception):
"""Cannot find an installed version of the framework."""

...


class FrameworkVersionWarning(Warning):
"""Warning for when a framework version does not match."""

...
3 changes: 2 additions & 1 deletion flama/models/models/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

class PyTorchModel(Model):
def predict(self, x: t.List[t.List[t.Any]]) -> t.Any:
assert torch is not None, "`torch` must be installed to use PyTorchModel."
if torch is None: # noqa
raise exceptions.FrameworkNotInstalled("pytorch")

try:
return self.model(torch.Tensor(x)).tolist()
Expand Down
8 changes: 8 additions & 0 deletions flama/models/models/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,17 @@
from flama import exceptions
from flama.models.base import Model

try:
import sklearn # type: ignore
except Exception: # pragma: no cover
sklearn = None


class SKLearnModel(Model):
def predict(self, x: t.List[t.List[t.Any]]) -> t.Any:
if sklearn is None: # noqa
raise exceptions.FrameworkNotInstalled("scikit-learn")

try:
return self.model.predict(x).tolist()
except ValueError as e:
Expand Down
13 changes: 11 additions & 2 deletions flama/models/models/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
from flama import exceptions
from flama.models.base import Model

try:
import numpy as np # type: ignore
except Exception: # pragma: no cover
np = None

try:
import tensorflow as tf # type: ignore
except Exception: # pragma: no cover
Expand All @@ -11,9 +16,13 @@

class TensorFlowModel(Model):
def predict(self, x: t.List[t.List[t.Any]]) -> t.Any:
assert tf is not None, "`tensorflow` must be installed to use TensorFlowModel."
if np is None: # noqa
raise exceptions.FrameworkNotInstalled("numpy")

if tf is None: # noqa
raise exceptions.FrameworkNotInstalled("tensorflow")

try:
return self.model.predict(x).tolist()
return self.model.predict(np.array(x)).tolist()
except (tf.errors.OpError, ValueError): # type: ignore
raise exceptions.HTTPException(status_code=400)
4 changes: 2 additions & 2 deletions flama/serialize/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import weakref
from pathlib import Path

from flama import exceptions
from flama.serialize.base import Serializer
from flama.serialize.exceptions import FrameworkVersionWarning
from flama.serialize.types import Framework

if sys.version_info < (3, 11): # PORT: Remove when stop supporting 3.10 # pragma: no cover
Expand Down Expand Up @@ -284,7 +284,7 @@ def from_dict(cls, data: t.Dict[str, t.Any], **kwargs) -> "ModelArtifact":
warnings.warn(
f"Model was built using {metadata.framework.lib.value} '{metadata.framework.version}' but detected "
f"version '{serializer.version()}' installed. This may cause unexpected behavior.",
FrameworkVersionWarning,
exceptions.FrameworkVersionWarning,
)

return cls(model=model, meta=metadata, artifacts=artifacts)
Expand Down
13 changes: 0 additions & 13 deletions flama/serialize/exceptions.py

This file was deleted.

3 changes: 2 additions & 1 deletion flama/serialize/serializers/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import io
import typing as t

from flama.serialize import exceptions, types
from flama import exceptions
from flama.serialize import types
from flama.serialize.base import Serializer

try:
Expand Down
3 changes: 2 additions & 1 deletion flama/serialize/serializers/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import typing as t
import warnings

from flama.serialize import exceptions, types
from flama import exceptions
from flama.serialize import types
from flama.serialize.base import Serializer

if t.TYPE_CHECKING:
Expand Down
24 changes: 9 additions & 15 deletions flama/serialize/serializers/tensorflow.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import codecs
import importlib.metadata
import io
import json
import tarfile
import tempfile
import typing as t
from tempfile import TemporaryDirectory

from flama.serialize import exceptions, types
from flama import exceptions
from flama.serialize import types
from flama.serialize.base import Serializer

try:
Expand All @@ -25,22 +24,17 @@ def dump(self, obj: t.Any, **kwargs) -> bytes:
if tf is None: # noqa
raise exceptions.FrameworkNotInstalled("tensorflow")

buffer = io.BytesIO()
with TemporaryDirectory() as saved_model_dir, tarfile.open(fileobj=buffer, mode="w") as model_tar:
tf.keras.models.save_model(obj, saved_model_dir) # type: ignore
model_tar.add(saved_model_dir, arcname="")
buffer.seek(0)
return codecs.encode(buffer.read(), "base64")
with tempfile.NamedTemporaryFile(mode="rb", suffix=".keras") as tmp_file:
tf.keras.models.save_model(obj, tmp_file.name) # type: ignore
return codecs.encode(tmp_file.read(), "base64")

def load(self, model: bytes, **kwargs) -> t.Any:
if tf is None: # noqa
raise exceptions.FrameworkNotInstalled("tensorflow")

with TemporaryDirectory() as saved_model_dir, tarfile.open(
fileobj=io.BytesIO(codecs.decode(model, "base64")), mode="r:"
) as model_tar:
model_tar.extractall(saved_model_dir)
return tf.keras.models.load_model(saved_model_dir) # type: ignore
with tempfile.NamedTemporaryFile(mode="wb", suffix=".keras") as tmp_file:
tmp_file.write(codecs.decode(model, "base64"))
return tf.keras.models.load_model(tmp_file.name) # type: ignore

def info(self, model: t.Any) -> t.Optional["JSONSchema"]:
model_info: "JSONSchema" = json.loads(model.to_json())
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _sklearn_pipeline(self):
def _tensorflow(self):
model = tf.keras.models.Sequential(
[
tf.keras.layers.Flatten(input_shape=(2,)),
tf.keras.Input((2,)),
tf.keras.layers.Dense(10, activation="tanh"),
tf.keras.layers.Dense(1, activation="sigmoid"),
]
Expand Down

0 comments on commit 1d3b7ed

Please sign in to comment.