from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession

from pyspark.sql.types import IntegerType, BooleanType, DateType, StringType, FloatType
from pyspark.sql.functions import lit, avg, sum

from pyspark.sql.types import StructType

if __name__ == '__main__':
    sc = SparkContext('local')
    spark = SparkSession(sc)

    schema = StructType()\
        .add('storeId', IntegerType(), True)\
        .add('open', IntegerType(), True)\
        .add('openDate', IntegerType(), True)\
        .add('division', StringType(), True)\
        .add('sqft', IntegerType(), True)\
        .add('numberOfEmployees', IntegerType(), True)\
        .add('customerSatisfaction', FloatType(), True)
    
    df = spark.read.option('header', True).csv(
        './data3.txt', 
        schema=schema
    )

    df.show()

    # a)
    df_a = df.select(['storeId', 'division'])
    df_a.show()

    # b)
    df_b = df.select([col for col in df.columns if col not in ['sqft', 'customerSatisfaction']])
    df_b.show()

    # c)
    df_c = df.filter(df.sqft < 25000)
    df_c.show()

    # d)
    df_d = df.filter((df.sqft < 25000) | (df.customerSatisfaction > 30))
    df_d.show()

    # e)
    df_e = df.withColumn('sqft100', df.sqft / 100)
    df_e.show()

    # f)
    df_f = df.withColumn('numberOfManagers', lit(1))
    df_f.show()

    # g)
    sqft_mean = df.select(avg('sqft').alias('sqft_mean'))
    # sqft_mean.show()
    # print(sqft_mean.first().sqft_mean)
    df_f = df.withColumn('sqftMean', lit(sqft_mean.first().sqft_mean))
    df_f.show()

    # h)
    print(df.count())

    # i)
    print(df.select(sum('sqft').alias('sqft_sum')).first().sqft_sum)

    # j)
    sample_j = df.sample(fraction=0.75, withReplacement=False)
    sample_j.show()

    # k)
    df.printSchema()

    # l)
    part = df.repartition(12)

    # o)
    # df.write.json('storesDF')
    
    # join
    df_join = df.withColumnsRenamed({'storeId': 'storeId1', 'division': 'division1'})\
    .select(['storeId1', 'division1', 'open'])\
    .join(
        df.withColumnsRenamed({'storeId': 'storeId2', 'division': 'division2'})\
            .select(['storeId2', 'division2', 'open']),
        'open'
    )
    
    df_join.filter(df_join.storeId1 < df_join.storeId2).show()