Source code for llmtuner.store

"""
Storing information on datasets and workflows in local database
"""

import os
import json
import sqlite3
import ast
import logging

# Function to map JSON schema types to SQLite data types
[docs] def map_json_type_to_sqlite(json_type): mapping = { "string": "TEXT", "integer": "INTEGER", "number": "REAL", "boolean": "INTEGER", # SQLite doesn't have a boolean type, typically stored as INTEGER (0 or 1) "array": "TEXT", # Arrays can be stored as a serialized string (e.g., JSON or CSV) "object": "TEXT" # Nested objects could also be stored as a serialized string (e.g., JSON) } return mapping.get(json_type, "TEXT")
[docs] class LLMTunerStore: """Small DB to hold data sets and processing info for model development""" def __init__(self, filename = "llmtuner.db", datapath = "local", keepopen = False):
[docs] self.filename = filename
[docs] self.datapath = datapath
[docs] self.keepopen = keepopen
[docs] self.connection = sqlite3.connect(self.filename)
[docs] self.cursor = self.connection.cursor()
[docs] self.schemas = {}
[docs] def _open(self): """Open connection to the DB""" self.connection = sqlite3.connect(self.filename) self.cursor = self.connection.cursor() if not self.schemas: filename = self.filename[0:self.filename.rfind("/")]+"/dbschemas.json" if os.path.isfile(filename): with open(filename, "r") as file: self.schemas = json.loads(file.read()) else: logging.warning("Could not load database schemas from "+ filename)
[docs] def _close(self, confirm = False): """Set confirm = True to force closing of connection""" if not self.keepopen or confirm: self.connection.close() schemafilename = self.filename[0:self.filename.rfind("/")]+"/dbschemas.json" with open(schemafilename, "w") as file: file.write(json.dumps(self.schemas))
[docs] def _keep_open(self, keepopen = True): """Set to false if connection should be closed after each operation""" self.keepopen = keepopen
@staticmethod
[docs] def _query_create_table_from_schema(schema, table_name = ""): """Extract table name and properties""" if not table_name: table_name = schema['title'] properties = schema['properties'] required_fields = schema.get('required', []) # Construct the CREATE TABLE SQL statement columns = [] for field, attributes in properties.items(): column_definition = f"{field} {map_json_type_to_sqlite(attributes['type'])}" if field in required_fields: column_definition += " NOT NULL" columns.append(column_definition) sqlcommand = f"CREATE TABLE IF NOT EXISTS {table_name} ({', '.join(columns)});" return table_name, sqlcommand
@staticmethod
[docs] def _check_metadata_agreement(dictref, dictcomp): """Compare metadata of object""" mismatch = {"missing": [], "divergent": []} for key in dictref: if key in dictcomp: if str(dictref[key]) != str(dictcomp[key]): mismatch["divergent"].append(key) else: mismatch["missing"].append(key) if not mismatch["missing"] and not mismatch["divergent"]: return True else: return mismatch
@staticmethod
[docs] def format_data_to_sql(schema, data): """ Reformat the entries of the dictionary 'data' according to the JSON schema 'schema' to fit into SQL tables. Arrays and objects are converted to JSON strings, and booleans are converted to integers. Parameters: schema (dict): The JSON schema defining the structure and types of 'data'. data (dict): The data dictionary to be formatted. Returns: dict: A new dictionary with the values reformatted according to the schema. """ formatted_data = {} for field, value in data.items(): field_type = schema['properties'].get(field, {}).get('type') if field_type == 'array' or field_type == 'object': # Convert arrays and objects to JSON strings formatted_data[field] = json.dumps(value) elif field_type == 'boolean': # Convert booleans to integers (True -> 1, False -> 0) formatted_data[field] = int(value) else: # Leave other types as they are (string, integer, number, etc.) formatted_data[field] = value return formatted_data
[docs] def list_tables(self): """Return list of available tables""" self._open() self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") tables = self.cursor.fetchall() return tables
[docs] def get_table_metadata(self, table_name): """Get information for table 'table_name'. Returns dictionary with column names, column info and number of rows.""" self._open() self.cursor.execute(f"SELECT COUNT(*) FROM {table_name}") length = self.cursor.fetchone()[0] metadata = {} self.cursor.execute(f"PRAGMA table_info({table_name})") columns = self.cursor.fetchall() for col in columns: metadata.setdefault(col[1], {"cid": col[0], "name": col[1], "type": col[2], "notnull": col[3], "default": col[4], "primary_key": col[5]} ) tableinfo = {"number_rows": length, "columns": metadata} self._close() return tableinfo
[docs] def create_table(self, schema, table_name = ""): """Creates a DB table from a json schema""" self._open() table_name, query = self._query_create_table_from_schema(schema, table_name) self.cursor.execute(query) self.connection.commit() self.schemas[table_name] = schema print ("Created DB table", table_name) self._close()
[docs] def delete_table(self, table_name): """Delete a table with 'table_name' from the database""" self._open() drop_table_sql = f"DROP TABLE IF EXISTS {table_name};" self.cursor.execute(drop_table_sql) self._close()
[docs] def add_data_to_table(self, table_name, data_dict, uniquekeys = [], update = False, mergekeys = []): """Add data from dictionary to a table in sqlite connection. If 'uniquekey' is provided, the entry is only added if not another entry with the same value for the key is present in the table. If 'update' is True, an existing entry is updated. """ self._open() data_dict_formatted = self.format_data_to_sql(self.schemas[table_name], data_dict) if type(uniquekeys) is str: uniquekeys = [uniquekeys] unique = True if uniquekeys: keyvals = {} for key in uniquekeys: keyvals.setdefault(key, data_dict_formatted[key]) founddata = self.get_data_from_table(table_name, key_values = keyvals, asdicts = False) logging.debug("Looking for prior entries with "+str(keyvals)+". Found "+str(len(founddata))+ " entries.") if founddata: unique = False if update: self._open() values = list(data_dict_formatted.values()) set_clause = ", ".join([f"{col} = ?" for col in data_dict_formatted.keys()]) sql_query = f"UPDATE {table_name} SET {set_clause} WHERE " for key in uniquekeys: sql_query += key+" = ? AND " values.append(keyvals[key]) sql_query = sql_query.rstrip(" AND ") logging.debug("query: "+sql_query) self.cursor.execute(sql_query, values) logging.info(f"Entries with {keyvals} already exist. Skipping") if unique: self._open() fields = ', '.join(data_dict_formatted.keys()) placeholders = ', '.join('?' for _ in data_dict_formatted.keys()) values = tuple(data_dict_formatted.values()) # Construct the INSERT INTO SQL statement insert_sql = f"INSERT INTO {table_name} ({fields}) VALUES ({placeholders})" # Execute the insert statement self.cursor.execute(insert_sql, values) self.connection.commit() logging.debug("Added: "+insert_sql) self._close()
[docs] def delete_data_from_table(self, table_name, key_values = {}): """Delete data from table, providing a table name and keys and values of entries to delete as dictionary""" self._open() query = "DELETE FROM "+table_name jointerm = "WHERE " values = [] for key in key_values: query = query + " " + jointerm + key + " = '"+key_values[key]+"'" jointerm = "AND " query = query+";" self.cursor.execute(query) self.connection.commit() self._close()
[docs] def get_data_from_table(self, table_name, key_values = {}, asdicts = False): """Getting rows from a table, if key_values = {"column1": "value1"} is passed, only entries matching the criteria are selected. If the column contains arrays, if value is 'select:XXX', only rows with XXX will be selected.""" self._open() sql_query = "" values = [] for key in key_values: if type(key_values[key]) == list: placeholders = ", ".join("?" for _ in key_values[key]) sql_query = f"SELECT * FROM {table_name} WHERE {key} IN ({placeholders})" values = key_values[key] continue if not sql_query: sql_query = "SELECT * FROM "+table_name jointerm = "WHERE " for key in key_values: if type(key_values[key]) is str: if key_values[key].find("select:")>-1: key_values[key] = '%"'+key_values[key].lstrip("select:")+'"%' sql_query = sql_query + " " + jointerm + key + " LIKE ? " else: sql_query = sql_query + " " + jointerm + key + " = ? " jointerm = "AND " values.append(key_values[key]) logging.debug("Executing query: '"+sql_query+"' with values "+str(values)) self.cursor.execute(sql_query, values) rows = self.cursor.fetchall() if asdicts: column_names = [description[0] for description in self.cursor.description] result = [] for row in rows: row_dict = dict(zip(column_names, row)) result.append(row_dict) self._close() return result self._close() return rows