"""Sets up basic configuration for LLMTuner"""
import os
import glob
import yaml
import json
import uuid
import importlib
import logging
logging.basicConfig(level=logging.INFO)
from llmtuner.store import LLMTunerStore
[docs]
class LLMTunerConfig:
    """Class to hold configuration for the LLMTuner"""
    def __init__(self, configfilepath = "", initialize = True):
        if not configfilepath:
            self.configpath = os.getcwd()+"/.llmtuner"
            
            if not os.path.exists(self.configpath):
                self._create_default_folder(self.configpath)
        else:
            self.configpath = configfilepath
            if not os.path.exists(configfilepath) and initialize: 
                self._create_default_folder(self.configpath)
            elif not os.path.exists(configfilepath):
                logging.warning("The filepath you provided does not exist. "+
                           "If you want to force creation of a folder, set 'initalize = True'. "+
                           "For now, created default folder.")
                self.configpath = os.getcwd()+"/.llmtuner"
                self._create_default_folder(self.configpath)
            
[docs]
        self.codedir = os.path.dirname(os.path.abspath(__file__)) 
        self._read_configfile()
        self._set_datapath() 
        self._initialize_db()
        self._initialize_parameters()
        self._set_talkerlist()
[docs]
    def _create_default_folder(self, filepath):
        """Creating a default folder if the folder does not exist"""
        os.makedirs(filepath, exist_ok=True)
        os.makedirs(filepath+"/local", exist_ok=True)
        
        with open(filepath+"/tunerconfig.yml", 'w') as file:
            self.parameterdict = default_configinfo
            self.parameterdict["datapath"] = filepath
            yaml.dump(self.parameterdict, file, default_flow_style=False)
        logging.info("Created standard configuration file.") 
[docs]
    def _read_configfile(self):
        """Get the right configuration file, default configs/tunerconfig.yml"""
        
        if os.path.isfile(self.configpath+"/tunerconfig.yml"):
            with open(self.configpath+"/tunerconfig.yml", 'r') as stream:
                try:
                    self.parameterdict = yaml.safe_load(stream)
                    logging.debug("Loaded parameters")
                except yaml.YAMLError as exc:
                    print(exc)
        else:
            logging.warning("Could not find configuration file. Something is seriously wrong ...")
            self.parameterdict = default_configinfo 
[docs]
    def _set_datapath(self, datapath = ""):
        """Find or create data folder, default .llmdata in the working directory, or 'datapath' in configuration file"""
        
        if "datapath" in self.parameterdict:
            if self.parameterdict["datapath"][0] == "/" and len(self.parameterdict["datapath"])>3:
                if os.path.exists(self.parameterdict["datapath"]):
                    self.datapath = self.parameterdict["datapath"]
                else:
                    logging.warning("Datapath", self.parameterdict["datapath"], "not found. Setting default datapath in current working directory.")
                    self.datapath = self.configpath+"/.llmdata"
            else:
                if os.path.exists(self.configpath+"/"+self.parameterdict["datapath"]):
                    self.datapath = self.configpath+"/"+self.parameterdict["datapath"]
                else:
                    logging.info("Datapath", self.configpath+"/"+self.parameterdict["datapath"], "not found. Setting default datapath in current working directory.")
                    self.datapath = self.configpath
        else:
            self.datapath = self.configpath
        logging.debug("Got the datapath at "+self.datapath) 
[docs]
    def _initialize_db(self):
        """Initialize or find the database"""
        if "databasename" in self.parameterdict:
            if os.path.isfile(self.parameterdict["databasename"]):
                self.dbpath = self.parameterdict["databasename"]
            elif os.path.isfile(self.datapath+self.parameterdict["databasename"]):
                self.dbpath = self.datapath+"/"+self.parameterdict["databasename"]
            else:
                logging.info("Did not find database at given datapath. Will revert to standard location.")
                self.dbpath = self.datapath+"/llmtuner.db"       
        else:
            self.dbpath = self.datapath+"/llmtuner.db"
        
        if not os.path.isfile(self.dbpath):
            logging.debug("Creating database at "+ self.dbpath)
            self.store = LLMTunerStore(self.dbpath, self.datapath+"/local")
            
            schemas = glob.glob(self.codedir + "/schemas/*.schema.json")
            for schemafile in schemas:
                with open(schemafile, "r") as f:
                    schema = json.loads(f.read())
                self.store.create_table(schema)
        else:
            self.store = LLMTunerStore(self.dbpath, self.datapath+"/local") 
