from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession

if __name__ == '__main__':
    sc = SparkContext('local')
    spark = SparkSession(sc)

    tran_file = sc.textFile('./ch04_data_transactions.txt')
    tran_data = tran_file.map(lambda t: t.split('#'))

    print(tran_data.collect())

    # vraca parove (prod_id, quantity*price)
    val_by_prod = tran_data.map(lambda t: (int(t[3]), int(t[4])*float(t[5])))
    # vraca parove (prod_id, quantity)
    quant_by_prod = tran_data.map(lambda t: (int(t[3]), int(t[4])))

    sum_val_by_prod = val_by_prod.reduceByKey(lambda x, y: x + y)
    sum_quant_by_prod = quant_by_prod.reduceByKey(lambda x, y: x + y)

    joined_data = sum_val_by_prod.join(sum_quant_by_prod)
    print(joined_data.collect())
    
    # avg_val_by_prod = joined_data.map(lambda t: (t[0], t[1][0] / t[1][1]))
    avg_val_by_prod = joined_data.mapValues(lambda t: t[0] / t[1])
    print(avg_val_by_prod.collect())