import os

import dask.array as da
import napari
import numpy as np
import xarray as xr
from csbdeep.utils import normalize
from dask import delayed
from skimage import io
from stardist.models import StarDist2D
from stardist.plot import render_label
import dask

from clonedetective import utils
from clonedetective.clone_counters import LazyCloneCounter
foo = LazyCloneCounter("trying_stardist", r"a\dg\d\d?p\d", 0.275)
foo.add_images(
    C0="../current_imaging_analysis/MARCM2A_E7F1_refactoring/C0/C0_imgs/*.tif*",
    C1="../current_imaging_analysis/MARCM2A_E7F1_refactoring/C1/C1_imgs/*.tif*",
    C2="../current_imaging_analysis/MARCM2A_E7F1_refactoring/C2/C2_imgs/*.tif*",
    C3="../current_imaging_analysis/MARCM2A_E7F1_refactoring/C3/C3_imgs/*.tif*",
)
model = StarDist2D.from_pretrained("2D_versatile_fluo")
Found model '2D_versatile_fluo' for 'StarDist2D'.
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.479071, nms_thresh=0.3.
2021-08-03 09:42:49.228354: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
def map_stardist(img_4d):
    img = img_4d[0, 0, ...]
    seg = model.predict_instances(normalize(img))[0]
    return seg[None, None, ...]
bar = foo.image_data["images"].data.map_blocks(map_stardist, dtype=np.int32)
labels_xarr = xr.DataArray(
    bar, coords=foo.image_data["images"].coords, dims=foo.image_data["images"].dims
).rename({"img_channels": "seg_channels"})
def save_stardist_seg(xarr_3d, fp):
    l = list()
    for img in xarr_3d:
        l.append(delayed(io.imsave)(
            os.path.join(
                fp, img.coords["img_name"].values.tolist() + ".tif"),
            img.data,
        ))
    return l
d_saves = save_stardist_seg(labels_xarr[0], "C0_stardist_segs")
dask.compute(d_saves)
os.listdir('C0_stardist_segs')
['a2g09p2.tif',
 'a2g10p3.tif',
 'a2g12p1.tif',
 'a1g04p1.tif',
 'a2g10p2.tif',
 'a2g09p3.tif',
 'a2g09p1.tif',
 'a1g04p3.tif',
 'a2g12p2.tif',
 'a1g06p1.tif',
 'a1g04p2.tif',
 'a2g10p1.tif',
 'a1g02p1.tif',
 'a1g02p3.tif',
 'a1g02p2.tif',
 'a2g13p2.tif',
 'a1g07p1.tif',
 'a1g05p3.tif',
 'a2g08p1.tif',
 'a1g05p2.tif',
 'a2g11p1.tif',
 'a2g13p3.tif',
 'a2g13p1.tif',
 'a1g07p2.tif',
 'a2g11p3.tif',
 'a2g08p2.tif',
 'a2g08p3.tif',
 'a1g05p1.tif',
 'a2g11p2.tif',
 'a1g07p3.tif',
 'a1g03p3.tif',
 'a1g01p1.tif',
 'a1g03p2.tif',
 'a1g01p2.tif',
 'a1g01p3.tif',
 'a1g03p1.tif',
 'a1g14p2.tif',
 'a2g02p2.tif',
 'a1g14p1.tif',
 'a2g02p1.tif',
 'a1g09p1.tif',
 'a2g04p3.tif',
 'a2g06p1.tif',
 'a1g12p2.tif',
 'a1g12p3.tif',
 'a1g10p1.tif',
 'a2g04p2.tif',
 'a1g09p2.tif',
 'a2g06p2.tif',
 'a1g12p1.tif',
 'a2g06p3.tif',
 'a2g04p1.tif',
 'a1g09p3.tif',
 '.ipynb_checkpoints',
 'a2g01p2.tif',
 'a1g15p1.tif',
 'a2g01p3.tif',
 'a2g03p1.tif',
 'a2g03p3.tif',
 'a2g01p1.tif',
 'a1g15p2.tif',
 'a1g15p3.tif',
 'a2g03p2.tif',
 'a2g07p2.tif',
 'a1g13p1.tif',
 'a1g11p3.tif',
 'a1g08p2.tif',
 'a1g08p3.tif',
 'a1g11p2.tif',
 'a2g05p1.tif',
 'a2g07p3.tif',
 'a2g07p1.tif',
 'a2g05p3.tif',
 'a1g08p1.tif',
 'a1g11p1.tif',
 'a2g05p2.tif']