[docs]
    def _initialize_parameters(self):
        """Create attributes to configuration to store interface parameters"""
        self.params = DictToObj(self.parameterdict)
        if hasattr(self.params.servers.anythingllm, "base_url"):
            if self.params.servers.anythingllm.base_url == default_configinfo["servers"]["anythingllm"]["base_url"]:
                logging.warning("You have to adjust the default configuration to provide access to an anythingllm instance.")
        else:
            print ("You have not provided a URL for AnythingLLM.")
        if hasattr(self.params.servers.anythingllm, "api_key"):
            if self.params.servers.anythingllm.api_key == default_configinfo["servers"]["anythingllm"]["api_key"]:
                logging.warning("You have not provided an API key for Anything LLM. This will limit functionality")
        else:
            print ("You have not provided an API key for Anything LLM. This will limit functionality") 
[docs]
    def _update_parameters(self, parentkey, valuedict):
        """Update setting of parameters for a given parentkey with a dictionary of values and write to file."""
        for key in self.parameterdict:
            if key == parentkey:
                print ("Updating configuration for", key)
                self.parameterdict.setdefault(parentkey, valuedict)
            else:
                for subkey in self.parameterdict[key]:
                    if subkey == parentkey:
                        print ("Updating configuration for", subkey)
                        self.parameterdict[key][parentkey] = valuedict
                        
        self._initialize_parameters()
        self.write_configs() 
[docs]
    def get_store(self):
        return LLMTunerStore(self.dbpath) 
[docs]
    def write_configs(self):
        """Saving current setting in the configuration to the configuration file."""
        self.parameterdict["datapath"] = self.datapath
        self.parameterdict["databasename"] = self.dbpath
        result = {}
        for key, value in self.params.__dict__.items():
            if isinstance(value, DictToObj):
                result[key] = value.to_dict()  # Recursively convert nested objects
            else:
                result[key] = value
        for key in result:
            self.parameterdict.setdefault(key, result[key])
    
        with open(self.configpath+"/tunerconfig.yml", 'w') as file:
            yaml.dump(self.parameterdict, file, default_flow_style=False)
           
        logging.info("Wrote current configurations to config file.") 
[docs]
    def _set_talkerlist(self):
        """Get the available talker classes for interfacing"""
        self.talkers = {}
        interfacenames = ["open"]
        for entry in self.parameterdict["servers"]:
            if entry != "anythingllm":
                interfacenames.append(entry)
        if "interfaces" in self.parameterdict:
            interfacenames += self.parameterdict["interfaces"]
        for name in interfacenames:
            module = importlib.import_module("llmtuner.infogetter."+name)
            class_name = name.capitalize() + "Talker"  # Assuming class names match the filenames
            clazz = getattr(module, class_name)
            self.talkers.setdefault(name, clazz) 
[docs]
    def write_local(self, content, filename = ""):
        """Write the file to local storage"""
        if not filename:
            filename = "download_" + str(uuid.uuid4().hex[:8])
        filepath = self.datapath + "/" + filename
        with open(filepath, 'wb') as file:
            file.write(content)
        return filepath 
[docs]
    def delete_local(self, filepath = ""):
        """Delete a file from local storage"""
        if os.path.isfile(filepath):
            os.remove(filepath)
            logging.info(f"File '{filepath}' has been deleted.")
        else:
            logging.warning(f"File '{filepath}' does not exist.") 
 
[docs]
default_configinfo = {
    "datapath": ".llmdata",
    "databasename": "llmtuner.db",
    "servers": {
        "anythingllm": {
            "base_url": "http://your-anythingllm-instance.com/api/v1/",
            "api_key": "YOUR API KEY"
        },
        "arxiv": {
            "base_url": "http://export.arxiv.org/api/query",
            "sourcetype": ""
        },
        "wiki": {
            "base_url": "https://wiki.km3net.de/api.php",
            "username": "",
            "password": ""
        }
    }
} 
[docs]
class DictToObj:
    def __init__(self, dictionary):
        for key, value in dictionary.items():
            if isinstance(value, dict):
                setattr(self, key, DictToObj(value))
            else:
                setattr(self, key, value)
[docs]
    def to_dict(self):
        result = {}
        for key, value in self.__dict__.items():
            if isinstance(value, DictToObj):
                result[key] = value.to_dict()  # Recursively convert nested objects
            else:
                result[key] = value
        return result