Source code for zadeh.fis

import json
import numpy as np

try:
    import pandas as pd
except ImportError:
    pd = None

try:
    import ipywidgets
except ImportError:
    ipywidgets = None

try:
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D  # Activates 3d perspective
    from matplotlib import cm
except ImportError:
    plt = None

from .variables import FuzzyVariable
from .rules import FuzzyRuleSet, FuzzyOr, FuzzyAnd, FuzzyValuation
from .context import FuzzyContext, set_fuzzy_context


[docs]class FIS: """A fuzzy inference system""" def __init__(self, variables, rules, target, defuzzification="centroid", aggregation="max", implication="min", AND="min", OR="max"): self.variables = variables if not isinstance(rules, FuzzyRuleSet): rules = FuzzyRuleSet(rules) self.rules = rules # TODO: Support multitarget self.target = target self.context = FuzzyContext(defuzzification=defuzzification, aggregation=aggregation, implication=implication, AND=AND, OR=OR)
[docs] def save(self, path): """Save the FIS definition to a path""" with open(path, "w") as f: json.dump(self._get_description(), f)
[docs] @staticmethod def load(path): """Load a FIS from the given path""" with open(path) as f: s = json.load(f) return FIS._from_description(s)
def _get_description(self): return {"variables": [v._get_description() for v in self.variables], "rules": self.rules._get_description(), "target": self.target._get_description(), "defuzzification": self.context.defuzzification, "aggregation": self.context.aggregation, "implication": self.context.implication, "AND": self.context.AND, "OR": self.context.OR, } @staticmethod def _from_description(description): variables = [FuzzyVariable._from_description(d) for d in description["variables"]] target_variable = FuzzyVariable._from_description(description["target"]) variables_dict = {**{v.name: v for v in variables}, target_variable.name: target_variable} defuzzification = description.get("defuzzification", "centroid") aggregation = description.get("aggregation", "max") implication = description.get("implication", "min") OR = description.get("OR", "max") AND = description.get("AND", "min") return FIS(variables, FuzzyRuleSet._from_description(description["rules"], variables_dict), target_variable, defuzzification=defuzzification, aggregation=aggregation, implication=implication, AND=AND, OR=OR) def _to_c(self): with set_fuzzy_context(self.context): return self.rules._to_c()
[docs] def get_output(self, values): """ Get the output of the system as a fuzzy set Args: values (dict of str): A mapping from variables to their fuzzy values. Returns: FuzzySet: The fuzzy set """ with set_fuzzy_context(self.context): return self.rules(values)
[docs] def get_crisp_output(self, values): """ Get the output of the system as a crisp value Args: values (dict of str): A mapping from variables to their fuzzy values. Returns: Centroid of the output of the system """ with set_fuzzy_context(self.context): return self.target.domain.defuzzify(self.get_output(values))
[docs] def batch_predict(self, X): """ Get the crisp output for a batch of inputs Args: X (pd.DataFrame or np.array or list of list): Input values. If pandas dataframe, must have a column for each of the variables with the same name. If array-like, order must be consistent with the variables. Returns: np.array: An array with the predictions. """ # Pandas dataframe syntax -- if available if pd is not None and isinstance(X, pd.DataFrame): return np.asarray( [self.get_crisp_output({v.name: x[v.name] for v in self.variables}) for _, x in X.iterrows()]) # Assuming ordered array-like input return np.asarray([self.get_crisp_output({v.name: x[i] for i, v in enumerate(self.variables)}) for x in X])
[docs] def dict_to_ordered(self, values): """Transform a dict of inputs into an array in the FIS order""" return [values[key.name] for key in self.variables]
[docs] def get_interactive(self, continuous_update=False): """ Display an interactive plot with the fuzzy output of the FIS Args: continuous_update (bool): Whether to continuously update with the widgets value. """ if ipywidgets is None or plt is None: raise ModuleNotFoundError("ipywidgets and matplotlib are required") def plot(**kwargs): output = self.get_output(kwargs) with set_fuzzy_context(self.context): crisp_value = self.target.domain.defuzzify(output) self.target.domain.plot_set(output) plt.vlines(crisp_value, *plt.ylim(), color="red") plt.legend(["Fuzzy output", "Crisp"]) plt.show() ipywidgets.interact(plot, **{variable.name: variable.domain.get_ipywidget(continuous_update=continuous_update) for variable in self.variables})
[docs] def plot_1d(self, variable, fixed_variables=None, axes=None): """ Produce a plot with the output as a function of a variable when the rest are fixed. Args: variable (FuzzyVariable): The independent variable. fixed_variables (dict of str): A mapping with fuzzy values of the rest of the variables. axes (plt.Axes): An existing axes instance to plot. If None, a new figure is created. Returns: plt.Axes: Axes for further tweaking """ if plt is None: raise ModuleNotFoundError("matplotlib is required") if fixed_variables is None: fixed_variables = {} xx = variable.domain.get_mesh() output = [self.get_crisp_output({variable.name: x, **fixed_variables}) for x in xx] ax = axes or plt.figure().add_subplot(1, 1, 1) ax.plot(xx, output) ax.set_xlabel(variable.name) ax.set_ylabel(self.target.name) return ax
[docs] def get_1d_interactive(self, variable, continuous_update=False): """ Produce an interactive plot with the output as a function of a variable when the rest are fixed. Args: variable (FuzzyVariable): The independent variable. continuous_update (bool): Whether to continuously update with the widgets value. """ if ipywidgets is None or plt is None: raise ModuleNotFoundError("ipywidgets and matplotlib are required") free_variables = [v for v in self.variables if v.name != variable.name] def plot(**kwargs): self.plot_1d(variable, kwargs) plt.show() ipywidgets.interact(plot, **{variable.name: variable.domain.get_ipywidget(continuous_update=continuous_update) for variable in free_variables})
[docs] def plot_2d(self, variable1, variable2, fixed_variables=None, axes=None): """ Produce a plot with the output as a function of two variables when the rest are fixed. Args: variable1 (FuzzyVariable): The first independent variable. variable2 (FuzzyVariable): The second independent variable. fixed_variables (dict of str): A mapping with fuzzy values of the rest of the variables. axes (plt.Axes): An existing axes instance to plot. An 3D projection must have been set on it. If None, a new figure is created. Returns: plt.Axes: Axes for further tweaking """ if plt is None: raise ModuleNotFoundError("matplotlib is required") if fixed_variables is None: fixed_variables = {} x_name = variable1.name y_name = variable2.name # TODO: Allow coarser mesh xx = variable1.domain.get_mesh() yy = variable2.domain.get_mesh() zz = np.asarray( [ [ self.get_crisp_output({x_name: x, y_name: y, **fixed_variables}) for x in xx ] for y in yy ] ) # String coordinates must be converted for this kind of plot: if xx.dtype.kind == 'U': xx = np.arange(len(xx)) if yy.dtype.kind == 'U': yy = np.arange(len(yy)) ax = axes or plt.figure().add_subplot(1, 1, 1, projection="3d") ax.plot_surface(*np.meshgrid(xx, yy), zz, cmap=cm.viridis) ax.set_xlabel(x_name) ax.set_ylabel(y_name) ax.set_zlabel(self.target.name) ax.invert_xaxis() # Seems more natural to me return ax
[docs] def get_2d_interactive(self, variable1, variable2, continuous_update=False): """ Produce an interactive plot with the output as a function of two variables when the rest are fixed. Args: variable1 (FuzzyVariable): The first independent variable. variable2 (FuzzyVariable): The second independent variable. continuous_update (bool): Whether to continuously update with the widgets value. """ if ipywidgets is None or plt is None: raise ModuleNotFoundError("ipywidgets and matplotlib are required") free_variables = [v for v in self.variables if (v.name != variable1.name and v.name != variable2.name)] def plot(**kwargs): self.plot_2d(variable1, variable2, kwargs) plt.show() ipywidgets.interact(plot, **{variable.name: variable.domain.get_ipywidget(continuous_update=continuous_update) for variable in free_variables})
[docs] def plot_rules(self, values, color="k"): """ Produce a plot which can be used to explain the behaviour of the system for the given set of values Args: values (dict of str): A mapping from variables to their values. color (str): Color used to plot rule activation Returns: 2-tuple of (fig, axes) for further tweaking """ if plt is None: raise ModuleNotFoundError("matplotlib is required") rule_list = self.rules.rule_list variables = {x.name: i for i, x in enumerate(self.variables)} fig, axes = plt.subplots(len(rule_list) + 1, len(self.variables) + 1, squeeze=False, figsize=(14, 14), sharex='col', sharey=True) for i, rule in enumerate(rule_list): # Rule input if isinstance(rule.antecedent, (FuzzyOr, FuzzyAnd)): proposition_list = rule.antecedent.proposition_list elif isinstance(rule.antecedent, FuzzyValuation): proposition_list = [rule.antecedent] else: raise NotImplementedError for proposition in proposition_list: plt.sca(axes[i, variables[proposition.variable.name]]) proposition.variable.plot(value=proposition.value) v = values[proposition.variable.name] fv = proposition(values) plt.axhline(fv, color=color, ls="--") plt.plot([v], [fv], "o", color=color) for var, j in variables.items(): plt.sca(axes[i, j]) plt.axvline(values[var], color=color) # Rule output plt.sca(axes[i, -1]) rule.consequent.variable.plot(value=rule.consequent.value) xx = rule.consequent.variable.domain.get_mesh() plt.plot(xx, [rule(values)(x) for x in xx], color=color) # Inputs for j, variable in enumerate(self.variables): plt.sca(axes[-1, j]) axes[-1, j].axvline(values[variable.name], 0, 1, color=color) plt.xlim(variable.domain.min, variable.domain.max) plt.xlabel(variable.name) # Output plt.sca(axes[-1, -1]) xx = self.target.domain.get_mesh() plt.plot(xx, [self.get_output(values)(x) for x in xx], label="membership") plt.vlines(self.get_crisp_output(values), 0, 1, color="C1", ls="--", label="crisp") plt.legend() # Final tuning for i in range(len(rule_list) + 1): for j, var in enumerate(self.variables + [self.target]): plt.sca(axes[i, j]) plt.ylabel("Membership function" if j == 0 else "") plt.xlabel(var.name if i > len(self.variables) else "") fig.subplots_adjust(hspace=0) fig.subplots_adjust(wspace=0.05) return fig, axes
[docs] def plot_rules_interactive(self, continuous_update=False): """ Create an interactive explorer for the plot_rules method Args: continuous_update (bool): Whether to continuously update with the widgets value. """ if ipywidgets is None or plt is None: raise ModuleNotFoundError("ipywidgets and matplotlib are required") def plot(**kwargs): self.plot_rules(kwargs) plt.show() ipywidgets.interact(plot, **{variable.name: variable.domain.get_ipywidget(continuous_update=continuous_update) for variable in self.variables})
[docs] def compile(self): """Get a compiled version of the model""" from .compile import CompiledFIS return CompiledFIS.from_existing(self)
[docs] @staticmethod def from_matlab(path): """Import a MATLABĀ® model from a .fis file""" from .mparser import read_mfis return read_mfis(path)